gpjax 0.9.3__tar.gz → 0.9.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. gpjax-0.9.5/LICENSE.txt +19 -0
  2. {gpjax-0.9.3 → gpjax-0.9.5}/PKG-INFO +20 -21
  3. {gpjax-0.9.3 → gpjax-0.9.5}/README.md +15 -15
  4. {gpjax-0.9.3 → gpjax-0.9.5}/docs/index.md +2 -3
  5. {gpjax-0.9.3 → gpjax-0.9.5}/examples/backend.py +4 -4
  6. {gpjax-0.9.3 → gpjax-0.9.5}/examples/barycentres.py +3 -3
  7. {gpjax-0.9.3 → gpjax-0.9.5}/examples/classification.py +6 -1
  8. {gpjax-0.9.3 → gpjax-0.9.5}/examples/collapsed_vi.py +2 -2
  9. {gpjax-0.9.3 → gpjax-0.9.5}/examples/constructing_new_kernels.py +13 -7
  10. {gpjax-0.9.3 → gpjax-0.9.5}/examples/deep_kernels.py +2 -2
  11. {gpjax-0.9.3 → gpjax-0.9.5}/examples/graph_kernels.py +5 -3
  12. {gpjax-0.9.3 → gpjax-0.9.5}/examples/intro_to_gps.py +39 -13
  13. {gpjax-0.9.3 → gpjax-0.9.5}/examples/intro_to_kernels.py +40 -22
  14. {gpjax-0.9.3 → gpjax-0.9.5}/examples/likelihoods_guide.py +5 -3
  15. {gpjax-0.9.3 → gpjax-0.9.5}/examples/oceanmodelling.py +6 -4
  16. {gpjax-0.9.3 → gpjax-0.9.5}/examples/poisson.py +1 -1
  17. {gpjax-0.9.3 → gpjax-0.9.5}/examples/regression.py +1 -1
  18. {gpjax-0.9.3 → gpjax-0.9.5}/examples/uncollapsed_vi.py +3 -3
  19. {gpjax-0.9.3 → gpjax-0.9.5}/examples/utils.py +1 -1
  20. {gpjax-0.9.3 → gpjax-0.9.5}/examples/yacht.py +5 -5
  21. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/__init__.py +1 -3
  22. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/citation.py +0 -43
  23. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/distributions.py +3 -1
  24. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/gps.py +2 -1
  25. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/variational_families.py +24 -19
  26. {gpjax-0.9.3 → gpjax-0.9.5}/mkdocs.yml +0 -2
  27. {gpjax-0.9.3 → gpjax-0.9.5}/pyproject.toml +3 -4
  28. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_citations.py +0 -64
  29. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_dataset.py +3 -2
  30. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_fit.py +2 -2
  31. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_gps.py +2 -1
  32. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_parameters.py +1 -1
  33. gpjax-0.9.3/LICENSE +0 -201
  34. gpjax-0.9.3/examples/bayesian_optimisation.py +0 -790
  35. gpjax-0.9.3/examples/decision_making.py +0 -411
  36. gpjax-0.9.3/gpjax/decision_making/__init__.py +0 -63
  37. gpjax-0.9.3/gpjax/decision_making/decision_maker.py +0 -302
  38. gpjax-0.9.3/gpjax/decision_making/posterior_handler.py +0 -152
  39. gpjax-0.9.3/gpjax/decision_making/search_space.py +0 -96
  40. gpjax-0.9.3/gpjax/decision_making/test_functions/__init__.py +0 -31
  41. gpjax-0.9.3/gpjax/decision_making/test_functions/continuous_functions.py +0 -169
  42. gpjax-0.9.3/gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -90
  43. gpjax-0.9.3/gpjax/decision_making/utility_functions/__init__.py +0 -37
  44. gpjax-0.9.3/gpjax/decision_making/utility_functions/base.py +0 -106
  45. gpjax-0.9.3/gpjax/decision_making/utility_functions/expected_improvement.py +0 -112
  46. gpjax-0.9.3/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -125
  47. gpjax-0.9.3/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -101
  48. gpjax-0.9.3/gpjax/decision_making/utility_maximizer.py +0 -157
  49. gpjax-0.9.3/gpjax/decision_making/utils.py +0 -64
  50. gpjax-0.9.3/publish/gpjax-0.9.3-py3-none-any.whl +0 -0
  51. gpjax-0.9.3/publish/gpjax-0.9.3.tar.gz +0 -0
  52. gpjax-0.9.3/tests/test_decision_making/__init__.py +0 -0
  53. gpjax-0.9.3/tests/test_decision_making/test_decision_maker.py +0 -474
  54. gpjax-0.9.3/tests/test_decision_making/test_posterior_handler.py +0 -315
  55. gpjax-0.9.3/tests/test_decision_making/test_search_space.py +0 -206
  56. gpjax-0.9.3/tests/test_decision_making/test_test_functions/__init__.py +0 -0
  57. gpjax-0.9.3/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -185
  58. gpjax-0.9.3/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -112
  59. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
  60. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/test_base.py +0 -26
  61. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +0 -67
  62. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -64
  63. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -120
  64. gpjax-0.9.3/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -167
  65. gpjax-0.9.3/tests/test_decision_making/test_utility_maximizer.py +0 -165
  66. gpjax-0.9.3/tests/test_decision_making/test_utils.py +0 -84
  67. gpjax-0.9.3/tests/test_decision_making/utils.py +0 -90
  68. {gpjax-0.9.3 → gpjax-0.9.5}/.github/CODE_OF_CONDUCT.md +0 -0
  69. {gpjax-0.9.3 → gpjax-0.9.5}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  70. {gpjax-0.9.3 → gpjax-0.9.5}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  71. {gpjax-0.9.3 → gpjax-0.9.5}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  72. {gpjax-0.9.3 → gpjax-0.9.5}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  73. {gpjax-0.9.3 → gpjax-0.9.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  74. {gpjax-0.9.3 → gpjax-0.9.5}/.github/codecov.yml +0 -0
  75. {gpjax-0.9.3 → gpjax-0.9.5}/.github/labels.yml +0 -0
  76. {gpjax-0.9.3 → gpjax-0.9.5}/.github/pull_request_template.md +0 -0
  77. {gpjax-0.9.3 → gpjax-0.9.5}/.github/release-drafter.yml +0 -0
  78. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/build_docs.yml +0 -0
  79. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/integration.yml +0 -0
  80. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/pr_greeting.yml +0 -0
  81. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/ruff.yml +0 -0
  82. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/stale_prs.yml +0 -0
  83. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/test_docs.yml +0 -0
  84. {gpjax-0.9.3 → gpjax-0.9.5}/.github/workflows/tests.yml +0 -0
  85. {gpjax-0.9.3 → gpjax-0.9.5}/.gitignore +0 -0
  86. {gpjax-0.9.3 → gpjax-0.9.5}/CITATION.bib +0 -0
  87. {gpjax-0.9.3 → gpjax-0.9.5}/Makefile +0 -0
  88. {gpjax-0.9.3 → gpjax-0.9.5}/docs/CODE_OF_CONDUCT.md +0 -0
  89. {gpjax-0.9.3 → gpjax-0.9.5}/docs/GOVERNANCE.md +0 -0
  90. {gpjax-0.9.3 → gpjax-0.9.5}/docs/contributing.md +0 -0
  91. {gpjax-0.9.3 → gpjax-0.9.5}/docs/design.md +0 -0
  92. {gpjax-0.9.3 → gpjax-0.9.5}/docs/index.rst +0 -0
  93. {gpjax-0.9.3 → gpjax-0.9.5}/docs/installation.md +0 -0
  94. {gpjax-0.9.3 → gpjax-0.9.5}/docs/javascripts/katex.js +0 -0
  95. {gpjax-0.9.3 → gpjax-0.9.5}/docs/refs.bib +0 -0
  96. {gpjax-0.9.3 → gpjax-0.9.5}/docs/scripts/gen_examples.py +0 -0
  97. {gpjax-0.9.3 → gpjax-0.9.5}/docs/scripts/gen_pages.py +0 -0
  98. {gpjax-0.9.3 → gpjax-0.9.5}/docs/scripts/notebook_converter.py +0 -0
  99. {gpjax-0.9.3 → gpjax-0.9.5}/docs/scripts/sharp_bits_figure.py +0 -0
  100. {gpjax-0.9.3 → gpjax-0.9.5}/docs/sharp_bits.md +0 -0
  101. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/GP.pdf +0 -0
  102. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/GP.svg +0 -0
  103. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/bijector_figure.svg +0 -0
  104. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/css/gpjax_theme.css +0 -0
  105. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/favicon.ico +0 -0
  106. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/gpjax.mplstyle +0 -0
  107. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/gpjax_logo.pdf +0 -0
  108. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/gpjax_logo.svg +0 -0
  109. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/jaxkern/lato.ttf +0 -0
  110. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/jaxkern/logo.png +0 -0
  111. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/jaxkern/logo.svg +0 -0
  112. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/jaxkern/main.py +0 -0
  113. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/step_size_figure.png +0 -0
  114. {gpjax-0.9.3 → gpjax-0.9.5}/docs/static/step_size_figure.svg +0 -0
  115. {gpjax-0.9.3 → gpjax-0.9.5}/docs/stylesheets/extra.css +0 -0
  116. {gpjax-0.9.3 → gpjax-0.9.5}/docs/stylesheets/permalinks.css +0 -0
  117. {gpjax-0.9.3 → gpjax-0.9.5}/examples/barycentres/barycentre_gp.gif +0 -0
  118. {gpjax-0.9.3 → gpjax-0.9.5}/examples/data/max_tempeature_switzerland.csv +0 -0
  119. {gpjax-0.9.3 → gpjax-0.9.5}/examples/data/yacht_hydrodynamics.data +0 -0
  120. {gpjax-0.9.3 → gpjax-0.9.5}/examples/gpjax.mplstyle +0 -0
  121. {gpjax-0.9.3 → gpjax-0.9.5}/examples/intro_to_gps/decomposed_mll.png +0 -0
  122. {gpjax-0.9.3 → gpjax-0.9.5}/examples/intro_to_gps/generating_process.png +0 -0
  123. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/dataset.py +0 -0
  124. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/fit.py +0 -0
  125. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/integrators.py +0 -0
  126. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/__init__.py +0 -0
  127. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/approximations/__init__.py +0 -0
  128. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/approximations/rff.py +0 -0
  129. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/base.py +0 -0
  130. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/__init__.py +0 -0
  131. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/base.py +0 -0
  132. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/basis_functions.py +0 -0
  133. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  134. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/dense.py +0 -0
  135. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/diagonal.py +0 -0
  136. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/computations/eigen.py +0 -0
  137. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  138. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/non_euclidean/graph.py +0 -0
  139. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/non_euclidean/utils.py +0 -0
  140. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/nonstationary/__init__.py +0 -0
  141. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  142. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/nonstationary/linear.py +0 -0
  143. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  144. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/__init__.py +0 -0
  145. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/base.py +0 -0
  146. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/matern12.py +0 -0
  147. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/matern32.py +0 -0
  148. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/matern52.py +0 -0
  149. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/periodic.py +0 -0
  150. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  151. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  152. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/rbf.py +0 -0
  153. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/utils.py +0 -0
  154. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/kernels/stationary/white.py +0 -0
  155. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/likelihoods.py +0 -0
  156. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/lower_cholesky.py +0 -0
  157. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/mean_functions.py +0 -0
  158. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/objectives.py +0 -0
  159. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/parameters.py +0 -0
  160. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/scan.py +0 -0
  161. {gpjax-0.9.3 → gpjax-0.9.5}/gpjax/typing.py +0 -0
  162. {gpjax-0.9.3 → gpjax-0.9.5}/static/CONTRIBUTING.md +0 -0
  163. {gpjax-0.9.3 → gpjax-0.9.5}/static/paper.bib +0 -0
  164. {gpjax-0.9.3 → gpjax-0.9.5}/static/paper.md +0 -0
  165. {gpjax-0.9.3 → gpjax-0.9.5}/static/paper.pdf +0 -0
  166. {gpjax-0.9.3 → gpjax-0.9.5}/tests/__init__.py +0 -0
  167. {gpjax-0.9.3 → gpjax-0.9.5}/tests/conftest.py +0 -0
  168. {gpjax-0.9.3 → gpjax-0.9.5}/tests/integration_tests.py +0 -0
  169. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_gaussian_distribution.py +0 -0
  170. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_integrators.py +0 -0
  171. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/__init__.py +0 -0
  172. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_approximations.py +0 -0
  173. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_base.py +0 -0
  174. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_computation.py +0 -0
  175. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_non_euclidean.py +0 -0
  176. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_nonstationary.py +0 -0
  177. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_stationary.py +0 -0
  178. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_kernels/test_utils.py +0 -0
  179. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_likelihoods.py +0 -0
  180. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_lower_cholesky.py +0 -0
  181. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_markdown.py +0 -0
  182. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_mean_functions.py +0 -0
  183. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_objectives.py +0 -0
  184. {gpjax-0.9.3 → gpjax-0.9.5}/tests/test_variational_families.py +0 -0
@@ -0,0 +1,19 @@
1
+ (C) Copyright 2019 Hewlett Packard Enterprise Development LP
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a
4
+ copy of this software and associated documentation files (the "Software"),
5
+ to deal in the Software without restriction, including without limitation
6
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ and/or sell copies of the Software, and to permit persons to whom the
8
+ Software is furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included
11
+ in all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
17
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
18
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
19
+ OTHER DEALINGS IN THE SOFTWARE.
@@ -1,13 +1,13 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.9.3
3
+ Version: 0.9.5
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
7
7
  Project-URL: Source, https://github.com/JaxGaussianProcesses/GPJax
8
8
  Author-email: Thomas Pinder <tompinder@live.co.uk>
9
- License-Expression: Apache-2.0
10
- License-File: LICENSE
9
+ License: MIT
10
+ License-File: LICENSE.txt
11
11
  Keywords: gaussian-processes jax machine-learning bayesian
12
12
  Classifier: Development Status :: 4 - Beta
13
13
  Classifier: Programming Language :: Python
@@ -19,10 +19,9 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
19
19
  Requires-Python: <3.13,>=3.10
20
20
  Requires-Dist: beartype>0.16.1
21
21
  Requires-Dist: cola-ml==0.0.5
22
- Requires-Dist: flax>=0.8.5
22
+ Requires-Dist: flax<0.10.0
23
23
  Requires-Dist: jax<0.4.28
24
24
  Requires-Dist: jaxlib<0.4.28
25
- Requires-Dist: jaxopt==0.8.2
26
25
  Requires-Dist: jaxtyping>0.2.10
27
26
  Requires-Dist: numpy<2.0.0
28
27
  Requires-Dist: optax>0.2.1
@@ -103,23 +102,23 @@ helped to shape GPJax into the package it is today.
103
102
 
104
103
  ## Notebook examples
105
104
 
106
- > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
107
- > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
108
- > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
109
- > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
110
- > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
111
- > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
112
- > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
113
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/examples/spatial/)
114
- > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/examples/barycentres/)
115
- > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/examples/deep_kernels/)
116
- > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/examples/poisson/)
117
- > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/)
105
+ > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
106
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
107
+ > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
108
+ > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
109
+ > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
110
+ > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
111
+ > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
112
+ > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
113
+ > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
114
+ > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
115
+ > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
116
+ > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
118
117
 
119
118
  ## Guides for customisation
120
119
  >
121
- > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
122
- > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/examples/yacht/)
120
+ > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
121
+ > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
123
122
 
124
123
  ## Conversion between `.ipynb` and `.py`
125
124
  Above examples are stored in [examples](docs/examples) directory in the double
@@ -180,7 +179,7 @@ optimiser = ox.adam(learning_rate=1e-2)
180
179
  # Obtain Type 2 MLEs of the hyperparameters
181
180
  opt_posterior, history = gpx.fit(
182
181
  model=posterior,
183
- objective=gpx.objectives.conjugate_mll,
182
+ objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
184
183
  train_data=D,
185
184
  optim=optimiser,
186
185
  num_iters=500,
@@ -71,23 +71,23 @@ helped to shape GPJax into the package it is today.
71
71
 
72
72
  ## Notebook examples
73
73
 
74
- > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
75
- > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
76
- > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
77
- > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
78
- > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
79
- > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
80
- > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
81
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/examples/spatial/)
82
- > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/examples/barycentres/)
83
- > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/examples/deep_kernels/)
84
- > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/examples/poisson/)
85
- > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/)
74
+ > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
75
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
76
+ > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
77
+ > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
78
+ > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
79
+ > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
80
+ > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
81
+ > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
82
+ > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
83
+ > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
84
+ > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
85
+ > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
86
86
 
87
87
  ## Guides for customisation
88
88
  >
89
- > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
90
- > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/examples/yacht/)
89
+ > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
90
+ > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
91
91
 
92
92
  ## Conversion between `.ipynb` and `.py`
93
93
  Above examples are stored in [examples](docs/examples) directory in the double
@@ -148,7 +148,7 @@ optimiser = ox.adam(learning_rate=1e-2)
148
148
  # Obtain Type 2 MLEs of the hyperparameters
149
149
  opt_posterior, history = gpx.fit(
150
150
  model=posterior,
151
- objective=gpx.objectives.conjugate_mll,
151
+ objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
152
152
  train_data=D,
153
153
  optim=optimiser,
154
154
  num_iters=500,
@@ -1,4 +1,4 @@
1
- # Welcome to GPJax!
1
+ # Welcome to GPJax
2
2
 
3
3
  GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU
4
4
  acceleration and just-in-time compilation. We seek to provide a flexible
@@ -6,7 +6,6 @@ API to enable researchers to rapidly prototype and develop new ideas.
6
6
 
7
7
  ![Gaussian process posterior.](static/GP.svg)
8
8
 
9
-
10
9
  ## "Hello, GP!"
11
10
 
12
11
  Typing GP models is as simple as the maths we
@@ -53,7 +52,7 @@ would write on paper, as shown below.
53
52
  !!! Begin
54
53
 
55
54
  Looking for a good place to start? Then why not begin with our [regression
56
- notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
55
+ notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
57
56
 
58
57
  ## Citing GPJax
59
58
 
@@ -122,7 +122,7 @@ print(constant_param._tag)
122
122
  # For most users, you will not need to worry about this as we provide a set of default
123
123
  # bijectors that are defined for all the parameter types we support. However, see our
124
124
  # [Kernel Guide
125
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to
125
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) to
126
126
  # see how you can define your own bijectors and parameter types.
127
127
 
128
128
  # %%
@@ -156,7 +156,7 @@ transform(_close_to_zero_state, DEFAULT_BIJECTION, inverse=True)
156
156
  # may be nested within several functions e.g., a kernel function within a GP model.
157
157
  # Fortunately, transforming several parameters is a simple operation that we here
158
158
  # demonstrate for a conjugate GP posterior (see our [Regression
159
- # Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed
159
+ # Notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) for detailed
160
160
  # explanation of this model.).
161
161
 
162
162
  # %%
@@ -239,7 +239,7 @@ print(positive_reals)
239
239
  # useful as it allows us to efficiently operate on a subset of the parameters whilst
240
240
  # leaving the others untouched. Looking forward, we hope to use this functionality in
241
241
  # our [Variational Inference
242
- # Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to
242
+ # Approximations](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/) to
243
243
  # perform more efficient updates of the variational parameters and then the model's
244
244
  # hyperparameters.
245
245
 
@@ -361,7 +361,7 @@ ax.set(xlabel="x", ylabel="m(x)")
361
361
  # In this notebook we have explored how GPJax's Flax-based backend may be easily
362
362
  # manipulated and extended. For a more applied look at this, see how we construct a
363
363
  # kernel on polar coordinates in our [Kernel
364
- # Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
364
+ # Guide](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
365
365
  # notebook.
366
366
  #
367
367
  # ## System configuration
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -154,9 +154,9 @@ plt.show()
154
154
  # We'll now independently learn Gaussian process posterior distributions for each
155
155
  # dataset. We won't spend any time here discussing how GP hyperparameters are
156
156
  # optimised. For advice on achieving this, see the
157
- # [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/)
157
+ # [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
158
158
  # for advice on optimisation and the
159
- # [Kernels notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) for
159
+ # [Kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) for
160
160
  # advice on selecting an appropriate kernel.
161
161
 
162
162
 
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -193,15 +193,20 @@ ax.legend()
193
193
  # $\boldsymbol{x}$, we can expand the log of this about the posterior mode
194
194
  # $\hat{\boldsymbol{f}}$ via a Taylor expansion. This gives:
195
195
  #
196
+ # $$
196
197
  # \begin{align}
197
198
  # \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3).
198
199
  # \end{align}
200
+ # $$
199
201
  #
200
202
  # Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode,
201
203
  # this suggests the following approximation
204
+ #
205
+ # $$
202
206
  # \begin{align}
203
207
  # \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\}
204
208
  # \end{align},
209
+ # $$
205
210
  #
206
211
  # that we identify as a Gaussian distribution,
207
212
  # $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$.
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.4
10
+ # jupytext_version: 1.16.6
11
11
  # kernelspec:
12
12
  # display_name: gpjax_beartype
13
13
  # language: python
@@ -131,7 +131,7 @@ q = gpx.variational_families.CollapsedVariationalGaussian(
131
131
  # %% [markdown]
132
132
  # We now train our model akin to a Gaussian process regression model via the `fit`
133
133
  # abstraction. Unlike the regression example given in the
134
- # [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/),
134
+ # [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/),
135
135
  # the inducing locations that induce our variational posterior distribution are now
136
136
  # part of the model's parameters. Using a gradient-based optimiser, we can then
137
137
  # _optimise_ their location such that the evidence lower bound is maximised.
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -71,7 +71,7 @@ cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
71
71
  # * White noise
72
72
  # * Linear.
73
73
  # * Polynomial.
74
- # * [Graph kernels](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/).
74
+ # * [Graph kernels](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/).
75
75
  #
76
76
  # While the syntax is consistent, each kernel's type influences the
77
77
  # characteristics of the sample paths drawn. We visualise this below with 10
@@ -92,7 +92,7 @@ x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
92
92
 
93
93
  meanf = gpx.mean_functions.Zero()
94
94
 
95
- for k, ax, c in zip(kernels, axes.ravel(), cols):
95
+ for k, ax, c in zip(kernels, axes.ravel(), cols, strict=False):
96
96
  prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
97
97
  rv = prior(x)
98
98
  y = rv.sample(seed=key, sample_shape=(10,))
@@ -185,7 +185,7 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
185
185
  # We'll demonstrate this process now for a circular kernel --- an adaption of
186
186
  # the excellent guide given in the PYMC3 documentation. We encourage curious
187
187
  # readers to visit their notebook
188
- # [here](https://www.pymc.io/projects/docs/en/v3/pymc-examples/examples/gaussian_processes/GP-Circular.html).
188
+ # [here](https://www.pymc.io/projects/docs/en/v3/pymc-_examples/_examples/gaussian_processes/GP-Circular.html).
189
189
  #
190
190
  # ### Circular kernel
191
191
  #
@@ -198,9 +198,15 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
198
198
  # kernels do not exhibit this behaviour and instead _wrap_ around the boundary
199
199
  # points to create a smooth function. Such a kernel was given in [Padonou &
200
200
  # Roustant (2015)](https://hal.inria.fr/hal-01119942v1) where any two angles
201
- # $\theta$ and $\theta'$ are written as $$W_c(\theta, \theta') = \left\lvert
201
+ # $\theta$ and $\theta'$ are written as
202
+ #
203
+ # $$
204
+ # \begin{align}
205
+ # W_c(\theta, \theta') & = \left\lvert
202
206
  # \left(1 + \tau \frac{d(\theta, \theta')}{c} \right) \left(1 - \frac{d(\theta,
203
- # \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}.$$
207
+ # \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}.
208
+ # \end{align}
209
+ # $$
204
210
  #
205
211
  # Here the hyperparameter $\tau$ is analogous to a lengthscale for Euclidean
206
212
  # stationary kernels, controlling the correlation between pairs of observations.
@@ -266,7 +272,7 @@ class Polar(gpx.kernels.AbstractKernel):
266
272
  #
267
273
  # We proceed to fit a GP with our custom circular kernel to a random sequence of
268
274
  # points on a circle (see the
269
- # [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/)
275
+ # [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
270
276
  # for further details on this process).
271
277
 
272
278
  # %%
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -141,7 +141,7 @@ class DeepKernelFunction(AbstractKernel):
141
141
  # activation functions between the layers. The first hidden layer contains 64 units,
142
142
  # while the second layer contains 32 units. Finally, we'll make the output of our
143
143
  # network a three units wide. The corresponding kernel that we define will then be of
144
- # [ARD form](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#active-dimensions)
144
+ # [ARD form](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#active-dimensions)
145
145
  # to allow for different lengthscales in each dimension of the feature space.
146
146
  # Users may wish to design more intricate network structures for more complex tasks,
147
147
  # which functionality is supported well in Haiku.
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.4
11
+ # jupytext_version: 1.16.6
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -22,7 +22,7 @@
22
22
  # of a graph using a Gaussian process with a Matérn kernel presented in
23
23
  # <strong data-cite="borovitskiy2021matern"></strong>. For a general discussion of the
24
24
  # kernels supported within GPJax, see the
25
- # [kernels notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels).
25
+ # [kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels).
26
26
 
27
27
  # %%
28
28
  import random
@@ -88,7 +88,9 @@ nx.draw(
88
88
  #
89
89
  # Graph kernels use the _Laplacian matrix_ $L$ to quantify the smoothness of a signal
90
90
  # (or function) on a graph
91
+ #
91
92
  # $$L=D-A,$$
93
+ #
92
94
  # where $D$ is the diagonal _degree matrix_ containing each vertices' degree and $A$
93
95
  # is the _adjacency matrix_ that has an $(i,j)^{\text{th}}$ entry of 1 if $v_i, v_j$
94
96
  # are connected and 0 otherwise. [Networkx](https://networkx.org) gives us an easy
@@ -151,7 +153,7 @@ cbar = plt.colorbar(sm, ax=ax)
151
153
  # non-Euclidean, our likelihood is still Gaussian and the model is still conjugate.
152
154
  # For this reason, we simply perform gradient descent on the GP's marginal
153
155
  # log-likelihood term as in the
154
- # [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
156
+ # [regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
155
157
  # We do this using the BFGS optimiser.
156
158
 
157
159
  # %%
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.4
10
+ # jupytext_version: 1.16.6
11
11
  # kernelspec:
12
12
  # display_name: gpjax
13
13
  # language: python
@@ -17,6 +17,7 @@
17
17
  # %% [markdown]
18
18
  # # New to Gaussian Processes?
19
19
  #
20
+ #
20
21
  # Fantastic that you're here! This notebook is designed to be a gentle
21
22
  # introduction to the mathematics of Gaussian processes (GPs). No prior
22
23
  # knowledge of Bayesian inference or GPs is assumed, and this notebook is
@@ -33,10 +34,11 @@
33
34
  # model are unknown, and our goal is to conduct inference to determine their
34
35
  # range of likely values. To achieve this, we apply Bayes' theorem
35
36
  #
37
+ # $$
36
38
  # \begin{align}
37
- # \label{eq:BayesTheorem}
38
- # p(\theta\,|\, \mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta}\,,
39
+ # p(\theta\mid\mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta},
39
40
  # \end{align}
41
+ # $$
40
42
  #
41
43
  # where $p(\mathbf{y}\,|\,\theta)$ denotes the _likelihood_, or model, and
42
44
  # quantifies how likely the observed dataset $\mathbf{y}$ is, given the
@@ -58,7 +60,7 @@
58
60
  # family, then there exists a conjugate prior. However, the conjugate prior may
59
61
  # not have a form that precisely reflects the practitioner's belief surrounding
60
62
  # the parameter. For this reason, conjugate models seldom appear; one exception
61
- # to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
63
+ # to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
62
64
  #
63
65
  # For models that do not contain a conjugate prior, the marginal log-likelihood
64
66
  # must be calculated to normalise the posterior distribution and ensure it
@@ -74,9 +76,13 @@
74
76
  # new points $\mathbf{y}^{\star}$ through the _posterior predictive
75
77
  # distribution_. This is achieved by integrating out the parameter set $\theta$
76
78
  # from our posterior distribution through
79
+ #
80
+ # $$
77
81
  # \begin{align}
78
82
  # p(\mathbf{y}^{\star}\mid \mathbf{y}) = \int p(\mathbf{y}^{\star} \,|\, \theta, \mathbf{y} ) p(\theta\,|\, \mathbf{y})\mathrm{d}\theta\,.
79
83
  # \end{align}
84
+ # $$
85
+ #
80
86
  # As with the marginal log-likelihood, evaluating this quantity requires
81
87
  # computing an integral which may not be tractable, particularly when $\theta$
82
88
  # is high-dimensional.
@@ -85,13 +91,16 @@
85
91
  # distribution, so we often compute and report moments of the posterior
86
92
  # distribution. Most commonly, we report the first moment and the centred second
87
93
  # moment
94
+ #
88
95
  # $$
89
96
  # \begin{alignat}{2}
90
- # \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta p(\theta\mid\mathbf{y})\mathrm{d}\theta\\
97
+ # \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta
98
+ # p(\theta\mid\mathbf{y})\mathrm{d}\theta \quad \\
91
99
  # \sigma^2 = \mathbb{V}[\theta\,|\,\mathbf{y}] & = \int \left(\theta -
92
100
  # \mathbb{E}[\theta\,|\,\mathbf{y}]\right)^2p(\theta\,|\,\mathbf{y})\mathrm{d}\theta&\,.
93
101
  # \end{alignat}
94
102
  # $$
103
+ #
95
104
  # Through this pair of statistics, we can communicate our beliefs about the most
96
105
  # likely value of $\theta$ i.e., $\mu$, and the uncertainty $\sigma$ around the
97
106
  # expected value. However, as with the marginal log-likelihood and predictive
@@ -205,13 +214,11 @@ titles = [r"$\rho = 0$", r"$\rho = 0.9$", r"$\rho = -0.5$"]
205
214
 
206
215
  cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", ["white", cols[1]], N=256)
207
216
 
208
- for a, t, d in zip([ax0, ax1, ax2], titles, dists):
217
+ for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
209
218
  d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
210
219
  xx.shape
211
220
  )
212
- cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap)
213
- for c in cntf.collections:
214
- c.set_edgecolor("face")
221
+ cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
215
222
  a.set_xlim(-2.75, 2.75)
216
223
  a.set_ylim(-2.75, 2.75)
217
224
  samples = d.sample(seed=key, sample_shape=(5000,))
@@ -228,13 +235,16 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists):
228
235
  # %% [markdown]
229
236
  # Extending the intuition given for the moments of a univariate Gaussian random
230
237
  # variables, we can obtain the mean and covariance by
238
+ #
231
239
  # $$
232
240
  # \begin{align}
233
- # \mathbb{E}[\mathbf{y}] = \mathbf{\mu}, \quad \operatorname{Cov}(\mathbf{y}) & = \mathbf{E}\left[(\mathbf{y} - \mathbf{\mu)}(\mathbf{y} - \mathbf{\mu)}^{\top} \right]\\
241
+ # \mathbb{E}[\mathbf{y}] & = \mathbf{\mu}, \\
242
+ # \operatorname{Cov}(\mathbf{y}) & = \mathbf{E}\left[(\mathbf{y} - \mathbf{\mu})(\mathbf{y} - \mathbf{\mu})^{\top} \right] \\
234
243
  # & =\mathbb{E}[\mathbf{y}\mathbf{y}^{\top}] - \mathbb{E}[\mathbf{y}]\mathbb{E}[\mathbf{y}]^{\top} \\
235
244
  # & =\mathbf{\Sigma}\,.
236
245
  # \end{align}
237
246
  # $$
247
+ #
238
248
  # The covariance matrix is a symmetric positive definite matrix that generalises
239
249
  # the notion of variance to multiple dimensions. The matrix's diagonal entries
240
250
  # contain the variance of each element, whilst the off-diagonal entries quantify
@@ -336,6 +346,7 @@ with warnings.catch_warnings():
336
346
  # $\mathbf{x}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{x}}, \boldsymbol{\Sigma}_{\mathbf{xx}})$ and
337
347
  # $\mathbf{y}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{y}}, \boldsymbol{\Sigma}_{\mathbf{yy}})$.
338
348
  # We define the joint distribution as
349
+ #
339
350
  # $$
340
351
  # \begin{align}
341
352
  # p\left(\begin{bmatrix}
@@ -348,6 +359,7 @@ with warnings.catch_warnings():
348
359
  # \end{bmatrix} \right)\,,
349
360
  # \end{align}
350
361
  # $$
362
+ #
351
363
  # where $\boldsymbol{\Sigma}_{\mathbf{x}\mathbf{y}}$ is the cross-covariance
352
364
  # matrix of $\mathbf{x}$ and $\mathbf{y}$.
353
365
  #
@@ -363,6 +375,7 @@ with warnings.catch_warnings():
363
375
  #
364
376
  # For a joint Gaussian random variable, the marginalisation of $\mathbf{x}$ or
365
377
  # $\mathbf{y}$ is given by
378
+ #
366
379
  # $$
367
380
  # \begin{alignat}{3}
368
381
  # & \int p(\mathbf{x}, \mathbf{y})\mathrm{d}\mathbf{y} && = p(\mathbf{x})
@@ -372,7 +385,9 @@ with warnings.catch_warnings():
372
385
  # \boldsymbol{\Sigma}_{\mathbf{yy}})\,.
373
386
  # \end{alignat}
374
387
  # $$
388
+ #
375
389
  # The conditional distributions are given by
390
+ #
376
391
  # $$
377
392
  # \begin{align}
378
393
  # p(\mathbf{y}\,|\, \mathbf{x}) & = \mathcal{N}\left(\boldsymbol{\mu}_{\mathbf{y}} + \boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}(\mathbf{x}-\boldsymbol{\mu}_{\mathbf{x}}), \boldsymbol{\Sigma}_{\mathbf{yy}}-\boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}\boldsymbol{\Sigma}_{\mathbf{xy}}\right)\,.
@@ -401,6 +416,7 @@ with warnings.catch_warnings():
401
416
  # We aim to capture the relationship between $\mathbf{X}$ and $\mathbf{y}$ using
402
417
  # a model $f$ with which we may make predictions at an unseen set of test points
403
418
  # $\mathbf{X}^{\star}\subset\mathcal{X}$. We formalise this by
419
+ #
404
420
  # $$
405
421
  # \begin{align}
406
422
  # y = f(\mathbf{X}) + \varepsilon\,,
@@ -430,6 +446,7 @@ with warnings.catch_warnings():
430
446
  # convenience in the remainder of this article.
431
447
  #
432
448
  # We define a joint GP prior over the latent function
449
+ #
433
450
  # $$
434
451
  # \begin{align}
435
452
  # p(\mathbf{f}, \mathbf{f}^{\star}) = \mathcal{N}\left(\mathbf{0}, \begin{bmatrix}
@@ -437,14 +454,17 @@ with warnings.catch_warnings():
437
454
  # \end{bmatrix}\right)\,,
438
455
  # \end{align}
439
456
  # $$
457
+ #
440
458
  # where $\mathbf{f}^{\star} = f(\mathbf{X}^{\star})$. Conditional on the GP's
441
459
  # latent function $f$, we assume a factorising likelihood generates our
442
460
  # observations
461
+ #
443
462
  # $$
444
463
  # \begin{align}
445
464
  # p(\mathbf{y}\,|\,\mathbf{f}) = \prod_{i=1}^n p(y_i\,|\, f_i)\,.
446
465
  # \end{align}
447
466
  # $$
467
+ #
448
468
  # Strictly speaking, the likelihood function is
449
469
  # $p(\mathbf{y}\,|\,\phi(\mathbf{f}))$ where $\phi$ is the likelihood function's
450
470
  # associated link function. Example link functions include the probit or
@@ -453,7 +473,7 @@ with warnings.catch_warnings():
453
473
  # considers Gaussian likelihood functions where the role of $\phi$ is
454
474
  # superfluous. However, this intuition will be helpful for models with a
455
475
  # non-Gaussian likelihood, such as those encountered in
456
- # [classification](https://docs.jaxgaussianprocesses.com/examples/classification).
476
+ # [classification](https://docs.jaxgaussianprocesses.com/_examples/classification).
457
477
  #
458
478
  # Applying Bayes' theorem \eqref{eq:BayesTheorem} yields the joint posterior distribution over the
459
479
  # latent function
@@ -470,7 +490,7 @@ with warnings.catch_warnings():
470
490
  # function with parameters $\boldsymbol{\theta}$ that maps pairs of inputs
471
491
  # $\mathbf{X}, \mathbf{X}' \in \mathcal{X}$ onto the real line. We dedicate the
472
492
  # entirety of the [Introduction to Kernels
473
- # notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_kernels) to
493
+ # notebook](https://docs.jaxgaussianprocesses.com/_examples/intro_to_kernels) to
474
494
  # exploring the different GPs each kernel can yield.
475
495
  #
476
496
  # ## Gaussian process regression
@@ -479,20 +499,25 @@ with warnings.catch_warnings():
479
499
  # $p(y_i\,|\, f_i) = \mathcal{N}(y_i\,|\, f_i, \sigma_n^2)$,
480
500
  # marginalising $\mathbf{f}$ from the joint posterior to obtain
481
501
  # the posterior predictive distribution is exact
502
+ #
482
503
  # $$
483
504
  # \begin{align}
484
505
  # p(\mathbf{f}^{\star}\mid \mathbf{y}) = \mathcal{N}(\mathbf{f}^{\star}\,|\,\boldsymbol{\mu}_{\,|\,\mathbf{y}}, \Sigma_{\,|\,\mathbf{y}})\,,
485
506
  # \end{align}
486
507
  # $$
508
+ #
487
509
  # where
510
+ #
488
511
  # $$
489
512
  # \begin{align}
490
513
  # \mathbf{\mu}_{\mid \mathbf{y}} & = \mathbf{K}_{\star f}\left( \mathbf{K}_{ff}+\sigma^2_n\mathbf{I}_n\right)^{-1}\mathbf{y} \\
491
514
  # \Sigma_{\,|\,\mathbf{y}} & = \mathbf{K}_{\star\star} - \mathbf{K}_{xf}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n\right)^{-1}\mathbf{K}_{fx} \,.
492
515
  # \end{align}
493
516
  # $$
517
+ #
494
518
  # Further, the log of the marginal likelihood of the GP can
495
519
  # be analytically expressed as
520
+ #
496
521
  # $$
497
522
  # \begin{align}
498
523
  # & = 0.5\left(-\underbrace{\mathbf{y}^{\top}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n \right)^{-1}\mathbf{y}}_{\text{Data fit}} -\underbrace{\log\lvert \mathbf{K}_{ff} + \sigma^2_n\rvert}_{\text{Complexity}} -\underbrace{n\log 2\pi}_{\text{Constant}} \right)\,.
@@ -505,6 +530,7 @@ with warnings.catch_warnings():
505
530
  # we call these terms the model hyperparameters
506
531
  # $\boldsymbol{\xi} = \{\boldsymbol{\theta},\sigma_n^2\}$
507
532
  # from which the maximum likelihood estimate is given by
533
+ #
508
534
  # $$
509
535
  # \begin{align*}
510
536
  # \boldsymbol{\xi}^{\star} = \operatorname{argmax}_{\boldsymbol{\xi} \in \Xi} \log p(\mathbf{y})\,.
@@ -532,7 +558,7 @@ with warnings.catch_warnings():
532
558
  # Bayes' theorem and the definition of a Gaussian random variable. Using the
533
559
  # ideas presented in this notebook, the user should be in a position to dive
534
560
  # into our [Regression
535
- # notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) and
561
+ # notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) and
536
562
  # start getting their hands on some code. For those looking to learn more about
537
563
  # the underling theory of GPs, an excellent starting point is the [Gaussian
538
564
  # Processes for Machine Learning](http://gaussianprocess.org/gpml/) textbook.