gpjax 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.9.
|
|
43
|
+
__version__ = "0.9.2"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
gpjax/gps.py
CHANGED
|
@@ -17,6 +17,7 @@ from abc import abstractmethod
|
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
19
|
from cola.annotations import PSD
|
|
20
|
+
from cola.linalg.algorithm_base import Algorithm
|
|
20
21
|
from cola.linalg.decompositions.decompositions import Cholesky
|
|
21
22
|
from cola.linalg.inverse.inv import solve
|
|
22
23
|
from cola.ops.operators import I_like
|
|
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
530
531
|
train_data: Dataset,
|
|
531
532
|
key: KeyArray,
|
|
532
533
|
num_features: int | None = 100,
|
|
534
|
+
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
|
|
533
535
|
) -> FunctionalSample:
|
|
534
536
|
r"""Draw approximate samples from the Gaussian process posterior.
|
|
535
537
|
|
|
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
563
565
|
key (KeyArray): The random seed used for the sample(s).
|
|
564
566
|
num_features (int): The number of features used when approximating the
|
|
565
567
|
kernel.
|
|
568
|
+
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
|
|
569
|
+
the inverse of the covariance matrix. See the
|
|
570
|
+
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
|
|
571
|
+
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
|
|
572
|
+
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
|
|
566
573
|
|
|
567
574
|
Returns:
|
|
568
575
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
588
595
|
canonical_weights = solve(
|
|
589
596
|
Sigma,
|
|
590
597
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
591
|
-
|
|
598
|
+
solver_algorithm,
|
|
592
599
|
) # [N, B]
|
|
593
600
|
|
|
594
601
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=Bx5JFaveeVk3qJMTzbmrKOFy0U7fNcQ_JnVo5m0ACGA,1697
|
|
2
2
|
gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
|
|
5
5
|
gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
|
|
6
|
-
gpjax/gps.py,sha256=
|
|
6
|
+
gpjax/gps.py,sha256=NO18geRfcjo4mA3PGkuGont_Mj_yRqfvWzJqYmoKwiY,31225
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
8
|
gpjax/likelihoods.py,sha256=Uh4kgLTod8ODw178L--G3w4olpm9XvCdcAZ8l7FwkF4,9255
|
|
9
9
|
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
@@ -56,7 +56,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
56
56
|
gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
|
|
57
57
|
gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
|
|
58
58
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
59
|
-
gpjax-0.9.
|
|
60
|
-
gpjax-0.9.
|
|
61
|
-
gpjax-0.9.
|
|
62
|
-
gpjax-0.9.
|
|
59
|
+
gpjax-0.9.2.dist-info/METADATA,sha256=JWT3cDW7onuKnTYUGqa15WxG4L7oEboJKPHYyAggYZ0,9976
|
|
60
|
+
gpjax-0.9.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
|
61
|
+
gpjax-0.9.2.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
62
|
+
gpjax-0.9.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|