gpjax 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +3 -3
- gpjax/dataset.py +1 -1
- gpjax/fit.py +1 -1
- gpjax/gps.py +197 -64
- gpjax/kernels/__init__.py +1 -1
- gpjax/kernels/base.py +2 -2
- gpjax/kernels/computations/__init__.py +1 -1
- gpjax/kernels/computations/base.py +1 -1
- gpjax/kernels/computations/basis_functions.py +2 -4
- gpjax/kernels/computations/constant_diagonal.py +1 -1
- gpjax/kernels/computations/dense.py +1 -1
- gpjax/kernels/computations/diagonal.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/__init__.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +18 -6
- gpjax/kernels/non_euclidean/utils.py +1 -1
- gpjax/kernels/nonstationary/__init__.py +1 -1
- gpjax/kernels/nonstationary/arccosine.py +1 -1
- gpjax/kernels/nonstationary/linear.py +1 -1
- gpjax/kernels/nonstationary/polynomial.py +1 -1
- gpjax/kernels/stationary/__init__.py +1 -1
- gpjax/kernels/stationary/base.py +1 -1
- gpjax/kernels/stationary/matern12.py +1 -1
- gpjax/kernels/stationary/matern32.py +1 -1
- gpjax/kernels/stationary/matern52.py +1 -1
- gpjax/kernels/stationary/periodic.py +1 -1
- gpjax/kernels/stationary/powered_exponential.py +1 -1
- gpjax/kernels/stationary/rational_quadratic.py +1 -1
- gpjax/kernels/stationary/rbf.py +1 -1
- gpjax/kernels/stationary/utils.py +1 -1
- gpjax/kernels/stationary/white.py +1 -1
- gpjax/scan.py +1 -1
- {gpjax-0.13.1.dist-info → gpjax-0.13.3.dist-info}/METADATA +12 -12
- gpjax-0.13.3.dist-info/RECORD +52 -0
- gpjax-0.13.1.dist-info/RECORD +0 -52
- {gpjax-0.13.1.dist-info → gpjax-0.13.3.dist-info}/WHEEL +0 -0
- {gpjax-0.13.1.dist-info → gpjax-0.13.3.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -38,9 +38,9 @@ from gpjax.fit import (
|
|
|
38
38
|
|
|
39
39
|
__license__ = "MIT"
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
|
-
__url__ = "https://github.com/
|
|
42
|
-
__contributors__ = "https://github.com/
|
|
43
|
-
__version__ = "0.13.
|
|
41
|
+
__url__ = "https://github.com/thomaspinder/GPJax"
|
|
42
|
+
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
|
|
43
|
+
__version__ = "0.13.3"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
gpjax/dataset.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/fit.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 The
|
|
1
|
+
# Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/gps.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
5
5
|
# You may obtain a copy of the License at
|
|
6
6
|
#
|
|
7
|
-
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
8
|
#
|
|
9
9
|
# Unless required by applicable law or agreed to in writing, software
|
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
@@ -13,10 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# from __future__ import annotations
|
|
16
|
+
|
|
16
17
|
from abc import abstractmethod
|
|
18
|
+
from typing import Literal
|
|
17
19
|
|
|
18
20
|
import beartype.typing as tp
|
|
19
21
|
from flax import nnx
|
|
22
|
+
import jax
|
|
20
23
|
import jax.numpy as jnp
|
|
21
24
|
import jax.random as jr
|
|
22
25
|
from jaxtyping import (
|
|
@@ -35,10 +38,13 @@ from gpjax.likelihoods import (
|
|
|
35
38
|
)
|
|
36
39
|
from gpjax.linalg import (
|
|
37
40
|
Dense,
|
|
41
|
+
Diagonal,
|
|
38
42
|
psd,
|
|
39
43
|
solve,
|
|
40
44
|
)
|
|
41
|
-
from gpjax.linalg.operations import
|
|
45
|
+
from gpjax.linalg.operations import (
|
|
46
|
+
lower_cholesky,
|
|
47
|
+
)
|
|
42
48
|
from gpjax.linalg.utils import add_jitter
|
|
43
49
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
50
|
from gpjax.parameters import (
|
|
@@ -77,7 +83,12 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
77
83
|
self.mean_function = mean_function
|
|
78
84
|
self.jitter = jitter
|
|
79
85
|
|
|
80
|
-
def __call__(
|
|
86
|
+
def __call__(
|
|
87
|
+
self,
|
|
88
|
+
test_inputs: Num[Array, "N D"],
|
|
89
|
+
*,
|
|
90
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
91
|
+
) -> GaussianDistribution:
|
|
81
92
|
r"""Evaluate the Gaussian process at the given points.
|
|
82
93
|
|
|
83
94
|
The output of this function is a
|
|
@@ -91,15 +102,27 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
91
102
|
|
|
92
103
|
Args:
|
|
93
104
|
test_inputs: Input locations where the GP should be evaluated.
|
|
105
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
106
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
107
|
+
or just the the standard-deviation of the predictive distribution at
|
|
108
|
+
the test_inputs.
|
|
94
109
|
|
|
95
110
|
Returns:
|
|
96
111
|
GaussianDistribution: A multivariate normal random variable representation
|
|
97
112
|
of the Gaussian process.
|
|
98
113
|
"""
|
|
99
|
-
return self.predict(
|
|
114
|
+
return self.predict(
|
|
115
|
+
test_inputs,
|
|
116
|
+
return_covariance_type=return_covariance_type,
|
|
117
|
+
)
|
|
100
118
|
|
|
101
119
|
@abstractmethod
|
|
102
|
-
def predict(
|
|
120
|
+
def predict(
|
|
121
|
+
self,
|
|
122
|
+
test_inputs: Num[Array, "N D"],
|
|
123
|
+
*,
|
|
124
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
125
|
+
) -> GaussianDistribution:
|
|
103
126
|
r"""Evaluate the predictive distribution.
|
|
104
127
|
|
|
105
128
|
Compute the latent function's multivariate normal distribution for a
|
|
@@ -108,6 +131,10 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
108
131
|
|
|
109
132
|
Args:
|
|
110
133
|
test_inputs: Input locations where the GP should be evaluated.
|
|
134
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
135
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
136
|
+
or just the the standard-deviation of the predictive distribution at
|
|
137
|
+
the test_inputs.
|
|
111
138
|
|
|
112
139
|
Returns:
|
|
113
140
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -123,8 +150,8 @@ class Prior(AbstractPrior[M, K]):
|
|
|
123
150
|
r"""A Gaussian process prior object.
|
|
124
151
|
|
|
125
152
|
The GP is parameterised by a
|
|
126
|
-
[mean](https://docs.
|
|
127
|
-
and [kernel](https://docs.
|
|
153
|
+
[mean](https://docs.thomaspinder.com/api/mean_functions/)
|
|
154
|
+
and [kernel](https://docs.thomaspinder.com/api/kernels/base/)
|
|
128
155
|
function.
|
|
129
156
|
|
|
130
157
|
A Gaussian process prior parameterised by a mean function $m(\cdot)$ and a kernel
|
|
@@ -220,7 +247,12 @@ class Prior(AbstractPrior[M, K]):
|
|
|
220
247
|
"""
|
|
221
248
|
return self.__mul__(other)
|
|
222
249
|
|
|
223
|
-
def predict(
|
|
250
|
+
def predict(
|
|
251
|
+
self,
|
|
252
|
+
test_inputs: Num[Array, "N D"],
|
|
253
|
+
*,
|
|
254
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
255
|
+
) -> GaussianDistribution:
|
|
224
256
|
r"""Compute the predictive prior distribution for a given set of
|
|
225
257
|
parameters. The output of this function is a function that computes
|
|
226
258
|
a TFP distribution for a given set of inputs.
|
|
@@ -241,17 +273,43 @@ class Prior(AbstractPrior[M, K]):
|
|
|
241
273
|
Args:
|
|
242
274
|
test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the
|
|
243
275
|
prior distribution.
|
|
276
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
277
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
278
|
+
or just the the standard-deviation of the predictive distribution at
|
|
279
|
+
the test_inputs.
|
|
244
280
|
|
|
245
281
|
Returns:
|
|
246
282
|
GaussianDistribution: A multivariate normal random variable representation
|
|
247
283
|
of the Gaussian process.
|
|
248
284
|
"""
|
|
285
|
+
|
|
286
|
+
def _return_full_covariance(
|
|
287
|
+
t: Num[Array, "N D"],
|
|
288
|
+
) -> Dense:
|
|
289
|
+
Kxx = self.kernel.gram(t)
|
|
290
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
291
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
292
|
+
return Kxx
|
|
293
|
+
|
|
294
|
+
def _return_diagonal_covariance(
|
|
295
|
+
t: Num[Array, "N D"],
|
|
296
|
+
) -> Dense:
|
|
297
|
+
Kxx = self.kernel.diagonal(t).diagonal
|
|
298
|
+
Kxx += self.jitter
|
|
299
|
+
Kxx = psd(Dense(Diagonal(Kxx).to_dense()))
|
|
300
|
+
return Kxx
|
|
301
|
+
|
|
249
302
|
mean_at_test = self.mean_function(test_inputs)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
303
|
+
cov = jax.lax.cond(
|
|
304
|
+
return_covariance_type == "dense",
|
|
305
|
+
_return_full_covariance,
|
|
306
|
+
_return_diagonal_covariance,
|
|
307
|
+
test_inputs,
|
|
308
|
+
)
|
|
253
309
|
|
|
254
|
-
return GaussianDistribution(
|
|
310
|
+
return GaussianDistribution(
|
|
311
|
+
loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov
|
|
312
|
+
)
|
|
255
313
|
|
|
256
314
|
def sample_approx(
|
|
257
315
|
self,
|
|
@@ -329,7 +387,7 @@ P = tp.TypeVar("P", bound=AbstractPrior)
|
|
|
329
387
|
|
|
330
388
|
#######################
|
|
331
389
|
# GP Posteriors
|
|
332
|
-
#######################
|
|
390
|
+
#######################from gpjax.linalg.operators import LinearOperator
|
|
333
391
|
class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
334
392
|
r"""Abstract Gaussian process posterior.
|
|
335
393
|
|
|
@@ -356,7 +414,11 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
356
414
|
self.jitter = jitter
|
|
357
415
|
|
|
358
416
|
def __call__(
|
|
359
|
-
self,
|
|
417
|
+
self,
|
|
418
|
+
test_inputs: Num[Array, "N D"],
|
|
419
|
+
train_data: Dataset,
|
|
420
|
+
*,
|
|
421
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
360
422
|
) -> GaussianDistribution:
|
|
361
423
|
r"""Evaluate the Gaussian process posterior at the given points.
|
|
362
424
|
|
|
@@ -372,16 +434,28 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
372
434
|
Args:
|
|
373
435
|
test_inputs: Input locations where the GP should be evaluated.
|
|
374
436
|
train_data: Training dataset to condition on.
|
|
437
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
438
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
439
|
+
or just the the standard-deviation of the predictive distribution at
|
|
440
|
+
the test_inputs.
|
|
375
441
|
|
|
376
442
|
Returns:
|
|
377
443
|
GaussianDistribution: A multivariate normal random variable representation
|
|
378
444
|
of the Gaussian process.
|
|
379
445
|
"""
|
|
380
|
-
return self.predict(
|
|
446
|
+
return self.predict(
|
|
447
|
+
test_inputs,
|
|
448
|
+
train_data,
|
|
449
|
+
return_covariance_type=return_covariance_type,
|
|
450
|
+
)
|
|
381
451
|
|
|
382
452
|
@abstractmethod
|
|
383
453
|
def predict(
|
|
384
|
-
self,
|
|
454
|
+
self,
|
|
455
|
+
test_inputs: Num[Array, "N D"],
|
|
456
|
+
train_data: Dataset,
|
|
457
|
+
*,
|
|
458
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
385
459
|
) -> GaussianDistribution:
|
|
386
460
|
r"""Compute the latent function's multivariate normal distribution for a
|
|
387
461
|
given set of parameters. For any class inheriting the `AbstractPosterior` class,
|
|
@@ -390,6 +464,10 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
390
464
|
Args:
|
|
391
465
|
test_inputs: Input locations where the GP should be evaluated.
|
|
392
466
|
train_data: Training dataset to condition on.
|
|
467
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
468
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
469
|
+
or just the the standard-deviation of the predictive distribution at
|
|
470
|
+
the test_inputs.
|
|
393
471
|
|
|
394
472
|
Returns:
|
|
395
473
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -442,8 +520,10 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
442
520
|
|
|
443
521
|
def predict(
|
|
444
522
|
self,
|
|
445
|
-
test_inputs: Num[Array, "
|
|
523
|
+
test_inputs: Num[Array, "M D"],
|
|
446
524
|
train_data: Dataset,
|
|
525
|
+
*,
|
|
526
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
447
527
|
) -> GaussianDistribution:
|
|
448
528
|
r"""Query the predictive posterior distribution.
|
|
449
529
|
|
|
@@ -454,13 +534,13 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
454
534
|
|
|
455
535
|
The predictive distribution of a conjugate GP is given by
|
|
456
536
|
$$
|
|
457
|
-
|
|
458
|
-
|
|
537
|
+
p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
|
|
538
|
+
& =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
|
|
459
539
|
$$
|
|
460
540
|
where
|
|
461
541
|
$$
|
|
462
|
-
|
|
463
|
-
|
|
542
|
+
\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
|
|
543
|
+
\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
|
|
464
544
|
$$
|
|
465
545
|
|
|
466
546
|
The conditioning set is a GPJax `Dataset` object, whilst predictions
|
|
@@ -486,44 +566,65 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
486
566
|
predictive distribution is evaluated.
|
|
487
567
|
train_data (Dataset): A `gpx.Dataset` object that contains the input and
|
|
488
568
|
output data used for training dataset.
|
|
569
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
570
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
571
|
+
or just the the standard-deviation of the predictive distribution at
|
|
572
|
+
the test_inputs.
|
|
489
573
|
|
|
490
574
|
Returns:
|
|
491
575
|
GaussianDistribution: A function that accepts an input array and
|
|
492
576
|
returns the predictive distribution as a `GaussianDistribution`.
|
|
493
577
|
"""
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
# Unpack test inputs
|
|
498
|
-
t = test_inputs
|
|
499
|
-
|
|
578
|
+
x = train_data.X
|
|
579
|
+
y = train_data.y
|
|
500
580
|
# Observation noise o²
|
|
501
|
-
obs_noise = self.likelihood.obs_stddev.value
|
|
581
|
+
obs_noise = jnp.square(self.likelihood.obs_stddev.value)
|
|
502
582
|
mx = self.prior.mean_function(x)
|
|
503
|
-
|
|
504
583
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
505
584
|
Kxx = self.prior.kernel.gram(x)
|
|
506
|
-
|
|
507
|
-
Kxx = Dense(Kxx_dense)
|
|
585
|
+
Kxx = add_jitter(Kxx.to_dense(), self.jitter)
|
|
508
586
|
|
|
509
|
-
Sigma_dense = Kxx
|
|
587
|
+
Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
588
|
Sigma = psd(Dense(Sigma_dense))
|
|
511
589
|
L_sigma = lower_cholesky(Sigma)
|
|
512
590
|
|
|
513
|
-
|
|
514
|
-
Ktt = self.prior.kernel.gram(t)
|
|
515
|
-
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
591
|
+
Kxt = self.prior.kernel.cross_covariance(x, test_inputs)
|
|
516
592
|
|
|
517
593
|
L_inv_Kxt = solve(L_sigma, Kxt)
|
|
518
594
|
L_inv_y_diff = solve(L_sigma, y - mx)
|
|
519
595
|
|
|
596
|
+
mean_t = self.prior.mean_function(test_inputs)
|
|
520
597
|
mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
|
|
521
598
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
599
|
+
def _return_full_covariance(
|
|
600
|
+
L_inv_Kxt: Num[Array, "N M"],
|
|
601
|
+
t: Num[Array, "M D"],
|
|
602
|
+
) -> Dense:
|
|
603
|
+
Ktt = self.prior.kernel.gram(t)
|
|
604
|
+
covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
|
|
605
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
606
|
+
covariance = psd(Dense(covariance))
|
|
607
|
+
return covariance
|
|
608
|
+
|
|
609
|
+
def _return_diagonal_covariance(
|
|
610
|
+
L_inv_Kxt: Num[Array, "N M"],
|
|
611
|
+
t: Num[Array, "M D"],
|
|
612
|
+
) -> Dense:
|
|
613
|
+
Ktt = self.prior.kernel.diagonal(t).diagonal
|
|
614
|
+
covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt)
|
|
615
|
+
covariance += self.prior.jitter
|
|
616
|
+
covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
|
|
617
|
+
return covariance
|
|
618
|
+
|
|
619
|
+
cov = jax.lax.cond(
|
|
620
|
+
return_covariance_type == "dense",
|
|
621
|
+
_return_full_covariance,
|
|
622
|
+
_return_diagonal_covariance,
|
|
623
|
+
L_inv_Kxt,
|
|
624
|
+
test_inputs,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov)
|
|
527
628
|
|
|
528
629
|
def sample_approx(
|
|
529
630
|
self,
|
|
@@ -567,7 +668,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
567
668
|
|
|
568
669
|
Returns:
|
|
569
670
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
570
|
-
|
|
671
|
+
process prior.
|
|
571
672
|
"""
|
|
572
673
|
if (not isinstance(num_samples, int)) or num_samples <= 0:
|
|
573
674
|
raise ValueError("num_samples must be a positive integer")
|
|
@@ -586,7 +687,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
586
687
|
canonical_weights = solve(
|
|
587
688
|
Sigma,
|
|
588
689
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
589
|
-
) #
|
|
690
|
+
) # [N, B]
|
|
590
691
|
|
|
591
692
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
592
693
|
fourier_features = fourier_feature_fn(test_inputs)
|
|
@@ -648,7 +749,11 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
648
749
|
self.key = key
|
|
649
750
|
|
|
650
751
|
def predict(
|
|
651
|
-
self,
|
|
752
|
+
self,
|
|
753
|
+
test_inputs: Num[Array, "M D"],
|
|
754
|
+
train_data: Dataset,
|
|
755
|
+
*,
|
|
756
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
652
757
|
) -> GaussianDistribution:
|
|
653
758
|
r"""Query the predictive posterior distribution.
|
|
654
759
|
|
|
@@ -660,50 +765,78 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
660
765
|
transformed through the likelihood function's inverse link function.
|
|
661
766
|
|
|
662
767
|
Args:
|
|
663
|
-
|
|
664
|
-
|
|
768
|
+
test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the
|
|
769
|
+
predictive distribution is evaluated.
|
|
770
|
+
train_data (Dataset): A `gpx.Dataset` object that contains the input
|
|
771
|
+
and output data used for training dataset.
|
|
772
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
773
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
774
|
+
or just the the standard-deviation of the predictive distribution at
|
|
775
|
+
the test_inputs.
|
|
665
776
|
|
|
666
777
|
Returns:
|
|
667
778
|
GaussianDistribution: A function that accepts an
|
|
668
779
|
input array and returns the predictive distribution as
|
|
669
780
|
a `dx.Distribution`.
|
|
670
781
|
"""
|
|
671
|
-
# Unpack training data
|
|
672
782
|
x = train_data.X
|
|
673
|
-
|
|
674
|
-
# Unpack mean function and kernel
|
|
783
|
+
t = test_inputs
|
|
675
784
|
mean_function = self.prior.mean_function
|
|
676
785
|
kernel = self.prior.kernel
|
|
677
786
|
|
|
678
|
-
# Precompute lower triangular of Gram matrix
|
|
787
|
+
# Precompute lower triangular of Gram matrix
|
|
679
788
|
Kxx = kernel.gram(x)
|
|
680
789
|
Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
|
|
681
790
|
Kxx = psd(Dense(Kxx_dense))
|
|
682
791
|
Lx = lower_cholesky(Kxx)
|
|
683
792
|
|
|
684
|
-
|
|
685
|
-
t = test_inputs
|
|
686
|
-
|
|
687
|
-
# Compute terms of the posterior predictive distribution
|
|
688
|
-
Ktx = kernel.cross_covariance(t, x)
|
|
689
|
-
Ktt = kernel.gram(t)
|
|
690
|
-
mean_t = mean_function(t)
|
|
691
|
-
|
|
793
|
+
Kxt = kernel.cross_covariance(x, t)
|
|
692
794
|
# Lx⁻¹ Kxt
|
|
693
|
-
Lx_inv_Kxt = solve(Lx,
|
|
795
|
+
Lx_inv_Kxt = solve(Lx, Kxt)
|
|
694
796
|
|
|
797
|
+
mean_t = mean_function(t)
|
|
695
798
|
# Whitened function values, wx, corresponding to the inputs, x
|
|
696
799
|
wx = self.latent.value
|
|
697
800
|
|
|
698
801
|
# μt + Ktx Lx⁻¹ wx
|
|
699
802
|
mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
|
|
700
803
|
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
804
|
+
def _return_full_covariance(
|
|
805
|
+
Lx_inv_Kxt: Num[Array, "N M"],
|
|
806
|
+
t: Num[Array, "M D"],
|
|
807
|
+
) -> Dense:
|
|
808
|
+
Ktt = kernel.gram(t)
|
|
809
|
+
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
810
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
811
|
+
covariance = psd(Dense(covariance))
|
|
812
|
+
|
|
813
|
+
return covariance
|
|
814
|
+
|
|
815
|
+
def _return_diagonal_covariance(
|
|
816
|
+
Lx_inv_Kxt: Num[Array, "N M"],
|
|
817
|
+
t: Num[Array, "M D"],
|
|
818
|
+
) -> Dense:
|
|
819
|
+
Ktt = kernel.diagonal(t).diagonal
|
|
820
|
+
covariance = Ktt - jnp.einsum("ij, ji->i", Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
821
|
+
covariance += self.prior.jitter
|
|
822
|
+
# It would be nice to return a Diagonal here, but the pytree needs
|
|
823
|
+
# to be the same for both cond branches and the other branch needs
|
|
824
|
+
# to return a Dense.
|
|
825
|
+
# They are both LinearOperators, but they inherit from that class
|
|
826
|
+
# and hence are not the same pytree anymore.
|
|
827
|
+
covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
|
|
828
|
+
|
|
829
|
+
return covariance
|
|
830
|
+
|
|
831
|
+
cov = jax.lax.cond(
|
|
832
|
+
return_covariance_type == "dense",
|
|
833
|
+
_return_full_covariance,
|
|
834
|
+
_return_diagonal_covariance,
|
|
835
|
+
Lx_inv_Kxt,
|
|
836
|
+
test_inputs,
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
|
|
707
840
|
|
|
708
841
|
|
|
709
842
|
#######################
|
gpjax/kernels/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/kernels/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -123,7 +123,7 @@ class AbstractKernel(nnx.Module):
|
|
|
123
123
|
"""
|
|
124
124
|
return self.compute_engine.gram(self, x)
|
|
125
125
|
|
|
126
|
-
def diagonal(self, x: Num[Array, "N D"]) ->
|
|
126
|
+
def diagonal(self, x: Num[Array, "N D"]) -> LinearOperator:
|
|
127
127
|
r"""Compute the diagonal of the gram matrix of the kernel.
|
|
128
128
|
|
|
129
129
|
Args:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -6,9 +6,7 @@ from jaxtyping import Float
|
|
|
6
6
|
import gpjax
|
|
7
7
|
from gpjax.kernels.computations.base import AbstractKernelComputation
|
|
8
8
|
from gpjax.linalg import (
|
|
9
|
-
Dense,
|
|
10
9
|
Diagonal,
|
|
11
|
-
psd,
|
|
12
10
|
)
|
|
13
11
|
from gpjax.typing import Array
|
|
14
12
|
|
|
@@ -27,9 +25,9 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
27
25
|
z2 = self.compute_features(kernel, y)
|
|
28
26
|
return self.scaling(kernel) * jnp.matmul(z1, z2.T)
|
|
29
27
|
|
|
30
|
-
def _gram(self, kernel: K, inputs: Float[Array, "N D"]) ->
|
|
28
|
+
def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]:
|
|
31
29
|
z1 = self.compute_features(kernel, inputs)
|
|
32
|
-
return
|
|
30
|
+
return self.scaling(kernel) * jnp.matmul(z1, z1.T)
|
|
33
31
|
|
|
34
32
|
def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
|
|
35
33
|
r"""For a given kernel, compute the elementwise diagonal of the
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,7 @@ import beartype.typing as tp
|
|
|
17
17
|
import jax.numpy as jnp
|
|
18
18
|
from jaxtyping import (
|
|
19
19
|
Float,
|
|
20
|
-
|
|
20
|
+
Integer,
|
|
21
21
|
Num,
|
|
22
22
|
)
|
|
23
23
|
|
|
@@ -103,11 +103,23 @@ class GraphKernel(StationaryKernel):
|
|
|
103
103
|
|
|
104
104
|
def __call__(
|
|
105
105
|
self,
|
|
106
|
-
x:
|
|
107
|
-
y:
|
|
106
|
+
x: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"],
|
|
107
|
+
y: ScalarInt | Integer[Array, " M"] | Integer[Array, "M 1"],
|
|
108
108
|
):
|
|
109
|
+
x_idx = self._prepare_indices(x)
|
|
110
|
+
y_idx = self._prepare_indices(y)
|
|
109
111
|
S = calculate_heat_semigroup(self)
|
|
110
|
-
Kxx = (jax_gather_nd(self.eigenvectors,
|
|
111
|
-
jax_gather_nd(self.eigenvectors,
|
|
112
|
+
Kxx = (jax_gather_nd(self.eigenvectors, x_idx) * S.squeeze()) @ jnp.transpose(
|
|
113
|
+
jax_gather_nd(self.eigenvectors, y_idx)
|
|
112
114
|
) # shape (n,n)
|
|
113
115
|
return Kxx.squeeze()
|
|
116
|
+
|
|
117
|
+
def _prepare_indices(
|
|
118
|
+
self,
|
|
119
|
+
indices: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"],
|
|
120
|
+
) -> Integer[Array, "N 1"]:
|
|
121
|
+
"""Ensure index arrays are integer column vectors regardless of caller shape."""
|
|
122
|
+
|
|
123
|
+
idx = jnp.asarray(indices, dtype=jnp.int32)
|
|
124
|
+
idx = jnp.atleast_1d(idx)
|
|
125
|
+
return idx.reshape(-1, 1)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/kernels/stationary/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/kernels/stationary/rbf.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/scan.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 The
|
|
1
|
+
# Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.3
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
|
-
Project-URL: Issues, https://github.com/
|
|
7
|
-
Project-URL: Source, https://github.com/
|
|
6
|
+
Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
|
|
7
|
+
Project-URL: Source, https://github.com/thomaspinder/GPJax
|
|
8
8
|
Author-email: Thomas Pinder <tompinder@live.co.uk>
|
|
9
9
|
License: MIT
|
|
10
10
|
License-File: LICENSE.txt
|
|
@@ -72,11 +72,11 @@ Description-Content-Type: text/markdown
|
|
|
72
72
|
<!-- <h1 align='center'>GPJax</h1>
|
|
73
73
|
<h2 align='center'>Gaussian processes in Jax.</h2> -->
|
|
74
74
|
<p align="center">
|
|
75
|
-
<img width="700" height="300" src="https://raw.githubusercontent.com/
|
|
75
|
+
<img width="700" height="300" src="https://raw.githubusercontent.com/thomaspinder/GPJax/main/docs/static/gpjax_logo.svg" alt="GPJax's logo">
|
|
76
76
|
</p>
|
|
77
77
|
|
|
78
|
-
[](https://codecov.io/gh/thomaspinder/GPJax)
|
|
79
|
+
[](https://www.codefactor.io/repository/github/thomaspinder/GPJax)
|
|
80
80
|
[](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
|
|
81
81
|
[](https://badge.fury.io/py/GPJax)
|
|
82
82
|
[](https://anaconda.org/conda-forge/gpjax)
|
|
@@ -101,19 +101,19 @@ with GP models.
|
|
|
101
101
|
|
|
102
102
|
We would be delighted to receive contributions from interested individuals and
|
|
103
103
|
groups. To learn how you can get involved, please read our [guide for
|
|
104
|
-
contributing](https://github.com/
|
|
104
|
+
contributing](https://github.com/thomaspinder/GPJax/blob/main/docs/contributing.md).
|
|
105
105
|
If you have any questions, we encourage you to [open an
|
|
106
|
-
issue](https://github.com/
|
|
106
|
+
issue](https://github.com/thomaspinder/GPJax/issues/new/choose). For
|
|
107
107
|
broader conversations, such as best GP fitting practices or questions about the
|
|
108
108
|
mathematics of GPs, we invite you to [open a
|
|
109
|
-
discussion](https://github.com/
|
|
109
|
+
discussion](https://github.com/thomaspinder/GPJax/discussions).
|
|
110
110
|
|
|
111
111
|
Another way you can contribute to GPJax is through [issue
|
|
112
112
|
triaging](https://www.codetriage.com/what). This can include reproducing bug reports,
|
|
113
113
|
asking for vital information such as version numbers and reproduction instructions, or
|
|
114
114
|
identifying stale issues. If you would like to begin triaging issues, an easy way to get
|
|
115
115
|
started is to
|
|
116
|
-
[subscribe to GPJax on CodeTriage](https://www.codetriage.com/
|
|
116
|
+
[subscribe to GPJax on CodeTriage](https://www.codetriage.com/thomaspinder/GPJax).
|
|
117
117
|
|
|
118
118
|
As a contributor to GPJax, you are expected to abide by our [code of
|
|
119
119
|
conduct](docs/CODE_OF_CONDUCT.md). If you feel that you have either experienced or
|
|
@@ -127,7 +127,7 @@ where we can discuss the development of GPJax and broader support for Gaussian
|
|
|
127
127
|
process modelling.
|
|
128
128
|
|
|
129
129
|
We appreciate all [the contributors to
|
|
130
|
-
GPJax](https://github.com/
|
|
130
|
+
GPJax](https://github.com/thomaspinder/GPJax/graphs/contributors) who have helped to shape
|
|
131
131
|
GPJax into the package it is today.
|
|
132
132
|
|
|
133
133
|
# Supported methods and interfaces
|
|
@@ -215,7 +215,7 @@ conda install --channel conda-forge gpjax
|
|
|
215
215
|
Clone a copy of the repository to your local machine and run the setup
|
|
216
216
|
configuration in development mode.
|
|
217
217
|
```bash
|
|
218
|
-
git clone https://github.com/
|
|
218
|
+
git clone https://github.com/thomaspinder/GPJax.git
|
|
219
219
|
cd GPJax
|
|
220
220
|
uv venv
|
|
221
221
|
uv sync --extra dev
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
gpjax/__init__.py,sha256=8EULPS_vtq4TeN6MjdLVHLlTVVyQADOoHtuVoRN8z5Y,1625
|
|
2
|
+
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
|
+
gpjax/dataset.py,sha256=Ef5JGrl4jJS1mQmL3JdO0fdqbVmflT_Cu5VrlpYdJY4,4071
|
|
4
|
+
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
|
+
gpjax/fit.py,sha256=tOXmM3l5-N3Jlnq8MOVkSpRj-0fRWOy2t9GxQRyUqxY,15318
|
|
6
|
+
gpjax/gps.py,sha256=NcPXkkx0kXrSBTUje4QR6PS0DGrD7p5c-DyiOATwUz8,35338
|
|
7
|
+
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
+
gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
|
|
9
|
+
gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
|
|
10
|
+
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
|
+
gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
|
|
12
|
+
gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
|
|
13
|
+
gpjax/scan.py,sha256=Z_V1yd76dL4B7_rnJnr_fohom6xzN_WRYxlTAAdqfa0,5357
|
|
14
|
+
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
+
gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
|
|
16
|
+
gpjax/kernels/__init__.py,sha256=GFiku1s8KPPzvfwDYzEjENPQReywvJ29M9Kl2JjkcoU,1893
|
|
17
|
+
gpjax/kernels/base.py,sha256=gov60BgXOQK0PxbKQRVgjU55kCVZYAkDZrAInODvugo,11549
|
|
18
|
+
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
+
gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
|
|
20
|
+
gpjax/kernels/computations/__init__.py,sha256=gSkYASLcyPwcIg12Jjv7vxW9rixMVUhnlcl07iNvN_8,1371
|
|
21
|
+
gpjax/kernels/computations/base.py,sha256=Fh_ppC7JwrbqKNJj1JqDOaUUGm23Zdm7RRLnGbQtePI,3608
|
|
22
|
+
gpjax/kernels/computations/basis_functions.py,sha256=Y6OS4UEvo-Mdn7yRMWPde5rOvewDHIleaUIVBOpCTd4,2466
|
|
23
|
+
gpjax/kernels/computations/constant_diagonal.py,sha256=ihxR0rcwcLi_sv3a026nEXRObMvmScBLOKj09XGwuns,1976
|
|
24
|
+
gpjax/kernels/computations/dense.py,sha256=TGbAXSoaqo3Erzsaj7NztrEN2a_sHW5mi3bPbA1W9nc,1384
|
|
25
|
+
gpjax/kernels/computations/diagonal.py,sha256=4V8aVXqC61GdgmFTvxDhmtMw1uSPi_C-m-d9-Fo6B8U,1760
|
|
26
|
+
gpjax/kernels/computations/eigen.py,sha256=koVBKeLrHZX77PLlrykM-h2FN5g7l7S5xoyGUKD69Oc,1372
|
|
27
|
+
gpjax/kernels/non_euclidean/__init__.py,sha256=Jb3K9a5-H-kNOirkrJJ46e3sgJIvoQS1-LfSsNXspJg,782
|
|
28
|
+
gpjax/kernels/non_euclidean/graph.py,sha256=SevOCNeVZ5eK0MPhYGLnohAvNqig0C5c-oOpKuUKxeA,4631
|
|
29
|
+
gpjax/kernels/non_euclidean/utils.py,sha256=PANh8fPOXnV8x6liTCJ08YYnDIQ33PrpLD-9hzjSg9w,2334
|
|
30
|
+
gpjax/kernels/nonstationary/__init__.py,sha256=X6WVSMrejtLH6k_ct-1uZ9dmdmd_IEBIneweceXgXI0,922
|
|
31
|
+
gpjax/kernels/nonstationary/arccosine.py,sha256=hjh6umWkSYs2xWa0rJQDPnmmkvk-vAwbGhMrkCev70o,5400
|
|
32
|
+
gpjax/kernels/nonstationary/linear.py,sha256=xjeG4DN5K3-fgRw-HTFvPfcxYfYO5vPxdZc6K5qvnW8,2563
|
|
33
|
+
gpjax/kernels/nonstationary/polynomial.py,sha256=D_czngVP0VGqFlqi_uJBayKegYs1Q-sHTCWCRDUWVAo,3380
|
|
34
|
+
gpjax/kernels/stationary/__init__.py,sha256=b8KuCIENnW1nVQv3KdlZmOicfgIdoZzqgvVKlkgzkUk,1398
|
|
35
|
+
gpjax/kernels/stationary/base.py,sha256=zIg21SK5nTg9kW0RJy6xXIp7XF7TjZxJT9Vvpsscz4M,6030
|
|
36
|
+
gpjax/kernels/stationary/matern12.py,sha256=FbhdY5xlkgoCCOizVd7ZKK2XjJOCMDmfn2bTdxTOIKw,1763
|
|
37
|
+
gpjax/kernels/stationary/matern32.py,sha256=tHHMN3pNMf5sGZNyCowi1RUUB0SENdkjHgQE8ZS3kMU,1956
|
|
38
|
+
gpjax/kernels/stationary/matern52.py,sha256=TjS2_R9tYga8ePpSd5UIBTKDCsGWJX8AlUVvElYAKYk,2009
|
|
39
|
+
gpjax/kernels/stationary/periodic.py,sha256=sCGWFYSTM6P91AFw9qTThAtzP5spatNrcU66Lr9qQjU,3520
|
|
40
|
+
gpjax/kernels/stationary/powered_exponential.py,sha256=Z0q9B6mZysL8pKwFPf_kYsJ8hUOf2E8rSlxhbl0U3f8,3671
|
|
41
|
+
gpjax/kernels/stationary/rational_quadratic.py,sha256=9X6X6YEIW8021XPzkCOMUihdWS0_JnBVO1b2JToFa2w,3410
|
|
42
|
+
gpjax/kernels/stationary/rbf.py,sha256=uuhDlV-kZZ3gL0ZRpKWAOh3pVTpkddEVIILsWdEsMmM,1682
|
|
43
|
+
gpjax/kernels/stationary/utils.py,sha256=wcS4rRJmNIeL3Log7gcK7klWWprulF4krQJq803alhk,2172
|
|
44
|
+
gpjax/kernels/stationary/white.py,sha256=rPVf2xzfv2WpPpb8E7LXGuguiyZ6QKa7M3ltzLItmpk,2204
|
|
45
|
+
gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
46
|
+
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
|
+
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
|
+
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
+
gpjax-0.13.3.dist-info/METADATA,sha256=1Au0xXIXWvS64SPLP8R_mPQ6nuUMayx2qdNLD-HdhW8,10296
|
|
50
|
+
gpjax-0.13.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
+
gpjax-0.13.3.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.13.3.dist-info/RECORD,,
|
gpjax-0.13.1.dist-info/RECORD
DELETED
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=asMWra4r95NSlYQbniJhCQV6pEk39ONOTvnkm-wy8OA,1641
|
|
2
|
-
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
|
-
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
-
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
|
-
gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
|
|
6
|
-
gpjax/gps.py,sha256=fSbHjfQ1vKUZ3CnqdNgJVgFsuTkCwK3yK_c0SQsWYo0,30118
|
|
7
|
-
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
|
|
9
|
-
gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
|
|
10
|
-
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
|
-
gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
|
|
12
|
-
gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
|
|
13
|
-
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
|
-
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
|
|
16
|
-
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
|
-
gpjax/kernels/base.py,sha256=4vUV6_wnbV64EUaOecVWE9Vz4G5Dg6xUMNsQSis8A1o,11561
|
|
18
|
-
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
-
gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
|
|
20
|
-
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
21
|
-
gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
|
|
22
|
-
gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
|
|
23
|
-
gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
|
|
24
|
-
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
25
|
-
gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
|
|
26
|
-
gpjax/kernels/computations/eigen.py,sha256=LuwYVPK0AuDGNSccPmAS8wyDJ2ngkRbc6kNZzqiJEOg,1380
|
|
27
|
-
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
28
|
-
gpjax/kernels/non_euclidean/graph.py,sha256=MukL4ZghUywkau4tVWmHsrAjzTcl9kHN6onEAvxSvjc,4109
|
|
29
|
-
gpjax/kernels/non_euclidean/utils.py,sha256=vxjRX209tOL3jKsw69hQfVtBqkASDZ__4tJOGI4oLjo,2342
|
|
30
|
-
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
|
-
gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
|
|
32
|
-
gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
|
|
33
|
-
gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
|
|
34
|
-
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
|
-
gpjax/kernels/stationary/base.py,sha256=YSgur73wqAFZBzt8D6CfujvyjowXoPOK1po10gjzMJo,6038
|
|
36
|
-
gpjax/kernels/stationary/matern12.py,sha256=tfrIP-pJML-kSqVwf8wdVUwC53mhFohMnf7z5OMX3xE,1771
|
|
37
|
-
gpjax/kernels/stationary/matern32.py,sha256=vdrszTtH03PB9D4AKg-YC_56xCupK3tO9L8vH0GYFcc,1964
|
|
38
|
-
gpjax/kernels/stationary/matern52.py,sha256=IVsIKDo9bTZnA98yBmA_Q31peFdSCTXjf0_cVlLQh6k,2017
|
|
39
|
-
gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
|
|
40
|
-
gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
|
|
41
|
-
gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
|
|
42
|
-
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
43
|
-
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
44
|
-
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
45
|
-
gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
46
|
-
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
|
-
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
|
-
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
-
gpjax-0.13.1.dist-info/METADATA,sha256=qhsF8IgUu4oYTuLdqT56XQ4EKkDnUH_4D7imLxL_nPQ,10400
|
|
50
|
-
gpjax-0.13.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
-
gpjax-0.13.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
-
gpjax-0.13.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|