gpjax 0.11.2__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.11.2"
43
+ __version__ = "0.12.0"
44
44
 
45
45
  __all__ = [
46
46
  "base",
gpjax/distributions.py CHANGED
@@ -17,11 +17,6 @@
17
17
  from beartype.typing import (
18
18
  Optional,
19
19
  )
20
- import cola
21
- from cola.linalg.decompositions import Cholesky
22
- from cola.ops import (
23
- LinearOperator,
24
- )
25
20
  from jax import vmap
26
21
  import jax.numpy as jnp
27
22
  import jax.random as jr
@@ -30,7 +25,14 @@ from numpyro.distributions import constraints
30
25
  from numpyro.distributions.distribution import Distribution
31
26
  from numpyro.distributions.util import is_prng_key
32
27
 
33
- from gpjax.lower_cholesky import lower_cholesky
28
+ from gpjax.linalg.operations import (
29
+ diag,
30
+ logdet,
31
+ lower_cholesky,
32
+ solve,
33
+ )
34
+ from gpjax.linalg.operators import LinearOperator
35
+ from gpjax.linalg.utils import psd
34
36
  from gpjax.typing import (
35
37
  Array,
36
38
  ScalarFloat,
@@ -47,7 +49,7 @@ class GaussianDistribution(Distribution):
47
49
  validate_args=None,
48
50
  ):
49
51
  self.loc = loc
50
- self.scale = cola.PSD(scale)
52
+ self.scale = psd(scale)
51
53
  batch_shape = ()
52
54
  event_shape = jnp.shape(self.loc)
53
55
  super().__init__(batch_shape, event_shape, validate_args=validate_args)
@@ -76,13 +78,12 @@ class GaussianDistribution(Distribution):
76
78
  @property
77
79
  def variance(self) -> Float[Array, " N"]:
78
80
  r"""Calculates the variance."""
79
- return cola.diag(self.scale)
81
+ return diag(self.scale)
80
82
 
81
83
  def entropy(self) -> ScalarFloat:
82
84
  r"""Calculates the entropy of the distribution."""
83
85
  return 0.5 * (
84
- self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
85
- + cola.logdet(self.scale, Cholesky(), Cholesky())
86
+ self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + logdet(self.scale)
86
87
  )
87
88
 
88
89
  def median(self) -> Float[Array, " N"]:
@@ -104,7 +105,7 @@ class GaussianDistribution(Distribution):
104
105
 
105
106
  def stddev(self) -> Float[Array, " N"]:
106
107
  r"""Calculates the standard deviation."""
107
- return jnp.sqrt(cola.diag(self.scale))
108
+ return jnp.sqrt(diag(self.scale))
108
109
 
109
110
  # @property
110
111
  # def event_shape(self) -> Tuple:
@@ -129,9 +130,7 @@ class GaussianDistribution(Distribution):
129
130
 
130
131
  # compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
131
132
  return -0.5 * (
132
- n * jnp.log(2.0 * jnp.pi)
133
- + cola.logdet(sigma, Cholesky(), Cholesky())
134
- + diff.T @ cola.solve(sigma, diff, Cholesky())
133
+ n * jnp.log(2.0 * jnp.pi) + logdet(sigma) + diff.T @ solve(sigma, diff)
135
134
  )
136
135
 
137
136
  # def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
@@ -219,53 +218,14 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
219
218
 
220
219
  # trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
221
220
  trace = _frobenius_norm_squared(
222
- cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky())
221
+ solve(sqrt_p, sqrt_q.to_dense())
223
222
  ) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.
224
223
 
225
224
  # Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
226
- mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky())))
225
+ mahalanobis = jnp.sum(jnp.square(solve(sqrt_p, diff)))
227
226
 
228
227
  # KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
229
- return (
230
- mahalanobis
231
- - n_dim
232
- - cola.logdet(sigma_q, Cholesky(), Cholesky())
233
- + cola.logdet(sigma_p, Cholesky(), Cholesky())
234
- + trace
235
- ) / 2.0
236
-
237
-
238
- # def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
239
- # r"""Checks that the inputs are correct."""
240
- # if loc is None and scale is None:
241
- # raise ValueError("At least one of `loc` or `scale` must be specified.")
242
-
243
- # if loc is not None and loc.ndim < 1:
244
- # raise ValueError("The parameter `loc` must have at least one dimension.")
245
-
246
- # if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
247
- # raise ValueError(
248
- # "The `scale` must have at least two dimensions, but "
249
- # f"`scale.shape = {scale.shape}`."
250
- # )
251
-
252
- # if scale is not None and not isinstance(scale, LinearOperator):
253
- # raise ValueError(
254
- # f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
255
- # )
256
-
257
- # if scale is not None and (scale.shape[-1] != scale.shape[-2]):
258
- # raise ValueError(
259
- # f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
260
- # )
261
-
262
- # if loc is not None:
263
- # num_dims = loc.shape[-1]
264
- # if scale is not None and (scale.shape[-1] != num_dims):
265
- # raise ValueError(
266
- # f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
267
- # f"`scale.shape = {scale.shape}`."
268
- # )
228
+ return (mahalanobis - n_dim - logdet(sigma_q) + logdet(sigma_p) + trace) / 2.0
269
229
 
270
230
 
271
231
  __all__ = [
gpjax/gps.py CHANGED
@@ -16,11 +16,6 @@
16
16
  from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
- from cola.annotations import PSD
20
- from cola.linalg.algorithm_base import Algorithm
21
- from cola.linalg.decompositions.decompositions import Cholesky
22
- from cola.linalg.inverse.inv import solve
23
- from cola.ops.operators import I_like
24
19
  from flax import nnx
25
20
  import jax.numpy as jnp
26
21
  import jax.random as jr
@@ -38,7 +33,13 @@ from gpjax.likelihoods import (
38
33
  Gaussian,
39
34
  NonGaussian,
40
35
  )
41
- from gpjax.lower_cholesky import lower_cholesky
36
+ from gpjax.linalg import (
37
+ Dense,
38
+ Identity,
39
+ psd,
40
+ solve,
41
+ )
42
+ from gpjax.linalg.operations import lower_cholesky
42
43
  from gpjax.mean_functions import AbstractMeanFunction
43
44
  from gpjax.parameters import (
44
45
  Parameter,
@@ -251,8 +252,8 @@ class Prior(AbstractPrior[M, K]):
251
252
  x = test_inputs
252
253
  mx = self.mean_function(x)
253
254
  Kxx = self.kernel.gram(x)
254
- Kxx += I_like(Kxx) * self.jitter
255
- Kxx = PSD(Kxx)
255
+ Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
256
+ Kxx = psd(Dense(Kxx_dense))
256
257
 
257
258
  return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)
258
259
 
@@ -315,15 +316,13 @@ class Prior(AbstractPrior[M, K]):
315
316
  if (not isinstance(num_samples, int)) or num_samples <= 0:
316
317
  raise ValueError("num_samples must be a positive integer")
317
318
 
318
- # sample fourier features
319
319
  fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)
320
320
 
321
- # sample fourier weights
322
- feature_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
321
+ feature_weights = jr.normal(key, [num_samples, 2 * num_features])
323
322
 
324
323
  def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
325
- feature_evals = fourier_feature_fn(test_inputs) # [N, L]
326
- evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B]
324
+ feature_evals = fourier_feature_fn(test_inputs)
325
+ evaluated_sample = jnp.inner(feature_evals, feature_weights)
327
326
  return self.mean_function(test_inputs) + evaluated_sample
328
327
 
329
328
  return sample_fn
@@ -504,24 +503,23 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
504
503
 
505
504
  # Precompute Gram matrix, Kxx, at training inputs, x
506
505
  Kxx = self.prior.kernel.gram(x)
507
- Kxx += I_like(Kxx) * self.jitter
506
+ Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
507
+ Kxx = Dense(Kxx_dense)
508
508
 
509
- # Σ = Kxx + Io²
510
- Sigma = Kxx + I_like(Kxx) * obs_noise
511
- Sigma = PSD(Sigma)
509
+ Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
510
+ Sigma = psd(Dense(Sigma_dense))
512
511
 
513
512
  mean_t = self.prior.mean_function(t)
514
513
  Ktt = self.prior.kernel.gram(t)
515
514
  Kxt = self.prior.kernel.cross_covariance(x, t)
516
- Sigma_inv_Kxt = solve(Sigma, Kxt, Cholesky())
515
+ Sigma_inv_Kxt = solve(Sigma, Kxt)
517
516
 
518
- # μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
519
517
  mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)
520
518
 
521
519
  # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
522
- covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
523
- covariance += I_like(covariance) * self.prior.jitter
524
- covariance = PSD(covariance)
520
+ covariance = Ktt.to_dense() - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
521
+ covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
522
+ covariance = psd(Dense(covariance))
525
523
 
526
524
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
527
525
 
@@ -531,7 +529,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
531
529
  train_data: Dataset,
532
530
  key: KeyArray,
533
531
  num_features: int | None = 100,
534
- solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
535
532
  ) -> FunctionalSample:
536
533
  r"""Draw approximate samples from the Gaussian process posterior.
537
534
 
@@ -565,11 +562,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
565
562
  key (KeyArray): The random seed used for the sample(s).
566
563
  num_features (int): The number of features used when approximating the
567
564
  kernel.
568
- solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
569
- the inverse of the covariance matrix. See the
570
- [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
571
- for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
572
- matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
573
565
 
574
566
  Returns:
575
567
  FunctionalSample: A function representing an approximate sample from the Gaussian
@@ -581,31 +573,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
581
573
  # sample fourier features
582
574
  fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)
583
575
 
584
- # sample fourier weights
585
- fourier_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
576
+ fourier_weights = jr.normal(key, [num_samples, 2 * num_features])
586
577
 
587
- # sample weights v for canonical features
588
- # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
589
578
  obs_var = self.likelihood.obs_stddev.value**2
590
- Kxx = self.prior.kernel.gram(train_data.X) # [N, N]
591
- Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N]
592
- eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples]) # [N, B]
593
- y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
579
+ Kxx = self.prior.kernel.gram(train_data.X)
580
+ Sigma = Kxx + jnp.eye(Kxx.shape[0]) * (obs_var + self.jitter)
581
+ eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
582
+ y = train_data.y - self.prior.mean_function(train_data.X)
594
583
  Phi = fourier_feature_fn(train_data.X)
595
584
  canonical_weights = solve(
596
585
  Sigma,
597
586
  y + eps - jnp.inner(Phi, fourier_weights),
598
- solver_algorithm,
599
587
  ) # [N, B]
600
588
 
601
589
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
602
- fourier_features = fourier_feature_fn(test_inputs) # [n, L]
603
- weight_space_contribution = jnp.inner(
604
- fourier_features, fourier_weights
605
- ) # [n, B]
590
+ fourier_features = fourier_feature_fn(test_inputs)
591
+ weight_space_contribution = jnp.inner(fourier_features, fourier_weights)
606
592
  canonical_features = self.prior.kernel.cross_covariance(
607
593
  test_inputs, train_data.X
608
- ) # [n, N]
594
+ )
609
595
  function_space_contribution = jnp.matmul(
610
596
  canonical_features, canonical_weights
611
597
  )
@@ -689,8 +675,8 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
689
675
 
690
676
  # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
691
677
  Kxx = kernel.gram(x)
692
- Kxx += I_like(Kxx) * self.prior.jitter
693
- Kxx = PSD(Kxx)
678
+ Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * self.prior.jitter
679
+ Kxx = psd(Dense(Kxx_dense))
694
680
  Lx = lower_cholesky(Kxx)
695
681
 
696
682
  # Unpack test inputs
@@ -702,7 +688,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
702
688
  mean_t = mean_function(t)
703
689
 
704
690
  # Lx⁻¹ Kxt
705
- Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky())
691
+ Lx_inv_Kxt = solve(Lx, Ktx.T)
706
692
 
707
693
  # Whitened function values, wx, corresponding to the inputs, x
708
694
  wx = self.latent.value
@@ -711,9 +697,9 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
711
697
  mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
712
698
 
713
699
  # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
714
- covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
715
- covariance += I_like(covariance) * self.prior.jitter
716
- covariance = PSD(covariance)
700
+ covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
701
+ covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
702
+ covariance = psd(Dense(covariance))
717
703
 
718
704
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
719
705
 
gpjax/kernels/base.py CHANGED
@@ -17,7 +17,6 @@ import abc
17
17
  import functools as ft
18
18
 
19
19
  import beartype.typing as tp
20
- from cola.ops.operator_base import LinearOperator
21
20
  from flax import nnx
22
21
  import jax.numpy as jnp
23
22
  from jaxtyping import (
@@ -29,6 +28,7 @@ from gpjax.kernels.computations import (
29
28
  AbstractKernelComputation,
30
29
  DenseKernelComputation,
31
30
  )
31
+ from gpjax.linalg import LinearOperator
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
@@ -16,11 +16,6 @@
16
16
  import abc
17
17
  import typing as tp
18
18
 
19
- from cola.annotations import PSD
20
- from cola.ops.operators import (
21
- Dense,
22
- Diagonal,
23
- )
24
19
  from jax import vmap
25
20
  from jaxtyping import (
26
21
  Float,
@@ -28,6 +23,11 @@ from jaxtyping import (
28
23
  )
29
24
 
30
25
  import gpjax
26
+ from gpjax.linalg import (
27
+ Dense,
28
+ Diagonal,
29
+ psd,
30
+ )
31
31
  from gpjax.typing import Array
32
32
 
33
33
  K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
@@ -69,7 +69,7 @@ class AbstractKernelComputation:
69
69
  The Gram covariance of the kernel function as a linear operator.
70
70
  """
71
71
  Kxx = self.cross_covariance(kernel, x, x)
72
- return PSD(Dense(Kxx))
72
+ return psd(Dense(Kxx))
73
73
 
74
74
  @abc.abstractmethod
75
75
  def _cross_covariance(
@@ -93,7 +93,7 @@ class AbstractKernelComputation:
93
93
  return self._cross_covariance(kernel, x, y)
94
94
 
95
95
  def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
96
- return PSD(Diagonal(diag=vmap(lambda x: kernel(x, x))(inputs)))
96
+ return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs)))
97
97
 
98
98
  def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
99
99
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -1,18 +1,19 @@
1
1
  import typing as tp
2
2
 
3
- from cola.annotations import PSD
4
- from cola.ops.operators import Dense
5
3
  import jax.numpy as jnp
6
4
  from jaxtyping import Float
7
5
 
8
6
  import gpjax
9
7
  from gpjax.kernels.computations.base import AbstractKernelComputation
8
+ from gpjax.linalg import (
9
+ Dense,
10
+ Diagonal,
11
+ psd,
12
+ )
10
13
  from gpjax.typing import Array
11
14
 
12
15
  K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") # noqa: F821
13
16
 
14
- from cola.ops import Diagonal
15
-
16
17
  # TODO: Use low rank linear operator!
17
18
 
18
19
 
@@ -28,7 +29,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
28
29
 
29
30
  def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
30
31
  z1 = self.compute_features(kernel, inputs)
31
- return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
32
+ return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
32
33
 
33
34
  def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
34
35
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -15,36 +15,34 @@
15
15
 
16
16
  import typing as tp
17
17
 
18
- from cola.annotations import PSD
19
- from cola.ops.operators import (
20
- Diagonal,
21
- Identity,
22
- Product,
23
- )
24
18
  from jax import vmap
25
19
  import jax.numpy as jnp
26
20
  from jaxtyping import Float
27
21
 
28
22
  import gpjax
29
23
  from gpjax.kernels.computations import AbstractKernelComputation
24
+ from gpjax.linalg import (
25
+ Diagonal,
26
+ psd,
27
+ )
30
28
  from gpjax.typing import Array
31
29
 
32
30
  K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
33
- ConstantDiagonalType = Product
31
+ ConstantDiagonalType = Diagonal
34
32
 
35
33
 
36
34
  class ConstantDiagonalKernelComputation(AbstractKernelComputation):
37
35
  r"""Computation engine for constant diagonal kernels."""
38
36
 
39
- def gram(self, kernel: K, x: Float[Array, "N D"]) -> Product:
37
+ def gram(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal:
40
38
  value = kernel(x[0], x[0])
41
- dtype = value.dtype
42
- shape = (x.shape[0], x.shape[0])
43
- return PSD(jnp.atleast_1d(value) * Identity(shape=shape, dtype=dtype))
39
+ # Create a diagonal matrix with constant values
40
+ diag = jnp.full(x.shape[0], value)
41
+ return psd(Diagonal(diag))
44
42
 
45
43
  def _diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
46
44
  diag = vmap(lambda x: kernel(x, x))(inputs)
47
- return PSD(Diagonal(diag=diag))
45
+ return psd(Diagonal(diag))
48
46
 
49
47
  def _cross_covariance(
50
48
  self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"]
@@ -14,16 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import beartype.typing as tp
17
- from cola.annotations import PSD
18
- from cola.ops.operators import (
19
- Diagonal,
20
- LinearOperator,
21
- )
22
17
  from jax import vmap
23
18
  from jaxtyping import Float
24
19
 
25
20
  import gpjax # noqa: F401
26
21
  from gpjax.kernels.computations import AbstractKernelComputation
22
+ from gpjax.linalg import (
23
+ Diagonal,
24
+ LinearOperator,
25
+ psd,
26
+ )
27
27
  from gpjax.typing import Array
28
28
 
29
29
  Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
@@ -35,7 +35,7 @@ class DiagonalKernelComputation(AbstractKernelComputation):
35
35
  """
36
36
 
37
37
  def gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator:
38
- return PSD(Diagonal(diag=vmap(lambda x: kernel(x, x))(x)))
38
+ return psd(Diagonal(vmap(lambda x: kernel(x, x))(x)))
39
39
 
40
40
  def _cross_covariance(
41
41
  self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
@@ -0,0 +1,37 @@
1
+ """Linear algebra module for GPJax."""
2
+
3
+ from gpjax.linalg.operations import (
4
+ diag,
5
+ logdet,
6
+ lower_cholesky,
7
+ solve,
8
+ )
9
+ from gpjax.linalg.operators import (
10
+ BlockDiag,
11
+ Dense,
12
+ Diagonal,
13
+ Identity,
14
+ Kronecker,
15
+ LinearOperator,
16
+ Triangular,
17
+ )
18
+ from gpjax.linalg.utils import (
19
+ PSD,
20
+ psd,
21
+ )
22
+
23
+ __all__ = [
24
+ "LinearOperator",
25
+ "Dense",
26
+ "Diagonal",
27
+ "Identity",
28
+ "Triangular",
29
+ "BlockDiag",
30
+ "Kronecker",
31
+ "lower_cholesky",
32
+ "solve",
33
+ "logdet",
34
+ "diag",
35
+ "psd",
36
+ "PSD",
37
+ ]