gpjax 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/fit.py +11 -6
- gpjax/gps.py +35 -33
- gpjax/kernels/approximations/rff.py +4 -6
- gpjax/kernels/base.py +2 -5
- gpjax/kernels/computations/basis_functions.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/base.py +1 -30
- gpjax/kernels/stationary/matern12.py +1 -1
- gpjax/kernels/stationary/matern32.py +1 -1
- gpjax/kernels/stationary/matern52.py +1 -1
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/utils.py +32 -0
- gpjax/mean_functions.py +9 -8
- gpjax/objectives.py +4 -3
- gpjax/parameters.py +0 -10
- gpjax/variational_families.py +65 -45
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/METADATA +21 -21
- gpjax-0.13.0.dist-info/RECORD +52 -0
- gpjax-0.12.0.dist-info/RECORD +0 -52
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/WHEEL +0 -0
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
)
|
|
24
24
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
25
25
|
from gpjax.kernels.stationary.utils import squared_distance
|
|
26
|
-
from gpjax.parameters import PositiveReal
|
|
27
26
|
from gpjax.typing import (
|
|
28
27
|
Array,
|
|
29
28
|
ScalarArray,
|
|
@@ -70,17 +69,15 @@ class RationalQuadratic(StationaryKernel):
|
|
|
70
69
|
compute_engine: The computation engine that the kernel uses to compute the
|
|
71
70
|
covariance matrix.
|
|
72
71
|
"""
|
|
73
|
-
|
|
74
|
-
self.alpha = alpha
|
|
75
|
-
else:
|
|
76
|
-
self.alpha = PositiveReal(alpha)
|
|
72
|
+
self.alpha = alpha
|
|
77
73
|
|
|
78
74
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
79
75
|
|
|
80
76
|
def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
|
|
81
77
|
x = self.slice_input(x) / self.lengthscale.value
|
|
82
78
|
y = self.slice_input(y) / self.lengthscale.value
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
79
|
+
alpha_val = self.alpha.value if hasattr(self.alpha, "value") else self.alpha
|
|
80
|
+
K = self.variance.value * (1 + 0.5 * squared_distance(x, y) / alpha_val) ** (
|
|
81
|
+
-alpha_val
|
|
82
|
+
)
|
|
86
83
|
return K.squeeze()
|
gpjax/likelihoods.py
CHANGED
|
@@ -29,7 +29,6 @@ from gpjax.integrators import (
|
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
31
|
NonNegativeReal,
|
|
32
|
-
Static,
|
|
33
32
|
)
|
|
34
33
|
from gpjax.typing import (
|
|
35
34
|
Array,
|
|
@@ -59,27 +58,27 @@ class AbstractLikelihood(nnx.Module):
|
|
|
59
58
|
self.num_datapoints = num_datapoints
|
|
60
59
|
self.integrator = integrator
|
|
61
60
|
|
|
62
|
-
def __call__(
|
|
61
|
+
def __call__(
|
|
62
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
63
|
+
) -> npd.Distribution:
|
|
63
64
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
66
|
-
|
|
67
|
-
**kwargs (Any): Keyword arguments to be passed to the likelihood's
|
|
68
|
-
`predict` method.
|
|
67
|
+
dist: The predictive distribution to evaluate the likelihood at.
|
|
69
68
|
|
|
70
69
|
Returns:
|
|
71
70
|
The predictive distribution.
|
|
72
71
|
"""
|
|
73
|
-
return self.predict(
|
|
72
|
+
return self.predict(dist)
|
|
74
73
|
|
|
75
74
|
@abc.abstractmethod
|
|
76
|
-
def predict(
|
|
75
|
+
def predict(
|
|
76
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
77
|
+
) -> npd.Distribution:
|
|
77
78
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
78
79
|
|
|
79
80
|
Args:
|
|
80
|
-
|
|
81
|
-
**kwargs (Any): Keyword arguments to be passed to the likelihood's
|
|
82
|
-
`predict` method.
|
|
81
|
+
dist: The predictive distribution to evaluate the likelihood at.
|
|
83
82
|
|
|
84
83
|
Returns:
|
|
85
84
|
npd.Distribution: The predictive distribution.
|
|
@@ -133,9 +132,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
133
132
|
def __init__(
|
|
134
133
|
self,
|
|
135
134
|
num_datapoints: int,
|
|
136
|
-
obs_stddev: tp.Union[
|
|
137
|
-
ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
|
|
138
|
-
] = 1.0,
|
|
135
|
+
obs_stddev: tp.Union[ScalarFloat, Float[Array, "#N"], NonNegativeReal] = 1.0,
|
|
139
136
|
integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
|
|
140
137
|
):
|
|
141
138
|
r"""Initializes the Gaussian likelihood.
|
|
@@ -148,7 +145,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
148
145
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
149
146
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
150
147
|
"""
|
|
151
|
-
if not isinstance(obs_stddev,
|
|
148
|
+
if not isinstance(obs_stddev, NonNegativeReal):
|
|
152
149
|
obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
|
|
153
150
|
self.obs_stddev = obs_stddev
|
|
154
151
|
|
gpjax/linalg/utils.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
"""Utility functions for the linear algebra module."""
|
|
2
2
|
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array
|
|
5
|
+
|
|
3
6
|
from gpjax.linalg.operators import LinearOperator
|
|
4
7
|
|
|
5
8
|
|
|
@@ -31,3 +34,32 @@ def psd(A: LinearOperator) -> LinearOperator:
|
|
|
31
34
|
A.annotations = set()
|
|
32
35
|
A.annotations.add(PSD)
|
|
33
36
|
return A
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
|
|
40
|
+
"""Add jitter to the diagonal of a matrix for numerical stability.
|
|
41
|
+
|
|
42
|
+
This function adds a small positive value (jitter) to the diagonal elements
|
|
43
|
+
of a square matrix to improve numerical stability, particularly for
|
|
44
|
+
Cholesky decompositions and matrix inversions.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
matrix: A square matrix to which jitter will be added.
|
|
48
|
+
jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The matrix with jitter added to its diagonal.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> import jax.numpy as jnp
|
|
55
|
+
>>> from gpjax.linalg.utils import add_jitter
|
|
56
|
+
>>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
|
|
57
|
+
>>> jittered_matrix = add_jitter(matrix, jitter=0.01)
|
|
58
|
+
"""
|
|
59
|
+
if matrix.ndim != 2:
|
|
60
|
+
raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
|
|
61
|
+
|
|
62
|
+
if matrix.shape[0] != matrix.shape[1]:
|
|
63
|
+
raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
|
|
64
|
+
|
|
65
|
+
return matrix + jnp.eye(matrix.shape[0]) * jitter
|
gpjax/mean_functions.py
CHANGED
|
@@ -27,8 +27,6 @@ from jaxtyping import (
|
|
|
27
27
|
|
|
28
28
|
from gpjax.parameters import (
|
|
29
29
|
Parameter,
|
|
30
|
-
Real,
|
|
31
|
-
Static,
|
|
32
30
|
)
|
|
33
31
|
from gpjax.typing import (
|
|
34
32
|
Array,
|
|
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
|
|
|
132
130
|
|
|
133
131
|
def __init__(
|
|
134
132
|
self,
|
|
135
|
-
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter
|
|
133
|
+
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
|
|
136
134
|
):
|
|
137
|
-
if isinstance(constant, Parameter)
|
|
135
|
+
if isinstance(constant, Parameter):
|
|
138
136
|
self.constant = constant
|
|
139
137
|
else:
|
|
140
|
-
self.constant =
|
|
138
|
+
self.constant = jnp.array(constant)
|
|
141
139
|
|
|
142
140
|
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
|
|
143
141
|
r"""Evaluate the mean function at the given points.
|
|
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
|
|
|
148
146
|
Returns:
|
|
149
147
|
Float[Array, "1"]: The evaluated mean function.
|
|
150
148
|
"""
|
|
151
|
-
|
|
149
|
+
if isinstance(self.constant, Parameter):
|
|
150
|
+
return jnp.ones((x.shape[0], 1)) * self.constant.value
|
|
151
|
+
else:
|
|
152
|
+
return jnp.ones((x.shape[0], 1)) * self.constant
|
|
152
153
|
|
|
153
154
|
|
|
154
155
|
class Zero(Constant):
|
|
@@ -160,7 +161,7 @@ class Zero(Constant):
|
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
def __init__(self):
|
|
163
|
-
super().__init__(constant=
|
|
164
|
+
super().__init__(constant=0.0)
|
|
164
165
|
|
|
165
166
|
|
|
166
167
|
class CombinationMeanFunction(AbstractMeanFunction):
|
|
@@ -175,7 +176,7 @@ class CombinationMeanFunction(AbstractMeanFunction):
|
|
|
175
176
|
super().__init__(**kwargs)
|
|
176
177
|
|
|
177
178
|
# Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
|
|
178
|
-
items_list: list[AbstractMeanFunction] = []
|
|
179
|
+
items_list: list[AbstractMeanFunction] = nnx.List([])
|
|
179
180
|
|
|
180
181
|
for item in means:
|
|
181
182
|
if not isinstance(item, AbstractMeanFunction):
|
gpjax/objectives.py
CHANGED
|
@@ -20,6 +20,7 @@ from gpjax.linalg import (
|
|
|
20
20
|
psd,
|
|
21
21
|
solve,
|
|
22
22
|
)
|
|
23
|
+
from gpjax.linalg.utils import add_jitter
|
|
23
24
|
from gpjax.typing import (
|
|
24
25
|
Array,
|
|
25
26
|
ScalarFloat,
|
|
@@ -97,7 +98,7 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
|
|
|
97
98
|
|
|
98
99
|
# Σ = (Kxx + Io²) = LLᵀ
|
|
99
100
|
Kxx = posterior.prior.kernel.gram(x)
|
|
100
|
-
Kxx_dense = Kxx.to_dense()
|
|
101
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
101
102
|
Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
102
103
|
Sigma = psd(Dense(Sigma_dense))
|
|
103
104
|
|
|
@@ -213,7 +214,7 @@ def log_posterior_density(
|
|
|
213
214
|
|
|
214
215
|
# Gram matrix
|
|
215
216
|
Kxx = posterior.prior.kernel.gram(x)
|
|
216
|
-
Kxx_dense = Kxx.to_dense()
|
|
217
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
217
218
|
Kxx = psd(Dense(Kxx_dense))
|
|
218
219
|
Lx = lower_cholesky(Kxx)
|
|
219
220
|
|
|
@@ -349,7 +350,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
349
350
|
noise = variational_family.posterior.likelihood.obs_stddev.value**2
|
|
350
351
|
z = variational_family.inducing_inputs.value
|
|
351
352
|
Kzz = kernel.gram(z)
|
|
352
|
-
Kzz_dense = Kzz.to_dense()
|
|
353
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
|
|
353
354
|
Kzz = psd(Dense(Kzz_dense))
|
|
354
355
|
Kzx = kernel.cross_covariance(z, x)
|
|
355
356
|
Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
|
gpjax/parameters.py
CHANGED
|
@@ -122,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
|
|
|
122
122
|
)
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
class Static(nnx.Variable[T]):
|
|
126
|
-
"""Static parameter that is not trainable."""
|
|
127
|
-
|
|
128
|
-
def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
|
|
129
|
-
_check_is_arraylike(value)
|
|
130
|
-
|
|
131
|
-
super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
|
|
132
|
-
self._tag = tag
|
|
133
|
-
|
|
134
|
-
|
|
135
125
|
class LowerTriangular(Parameter[T]):
|
|
136
126
|
"""Parameter that is a lower triangular matrix."""
|
|
137
127
|
|
gpjax/variational_families.py
CHANGED
|
@@ -40,11 +40,11 @@ from gpjax.linalg import (
|
|
|
40
40
|
psd,
|
|
41
41
|
solve,
|
|
42
42
|
)
|
|
43
|
+
from gpjax.linalg.utils import add_jitter
|
|
43
44
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
45
|
from gpjax.parameters import (
|
|
45
46
|
LowerTriangular,
|
|
46
47
|
Real,
|
|
47
|
-
Static,
|
|
48
48
|
)
|
|
49
49
|
from gpjax.typing import (
|
|
50
50
|
Array,
|
|
@@ -110,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
110
110
|
inducing_inputs: tp.Union[
|
|
111
111
|
Float[Array, "N D"],
|
|
112
112
|
Real,
|
|
113
|
-
Static,
|
|
114
113
|
],
|
|
115
114
|
jitter: ScalarFloat = 1e-6,
|
|
116
115
|
):
|
|
117
|
-
if not isinstance(inducing_inputs,
|
|
116
|
+
if not isinstance(inducing_inputs, Real):
|
|
118
117
|
inducing_inputs = Real(inducing_inputs)
|
|
119
118
|
|
|
120
119
|
self.inducing_inputs = inducing_inputs
|
|
@@ -177,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
177
176
|
approximation and the GP prior.
|
|
178
177
|
"""
|
|
179
178
|
# Unpack variational parameters
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
179
|
+
variational_mean = self.variational_mean.value
|
|
180
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
181
|
+
inducing_inputs = self.inducing_inputs.value
|
|
183
182
|
|
|
184
183
|
# Unpack mean function and kernel
|
|
185
184
|
mean_function = self.posterior.prior.mean_function
|
|
186
185
|
kernel = self.posterior.prior.kernel
|
|
187
186
|
|
|
188
|
-
|
|
189
|
-
Kzz = kernel.gram(
|
|
190
|
-
Kzz = psd(Dense(Kzz.to_dense()
|
|
187
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
188
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
189
|
+
Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
|
|
191
190
|
|
|
192
|
-
|
|
193
|
-
|
|
191
|
+
variational_sqrt_triangular = Triangular(variational_sqrt)
|
|
192
|
+
variational_covariance = (
|
|
193
|
+
variational_sqrt_triangular @ variational_sqrt_triangular.T
|
|
194
|
+
)
|
|
194
195
|
|
|
195
|
-
|
|
196
|
-
|
|
196
|
+
q_inducing = GaussianDistribution(
|
|
197
|
+
loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
|
|
198
|
+
)
|
|
199
|
+
p_inducing = GaussianDistribution(
|
|
200
|
+
loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
|
|
201
|
+
)
|
|
197
202
|
|
|
198
|
-
return
|
|
203
|
+
return q_inducing.kl_divergence(p_inducing)
|
|
199
204
|
|
|
200
205
|
def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
|
|
201
206
|
r"""Compute the predictive distribution of the GP at the test inputs t.
|
|
@@ -215,26 +220,26 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
215
220
|
the test inputs.
|
|
216
221
|
"""
|
|
217
222
|
# Unpack variational parameters
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
223
|
+
variational_mean = self.variational_mean.value
|
|
224
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
225
|
+
inducing_inputs = self.inducing_inputs.value
|
|
221
226
|
|
|
222
227
|
# Unpack mean function and kernel
|
|
223
228
|
mean_function = self.posterior.prior.mean_function
|
|
224
229
|
kernel = self.posterior.prior.kernel
|
|
225
230
|
|
|
226
|
-
Kzz = kernel.gram(
|
|
227
|
-
Kzz_dense = Kzz.to_dense()
|
|
231
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
232
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
228
233
|
Kzz = psd(Dense(Kzz_dense))
|
|
229
234
|
Lz = lower_cholesky(Kzz)
|
|
230
|
-
|
|
235
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
231
236
|
|
|
232
237
|
# Unpack test inputs
|
|
233
|
-
|
|
238
|
+
test_points = test_inputs
|
|
234
239
|
|
|
235
|
-
Ktt = kernel.gram(
|
|
236
|
-
Kzt = kernel.cross_covariance(
|
|
237
|
-
|
|
240
|
+
Ktt = kernel.gram(test_points)
|
|
241
|
+
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
|
|
242
|
+
test_mean = mean_function(test_points)
|
|
238
243
|
|
|
239
244
|
# Lz⁻¹ Kzt
|
|
240
245
|
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
@@ -243,10 +248,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
243
248
|
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
244
249
|
|
|
245
250
|
# Ktz Kzz⁻¹ sqrt
|
|
246
|
-
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T,
|
|
251
|
+
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
|
|
247
252
|
|
|
248
253
|
# μt + Ktz Kzz⁻¹ (μ - μz)
|
|
249
|
-
mean =
|
|
254
|
+
mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
|
|
250
255
|
|
|
251
256
|
# Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
|
|
252
257
|
covariance = (
|
|
@@ -254,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
254
259
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
255
260
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
256
261
|
)
|
|
257
|
-
|
|
262
|
+
if hasattr(covariance, "to_dense"):
|
|
263
|
+
covariance = covariance.to_dense()
|
|
264
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
265
|
+
covariance = Dense(covariance)
|
|
258
266
|
|
|
259
267
|
return GaussianDistribution(
|
|
260
268
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -329,7 +337,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
329
337
|
kernel = self.posterior.prior.kernel
|
|
330
338
|
|
|
331
339
|
Kzz = kernel.gram(z)
|
|
332
|
-
Kzz_dense = Kzz.to_dense()
|
|
340
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
333
341
|
Kzz = psd(Dense(Kzz_dense))
|
|
334
342
|
Lz = lower_cholesky(Kzz)
|
|
335
343
|
|
|
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
355
363
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
356
364
|
+ jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
|
|
357
365
|
)
|
|
358
|
-
|
|
366
|
+
if hasattr(covariance, "to_dense"):
|
|
367
|
+
covariance = covariance.to_dense()
|
|
368
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
369
|
+
covariance = Dense(covariance)
|
|
359
370
|
|
|
360
371
|
return GaussianDistribution(
|
|
361
372
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
390
401
|
if natural_matrix is None:
|
|
391
402
|
natural_matrix = -0.5 * jnp.eye(self.num_inducing)
|
|
392
403
|
|
|
393
|
-
self.natural_vector =
|
|
394
|
-
self.natural_matrix =
|
|
404
|
+
self.natural_vector = Real(natural_vector)
|
|
405
|
+
self.natural_matrix = Real(natural_matrix)
|
|
395
406
|
|
|
396
407
|
def prior_kl(self) -> ScalarFloat:
|
|
397
408
|
r"""Compute the KL-divergence between our current variational approximation
|
|
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
422
433
|
|
|
423
434
|
# S⁻¹ = -2θ₂
|
|
424
435
|
S_inv = -2 * natural_matrix
|
|
425
|
-
S_inv
|
|
436
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
426
437
|
|
|
427
438
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
428
439
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -441,7 +452,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
441
452
|
|
|
442
453
|
muz = mean_function(z)
|
|
443
454
|
Kzz = kernel.gram(z)
|
|
444
|
-
Kzz_dense = Kzz.to_dense()
|
|
455
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
445
456
|
Kzz = psd(Dense(Kzz_dense))
|
|
446
457
|
|
|
447
458
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
@@ -476,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
476
487
|
|
|
477
488
|
# S⁻¹ = -2θ₂
|
|
478
489
|
S_inv = -2 * natural_matrix
|
|
479
|
-
S_inv
|
|
490
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
480
491
|
|
|
481
492
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
482
493
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -493,7 +504,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
493
504
|
mu = jnp.matmul(S, natural_vector)
|
|
494
505
|
|
|
495
506
|
Kzz = kernel.gram(z)
|
|
496
|
-
Kzz_dense = Kzz.to_dense()
|
|
507
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
497
508
|
Kzz = psd(Dense(Kzz_dense))
|
|
498
509
|
Lz = lower_cholesky(Kzz)
|
|
499
510
|
muz = mean_function(z)
|
|
@@ -520,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
520
531
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
521
532
|
+ jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
|
|
522
533
|
)
|
|
523
|
-
|
|
534
|
+
if hasattr(covariance, "to_dense"):
|
|
535
|
+
covariance = covariance.to_dense()
|
|
536
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
537
|
+
covariance = Dense(covariance)
|
|
524
538
|
|
|
525
539
|
return GaussianDistribution(
|
|
526
540
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -556,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
556
570
|
if expectation_matrix is None:
|
|
557
571
|
expectation_matrix = jnp.eye(self.num_inducing)
|
|
558
572
|
|
|
559
|
-
self.expectation_vector =
|
|
560
|
-
self.expectation_matrix =
|
|
573
|
+
self.expectation_vector = Real(expectation_vector)
|
|
574
|
+
self.expectation_matrix = Real(expectation_matrix)
|
|
561
575
|
|
|
562
576
|
def prior_kl(self) -> ScalarFloat:
|
|
563
577
|
r"""Evaluate the prior KL-divergence.
|
|
@@ -595,12 +609,12 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
595
609
|
# S = η₂ - η₁ η₁ᵀ
|
|
596
610
|
S = expectation_matrix - jnp.outer(mu, mu)
|
|
597
611
|
S = psd(Dense(S))
|
|
598
|
-
S_dense = S.to_dense()
|
|
612
|
+
S_dense = add_jitter(S.to_dense(), self.jitter)
|
|
599
613
|
S = psd(Dense(S_dense))
|
|
600
614
|
|
|
601
615
|
muz = mean_function(z)
|
|
602
616
|
Kzz = kernel.gram(z)
|
|
603
|
-
Kzz_dense = Kzz.to_dense()
|
|
617
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
604
618
|
Kzz = psd(Dense(Kzz_dense))
|
|
605
619
|
|
|
606
620
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
@@ -640,14 +654,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
640
654
|
|
|
641
655
|
# S = η₂ - η₁ η₁ᵀ
|
|
642
656
|
S = expectation_matrix - jnp.matmul(mu, mu.T)
|
|
643
|
-
S = Dense(
|
|
657
|
+
S = Dense(add_jitter(S, self.jitter))
|
|
644
658
|
S = psd(S)
|
|
645
659
|
|
|
646
660
|
# S = sqrt sqrtᵀ
|
|
647
661
|
sqrt = lower_cholesky(S)
|
|
648
662
|
|
|
649
663
|
Kzz = kernel.gram(z)
|
|
650
|
-
Kzz_dense = Kzz.to_dense()
|
|
664
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
651
665
|
Kzz = psd(Dense(Kzz_dense))
|
|
652
666
|
Lz = lower_cholesky(Kzz)
|
|
653
667
|
muz = mean_function(z)
|
|
@@ -677,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
677
691
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
678
692
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
679
693
|
)
|
|
680
|
-
|
|
694
|
+
if hasattr(covariance, "to_dense"):
|
|
695
|
+
covariance = covariance.to_dense()
|
|
696
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
697
|
+
covariance = Dense(covariance)
|
|
681
698
|
|
|
682
699
|
return GaussianDistribution(
|
|
683
700
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -734,7 +751,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
734
751
|
|
|
735
752
|
Kzx = kernel.cross_covariance(z, x)
|
|
736
753
|
Kzz = kernel.gram(z)
|
|
737
|
-
Kzz_dense = Kzz.to_dense()
|
|
754
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
738
755
|
Kzz = psd(Dense(Kzz_dense))
|
|
739
756
|
|
|
740
757
|
# Lz Lzᵀ = Kzz
|
|
@@ -780,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
780
797
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
781
798
|
+ jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
|
|
782
799
|
)
|
|
783
|
-
|
|
800
|
+
if hasattr(covariance, "to_dense"):
|
|
801
|
+
covariance = covariance.to_dense()
|
|
802
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
803
|
+
covariance = Dense(covariance)
|
|
784
804
|
|
|
785
805
|
return GaussianDistribution(
|
|
786
806
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.13.0
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -11,15 +11,14 @@ License-File: LICENSE.txt
|
|
|
11
11
|
Keywords: gaussian-processes jax machine-learning bayesian
|
|
12
12
|
Classifier: Development Status :: 4 - Beta
|
|
13
13
|
Classifier: Programming Language :: Python
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
17
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
19
18
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
20
|
-
Requires-Python:
|
|
19
|
+
Requires-Python: >=3.11
|
|
21
20
|
Requires-Dist: beartype>0.16.1
|
|
22
|
-
Requires-Dist: flax>=0.
|
|
21
|
+
Requires-Dist: flax>=0.12.0
|
|
23
22
|
Requires-Dist: jax>=0.5.0
|
|
24
23
|
Requires-Dist: jaxlib>=0.5.0
|
|
25
24
|
Requires-Dist: jaxtyping>0.2.10
|
|
@@ -60,7 +59,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
|
60
59
|
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
61
60
|
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
62
61
|
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
63
|
-
Requires-Dist: mkdocstrings[python]<0.
|
|
62
|
+
Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
|
|
64
63
|
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
65
64
|
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
66
65
|
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
@@ -80,6 +79,7 @@ Description-Content-Type: text/markdown
|
|
|
80
79
|
[](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax)
|
|
81
80
|
[](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
|
|
82
81
|
[](https://badge.fury.io/py/GPJax)
|
|
82
|
+
[](https://anaconda.org/conda-forge/gpjax)
|
|
83
83
|
[](https://doi.org/10.21105/joss.04455)
|
|
84
84
|
[](https://pepy.tech/project/gpjax)
|
|
85
85
|
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
@@ -126,18 +126,9 @@ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw
|
|
|
126
126
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
127
127
|
process modelling.
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
|
|
133
|
-
project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
|
|
134
|
-
[henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
|
|
135
|
-
[thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
|
|
136
|
-
governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
|
|
137
|
-
[the contributors to
|
|
138
|
-
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
|
|
139
|
-
helped to shape GPJax into the package it is today.
|
|
140
|
-
|
|
129
|
+
We appreciate all [the contributors to
|
|
130
|
+
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
|
|
131
|
+
GPJax into the package it is today.
|
|
141
132
|
|
|
142
133
|
# Supported methods and interfaces
|
|
143
134
|
|
|
@@ -183,13 +174,21 @@ jupytext --to py:percent example.ipynb
|
|
|
183
174
|
|
|
184
175
|
## Stable version
|
|
185
176
|
|
|
186
|
-
The latest stable version of GPJax can be installed
|
|
187
|
-
pip:
|
|
177
|
+
The latest stable version of GPJax can be installed from [PyPI](https://pypi.org/project/gpjax/):
|
|
188
178
|
|
|
189
179
|
```bash
|
|
190
180
|
pip install gpjax
|
|
191
181
|
```
|
|
192
182
|
|
|
183
|
+
or from [conda-forge](https://github.com/conda-forge/gpjax-feedstock):
|
|
184
|
+
|
|
185
|
+
```bash
|
|
186
|
+
# with Pixi
|
|
187
|
+
pixi add gpjax
|
|
188
|
+
# or with conda
|
|
189
|
+
conda install --channel conda-forge gpjax
|
|
190
|
+
```
|
|
191
|
+
|
|
193
192
|
> **Note**
|
|
194
193
|
>
|
|
195
194
|
> We recommend you check your installation version:
|
|
@@ -208,7 +207,7 @@ pip install gpjax
|
|
|
208
207
|
>
|
|
209
208
|
> We advise you create virtual environment before installing:
|
|
210
209
|
> ```
|
|
211
|
-
> conda create -n gpjax_experimental python=3.
|
|
210
|
+
> conda create -n gpjax_experimental python=3.11.0
|
|
212
211
|
> conda activate gpjax_experimental
|
|
213
212
|
> ```
|
|
214
213
|
|
|
@@ -218,13 +217,14 @@ configuration in development mode.
|
|
|
218
217
|
```bash
|
|
219
218
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
220
219
|
cd GPJax
|
|
220
|
+
uv venv
|
|
221
221
|
uv sync --extra dev
|
|
222
222
|
```
|
|
223
223
|
|
|
224
224
|
> We recommend you check your installation passes the supplied unit tests:
|
|
225
225
|
>
|
|
226
226
|
> ```python
|
|
227
|
-
> uv run
|
|
227
|
+
> uv run poe all-tests
|
|
228
228
|
> ```
|
|
229
229
|
|
|
230
230
|
# Citing GPJax
|