gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,10 +40,9 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.11.2"
43
+ __version__ = "0.12.2"
44
44
 
45
45
  __all__ = [
46
- "base",
47
46
  "gps",
48
47
  "integrators",
49
48
  "kernels",
@@ -55,8 +54,6 @@ __all__ = [
55
54
  "Dataset",
56
55
  "cite",
57
56
  "fit",
58
- "Module",
59
- "param_field",
60
57
  "fit_lbfgs",
61
58
  "fit_scipy",
62
59
  ]
gpjax/distributions.py CHANGED
@@ -17,11 +17,6 @@
17
17
  from beartype.typing import (
18
18
  Optional,
19
19
  )
20
- import cola
21
- from cola.linalg.decompositions import Cholesky
22
- from cola.ops import (
23
- LinearOperator,
24
- )
25
20
  from jax import vmap
26
21
  import jax.numpy as jnp
27
22
  import jax.random as jr
@@ -30,7 +25,14 @@ from numpyro.distributions import constraints
30
25
  from numpyro.distributions.distribution import Distribution
31
26
  from numpyro.distributions.util import is_prng_key
32
27
 
33
- from gpjax.lower_cholesky import lower_cholesky
28
+ from gpjax.linalg.operations import (
29
+ diag,
30
+ logdet,
31
+ lower_cholesky,
32
+ solve,
33
+ )
34
+ from gpjax.linalg.operators import LinearOperator
35
+ from gpjax.linalg.utils import psd
34
36
  from gpjax.typing import (
35
37
  Array,
36
38
  ScalarFloat,
@@ -47,7 +49,7 @@ class GaussianDistribution(Distribution):
47
49
  validate_args=None,
48
50
  ):
49
51
  self.loc = loc
50
- self.scale = cola.PSD(scale)
52
+ self.scale = psd(scale)
51
53
  batch_shape = ()
52
54
  event_shape = jnp.shape(self.loc)
53
55
  super().__init__(batch_shape, event_shape, validate_args=validate_args)
@@ -76,13 +78,12 @@ class GaussianDistribution(Distribution):
76
78
  @property
77
79
  def variance(self) -> Float[Array, " N"]:
78
80
  r"""Calculates the variance."""
79
- return cola.diag(self.scale)
81
+ return diag(self.scale)
80
82
 
81
83
  def entropy(self) -> ScalarFloat:
82
84
  r"""Calculates the entropy of the distribution."""
83
85
  return 0.5 * (
84
- self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
85
- + cola.logdet(self.scale, Cholesky(), Cholesky())
86
+ self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + logdet(self.scale)
86
87
  )
87
88
 
88
89
  def median(self) -> Float[Array, " N"]:
@@ -104,7 +105,7 @@ class GaussianDistribution(Distribution):
104
105
 
105
106
  def stddev(self) -> Float[Array, " N"]:
106
107
  r"""Calculates the standard deviation."""
107
- return jnp.sqrt(cola.diag(self.scale))
108
+ return jnp.sqrt(diag(self.scale))
108
109
 
109
110
  # @property
110
111
  # def event_shape(self) -> Tuple:
@@ -129,9 +130,7 @@ class GaussianDistribution(Distribution):
129
130
 
130
131
  # compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
131
132
  return -0.5 * (
132
- n * jnp.log(2.0 * jnp.pi)
133
- + cola.logdet(sigma, Cholesky(), Cholesky())
134
- + diff.T @ cola.solve(sigma, diff, Cholesky())
133
+ n * jnp.log(2.0 * jnp.pi) + logdet(sigma) + diff.T @ solve(sigma, diff)
135
134
  )
136
135
 
137
136
  # def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
@@ -219,53 +218,14 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
219
218
 
220
219
  # trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
221
220
  trace = _frobenius_norm_squared(
222
- cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky())
221
+ solve(sqrt_p, sqrt_q.to_dense())
223
222
  ) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.
224
223
 
225
224
  # Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
226
- mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky())))
225
+ mahalanobis = jnp.sum(jnp.square(solve(sqrt_p, diff)))
227
226
 
228
227
  # KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
229
- return (
230
- mahalanobis
231
- - n_dim
232
- - cola.logdet(sigma_q, Cholesky(), Cholesky())
233
- + cola.logdet(sigma_p, Cholesky(), Cholesky())
234
- + trace
235
- ) / 2.0
236
-
237
-
238
- # def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
239
- # r"""Checks that the inputs are correct."""
240
- # if loc is None and scale is None:
241
- # raise ValueError("At least one of `loc` or `scale` must be specified.")
242
-
243
- # if loc is not None and loc.ndim < 1:
244
- # raise ValueError("The parameter `loc` must have at least one dimension.")
245
-
246
- # if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
247
- # raise ValueError(
248
- # "The `scale` must have at least two dimensions, but "
249
- # f"`scale.shape = {scale.shape}`."
250
- # )
251
-
252
- # if scale is not None and not isinstance(scale, LinearOperator):
253
- # raise ValueError(
254
- # f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
255
- # )
256
-
257
- # if scale is not None and (scale.shape[-1] != scale.shape[-2]):
258
- # raise ValueError(
259
- # f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
260
- # )
261
-
262
- # if loc is not None:
263
- # num_dims = loc.shape[-1]
264
- # if scale is not None and (scale.shape[-1] != num_dims):
265
- # raise ValueError(
266
- # f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
267
- # f"`scale.shape = {scale.shape}`."
268
- # )
228
+ return (mahalanobis - n_dim - logdet(sigma_q) + logdet(sigma_p) + trace) / 2.0
269
229
 
270
230
 
271
231
  __all__ = [
gpjax/fit.py CHANGED
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
48
48
  train_data: Dataset,
49
49
  optim: ox.GradientTransformation,
50
50
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
51
+ trainable: nnx.filterlib.Filter = Parameter,
51
52
  key: KeyArray = jr.PRNGKey(42),
52
53
  num_iters: int = 100,
53
54
  batch_size: int = -1,
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
65
66
  >>> import jax.random as jr
66
67
  >>> import optax as ox
67
68
  >>> import gpjax as gpx
68
- >>> from gpjax.parameters import PositiveReal, Static
69
+ >>> from gpjax.parameters import PositiveReal
69
70
  >>>
70
71
  >>> # (1) Create a dataset:
71
72
  >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
75
76
  >>> class LinearModel(nnx.Module):
76
77
  >>> def __init__(self, weight: float, bias: float):
77
78
  >>> self.weight = PositiveReal(weight)
78
- >>> self.bias = Static(bias)
79
+ >>> self.bias = bias
79
80
  >>>
80
81
  >>> def __call__(self, x):
81
- >>> return self.weight.value * x + self.bias.value
82
+ >>> return self.weight.value * x + self.bias
82
83
  >>>
83
84
  >>> model = LinearModel(weight=1.0, bias=1.0)
84
85
  >>>
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
100
101
  train_data (Dataset): The training data to be used for the optimisation.
101
102
  optim (GradientTransformation): The Optax optimiser that is to be used for
102
103
  learning a parameter set.
104
+ trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
105
+ Defaults to nnx.Param (all Parameter instances).
103
106
  num_iters (int): The number of optimisation steps to run. Defaults
104
107
  to 100.
105
108
  batch_size (int): The size of the mini-batch to use. Defaults to -1
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
127
130
  _check_verbose(verbose)
128
131
 
129
132
  # Model state filtering
130
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
133
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
131
134
 
132
135
  # Parameters bijection to unconstrained space
133
136
  if params_bijection is not None:
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
182
185
  model: Model,
183
186
  objective: Objective,
184
187
  train_data: Dataset,
188
+ trainable: nnx.filterlib.Filter = Parameter,
185
189
  max_iters: int = 500,
186
190
  verbose: bool = True,
187
191
  safe: bool = True,
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
210
214
  _check_verbose(verbose)
211
215
 
212
216
  # Model state filtering
213
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
217
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
214
218
 
215
219
  # Parameters bijection to unconstrained space
216
220
  params = transform(params, DEFAULT_BIJECTION, inverse=True)
@@ -258,6 +262,7 @@ def fit_lbfgs(
258
262
  objective: Objective,
259
263
  train_data: Dataset,
260
264
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
265
+ trainable: nnx.filterlib.Filter = Parameter,
261
266
  max_iters: int = 100,
262
267
  safe: bool = True,
263
268
  max_linesearch_steps: int = 32,
@@ -290,7 +295,7 @@ def fit_lbfgs(
290
295
  _check_num_iters(max_iters)
291
296
 
292
297
  # Model state filtering
293
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
298
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
294
299
 
295
300
  # Parameters bijection to unconstrained space
296
301
  if params_bijection is not None:
gpjax/gps.py CHANGED
@@ -16,14 +16,9 @@
16
16
  from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
- from cola.annotations import PSD
20
- from cola.linalg.algorithm_base import Algorithm
21
- from cola.linalg.decompositions.decompositions import Cholesky
22
- from cola.linalg.inverse.inv import solve
23
- from cola.ops.operators import I_like
24
- from flax import nnx
25
19
  import jax.numpy as jnp
26
20
  import jax.random as jr
21
+ from flax import nnx
27
22
  from jaxtyping import (
28
23
  Float,
29
24
  Num,
@@ -38,12 +33,17 @@ from gpjax.likelihoods import (
38
33
  Gaussian,
39
34
  NonGaussian,
40
35
  )
41
- from gpjax.lower_cholesky import lower_cholesky
36
+ from gpjax.linalg import (
37
+ Dense,
38
+ psd,
39
+ solve,
40
+ )
41
+ from gpjax.linalg.operations import lower_cholesky
42
+ from gpjax.linalg.utils import add_jitter
42
43
  from gpjax.mean_functions import AbstractMeanFunction
43
44
  from gpjax.parameters import (
44
45
  Parameter,
45
46
  Real,
46
- Static,
47
47
  )
48
48
  from gpjax.typing import (
49
49
  Array,
@@ -77,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
77
77
  self.mean_function = mean_function
78
78
  self.jitter = jitter
79
79
 
80
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
80
+ def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
81
81
  r"""Evaluate the Gaussian process at the given points.
82
82
 
83
83
  The output of this function is a
@@ -90,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
90
90
  `__call__` method and should instead define a `predict` method.
91
91
 
92
92
  Args:
93
- *args (Any): The arguments to pass to the GP's `predict` method.
94
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
93
+ test_inputs: Input locations where the GP should be evaluated.
95
94
 
96
95
  Returns:
97
96
  GaussianDistribution: A multivariate normal random variable representation
98
97
  of the Gaussian process.
99
98
  """
100
- return self.predict(*args, **kwargs)
99
+ return self.predict(test_inputs)
101
100
 
102
101
  @abstractmethod
103
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
102
+ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
104
103
  r"""Evaluate the predictive distribution.
105
104
 
106
105
  Compute the latent function's multivariate normal distribution for a
@@ -108,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
108
107
  this method must be implemented.
109
108
 
110
109
  Args:
111
- *args (Any): Arguments to the predict method.
112
- **kwargs (Any): Keyword arguments to the predict method.
110
+ test_inputs: Input locations where the GP should be evaluated.
113
111
 
114
112
  Returns:
115
113
  GaussianDistribution: A multivariate normal random variable representation
@@ -248,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
248
246
  GaussianDistribution: A multivariate normal random variable representation
249
247
  of the Gaussian process.
250
248
  """
251
- x = test_inputs
252
- mx = self.mean_function(x)
253
- Kxx = self.kernel.gram(x)
254
- Kxx += I_like(Kxx) * self.jitter
255
- Kxx = PSD(Kxx)
249
+ mean_at_test = self.mean_function(test_inputs)
250
+ Kxx = self.kernel.gram(test_inputs)
251
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
252
+ Kxx = psd(Dense(Kxx_dense))
256
253
 
257
- return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)
254
+ return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
258
255
 
259
256
  def sample_approx(
260
257
  self,
@@ -315,15 +312,13 @@ class Prior(AbstractPrior[M, K]):
315
312
  if (not isinstance(num_samples, int)) or num_samples <= 0:
316
313
  raise ValueError("num_samples must be a positive integer")
317
314
 
318
- # sample fourier features
319
315
  fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)
320
316
 
321
- # sample fourier weights
322
- feature_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
317
+ feature_weights = jr.normal(key, [num_samples, 2 * num_features])
323
318
 
324
319
  def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
325
- feature_evals = fourier_feature_fn(test_inputs) # [N, L]
326
- evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B]
320
+ feature_evals = fourier_feature_fn(test_inputs)
321
+ evaluated_sample = jnp.inner(feature_evals, feature_weights)
327
322
  return self.mean_function(test_inputs) + evaluated_sample
328
323
 
329
324
  return sample_fn
@@ -360,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
360
355
  self.likelihood = likelihood
361
356
  self.jitter = jitter
362
357
 
363
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
358
+ def __call__(
359
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
360
+ ) -> GaussianDistribution:
364
361
  r"""Evaluate the Gaussian process posterior at the given points.
365
362
 
366
363
  The output of this function is a
@@ -369,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
369
366
  evaluated and the distribution can be sampled.
370
367
 
371
368
  Under the hood, `__call__` is calling the objects `predict` method. For this
372
- reasons, classes inheriting the `AbstractPrior` class, should not overwrite the
369
+ reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
373
370
  `__call__` method and should instead define a `predict` method.
374
371
 
375
372
  Args:
376
- *args (Any): The arguments to pass to the GP's `predict` method.
377
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
373
+ test_inputs: Input locations where the GP should be evaluated.
374
+ train_data: Training dataset to condition on.
378
375
 
379
376
  Returns:
380
377
  GaussianDistribution: A multivariate normal random variable representation
381
378
  of the Gaussian process.
382
379
  """
383
- return self.predict(*args, **kwargs)
380
+ return self.predict(test_inputs, train_data)
384
381
 
385
382
  @abstractmethod
386
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
383
+ def predict(
384
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
385
+ ) -> GaussianDistribution:
387
386
  r"""Compute the latent function's multivariate normal distribution for a
388
- given set of parameters. For any class inheriting the `AbstractPrior` class,
387
+ given set of parameters. For any class inheriting the `AbstractPosterior` class,
389
388
  this method must be implemented.
390
389
 
391
390
  Args:
392
- *args (Any): Arguments to the predict method.
393
- **kwargs (Any): Keyword arguments to the predict method.
391
+ test_inputs: Input locations where the GP should be evaluated.
392
+ train_data: Training dataset to condition on.
394
393
 
395
394
  Returns:
396
395
  GaussianDistribution: A multivariate normal random variable representation
@@ -504,24 +503,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
504
503
 
505
504
  # Precompute Gram matrix, Kxx, at training inputs, x
506
505
  Kxx = self.prior.kernel.gram(x)
507
- Kxx += I_like(Kxx) * self.jitter
506
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
507
+ Kxx = Dense(Kxx_dense)
508
508
 
509
- # Σ = Kxx + Io²
510
- Sigma = Kxx + I_like(Kxx) * obs_noise
511
- Sigma = PSD(Sigma)
509
+ Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
510
+ Sigma = psd(Dense(Sigma_dense))
511
+ L_sigma = lower_cholesky(Sigma)
512
512
 
513
513
  mean_t = self.prior.mean_function(t)
514
514
  Ktt = self.prior.kernel.gram(t)
515
515
  Kxt = self.prior.kernel.cross_covariance(x, t)
516
- Sigma_inv_Kxt = solve(Sigma, Kxt, Cholesky())
517
516
 
518
- # μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
519
- mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)
517
+ L_inv_Kxt = solve(L_sigma, Kxt)
518
+ L_inv_y_diff = solve(L_sigma, y - mx)
519
+
520
+ mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
520
521
 
521
- # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
522
- covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
523
- covariance += I_like(covariance) * self.prior.jitter
524
- covariance = PSD(covariance)
522
+ covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
523
+ covariance = add_jitter(covariance, self.prior.jitter)
524
+ covariance = psd(Dense(covariance))
525
525
 
526
526
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
527
527
 
@@ -531,7 +531,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
531
531
  train_data: Dataset,
532
532
  key: KeyArray,
533
533
  num_features: int | None = 100,
534
- solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
535
534
  ) -> FunctionalSample:
536
535
  r"""Draw approximate samples from the Gaussian process posterior.
537
536
 
@@ -565,11 +564,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
565
564
  key (KeyArray): The random seed used for the sample(s).
566
565
  num_features (int): The number of features used when approximating the
567
566
  kernel.
568
- solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
569
- the inverse of the covariance matrix. See the
570
- [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
571
- for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
572
- matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
573
567
 
574
568
  Returns:
575
569
  FunctionalSample: A function representing an approximate sample from the Gaussian
@@ -581,31 +575,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
581
575
  # sample fourier features
582
576
  fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)
583
577
 
584
- # sample fourier weights
585
- fourier_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
578
+ fourier_weights = jr.normal(key, [num_samples, 2 * num_features])
586
579
 
587
- # sample weights v for canonical features
588
- # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
589
580
  obs_var = self.likelihood.obs_stddev.value**2
590
- Kxx = self.prior.kernel.gram(train_data.X) # [N, N]
591
- Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N]
592
- eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples]) # [N, B]
593
- y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
581
+ Kxx = self.prior.kernel.gram(train_data.X)
582
+ Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
583
+ eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
584
+ y = train_data.y - self.prior.mean_function(train_data.X)
594
585
  Phi = fourier_feature_fn(train_data.X)
595
586
  canonical_weights = solve(
596
587
  Sigma,
597
588
  y + eps - jnp.inner(Phi, fourier_weights),
598
- solver_algorithm,
599
589
  ) # [N, B]
600
590
 
601
591
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
602
- fourier_features = fourier_feature_fn(test_inputs) # [n, L]
603
- weight_space_contribution = jnp.inner(
604
- fourier_features, fourier_weights
605
- ) # [n, B]
592
+ fourier_features = fourier_feature_fn(test_inputs)
593
+ weight_space_contribution = jnp.inner(fourier_features, fourier_weights)
606
594
  canonical_features = self.prior.kernel.cross_covariance(
607
595
  test_inputs, train_data.X
608
- ) # [n, N]
596
+ )
609
597
  function_space_contribution = jnp.matmul(
610
598
  canonical_features, canonical_weights
611
599
  )
@@ -657,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
657
645
 
658
646
  # TODO: static or intermediate?
659
647
  self.latent = latent if isinstance(latent, Parameter) else Real(latent)
660
- self.key = Static(key)
648
+ self.key = key
661
649
 
662
650
  def predict(
663
651
  self, test_inputs: Num[Array, "N D"], train_data: Dataset
@@ -689,8 +677,8 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
689
677
 
690
678
  # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
691
679
  Kxx = kernel.gram(x)
692
- Kxx += I_like(Kxx) * self.prior.jitter
693
- Kxx = PSD(Kxx)
680
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
681
+ Kxx = psd(Dense(Kxx_dense))
694
682
  Lx = lower_cholesky(Kxx)
695
683
 
696
684
  # Unpack test inputs
@@ -702,7 +690,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
702
690
  mean_t = mean_function(t)
703
691
 
704
692
  # Lx⁻¹ Kxt
705
- Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky())
693
+ Lx_inv_Kxt = solve(Lx, Ktx.T)
706
694
 
707
695
  # Whitened function values, wx, corresponding to the inputs, x
708
696
  wx = self.latent.value
@@ -711,9 +699,9 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
711
699
  mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
712
700
 
713
701
  # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
714
- covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
715
- covariance += I_like(covariance) * self.prior.jitter
716
- covariance = PSD(covariance)
702
+ covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
703
+ covariance = add_jitter(covariance, self.prior.jitter)
704
+ covariance = psd(Dense(covariance))
717
705
 
718
706
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
719
707
 
@@ -7,7 +7,6 @@ from jaxtyping import Float
7
7
  from gpjax.kernels.base import AbstractKernel
8
8
  from gpjax.kernels.computations import BasisFunctionComputation
9
9
  from gpjax.kernels.stationary.base import StationaryKernel
10
- from gpjax.parameters import Static
11
10
  from gpjax.typing import (
12
11
  Array,
13
12
  KeyArray,
@@ -66,10 +65,8 @@ class RFF(AbstractKernel):
66
65
  "Please specify the n_dims argument for the base kernel."
67
66
  )
68
67
 
69
- self.frequencies = Static(
70
- self.base_kernel.spectral_density.sample(
71
- key=key, sample_shape=(self.num_basis_fns, n_dims)
72
- )
68
+ self.frequencies = self.base_kernel.spectral_density.sample(
69
+ key=key, sample_shape=(self.num_basis_fns, n_dims)
73
70
  )
74
71
  self.name = f"{self.base_kernel.name} (RFF)"
75
72
 
gpjax/kernels/base.py CHANGED
@@ -17,7 +17,6 @@ import abc
17
17
  import functools as ft
18
18
 
19
19
  import beartype.typing as tp
20
- from cola.ops.operator_base import LinearOperator
21
20
  from flax import nnx
22
21
  import jax.numpy as jnp
23
22
  from jaxtyping import (
@@ -29,10 +28,10 @@ from gpjax.kernels.computations import (
29
28
  AbstractKernelComputation,
30
29
  DenseKernelComputation,
31
30
  )
31
+ from gpjax.linalg import LinearOperator
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
35
- Static,
36
35
  )
37
36
  from gpjax.typing import (
38
37
  Array,
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
221
220
  def __init__(
222
221
  self,
223
222
  active_dims: tp.Union[list[int], slice, None] = None,
224
- constant: tp.Union[
225
- ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226
- ] = jnp.array(0.0),
223
+ constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
227
224
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
228
225
  ):
229
226
  if isinstance(constant, Parameter):
@@ -16,11 +16,6 @@
16
16
  import abc
17
17
  import typing as tp
18
18
 
19
- from cola.annotations import PSD
20
- from cola.ops.operators import (
21
- Dense,
22
- Diagonal,
23
- )
24
19
  from jax import vmap
25
20
  from jaxtyping import (
26
21
  Float,
@@ -28,6 +23,11 @@ from jaxtyping import (
28
23
  )
29
24
 
30
25
  import gpjax
26
+ from gpjax.linalg import (
27
+ Dense,
28
+ Diagonal,
29
+ psd,
30
+ )
31
31
  from gpjax.typing import Array
32
32
 
33
33
  K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
@@ -69,7 +69,7 @@ class AbstractKernelComputation:
69
69
  The Gram covariance of the kernel function as a linear operator.
70
70
  """
71
71
  Kxx = self.cross_covariance(kernel, x, x)
72
- return PSD(Dense(Kxx))
72
+ return psd(Dense(Kxx))
73
73
 
74
74
  @abc.abstractmethod
75
75
  def _cross_covariance(
@@ -93,7 +93,7 @@ class AbstractKernelComputation:
93
93
  return self._cross_covariance(kernel, x, y)
94
94
 
95
95
  def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
96
- return PSD(Diagonal(diag=vmap(lambda x: kernel(x, x))(inputs)))
96
+ return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs)))
97
97
 
98
98
  def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
99
99
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -1,18 +1,19 @@
1
1
  import typing as tp
2
2
 
3
- from cola.annotations import PSD
4
- from cola.ops.operators import Dense
5
3
  import jax.numpy as jnp
6
4
  from jaxtyping import Float
7
5
 
8
6
  import gpjax
9
7
  from gpjax.kernels.computations.base import AbstractKernelComputation
8
+ from gpjax.linalg import (
9
+ Dense,
10
+ Diagonal,
11
+ psd,
12
+ )
10
13
  from gpjax.typing import Array
11
14
 
12
15
  K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") # noqa: F821
13
16
 
14
- from cola.ops import Diagonal
15
-
16
17
  # TODO: Use low rank linear operator!
17
18
 
18
19
 
@@ -28,7 +29,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
28
29
 
29
30
  def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
30
31
  z1 = self.compute_features(kernel, inputs)
31
- return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
32
+ return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
32
33
 
33
34
  def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
34
35
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -56,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
56
57
  Returns:
57
58
  A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
58
59
  """
59
- frequencies = kernel.frequencies.value
60
+ frequencies = kernel.frequencies
60
61
  scaling_factor = kernel.base_kernel.lengthscale.value
61
62
  z = jnp.matmul(x, (frequencies / scaling_factor).T)
62
63
  z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)