gpjax 0.9.2__tar.gz → 0.9.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {gpjax-0.9.2 → gpjax-0.9.4}/PKG-INFO +18 -18
  2. {gpjax-0.9.2 → gpjax-0.9.4}/README.md +15 -15
  3. {gpjax-0.9.2 → gpjax-0.9.4}/docs/index.md +2 -3
  4. {gpjax-0.9.2 → gpjax-0.9.4}/examples/backend.py +4 -4
  5. {gpjax-0.9.2 → gpjax-0.9.4}/examples/barycentres.py +3 -3
  6. {gpjax-0.9.2 → gpjax-0.9.4}/examples/bayesian_optimisation.py +3 -3
  7. {gpjax-0.9.2 → gpjax-0.9.4}/examples/classification.py +6 -1
  8. {gpjax-0.9.2 → gpjax-0.9.4}/examples/collapsed_vi.py +2 -2
  9. {gpjax-0.9.2 → gpjax-0.9.4}/examples/constructing_new_kernels.py +12 -6
  10. {gpjax-0.9.2 → gpjax-0.9.4}/examples/decision_making.py +5 -5
  11. {gpjax-0.9.2 → gpjax-0.9.4}/examples/deep_kernels.py +2 -2
  12. {gpjax-0.9.2 → gpjax-0.9.4}/examples/graph_kernels.py +5 -3
  13. {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps.py +38 -12
  14. {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_kernels.py +42 -21
  15. {gpjax-0.9.2 → gpjax-0.9.4}/examples/likelihoods_guide.py +5 -3
  16. {gpjax-0.9.2 → gpjax-0.9.4}/examples/oceanmodelling.py +6 -4
  17. {gpjax-0.9.2 → gpjax-0.9.4}/examples/poisson.py +12 -23
  18. {gpjax-0.9.2 → gpjax-0.9.4}/examples/regression.py +1 -1
  19. {gpjax-0.9.2 → gpjax-0.9.4}/examples/uncollapsed_vi.py +3 -4
  20. {gpjax-0.9.2 → gpjax-0.9.4}/examples/yacht.py +5 -5
  21. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/__init__.py +1 -1
  22. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/non_conjugate_functions.py +2 -2
  23. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/gps.py +2 -1
  24. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/likelihoods.py +3 -5
  25. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/scan.py +10 -10
  26. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/variational_families.py +33 -21
  27. {gpjax-0.9.2 → gpjax-0.9.4}/mkdocs.yml +2 -2
  28. {gpjax-0.9.2 → gpjax-0.9.4}/pyproject.toml +2 -2
  29. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +1 -1
  30. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_fit.py +2 -2
  31. {gpjax-0.9.2 → gpjax-0.9.4}/.github/CODE_OF_CONDUCT.md +0 -0
  32. {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  33. {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  34. {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  35. {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  36. {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  37. {gpjax-0.9.2 → gpjax-0.9.4}/.github/codecov.yml +0 -0
  38. {gpjax-0.9.2 → gpjax-0.9.4}/.github/labels.yml +0 -0
  39. {gpjax-0.9.2 → gpjax-0.9.4}/.github/pull_request_template.md +0 -0
  40. {gpjax-0.9.2 → gpjax-0.9.4}/.github/release-drafter.yml +0 -0
  41. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/build_docs.yml +0 -0
  42. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/integration.yml +0 -0
  43. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/pr_greeting.yml +0 -0
  44. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/ruff.yml +0 -0
  45. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/stale_prs.yml +0 -0
  46. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/test_docs.yml +0 -0
  47. {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/tests.yml +0 -0
  48. {gpjax-0.9.2 → gpjax-0.9.4}/.gitignore +0 -0
  49. {gpjax-0.9.2 → gpjax-0.9.4}/CITATION.bib +0 -0
  50. {gpjax-0.9.2 → gpjax-0.9.4}/LICENSE +0 -0
  51. {gpjax-0.9.2 → gpjax-0.9.4}/Makefile +0 -0
  52. {gpjax-0.9.2 → gpjax-0.9.4}/docs/CODE_OF_CONDUCT.md +0 -0
  53. {gpjax-0.9.2 → gpjax-0.9.4}/docs/GOVERNANCE.md +0 -0
  54. {gpjax-0.9.2 → gpjax-0.9.4}/docs/contributing.md +0 -0
  55. {gpjax-0.9.2 → gpjax-0.9.4}/docs/design.md +0 -0
  56. {gpjax-0.9.2 → gpjax-0.9.4}/docs/index.rst +0 -0
  57. {gpjax-0.9.2 → gpjax-0.9.4}/docs/installation.md +0 -0
  58. {gpjax-0.9.2 → gpjax-0.9.4}/docs/javascripts/katex.js +0 -0
  59. {gpjax-0.9.2 → gpjax-0.9.4}/docs/refs.bib +0 -0
  60. {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/gen_examples.py +0 -0
  61. {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/gen_pages.py +0 -0
  62. {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/notebook_converter.py +0 -0
  63. {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/sharp_bits_figure.py +0 -0
  64. {gpjax-0.9.2 → gpjax-0.9.4}/docs/sharp_bits.md +0 -0
  65. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/GP.pdf +0 -0
  66. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/GP.svg +0 -0
  67. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/bijector_figure.svg +0 -0
  68. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/css/gpjax_theme.css +0 -0
  69. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/favicon.ico +0 -0
  70. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax.mplstyle +0 -0
  71. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax_logo.pdf +0 -0
  72. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax_logo.svg +0 -0
  73. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/lato.ttf +0 -0
  74. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/logo.png +0 -0
  75. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/logo.svg +0 -0
  76. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/main.py +0 -0
  77. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/step_size_figure.png +0 -0
  78. {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/step_size_figure.svg +0 -0
  79. {gpjax-0.9.2 → gpjax-0.9.4}/docs/stylesheets/extra.css +0 -0
  80. {gpjax-0.9.2 → gpjax-0.9.4}/docs/stylesheets/permalinks.css +0 -0
  81. {gpjax-0.9.2 → gpjax-0.9.4}/examples/barycentres/barycentre_gp.gif +0 -0
  82. {gpjax-0.9.2 → gpjax-0.9.4}/examples/data/max_tempeature_switzerland.csv +0 -0
  83. {gpjax-0.9.2 → gpjax-0.9.4}/examples/data/yacht_hydrodynamics.data +0 -0
  84. {gpjax-0.9.2 → gpjax-0.9.4}/examples/gpjax.mplstyle +0 -0
  85. {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps/decomposed_mll.png +0 -0
  86. {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps/generating_process.png +0 -0
  87. {gpjax-0.9.2 → gpjax-0.9.4}/examples/utils.py +0 -0
  88. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/citation.py +0 -0
  89. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/dataset.py +0 -0
  90. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/__init__.py +0 -0
  91. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/decision_maker.py +0 -0
  92. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/posterior_handler.py +0 -0
  93. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/search_space.py +0 -0
  94. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/__init__.py +0 -0
  95. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
  96. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/__init__.py +0 -0
  97. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/base.py +0 -0
  98. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
  99. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
  100. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
  101. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_maximizer.py +0 -0
  102. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utils.py +0 -0
  103. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/distributions.py +0 -0
  104. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/fit.py +0 -0
  105. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/integrators.py +0 -0
  106. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/__init__.py +0 -0
  107. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/approximations/__init__.py +0 -0
  108. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/approximations/rff.py +0 -0
  109. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/base.py +0 -0
  110. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/__init__.py +0 -0
  111. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/base.py +0 -0
  112. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/basis_functions.py +0 -0
  113. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  114. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/dense.py +0 -0
  115. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/diagonal.py +0 -0
  116. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/eigen.py +0 -0
  117. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  118. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/graph.py +0 -0
  119. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/utils.py +0 -0
  120. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/__init__.py +0 -0
  121. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  122. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/linear.py +0 -0
  123. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  124. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/__init__.py +0 -0
  125. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/base.py +0 -0
  126. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern12.py +0 -0
  127. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern32.py +0 -0
  128. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern52.py +0 -0
  129. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/periodic.py +0 -0
  130. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  131. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  132. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/rbf.py +0 -0
  133. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/utils.py +0 -0
  134. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/white.py +0 -0
  135. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/lower_cholesky.py +0 -0
  136. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/mean_functions.py +0 -0
  137. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/objectives.py +0 -0
  138. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/parameters.py +0 -0
  139. {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/typing.py +0 -0
  140. {gpjax-0.9.2 → gpjax-0.9.4}/static/CONTRIBUTING.md +0 -0
  141. {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.bib +0 -0
  142. {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.md +0 -0
  143. {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.pdf +0 -0
  144. {gpjax-0.9.2 → gpjax-0.9.4}/tests/__init__.py +0 -0
  145. {gpjax-0.9.2 → gpjax-0.9.4}/tests/conftest.py +0 -0
  146. {gpjax-0.9.2 → gpjax-0.9.4}/tests/integration_tests.py +0 -0
  147. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_citations.py +0 -0
  148. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_dataset.py +0 -0
  149. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/__init__.py +0 -0
  150. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_decision_maker.py +0 -0
  151. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_posterior_handler.py +0 -0
  152. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_search_space.py +0 -0
  153. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
  154. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
  155. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
  156. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
  157. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
  158. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
  159. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
  160. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
  161. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_maximizer.py +0 -0
  162. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utils.py +0 -0
  163. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/utils.py +0 -0
  164. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_gaussian_distribution.py +0 -0
  165. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_gps.py +0 -0
  166. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_integrators.py +0 -0
  167. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/__init__.py +0 -0
  168. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_approximations.py +0 -0
  169. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_base.py +0 -0
  170. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_computation.py +0 -0
  171. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_non_euclidean.py +0 -0
  172. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_nonstationary.py +0 -0
  173. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_stationary.py +0 -0
  174. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_utils.py +0 -0
  175. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_likelihoods.py +0 -0
  176. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_lower_cholesky.py +0 -0
  177. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_markdown.py +0 -0
  178. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_mean_functions.py +0 -0
  179. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_objectives.py +0 -0
  180. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_parameters.py +0 -0
  181. {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.9.2
3
+ Version: 0.9.4
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
19
19
  Requires-Python: <3.13,>=3.10
20
20
  Requires-Dist: beartype>0.16.1
21
21
  Requires-Dist: cola-ml==0.0.5
22
- Requires-Dist: flax>=0.8.5
22
+ Requires-Dist: flax<0.10.0
23
23
  Requires-Dist: jax<0.4.28
24
24
  Requires-Dist: jaxlib<0.4.28
25
25
  Requires-Dist: jaxopt==0.8.2
@@ -103,23 +103,23 @@ helped to shape GPJax into the package it is today.
103
103
 
104
104
  ## Notebook examples
105
105
 
106
- > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
107
- > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
108
- > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
109
- > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
110
- > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
111
- > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
112
- > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
113
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/examples/spatial/)
114
- > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/examples/barycentres/)
115
- > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/examples/deep_kernels/)
116
- > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/examples/poisson/)
117
- > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/)
106
+ > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
107
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
108
+ > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
109
+ > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
110
+ > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
111
+ > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
112
+ > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
113
+ > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
114
+ > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
115
+ > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
116
+ > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
117
+ > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
118
118
 
119
119
  ## Guides for customisation
120
120
  >
121
- > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
122
- > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/examples/yacht/)
121
+ > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
122
+ > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
123
123
 
124
124
  ## Conversion between `.ipynb` and `.py`
125
125
  Above examples are stored in [examples](docs/examples) directory in the double
@@ -180,7 +180,7 @@ optimiser = ox.adam(learning_rate=1e-2)
180
180
  # Obtain Type 2 MLEs of the hyperparameters
181
181
  opt_posterior, history = gpx.fit(
182
182
  model=posterior,
183
- objective=gpx.objectives.conjugate_mll,
183
+ objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
184
184
  train_data=D,
185
185
  optim=optimiser,
186
186
  num_iters=500,
@@ -71,23 +71,23 @@ helped to shape GPJax into the package it is today.
71
71
 
72
72
  ## Notebook examples
73
73
 
74
- > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
75
- > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
76
- > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
77
- > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
78
- > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
79
- > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
80
- > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
81
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/examples/spatial/)
82
- > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/examples/barycentres/)
83
- > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/examples/deep_kernels/)
84
- > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/examples/poisson/)
85
- > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/)
74
+ > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
75
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
76
+ > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
77
+ > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
78
+ > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
79
+ > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
80
+ > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
81
+ > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
82
+ > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
83
+ > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
84
+ > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
85
+ > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
86
86
 
87
87
  ## Guides for customisation
88
88
  >
89
- > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
90
- > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/examples/yacht/)
89
+ > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
90
+ > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
91
91
 
92
92
  ## Conversion between `.ipynb` and `.py`
93
93
  Above examples are stored in [examples](docs/examples) directory in the double
@@ -148,7 +148,7 @@ optimiser = ox.adam(learning_rate=1e-2)
148
148
  # Obtain Type 2 MLEs of the hyperparameters
149
149
  opt_posterior, history = gpx.fit(
150
150
  model=posterior,
151
- objective=gpx.objectives.conjugate_mll,
151
+ objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
152
152
  train_data=D,
153
153
  optim=optimiser,
154
154
  num_iters=500,
@@ -1,4 +1,4 @@
1
- # Welcome to GPJax!
1
+ # Welcome to GPJax
2
2
 
3
3
  GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU
4
4
  acceleration and just-in-time compilation. We seek to provide a flexible
@@ -6,7 +6,6 @@ API to enable researchers to rapidly prototype and develop new ideas.
6
6
 
7
7
  ![Gaussian process posterior.](static/GP.svg)
8
8
 
9
-
10
9
  ## "Hello, GP!"
11
10
 
12
11
  Typing GP models is as simple as the maths we
@@ -53,7 +52,7 @@ would write on paper, as shown below.
53
52
  !!! Begin
54
53
 
55
54
  Looking for a good place to start? Then why not begin with our [regression
56
- notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
55
+ notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
57
56
 
58
57
  ## Citing GPJax
59
58
 
@@ -122,7 +122,7 @@ print(constant_param._tag)
122
122
  # For most users, you will not need to worry about this as we provide a set of default
123
123
  # bijectors that are defined for all the parameter types we support. However, see our
124
124
  # [Kernel Guide
125
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to
125
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) to
126
126
  # see how you can define your own bijectors and parameter types.
127
127
 
128
128
  # %%
@@ -156,7 +156,7 @@ transform(_close_to_zero_state, DEFAULT_BIJECTION, inverse=True)
156
156
  # may be nested within several functions e.g., a kernel function within a GP model.
157
157
  # Fortunately, transforming several parameters is a simple operation that we here
158
158
  # demonstrate for a conjugate GP posterior (see our [Regression
159
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed
159
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) for detailed
160
160
  # explanation of this model.).
161
161
 
162
162
  # %%
@@ -239,7 +239,7 @@ print(positive_reals)
239
239
  # useful as it allows us to efficiently operate on a subset of the parameters whilst
240
240
  # leaving the others untouched. Looking forward, we hope to use this functionality in
241
241
  # our [Variational Inference
242
- # Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to
242
+ # Approximations](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/) to
243
243
  # perform more efficient updates of the variational parameters and then the model's
244
244
  # hyperparameters.
245
245
 
@@ -361,7 +361,7 @@ ax.set(xlabel="x", ylabel="m(x)")
361
361
  # In this notebook we have explored how GPJax's Flax-based backend may be easily
362
362
  # manipulated and extended. For a more applied look at this, see how we construct a
363
363
  # kernel on polar coordinates in our [Kernel
364
- # Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
364
+ # Guide](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
365
365
  # notebook.
366
366
  #
367
367
  # ## System configuration
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -154,9 +154,9 @@ plt.show()
154
154
  # We'll now independently learn Gaussian process posterior distributions for each
155
155
  # dataset. We won't spend any time here discussing how GP hyperparameters are
156
156
  # optimised. For advice on achieving this, see the
157
- # [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/)
157
+ # [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
158
158
  # for advice on optimisation and the
159
- # [Kernels notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) for
159
+ # [Kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) for
160
160
  # advice on selecting an appropriate kernel.
161
161
 
162
162
 
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -20,7 +20,7 @@
20
20
  #
21
21
  # In this guide we introduce the Bayesian Optimisation (BO) paradigm for
22
22
  # optimising black-box functions. We'll assume an understanding of Gaussian processes
23
- # (GPs), so if you're not familiar with them, check out our [GP introduction notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/).
23
+ # (GPs), so if you're not familiar with them, check out our [GP introduction notebook](https://docs.jaxgaussianprocesses.com/_examples/intro_to_gps/).
24
24
 
25
25
  # %%
26
26
  from typing import (
@@ -278,7 +278,7 @@ opt_posterior = return_optimised_posterior(D, prior, key)
278
278
  # will do this using the `sample_approx` method, which generates an approximate sample
279
279
  # from the posterior using decoupled sampling introduced in ([Wilson et al.,
280
280
  # 2020](https://proceedings.mlr.press/v119/wilson20a.html)) and discussed in our [Pathwise
281
- # Sampling Notebook](https://docs.jaxgaussianprocesses.com/examples/spatial/). This method
281
+ # Sampling Notebook](https://docs.jaxgaussianprocesses.com/_examples/spatial/). This method
282
282
  # is used as it enables us to sample from the posterior in a manner which scales linearly
283
283
  # with the number of points sampled, $O(N)$, mitigating the cubic cost associated with
284
284
  # drawing exact samples from a GP posterior, $O(N^3)$. It also generates more accurate
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -193,15 +193,20 @@ ax.legend()
193
193
  # $\boldsymbol{x}$, we can expand the log of this about the posterior mode
194
194
  # $\hat{\boldsymbol{f}}$ via a Taylor expansion. This gives:
195
195
  #
196
+ # $$
196
197
  # \begin{align}
197
198
  # \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3).
198
199
  # \end{align}
200
+ # $$
199
201
  #
200
202
  # Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode,
201
203
  # this suggests the following approximation
204
+ #
205
+ # $$
202
206
  # \begin{align}
203
207
  # \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\}
204
208
  # \end{align},
209
+ # $$
205
210
  #
206
211
  # that we identify as a Gaussian distribution,
207
212
  # $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$.
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.4
10
+ # jupytext_version: 1.16.6
11
11
  # kernelspec:
12
12
  # display_name: gpjax_beartype
13
13
  # language: python
@@ -131,7 +131,7 @@ q = gpx.variational_families.CollapsedVariationalGaussian(
131
131
  # %% [markdown]
132
132
  # We now train our model akin to a Gaussian process regression model via the `fit`
133
133
  # abstraction. Unlike the regression example given in the
134
- # [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/),
134
+ # [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/),
135
135
  # the inducing locations that induce our variational posterior distribution are now
136
136
  # part of the model's parameters. Using a gradient-based optimiser, we can then
137
137
  # _optimise_ their location such that the evidence lower bound is maximised.
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -71,7 +71,7 @@ cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
71
71
  # * White noise
72
72
  # * Linear.
73
73
  # * Polynomial.
74
- # * [Graph kernels](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/).
74
+ # * [Graph kernels](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/).
75
75
  #
76
76
  # While the syntax is consistent, each kernel's type influences the
77
77
  # characteristics of the sample paths drawn. We visualise this below with 10
@@ -185,7 +185,7 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
185
185
  # We'll demonstrate this process now for a circular kernel --- an adaption of
186
186
  # the excellent guide given in the PYMC3 documentation. We encourage curious
187
187
  # readers to visit their notebook
188
- # [here](https://www.pymc.io/projects/docs/en/v3/pymc-examples/examples/gaussian_processes/GP-Circular.html).
188
+ # [here](https://www.pymc.io/projects/docs/en/v3/pymc-_examples/_examples/gaussian_processes/GP-Circular.html).
189
189
  #
190
190
  # ### Circular kernel
191
191
  #
@@ -198,9 +198,15 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
198
198
  # kernels do not exhibit this behaviour and instead _wrap_ around the boundary
199
199
  # points to create a smooth function. Such a kernel was given in [Padonou &
200
200
  # Roustant (2015)](https://hal.inria.fr/hal-01119942v1) where any two angles
201
- # $\theta$ and $\theta'$ are written as $$W_c(\theta, \theta') = \left\lvert
201
+ # $\theta$ and $\theta'$ are written as
202
+ #
203
+ # $$
204
+ # \begin{align}
205
+ # W_c(\theta, \theta') & = \left\lvert
202
206
  # \left(1 + \tau \frac{d(\theta, \theta')}{c} \right) \left(1 - \frac{d(\theta,
203
- # \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}.$$
207
+ # \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}.
208
+ # \end{align}
209
+ # $$
204
210
  #
205
211
  # Here the hyperparameter $\tau$ is analogous to a lengthscale for Euclidean
206
212
  # stationary kernels, controlling the correlation between pairs of observations.
@@ -266,7 +272,7 @@ class Polar(gpx.kernels.AbstractKernel):
266
272
  #
267
273
  # We proceed to fit a GP with our custom circular kernel to a random sequence of
268
274
  # points on a circle (see the
269
- # [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/)
275
+ # [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
270
276
  # for further details on this process).
271
277
 
272
278
  # %%
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.4
10
+ # jupytext_version: 1.16.6
11
11
  # kernelspec:
12
12
  # display_name: gpjax
13
13
  # language: python
@@ -22,7 +22,7 @@
22
22
  # such problems include Bayesian optimisation (BO) and experimental design. For an
23
23
  # in-depth introduction to Bayesian optimisation itself, be sure to checkout out our
24
24
  # [Introduction to BO
25
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
25
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/).
26
26
  #
27
27
  # We'll be using BO as a case study to demonstrate how one may use the decision making
28
28
  # module to solve sequential decision making problems. The goal of the decision making
@@ -76,7 +76,7 @@ cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
76
76
  # ## The Black-Box Objective Function
77
77
  #
78
78
  # We'll be using the same problem as in the [Introduction to BO
79
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/), but
79
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/), but
80
80
  # rather than focussing on the mechanics of BO we'll be looking at how one may use the
81
81
  # abstractions provided by the decision making module to implement the BO loop.
82
82
  #
@@ -181,7 +181,7 @@ likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
181
181
  # this for us. This class takes as input a `prior` and `likeligood_builder`, which we have
182
182
  # defined above. We tend to also optimise the hyperparameters of the GP prior when
183
183
  # "fitting" our GP, as demonstrated in the [Regression
184
- # notebook](https://docs.jaxgaussianprocesses.com/examples/regression/). This will be
184
+ # notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/). This will be
185
185
  # using the GPJax `fit` method under the hood, which requires an `optimization_objective`,
186
186
  # `optimizer` and `num_optimization_iters`. Therefore, we also pass these to the
187
187
  # `PosteriorHandler` as demonstrated below:
@@ -257,7 +257,7 @@ acquisition_maximizer = ContinuousSinglePointUtilityMaximizer(
257
257
  #
258
258
  # It is worth noting that `ThompsonSampling` is not the only utility function we could use,
259
259
  # since our module also provides e.g. `ProbabilityOfImprovement`, `ExpectedImprovment`,
260
- # which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
260
+ # which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/).
261
261
 
262
262
 
263
263
  # %% [markdown]
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -141,7 +141,7 @@ class DeepKernelFunction(AbstractKernel):
141
141
  # activation functions between the layers. The first hidden layer contains 64 units,
142
142
  # while the second layer contains 32 units. Finally, we'll make the output of our
143
143
  # network a three units wide. The corresponding kernel that we define will then be of
144
- # [ARD form](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#active-dimensions)
144
+ # [ARD form](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#active-dimensions)
145
145
  # to allow for different lengthscales in each dimension of the feature space.
146
146
  # Users may wish to design more intricate network structures for more complex tasks,
147
147
  # which functionality is supported well in Haiku.
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -22,7 +22,7 @@
22
22
  # of a graph using a Gaussian process with a Matérn kernel presented in
23
23
  # <strong data-cite="borovitskiy2021matern"></strong>. For a general discussion of the
24
24
  # kernels supported within GPJax, see the
25
- # [kernels notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels).
25
+ # [kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels).
26
26
 
27
27
  # %%
28
28
  import random
@@ -88,7 +88,9 @@ nx.draw(
88
88
  #
89
89
  # Graph kernels use the _Laplacian matrix_ $L$ to quantify the smoothness of a signal
90
90
  # (or function) on a graph
91
+ #
91
92
  # $$L=D-A,$$
93
+ #
92
94
  # where $D$ is the diagonal _degree matrix_ containing each vertices' degree and $A$
93
95
  # is the _adjacency matrix_ that has an $(i,j)^{\text{th}}$ entry of 1 if $v_i, v_j$
94
96
  # are connected and 0 otherwise. [Networkx](https://networkx.org) gives us an easy
@@ -151,7 +153,7 @@ cbar = plt.colorbar(sm, ax=ax)
151
153
  # non-Euclidean, our likelihood is still Gaussian and the model is still conjugate.
152
154
  # For this reason, we simply perform gradient descent on the GP's marginal
153
155
  # log-likelihood term as in the
154
- # [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
156
+ # [regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
155
157
  # We do this using the BFGS optimiser.
156
158
 
157
159
  # %%
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.4
10
+ # jupytext_version: 1.16.6
11
11
  # kernelspec:
12
12
  # display_name: gpjax
13
13
  # language: python
@@ -17,6 +17,7 @@
17
17
  # %% [markdown]
18
18
  # # New to Gaussian Processes?
19
19
  #
20
+ #
20
21
  # Fantastic that you're here! This notebook is designed to be a gentle
21
22
  # introduction to the mathematics of Gaussian processes (GPs). No prior
22
23
  # knowledge of Bayesian inference or GPs is assumed, and this notebook is
@@ -33,10 +34,11 @@
33
34
  # model are unknown, and our goal is to conduct inference to determine their
34
35
  # range of likely values. To achieve this, we apply Bayes' theorem
35
36
  #
37
+ # $$
36
38
  # \begin{align}
37
- # \label{eq:BayesTheorem}
38
- # p(\theta\,|\, \mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta}\,,
39
+ # p(\theta\mid\mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta},
39
40
  # \end{align}
41
+ # $$
40
42
  #
41
43
  # where $p(\mathbf{y}\,|\,\theta)$ denotes the _likelihood_, or model, and
42
44
  # quantifies how likely the observed dataset $\mathbf{y}$ is, given the
@@ -58,7 +60,7 @@
58
60
  # family, then there exists a conjugate prior. However, the conjugate prior may
59
61
  # not have a form that precisely reflects the practitioner's belief surrounding
60
62
  # the parameter. For this reason, conjugate models seldom appear; one exception
61
- # to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
63
+ # to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
62
64
  #
63
65
  # For models that do not contain a conjugate prior, the marginal log-likelihood
64
66
  # must be calculated to normalise the posterior distribution and ensure it
@@ -74,9 +76,13 @@
74
76
  # new points $\mathbf{y}^{\star}$ through the _posterior predictive
75
77
  # distribution_. This is achieved by integrating out the parameter set $\theta$
76
78
  # from our posterior distribution through
79
+ #
80
+ # $$
77
81
  # \begin{align}
78
82
  # p(\mathbf{y}^{\star}\mid \mathbf{y}) = \int p(\mathbf{y}^{\star} \,|\, \theta, \mathbf{y} ) p(\theta\,|\, \mathbf{y})\mathrm{d}\theta\,.
79
83
  # \end{align}
84
+ # $$
85
+ #
80
86
  # As with the marginal log-likelihood, evaluating this quantity requires
81
87
  # computing an integral which may not be tractable, particularly when $\theta$
82
88
  # is high-dimensional.
@@ -85,13 +91,16 @@
85
91
  # distribution, so we often compute and report moments of the posterior
86
92
  # distribution. Most commonly, we report the first moment and the centred second
87
93
  # moment
94
+ #
88
95
  # $$
89
96
  # \begin{alignat}{2}
90
- # \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta p(\theta\mid\mathbf{y})\mathrm{d}\theta\\
97
+ # \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta
98
+ # p(\theta\mid\mathbf{y})\mathrm{d}\theta \quad \\
91
99
  # \sigma^2 = \mathbb{V}[\theta\,|\,\mathbf{y}] & = \int \left(\theta -
92
100
  # \mathbb{E}[\theta\,|\,\mathbf{y}]\right)^2p(\theta\,|\,\mathbf{y})\mathrm{d}\theta&\,.
93
101
  # \end{alignat}
94
102
  # $$
103
+ #
95
104
  # Through this pair of statistics, we can communicate our beliefs about the most
96
105
  # likely value of $\theta$ i.e., $\mu$, and the uncertainty $\sigma$ around the
97
106
  # expected value. However, as with the marginal log-likelihood and predictive
@@ -209,9 +218,7 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists):
209
218
  d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
210
219
  xx.shape
211
220
  )
212
- cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap)
213
- for c in cntf.collections:
214
- c.set_edgecolor("face")
221
+ cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
215
222
  a.set_xlim(-2.75, 2.75)
216
223
  a.set_ylim(-2.75, 2.75)
217
224
  samples = d.sample(seed=key, sample_shape=(5000,))
@@ -228,13 +235,16 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists):
228
235
  # %% [markdown]
229
236
  # Extending the intuition given for the moments of a univariate Gaussian random
230
237
  # variables, we can obtain the mean and covariance by
238
+ #
231
239
  # $$
232
240
  # \begin{align}
233
- # \mathbb{E}[\mathbf{y}] = \mathbf{\mu}, \quad \operatorname{Cov}(\mathbf{y}) & = \mathbf{E}\left[(\mathbf{y} - \mathbf{\mu)}(\mathbf{y} - \mathbf{\mu)}^{\top} \right]\\
241
+ # \mathbb{E}[\mathbf{y}] & = \mathbf{\mu}, \\
242
+ # \operatorname{Cov}(\mathbf{y}) & = \mathbf{E}\left[(\mathbf{y} - \mathbf{\mu})(\mathbf{y} - \mathbf{\mu})^{\top} \right] \\
234
243
  # & =\mathbb{E}[\mathbf{y}\mathbf{y}^{\top}] - \mathbb{E}[\mathbf{y}]\mathbb{E}[\mathbf{y}]^{\top} \\
235
244
  # & =\mathbf{\Sigma}\,.
236
245
  # \end{align}
237
246
  # $$
247
+ #
238
248
  # The covariance matrix is a symmetric positive definite matrix that generalises
239
249
  # the notion of variance to multiple dimensions. The matrix's diagonal entries
240
250
  # contain the variance of each element, whilst the off-diagonal entries quantify
@@ -336,6 +346,7 @@ with warnings.catch_warnings():
336
346
  # $\mathbf{x}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{x}}, \boldsymbol{\Sigma}_{\mathbf{xx}})$ and
337
347
  # $\mathbf{y}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{y}}, \boldsymbol{\Sigma}_{\mathbf{yy}})$.
338
348
  # We define the joint distribution as
349
+ #
339
350
  # $$
340
351
  # \begin{align}
341
352
  # p\left(\begin{bmatrix}
@@ -348,6 +359,7 @@ with warnings.catch_warnings():
348
359
  # \end{bmatrix} \right)\,,
349
360
  # \end{align}
350
361
  # $$
362
+ #
351
363
  # where $\boldsymbol{\Sigma}_{\mathbf{x}\mathbf{y}}$ is the cross-covariance
352
364
  # matrix of $\mathbf{x}$ and $\mathbf{y}$.
353
365
  #
@@ -363,6 +375,7 @@ with warnings.catch_warnings():
363
375
  #
364
376
  # For a joint Gaussian random variable, the marginalisation of $\mathbf{x}$ or
365
377
  # $\mathbf{y}$ is given by
378
+ #
366
379
  # $$
367
380
  # \begin{alignat}{3}
368
381
  # & \int p(\mathbf{x}, \mathbf{y})\mathrm{d}\mathbf{y} && = p(\mathbf{x})
@@ -372,7 +385,9 @@ with warnings.catch_warnings():
372
385
  # \boldsymbol{\Sigma}_{\mathbf{yy}})\,.
373
386
  # \end{alignat}
374
387
  # $$
388
+ #
375
389
  # The conditional distributions are given by
390
+ #
376
391
  # $$
377
392
  # \begin{align}
378
393
  # p(\mathbf{y}\,|\, \mathbf{x}) & = \mathcal{N}\left(\boldsymbol{\mu}_{\mathbf{y}} + \boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}(\mathbf{x}-\boldsymbol{\mu}_{\mathbf{x}}), \boldsymbol{\Sigma}_{\mathbf{yy}}-\boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}\boldsymbol{\Sigma}_{\mathbf{xy}}\right)\,.
@@ -401,6 +416,7 @@ with warnings.catch_warnings():
401
416
  # We aim to capture the relationship between $\mathbf{X}$ and $\mathbf{y}$ using
402
417
  # a model $f$ with which we may make predictions at an unseen set of test points
403
418
  # $\mathbf{X}^{\star}\subset\mathcal{X}$. We formalise this by
419
+ #
404
420
  # $$
405
421
  # \begin{align}
406
422
  # y = f(\mathbf{X}) + \varepsilon\,,
@@ -430,6 +446,7 @@ with warnings.catch_warnings():
430
446
  # convenience in the remainder of this article.
431
447
  #
432
448
  # We define a joint GP prior over the latent function
449
+ #
433
450
  # $$
434
451
  # \begin{align}
435
452
  # p(\mathbf{f}, \mathbf{f}^{\star}) = \mathcal{N}\left(\mathbf{0}, \begin{bmatrix}
@@ -437,14 +454,17 @@ with warnings.catch_warnings():
437
454
  # \end{bmatrix}\right)\,,
438
455
  # \end{align}
439
456
  # $$
457
+ #
440
458
  # where $\mathbf{f}^{\star} = f(\mathbf{X}^{\star})$. Conditional on the GP's
441
459
  # latent function $f$, we assume a factorising likelihood generates our
442
460
  # observations
461
+ #
443
462
  # $$
444
463
  # \begin{align}
445
464
  # p(\mathbf{y}\,|\,\mathbf{f}) = \prod_{i=1}^n p(y_i\,|\, f_i)\,.
446
465
  # \end{align}
447
466
  # $$
467
+ #
448
468
  # Strictly speaking, the likelihood function is
449
469
  # $p(\mathbf{y}\,|\,\phi(\mathbf{f}))$ where $\phi$ is the likelihood function's
450
470
  # associated link function. Example link functions include the probit or
@@ -453,7 +473,7 @@ with warnings.catch_warnings():
453
473
  # considers Gaussian likelihood functions where the role of $\phi$ is
454
474
  # superfluous. However, this intuition will be helpful for models with a
455
475
  # non-Gaussian likelihood, such as those encountered in
456
- # [classification](https://docs.jaxgaussianprocesses.com/examples/classification).
476
+ # [classification](https://docs.jaxgaussianprocesses.com/_examples/classification).
457
477
  #
458
478
  # Applying Bayes' theorem \eqref{eq:BayesTheorem} yields the joint posterior distribution over the
459
479
  # latent function
@@ -470,7 +490,7 @@ with warnings.catch_warnings():
470
490
  # function with parameters $\boldsymbol{\theta}$ that maps pairs of inputs
471
491
  # $\mathbf{X}, \mathbf{X}' \in \mathcal{X}$ onto the real line. We dedicate the
472
492
  # entirety of the [Introduction to Kernels
473
- # notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_kernels) to
493
+ # notebook](https://docs.jaxgaussianprocesses.com/_examples/intro_to_kernels) to
474
494
  # exploring the different GPs each kernel can yield.
475
495
  #
476
496
  # ## Gaussian process regression
@@ -479,20 +499,25 @@ with warnings.catch_warnings():
479
499
  # $p(y_i\,|\, f_i) = \mathcal{N}(y_i\,|\, f_i, \sigma_n^2)$,
480
500
  # marginalising $\mathbf{f}$ from the joint posterior to obtain
481
501
  # the posterior predictive distribution is exact
502
+ #
482
503
  # $$
483
504
  # \begin{align}
484
505
  # p(\mathbf{f}^{\star}\mid \mathbf{y}) = \mathcal{N}(\mathbf{f}^{\star}\,|\,\boldsymbol{\mu}_{\,|\,\mathbf{y}}, \Sigma_{\,|\,\mathbf{y}})\,,
485
506
  # \end{align}
486
507
  # $$
508
+ #
487
509
  # where
510
+ #
488
511
  # $$
489
512
  # \begin{align}
490
513
  # \mathbf{\mu}_{\mid \mathbf{y}} & = \mathbf{K}_{\star f}\left( \mathbf{K}_{ff}+\sigma^2_n\mathbf{I}_n\right)^{-1}\mathbf{y} \\
491
514
  # \Sigma_{\,|\,\mathbf{y}} & = \mathbf{K}_{\star\star} - \mathbf{K}_{xf}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n\right)^{-1}\mathbf{K}_{fx} \,.
492
515
  # \end{align}
493
516
  # $$
517
+ #
494
518
  # Further, the log of the marginal likelihood of the GP can
495
519
  # be analytically expressed as
520
+ #
496
521
  # $$
497
522
  # \begin{align}
498
523
  # & = 0.5\left(-\underbrace{\mathbf{y}^{\top}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n \right)^{-1}\mathbf{y}}_{\text{Data fit}} -\underbrace{\log\lvert \mathbf{K}_{ff} + \sigma^2_n\rvert}_{\text{Complexity}} -\underbrace{n\log 2\pi}_{\text{Constant}} \right)\,.
@@ -505,6 +530,7 @@ with warnings.catch_warnings():
505
530
  # we call these terms the model hyperparameters
506
531
  # $\boldsymbol{\xi} = \{\boldsymbol{\theta},\sigma_n^2\}$
507
532
  # from which the maximum likelihood estimate is given by
533
+ #
508
534
  # $$
509
535
  # \begin{align*}
510
536
  # \boldsymbol{\xi}^{\star} = \operatorname{argmax}_{\boldsymbol{\xi} \in \Xi} \log p(\mathbf{y})\,.
@@ -532,7 +558,7 @@ with warnings.catch_warnings():
532
558
  # Bayes' theorem and the definition of a Gaussian random variable. Using the
533
559
  # ideas presented in this notebook, the user should be in a position to dive
534
560
  # into our [Regression
535
- # notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) and
561
+ # notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) and
536
562
  # start getting their hands on some code. For those looking to learn more about
537
563
  # the underling theory of GPs, an excellent starting point is the [Gaussian
538
564
  # Processes for Machine Learning](http://gaussianprocess.org/gpml/) textbook.