gpjax 0.9.1__tar.gz → 0.9.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. {gpjax-0.9.1 → gpjax-0.9.2}/PKG-INFO +1 -1
  2. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/__init__.py +1 -1
  3. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/gps.py +8 -1
  4. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_gps.py +13 -6
  5. gpjax-0.9.1/.github/workflows/labeler.yml +0 -18
  6. {gpjax-0.9.1 → gpjax-0.9.2}/.github/CODE_OF_CONDUCT.md +0 -0
  7. {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  8. {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  9. {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  10. {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  11. {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  12. {gpjax-0.9.1 → gpjax-0.9.2}/.github/codecov.yml +0 -0
  13. {gpjax-0.9.1 → gpjax-0.9.2}/.github/labels.yml +0 -0
  14. {gpjax-0.9.1 → gpjax-0.9.2}/.github/pull_request_template.md +0 -0
  15. {gpjax-0.9.1 → gpjax-0.9.2}/.github/release-drafter.yml +0 -0
  16. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/build_docs.yml +0 -0
  17. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/integration.yml +0 -0
  18. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/pr_greeting.yml +0 -0
  19. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/ruff.yml +0 -0
  20. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/stale_prs.yml +0 -0
  21. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/test_docs.yml +0 -0
  22. {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/tests.yml +0 -0
  23. {gpjax-0.9.1 → gpjax-0.9.2}/.gitignore +0 -0
  24. {gpjax-0.9.1 → gpjax-0.9.2}/CITATION.bib +0 -0
  25. {gpjax-0.9.1 → gpjax-0.9.2}/LICENSE +0 -0
  26. {gpjax-0.9.1 → gpjax-0.9.2}/Makefile +0 -0
  27. {gpjax-0.9.1 → gpjax-0.9.2}/README.md +0 -0
  28. {gpjax-0.9.1 → gpjax-0.9.2}/docs/CODE_OF_CONDUCT.md +0 -0
  29. {gpjax-0.9.1 → gpjax-0.9.2}/docs/GOVERNANCE.md +0 -0
  30. {gpjax-0.9.1 → gpjax-0.9.2}/docs/contributing.md +0 -0
  31. {gpjax-0.9.1 → gpjax-0.9.2}/docs/design.md +0 -0
  32. {gpjax-0.9.1 → gpjax-0.9.2}/docs/index.md +0 -0
  33. {gpjax-0.9.1 → gpjax-0.9.2}/docs/index.rst +0 -0
  34. {gpjax-0.9.1 → gpjax-0.9.2}/docs/installation.md +0 -0
  35. {gpjax-0.9.1 → gpjax-0.9.2}/docs/javascripts/katex.js +0 -0
  36. {gpjax-0.9.1 → gpjax-0.9.2}/docs/refs.bib +0 -0
  37. {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/gen_examples.py +0 -0
  38. {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/gen_pages.py +0 -0
  39. {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/notebook_converter.py +0 -0
  40. {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/sharp_bits_figure.py +0 -0
  41. {gpjax-0.9.1 → gpjax-0.9.2}/docs/sharp_bits.md +0 -0
  42. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/GP.pdf +0 -0
  43. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/GP.svg +0 -0
  44. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/bijector_figure.svg +0 -0
  45. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/css/gpjax_theme.css +0 -0
  46. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/favicon.ico +0 -0
  47. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax.mplstyle +0 -0
  48. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax_logo.pdf +0 -0
  49. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax_logo.svg +0 -0
  50. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/lato.ttf +0 -0
  51. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/logo.png +0 -0
  52. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/logo.svg +0 -0
  53. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/main.py +0 -0
  54. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/step_size_figure.png +0 -0
  55. {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/step_size_figure.svg +0 -0
  56. {gpjax-0.9.1 → gpjax-0.9.2}/docs/stylesheets/extra.css +0 -0
  57. {gpjax-0.9.1 → gpjax-0.9.2}/docs/stylesheets/permalinks.css +0 -0
  58. {gpjax-0.9.1 → gpjax-0.9.2}/examples/backend.py +0 -0
  59. {gpjax-0.9.1 → gpjax-0.9.2}/examples/barycentres/barycentre_gp.gif +0 -0
  60. {gpjax-0.9.1 → gpjax-0.9.2}/examples/barycentres.py +0 -0
  61. {gpjax-0.9.1 → gpjax-0.9.2}/examples/bayesian_optimisation.py +0 -0
  62. {gpjax-0.9.1 → gpjax-0.9.2}/examples/classification.py +0 -0
  63. {gpjax-0.9.1 → gpjax-0.9.2}/examples/collapsed_vi.py +0 -0
  64. {gpjax-0.9.1 → gpjax-0.9.2}/examples/constructing_new_kernels.py +0 -0
  65. {gpjax-0.9.1 → gpjax-0.9.2}/examples/data/max_tempeature_switzerland.csv +0 -0
  66. {gpjax-0.9.1 → gpjax-0.9.2}/examples/data/yacht_hydrodynamics.data +0 -0
  67. {gpjax-0.9.1 → gpjax-0.9.2}/examples/decision_making.py +0 -0
  68. {gpjax-0.9.1 → gpjax-0.9.2}/examples/deep_kernels.py +0 -0
  69. {gpjax-0.9.1 → gpjax-0.9.2}/examples/gpjax.mplstyle +0 -0
  70. {gpjax-0.9.1 → gpjax-0.9.2}/examples/graph_kernels.py +0 -0
  71. {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
  72. {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps/generating_process.png +0 -0
  73. {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps.py +0 -0
  74. {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_kernels.py +0 -0
  75. {gpjax-0.9.1 → gpjax-0.9.2}/examples/likelihoods_guide.py +0 -0
  76. {gpjax-0.9.1 → gpjax-0.9.2}/examples/oceanmodelling.py +0 -0
  77. {gpjax-0.9.1 → gpjax-0.9.2}/examples/poisson.py +0 -0
  78. {gpjax-0.9.1 → gpjax-0.9.2}/examples/regression.py +0 -0
  79. {gpjax-0.9.1 → gpjax-0.9.2}/examples/uncollapsed_vi.py +0 -0
  80. {gpjax-0.9.1 → gpjax-0.9.2}/examples/utils.py +0 -0
  81. {gpjax-0.9.1 → gpjax-0.9.2}/examples/yacht.py +0 -0
  82. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/citation.py +0 -0
  83. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/dataset.py +0 -0
  84. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/__init__.py +0 -0
  85. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/decision_maker.py +0 -0
  86. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/posterior_handler.py +0 -0
  87. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/search_space.py +0 -0
  88. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/__init__.py +0 -0
  89. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
  90. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -0
  91. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/__init__.py +0 -0
  92. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/base.py +0 -0
  93. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
  94. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
  95. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
  96. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_maximizer.py +0 -0
  97. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utils.py +0 -0
  98. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/distributions.py +0 -0
  99. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/fit.py +0 -0
  100. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/integrators.py +0 -0
  101. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/__init__.py +0 -0
  102. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/approximations/__init__.py +0 -0
  103. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/approximations/rff.py +0 -0
  104. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/base.py +0 -0
  105. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/__init__.py +0 -0
  106. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/base.py +0 -0
  107. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/basis_functions.py +0 -0
  108. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  109. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/dense.py +0 -0
  110. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/diagonal.py +0 -0
  111. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/eigen.py +0 -0
  112. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  113. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
  114. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
  115. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
  116. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  117. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/linear.py +0 -0
  118. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  119. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/__init__.py +0 -0
  120. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/base.py +0 -0
  121. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern12.py +0 -0
  122. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern32.py +0 -0
  123. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern52.py +0 -0
  124. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/periodic.py +0 -0
  125. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  126. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  127. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/rbf.py +0 -0
  128. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/utils.py +0 -0
  129. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/white.py +0 -0
  130. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/likelihoods.py +0 -0
  131. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/lower_cholesky.py +0 -0
  132. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/mean_functions.py +0 -0
  133. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/objectives.py +0 -0
  134. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/parameters.py +0 -0
  135. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/scan.py +0 -0
  136. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/typing.py +0 -0
  137. {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/variational_families.py +0 -0
  138. {gpjax-0.9.1 → gpjax-0.9.2}/mkdocs.yml +0 -0
  139. {gpjax-0.9.1 → gpjax-0.9.2}/pyproject.toml +0 -0
  140. {gpjax-0.9.1 → gpjax-0.9.2}/static/CONTRIBUTING.md +0 -0
  141. {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.bib +0 -0
  142. {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.md +0 -0
  143. {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.pdf +0 -0
  144. {gpjax-0.9.1 → gpjax-0.9.2}/tests/__init__.py +0 -0
  145. {gpjax-0.9.1 → gpjax-0.9.2}/tests/conftest.py +0 -0
  146. {gpjax-0.9.1 → gpjax-0.9.2}/tests/integration_tests.py +0 -0
  147. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_citations.py +0 -0
  148. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_dataset.py +0 -0
  149. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/__init__.py +0 -0
  150. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_decision_maker.py +0 -0
  151. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_posterior_handler.py +0 -0
  152. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_search_space.py +0 -0
  153. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
  154. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
  155. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
  156. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
  157. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
  158. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +0 -0
  159. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
  160. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
  161. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
  162. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_maximizer.py +0 -0
  163. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utils.py +0 -0
  164. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/utils.py +0 -0
  165. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_fit.py +0 -0
  166. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_gaussian_distribution.py +0 -0
  167. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_integrators.py +0 -0
  168. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/__init__.py +0 -0
  169. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_approximations.py +0 -0
  170. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_base.py +0 -0
  171. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_computation.py +0 -0
  172. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_non_euclidean.py +0 -0
  173. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_nonstationary.py +0 -0
  174. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_stationary.py +0 -0
  175. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_utils.py +0 -0
  176. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_likelihoods.py +0 -0
  177. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_lower_cholesky.py +0 -0
  178. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_markdown.py +0 -0
  179. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_mean_functions.py +0 -0
  180. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_objectives.py +0 -0
  181. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_parameters.py +0 -0
  182. {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: gpjax
3
- Version: 0.9.1
3
+ Version: 0.9.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Didactic Gaussian processes in JAX"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.1"
43
+ __version__ = "0.9.2"
44
44
 
45
45
  __all__ = [
46
46
  "base",
@@ -17,6 +17,7 @@ from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
19
  from cola.annotations import PSD
20
+ from cola.linalg.algorithm_base import Algorithm
20
21
  from cola.linalg.decompositions.decompositions import Cholesky
21
22
  from cola.linalg.inverse.inv import solve
22
23
  from cola.ops.operators import I_like
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
530
531
  train_data: Dataset,
531
532
  key: KeyArray,
532
533
  num_features: int | None = 100,
534
+ solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
533
535
  ) -> FunctionalSample:
534
536
  r"""Draw approximate samples from the Gaussian process posterior.
535
537
 
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
563
565
  key (KeyArray): The random seed used for the sample(s).
564
566
  num_features (int): The number of features used when approximating the
565
567
  kernel.
568
+ solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
569
+ the inverse of the covariance matrix. See the
570
+ [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
571
+ for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
572
+ matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
566
573
 
567
574
  Returns:
568
575
  FunctionalSample: A function representing an approximate sample from the Gaussian
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
588
595
  canonical_weights = solve(
589
596
  Sigma,
590
597
  y + eps - jnp.inner(Phi, fourier_weights),
591
- Cholesky(),
598
+ solver_algorithm,
592
599
  ) # [N, B]
593
600
 
594
601
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
@@ -25,13 +25,15 @@ from typing import (
25
25
  Type,
26
26
  )
27
27
 
28
+ from cola.linalg.algorithm_base import Auto
29
+ from cola.linalg.decompositions.decompositions import Cholesky
30
+ from cola.linalg.inverse.cg import CG
28
31
  from jax import config
29
32
  import jax.numpy as jnp
30
33
  import jax.random as jr
31
34
  import pytest
32
35
  import tensorflow_probability.substrates.jax.distributions as tfd
33
36
 
34
- # from gpjax.dataset import Dataset
35
37
  from gpjax.dataset import Dataset
36
38
  from gpjax.distributions import GaussianDistribution
37
39
  from gpjax.gps import (
@@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
283
285
  @pytest.mark.parametrize("num_datapoints", [1, 5])
284
286
  @pytest.mark.parametrize("kernel", [RBF, Matern52])
285
287
  @pytest.mark.parametrize("mean_function", [Zero, Constant])
286
- def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
288
+ @pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
289
+ def test_conjugate_posterior_sample_approx(
290
+ num_datapoints, kernel, mean_function, solver_algorithm
291
+ ):
287
292
  kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
288
293
  p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
289
294
  num_datapoints=num_datapoints
@@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
310
315
  # with pytest.raises(ValidationErrors):
311
316
  # p.sample_approx(1, D, key, 0.5)
312
317
 
313
- sampled_fn = p.sample_approx(1, D, key, 100)
318
+ sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
314
319
  assert isinstance(sampled_fn, Callable) # check type
315
320
 
316
321
  x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
317
322
  evals = sampled_fn(x)
318
323
  assert evals.shape == (num_datapoints, 1.0) # check shape
319
324
 
320
- sampled_fn_2 = p.sample_approx(1, D, key, 100)
325
+ sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
321
326
  evals_2 = sampled_fn_2(x)
322
327
  max_delta = jnp.max(jnp.abs(evals - evals_2))
323
328
  assert max_delta == 0.0 # samples same for same seed
324
329
 
325
330
  new_key = jr.key(12345)
326
- sampled_fn_3 = p.sample_approx(1, D, new_key, 100)
331
+ sampled_fn_3 = p.sample_approx(
332
+ 1, D, new_key, 100, solver_algorithm=solver_algorithm
333
+ )
327
334
  evals_3 = sampled_fn_3(x)
328
335
  max_delta = jnp.max(jnp.abs(evals - evals_3))
329
336
  assert max_delta > 0.01 # samples different for different seed
330
337
 
331
338
  # Check validty of samples using Monte-Carlo
332
- sampled_fn = p.sample_approx(10_000, D, key, 100)
339
+ sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
333
340
  sampled_evals = sampled_fn(x)
334
341
  approx_mean = jnp.mean(sampled_evals, -1)
335
342
  approx_var = jnp.var(sampled_evals, -1)
@@ -1,18 +0,0 @@
1
- name: Labeler
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
-
8
- jobs:
9
- labeler:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - name: Check out the repository
13
- uses: actions/checkout@v3.5.2
14
-
15
- - name: Run Labeler
16
- uses: crazy-max/ghaction-github-labeler@v4.1.0
17
- with:
18
- skip-delete: true
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes