gpjax 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/fit.py +11 -6
- gpjax/gps.py +36 -34
- gpjax/kernels/approximations/rff.py +2 -5
- gpjax/kernels/base.py +1 -4
- gpjax/kernels/computations/basis_functions.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/utils.py +32 -0
- gpjax/mean_functions.py +8 -7
- gpjax/objectives.py +4 -3
- gpjax/parameters.py +0 -10
- gpjax/variational_families.py +65 -45
- {gpjax-0.12.0.dist-info → gpjax-0.12.2.dist-info}/METADATA +9 -17
- {gpjax-0.12.0.dist-info → gpjax-0.12.2.dist-info}/RECORD +23 -23
- {gpjax-0.12.0.dist-info → gpjax-0.12.2.dist-info}/WHEEL +0 -0
- {gpjax-0.12.0.dist-info → gpjax-0.12.2.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,10 +40,9 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.12.
|
|
43
|
+
__version__ = "0.12.2"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
|
-
"base",
|
|
47
46
|
"gps",
|
|
48
47
|
"integrators",
|
|
49
48
|
"kernels",
|
|
@@ -55,8 +54,6 @@ __all__ = [
|
|
|
55
54
|
"Dataset",
|
|
56
55
|
"cite",
|
|
57
56
|
"fit",
|
|
58
|
-
"Module",
|
|
59
|
-
"param_field",
|
|
60
57
|
"fit_lbfgs",
|
|
61
58
|
"fit_scipy",
|
|
62
59
|
]
|
gpjax/fit.py
CHANGED
|
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
|
|
|
48
48
|
train_data: Dataset,
|
|
49
49
|
optim: ox.GradientTransformation,
|
|
50
50
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
51
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
51
52
|
key: KeyArray = jr.PRNGKey(42),
|
|
52
53
|
num_iters: int = 100,
|
|
53
54
|
batch_size: int = -1,
|
|
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
|
|
|
65
66
|
>>> import jax.random as jr
|
|
66
67
|
>>> import optax as ox
|
|
67
68
|
>>> import gpjax as gpx
|
|
68
|
-
>>> from gpjax.parameters import PositiveReal
|
|
69
|
+
>>> from gpjax.parameters import PositiveReal
|
|
69
70
|
>>>
|
|
70
71
|
>>> # (1) Create a dataset:
|
|
71
72
|
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
|
|
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
|
|
|
75
76
|
>>> class LinearModel(nnx.Module):
|
|
76
77
|
>>> def __init__(self, weight: float, bias: float):
|
|
77
78
|
>>> self.weight = PositiveReal(weight)
|
|
78
|
-
>>> self.bias =
|
|
79
|
+
>>> self.bias = bias
|
|
79
80
|
>>>
|
|
80
81
|
>>> def __call__(self, x):
|
|
81
|
-
>>> return self.weight.value * x + self.bias
|
|
82
|
+
>>> return self.weight.value * x + self.bias
|
|
82
83
|
>>>
|
|
83
84
|
>>> model = LinearModel(weight=1.0, bias=1.0)
|
|
84
85
|
>>>
|
|
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
|
|
|
100
101
|
train_data (Dataset): The training data to be used for the optimisation.
|
|
101
102
|
optim (GradientTransformation): The Optax optimiser that is to be used for
|
|
102
103
|
learning a parameter set.
|
|
104
|
+
trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
|
|
105
|
+
Defaults to nnx.Param (all Parameter instances).
|
|
103
106
|
num_iters (int): The number of optimisation steps to run. Defaults
|
|
104
107
|
to 100.
|
|
105
108
|
batch_size (int): The size of the mini-batch to use. Defaults to -1
|
|
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
|
|
|
127
130
|
_check_verbose(verbose)
|
|
128
131
|
|
|
129
132
|
# Model state filtering
|
|
130
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
133
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
131
134
|
|
|
132
135
|
# Parameters bijection to unconstrained space
|
|
133
136
|
if params_bijection is not None:
|
|
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
182
185
|
model: Model,
|
|
183
186
|
objective: Objective,
|
|
184
187
|
train_data: Dataset,
|
|
188
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
185
189
|
max_iters: int = 500,
|
|
186
190
|
verbose: bool = True,
|
|
187
191
|
safe: bool = True,
|
|
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
210
214
|
_check_verbose(verbose)
|
|
211
215
|
|
|
212
216
|
# Model state filtering
|
|
213
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
217
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
214
218
|
|
|
215
219
|
# Parameters bijection to unconstrained space
|
|
216
220
|
params = transform(params, DEFAULT_BIJECTION, inverse=True)
|
|
@@ -258,6 +262,7 @@ def fit_lbfgs(
|
|
|
258
262
|
objective: Objective,
|
|
259
263
|
train_data: Dataset,
|
|
260
264
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
265
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
261
266
|
max_iters: int = 100,
|
|
262
267
|
safe: bool = True,
|
|
263
268
|
max_linesearch_steps: int = 32,
|
|
@@ -290,7 +295,7 @@ def fit_lbfgs(
|
|
|
290
295
|
_check_num_iters(max_iters)
|
|
291
296
|
|
|
292
297
|
# Model state filtering
|
|
293
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
298
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
294
299
|
|
|
295
300
|
# Parameters bijection to unconstrained space
|
|
296
301
|
if params_bijection is not None:
|
gpjax/gps.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
|
-
from flax import nnx
|
|
20
19
|
import jax.numpy as jnp
|
|
21
20
|
import jax.random as jr
|
|
21
|
+
from flax import nnx
|
|
22
22
|
from jaxtyping import (
|
|
23
23
|
Float,
|
|
24
24
|
Num,
|
|
@@ -35,16 +35,15 @@ from gpjax.likelihoods import (
|
|
|
35
35
|
)
|
|
36
36
|
from gpjax.linalg import (
|
|
37
37
|
Dense,
|
|
38
|
-
Identity,
|
|
39
38
|
psd,
|
|
40
39
|
solve,
|
|
41
40
|
)
|
|
42
41
|
from gpjax.linalg.operations import lower_cholesky
|
|
42
|
+
from gpjax.linalg.utils import add_jitter
|
|
43
43
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
44
|
from gpjax.parameters import (
|
|
45
45
|
Parameter,
|
|
46
46
|
Real,
|
|
47
|
-
Static,
|
|
48
47
|
)
|
|
49
48
|
from gpjax.typing import (
|
|
50
49
|
Array,
|
|
@@ -78,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
78
77
|
self.mean_function = mean_function
|
|
79
78
|
self.jitter = jitter
|
|
80
79
|
|
|
81
|
-
def __call__(self,
|
|
80
|
+
def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
82
81
|
r"""Evaluate the Gaussian process at the given points.
|
|
83
82
|
|
|
84
83
|
The output of this function is a
|
|
@@ -91,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
91
90
|
`__call__` method and should instead define a `predict` method.
|
|
92
91
|
|
|
93
92
|
Args:
|
|
94
|
-
|
|
95
|
-
**kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
|
|
93
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
96
94
|
|
|
97
95
|
Returns:
|
|
98
96
|
GaussianDistribution: A multivariate normal random variable representation
|
|
99
97
|
of the Gaussian process.
|
|
100
98
|
"""
|
|
101
|
-
return self.predict(
|
|
99
|
+
return self.predict(test_inputs)
|
|
102
100
|
|
|
103
101
|
@abstractmethod
|
|
104
|
-
def predict(self,
|
|
102
|
+
def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
105
103
|
r"""Evaluate the predictive distribution.
|
|
106
104
|
|
|
107
105
|
Compute the latent function's multivariate normal distribution for a
|
|
@@ -109,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
109
107
|
this method must be implemented.
|
|
110
108
|
|
|
111
109
|
Args:
|
|
112
|
-
|
|
113
|
-
**kwargs (Any): Keyword arguments to the predict method.
|
|
110
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
114
111
|
|
|
115
112
|
Returns:
|
|
116
113
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -249,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
|
|
|
249
246
|
GaussianDistribution: A multivariate normal random variable representation
|
|
250
247
|
of the Gaussian process.
|
|
251
248
|
"""
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
|
|
249
|
+
mean_at_test = self.mean_function(test_inputs)
|
|
250
|
+
Kxx = self.kernel.gram(test_inputs)
|
|
251
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
256
252
|
Kxx = psd(Dense(Kxx_dense))
|
|
257
253
|
|
|
258
|
-
return GaussianDistribution(jnp.atleast_1d(
|
|
254
|
+
return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
|
|
259
255
|
|
|
260
256
|
def sample_approx(
|
|
261
257
|
self,
|
|
@@ -359,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
359
355
|
self.likelihood = likelihood
|
|
360
356
|
self.jitter = jitter
|
|
361
357
|
|
|
362
|
-
def __call__(
|
|
358
|
+
def __call__(
|
|
359
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
360
|
+
) -> GaussianDistribution:
|
|
363
361
|
r"""Evaluate the Gaussian process posterior at the given points.
|
|
364
362
|
|
|
365
363
|
The output of this function is a
|
|
@@ -368,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
368
366
|
evaluated and the distribution can be sampled.
|
|
369
367
|
|
|
370
368
|
Under the hood, `__call__` is calling the objects `predict` method. For this
|
|
371
|
-
reasons, classes inheriting the `
|
|
369
|
+
reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
|
|
372
370
|
`__call__` method and should instead define a `predict` method.
|
|
373
371
|
|
|
374
372
|
Args:
|
|
375
|
-
|
|
376
|
-
|
|
373
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
374
|
+
train_data: Training dataset to condition on.
|
|
377
375
|
|
|
378
376
|
Returns:
|
|
379
377
|
GaussianDistribution: A multivariate normal random variable representation
|
|
380
378
|
of the Gaussian process.
|
|
381
379
|
"""
|
|
382
|
-
return self.predict(
|
|
380
|
+
return self.predict(test_inputs, train_data)
|
|
383
381
|
|
|
384
382
|
@abstractmethod
|
|
385
|
-
def predict(
|
|
383
|
+
def predict(
|
|
384
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
385
|
+
) -> GaussianDistribution:
|
|
386
386
|
r"""Compute the latent function's multivariate normal distribution for a
|
|
387
|
-
given set of parameters. For any class inheriting the `
|
|
387
|
+
given set of parameters. For any class inheriting the `AbstractPosterior` class,
|
|
388
388
|
this method must be implemented.
|
|
389
389
|
|
|
390
390
|
Args:
|
|
391
|
-
|
|
392
|
-
|
|
391
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
392
|
+
train_data: Training dataset to condition on.
|
|
393
393
|
|
|
394
394
|
Returns:
|
|
395
395
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -503,22 +503,24 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
503
503
|
|
|
504
504
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
505
505
|
Kxx = self.prior.kernel.gram(x)
|
|
506
|
-
Kxx_dense =
|
|
506
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
507
507
|
Kxx = Dense(Kxx_dense)
|
|
508
508
|
|
|
509
509
|
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
510
|
Sigma = psd(Dense(Sigma_dense))
|
|
511
|
+
L_sigma = lower_cholesky(Sigma)
|
|
511
512
|
|
|
512
513
|
mean_t = self.prior.mean_function(t)
|
|
513
514
|
Ktt = self.prior.kernel.gram(t)
|
|
514
515
|
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
515
|
-
Sigma_inv_Kxt = solve(Sigma, Kxt)
|
|
516
516
|
|
|
517
|
-
|
|
517
|
+
L_inv_Kxt = solve(L_sigma, Kxt)
|
|
518
|
+
L_inv_y_diff = solve(L_sigma, y - mx)
|
|
519
|
+
|
|
520
|
+
mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
|
|
518
521
|
|
|
519
|
-
|
|
520
|
-
covariance =
|
|
521
|
-
covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
|
|
522
|
+
covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
|
|
523
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
522
524
|
covariance = psd(Dense(covariance))
|
|
523
525
|
|
|
524
526
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
@@ -577,7 +579,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
577
579
|
|
|
578
580
|
obs_var = self.likelihood.obs_stddev.value**2
|
|
579
581
|
Kxx = self.prior.kernel.gram(train_data.X)
|
|
580
|
-
Sigma = Kxx
|
|
582
|
+
Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
|
|
581
583
|
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
582
584
|
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
583
585
|
Phi = fourier_feature_fn(train_data.X)
|
|
@@ -643,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
643
645
|
|
|
644
646
|
# TODO: static or intermediate?
|
|
645
647
|
self.latent = latent if isinstance(latent, Parameter) else Real(latent)
|
|
646
|
-
self.key =
|
|
648
|
+
self.key = key
|
|
647
649
|
|
|
648
650
|
def predict(
|
|
649
651
|
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
@@ -675,7 +677,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
675
677
|
|
|
676
678
|
# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
|
|
677
679
|
Kxx = kernel.gram(x)
|
|
678
|
-
Kxx_dense = Kxx.to_dense()
|
|
680
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
|
|
679
681
|
Kxx = psd(Dense(Kxx_dense))
|
|
680
682
|
Lx = lower_cholesky(Kxx)
|
|
681
683
|
|
|
@@ -698,7 +700,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
698
700
|
|
|
699
701
|
# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
|
|
700
702
|
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
701
|
-
covariance
|
|
703
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
702
704
|
covariance = psd(Dense(covariance))
|
|
703
705
|
|
|
704
706
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
@@ -7,7 +7,6 @@ from jaxtyping import Float
|
|
|
7
7
|
from gpjax.kernels.base import AbstractKernel
|
|
8
8
|
from gpjax.kernels.computations import BasisFunctionComputation
|
|
9
9
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
10
|
-
from gpjax.parameters import Static
|
|
11
10
|
from gpjax.typing import (
|
|
12
11
|
Array,
|
|
13
12
|
KeyArray,
|
|
@@ -66,10 +65,8 @@ class RFF(AbstractKernel):
|
|
|
66
65
|
"Please specify the n_dims argument for the base kernel."
|
|
67
66
|
)
|
|
68
67
|
|
|
69
|
-
self.frequencies =
|
|
70
|
-
self.
|
|
71
|
-
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
72
|
-
)
|
|
68
|
+
self.frequencies = self.base_kernel.spectral_density.sample(
|
|
69
|
+
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
73
70
|
)
|
|
74
71
|
self.name = f"{self.base_kernel.name} (RFF)"
|
|
75
72
|
|
gpjax/kernels/base.py
CHANGED
|
@@ -32,7 +32,6 @@ from gpjax.linalg import LinearOperator
|
|
|
32
32
|
from gpjax.parameters import (
|
|
33
33
|
Parameter,
|
|
34
34
|
Real,
|
|
35
|
-
Static,
|
|
36
35
|
)
|
|
37
36
|
from gpjax.typing import (
|
|
38
37
|
Array,
|
|
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
|
|
|
221
220
|
def __init__(
|
|
222
221
|
self,
|
|
223
222
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
224
|
-
constant: tp.Union[
|
|
225
|
-
ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
|
|
226
|
-
] = jnp.array(0.0),
|
|
223
|
+
constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
|
|
227
224
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
228
225
|
):
|
|
229
226
|
if isinstance(constant, Parameter):
|
|
@@ -57,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
57
57
|
Returns:
|
|
58
58
|
A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
|
|
59
59
|
"""
|
|
60
|
-
frequencies = kernel.frequencies
|
|
60
|
+
frequencies = kernel.frequencies
|
|
61
61
|
scaling_factor = kernel.base_kernel.lengthscale.value
|
|
62
62
|
z = jnp.matmul(x, (frequencies / scaling_factor).T)
|
|
63
63
|
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
|
|
@@ -42,7 +42,7 @@ class EigenKernelComputation(AbstractKernelComputation):
|
|
|
42
42
|
# Transform the eigenvalues of the graph Laplacian according to the
|
|
43
43
|
# RBF kernel's SPDE form.
|
|
44
44
|
S = jnp.power(
|
|
45
|
-
kernel.eigenvalues
|
|
45
|
+
kernel.eigenvalues
|
|
46
46
|
+ 2
|
|
47
47
|
* kernel.smoothness.value
|
|
48
48
|
/ kernel.lengthscale.value
|
|
@@ -30,7 +30,6 @@ from gpjax.kernels.stationary.base import StationaryKernel
|
|
|
30
30
|
from gpjax.parameters import (
|
|
31
31
|
Parameter,
|
|
32
32
|
PositiveReal,
|
|
33
|
-
Static,
|
|
34
33
|
)
|
|
35
34
|
from gpjax.typing import (
|
|
36
35
|
Array,
|
|
@@ -55,9 +54,9 @@ class GraphKernel(StationaryKernel):
|
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
num_vertex: tp.Union[ScalarInt, None]
|
|
58
|
-
laplacian:
|
|
59
|
-
eigenvalues:
|
|
60
|
-
eigenvectors:
|
|
57
|
+
laplacian: Float[Array, "N N"]
|
|
58
|
+
eigenvalues: Float[Array, "N 1"]
|
|
59
|
+
eigenvectors: Float[Array, "N N"]
|
|
61
60
|
name: str = "Graph Matérn"
|
|
62
61
|
|
|
63
62
|
def __init__(
|
|
@@ -91,11 +90,11 @@ class GraphKernel(StationaryKernel):
|
|
|
91
90
|
else:
|
|
92
91
|
self.smoothness = PositiveReal(smoothness)
|
|
93
92
|
|
|
94
|
-
self.laplacian =
|
|
95
|
-
evals, eigenvectors = jnp.linalg.eigh(self.laplacian
|
|
96
|
-
self.eigenvectors =
|
|
97
|
-
self.eigenvalues =
|
|
98
|
-
self.num_vertex = self.eigenvalues.
|
|
93
|
+
self.laplacian = laplacian
|
|
94
|
+
evals, eigenvectors = jnp.linalg.eigh(self.laplacian)
|
|
95
|
+
self.eigenvectors = eigenvectors
|
|
96
|
+
self.eigenvalues = evals.reshape(-1, 1)
|
|
97
|
+
self.num_vertex = self.eigenvalues.shape[0]
|
|
99
98
|
|
|
100
99
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
101
100
|
|
|
@@ -107,7 +106,7 @@ class GraphKernel(StationaryKernel):
|
|
|
107
106
|
S,
|
|
108
107
|
**kwargs,
|
|
109
108
|
):
|
|
110
|
-
Kxx = (jax_gather_nd(self.eigenvectors
|
|
111
|
-
jax_gather_nd(self.eigenvectors
|
|
109
|
+
Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
|
|
110
|
+
jax_gather_nd(self.eigenvectors, y)
|
|
112
111
|
) # shape (n,n)
|
|
113
112
|
return Kxx.squeeze()
|
|
@@ -25,7 +25,6 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
)
|
|
26
26
|
from gpjax.parameters import (
|
|
27
27
|
NonNegativeReal,
|
|
28
|
-
PositiveReal,
|
|
29
28
|
)
|
|
30
29
|
from gpjax.typing import (
|
|
31
30
|
Array,
|
|
@@ -82,30 +81,13 @@ class ArcCosine(AbstractKernel):
|
|
|
82
81
|
|
|
83
82
|
self.order = order
|
|
84
83
|
|
|
85
|
-
|
|
86
|
-
self.weight_variance = weight_variance
|
|
87
|
-
else:
|
|
88
|
-
self.weight_variance = PositiveReal(weight_variance)
|
|
89
|
-
if tp.TYPE_CHECKING:
|
|
90
|
-
self.weight_variance = tp.cast(
|
|
91
|
-
PositiveReal[WeightVariance], self.weight_variance
|
|
92
|
-
)
|
|
84
|
+
self.weight_variance = weight_variance
|
|
93
85
|
|
|
94
86
|
if isinstance(variance, nnx.Variable):
|
|
95
87
|
self.variance = variance
|
|
96
88
|
else:
|
|
97
89
|
self.variance = NonNegativeReal(variance)
|
|
98
|
-
|
|
99
|
-
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
100
|
-
|
|
101
|
-
if isinstance(bias_variance, nnx.Variable):
|
|
102
|
-
self.bias_variance = bias_variance
|
|
103
|
-
else:
|
|
104
|
-
self.bias_variance = PositiveReal(bias_variance)
|
|
105
|
-
if tp.TYPE_CHECKING:
|
|
106
|
-
self.bias_variance = tp.cast(
|
|
107
|
-
PositiveReal[ScalarArray], self.bias_variance
|
|
108
|
-
)
|
|
90
|
+
self.bias_variance = bias_variance
|
|
109
91
|
|
|
110
92
|
self.name = f"ArcCosine (order {self.order})"
|
|
111
93
|
|
|
@@ -141,7 +123,17 @@ class ArcCosine(AbstractKernel):
|
|
|
141
123
|
Returns:
|
|
142
124
|
ScalarFloat: The value of the weighted product between the two arguments``.
|
|
143
125
|
"""
|
|
144
|
-
|
|
126
|
+
weight_var = (
|
|
127
|
+
self.weight_variance.value
|
|
128
|
+
if hasattr(self.weight_variance, "value")
|
|
129
|
+
else self.weight_variance
|
|
130
|
+
)
|
|
131
|
+
bias_var = (
|
|
132
|
+
self.bias_variance.value
|
|
133
|
+
if hasattr(self.bias_variance, "value")
|
|
134
|
+
else self.bias_variance
|
|
135
|
+
)
|
|
136
|
+
return jnp.inner(weight_var * x, y) + bias_var
|
|
145
137
|
|
|
146
138
|
def _J(self, theta: ScalarFloat) -> ScalarFloat:
|
|
147
139
|
r"""Evaluate the angular dependency function corresponding to the desired order.
|
|
@@ -69,12 +69,9 @@ class Polynomial(AbstractKernel):
|
|
|
69
69
|
|
|
70
70
|
self.degree = degree
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
self.shift = PositiveReal(shift)
|
|
76
|
-
if tp.TYPE_CHECKING:
|
|
77
|
-
self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
|
|
72
|
+
self.shift = shift
|
|
73
|
+
if tp.TYPE_CHECKING and not isinstance(shift, nnx.Variable):
|
|
74
|
+
self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
|
|
78
75
|
|
|
79
76
|
if isinstance(variance, nnx.Variable):
|
|
80
77
|
self.variance = variance
|
|
@@ -88,7 +85,9 @@ class Polynomial(AbstractKernel):
|
|
|
88
85
|
def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
|
|
89
86
|
x = self.slice_input(x)
|
|
90
87
|
y = self.slice_input(y)
|
|
91
|
-
|
|
92
|
-
|
|
88
|
+
shift_val = self.shift.value if hasattr(self.shift, "value") else self.shift
|
|
89
|
+
variance_val = (
|
|
90
|
+
self.variance.value if hasattr(self.variance, "value") else self.variance
|
|
93
91
|
)
|
|
92
|
+
K = jnp.power(shift_val + variance_val * jnp.dot(x, y), self.degree)
|
|
94
93
|
return K.squeeze()
|
|
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
DenseKernelComputation,
|
|
24
24
|
)
|
|
25
25
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
26
|
-
from gpjax.parameters import PositiveReal
|
|
27
26
|
from gpjax.typing import (
|
|
28
27
|
Array,
|
|
29
28
|
ScalarArray,
|
|
@@ -72,10 +71,7 @@ class Periodic(StationaryKernel):
|
|
|
72
71
|
covariance matrix.
|
|
73
72
|
"""
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
self.period = period
|
|
77
|
-
else:
|
|
78
|
-
self.period = PositiveReal(period)
|
|
74
|
+
self.period = period
|
|
79
75
|
|
|
80
76
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
81
77
|
|
|
@@ -84,8 +80,9 @@ class Periodic(StationaryKernel):
|
|
|
84
80
|
) -> Float[Array, ""]:
|
|
85
81
|
x = self.slice_input(x)
|
|
86
82
|
y = self.slice_input(y)
|
|
83
|
+
period_val = self.period.value if hasattr(self.period, "value") else self.period
|
|
87
84
|
sine_squared = (
|
|
88
|
-
jnp.sin(jnp.pi * (x - y) /
|
|
85
|
+
jnp.sin(jnp.pi * (x - y) / period_val) / self.lengthscale.value
|
|
89
86
|
) ** 2
|
|
90
87
|
K = self.variance.value * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
|
|
91
88
|
return K.squeeze()
|
|
@@ -24,7 +24,6 @@ from gpjax.kernels.computations import (
|
|
|
24
24
|
)
|
|
25
25
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
26
26
|
from gpjax.kernels.stationary.utils import euclidean_distance
|
|
27
|
-
from gpjax.parameters import SigmoidBounded
|
|
28
27
|
from gpjax.typing import (
|
|
29
28
|
Array,
|
|
30
29
|
ScalarArray,
|
|
@@ -76,10 +75,7 @@ class PoweredExponential(StationaryKernel):
|
|
|
76
75
|
compute_engine: the computation engine that the kernel uses to compute the
|
|
77
76
|
covariance matrix.
|
|
78
77
|
"""
|
|
79
|
-
|
|
80
|
-
self.power = power
|
|
81
|
-
else:
|
|
82
|
-
self.power = SigmoidBounded(power)
|
|
78
|
+
self.power = power
|
|
83
79
|
|
|
84
80
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
85
81
|
|
|
@@ -88,7 +84,6 @@ class PoweredExponential(StationaryKernel):
|
|
|
88
84
|
) -> Float[Array, ""]:
|
|
89
85
|
x = self.slice_input(x) / self.lengthscale.value
|
|
90
86
|
y = self.slice_input(y) / self.lengthscale.value
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
)
|
|
87
|
+
power_val = self.power.value if hasattr(self.power, "value") else self.power
|
|
88
|
+
K = self.variance.value * jnp.exp(-(euclidean_distance(x, y) ** power_val))
|
|
94
89
|
return K.squeeze()
|
|
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
)
|
|
24
24
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
25
25
|
from gpjax.kernels.stationary.utils import squared_distance
|
|
26
|
-
from gpjax.parameters import PositiveReal
|
|
27
26
|
from gpjax.typing import (
|
|
28
27
|
Array,
|
|
29
28
|
ScalarArray,
|
|
@@ -70,17 +69,15 @@ class RationalQuadratic(StationaryKernel):
|
|
|
70
69
|
compute_engine: The computation engine that the kernel uses to compute the
|
|
71
70
|
covariance matrix.
|
|
72
71
|
"""
|
|
73
|
-
|
|
74
|
-
self.alpha = alpha
|
|
75
|
-
else:
|
|
76
|
-
self.alpha = PositiveReal(alpha)
|
|
72
|
+
self.alpha = alpha
|
|
77
73
|
|
|
78
74
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
79
75
|
|
|
80
76
|
def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
|
|
81
77
|
x = self.slice_input(x) / self.lengthscale.value
|
|
82
78
|
y = self.slice_input(y) / self.lengthscale.value
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
79
|
+
alpha_val = self.alpha.value if hasattr(self.alpha, "value") else self.alpha
|
|
80
|
+
K = self.variance.value * (1 + 0.5 * squared_distance(x, y) / alpha_val) ** (
|
|
81
|
+
-alpha_val
|
|
82
|
+
)
|
|
86
83
|
return K.squeeze()
|
gpjax/likelihoods.py
CHANGED
|
@@ -29,7 +29,6 @@ from gpjax.integrators import (
|
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
31
|
NonNegativeReal,
|
|
32
|
-
Static,
|
|
33
32
|
)
|
|
34
33
|
from gpjax.typing import (
|
|
35
34
|
Array,
|
|
@@ -59,27 +58,27 @@ class AbstractLikelihood(nnx.Module):
|
|
|
59
58
|
self.num_datapoints = num_datapoints
|
|
60
59
|
self.integrator = integrator
|
|
61
60
|
|
|
62
|
-
def __call__(
|
|
61
|
+
def __call__(
|
|
62
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
63
|
+
) -> npd.Distribution:
|
|
63
64
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
66
|
-
|
|
67
|
-
**kwargs (Any): Keyword arguments to be passed to the likelihood's
|
|
68
|
-
`predict` method.
|
|
67
|
+
dist: The predictive distribution to evaluate the likelihood at.
|
|
69
68
|
|
|
70
69
|
Returns:
|
|
71
70
|
The predictive distribution.
|
|
72
71
|
"""
|
|
73
|
-
return self.predict(
|
|
72
|
+
return self.predict(dist)
|
|
74
73
|
|
|
75
74
|
@abc.abstractmethod
|
|
76
|
-
def predict(
|
|
75
|
+
def predict(
|
|
76
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
77
|
+
) -> npd.Distribution:
|
|
77
78
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
78
79
|
|
|
79
80
|
Args:
|
|
80
|
-
|
|
81
|
-
**kwargs (Any): Keyword arguments to be passed to the likelihood's
|
|
82
|
-
`predict` method.
|
|
81
|
+
dist: The predictive distribution to evaluate the likelihood at.
|
|
83
82
|
|
|
84
83
|
Returns:
|
|
85
84
|
npd.Distribution: The predictive distribution.
|
|
@@ -133,9 +132,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
133
132
|
def __init__(
|
|
134
133
|
self,
|
|
135
134
|
num_datapoints: int,
|
|
136
|
-
obs_stddev: tp.Union[
|
|
137
|
-
ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
|
|
138
|
-
] = 1.0,
|
|
135
|
+
obs_stddev: tp.Union[ScalarFloat, Float[Array, "#N"], NonNegativeReal] = 1.0,
|
|
139
136
|
integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
|
|
140
137
|
):
|
|
141
138
|
r"""Initializes the Gaussian likelihood.
|
|
@@ -148,7 +145,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
148
145
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
149
146
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
150
147
|
"""
|
|
151
|
-
if not isinstance(obs_stddev,
|
|
148
|
+
if not isinstance(obs_stddev, NonNegativeReal):
|
|
152
149
|
obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
|
|
153
150
|
self.obs_stddev = obs_stddev
|
|
154
151
|
|
gpjax/linalg/utils.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
"""Utility functions for the linear algebra module."""
|
|
2
2
|
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array
|
|
5
|
+
|
|
3
6
|
from gpjax.linalg.operators import LinearOperator
|
|
4
7
|
|
|
5
8
|
|
|
@@ -31,3 +34,32 @@ def psd(A: LinearOperator) -> LinearOperator:
|
|
|
31
34
|
A.annotations = set()
|
|
32
35
|
A.annotations.add(PSD)
|
|
33
36
|
return A
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
|
|
40
|
+
"""Add jitter to the diagonal of a matrix for numerical stability.
|
|
41
|
+
|
|
42
|
+
This function adds a small positive value (jitter) to the diagonal elements
|
|
43
|
+
of a square matrix to improve numerical stability, particularly for
|
|
44
|
+
Cholesky decompositions and matrix inversions.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
matrix: A square matrix to which jitter will be added.
|
|
48
|
+
jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The matrix with jitter added to its diagonal.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> import jax.numpy as jnp
|
|
55
|
+
>>> from gpjax.linalg.utils import add_jitter
|
|
56
|
+
>>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
|
|
57
|
+
>>> jittered_matrix = add_jitter(matrix, jitter=0.01)
|
|
58
|
+
"""
|
|
59
|
+
if matrix.ndim != 2:
|
|
60
|
+
raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
|
|
61
|
+
|
|
62
|
+
if matrix.shape[0] != matrix.shape[1]:
|
|
63
|
+
raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
|
|
64
|
+
|
|
65
|
+
return matrix + jnp.eye(matrix.shape[0]) * jitter
|
gpjax/mean_functions.py
CHANGED
|
@@ -27,8 +27,6 @@ from jaxtyping import (
|
|
|
27
27
|
|
|
28
28
|
from gpjax.parameters import (
|
|
29
29
|
Parameter,
|
|
30
|
-
Real,
|
|
31
|
-
Static,
|
|
32
30
|
)
|
|
33
31
|
from gpjax.typing import (
|
|
34
32
|
Array,
|
|
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
|
|
|
132
130
|
|
|
133
131
|
def __init__(
|
|
134
132
|
self,
|
|
135
|
-
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter
|
|
133
|
+
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
|
|
136
134
|
):
|
|
137
|
-
if isinstance(constant, Parameter)
|
|
135
|
+
if isinstance(constant, Parameter):
|
|
138
136
|
self.constant = constant
|
|
139
137
|
else:
|
|
140
|
-
self.constant =
|
|
138
|
+
self.constant = jnp.array(constant)
|
|
141
139
|
|
|
142
140
|
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
|
|
143
141
|
r"""Evaluate the mean function at the given points.
|
|
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
|
|
|
148
146
|
Returns:
|
|
149
147
|
Float[Array, "1"]: The evaluated mean function.
|
|
150
148
|
"""
|
|
151
|
-
|
|
149
|
+
if isinstance(self.constant, Parameter):
|
|
150
|
+
return jnp.ones((x.shape[0], 1)) * self.constant.value
|
|
151
|
+
else:
|
|
152
|
+
return jnp.ones((x.shape[0], 1)) * self.constant
|
|
152
153
|
|
|
153
154
|
|
|
154
155
|
class Zero(Constant):
|
|
@@ -160,7 +161,7 @@ class Zero(Constant):
|
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
def __init__(self):
|
|
163
|
-
super().__init__(constant=
|
|
164
|
+
super().__init__(constant=0.0)
|
|
164
165
|
|
|
165
166
|
|
|
166
167
|
class CombinationMeanFunction(AbstractMeanFunction):
|
gpjax/objectives.py
CHANGED
|
@@ -20,6 +20,7 @@ from gpjax.linalg import (
|
|
|
20
20
|
psd,
|
|
21
21
|
solve,
|
|
22
22
|
)
|
|
23
|
+
from gpjax.linalg.utils import add_jitter
|
|
23
24
|
from gpjax.typing import (
|
|
24
25
|
Array,
|
|
25
26
|
ScalarFloat,
|
|
@@ -97,7 +98,7 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
|
|
|
97
98
|
|
|
98
99
|
# Σ = (Kxx + Io²) = LLᵀ
|
|
99
100
|
Kxx = posterior.prior.kernel.gram(x)
|
|
100
|
-
Kxx_dense = Kxx.to_dense()
|
|
101
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
101
102
|
Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
102
103
|
Sigma = psd(Dense(Sigma_dense))
|
|
103
104
|
|
|
@@ -213,7 +214,7 @@ def log_posterior_density(
|
|
|
213
214
|
|
|
214
215
|
# Gram matrix
|
|
215
216
|
Kxx = posterior.prior.kernel.gram(x)
|
|
216
|
-
Kxx_dense = Kxx.to_dense()
|
|
217
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
217
218
|
Kxx = psd(Dense(Kxx_dense))
|
|
218
219
|
Lx = lower_cholesky(Kxx)
|
|
219
220
|
|
|
@@ -349,7 +350,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
349
350
|
noise = variational_family.posterior.likelihood.obs_stddev.value**2
|
|
350
351
|
z = variational_family.inducing_inputs.value
|
|
351
352
|
Kzz = kernel.gram(z)
|
|
352
|
-
Kzz_dense = Kzz.to_dense()
|
|
353
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
|
|
353
354
|
Kzz = psd(Dense(Kzz_dense))
|
|
354
355
|
Kzx = kernel.cross_covariance(z, x)
|
|
355
356
|
Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
|
gpjax/parameters.py
CHANGED
|
@@ -122,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
|
|
|
122
122
|
)
|
|
123
123
|
|
|
124
124
|
|
|
125
|
-
class Static(nnx.Variable[T]):
|
|
126
|
-
"""Static parameter that is not trainable."""
|
|
127
|
-
|
|
128
|
-
def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
|
|
129
|
-
_check_is_arraylike(value)
|
|
130
|
-
|
|
131
|
-
super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
|
|
132
|
-
self._tag = tag
|
|
133
|
-
|
|
134
|
-
|
|
135
125
|
class LowerTriangular(Parameter[T]):
|
|
136
126
|
"""Parameter that is a lower triangular matrix."""
|
|
137
127
|
|
gpjax/variational_families.py
CHANGED
|
@@ -40,11 +40,11 @@ from gpjax.linalg import (
|
|
|
40
40
|
psd,
|
|
41
41
|
solve,
|
|
42
42
|
)
|
|
43
|
+
from gpjax.linalg.utils import add_jitter
|
|
43
44
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
45
|
from gpjax.parameters import (
|
|
45
46
|
LowerTriangular,
|
|
46
47
|
Real,
|
|
47
|
-
Static,
|
|
48
48
|
)
|
|
49
49
|
from gpjax.typing import (
|
|
50
50
|
Array,
|
|
@@ -110,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
110
110
|
inducing_inputs: tp.Union[
|
|
111
111
|
Float[Array, "N D"],
|
|
112
112
|
Real,
|
|
113
|
-
Static,
|
|
114
113
|
],
|
|
115
114
|
jitter: ScalarFloat = 1e-6,
|
|
116
115
|
):
|
|
117
|
-
if not isinstance(inducing_inputs,
|
|
116
|
+
if not isinstance(inducing_inputs, Real):
|
|
118
117
|
inducing_inputs = Real(inducing_inputs)
|
|
119
118
|
|
|
120
119
|
self.inducing_inputs = inducing_inputs
|
|
@@ -177,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
177
176
|
approximation and the GP prior.
|
|
178
177
|
"""
|
|
179
178
|
# Unpack variational parameters
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
179
|
+
variational_mean = self.variational_mean.value
|
|
180
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
181
|
+
inducing_inputs = self.inducing_inputs.value
|
|
183
182
|
|
|
184
183
|
# Unpack mean function and kernel
|
|
185
184
|
mean_function = self.posterior.prior.mean_function
|
|
186
185
|
kernel = self.posterior.prior.kernel
|
|
187
186
|
|
|
188
|
-
|
|
189
|
-
Kzz = kernel.gram(
|
|
190
|
-
Kzz = psd(Dense(Kzz.to_dense()
|
|
187
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
188
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
189
|
+
Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
|
|
191
190
|
|
|
192
|
-
|
|
193
|
-
|
|
191
|
+
variational_sqrt_triangular = Triangular(variational_sqrt)
|
|
192
|
+
variational_covariance = (
|
|
193
|
+
variational_sqrt_triangular @ variational_sqrt_triangular.T
|
|
194
|
+
)
|
|
194
195
|
|
|
195
|
-
|
|
196
|
-
|
|
196
|
+
q_inducing = GaussianDistribution(
|
|
197
|
+
loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
|
|
198
|
+
)
|
|
199
|
+
p_inducing = GaussianDistribution(
|
|
200
|
+
loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
|
|
201
|
+
)
|
|
197
202
|
|
|
198
|
-
return
|
|
203
|
+
return q_inducing.kl_divergence(p_inducing)
|
|
199
204
|
|
|
200
205
|
def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
|
|
201
206
|
r"""Compute the predictive distribution of the GP at the test inputs t.
|
|
@@ -215,26 +220,26 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
215
220
|
the test inputs.
|
|
216
221
|
"""
|
|
217
222
|
# Unpack variational parameters
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
223
|
+
variational_mean = self.variational_mean.value
|
|
224
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
225
|
+
inducing_inputs = self.inducing_inputs.value
|
|
221
226
|
|
|
222
227
|
# Unpack mean function and kernel
|
|
223
228
|
mean_function = self.posterior.prior.mean_function
|
|
224
229
|
kernel = self.posterior.prior.kernel
|
|
225
230
|
|
|
226
|
-
Kzz = kernel.gram(
|
|
227
|
-
Kzz_dense = Kzz.to_dense()
|
|
231
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
232
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
228
233
|
Kzz = psd(Dense(Kzz_dense))
|
|
229
234
|
Lz = lower_cholesky(Kzz)
|
|
230
|
-
|
|
235
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
231
236
|
|
|
232
237
|
# Unpack test inputs
|
|
233
|
-
|
|
238
|
+
test_points = test_inputs
|
|
234
239
|
|
|
235
|
-
Ktt = kernel.gram(
|
|
236
|
-
Kzt = kernel.cross_covariance(
|
|
237
|
-
|
|
240
|
+
Ktt = kernel.gram(test_points)
|
|
241
|
+
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
|
|
242
|
+
test_mean = mean_function(test_points)
|
|
238
243
|
|
|
239
244
|
# Lz⁻¹ Kzt
|
|
240
245
|
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
@@ -243,10 +248,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
243
248
|
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
244
249
|
|
|
245
250
|
# Ktz Kzz⁻¹ sqrt
|
|
246
|
-
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T,
|
|
251
|
+
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
|
|
247
252
|
|
|
248
253
|
# μt + Ktz Kzz⁻¹ (μ - μz)
|
|
249
|
-
mean =
|
|
254
|
+
mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
|
|
250
255
|
|
|
251
256
|
# Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
|
|
252
257
|
covariance = (
|
|
@@ -254,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
254
259
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
255
260
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
256
261
|
)
|
|
257
|
-
|
|
262
|
+
if hasattr(covariance, "to_dense"):
|
|
263
|
+
covariance = covariance.to_dense()
|
|
264
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
265
|
+
covariance = Dense(covariance)
|
|
258
266
|
|
|
259
267
|
return GaussianDistribution(
|
|
260
268
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -329,7 +337,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
329
337
|
kernel = self.posterior.prior.kernel
|
|
330
338
|
|
|
331
339
|
Kzz = kernel.gram(z)
|
|
332
|
-
Kzz_dense = Kzz.to_dense()
|
|
340
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
333
341
|
Kzz = psd(Dense(Kzz_dense))
|
|
334
342
|
Lz = lower_cholesky(Kzz)
|
|
335
343
|
|
|
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
355
363
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
356
364
|
+ jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
|
|
357
365
|
)
|
|
358
|
-
|
|
366
|
+
if hasattr(covariance, "to_dense"):
|
|
367
|
+
covariance = covariance.to_dense()
|
|
368
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
369
|
+
covariance = Dense(covariance)
|
|
359
370
|
|
|
360
371
|
return GaussianDistribution(
|
|
361
372
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
390
401
|
if natural_matrix is None:
|
|
391
402
|
natural_matrix = -0.5 * jnp.eye(self.num_inducing)
|
|
392
403
|
|
|
393
|
-
self.natural_vector =
|
|
394
|
-
self.natural_matrix =
|
|
404
|
+
self.natural_vector = Real(natural_vector)
|
|
405
|
+
self.natural_matrix = Real(natural_matrix)
|
|
395
406
|
|
|
396
407
|
def prior_kl(self) -> ScalarFloat:
|
|
397
408
|
r"""Compute the KL-divergence between our current variational approximation
|
|
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
422
433
|
|
|
423
434
|
# S⁻¹ = -2θ₂
|
|
424
435
|
S_inv = -2 * natural_matrix
|
|
425
|
-
S_inv
|
|
436
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
426
437
|
|
|
427
438
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
428
439
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -441,7 +452,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
441
452
|
|
|
442
453
|
muz = mean_function(z)
|
|
443
454
|
Kzz = kernel.gram(z)
|
|
444
|
-
Kzz_dense = Kzz.to_dense()
|
|
455
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
445
456
|
Kzz = psd(Dense(Kzz_dense))
|
|
446
457
|
|
|
447
458
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
@@ -476,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
476
487
|
|
|
477
488
|
# S⁻¹ = -2θ₂
|
|
478
489
|
S_inv = -2 * natural_matrix
|
|
479
|
-
S_inv
|
|
490
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
480
491
|
|
|
481
492
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
482
493
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -493,7 +504,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
493
504
|
mu = jnp.matmul(S, natural_vector)
|
|
494
505
|
|
|
495
506
|
Kzz = kernel.gram(z)
|
|
496
|
-
Kzz_dense = Kzz.to_dense()
|
|
507
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
497
508
|
Kzz = psd(Dense(Kzz_dense))
|
|
498
509
|
Lz = lower_cholesky(Kzz)
|
|
499
510
|
muz = mean_function(z)
|
|
@@ -520,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
520
531
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
521
532
|
+ jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
|
|
522
533
|
)
|
|
523
|
-
|
|
534
|
+
if hasattr(covariance, "to_dense"):
|
|
535
|
+
covariance = covariance.to_dense()
|
|
536
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
537
|
+
covariance = Dense(covariance)
|
|
524
538
|
|
|
525
539
|
return GaussianDistribution(
|
|
526
540
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -556,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
556
570
|
if expectation_matrix is None:
|
|
557
571
|
expectation_matrix = jnp.eye(self.num_inducing)
|
|
558
572
|
|
|
559
|
-
self.expectation_vector =
|
|
560
|
-
self.expectation_matrix =
|
|
573
|
+
self.expectation_vector = Real(expectation_vector)
|
|
574
|
+
self.expectation_matrix = Real(expectation_matrix)
|
|
561
575
|
|
|
562
576
|
def prior_kl(self) -> ScalarFloat:
|
|
563
577
|
r"""Evaluate the prior KL-divergence.
|
|
@@ -595,12 +609,12 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
595
609
|
# S = η₂ - η₁ η₁ᵀ
|
|
596
610
|
S = expectation_matrix - jnp.outer(mu, mu)
|
|
597
611
|
S = psd(Dense(S))
|
|
598
|
-
S_dense = S.to_dense()
|
|
612
|
+
S_dense = add_jitter(S.to_dense(), self.jitter)
|
|
599
613
|
S = psd(Dense(S_dense))
|
|
600
614
|
|
|
601
615
|
muz = mean_function(z)
|
|
602
616
|
Kzz = kernel.gram(z)
|
|
603
|
-
Kzz_dense = Kzz.to_dense()
|
|
617
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
604
618
|
Kzz = psd(Dense(Kzz_dense))
|
|
605
619
|
|
|
606
620
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
@@ -640,14 +654,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
640
654
|
|
|
641
655
|
# S = η₂ - η₁ η₁ᵀ
|
|
642
656
|
S = expectation_matrix - jnp.matmul(mu, mu.T)
|
|
643
|
-
S = Dense(
|
|
657
|
+
S = Dense(add_jitter(S, self.jitter))
|
|
644
658
|
S = psd(S)
|
|
645
659
|
|
|
646
660
|
# S = sqrt sqrtᵀ
|
|
647
661
|
sqrt = lower_cholesky(S)
|
|
648
662
|
|
|
649
663
|
Kzz = kernel.gram(z)
|
|
650
|
-
Kzz_dense = Kzz.to_dense()
|
|
664
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
651
665
|
Kzz = psd(Dense(Kzz_dense))
|
|
652
666
|
Lz = lower_cholesky(Kzz)
|
|
653
667
|
muz = mean_function(z)
|
|
@@ -677,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
677
691
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
678
692
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
679
693
|
)
|
|
680
|
-
|
|
694
|
+
if hasattr(covariance, "to_dense"):
|
|
695
|
+
covariance = covariance.to_dense()
|
|
696
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
697
|
+
covariance = Dense(covariance)
|
|
681
698
|
|
|
682
699
|
return GaussianDistribution(
|
|
683
700
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -734,7 +751,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
734
751
|
|
|
735
752
|
Kzx = kernel.cross_covariance(z, x)
|
|
736
753
|
Kzz = kernel.gram(z)
|
|
737
|
-
Kzz_dense = Kzz.to_dense()
|
|
754
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
738
755
|
Kzz = psd(Dense(Kzz_dense))
|
|
739
756
|
|
|
740
757
|
# Lz Lzᵀ = Kzz
|
|
@@ -780,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
780
797
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
781
798
|
+ jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
|
|
782
799
|
)
|
|
783
|
-
|
|
800
|
+
if hasattr(covariance, "to_dense"):
|
|
801
|
+
covariance = covariance.to_dense()
|
|
802
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
803
|
+
covariance = Dense(covariance)
|
|
784
804
|
|
|
785
805
|
return GaussianDistribution(
|
|
786
806
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.2
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -17,7 +17,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
18
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
19
19
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
20
|
-
Requires-Python:
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
21
|
Requires-Dist: beartype>0.16.1
|
|
22
22
|
Requires-Dist: flax>=0.10.0
|
|
23
23
|
Requires-Dist: jax>=0.5.0
|
|
@@ -60,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
|
60
60
|
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
61
61
|
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
62
62
|
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
63
|
-
Requires-Dist: mkdocstrings[python]<0.
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
|
|
64
64
|
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
65
65
|
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
66
66
|
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
@@ -126,18 +126,9 @@ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw
|
|
|
126
126
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
127
127
|
process modelling.
|
|
128
128
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
|
|
133
|
-
project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
|
|
134
|
-
[henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
|
|
135
|
-
[thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
|
|
136
|
-
governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
|
|
137
|
-
[the contributors to
|
|
138
|
-
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
|
|
139
|
-
helped to shape GPJax into the package it is today.
|
|
140
|
-
|
|
129
|
+
We appreciate all [the contributors to
|
|
130
|
+
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
|
|
131
|
+
GPJax into the package it is today.
|
|
141
132
|
|
|
142
133
|
# Supported methods and interfaces
|
|
143
134
|
|
|
@@ -218,13 +209,14 @@ configuration in development mode.
|
|
|
218
209
|
```bash
|
|
219
210
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
220
211
|
cd GPJax
|
|
221
|
-
|
|
212
|
+
hatch env create
|
|
213
|
+
hatch shell
|
|
222
214
|
```
|
|
223
215
|
|
|
224
216
|
> We recommend you check your installation passes the supplied unit tests:
|
|
225
217
|
>
|
|
226
218
|
> ```python
|
|
227
|
-
>
|
|
219
|
+
> hatch run dev:test
|
|
228
220
|
> ```
|
|
229
221
|
|
|
230
222
|
# Citing GPJax
|
|
@@ -1,52 +1,52 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=RzwpixFXn6HNHLVLy4LVXhFUk2c-_ce6n1gjZ2B93F0,1641
|
|
2
2
|
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
|
-
gpjax/fit.py,sha256=
|
|
6
|
-
gpjax/gps.py,sha256
|
|
5
|
+
gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
|
|
6
|
+
gpjax/gps.py,sha256=ipaeYMnPffhKK_JsEHe4fF8GmolQIjXB1YbyfUIL8H4,30118
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=
|
|
9
|
-
gpjax/mean_functions.py,sha256
|
|
8
|
+
gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
|
|
9
|
+
gpjax/mean_functions.py,sha256=KiHQXI-b7o0Vi5KQxGm6RNsUjitJc9jEOCq2GrSx4II,6531
|
|
10
10
|
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
|
-
gpjax/objectives.py,sha256=
|
|
12
|
-
gpjax/parameters.py,sha256=
|
|
11
|
+
gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
|
|
12
|
+
gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
|
|
13
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=
|
|
15
|
+
gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
|
|
16
16
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
|
-
gpjax/kernels/base.py,sha256=
|
|
17
|
+
gpjax/kernels/base.py,sha256=4Lx8y3kPX4WqQZGRGAsBkqn_i6FlfoAhSn9Tv415xuQ,11551
|
|
18
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
-
gpjax/kernels/approximations/rff.py,sha256=
|
|
19
|
+
gpjax/kernels/approximations/rff.py,sha256=GbNUmDPEKEKuMwxUcocxl_9IFR3Q9KEPZXzjy_ZD-2w,4043
|
|
20
20
|
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
21
21
|
gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
|
|
22
|
-
gpjax/kernels/computations/basis_functions.py,sha256=
|
|
22
|
+
gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
|
|
23
23
|
gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
|
|
24
24
|
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
25
25
|
gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
|
|
26
|
-
gpjax/kernels/computations/eigen.py,sha256=
|
|
26
|
+
gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
|
|
27
27
|
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
28
|
-
gpjax/kernels/non_euclidean/graph.py,sha256=
|
|
28
|
+
gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
|
|
29
29
|
gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
|
|
30
30
|
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
|
-
gpjax/kernels/nonstationary/arccosine.py,sha256=
|
|
31
|
+
gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
|
|
32
32
|
gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
|
|
33
|
-
gpjax/kernels/nonstationary/polynomial.py,sha256=
|
|
33
|
+
gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
|
|
34
34
|
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
35
|
gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
|
|
36
36
|
gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
|
|
37
37
|
gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
|
|
38
38
|
gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
|
|
39
|
-
gpjax/kernels/stationary/periodic.py,sha256=
|
|
40
|
-
gpjax/kernels/stationary/powered_exponential.py,sha256=
|
|
41
|
-
gpjax/kernels/stationary/rational_quadratic.py,sha256=
|
|
39
|
+
gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
|
|
40
|
+
gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
|
|
41
|
+
gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
|
|
42
42
|
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
43
43
|
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
44
44
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
45
45
|
gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
46
46
|
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
47
|
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
|
-
gpjax/linalg/utils.py,sha256=
|
|
49
|
-
gpjax-0.12.
|
|
50
|
-
gpjax-0.12.
|
|
51
|
-
gpjax-0.12.
|
|
52
|
-
gpjax-0.12.
|
|
48
|
+
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
+
gpjax-0.12.2.dist-info/METADATA,sha256=eckQKXiBXi8XbBeJFviBAIPdBGVWGFQg7wVZwMfPPxs,10129
|
|
50
|
+
gpjax-0.12.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
+
gpjax-0.12.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.12.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|