gpjax 0.11.2__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/distributions.py +16 -56
- gpjax/gps.py +34 -48
- gpjax/kernels/base.py +1 -1
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +6 -5
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +33 -0
- gpjax/objectives.py +21 -21
- gpjax/parameters.py +11 -13
- gpjax/variational_families.py +43 -37
- {gpjax-0.11.2.dist-info → gpjax-0.12.0.dist-info}/METADATA +49 -9
- {gpjax-0.11.2.dist-info → gpjax-0.12.0.dist-info}/RECORD +19 -16
- gpjax/lower_cholesky.py +0 -69
- {gpjax-0.11.2.dist-info → gpjax-0.12.0.dist-info}/WHEEL +0 -0
- {gpjax-0.11.2.dist-info → gpjax-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/objectives.py
CHANGED
|
@@ -1,13 +1,5 @@
|
|
|
1
1
|
from typing import TypeVar
|
|
2
2
|
|
|
3
|
-
from cola.annotations import PSD
|
|
4
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
5
|
-
from cola.linalg.inverse.inv import (
|
|
6
|
-
inv,
|
|
7
|
-
solve,
|
|
8
|
-
)
|
|
9
|
-
from cola.linalg.trace.diag_trace import diag
|
|
10
|
-
from cola.ops.operators import I_like
|
|
11
3
|
from flax import nnx
|
|
12
4
|
from jax import vmap
|
|
13
5
|
import jax.numpy as jnp
|
|
@@ -22,7 +14,12 @@ from gpjax.gps import (
|
|
|
22
14
|
ConjugatePosterior,
|
|
23
15
|
NonConjugatePosterior,
|
|
24
16
|
)
|
|
25
|
-
from gpjax.
|
|
17
|
+
from gpjax.linalg import (
|
|
18
|
+
Dense,
|
|
19
|
+
lower_cholesky,
|
|
20
|
+
psd,
|
|
21
|
+
solve,
|
|
22
|
+
)
|
|
26
23
|
from gpjax.typing import (
|
|
27
24
|
Array,
|
|
28
25
|
ScalarFloat,
|
|
@@ -100,9 +97,9 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
|
|
|
100
97
|
|
|
101
98
|
# Σ = (Kxx + Io²) = LLᵀ
|
|
102
99
|
Kxx = posterior.prior.kernel.gram(x)
|
|
103
|
-
Kxx
|
|
104
|
-
|
|
105
|
-
Sigma =
|
|
100
|
+
Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
|
|
101
|
+
Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
102
|
+
Sigma = psd(Dense(Sigma_dense))
|
|
106
103
|
|
|
107
104
|
# p(y | x, θ), where θ are the model hyperparameters:
|
|
108
105
|
mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma)
|
|
@@ -164,11 +161,14 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
|
|
|
164
161
|
|
|
165
162
|
# Σ = (Kxx + Io²)
|
|
166
163
|
Kxx = posterior.prior.kernel.gram(x)
|
|
167
|
-
|
|
168
|
-
|
|
164
|
+
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * (
|
|
165
|
+
obs_var + posterior.prior.jitter
|
|
166
|
+
)
|
|
167
|
+
Sigma = psd(Dense(Sigma_dense)) # [N, N]
|
|
169
168
|
|
|
170
|
-
Sigma_inv_y = solve(Sigma, y - mx
|
|
171
|
-
|
|
169
|
+
Sigma_inv_y = solve(Sigma, y - mx) # [N, 1]
|
|
170
|
+
Sigma_inv = jnp.linalg.inv(Sigma.to_dense())
|
|
171
|
+
Sigma_inv_diag = jnp.diag(Sigma_inv)[:, None] # [N, 1]
|
|
172
172
|
|
|
173
173
|
loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
|
|
174
174
|
loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
|
|
@@ -213,8 +213,8 @@ def log_posterior_density(
|
|
|
213
213
|
|
|
214
214
|
# Gram matrix
|
|
215
215
|
Kxx = posterior.prior.kernel.gram(x)
|
|
216
|
-
Kxx
|
|
217
|
-
Kxx =
|
|
216
|
+
Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
|
|
217
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
218
218
|
Lx = lower_cholesky(Kxx)
|
|
219
219
|
|
|
220
220
|
# Compute the prior mean function
|
|
@@ -349,8 +349,8 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
349
349
|
noise = variational_family.posterior.likelihood.obs_stddev.value**2
|
|
350
350
|
z = variational_family.inducing_inputs.value
|
|
351
351
|
Kzz = kernel.gram(z)
|
|
352
|
-
Kzz
|
|
353
|
-
Kzz =
|
|
352
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * variational_family.jitter
|
|
353
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
354
354
|
Kzx = kernel.cross_covariance(z, x)
|
|
355
355
|
Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
|
|
356
356
|
μx = mean_function(x)
|
|
@@ -383,7 +383,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
383
383
|
#
|
|
384
384
|
# with A and B defined as above.
|
|
385
385
|
|
|
386
|
-
A = solve(Lz, Kzx
|
|
386
|
+
A = solve(Lz, Kzx) / jnp.sqrt(noise)
|
|
387
387
|
|
|
388
388
|
# AAᵀ
|
|
389
389
|
AAT = jnp.matmul(A, A.T)
|
gpjax/parameters.py
CHANGED
|
@@ -21,23 +21,20 @@ def transform(
|
|
|
21
21
|
r"""Transforms parameters using a bijector.
|
|
22
22
|
|
|
23
23
|
Example:
|
|
24
|
-
```pycon
|
|
25
24
|
>>> from gpjax.parameters import PositiveReal, transform
|
|
26
25
|
>>> import jax.numpy as jnp
|
|
27
26
|
>>> import numpyro.distributions.transforms as npt
|
|
28
27
|
>>> from flax import nnx
|
|
29
28
|
>>> params = nnx.State(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
29
|
+
... {
|
|
30
|
+
... "a": PositiveReal(jnp.array([1.0])),
|
|
31
|
+
... "b": PositiveReal(jnp.array([2.0])),
|
|
32
|
+
... }
|
|
33
|
+
... )
|
|
35
34
|
>>> params_bijection = {'positive': npt.SoftplusTransform()}
|
|
36
35
|
>>> transformed_params = transform(params, params_bijection)
|
|
37
36
|
>>> print(transformed_params["a"].value)
|
|
38
|
-
|
|
39
|
-
```
|
|
40
|
-
|
|
37
|
+
[1.3132617]
|
|
41
38
|
|
|
42
39
|
Args:
|
|
43
40
|
params: A nnx.State object containing parameters to be transformed.
|
|
@@ -49,7 +46,7 @@ def transform(
|
|
|
49
46
|
"""
|
|
50
47
|
|
|
51
48
|
def _inner(param):
|
|
52
|
-
bijector = params_bijection.get(param.
|
|
49
|
+
bijector = params_bijection.get(param.tag, npt.IdentityTransform())
|
|
53
50
|
if inverse:
|
|
54
51
|
transformed_value = bijector.inv(param.value)
|
|
55
52
|
else:
|
|
@@ -60,10 +57,11 @@ def transform(
|
|
|
60
57
|
|
|
61
58
|
gp_params, *other_params = params.split(Parameter, ...)
|
|
62
59
|
|
|
60
|
+
# Transform each parameter in the state
|
|
63
61
|
transformed_gp_params: nnx.State = jtu.tree_map(
|
|
64
|
-
lambda x: _inner(x),
|
|
62
|
+
lambda x: _inner(x) if isinstance(x, Parameter) else x,
|
|
65
63
|
gp_params,
|
|
66
|
-
is_leaf=lambda x: isinstance(x,
|
|
64
|
+
is_leaf=lambda x: isinstance(x, Parameter),
|
|
67
65
|
)
|
|
68
66
|
return nnx.State.merge(transformed_gp_params, *other_params)
|
|
69
67
|
|
|
@@ -79,7 +77,7 @@ class Parameter(nnx.Variable[T]):
|
|
|
79
77
|
_check_is_arraylike(value)
|
|
80
78
|
|
|
81
79
|
super().__init__(value=jnp.asarray(value), **kwargs)
|
|
82
|
-
self.
|
|
80
|
+
self.tag = tag
|
|
83
81
|
|
|
84
82
|
|
|
85
83
|
class NonNegativeReal(Parameter[T]):
|
gpjax/variational_families.py
CHANGED
|
@@ -16,15 +16,6 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
21
|
-
from cola.linalg.inverse.inv import solve
|
|
22
|
-
from cola.ops.operators import (
|
|
23
|
-
Dense,
|
|
24
|
-
I_like,
|
|
25
|
-
Identity,
|
|
26
|
-
Triangular,
|
|
27
|
-
)
|
|
28
19
|
from flax import nnx
|
|
29
20
|
import jax.numpy as jnp
|
|
30
21
|
import jax.scipy as jsp
|
|
@@ -41,7 +32,14 @@ from gpjax.likelihoods import (
|
|
|
41
32
|
Gaussian,
|
|
42
33
|
NonGaussian,
|
|
43
34
|
)
|
|
44
|
-
from gpjax.
|
|
35
|
+
from gpjax.linalg import (
|
|
36
|
+
Dense,
|
|
37
|
+
Identity,
|
|
38
|
+
Triangular,
|
|
39
|
+
lower_cholesky,
|
|
40
|
+
psd,
|
|
41
|
+
solve,
|
|
42
|
+
)
|
|
45
43
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
46
44
|
from gpjax.parameters import (
|
|
47
45
|
LowerTriangular,
|
|
@@ -189,7 +187,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
189
187
|
|
|
190
188
|
muz = mean_function(z)
|
|
191
189
|
Kzz = kernel.gram(z)
|
|
192
|
-
Kzz =
|
|
190
|
+
Kzz = psd(Dense(Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter))
|
|
193
191
|
|
|
194
192
|
sqrt = Triangular(sqrt)
|
|
195
193
|
S = sqrt @ sqrt.T
|
|
@@ -226,7 +224,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
226
224
|
kernel = self.posterior.prior.kernel
|
|
227
225
|
|
|
228
226
|
Kzz = kernel.gram(z)
|
|
229
|
-
Kzz
|
|
227
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
228
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
230
229
|
Lz = lower_cholesky(Kzz)
|
|
231
230
|
muz = mean_function(z)
|
|
232
231
|
|
|
@@ -238,10 +237,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
238
237
|
mut = mean_function(t)
|
|
239
238
|
|
|
240
239
|
# Lz⁻¹ Kzt
|
|
241
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
240
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
242
241
|
|
|
243
242
|
# Kzz⁻¹ Kzt
|
|
244
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
243
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
245
244
|
|
|
246
245
|
# Ktz Kzz⁻¹ sqrt
|
|
247
246
|
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
|
|
@@ -255,7 +254,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
255
254
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
256
255
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
257
256
|
)
|
|
258
|
-
covariance +=
|
|
257
|
+
covariance += jnp.eye(covariance.shape[0]) * self.jitter
|
|
259
258
|
|
|
260
259
|
return GaussianDistribution(
|
|
261
260
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -330,7 +329,8 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
330
329
|
kernel = self.posterior.prior.kernel
|
|
331
330
|
|
|
332
331
|
Kzz = kernel.gram(z)
|
|
333
|
-
Kzz
|
|
332
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
333
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
334
334
|
Lz = lower_cholesky(Kzz)
|
|
335
335
|
|
|
336
336
|
# Unpack test inputs
|
|
@@ -341,7 +341,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
341
341
|
mut = mean_function(t)
|
|
342
342
|
|
|
343
343
|
# Lz⁻¹ Kzt
|
|
344
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
344
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
345
345
|
|
|
346
346
|
# Ktz Lz⁻ᵀ sqrt
|
|
347
347
|
Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt)
|
|
@@ -355,7 +355,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
355
355
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
356
356
|
+ jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
|
|
357
357
|
)
|
|
358
|
-
covariance +=
|
|
358
|
+
covariance += jnp.eye(covariance.shape[0]) * self.jitter
|
|
359
359
|
|
|
360
360
|
return GaussianDistribution(
|
|
361
361
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -441,7 +441,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
441
441
|
|
|
442
442
|
muz = mean_function(z)
|
|
443
443
|
Kzz = kernel.gram(z)
|
|
444
|
-
Kzz
|
|
444
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
445
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
445
446
|
|
|
446
447
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
447
448
|
pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
|
|
@@ -492,7 +493,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
492
493
|
mu = jnp.matmul(S, natural_vector)
|
|
493
494
|
|
|
494
495
|
Kzz = kernel.gram(z)
|
|
495
|
-
Kzz
|
|
496
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
497
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
496
498
|
Lz = lower_cholesky(Kzz)
|
|
497
499
|
muz = mean_function(z)
|
|
498
500
|
|
|
@@ -501,10 +503,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
501
503
|
mut = mean_function(test_inputs)
|
|
502
504
|
|
|
503
505
|
# Lz⁻¹ Kzt
|
|
504
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
506
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
505
507
|
|
|
506
508
|
# Kzz⁻¹ Kzt
|
|
507
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
509
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
508
510
|
|
|
509
511
|
# Ktz Kzz⁻¹ L
|
|
510
512
|
Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
|
|
@@ -518,7 +520,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
518
520
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
519
521
|
+ jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
|
|
520
522
|
)
|
|
521
|
-
covariance +=
|
|
523
|
+
covariance += jnp.eye(covariance.shape[0]) * self.jitter
|
|
522
524
|
|
|
523
525
|
return GaussianDistribution(
|
|
524
526
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -592,12 +594,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
592
594
|
|
|
593
595
|
# S = η₂ - η₁ η₁ᵀ
|
|
594
596
|
S = expectation_matrix - jnp.outer(mu, mu)
|
|
595
|
-
S =
|
|
596
|
-
S
|
|
597
|
+
S = psd(Dense(S))
|
|
598
|
+
S_dense = S.to_dense() + jnp.eye(S.shape[0]) * self.jitter
|
|
599
|
+
S = psd(Dense(S_dense))
|
|
597
600
|
|
|
598
601
|
muz = mean_function(z)
|
|
599
602
|
Kzz = kernel.gram(z)
|
|
600
|
-
Kzz
|
|
603
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
604
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
601
605
|
|
|
602
606
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
603
607
|
pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
|
|
@@ -636,14 +640,15 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
636
640
|
|
|
637
641
|
# S = η₂ - η₁ η₁ᵀ
|
|
638
642
|
S = expectation_matrix - jnp.matmul(mu, mu.T)
|
|
639
|
-
S = Dense(S
|
|
640
|
-
S =
|
|
643
|
+
S = Dense(S + jnp.eye(S.shape[0]) * self.jitter)
|
|
644
|
+
S = psd(S)
|
|
641
645
|
|
|
642
646
|
# S = sqrt sqrtᵀ
|
|
643
647
|
sqrt = lower_cholesky(S)
|
|
644
648
|
|
|
645
649
|
Kzz = kernel.gram(z)
|
|
646
|
-
Kzz
|
|
650
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
651
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
647
652
|
Lz = lower_cholesky(Kzz)
|
|
648
653
|
muz = mean_function(z)
|
|
649
654
|
|
|
@@ -655,10 +660,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
655
660
|
mut = mean_function(t)
|
|
656
661
|
|
|
657
662
|
# Lz⁻¹ Kzt
|
|
658
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
663
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
659
664
|
|
|
660
665
|
# Kzz⁻¹ Kzt
|
|
661
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
666
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
662
667
|
|
|
663
668
|
# Ktz Kzz⁻¹ sqrt
|
|
664
669
|
Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt
|
|
@@ -672,7 +677,7 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
672
677
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
673
678
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
674
679
|
)
|
|
675
|
-
covariance +=
|
|
680
|
+
covariance += jnp.eye(covariance.shape[0]) * self.jitter
|
|
676
681
|
|
|
677
682
|
return GaussianDistribution(
|
|
678
683
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -729,13 +734,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
729
734
|
|
|
730
735
|
Kzx = kernel.cross_covariance(z, x)
|
|
731
736
|
Kzz = kernel.gram(z)
|
|
732
|
-
Kzz
|
|
737
|
+
Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
|
|
738
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
733
739
|
|
|
734
740
|
# Lz Lzᵀ = Kzz
|
|
735
741
|
Lz = lower_cholesky(Kzz)
|
|
736
742
|
|
|
737
743
|
# Lz⁻¹ Kzx
|
|
738
|
-
Lz_inv_Kzx = solve(Lz, Kzx
|
|
744
|
+
Lz_inv_Kzx = solve(Lz, Kzx)
|
|
739
745
|
|
|
740
746
|
# A = Lz⁻¹ Kzt / o
|
|
741
747
|
A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev.value
|
|
@@ -753,14 +759,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
753
759
|
Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff))
|
|
754
760
|
|
|
755
761
|
# Kzz⁻¹ Kzx (y - μx)
|
|
756
|
-
Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff
|
|
762
|
+
Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff)
|
|
757
763
|
|
|
758
764
|
Ktt = kernel.gram(t)
|
|
759
765
|
Kzt = kernel.cross_covariance(z, t)
|
|
760
766
|
mut = mean_function(t)
|
|
761
767
|
|
|
762
768
|
# Lz⁻¹ Kzt
|
|
763
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
769
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
764
770
|
|
|
765
771
|
# L⁻¹ Lz⁻¹ Kzt
|
|
766
772
|
L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)
|
|
@@ -774,7 +780,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
774
780
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
775
781
|
+ jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
|
|
776
782
|
)
|
|
777
|
-
covariance +=
|
|
783
|
+
covariance += jnp.eye(covariance.shape[0]) * self.jitter
|
|
778
784
|
|
|
779
785
|
return GaussianDistribution(
|
|
780
786
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
18
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
18
19
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
19
|
-
Requires-Python:
|
|
20
|
+
Requires-Python: <=3.13,>=3.10
|
|
20
21
|
Requires-Dist: beartype>0.16.1
|
|
21
|
-
Requires-Dist: cola-ml>=0.0.7
|
|
22
22
|
Requires-Dist: flax>=0.10.0
|
|
23
23
|
Requires-Dist: jax>=0.5.0
|
|
24
24
|
Requires-Dist: jaxlib>=0.5.0
|
|
@@ -27,6 +27,47 @@ Requires-Dist: numpy>=2.0.0
|
|
|
27
27
|
Requires-Dist: numpyro
|
|
28
28
|
Requires-Dist: optax>0.2.1
|
|
29
29
|
Requires-Dist: tqdm>4.66.2
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
32
|
+
Requires-Dist: autoflake; extra == 'dev'
|
|
33
|
+
Requires-Dist: black; extra == 'dev'
|
|
34
|
+
Requires-Dist: codespell>=2.2.4; extra == 'dev'
|
|
35
|
+
Requires-Dist: coverage>=7.2.2; extra == 'dev'
|
|
36
|
+
Requires-Dist: interrogate>=1.5.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: isort; extra == 'dev'
|
|
38
|
+
Requires-Dist: jupytext; extra == 'dev'
|
|
39
|
+
Requires-Dist: mktestdocs>=0.2.1; extra == 'dev'
|
|
40
|
+
Requires-Dist: networkx; extra == 'dev'
|
|
41
|
+
Requires-Dist: pre-commit>=3.2.2; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest-beartype; extra == 'dev'
|
|
43
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: pytest-pretty>=1.1.1; extra == 'dev'
|
|
45
|
+
Requires-Dist: pytest-xdist>=3.2.1; extra == 'dev'
|
|
46
|
+
Requires-Dist: pytest>=7.2.2; extra == 'dev'
|
|
47
|
+
Requires-Dist: ruff>=0.6; extra == 'dev'
|
|
48
|
+
Requires-Dist: xdoctest>=1.1.1; extra == 'dev'
|
|
49
|
+
Provides-Extra: docs
|
|
50
|
+
Requires-Dist: blackjax>=0.9.6; extra == 'docs'
|
|
51
|
+
Requires-Dist: ipykernel>=6.22.0; extra == 'docs'
|
|
52
|
+
Requires-Dist: ipython>=8.11.0; extra == 'docs'
|
|
53
|
+
Requires-Dist: ipywidgets>=8.0.5; extra == 'docs'
|
|
54
|
+
Requires-Dist: jupytext>=1.14.5; extra == 'docs'
|
|
55
|
+
Requires-Dist: markdown-katex>=202406.1035; extra == 'docs'
|
|
56
|
+
Requires-Dist: matplotlib>=3.7.1; extra == 'docs'
|
|
57
|
+
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
|
|
58
|
+
Requires-Dist: mkdocs-git-authors-plugin>=0.7.0; extra == 'docs'
|
|
59
|
+
Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
60
|
+
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
61
|
+
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
62
|
+
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<0.28.0; extra == 'docs'
|
|
64
|
+
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
65
|
+
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
66
|
+
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
67
|
+
Requires-Dist: pymdown-extensions>=10.7.1; extra == 'docs'
|
|
68
|
+
Requires-Dist: scikit-learn>=1.5.1; extra == 'docs'
|
|
69
|
+
Requires-Dist: seaborn>=0.12.2; extra == 'docs'
|
|
70
|
+
Requires-Dist: watermark>=2.3.1; extra == 'docs'
|
|
30
71
|
Description-Content-Type: text/markdown
|
|
31
72
|
|
|
32
73
|
<!-- <h1 align='center'>GPJax</h1>
|
|
@@ -41,12 +82,12 @@ Description-Content-Type: text/markdown
|
|
|
41
82
|
[](https://badge.fury.io/py/GPJax)
|
|
42
83
|
[](https://doi.org/10.21105/joss.04455)
|
|
43
84
|
[](https://pepy.tech/project/gpjax)
|
|
44
|
-
[](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
85
|
+
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
45
86
|
|
|
46
87
|
[**Quickstart**](#simple-example)
|
|
47
88
|
| [**Install guide**](#installation)
|
|
48
89
|
| [**Documentation**](https://docs.jaxgaussianprocesses.com/)
|
|
49
|
-
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
90
|
+
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
50
91
|
|
|
51
92
|
GPJax aims to provide a low-level interface to Gaussian process (GP) models in
|
|
52
93
|
[Jax](https://github.com/google/jax), structured to give researchers maximum
|
|
@@ -81,7 +122,7 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
|
|
|
81
122
|
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
|
|
82
123
|
|
|
83
124
|
Feel free to join our [Slack
|
|
84
|
-
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
125
|
+
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
|
|
85
126
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
86
127
|
process modelling.
|
|
87
128
|
|
|
@@ -177,14 +218,13 @@ configuration in development mode.
|
|
|
177
218
|
```bash
|
|
178
219
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
179
220
|
cd GPJax
|
|
180
|
-
|
|
181
|
-
hatch shell
|
|
221
|
+
uv sync --extra dev
|
|
182
222
|
```
|
|
183
223
|
|
|
184
224
|
> We recommend you check your installation passes the supplied unit tests:
|
|
185
225
|
>
|
|
186
226
|
> ```python
|
|
187
|
-
>
|
|
227
|
+
> uv run pytest --beartype-packages='gpjax'
|
|
188
228
|
> ```
|
|
189
229
|
|
|
190
230
|
# Citing GPJax
|
|
@@ -1,29 +1,28 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=FSrKDFSQ7xDqwQGBWwEPqqjvYxEbhUPPestKLoAPjWA,1686
|
|
2
2
|
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
-
gpjax/distributions.py,sha256=
|
|
4
|
+
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
5
|
gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
|
|
6
|
-
gpjax/gps.py,sha256
|
|
6
|
+
gpjax/gps.py,sha256=-Log0pcU8qmB5fUxfzoNjD0S64gpiypAjFzjGXX6w7I,30301
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
8
|
gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
|
|
9
|
-
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
9
|
gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
|
|
11
10
|
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
12
|
-
gpjax/objectives.py,sha256=
|
|
13
|
-
gpjax/parameters.py,sha256=
|
|
11
|
+
gpjax/objectives.py,sha256=Tm36h8fz_nWkZPlufMQzZWKK1ytrtT9yvvP8YdxYKNw,15359
|
|
12
|
+
gpjax/parameters.py,sha256=qIEqyMKNd2n2Ak15PisCmqhX5qhsoRgng_s4doL96rE,7044
|
|
14
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
15
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
16
|
-
gpjax/variational_families.py,sha256=
|
|
15
|
+
gpjax/variational_families.py,sha256=rE3LarwIAkvDvLlWrz8Ww6BUBz88YHdV4ceY97r3IBw,28637
|
|
17
16
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
18
|
-
gpjax/kernels/base.py,sha256=
|
|
17
|
+
gpjax/kernels/base.py,sha256=hOUXwarspDFnuI2_QreyIVPdz2fzRVJj4p3Zdu1touw,11606
|
|
19
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
20
19
|
gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
|
|
21
20
|
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
22
|
-
gpjax/kernels/computations/base.py,sha256=
|
|
23
|
-
gpjax/kernels/computations/basis_functions.py,sha256=
|
|
24
|
-
gpjax/kernels/computations/constant_diagonal.py,sha256=
|
|
21
|
+
gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
|
|
22
|
+
gpjax/kernels/computations/basis_functions.py,sha256=MPSo40NEx_ngnSLTa9ntVJzma_jugvm5dMpZd5MtG5M,2490
|
|
23
|
+
gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
|
|
25
24
|
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
26
|
-
gpjax/kernels/computations/diagonal.py,sha256=
|
|
25
|
+
gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
|
|
27
26
|
gpjax/kernels/computations/eigen.py,sha256=w7I7LK42j0ouchHCI1ltXx0lpwqvK1bRb4HclnF3rKs,1936
|
|
28
27
|
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
29
28
|
gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
|
|
@@ -43,7 +42,11 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
43
42
|
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
44
43
|
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
45
44
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
46
|
-
gpjax
|
|
47
|
-
gpjax
|
|
48
|
-
gpjax
|
|
49
|
-
gpjax
|
|
45
|
+
gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
46
|
+
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
|
+
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
|
+
gpjax/linalg/utils.py,sha256=DGX40TDhmfYn7JBxElpBm_9W0cetm0HZUK7B3j74xxo,895
|
|
49
|
+
gpjax-0.12.0.dist-info/METADATA,sha256=8lLQb5SUvWvniry-zBOR3wzm03tXvHe7Lzry_Ho3peE,10562
|
|
50
|
+
gpjax-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
+
gpjax-0.12.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.12.0.dist-info/RECORD,,
|
gpjax/lower_cholesky.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from cola.annotations import PSD
|
|
17
|
-
from cola.fns import dispatch
|
|
18
|
-
from cola.ops.operator_base import LinearOperator
|
|
19
|
-
from cola.ops.operators import (
|
|
20
|
-
BlockDiag,
|
|
21
|
-
Diagonal,
|
|
22
|
-
Identity,
|
|
23
|
-
Kronecker,
|
|
24
|
-
Triangular,
|
|
25
|
-
)
|
|
26
|
-
import jax.numpy as jnp
|
|
27
|
-
|
|
28
|
-
# TODO: Once this functionality is supported in CoLA, remove this.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@dispatch
|
|
32
|
-
def lower_cholesky(A: LinearOperator) -> Triangular: # noqa: F811
|
|
33
|
-
"""Returns the lower Cholesky factor of a linear operator.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
A: The input linear operator.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
Triangular: The lower Cholesky factor of A.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
if PSD not in A.annotations:
|
|
43
|
-
raise ValueError(
|
|
44
|
-
"Expected LinearOperator to be PSD, did you forget to use cola.PSD?"
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@lower_cholesky.dispatch
|
|
51
|
-
def _(A: Diagonal): # noqa: F811
|
|
52
|
-
return Diagonal(jnp.sqrt(A.diag))
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@lower_cholesky.dispatch
|
|
56
|
-
def _(A: Identity): # noqa: F811
|
|
57
|
-
return A
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@lower_cholesky.dispatch
|
|
61
|
-
def _(A: Kronecker): # noqa: F811
|
|
62
|
-
return Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms])
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@lower_cholesky.dispatch
|
|
66
|
-
def _(A: BlockDiag): # noqa: F811
|
|
67
|
-
return BlockDiag(
|
|
68
|
-
*[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities
|
|
69
|
-
)
|
|
File without changes
|
|
File without changes
|