gpjax 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/citation.py +7 -2
- gpjax/distributions.py +16 -56
- gpjax/fit.py +3 -3
- gpjax/gps.py +34 -48
- gpjax/kernels/base.py +1 -1
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +6 -5
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +33 -0
- gpjax/objectives.py +21 -21
- gpjax/parameters.py +11 -13
- gpjax/variational_families.py +43 -37
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/METADATA +49 -9
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/RECORD +21 -18
- gpjax/lower_cholesky.py +0 -69
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/WHEEL +0 -0
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.12.0"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
gpjax/citation.py
CHANGED
|
@@ -8,7 +8,12 @@ from beartype.typing import (
|
|
|
8
8
|
Dict,
|
|
9
9
|
Union,
|
|
10
10
|
)
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
# safely removable once jax>=0.6.0
|
|
14
|
+
from jaxlib.xla_extension import PjitFunction
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
from jaxlib._jax import PjitFunction
|
|
12
17
|
|
|
13
18
|
from gpjax.kernels import (
|
|
14
19
|
RFF,
|
|
@@ -45,7 +50,7 @@ class AbstractCitation:
|
|
|
45
50
|
|
|
46
51
|
|
|
47
52
|
class NullCitation(AbstractCitation):
|
|
48
|
-
def
|
|
53
|
+
def as_str(self) -> str:
|
|
49
54
|
return (
|
|
50
55
|
"No citation available. If you think this is an error, please open a pull"
|
|
51
56
|
" request."
|
gpjax/distributions.py
CHANGED
|
@@ -17,11 +17,6 @@
|
|
|
17
17
|
from beartype.typing import (
|
|
18
18
|
Optional,
|
|
19
19
|
)
|
|
20
|
-
import cola
|
|
21
|
-
from cola.linalg.decompositions import Cholesky
|
|
22
|
-
from cola.ops import (
|
|
23
|
-
LinearOperator,
|
|
24
|
-
)
|
|
25
20
|
from jax import vmap
|
|
26
21
|
import jax.numpy as jnp
|
|
27
22
|
import jax.random as jr
|
|
@@ -30,7 +25,14 @@ from numpyro.distributions import constraints
|
|
|
30
25
|
from numpyro.distributions.distribution import Distribution
|
|
31
26
|
from numpyro.distributions.util import is_prng_key
|
|
32
27
|
|
|
33
|
-
from gpjax.
|
|
28
|
+
from gpjax.linalg.operations import (
|
|
29
|
+
diag,
|
|
30
|
+
logdet,
|
|
31
|
+
lower_cholesky,
|
|
32
|
+
solve,
|
|
33
|
+
)
|
|
34
|
+
from gpjax.linalg.operators import LinearOperator
|
|
35
|
+
from gpjax.linalg.utils import psd
|
|
34
36
|
from gpjax.typing import (
|
|
35
37
|
Array,
|
|
36
38
|
ScalarFloat,
|
|
@@ -47,7 +49,7 @@ class GaussianDistribution(Distribution):
|
|
|
47
49
|
validate_args=None,
|
|
48
50
|
):
|
|
49
51
|
self.loc = loc
|
|
50
|
-
self.scale =
|
|
52
|
+
self.scale = psd(scale)
|
|
51
53
|
batch_shape = ()
|
|
52
54
|
event_shape = jnp.shape(self.loc)
|
|
53
55
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
@@ -76,13 +78,12 @@ class GaussianDistribution(Distribution):
|
|
|
76
78
|
@property
|
|
77
79
|
def variance(self) -> Float[Array, " N"]:
|
|
78
80
|
r"""Calculates the variance."""
|
|
79
|
-
return
|
|
81
|
+
return diag(self.scale)
|
|
80
82
|
|
|
81
83
|
def entropy(self) -> ScalarFloat:
|
|
82
84
|
r"""Calculates the entropy of the distribution."""
|
|
83
85
|
return 0.5 * (
|
|
84
|
-
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
85
|
-
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
86
|
+
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + logdet(self.scale)
|
|
86
87
|
)
|
|
87
88
|
|
|
88
89
|
def median(self) -> Float[Array, " N"]:
|
|
@@ -104,7 +105,7 @@ class GaussianDistribution(Distribution):
|
|
|
104
105
|
|
|
105
106
|
def stddev(self) -> Float[Array, " N"]:
|
|
106
107
|
r"""Calculates the standard deviation."""
|
|
107
|
-
return jnp.sqrt(
|
|
108
|
+
return jnp.sqrt(diag(self.scale))
|
|
108
109
|
|
|
109
110
|
# @property
|
|
110
111
|
# def event_shape(self) -> Tuple:
|
|
@@ -129,9 +130,7 @@ class GaussianDistribution(Distribution):
|
|
|
129
130
|
|
|
130
131
|
# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
|
|
131
132
|
return -0.5 * (
|
|
132
|
-
n * jnp.log(2.0 * jnp.pi)
|
|
133
|
-
+ cola.logdet(sigma, Cholesky(), Cholesky())
|
|
134
|
-
+ diff.T @ cola.solve(sigma, diff, Cholesky())
|
|
133
|
+
n * jnp.log(2.0 * jnp.pi) + logdet(sigma) + diff.T @ solve(sigma, diff)
|
|
135
134
|
)
|
|
136
135
|
|
|
137
136
|
# def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
@@ -219,53 +218,14 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
|
|
|
219
218
|
|
|
220
219
|
# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
|
|
221
220
|
trace = _frobenius_norm_squared(
|
|
222
|
-
|
|
221
|
+
solve(sqrt_p, sqrt_q.to_dense())
|
|
223
222
|
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.
|
|
224
223
|
|
|
225
224
|
# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
|
|
226
|
-
mahalanobis = jnp.sum(jnp.square(
|
|
225
|
+
mahalanobis = jnp.sum(jnp.square(solve(sqrt_p, diff)))
|
|
227
226
|
|
|
228
227
|
# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
|
|
229
|
-
return (
|
|
230
|
-
mahalanobis
|
|
231
|
-
- n_dim
|
|
232
|
-
- cola.logdet(sigma_q, Cholesky(), Cholesky())
|
|
233
|
-
+ cola.logdet(sigma_p, Cholesky(), Cholesky())
|
|
234
|
-
+ trace
|
|
235
|
-
) / 2.0
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
# def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
239
|
-
# r"""Checks that the inputs are correct."""
|
|
240
|
-
# if loc is None and scale is None:
|
|
241
|
-
# raise ValueError("At least one of `loc` or `scale` must be specified.")
|
|
242
|
-
|
|
243
|
-
# if loc is not None and loc.ndim < 1:
|
|
244
|
-
# raise ValueError("The parameter `loc` must have at least one dimension.")
|
|
245
|
-
|
|
246
|
-
# if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
|
|
247
|
-
# raise ValueError(
|
|
248
|
-
# "The `scale` must have at least two dimensions, but "
|
|
249
|
-
# f"`scale.shape = {scale.shape}`."
|
|
250
|
-
# )
|
|
251
|
-
|
|
252
|
-
# if scale is not None and not isinstance(scale, LinearOperator):
|
|
253
|
-
# raise ValueError(
|
|
254
|
-
# f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
|
|
255
|
-
# )
|
|
256
|
-
|
|
257
|
-
# if scale is not None and (scale.shape[-1] != scale.shape[-2]):
|
|
258
|
-
# raise ValueError(
|
|
259
|
-
# f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
|
|
260
|
-
# )
|
|
261
|
-
|
|
262
|
-
# if loc is not None:
|
|
263
|
-
# num_dims = loc.shape[-1]
|
|
264
|
-
# if scale is not None and (scale.shape[-1] != num_dims):
|
|
265
|
-
# raise ValueError(
|
|
266
|
-
# f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
|
|
267
|
-
# f"`scale.shape = {scale.shape}`."
|
|
268
|
-
# )
|
|
228
|
+
return (mahalanobis - n_dim - logdet(sigma_q) + logdet(sigma_p) + trace) / 2.0
|
|
269
229
|
|
|
270
230
|
|
|
271
231
|
__all__ = [
|
gpjax/fit.py
CHANGED
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
import typing as tp
|
|
17
17
|
|
|
18
|
+
from flax import nnx
|
|
18
19
|
import jax
|
|
20
|
+
from jax.flatten_util import ravel_pytree
|
|
19
21
|
import jax.numpy as jnp
|
|
20
22
|
import jax.random as jr
|
|
21
|
-
import optax as ox
|
|
22
|
-
from flax import nnx
|
|
23
|
-
from jax.flatten_util import ravel_pytree
|
|
24
23
|
from numpyro.distributions.transforms import Transform
|
|
24
|
+
import optax as ox
|
|
25
25
|
from scipy.optimize import minimize
|
|
26
26
|
|
|
27
27
|
from gpjax.dataset import Dataset
|
gpjax/gps.py
CHANGED
|
@@ -16,11 +16,6 @@
|
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.linalg.algorithm_base import Algorithm
|
|
21
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
22
|
-
from cola.linalg.inverse.inv import solve
|
|
23
|
-
from cola.ops.operators import I_like
|
|
24
19
|
from flax import nnx
|
|
25
20
|
import jax.numpy as jnp
|
|
26
21
|
import jax.random as jr
|
|
@@ -38,7 +33,13 @@ from gpjax.likelihoods import (
|
|
|
38
33
|
Gaussian,
|
|
39
34
|
NonGaussian,
|
|
40
35
|
)
|
|
41
|
-
from gpjax.
|
|
36
|
+
from gpjax.linalg import (
|
|
37
|
+
Dense,
|
|
38
|
+
Identity,
|
|
39
|
+
psd,
|
|
40
|
+
solve,
|
|
41
|
+
)
|
|
42
|
+
from gpjax.linalg.operations import lower_cholesky
|
|
42
43
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
43
44
|
from gpjax.parameters import (
|
|
44
45
|
Parameter,
|
|
@@ -251,8 +252,8 @@ class Prior(AbstractPrior[M, K]):
|
|
|
251
252
|
x = test_inputs
|
|
252
253
|
mx = self.mean_function(x)
|
|
253
254
|
Kxx = self.kernel.gram(x)
|
|
254
|
-
Kxx
|
|
255
|
-
Kxx =
|
|
255
|
+
Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
|
|
256
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
256
257
|
|
|
257
258
|
return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)
|
|
258
259
|
|
|
@@ -315,15 +316,13 @@ class Prior(AbstractPrior[M, K]):
|
|
|
315
316
|
if (not isinstance(num_samples, int)) or num_samples <= 0:
|
|
316
317
|
raise ValueError("num_samples must be a positive integer")
|
|
317
318
|
|
|
318
|
-
# sample fourier features
|
|
319
319
|
fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)
|
|
320
320
|
|
|
321
|
-
|
|
322
|
-
feature_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
|
|
321
|
+
feature_weights = jr.normal(key, [num_samples, 2 * num_features])
|
|
323
322
|
|
|
324
323
|
def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
|
|
325
|
-
feature_evals = fourier_feature_fn(test_inputs)
|
|
326
|
-
evaluated_sample = jnp.inner(feature_evals, feature_weights)
|
|
324
|
+
feature_evals = fourier_feature_fn(test_inputs)
|
|
325
|
+
evaluated_sample = jnp.inner(feature_evals, feature_weights)
|
|
327
326
|
return self.mean_function(test_inputs) + evaluated_sample
|
|
328
327
|
|
|
329
328
|
return sample_fn
|
|
@@ -504,24 +503,23 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
504
503
|
|
|
505
504
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
506
505
|
Kxx = self.prior.kernel.gram(x)
|
|
507
|
-
Kxx
|
|
506
|
+
Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
|
|
507
|
+
Kxx = Dense(Kxx_dense)
|
|
508
508
|
|
|
509
|
-
|
|
510
|
-
Sigma =
|
|
511
|
-
Sigma = PSD(Sigma)
|
|
509
|
+
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
|
+
Sigma = psd(Dense(Sigma_dense))
|
|
512
511
|
|
|
513
512
|
mean_t = self.prior.mean_function(t)
|
|
514
513
|
Ktt = self.prior.kernel.gram(t)
|
|
515
514
|
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
516
|
-
Sigma_inv_Kxt = solve(Sigma, Kxt
|
|
515
|
+
Sigma_inv_Kxt = solve(Sigma, Kxt)
|
|
517
516
|
|
|
518
|
-
# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
|
|
519
517
|
mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)
|
|
520
518
|
|
|
521
519
|
# Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
|
|
522
|
-
covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
|
|
523
|
-
covariance +=
|
|
524
|
-
covariance =
|
|
520
|
+
covariance = Ktt.to_dense() - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
|
|
521
|
+
covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
|
|
522
|
+
covariance = psd(Dense(covariance))
|
|
525
523
|
|
|
526
524
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
527
525
|
|
|
@@ -531,7 +529,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
531
529
|
train_data: Dataset,
|
|
532
530
|
key: KeyArray,
|
|
533
531
|
num_features: int | None = 100,
|
|
534
|
-
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
|
|
535
532
|
) -> FunctionalSample:
|
|
536
533
|
r"""Draw approximate samples from the Gaussian process posterior.
|
|
537
534
|
|
|
@@ -565,11 +562,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
565
562
|
key (KeyArray): The random seed used for the sample(s).
|
|
566
563
|
num_features (int): The number of features used when approximating the
|
|
567
564
|
kernel.
|
|
568
|
-
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
|
|
569
|
-
the inverse of the covariance matrix. See the
|
|
570
|
-
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
|
|
571
|
-
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
|
|
572
|
-
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
|
|
573
565
|
|
|
574
566
|
Returns:
|
|
575
567
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
@@ -581,31 +573,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
581
573
|
# sample fourier features
|
|
582
574
|
fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)
|
|
583
575
|
|
|
584
|
-
|
|
585
|
-
fourier_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
|
|
576
|
+
fourier_weights = jr.normal(key, [num_samples, 2 * num_features])
|
|
586
577
|
|
|
587
|
-
# sample weights v for canonical features
|
|
588
|
-
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
|
|
589
578
|
obs_var = self.likelihood.obs_stddev.value**2
|
|
590
|
-
Kxx = self.prior.kernel.gram(train_data.X)
|
|
591
|
-
Sigma = Kxx +
|
|
592
|
-
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
593
|
-
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
579
|
+
Kxx = self.prior.kernel.gram(train_data.X)
|
|
580
|
+
Sigma = Kxx + jnp.eye(Kxx.shape[0]) * (obs_var + self.jitter)
|
|
581
|
+
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
582
|
+
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
594
583
|
Phi = fourier_feature_fn(train_data.X)
|
|
595
584
|
canonical_weights = solve(
|
|
596
585
|
Sigma,
|
|
597
586
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
598
|
-
solver_algorithm,
|
|
599
587
|
) # [N, B]
|
|
600
588
|
|
|
601
589
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
602
|
-
fourier_features = fourier_feature_fn(test_inputs)
|
|
603
|
-
weight_space_contribution = jnp.inner(
|
|
604
|
-
fourier_features, fourier_weights
|
|
605
|
-
) # [n, B]
|
|
590
|
+
fourier_features = fourier_feature_fn(test_inputs)
|
|
591
|
+
weight_space_contribution = jnp.inner(fourier_features, fourier_weights)
|
|
606
592
|
canonical_features = self.prior.kernel.cross_covariance(
|
|
607
593
|
test_inputs, train_data.X
|
|
608
|
-
)
|
|
594
|
+
)
|
|
609
595
|
function_space_contribution = jnp.matmul(
|
|
610
596
|
canonical_features, canonical_weights
|
|
611
597
|
)
|
|
@@ -689,8 +675,8 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
689
675
|
|
|
690
676
|
# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
|
|
691
677
|
Kxx = kernel.gram(x)
|
|
692
|
-
Kxx
|
|
693
|
-
Kxx =
|
|
678
|
+
Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * self.prior.jitter
|
|
679
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
694
680
|
Lx = lower_cholesky(Kxx)
|
|
695
681
|
|
|
696
682
|
# Unpack test inputs
|
|
@@ -702,7 +688,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
702
688
|
mean_t = mean_function(t)
|
|
703
689
|
|
|
704
690
|
# Lx⁻¹ Kxt
|
|
705
|
-
Lx_inv_Kxt = solve(Lx, Ktx.T
|
|
691
|
+
Lx_inv_Kxt = solve(Lx, Ktx.T)
|
|
706
692
|
|
|
707
693
|
# Whitened function values, wx, corresponding to the inputs, x
|
|
708
694
|
wx = self.latent.value
|
|
@@ -711,9 +697,9 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
711
697
|
mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
|
|
712
698
|
|
|
713
699
|
# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
|
|
714
|
-
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
715
|
-
covariance +=
|
|
716
|
-
covariance =
|
|
700
|
+
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
701
|
+
covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
|
|
702
|
+
covariance = psd(Dense(covariance))
|
|
717
703
|
|
|
718
704
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
719
705
|
|
gpjax/kernels/base.py
CHANGED
|
@@ -17,7 +17,6 @@ import abc
|
|
|
17
17
|
import functools as ft
|
|
18
18
|
|
|
19
19
|
import beartype.typing as tp
|
|
20
|
-
from cola.ops.operator_base import LinearOperator
|
|
21
20
|
from flax import nnx
|
|
22
21
|
import jax.numpy as jnp
|
|
23
22
|
from jaxtyping import (
|
|
@@ -29,6 +28,7 @@ from gpjax.kernels.computations import (
|
|
|
29
28
|
AbstractKernelComputation,
|
|
30
29
|
DenseKernelComputation,
|
|
31
30
|
)
|
|
31
|
+
from gpjax.linalg import LinearOperator
|
|
32
32
|
from gpjax.parameters import (
|
|
33
33
|
Parameter,
|
|
34
34
|
Real,
|
|
@@ -16,11 +16,6 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
import typing as tp
|
|
18
18
|
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.ops.operators import (
|
|
21
|
-
Dense,
|
|
22
|
-
Diagonal,
|
|
23
|
-
)
|
|
24
19
|
from jax import vmap
|
|
25
20
|
from jaxtyping import (
|
|
26
21
|
Float,
|
|
@@ -28,6 +23,11 @@ from jaxtyping import (
|
|
|
28
23
|
)
|
|
29
24
|
|
|
30
25
|
import gpjax
|
|
26
|
+
from gpjax.linalg import (
|
|
27
|
+
Dense,
|
|
28
|
+
Diagonal,
|
|
29
|
+
psd,
|
|
30
|
+
)
|
|
31
31
|
from gpjax.typing import Array
|
|
32
32
|
|
|
33
33
|
K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
|
|
@@ -69,7 +69,7 @@ class AbstractKernelComputation:
|
|
|
69
69
|
The Gram covariance of the kernel function as a linear operator.
|
|
70
70
|
"""
|
|
71
71
|
Kxx = self.cross_covariance(kernel, x, x)
|
|
72
|
-
return
|
|
72
|
+
return psd(Dense(Kxx))
|
|
73
73
|
|
|
74
74
|
@abc.abstractmethod
|
|
75
75
|
def _cross_covariance(
|
|
@@ -93,7 +93,7 @@ class AbstractKernelComputation:
|
|
|
93
93
|
return self._cross_covariance(kernel, x, y)
|
|
94
94
|
|
|
95
95
|
def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
|
|
96
|
-
return
|
|
96
|
+
return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs)))
|
|
97
97
|
|
|
98
98
|
def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
|
|
99
99
|
r"""For a given kernel, compute the elementwise diagonal of the
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
|
-
from cola.annotations import PSD
|
|
4
|
-
from cola.ops.operators import Dense
|
|
5
3
|
import jax.numpy as jnp
|
|
6
4
|
from jaxtyping import Float
|
|
7
5
|
|
|
8
6
|
import gpjax
|
|
9
7
|
from gpjax.kernels.computations.base import AbstractKernelComputation
|
|
8
|
+
from gpjax.linalg import (
|
|
9
|
+
Dense,
|
|
10
|
+
Diagonal,
|
|
11
|
+
psd,
|
|
12
|
+
)
|
|
10
13
|
from gpjax.typing import Array
|
|
11
14
|
|
|
12
15
|
K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") # noqa: F821
|
|
13
16
|
|
|
14
|
-
from cola.ops import Diagonal
|
|
15
|
-
|
|
16
17
|
# TODO: Use low rank linear operator!
|
|
17
18
|
|
|
18
19
|
|
|
@@ -28,7 +29,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
28
29
|
|
|
29
30
|
def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
|
|
30
31
|
z1 = self.compute_features(kernel, inputs)
|
|
31
|
-
return
|
|
32
|
+
return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
|
|
32
33
|
|
|
33
34
|
def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
|
|
34
35
|
r"""For a given kernel, compute the elementwise diagonal of the
|
|
@@ -15,36 +15,34 @@
|
|
|
15
15
|
|
|
16
16
|
import typing as tp
|
|
17
17
|
|
|
18
|
-
from cola.annotations import PSD
|
|
19
|
-
from cola.ops.operators import (
|
|
20
|
-
Diagonal,
|
|
21
|
-
Identity,
|
|
22
|
-
Product,
|
|
23
|
-
)
|
|
24
18
|
from jax import vmap
|
|
25
19
|
import jax.numpy as jnp
|
|
26
20
|
from jaxtyping import Float
|
|
27
21
|
|
|
28
22
|
import gpjax
|
|
29
23
|
from gpjax.kernels.computations import AbstractKernelComputation
|
|
24
|
+
from gpjax.linalg import (
|
|
25
|
+
Diagonal,
|
|
26
|
+
psd,
|
|
27
|
+
)
|
|
30
28
|
from gpjax.typing import Array
|
|
31
29
|
|
|
32
30
|
K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
|
|
33
|
-
ConstantDiagonalType =
|
|
31
|
+
ConstantDiagonalType = Diagonal
|
|
34
32
|
|
|
35
33
|
|
|
36
34
|
class ConstantDiagonalKernelComputation(AbstractKernelComputation):
|
|
37
35
|
r"""Computation engine for constant diagonal kernels."""
|
|
38
36
|
|
|
39
|
-
def gram(self, kernel: K, x: Float[Array, "N D"]) ->
|
|
37
|
+
def gram(self, kernel: K, x: Float[Array, "N D"]) -> Diagonal:
|
|
40
38
|
value = kernel(x[0], x[0])
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return
|
|
39
|
+
# Create a diagonal matrix with constant values
|
|
40
|
+
diag = jnp.full(x.shape[0], value)
|
|
41
|
+
return psd(Diagonal(diag))
|
|
44
42
|
|
|
45
43
|
def _diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
|
|
46
44
|
diag = vmap(lambda x: kernel(x, x))(inputs)
|
|
47
|
-
return
|
|
45
|
+
return psd(Diagonal(diag))
|
|
48
46
|
|
|
49
47
|
def _cross_covariance(
|
|
50
48
|
self, kernel: K, x: Float[Array, "N D"], y: Float[Array, "M D"]
|
|
@@ -14,16 +14,16 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import beartype.typing as tp
|
|
17
|
-
from cola.annotations import PSD
|
|
18
|
-
from cola.ops.operators import (
|
|
19
|
-
Diagonal,
|
|
20
|
-
LinearOperator,
|
|
21
|
-
)
|
|
22
17
|
from jax import vmap
|
|
23
18
|
from jaxtyping import Float
|
|
24
19
|
|
|
25
20
|
import gpjax # noqa: F401
|
|
26
21
|
from gpjax.kernels.computations import AbstractKernelComputation
|
|
22
|
+
from gpjax.linalg import (
|
|
23
|
+
Diagonal,
|
|
24
|
+
LinearOperator,
|
|
25
|
+
psd,
|
|
26
|
+
)
|
|
27
27
|
from gpjax.typing import Array
|
|
28
28
|
|
|
29
29
|
Kernel = tp.TypeVar("Kernel", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
|
|
@@ -35,7 +35,7 @@ class DiagonalKernelComputation(AbstractKernelComputation):
|
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
def gram(self, kernel: Kernel, x: Float[Array, "N D"]) -> LinearOperator:
|
|
38
|
-
return
|
|
38
|
+
return psd(Diagonal(vmap(lambda x: kernel(x, x))(x)))
|
|
39
39
|
|
|
40
40
|
def _cross_covariance(
|
|
41
41
|
self, kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"]
|
gpjax/linalg/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Linear algebra module for GPJax."""
|
|
2
|
+
|
|
3
|
+
from gpjax.linalg.operations import (
|
|
4
|
+
diag,
|
|
5
|
+
logdet,
|
|
6
|
+
lower_cholesky,
|
|
7
|
+
solve,
|
|
8
|
+
)
|
|
9
|
+
from gpjax.linalg.operators import (
|
|
10
|
+
BlockDiag,
|
|
11
|
+
Dense,
|
|
12
|
+
Diagonal,
|
|
13
|
+
Identity,
|
|
14
|
+
Kronecker,
|
|
15
|
+
LinearOperator,
|
|
16
|
+
Triangular,
|
|
17
|
+
)
|
|
18
|
+
from gpjax.linalg.utils import (
|
|
19
|
+
PSD,
|
|
20
|
+
psd,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"LinearOperator",
|
|
25
|
+
"Dense",
|
|
26
|
+
"Diagonal",
|
|
27
|
+
"Identity",
|
|
28
|
+
"Triangular",
|
|
29
|
+
"BlockDiag",
|
|
30
|
+
"Kronecker",
|
|
31
|
+
"lower_cholesky",
|
|
32
|
+
"solve",
|
|
33
|
+
"logdet",
|
|
34
|
+
"diag",
|
|
35
|
+
"psd",
|
|
36
|
+
"PSD",
|
|
37
|
+
]
|