gpjax 0.9.1__tar.gz → 0.9.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. {gpjax-0.9.1 → gpjax-0.9.3}/PKG-INFO +1 -1
  2. {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_kernels.py +1 -1
  3. {gpjax-0.9.1 → gpjax-0.9.3}/examples/poisson.py +11 -22
  4. {gpjax-0.9.1 → gpjax-0.9.3}/examples/uncollapsed_vi.py +1 -2
  5. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/__init__.py +1 -1
  6. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/gps.py +8 -1
  7. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/likelihoods.py +3 -5
  8. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/scan.py +10 -10
  9. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/variational_families.py +9 -2
  10. {gpjax-0.9.1 → gpjax-0.9.3}/mkdocs.yml +2 -2
  11. gpjax-0.9.3/publish/gpjax-0.9.3-py3-none-any.whl +0 -0
  12. gpjax-0.9.3/publish/gpjax-0.9.3.tar.gz +0 -0
  13. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_gps.py +13 -6
  14. gpjax-0.9.1/.github/workflows/labeler.yml +0 -18
  15. {gpjax-0.9.1 → gpjax-0.9.3}/.github/CODE_OF_CONDUCT.md +0 -0
  16. {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  17. {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  18. {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  19. {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  20. {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  21. {gpjax-0.9.1 → gpjax-0.9.3}/.github/codecov.yml +0 -0
  22. {gpjax-0.9.1 → gpjax-0.9.3}/.github/labels.yml +0 -0
  23. {gpjax-0.9.1 → gpjax-0.9.3}/.github/pull_request_template.md +0 -0
  24. {gpjax-0.9.1 → gpjax-0.9.3}/.github/release-drafter.yml +0 -0
  25. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/build_docs.yml +0 -0
  26. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/integration.yml +0 -0
  27. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/pr_greeting.yml +0 -0
  28. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/ruff.yml +0 -0
  29. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/stale_prs.yml +0 -0
  30. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/test_docs.yml +0 -0
  31. {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/tests.yml +0 -0
  32. {gpjax-0.9.1 → gpjax-0.9.3}/.gitignore +0 -0
  33. {gpjax-0.9.1 → gpjax-0.9.3}/CITATION.bib +0 -0
  34. {gpjax-0.9.1 → gpjax-0.9.3}/LICENSE +0 -0
  35. {gpjax-0.9.1 → gpjax-0.9.3}/Makefile +0 -0
  36. {gpjax-0.9.1 → gpjax-0.9.3}/README.md +0 -0
  37. {gpjax-0.9.1 → gpjax-0.9.3}/docs/CODE_OF_CONDUCT.md +0 -0
  38. {gpjax-0.9.1 → gpjax-0.9.3}/docs/GOVERNANCE.md +0 -0
  39. {gpjax-0.9.1 → gpjax-0.9.3}/docs/contributing.md +0 -0
  40. {gpjax-0.9.1 → gpjax-0.9.3}/docs/design.md +0 -0
  41. {gpjax-0.9.1 → gpjax-0.9.3}/docs/index.md +0 -0
  42. {gpjax-0.9.1 → gpjax-0.9.3}/docs/index.rst +0 -0
  43. {gpjax-0.9.1 → gpjax-0.9.3}/docs/installation.md +0 -0
  44. {gpjax-0.9.1 → gpjax-0.9.3}/docs/javascripts/katex.js +0 -0
  45. {gpjax-0.9.1 → gpjax-0.9.3}/docs/refs.bib +0 -0
  46. {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/gen_examples.py +0 -0
  47. {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/gen_pages.py +0 -0
  48. {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/notebook_converter.py +0 -0
  49. {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/sharp_bits_figure.py +0 -0
  50. {gpjax-0.9.1 → gpjax-0.9.3}/docs/sharp_bits.md +0 -0
  51. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/GP.pdf +0 -0
  52. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/GP.svg +0 -0
  53. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/bijector_figure.svg +0 -0
  54. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/css/gpjax_theme.css +0 -0
  55. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/favicon.ico +0 -0
  56. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax.mplstyle +0 -0
  57. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax_logo.pdf +0 -0
  58. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax_logo.svg +0 -0
  59. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/lato.ttf +0 -0
  60. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/logo.png +0 -0
  61. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/logo.svg +0 -0
  62. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/main.py +0 -0
  63. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/step_size_figure.png +0 -0
  64. {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/step_size_figure.svg +0 -0
  65. {gpjax-0.9.1 → gpjax-0.9.3}/docs/stylesheets/extra.css +0 -0
  66. {gpjax-0.9.1 → gpjax-0.9.3}/docs/stylesheets/permalinks.css +0 -0
  67. {gpjax-0.9.1 → gpjax-0.9.3}/examples/backend.py +0 -0
  68. {gpjax-0.9.1 → gpjax-0.9.3}/examples/barycentres/barycentre_gp.gif +0 -0
  69. {gpjax-0.9.1 → gpjax-0.9.3}/examples/barycentres.py +0 -0
  70. {gpjax-0.9.1 → gpjax-0.9.3}/examples/bayesian_optimisation.py +0 -0
  71. {gpjax-0.9.1 → gpjax-0.9.3}/examples/classification.py +0 -0
  72. {gpjax-0.9.1 → gpjax-0.9.3}/examples/collapsed_vi.py +0 -0
  73. {gpjax-0.9.1 → gpjax-0.9.3}/examples/constructing_new_kernels.py +0 -0
  74. {gpjax-0.9.1 → gpjax-0.9.3}/examples/data/max_tempeature_switzerland.csv +0 -0
  75. {gpjax-0.9.1 → gpjax-0.9.3}/examples/data/yacht_hydrodynamics.data +0 -0
  76. {gpjax-0.9.1 → gpjax-0.9.3}/examples/decision_making.py +0 -0
  77. {gpjax-0.9.1 → gpjax-0.9.3}/examples/deep_kernels.py +0 -0
  78. {gpjax-0.9.1 → gpjax-0.9.3}/examples/gpjax.mplstyle +0 -0
  79. {gpjax-0.9.1 → gpjax-0.9.3}/examples/graph_kernels.py +0 -0
  80. {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps/decomposed_mll.png +0 -0
  81. {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps/generating_process.png +0 -0
  82. {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps.py +0 -0
  83. {gpjax-0.9.1 → gpjax-0.9.3}/examples/likelihoods_guide.py +0 -0
  84. {gpjax-0.9.1 → gpjax-0.9.3}/examples/oceanmodelling.py +0 -0
  85. {gpjax-0.9.1 → gpjax-0.9.3}/examples/regression.py +0 -0
  86. {gpjax-0.9.1 → gpjax-0.9.3}/examples/utils.py +0 -0
  87. {gpjax-0.9.1 → gpjax-0.9.3}/examples/yacht.py +0 -0
  88. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/citation.py +0 -0
  89. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/dataset.py +0 -0
  90. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/__init__.py +0 -0
  91. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/decision_maker.py +0 -0
  92. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/posterior_handler.py +0 -0
  93. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/search_space.py +0 -0
  94. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/__init__.py +0 -0
  95. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
  96. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -0
  97. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/__init__.py +0 -0
  98. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/base.py +0 -0
  99. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
  100. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
  101. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
  102. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_maximizer.py +0 -0
  103. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utils.py +0 -0
  104. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/distributions.py +0 -0
  105. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/fit.py +0 -0
  106. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/integrators.py +0 -0
  107. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/__init__.py +0 -0
  108. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/approximations/__init__.py +0 -0
  109. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/approximations/rff.py +0 -0
  110. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/base.py +0 -0
  111. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/__init__.py +0 -0
  112. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/base.py +0 -0
  113. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/basis_functions.py +0 -0
  114. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  115. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/dense.py +0 -0
  116. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/diagonal.py +0 -0
  117. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/eigen.py +0 -0
  118. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  119. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/graph.py +0 -0
  120. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/utils.py +0 -0
  121. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/__init__.py +0 -0
  122. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  123. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/linear.py +0 -0
  124. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  125. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/__init__.py +0 -0
  126. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/base.py +0 -0
  127. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern12.py +0 -0
  128. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern32.py +0 -0
  129. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern52.py +0 -0
  130. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/periodic.py +0 -0
  131. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  132. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  133. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/rbf.py +0 -0
  134. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/utils.py +0 -0
  135. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/white.py +0 -0
  136. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/lower_cholesky.py +0 -0
  137. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/mean_functions.py +0 -0
  138. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/objectives.py +0 -0
  139. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/parameters.py +0 -0
  140. {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/typing.py +0 -0
  141. {gpjax-0.9.1 → gpjax-0.9.3}/pyproject.toml +0 -0
  142. {gpjax-0.9.1 → gpjax-0.9.3}/static/CONTRIBUTING.md +0 -0
  143. {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.bib +0 -0
  144. {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.md +0 -0
  145. {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.pdf +0 -0
  146. {gpjax-0.9.1 → gpjax-0.9.3}/tests/__init__.py +0 -0
  147. {gpjax-0.9.1 → gpjax-0.9.3}/tests/conftest.py +0 -0
  148. {gpjax-0.9.1 → gpjax-0.9.3}/tests/integration_tests.py +0 -0
  149. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_citations.py +0 -0
  150. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_dataset.py +0 -0
  151. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/__init__.py +0 -0
  152. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_decision_maker.py +0 -0
  153. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_posterior_handler.py +0 -0
  154. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_search_space.py +0 -0
  155. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
  156. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
  157. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
  158. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
  159. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
  160. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +0 -0
  161. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
  162. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
  163. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
  164. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_maximizer.py +0 -0
  165. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utils.py +0 -0
  166. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/utils.py +0 -0
  167. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_fit.py +0 -0
  168. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_gaussian_distribution.py +0 -0
  169. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_integrators.py +0 -0
  170. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/__init__.py +0 -0
  171. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_approximations.py +0 -0
  172. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_base.py +0 -0
  173. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_computation.py +0 -0
  174. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_non_euclidean.py +0 -0
  175. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_nonstationary.py +0 -0
  176. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_stationary.py +0 -0
  177. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_utils.py +0 -0
  178. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_likelihoods.py +0 -0
  179. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_lower_cholesky.py +0 -0
  180. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_markdown.py +0 -0
  181. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_mean_functions.py +0 -0
  182. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_objectives.py +0 -0
  183. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_parameters.py +0 -0
  184. {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: gpjax
3
- Version: 0.9.1
3
+ Version: 0.9.3
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -246,7 +246,7 @@ kernel = gpx.kernels.Matern52(
246
246
  prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
247
247
 
248
248
  likelihood = gpx.likelihoods.Gaussian(
249
- num_datapoints=D.n, obs_stddev=PositiveReal(value=jnp.array(1e-3), tag="Static")
249
+ num_datapoints=D.n, obs_stdev=Static(jnp.array(1e-3))
250
250
  ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
251
251
 
252
252
  no_opt_posterior = prior * likelihood
@@ -154,33 +154,22 @@ def logprob_fn(params):
154
154
  return gpx.objectives.log_posterior_density(model, D)
155
155
 
156
156
 
157
- # jit compile
158
- logprob_fn = jax.jit(logprob_fn)
159
- _ = logprob_fn(params)
157
+ step_size = 1e-3
158
+ inverse_mass_matrix = jnp.ones(53)
159
+ nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix)
160
160
 
161
+ state = nuts.init(params)
161
162
 
162
- adapt = blackjax.window_adaptation(
163
- blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True
164
- )
165
-
166
- # Initialise the chain
167
- last_state, kernel, _ = adapt.run(key, params)
168
-
169
-
170
- def inference_loop(rng_key, kernel, initial_state, num_samples):
171
- def one_step(state, rng_key):
172
- state, info = kernel(rng_key, state)
173
- return state, (state, info)
174
-
175
- keys = jax.random.split(rng_key, num_samples)
176
- _, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10)
163
+ step = jax.jit(nuts.step)
177
164
 
178
- return states, infos
179
165
 
166
+ def one_step(state, rng_key):
167
+ state, info = step(rng_key, state)
168
+ return state, (state, info)
180
169
 
181
- # Sample from the posterior distribution
182
- states, infos = inference_loop(key, kernel, last_state, num_samples)
183
170
 
171
+ keys = jax.random.split(key, num_samples)
172
+ _, (states, infos) = jax.lax.scan(one_step, state, keys, unroll=10)
184
173
 
185
174
  # %% [markdown]
186
175
  # ### Sampler efficiency
@@ -190,7 +179,7 @@ states, infos = inference_loop(key, kernel, last_state, num_samples)
190
179
  # proposed sample, divided by the total number of steps run by the chain).
191
180
 
192
181
  # %%
193
- acceptance_rate = jnp.mean(infos.acceptance_probability)
182
+ acceptance_rate = jnp.mean(infos.acceptance_rate)
194
183
  print(f"Acceptance rate: {acceptance_rate:.2f}")
195
184
 
196
185
  # %%
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.11.2
12
12
  # kernelspec:
13
13
  # display_name: gpjax_beartype
14
14
  # language: python
@@ -319,7 +319,6 @@ opt_rep, history = gpx.fit(
319
319
  model=q,
320
320
  objective=lambda p, d: -gpx.objectives.elbo(p, d),
321
321
  train_data=D,
322
- params_bijection=params_bijection,
323
322
  optim=ox.adam(learning_rate=0.01),
324
323
  num_iters=3000,
325
324
  key=jr.key(42),
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Didactic Gaussian processes in JAX"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.1"
43
+ __version__ = "0.9.3"
44
44
 
45
45
  __all__ = [
46
46
  "base",
@@ -17,6 +17,7 @@ from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
19
  from cola.annotations import PSD
20
+ from cola.linalg.algorithm_base import Algorithm
20
21
  from cola.linalg.decompositions.decompositions import Cholesky
21
22
  from cola.linalg.inverse.inv import solve
22
23
  from cola.ops.operators import I_like
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
530
531
  train_data: Dataset,
531
532
  key: KeyArray,
532
533
  num_features: int | None = 100,
534
+ solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
533
535
  ) -> FunctionalSample:
534
536
  r"""Draw approximate samples from the Gaussian process posterior.
535
537
 
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
563
565
  key (KeyArray): The random seed used for the sample(s).
564
566
  num_features (int): The number of features used when approximating the
565
567
  kernel.
568
+ solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
569
+ the inverse of the covariance matrix. See the
570
+ [CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
571
+ for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
572
+ matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
566
573
 
567
574
  Returns:
568
575
  FunctionalSample: A function representing an approximate sample from the Gaussian
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
588
595
  canonical_weights = solve(
589
596
  Sigma,
590
597
  y + eps - jnp.inner(Phi, fourier_weights),
591
- Cholesky(),
598
+ solver_algorithm,
592
599
  ) # [N, B]
593
600
 
594
601
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
@@ -28,7 +28,6 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- Parameter,
32
31
  PositiveReal,
33
32
  Static,
34
33
  )
@@ -152,10 +151,9 @@ class Gaussian(AbstractLikelihood):
152
151
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
153
152
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
154
153
  """
155
- if isinstance(obs_stddev, Parameter):
156
- self.obs_stddev = obs_stddev
157
- else:
158
- self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
154
+ if not isinstance(obs_stddev, (PositiveReal, Static)):
155
+ obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
156
+ self.obs_stddev = obs_stddev
159
157
 
160
158
  super().__init__(num_datapoints, integrator)
161
159
 
@@ -22,7 +22,6 @@ from beartype.typing import (
22
22
  )
23
23
  import jax
24
24
  from jax import lax
25
- from jax.experimental import host_callback as hcb
26
25
  import jax.numpy as jnp
27
26
  import jax.tree_util as jtu
28
27
  from jaxtyping import (
@@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:
54
53
 
55
54
  def _do_callback(_) -> int:
56
55
  """Perform the callback."""
57
- return hcb.id_tap(func, *args, result=_dummy_result)
56
+ jax.debug.callback(func, *args)
57
+ return _dummy_result
58
58
 
59
59
  def _not_callback(_) -> int:
60
60
  """Do nothing."""
@@ -113,19 +113,19 @@ def vscan(
113
113
  _progress_bar = trange(_length)
114
114
  _progress_bar.set_description("Compiling...", refresh=True)
115
115
 
116
- def _set_running(args: Any, transform: Any) -> None:
116
+ def _set_running(*args: Any) -> None:
117
117
  """Set the tqdm progress bar to running."""
118
118
  _progress_bar.set_description("Running", refresh=False)
119
119
 
120
- def _update_tqdm(args: Any, transform: Any) -> None:
120
+ def _update_tqdm(*args: Any) -> None:
121
121
  """Update the tqdm progress bar with the latest objective value."""
122
122
  _value, _iter_num = args
123
- _progress_bar.update(_iter_num)
123
+ _progress_bar.update(_iter_num.item())
124
124
 
125
125
  if log_value and _value is not None:
126
126
  _progress_bar.set_postfix({"Value": f"{_value: .2f}"})
127
127
 
128
- def _close_tqdm(args: Any, transform: Any) -> None:
128
+ def _close_tqdm(*args: Any) -> None:
129
129
  """Close the tqdm progress bar."""
130
130
  _progress_bar.close()
131
131
 
@@ -145,16 +145,16 @@ def vscan(
145
145
  _is_last: bool = iter_num == _length - 1
146
146
 
147
147
  # Update progress bar, if first of log_rate.
148
- _callback(_is_first, _set_running, (y, log_rate))
148
+ _callback(_is_first, _set_running)
149
149
 
150
150
  # Update progress bar, if multiple of log_rate.
151
- _callback(_is_multiple, _update_tqdm, (y, log_rate))
151
+ _callback(_is_multiple, _update_tqdm, y, log_rate)
152
152
 
153
153
  # Update progress bar, if remainder.
154
- _callback(_is_remainder, _update_tqdm, (y, _remainder))
154
+ _callback(_is_remainder, _update_tqdm, y, _remainder)
155
155
 
156
156
  # Close progress bar, if last iteration.
157
- _callback(_is_last, _close_tqdm, (y, None))
157
+ _callback(_is_last, _close_tqdm)
158
158
 
159
159
  return carry, y
160
160
 
@@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
108
  def __init__(
109
109
  self,
110
110
  posterior: AbstractPosterior[P, L],
111
- inducing_inputs: Float[Array, "N D"],
111
+ inducing_inputs: tp.Union[
112
+ Float[Array, "N D"],
113
+ Real,
114
+ Static,
115
+ ],
112
116
  jitter: ScalarFloat = 1e-6,
113
117
  ):
114
- self.inducing_inputs = Static(inducing_inputs)
118
+ if not isinstance(inducing_inputs, (Real, Static)):
119
+ inducing_inputs = Real(inducing_inputs)
120
+
121
+ self.inducing_inputs = inducing_inputs
115
122
  self.jitter = jitter
116
123
 
117
124
  super().__init__(posterior)
@@ -24,8 +24,8 @@ nav:
24
24
  - Barycentres: _examples/barycentres.md
25
25
  - Deep kernel learning: _examples/deep_kernels.md
26
26
  - Graph kernels: _examples/graph_kernels.md
27
- - Sparse GPs: _examples/uncollapsed_vi.md
28
- - Stochastic sparse GPs: _examples/collapsed_vi.md
27
+ - Sparse GPs: _examples/collapsed_vi.md
28
+ - Stochastic sparse GPs: _examples/uncollapsed_vi.md
29
29
  - Bayesian Optimisation: _examples/bayesian_optimisation.md
30
30
  - Decision Making: _examples/decision_making.md
31
31
  - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md
Binary file
@@ -25,13 +25,15 @@ from typing import (
25
25
  Type,
26
26
  )
27
27
 
28
+ from cola.linalg.algorithm_base import Auto
29
+ from cola.linalg.decompositions.decompositions import Cholesky
30
+ from cola.linalg.inverse.cg import CG
28
31
  from jax import config
29
32
  import jax.numpy as jnp
30
33
  import jax.random as jr
31
34
  import pytest
32
35
  import tensorflow_probability.substrates.jax.distributions as tfd
33
36
 
34
- # from gpjax.dataset import Dataset
35
37
  from gpjax.dataset import Dataset
36
38
  from gpjax.distributions import GaussianDistribution
37
39
  from gpjax.gps import (
@@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
283
285
  @pytest.mark.parametrize("num_datapoints", [1, 5])
284
286
  @pytest.mark.parametrize("kernel", [RBF, Matern52])
285
287
  @pytest.mark.parametrize("mean_function", [Zero, Constant])
286
- def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
288
+ @pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
289
+ def test_conjugate_posterior_sample_approx(
290
+ num_datapoints, kernel, mean_function, solver_algorithm
291
+ ):
287
292
  kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
288
293
  p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
289
294
  num_datapoints=num_datapoints
@@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
310
315
  # with pytest.raises(ValidationErrors):
311
316
  # p.sample_approx(1, D, key, 0.5)
312
317
 
313
- sampled_fn = p.sample_approx(1, D, key, 100)
318
+ sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
314
319
  assert isinstance(sampled_fn, Callable) # check type
315
320
 
316
321
  x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
317
322
  evals = sampled_fn(x)
318
323
  assert evals.shape == (num_datapoints, 1.0) # check shape
319
324
 
320
- sampled_fn_2 = p.sample_approx(1, D, key, 100)
325
+ sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
321
326
  evals_2 = sampled_fn_2(x)
322
327
  max_delta = jnp.max(jnp.abs(evals - evals_2))
323
328
  assert max_delta == 0.0 # samples same for same seed
324
329
 
325
330
  new_key = jr.key(12345)
326
- sampled_fn_3 = p.sample_approx(1, D, new_key, 100)
331
+ sampled_fn_3 = p.sample_approx(
332
+ 1, D, new_key, 100, solver_algorithm=solver_algorithm
333
+ )
327
334
  evals_3 = sampled_fn_3(x)
328
335
  max_delta = jnp.max(jnp.abs(evals - evals_3))
329
336
  assert max_delta > 0.01 # samples different for different seed
330
337
 
331
338
  # Check validty of samples using Monte-Carlo
332
- sampled_fn = p.sample_approx(10_000, D, key, 100)
339
+ sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
333
340
  sampled_evals = sampled_fn(x)
334
341
  approx_mean = jnp.mean(sampled_evals, -1)
335
342
  approx_var = jnp.var(sampled_evals, -1)
@@ -1,18 +0,0 @@
1
- name: Labeler
2
-
3
- on:
4
- push:
5
- branches:
6
- - main
7
-
8
- jobs:
9
- labeler:
10
- runs-on: ubuntu-latest
11
- steps:
12
- - name: Check out the repository
13
- uses: actions/checkout@v3.5.2
14
-
15
- - name: Run Labeler
16
- uses: crazy-max/ghaction-github-labeler@v4.1.0
17
- with:
18
- skip-delete: true
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes