gpjax 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
23
23
  )
24
24
  from gpjax.kernels.stationary.base import StationaryKernel
25
25
  from gpjax.kernels.stationary.utils import squared_distance
26
- from gpjax.parameters import PositiveReal
27
26
  from gpjax.typing import (
28
27
  Array,
29
28
  ScalarArray,
@@ -70,17 +69,15 @@ class RationalQuadratic(StationaryKernel):
70
69
  compute_engine: The computation engine that the kernel uses to compute the
71
70
  covariance matrix.
72
71
  """
73
- if isinstance(alpha, nnx.Variable):
74
- self.alpha = alpha
75
- else:
76
- self.alpha = PositiveReal(alpha)
72
+ self.alpha = alpha
77
73
 
78
74
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
79
75
 
80
76
  def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
81
77
  x = self.slice_input(x) / self.lengthscale.value
82
78
  y = self.slice_input(y) / self.lengthscale.value
83
- K = self.variance.value * (
84
- 1 + 0.5 * squared_distance(x, y) / self.alpha.value
85
- ) ** (-self.alpha.value)
79
+ alpha_val = self.alpha.value if hasattr(self.alpha, "value") else self.alpha
80
+ K = self.variance.value * (1 + 0.5 * squared_distance(x, y) / alpha_val) ** (
81
+ -alpha_val
82
+ )
86
83
  return K.squeeze()
gpjax/likelihoods.py CHANGED
@@ -29,7 +29,6 @@ from gpjax.integrators import (
29
29
  )
30
30
  from gpjax.parameters import (
31
31
  NonNegativeReal,
32
- Static,
33
32
  )
34
33
  from gpjax.typing import (
35
34
  Array,
@@ -59,27 +58,27 @@ class AbstractLikelihood(nnx.Module):
59
58
  self.num_datapoints = num_datapoints
60
59
  self.integrator = integrator
61
60
 
62
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
61
+ def __call__(
62
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
63
+ ) -> npd.Distribution:
63
64
  r"""Evaluate the likelihood function at a given predictive distribution.
64
65
 
65
66
  Args:
66
- *args (Any): Arguments to be passed to the likelihood's `predict` method.
67
- **kwargs (Any): Keyword arguments to be passed to the likelihood's
68
- `predict` method.
67
+ dist: The predictive distribution to evaluate the likelihood at.
69
68
 
70
69
  Returns:
71
70
  The predictive distribution.
72
71
  """
73
- return self.predict(*args, **kwargs)
72
+ return self.predict(dist)
74
73
 
75
74
  @abc.abstractmethod
76
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
75
+ def predict(
76
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
77
+ ) -> npd.Distribution:
77
78
  r"""Evaluate the likelihood function at a given predictive distribution.
78
79
 
79
80
  Args:
80
- *args (Any): Arguments to be passed to the likelihood's `predict` method.
81
- **kwargs (Any): Keyword arguments to be passed to the likelihood's
82
- `predict` method.
81
+ dist: The predictive distribution to evaluate the likelihood at.
83
82
 
84
83
  Returns:
85
84
  npd.Distribution: The predictive distribution.
@@ -133,9 +132,7 @@ class Gaussian(AbstractLikelihood):
133
132
  def __init__(
134
133
  self,
135
134
  num_datapoints: int,
136
- obs_stddev: tp.Union[
137
- ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
138
- ] = 1.0,
135
+ obs_stddev: tp.Union[ScalarFloat, Float[Array, "#N"], NonNegativeReal] = 1.0,
139
136
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
140
137
  ):
141
138
  r"""Initializes the Gaussian likelihood.
@@ -148,7 +145,7 @@ class Gaussian(AbstractLikelihood):
148
145
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
149
146
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
150
147
  """
151
- if not isinstance(obs_stddev, (NonNegativeReal, Static)):
148
+ if not isinstance(obs_stddev, NonNegativeReal):
152
149
  obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
153
150
  self.obs_stddev = obs_stddev
154
151
 
gpjax/linalg/utils.py CHANGED
@@ -1,5 +1,8 @@
1
1
  """Utility functions for the linear algebra module."""
2
2
 
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
3
6
  from gpjax.linalg.operators import LinearOperator
4
7
 
5
8
 
@@ -31,3 +34,32 @@ def psd(A: LinearOperator) -> LinearOperator:
31
34
  A.annotations = set()
32
35
  A.annotations.add(PSD)
33
36
  return A
37
+
38
+
39
+ def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
40
+ """Add jitter to the diagonal of a matrix for numerical stability.
41
+
42
+ This function adds a small positive value (jitter) to the diagonal elements
43
+ of a square matrix to improve numerical stability, particularly for
44
+ Cholesky decompositions and matrix inversions.
45
+
46
+ Args:
47
+ matrix: A square matrix to which jitter will be added.
48
+ jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
49
+
50
+ Returns:
51
+ The matrix with jitter added to its diagonal.
52
+
53
+ Examples:
54
+ >>> import jax.numpy as jnp
55
+ >>> from gpjax.linalg.utils import add_jitter
56
+ >>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
57
+ >>> jittered_matrix = add_jitter(matrix, jitter=0.01)
58
+ """
59
+ if matrix.ndim != 2:
60
+ raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
61
+
62
+ if matrix.shape[0] != matrix.shape[1]:
63
+ raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
64
+
65
+ return matrix + jnp.eye(matrix.shape[0]) * jitter
gpjax/mean_functions.py CHANGED
@@ -27,8 +27,6 @@ from jaxtyping import (
27
27
 
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
- Real,
31
- Static,
32
30
  )
33
31
  from gpjax.typing import (
34
32
  Array,
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
132
130
 
133
131
  def __init__(
134
132
  self,
135
- constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
133
+ constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
136
134
  ):
137
- if isinstance(constant, Parameter) or isinstance(constant, Static):
135
+ if isinstance(constant, Parameter):
138
136
  self.constant = constant
139
137
  else:
140
- self.constant = Real(jnp.array(constant))
138
+ self.constant = jnp.array(constant)
141
139
 
142
140
  def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
143
141
  r"""Evaluate the mean function at the given points.
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
148
146
  Returns:
149
147
  Float[Array, "1"]: The evaluated mean function.
150
148
  """
151
- return jnp.ones((x.shape[0], 1)) * self.constant.value
149
+ if isinstance(self.constant, Parameter):
150
+ return jnp.ones((x.shape[0], 1)) * self.constant.value
151
+ else:
152
+ return jnp.ones((x.shape[0], 1)) * self.constant
152
153
 
153
154
 
154
155
  class Zero(Constant):
@@ -160,7 +161,7 @@ class Zero(Constant):
160
161
  """
161
162
 
162
163
  def __init__(self):
163
- super().__init__(constant=Static(jnp.array(0.0)))
164
+ super().__init__(constant=0.0)
164
165
 
165
166
 
166
167
  class CombinationMeanFunction(AbstractMeanFunction):
@@ -175,7 +176,7 @@ class CombinationMeanFunction(AbstractMeanFunction):
175
176
  super().__init__(**kwargs)
176
177
 
177
178
  # Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
178
- items_list: list[AbstractMeanFunction] = []
179
+ items_list: list[AbstractMeanFunction] = nnx.List([])
179
180
 
180
181
  for item in means:
181
182
  if not isinstance(item, AbstractMeanFunction):
gpjax/objectives.py CHANGED
@@ -20,6 +20,7 @@ from gpjax.linalg import (
20
20
  psd,
21
21
  solve,
22
22
  )
23
+ from gpjax.linalg.utils import add_jitter
23
24
  from gpjax.typing import (
24
25
  Array,
25
26
  ScalarFloat,
@@ -97,7 +98,7 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
97
98
 
98
99
  # Σ = (Kxx + Io²) = LLᵀ
99
100
  Kxx = posterior.prior.kernel.gram(x)
100
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
101
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
101
102
  Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
102
103
  Sigma = psd(Dense(Sigma_dense))
103
104
 
@@ -213,7 +214,7 @@ def log_posterior_density(
213
214
 
214
215
  # Gram matrix
215
216
  Kxx = posterior.prior.kernel.gram(x)
216
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
217
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
217
218
  Kxx = psd(Dense(Kxx_dense))
218
219
  Lx = lower_cholesky(Kxx)
219
220
 
@@ -349,7 +350,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
349
350
  noise = variational_family.posterior.likelihood.obs_stddev.value**2
350
351
  z = variational_family.inducing_inputs.value
351
352
  Kzz = kernel.gram(z)
352
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * variational_family.jitter
353
+ Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
353
354
  Kzz = psd(Dense(Kzz_dense))
354
355
  Kzx = kernel.cross_covariance(z, x)
355
356
  Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
gpjax/parameters.py CHANGED
@@ -122,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
122
122
  )
123
123
 
124
124
 
125
- class Static(nnx.Variable[T]):
126
- """Static parameter that is not trainable."""
127
-
128
- def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
129
- _check_is_arraylike(value)
130
-
131
- super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
132
- self._tag = tag
133
-
134
-
135
125
  class LowerTriangular(Parameter[T]):
136
126
  """Parameter that is a lower triangular matrix."""
137
127
 
@@ -40,11 +40,11 @@ from gpjax.linalg import (
40
40
  psd,
41
41
  solve,
42
42
  )
43
+ from gpjax.linalg.utils import add_jitter
43
44
  from gpjax.mean_functions import AbstractMeanFunction
44
45
  from gpjax.parameters import (
45
46
  LowerTriangular,
46
47
  Real,
47
- Static,
48
48
  )
49
49
  from gpjax.typing import (
50
50
  Array,
@@ -110,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
110
110
  inducing_inputs: tp.Union[
111
111
  Float[Array, "N D"],
112
112
  Real,
113
- Static,
114
113
  ],
115
114
  jitter: ScalarFloat = 1e-6,
116
115
  ):
117
- if not isinstance(inducing_inputs, (Real, Static)):
116
+ if not isinstance(inducing_inputs, Real):
118
117
  inducing_inputs = Real(inducing_inputs)
119
118
 
120
119
  self.inducing_inputs = inducing_inputs
@@ -177,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
177
176
  approximation and the GP prior.
178
177
  """
179
178
  # Unpack variational parameters
180
- mu = self.variational_mean.value
181
- sqrt = self.variational_root_covariance.value
182
- z = self.inducing_inputs.value
179
+ variational_mean = self.variational_mean.value
180
+ variational_sqrt = self.variational_root_covariance.value
181
+ inducing_inputs = self.inducing_inputs.value
183
182
 
184
183
  # Unpack mean function and kernel
185
184
  mean_function = self.posterior.prior.mean_function
186
185
  kernel = self.posterior.prior.kernel
187
186
 
188
- muz = mean_function(z)
189
- Kzz = kernel.gram(z)
190
- Kzz = psd(Dense(Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter))
187
+ inducing_mean = mean_function(inducing_inputs)
188
+ Kzz = kernel.gram(inducing_inputs)
189
+ Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
191
190
 
192
- sqrt = Triangular(sqrt)
193
- S = sqrt @ sqrt.T
191
+ variational_sqrt_triangular = Triangular(variational_sqrt)
192
+ variational_covariance = (
193
+ variational_sqrt_triangular @ variational_sqrt_triangular.T
194
+ )
194
195
 
195
- qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
196
- pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
196
+ q_inducing = GaussianDistribution(
197
+ loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
198
+ )
199
+ p_inducing = GaussianDistribution(
200
+ loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
201
+ )
197
202
 
198
- return qu.kl_divergence(pu)
203
+ return q_inducing.kl_divergence(p_inducing)
199
204
 
200
205
  def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
201
206
  r"""Compute the predictive distribution of the GP at the test inputs t.
@@ -215,26 +220,26 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
215
220
  the test inputs.
216
221
  """
217
222
  # Unpack variational parameters
218
- mu = self.variational_mean.value
219
- sqrt = self.variational_root_covariance.value
220
- z = self.inducing_inputs.value
223
+ variational_mean = self.variational_mean.value
224
+ variational_sqrt = self.variational_root_covariance.value
225
+ inducing_inputs = self.inducing_inputs.value
221
226
 
222
227
  # Unpack mean function and kernel
223
228
  mean_function = self.posterior.prior.mean_function
224
229
  kernel = self.posterior.prior.kernel
225
230
 
226
- Kzz = kernel.gram(z)
227
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
231
+ Kzz = kernel.gram(inducing_inputs)
232
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
228
233
  Kzz = psd(Dense(Kzz_dense))
229
234
  Lz = lower_cholesky(Kzz)
230
- muz = mean_function(z)
235
+ inducing_mean = mean_function(inducing_inputs)
231
236
 
232
237
  # Unpack test inputs
233
- t = test_inputs
238
+ test_points = test_inputs
234
239
 
235
- Ktt = kernel.gram(t)
236
- Kzt = kernel.cross_covariance(z, t)
237
- mut = mean_function(t)
240
+ Ktt = kernel.gram(test_points)
241
+ Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
+ test_mean = mean_function(test_points)
238
243
 
239
244
  # Lz⁻¹ Kzt
240
245
  Lz_inv_Kzt = solve(Lz, Kzt)
@@ -243,10 +248,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
243
248
  Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
244
249
 
245
250
  # Ktz Kzz⁻¹ sqrt
246
- Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
251
+ Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
247
252
 
248
253
  # μt + Ktz Kzz⁻¹ (μ - μz)
249
- mean = mut + jnp.matmul(Kzz_inv_Kzt.T, mu - muz)
254
+ mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
250
255
 
251
256
  # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
252
257
  covariance = (
@@ -254,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
254
259
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
255
260
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
256
261
  )
257
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
262
+ if hasattr(covariance, "to_dense"):
263
+ covariance = covariance.to_dense()
264
+ covariance = add_jitter(covariance, self.jitter)
265
+ covariance = Dense(covariance)
258
266
 
259
267
  return GaussianDistribution(
260
268
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -329,7 +337,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
329
337
  kernel = self.posterior.prior.kernel
330
338
 
331
339
  Kzz = kernel.gram(z)
332
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
340
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
333
341
  Kzz = psd(Dense(Kzz_dense))
334
342
  Lz = lower_cholesky(Kzz)
335
343
 
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
355
363
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
356
364
  + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
357
365
  )
358
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
366
+ if hasattr(covariance, "to_dense"):
367
+ covariance = covariance.to_dense()
368
+ covariance = add_jitter(covariance, self.jitter)
369
+ covariance = Dense(covariance)
359
370
 
360
371
  return GaussianDistribution(
361
372
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
390
401
  if natural_matrix is None:
391
402
  natural_matrix = -0.5 * jnp.eye(self.num_inducing)
392
403
 
393
- self.natural_vector = Static(natural_vector)
394
- self.natural_matrix = Static(natural_matrix)
404
+ self.natural_vector = Real(natural_vector)
405
+ self.natural_matrix = Real(natural_matrix)
395
406
 
396
407
  def prior_kl(self) -> ScalarFloat:
397
408
  r"""Compute the KL-divergence between our current variational approximation
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
422
433
 
423
434
  # S⁻¹ = -2θ₂
424
435
  S_inv = -2 * natural_matrix
425
- S_inv += jnp.eye(m) * self.jitter
436
+ S_inv = add_jitter(S_inv, self.jitter)
426
437
 
427
438
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
428
439
  sqrt_inv = jnp.swapaxes(
@@ -441,7 +452,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
441
452
 
442
453
  muz = mean_function(z)
443
454
  Kzz = kernel.gram(z)
444
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
455
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
445
456
  Kzz = psd(Dense(Kzz_dense))
446
457
 
447
458
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
@@ -476,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
476
487
 
477
488
  # S⁻¹ = -2θ₂
478
489
  S_inv = -2 * natural_matrix
479
- S_inv += jnp.eye(m) * self.jitter
490
+ S_inv = add_jitter(S_inv, self.jitter)
480
491
 
481
492
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
482
493
  sqrt_inv = jnp.swapaxes(
@@ -493,7 +504,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
493
504
  mu = jnp.matmul(S, natural_vector)
494
505
 
495
506
  Kzz = kernel.gram(z)
496
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
507
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
497
508
  Kzz = psd(Dense(Kzz_dense))
498
509
  Lz = lower_cholesky(Kzz)
499
510
  muz = mean_function(z)
@@ -520,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
520
531
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
521
532
  + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
522
533
  )
523
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
534
+ if hasattr(covariance, "to_dense"):
535
+ covariance = covariance.to_dense()
536
+ covariance = add_jitter(covariance, self.jitter)
537
+ covariance = Dense(covariance)
524
538
 
525
539
  return GaussianDistribution(
526
540
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -556,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
556
570
  if expectation_matrix is None:
557
571
  expectation_matrix = jnp.eye(self.num_inducing)
558
572
 
559
- self.expectation_vector = Static(expectation_vector)
560
- self.expectation_matrix = Static(expectation_matrix)
573
+ self.expectation_vector = Real(expectation_vector)
574
+ self.expectation_matrix = Real(expectation_matrix)
561
575
 
562
576
  def prior_kl(self) -> ScalarFloat:
563
577
  r"""Evaluate the prior KL-divergence.
@@ -595,12 +609,12 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
595
609
  # S = η₂ - η₁ η₁ᵀ
596
610
  S = expectation_matrix - jnp.outer(mu, mu)
597
611
  S = psd(Dense(S))
598
- S_dense = S.to_dense() + jnp.eye(S.shape[0]) * self.jitter
612
+ S_dense = add_jitter(S.to_dense(), self.jitter)
599
613
  S = psd(Dense(S_dense))
600
614
 
601
615
  muz = mean_function(z)
602
616
  Kzz = kernel.gram(z)
603
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
617
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
604
618
  Kzz = psd(Dense(Kzz_dense))
605
619
 
606
620
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
@@ -640,14 +654,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
640
654
 
641
655
  # S = η₂ - η₁ η₁ᵀ
642
656
  S = expectation_matrix - jnp.matmul(mu, mu.T)
643
- S = Dense(S + jnp.eye(S.shape[0]) * self.jitter)
657
+ S = Dense(add_jitter(S, self.jitter))
644
658
  S = psd(S)
645
659
 
646
660
  # S = sqrt sqrtᵀ
647
661
  sqrt = lower_cholesky(S)
648
662
 
649
663
  Kzz = kernel.gram(z)
650
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
664
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
651
665
  Kzz = psd(Dense(Kzz_dense))
652
666
  Lz = lower_cholesky(Kzz)
653
667
  muz = mean_function(z)
@@ -677,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
677
691
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
678
692
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
679
693
  )
680
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
694
+ if hasattr(covariance, "to_dense"):
695
+ covariance = covariance.to_dense()
696
+ covariance = add_jitter(covariance, self.jitter)
697
+ covariance = Dense(covariance)
681
698
 
682
699
  return GaussianDistribution(
683
700
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -734,7 +751,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
734
751
 
735
752
  Kzx = kernel.cross_covariance(z, x)
736
753
  Kzz = kernel.gram(z)
737
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
754
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
738
755
  Kzz = psd(Dense(Kzz_dense))
739
756
 
740
757
  # Lz Lzᵀ = Kzz
@@ -780,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
780
797
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
781
798
  + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
782
799
  )
783
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
800
+ if hasattr(covariance, "to_dense"):
801
+ covariance = covariance.to_dense()
802
+ covariance = add_jitter(covariance, self.jitter)
803
+ covariance = Dense(covariance)
784
804
 
785
805
  return GaussianDistribution(
786
806
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.12.0
3
+ Version: 0.13.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -11,15 +11,14 @@ License-File: LICENSE.txt
11
11
  Keywords: gaussian-processes jax machine-learning bayesian
12
12
  Classifier: Development Status :: 4 - Beta
13
13
  Classifier: Programming Language :: Python
14
- Classifier: Programming Language :: Python :: 3.10
15
14
  Classifier: Programming Language :: Python :: 3.11
16
15
  Classifier: Programming Language :: Python :: 3.12
17
16
  Classifier: Programming Language :: Python :: 3.13
18
17
  Classifier: Programming Language :: Python :: Implementation :: CPython
19
18
  Classifier: Programming Language :: Python :: Implementation :: PyPy
20
- Requires-Python: <=3.13,>=3.10
19
+ Requires-Python: >=3.11
21
20
  Requires-Dist: beartype>0.16.1
22
- Requires-Dist: flax>=0.10.0
21
+ Requires-Dist: flax>=0.12.0
23
22
  Requires-Dist: jax>=0.5.0
24
23
  Requires-Dist: jaxlib>=0.5.0
25
24
  Requires-Dist: jaxtyping>0.2.10
@@ -60,7 +59,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
60
59
  Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
61
60
  Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
62
61
  Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
63
- Requires-Dist: mkdocstrings[python]<0.28.0; extra == 'docs'
62
+ Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
64
63
  Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
65
64
  Requires-Dist: networkx>=3.0; extra == 'docs'
66
65
  Requires-Dist: pandas>=1.5.3; extra == 'docs'
@@ -80,6 +79,7 @@ Description-Content-Type: text/markdown
80
79
  [![CodeFactor](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax/badge)](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax)
81
80
  [![Netlify Status](https://api.netlify.com/api/v1/badges/d3950e6f-321f-4508-9e52-426b5dae2715/deploy-status)](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
82
81
  [![PyPI version](https://badge.fury.io/py/GPJax.svg)](https://badge.fury.io/py/GPJax)
82
+ [![Conda Version](https://img.shields.io/conda/vn/conda-forge/gpjax.svg)](https://anaconda.org/conda-forge/gpjax)
83
83
  [![DOI](https://joss.theoj.org/papers/10.21105/joss.04455/status.svg)](https://doi.org/10.21105/joss.04455)
84
84
  [![Downloads](https://pepy.tech/badge/gpjax)](https://pepy.tech/project/gpjax)
85
85
  [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
@@ -126,18 +126,9 @@ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw
126
126
  where we can discuss the development of GPJax and broader support for Gaussian
127
127
  process modelling.
128
128
 
129
-
130
- ## Governance
131
-
132
- GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
133
- project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
134
- [henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
135
- [thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
136
- governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
137
- [the contributors to
138
- GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
139
- helped to shape GPJax into the package it is today.
140
-
129
+ We appreciate all [the contributors to
130
+ GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
131
+ GPJax into the package it is today.
141
132
 
142
133
  # Supported methods and interfaces
143
134
 
@@ -183,13 +174,21 @@ jupytext --to py:percent example.ipynb
183
174
 
184
175
  ## Stable version
185
176
 
186
- The latest stable version of GPJax can be installed via
187
- pip:
177
+ The latest stable version of GPJax can be installed from [PyPI](https://pypi.org/project/gpjax/):
188
178
 
189
179
  ```bash
190
180
  pip install gpjax
191
181
  ```
192
182
 
183
+ or from [conda-forge](https://github.com/conda-forge/gpjax-feedstock):
184
+
185
+ ```bash
186
+ # with Pixi
187
+ pixi add gpjax
188
+ # or with conda
189
+ conda install --channel conda-forge gpjax
190
+ ```
191
+
193
192
  > **Note**
194
193
  >
195
194
  > We recommend you check your installation version:
@@ -208,7 +207,7 @@ pip install gpjax
208
207
  >
209
208
  > We advise you create virtual environment before installing:
210
209
  > ```
211
- > conda create -n gpjax_experimental python=3.10.0
210
+ > conda create -n gpjax_experimental python=3.11.0
212
211
  > conda activate gpjax_experimental
213
212
  > ```
214
213
 
@@ -218,13 +217,14 @@ configuration in development mode.
218
217
  ```bash
219
218
  git clone https://github.com/JaxGaussianProcesses/GPJax.git
220
219
  cd GPJax
220
+ uv venv
221
221
  uv sync --extra dev
222
222
  ```
223
223
 
224
224
  > We recommend you check your installation passes the supplied unit tests:
225
225
  >
226
226
  > ```python
227
- > uv run pytest --beartype-packages='gpjax'
227
+ > uv run poe all-tests
228
228
  > ```
229
229
 
230
230
  # Citing GPJax