gpjax 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Didactic Gaussian processes in JAX"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.1"
43
+ __version__ = "0.9.2"
44
44
 
45
45
  __all__ = [
46
46
  "base",
gpjax/gps.py CHANGED
@@ -17,6 +17,7 @@ from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
19
  from cola.annotations import PSD
20
+ from cola.linalg.algorithm_base import Algorithm
20
21
  from cola.linalg.decompositions.decompositions import Cholesky
21
22
  from cola.linalg.inverse.inv import solve
22
23
  from cola.ops.operators import I_like
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
530
531
  train_data: Dataset,
531
532
  key: KeyArray,
532
533
  num_features: int | None = 100,
534
+ solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
533
535
  ) -> FunctionalSample:
534
536
  r"""Draw approximate samples from the Gaussian process posterior.
535
537
 
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
563
565
  key (KeyArray): The random seed used for the sample(s).
564
566
  num_features (int): The number of features used when approximating the
565
567
  kernel.
568
+ solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
569
+ the inverse of the covariance matrix. See the
570
+ [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
571
+ for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
572
+ matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
566
573
 
567
574
  Returns:
568
575
  FunctionalSample: A function representing an approximate sample from the Gaussian
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
588
595
  canonical_weights = solve(
589
596
  Sigma,
590
597
  y + eps - jnp.inner(Phi, fourier_weights),
591
- Cholesky(),
598
+ solver_algorithm,
592
599
  ) # [N, B]
593
600
 
594
601
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: gpjax
3
- Version: 0.9.1
3
+ Version: 0.9.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,9 +1,9 @@
1
- gpjax/__init__.py,sha256=WI2T3AUoqBPcFQe8_pflRkV1k516x6ljUJsqD_7FUBY,1697
1
+ gpjax/__init__.py,sha256=Bx5JFaveeVk3qJMTzbmrKOFy0U7fNcQ_JnVo5m0ACGA,1697
2
2
  gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
5
5
  gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
6
- gpjax/gps.py,sha256=6v0P6fw9CmyHfE8S3WwxCE2qECQ_kaq_SQWHkz6UIPE,30612
6
+ gpjax/gps.py,sha256=NO18geRfcjo4mA3PGkuGont_Mj_yRqfvWzJqYmoKwiY,31225
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
8
  gpjax/likelihoods.py,sha256=Uh4kgLTod8ODw178L--G3w4olpm9XvCdcAZ8l7FwkF4,9255
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
@@ -56,7 +56,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
56
56
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
57
57
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
58
58
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
59
- gpjax-0.9.1.dist-info/METADATA,sha256=aR6Xzmn8HdUzxvE7cdKuL6lMI21kHoDQaaUbM9wkzFw,9976
60
- gpjax-0.9.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
61
- gpjax-0.9.1.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
- gpjax-0.9.1.dist-info/RECORD,,
59
+ gpjax-0.9.2.dist-info/METADATA,sha256=JWT3cDW7onuKnTYUGqa15WxG4L7oEboJKPHYyAggYZ0,9976
60
+ gpjax-0.9.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
61
+ gpjax-0.9.2.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
+ gpjax-0.9.2.dist-info/RECORD,,
File without changes