gpjax 0.9.2__tar.gz → 0.9.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.9.2 → gpjax-0.9.4}/PKG-INFO +18 -18
- {gpjax-0.9.2 → gpjax-0.9.4}/README.md +15 -15
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/index.md +2 -3
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/backend.py +4 -4
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/barycentres.py +3 -3
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/bayesian_optimisation.py +3 -3
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/classification.py +6 -1
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/collapsed_vi.py +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/constructing_new_kernels.py +12 -6
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/decision_making.py +5 -5
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/deep_kernels.py +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/graph_kernels.py +5 -3
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps.py +38 -12
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_kernels.py +42 -21
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/likelihoods_guide.py +5 -3
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/oceanmodelling.py +6 -4
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/poisson.py +12 -23
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/regression.py +1 -1
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/uncollapsed_vi.py +3 -4
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/yacht.py +5 -5
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/__init__.py +1 -1
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/non_conjugate_functions.py +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/gps.py +2 -1
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/likelihoods.py +3 -5
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/scan.py +10 -10
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/variational_families.py +33 -21
- {gpjax-0.9.2 → gpjax-0.9.4}/mkdocs.yml +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/pyproject.toml +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +1 -1
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_fit.py +2 -2
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/codecov.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/labels.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/pull_request_template.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/release-drafter.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/integration.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.github/workflows/tests.yml +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/.gitignore +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/CITATION.bib +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/LICENSE +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/Makefile +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/contributing.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/design.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/index.rst +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/installation.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/javascripts/katex.js +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/refs.bib +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/sharp_bits.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/GP.pdf +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/GP.svg +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/favicon.ico +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/examples/utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/citation.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/dataset.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/decision_maker.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/posterior_handler.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/search_space.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utility_maximizer.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/decision_making/utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/distributions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/fit.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/integrators.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/mean_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/objectives.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/parameters.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/gpjax/typing.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.bib +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.md +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/static/paper.pdf +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/conftest.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/integration_tests.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_citations.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_dataset.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_decision_maker.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_posterior_handler.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_search_space.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utility_maximizer.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/test_utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_decision_making/utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_gps.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_integrators.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_likelihoods.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_markdown.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_mean_functions.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_objectives.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_parameters.py +0 -0
- {gpjax-0.9.2 → gpjax-0.9.4}/tests/test_variational_families.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.9.
|
|
3
|
+
Version: 0.9.4
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
|
19
19
|
Requires-Python: <3.13,>=3.10
|
|
20
20
|
Requires-Dist: beartype>0.16.1
|
|
21
21
|
Requires-Dist: cola-ml==0.0.5
|
|
22
|
-
Requires-Dist: flax
|
|
22
|
+
Requires-Dist: flax<0.10.0
|
|
23
23
|
Requires-Dist: jax<0.4.28
|
|
24
24
|
Requires-Dist: jaxlib<0.4.28
|
|
25
25
|
Requires-Dist: jaxopt==0.8.2
|
|
@@ -103,23 +103,23 @@ helped to shape GPJax into the package it is today.
|
|
|
103
103
|
|
|
104
104
|
## Notebook examples
|
|
105
105
|
|
|
106
|
-
> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/
|
|
107
|
-
> - [**Classification**](https://docs.jaxgaussianprocesses.com/
|
|
108
|
-
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/
|
|
109
|
-
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/
|
|
110
|
-
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/
|
|
111
|
-
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/
|
|
112
|
-
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/
|
|
113
|
-
> - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/
|
|
114
|
-
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/
|
|
115
|
-
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/
|
|
116
|
-
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/
|
|
117
|
-
> - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/
|
|
106
|
+
> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
107
|
+
> - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
|
|
108
|
+
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
|
|
109
|
+
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
|
|
110
|
+
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
111
|
+
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
112
|
+
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
113
|
+
> - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
|
|
114
|
+
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
115
|
+
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
116
|
+
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
117
|
+
> - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
|
|
118
118
|
|
|
119
119
|
## Guides for customisation
|
|
120
120
|
>
|
|
121
|
-
> - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/
|
|
122
|
-
> - [**UCI regression**](https://docs.jaxgaussianprocesses.com/
|
|
121
|
+
> - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
122
|
+
> - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
|
|
123
123
|
|
|
124
124
|
## Conversion between `.ipynb` and `.py`
|
|
125
125
|
Above examples are stored in [examples](docs/examples) directory in the double
|
|
@@ -180,7 +180,7 @@ optimiser = ox.adam(learning_rate=1e-2)
|
|
|
180
180
|
# Obtain Type 2 MLEs of the hyperparameters
|
|
181
181
|
opt_posterior, history = gpx.fit(
|
|
182
182
|
model=posterior,
|
|
183
|
-
objective=gpx.objectives.conjugate_mll,
|
|
183
|
+
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
|
|
184
184
|
train_data=D,
|
|
185
185
|
optim=optimiser,
|
|
186
186
|
num_iters=500,
|
|
@@ -71,23 +71,23 @@ helped to shape GPJax into the package it is today.
|
|
|
71
71
|
|
|
72
72
|
## Notebook examples
|
|
73
73
|
|
|
74
|
-
> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/
|
|
75
|
-
> - [**Classification**](https://docs.jaxgaussianprocesses.com/
|
|
76
|
-
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/
|
|
77
|
-
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/
|
|
78
|
-
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/
|
|
79
|
-
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/
|
|
80
|
-
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/
|
|
81
|
-
> - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/
|
|
82
|
-
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/
|
|
83
|
-
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/
|
|
84
|
-
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/
|
|
85
|
-
> - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/
|
|
74
|
+
> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
75
|
+
> - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
|
|
76
|
+
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
|
|
77
|
+
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
|
|
78
|
+
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
79
|
+
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
80
|
+
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
81
|
+
> - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
|
|
82
|
+
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
83
|
+
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
84
|
+
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
85
|
+
> - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
|
|
86
86
|
|
|
87
87
|
## Guides for customisation
|
|
88
88
|
>
|
|
89
|
-
> - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/
|
|
90
|
-
> - [**UCI regression**](https://docs.jaxgaussianprocesses.com/
|
|
89
|
+
> - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
90
|
+
> - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
|
|
91
91
|
|
|
92
92
|
## Conversion between `.ipynb` and `.py`
|
|
93
93
|
Above examples are stored in [examples](docs/examples) directory in the double
|
|
@@ -148,7 +148,7 @@ optimiser = ox.adam(learning_rate=1e-2)
|
|
|
148
148
|
# Obtain Type 2 MLEs of the hyperparameters
|
|
149
149
|
opt_posterior, history = gpx.fit(
|
|
150
150
|
model=posterior,
|
|
151
|
-
objective=gpx.objectives.conjugate_mll,
|
|
151
|
+
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
|
|
152
152
|
train_data=D,
|
|
153
153
|
optim=optimiser,
|
|
154
154
|
num_iters=500,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Welcome to GPJax
|
|
1
|
+
# Welcome to GPJax
|
|
2
2
|
|
|
3
3
|
GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU
|
|
4
4
|
acceleration and just-in-time compilation. We seek to provide a flexible
|
|
@@ -6,7 +6,6 @@ API to enable researchers to rapidly prototype and develop new ideas.
|
|
|
6
6
|
|
|
7
7
|

|
|
8
8
|
|
|
9
|
-
|
|
10
9
|
## "Hello, GP!"
|
|
11
10
|
|
|
12
11
|
Typing GP models is as simple as the maths we
|
|
@@ -53,7 +52,7 @@ would write on paper, as shown below.
|
|
|
53
52
|
!!! Begin
|
|
54
53
|
|
|
55
54
|
Looking for a good place to start? Then why not begin with our [regression
|
|
56
|
-
notebook](https://docs.jaxgaussianprocesses.com/
|
|
55
|
+
notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
|
|
57
56
|
|
|
58
57
|
## Citing GPJax
|
|
59
58
|
|
|
@@ -122,7 +122,7 @@ print(constant_param._tag)
|
|
|
122
122
|
# For most users, you will not need to worry about this as we provide a set of default
|
|
123
123
|
# bijectors that are defined for all the parameter types we support. However, see our
|
|
124
124
|
# [Kernel Guide
|
|
125
|
-
# Notebook](https://docs.jaxgaussianprocesses.com/
|
|
125
|
+
# Notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) to
|
|
126
126
|
# see how you can define your own bijectors and parameter types.
|
|
127
127
|
|
|
128
128
|
# %%
|
|
@@ -156,7 +156,7 @@ transform(_close_to_zero_state, DEFAULT_BIJECTION, inverse=True)
|
|
|
156
156
|
# may be nested within several functions e.g., a kernel function within a GP model.
|
|
157
157
|
# Fortunately, transforming several parameters is a simple operation that we here
|
|
158
158
|
# demonstrate for a conjugate GP posterior (see our [Regression
|
|
159
|
-
# Notebook](https://docs.jaxgaussianprocesses.com/
|
|
159
|
+
# Notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) for detailed
|
|
160
160
|
# explanation of this model.).
|
|
161
161
|
|
|
162
162
|
# %%
|
|
@@ -239,7 +239,7 @@ print(positive_reals)
|
|
|
239
239
|
# useful as it allows us to efficiently operate on a subset of the parameters whilst
|
|
240
240
|
# leaving the others untouched. Looking forward, we hope to use this functionality in
|
|
241
241
|
# our [Variational Inference
|
|
242
|
-
# Approximations](https://docs.jaxgaussianprocesses.com/
|
|
242
|
+
# Approximations](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/) to
|
|
243
243
|
# perform more efficient updates of the variational parameters and then the model's
|
|
244
244
|
# hyperparameters.
|
|
245
245
|
|
|
@@ -361,7 +361,7 @@ ax.set(xlabel="x", ylabel="m(x)")
|
|
|
361
361
|
# In this notebook we have explored how GPJax's Flax-based backend may be easily
|
|
362
362
|
# manipulated and extended. For a more applied look at this, see how we construct a
|
|
363
363
|
# kernel on polar coordinates in our [Kernel
|
|
364
|
-
# Guide](https://docs.jaxgaussianprocesses.com/
|
|
364
|
+
# Guide](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
365
365
|
# notebook.
|
|
366
366
|
#
|
|
367
367
|
# ## System configuration
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -154,9 +154,9 @@ plt.show()
|
|
|
154
154
|
# We'll now independently learn Gaussian process posterior distributions for each
|
|
155
155
|
# dataset. We won't spend any time here discussing how GP hyperparameters are
|
|
156
156
|
# optimised. For advice on achieving this, see the
|
|
157
|
-
# [Regression notebook](https://docs.jaxgaussianprocesses.com/
|
|
157
|
+
# [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
158
158
|
# for advice on optimisation and the
|
|
159
|
-
# [Kernels notebook](https://docs.jaxgaussianprocesses.com/
|
|
159
|
+
# [Kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/) for
|
|
160
160
|
# advice on selecting an appropriate kernel.
|
|
161
161
|
|
|
162
162
|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
#
|
|
21
21
|
# In this guide we introduce the Bayesian Optimisation (BO) paradigm for
|
|
22
22
|
# optimising black-box functions. We'll assume an understanding of Gaussian processes
|
|
23
|
-
# (GPs), so if you're not familiar with them, check out our [GP introduction notebook](https://docs.jaxgaussianprocesses.com/
|
|
23
|
+
# (GPs), so if you're not familiar with them, check out our [GP introduction notebook](https://docs.jaxgaussianprocesses.com/_examples/intro_to_gps/).
|
|
24
24
|
|
|
25
25
|
# %%
|
|
26
26
|
from typing import (
|
|
@@ -278,7 +278,7 @@ opt_posterior = return_optimised_posterior(D, prior, key)
|
|
|
278
278
|
# will do this using the `sample_approx` method, which generates an approximate sample
|
|
279
279
|
# from the posterior using decoupled sampling introduced in ([Wilson et al.,
|
|
280
280
|
# 2020](https://proceedings.mlr.press/v119/wilson20a.html)) and discussed in our [Pathwise
|
|
281
|
-
# Sampling Notebook](https://docs.jaxgaussianprocesses.com/
|
|
281
|
+
# Sampling Notebook](https://docs.jaxgaussianprocesses.com/_examples/spatial/). This method
|
|
282
282
|
# is used as it enables us to sample from the posterior in a manner which scales linearly
|
|
283
283
|
# with the number of points sampled, $O(N)$, mitigating the cubic cost associated with
|
|
284
284
|
# drawing exact samples from a GP posterior, $O(N^3)$. It also generates more accurate
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -193,15 +193,20 @@ ax.legend()
|
|
|
193
193
|
# $\boldsymbol{x}$, we can expand the log of this about the posterior mode
|
|
194
194
|
# $\hat{\boldsymbol{f}}$ via a Taylor expansion. This gives:
|
|
195
195
|
#
|
|
196
|
+
# $$
|
|
196
197
|
# \begin{align}
|
|
197
198
|
# \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3).
|
|
198
199
|
# \end{align}
|
|
200
|
+
# $$
|
|
199
201
|
#
|
|
200
202
|
# Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode,
|
|
201
203
|
# this suggests the following approximation
|
|
204
|
+
#
|
|
205
|
+
# $$
|
|
202
206
|
# \begin{align}
|
|
203
207
|
# \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\}
|
|
204
208
|
# \end{align},
|
|
209
|
+
# $$
|
|
205
210
|
#
|
|
206
211
|
# that we identify as a Gaussian distribution,
|
|
207
212
|
# $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$.
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# extension: .py
|
|
8
8
|
# format_name: percent
|
|
9
9
|
# format_version: '1.3'
|
|
10
|
-
# jupytext_version: 1.16.
|
|
10
|
+
# jupytext_version: 1.16.6
|
|
11
11
|
# kernelspec:
|
|
12
12
|
# display_name: gpjax_beartype
|
|
13
13
|
# language: python
|
|
@@ -131,7 +131,7 @@ q = gpx.variational_families.CollapsedVariationalGaussian(
|
|
|
131
131
|
# %% [markdown]
|
|
132
132
|
# We now train our model akin to a Gaussian process regression model via the `fit`
|
|
133
133
|
# abstraction. Unlike the regression example given in the
|
|
134
|
-
# [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/
|
|
134
|
+
# [conjugate regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/),
|
|
135
135
|
# the inducing locations that induce our variational posterior distribution are now
|
|
136
136
|
# part of the model's parameters. Using a gradient-based optimiser, we can then
|
|
137
137
|
# _optimise_ their location such that the evidence lower bound is maximised.
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -71,7 +71,7 @@ cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
71
71
|
# * White noise
|
|
72
72
|
# * Linear.
|
|
73
73
|
# * Polynomial.
|
|
74
|
-
# * [Graph kernels](https://docs.jaxgaussianprocesses.com/
|
|
74
|
+
# * [Graph kernels](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/).
|
|
75
75
|
#
|
|
76
76
|
# While the syntax is consistent, each kernel's type influences the
|
|
77
77
|
# characteristics of the sample paths drawn. We visualise this below with 10
|
|
@@ -185,7 +185,7 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
|
|
|
185
185
|
# We'll demonstrate this process now for a circular kernel --- an adaption of
|
|
186
186
|
# the excellent guide given in the PYMC3 documentation. We encourage curious
|
|
187
187
|
# readers to visit their notebook
|
|
188
|
-
# [here](https://www.pymc.io/projects/docs/en/v3/pymc-
|
|
188
|
+
# [here](https://www.pymc.io/projects/docs/en/v3/pymc-_examples/_examples/gaussian_processes/GP-Circular.html).
|
|
189
189
|
#
|
|
190
190
|
# ### Circular kernel
|
|
191
191
|
#
|
|
@@ -198,9 +198,15 @@ fig.colorbar(im3, ax=ax[3], fraction=0.05)
|
|
|
198
198
|
# kernels do not exhibit this behaviour and instead _wrap_ around the boundary
|
|
199
199
|
# points to create a smooth function. Such a kernel was given in [Padonou &
|
|
200
200
|
# Roustant (2015)](https://hal.inria.fr/hal-01119942v1) where any two angles
|
|
201
|
-
# $\theta$ and $\theta'$ are written as
|
|
201
|
+
# $\theta$ and $\theta'$ are written as
|
|
202
|
+
#
|
|
203
|
+
# $$
|
|
204
|
+
# \begin{align}
|
|
205
|
+
# W_c(\theta, \theta') & = \left\lvert
|
|
202
206
|
# \left(1 + \tau \frac{d(\theta, \theta')}{c} \right) \left(1 - \frac{d(\theta,
|
|
203
|
-
# \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}
|
|
207
|
+
# \theta')}{c} \right)^{\tau} \right\rvert \quad \tau \geq 4 \tag{1}.
|
|
208
|
+
# \end{align}
|
|
209
|
+
# $$
|
|
204
210
|
#
|
|
205
211
|
# Here the hyperparameter $\tau$ is analogous to a lengthscale for Euclidean
|
|
206
212
|
# stationary kernels, controlling the correlation between pairs of observations.
|
|
@@ -266,7 +272,7 @@ class Polar(gpx.kernels.AbstractKernel):
|
|
|
266
272
|
#
|
|
267
273
|
# We proceed to fit a GP with our custom circular kernel to a random sequence of
|
|
268
274
|
# points on a circle (see the
|
|
269
|
-
# [Regression notebook](https://docs.jaxgaussianprocesses.com/
|
|
275
|
+
# [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
270
276
|
# for further details on this process).
|
|
271
277
|
|
|
272
278
|
# %%
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# extension: .py
|
|
8
8
|
# format_name: percent
|
|
9
9
|
# format_version: '1.3'
|
|
10
|
-
# jupytext_version: 1.16.
|
|
10
|
+
# jupytext_version: 1.16.6
|
|
11
11
|
# kernelspec:
|
|
12
12
|
# display_name: gpjax
|
|
13
13
|
# language: python
|
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
# such problems include Bayesian optimisation (BO) and experimental design. For an
|
|
23
23
|
# in-depth introduction to Bayesian optimisation itself, be sure to checkout out our
|
|
24
24
|
# [Introduction to BO
|
|
25
|
-
# Notebook](https://docs.jaxgaussianprocesses.com/
|
|
25
|
+
# Notebook](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/).
|
|
26
26
|
#
|
|
27
27
|
# We'll be using BO as a case study to demonstrate how one may use the decision making
|
|
28
28
|
# module to solve sequential decision making problems. The goal of the decision making
|
|
@@ -76,7 +76,7 @@ cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
76
76
|
# ## The Black-Box Objective Function
|
|
77
77
|
#
|
|
78
78
|
# We'll be using the same problem as in the [Introduction to BO
|
|
79
|
-
# Notebook](https://docs.jaxgaussianprocesses.com/
|
|
79
|
+
# Notebook](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/), but
|
|
80
80
|
# rather than focussing on the mechanics of BO we'll be looking at how one may use the
|
|
81
81
|
# abstractions provided by the decision making module to implement the BO loop.
|
|
82
82
|
#
|
|
@@ -181,7 +181,7 @@ likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
|
|
|
181
181
|
# this for us. This class takes as input a `prior` and `likeligood_builder`, which we have
|
|
182
182
|
# defined above. We tend to also optimise the hyperparameters of the GP prior when
|
|
183
183
|
# "fitting" our GP, as demonstrated in the [Regression
|
|
184
|
-
# notebook](https://docs.jaxgaussianprocesses.com/
|
|
184
|
+
# notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/). This will be
|
|
185
185
|
# using the GPJax `fit` method under the hood, which requires an `optimization_objective`,
|
|
186
186
|
# `optimizer` and `num_optimization_iters`. Therefore, we also pass these to the
|
|
187
187
|
# `PosteriorHandler` as demonstrated below:
|
|
@@ -257,7 +257,7 @@ acquisition_maximizer = ContinuousSinglePointUtilityMaximizer(
|
|
|
257
257
|
#
|
|
258
258
|
# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
|
|
259
259
|
# since our module also provides e.g. `ProbabilityOfImprovement`, `ExpectedImprovment`,
|
|
260
|
-
# which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/
|
|
260
|
+
# which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/).
|
|
261
261
|
|
|
262
262
|
|
|
263
263
|
# %% [markdown]
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -141,7 +141,7 @@ class DeepKernelFunction(AbstractKernel):
|
|
|
141
141
|
# activation functions between the layers. The first hidden layer contains 64 units,
|
|
142
142
|
# while the second layer contains 32 units. Finally, we'll make the output of our
|
|
143
143
|
# network a three units wide. The corresponding kernel that we define will then be of
|
|
144
|
-
# [ARD form](https://docs.jaxgaussianprocesses.com/
|
|
144
|
+
# [ARD form](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#active-dimensions)
|
|
145
145
|
# to allow for different lengthscales in each dimension of the feature space.
|
|
146
146
|
# Users may wish to design more intricate network structures for more complex tasks,
|
|
147
147
|
# which functionality is supported well in Haiku.
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.6
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
# of a graph using a Gaussian process with a Matérn kernel presented in
|
|
23
23
|
# <strong data-cite="borovitskiy2021matern"></strong>. For a general discussion of the
|
|
24
24
|
# kernels supported within GPJax, see the
|
|
25
|
-
# [kernels notebook](https://docs.jaxgaussianprocesses.com/
|
|
25
|
+
# [kernels notebook](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels).
|
|
26
26
|
|
|
27
27
|
# %%
|
|
28
28
|
import random
|
|
@@ -88,7 +88,9 @@ nx.draw(
|
|
|
88
88
|
#
|
|
89
89
|
# Graph kernels use the _Laplacian matrix_ $L$ to quantify the smoothness of a signal
|
|
90
90
|
# (or function) on a graph
|
|
91
|
+
#
|
|
91
92
|
# $$L=D-A,$$
|
|
93
|
+
#
|
|
92
94
|
# where $D$ is the diagonal _degree matrix_ containing each vertices' degree and $A$
|
|
93
95
|
# is the _adjacency matrix_ that has an $(i,j)^{\text{th}}$ entry of 1 if $v_i, v_j$
|
|
94
96
|
# are connected and 0 otherwise. [Networkx](https://networkx.org) gives us an easy
|
|
@@ -151,7 +153,7 @@ cbar = plt.colorbar(sm, ax=ax)
|
|
|
151
153
|
# non-Euclidean, our likelihood is still Gaussian and the model is still conjugate.
|
|
152
154
|
# For this reason, we simply perform gradient descent on the GP's marginal
|
|
153
155
|
# log-likelihood term as in the
|
|
154
|
-
# [regression notebook](https://docs.jaxgaussianprocesses.com/
|
|
156
|
+
# [regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
|
|
155
157
|
# We do this using the BFGS optimiser.
|
|
156
158
|
|
|
157
159
|
# %%
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# extension: .py
|
|
8
8
|
# format_name: percent
|
|
9
9
|
# format_version: '1.3'
|
|
10
|
-
# jupytext_version: 1.16.
|
|
10
|
+
# jupytext_version: 1.16.6
|
|
11
11
|
# kernelspec:
|
|
12
12
|
# display_name: gpjax
|
|
13
13
|
# language: python
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
# %% [markdown]
|
|
18
18
|
# # New to Gaussian Processes?
|
|
19
19
|
#
|
|
20
|
+
#
|
|
20
21
|
# Fantastic that you're here! This notebook is designed to be a gentle
|
|
21
22
|
# introduction to the mathematics of Gaussian processes (GPs). No prior
|
|
22
23
|
# knowledge of Bayesian inference or GPs is assumed, and this notebook is
|
|
@@ -33,10 +34,11 @@
|
|
|
33
34
|
# model are unknown, and our goal is to conduct inference to determine their
|
|
34
35
|
# range of likely values. To achieve this, we apply Bayes' theorem
|
|
35
36
|
#
|
|
37
|
+
# $$
|
|
36
38
|
# \begin{align}
|
|
37
|
-
# \
|
|
38
|
-
# p(\theta\,|\, \mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\,|\,\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta}\,,
|
|
39
|
+
# p(\theta\mid\mathbf{y}) = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{p(\mathbf{y})} = \frac{p(\theta)p(\mathbf{y}\mid\theta)}{\int_{\theta}p(\mathbf{y}, \theta)\mathrm{d}\theta},
|
|
39
40
|
# \end{align}
|
|
41
|
+
# $$
|
|
40
42
|
#
|
|
41
43
|
# where $p(\mathbf{y}\,|\,\theta)$ denotes the _likelihood_, or model, and
|
|
42
44
|
# quantifies how likely the observed dataset $\mathbf{y}$ is, given the
|
|
@@ -58,7 +60,7 @@
|
|
|
58
60
|
# family, then there exists a conjugate prior. However, the conjugate prior may
|
|
59
61
|
# not have a form that precisely reflects the practitioner's belief surrounding
|
|
60
62
|
# the parameter. For this reason, conjugate models seldom appear; one exception
|
|
61
|
-
# to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/
|
|
63
|
+
# to this is GP regression that we present fully in our [Regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/).
|
|
62
64
|
#
|
|
63
65
|
# For models that do not contain a conjugate prior, the marginal log-likelihood
|
|
64
66
|
# must be calculated to normalise the posterior distribution and ensure it
|
|
@@ -74,9 +76,13 @@
|
|
|
74
76
|
# new points $\mathbf{y}^{\star}$ through the _posterior predictive
|
|
75
77
|
# distribution_. This is achieved by integrating out the parameter set $\theta$
|
|
76
78
|
# from our posterior distribution through
|
|
79
|
+
#
|
|
80
|
+
# $$
|
|
77
81
|
# \begin{align}
|
|
78
82
|
# p(\mathbf{y}^{\star}\mid \mathbf{y}) = \int p(\mathbf{y}^{\star} \,|\, \theta, \mathbf{y} ) p(\theta\,|\, \mathbf{y})\mathrm{d}\theta\,.
|
|
79
83
|
# \end{align}
|
|
84
|
+
# $$
|
|
85
|
+
#
|
|
80
86
|
# As with the marginal log-likelihood, evaluating this quantity requires
|
|
81
87
|
# computing an integral which may not be tractable, particularly when $\theta$
|
|
82
88
|
# is high-dimensional.
|
|
@@ -85,13 +91,16 @@
|
|
|
85
91
|
# distribution, so we often compute and report moments of the posterior
|
|
86
92
|
# distribution. Most commonly, we report the first moment and the centred second
|
|
87
93
|
# moment
|
|
94
|
+
#
|
|
88
95
|
# $$
|
|
89
96
|
# \begin{alignat}{2}
|
|
90
|
-
# \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta
|
|
97
|
+
# \mu = \mathbb{E}[\theta\,|\,\mathbf{y}] & = \int \theta
|
|
98
|
+
# p(\theta\mid\mathbf{y})\mathrm{d}\theta \quad \\
|
|
91
99
|
# \sigma^2 = \mathbb{V}[\theta\,|\,\mathbf{y}] & = \int \left(\theta -
|
|
92
100
|
# \mathbb{E}[\theta\,|\,\mathbf{y}]\right)^2p(\theta\,|\,\mathbf{y})\mathrm{d}\theta&\,.
|
|
93
101
|
# \end{alignat}
|
|
94
102
|
# $$
|
|
103
|
+
#
|
|
95
104
|
# Through this pair of statistics, we can communicate our beliefs about the most
|
|
96
105
|
# likely value of $\theta$ i.e., $\mu$, and the uncertainty $\sigma$ around the
|
|
97
106
|
# expected value. However, as with the marginal log-likelihood and predictive
|
|
@@ -209,9 +218,7 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists):
|
|
|
209
218
|
d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
|
|
210
219
|
xx.shape
|
|
211
220
|
)
|
|
212
|
-
cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap)
|
|
213
|
-
for c in cntf.collections:
|
|
214
|
-
c.set_edgecolor("face")
|
|
221
|
+
cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
|
|
215
222
|
a.set_xlim(-2.75, 2.75)
|
|
216
223
|
a.set_ylim(-2.75, 2.75)
|
|
217
224
|
samples = d.sample(seed=key, sample_shape=(5000,))
|
|
@@ -228,13 +235,16 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists):
|
|
|
228
235
|
# %% [markdown]
|
|
229
236
|
# Extending the intuition given for the moments of a univariate Gaussian random
|
|
230
237
|
# variables, we can obtain the mean and covariance by
|
|
238
|
+
#
|
|
231
239
|
# $$
|
|
232
240
|
# \begin{align}
|
|
233
|
-
#
|
|
241
|
+
# \mathbb{E}[\mathbf{y}] & = \mathbf{\mu}, \\
|
|
242
|
+
# \operatorname{Cov}(\mathbf{y}) & = \mathbf{E}\left[(\mathbf{y} - \mathbf{\mu})(\mathbf{y} - \mathbf{\mu})^{\top} \right] \\
|
|
234
243
|
# & =\mathbb{E}[\mathbf{y}\mathbf{y}^{\top}] - \mathbb{E}[\mathbf{y}]\mathbb{E}[\mathbf{y}]^{\top} \\
|
|
235
244
|
# & =\mathbf{\Sigma}\,.
|
|
236
245
|
# \end{align}
|
|
237
246
|
# $$
|
|
247
|
+
#
|
|
238
248
|
# The covariance matrix is a symmetric positive definite matrix that generalises
|
|
239
249
|
# the notion of variance to multiple dimensions. The matrix's diagonal entries
|
|
240
250
|
# contain the variance of each element, whilst the off-diagonal entries quantify
|
|
@@ -336,6 +346,7 @@ with warnings.catch_warnings():
|
|
|
336
346
|
# $\mathbf{x}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{x}}, \boldsymbol{\Sigma}_{\mathbf{xx}})$ and
|
|
337
347
|
# $\mathbf{y}\sim\mathcal{N}(\boldsymbol{\mu}_{\mathbf{y}}, \boldsymbol{\Sigma}_{\mathbf{yy}})$.
|
|
338
348
|
# We define the joint distribution as
|
|
349
|
+
#
|
|
339
350
|
# $$
|
|
340
351
|
# \begin{align}
|
|
341
352
|
# p\left(\begin{bmatrix}
|
|
@@ -348,6 +359,7 @@ with warnings.catch_warnings():
|
|
|
348
359
|
# \end{bmatrix} \right)\,,
|
|
349
360
|
# \end{align}
|
|
350
361
|
# $$
|
|
362
|
+
#
|
|
351
363
|
# where $\boldsymbol{\Sigma}_{\mathbf{x}\mathbf{y}}$ is the cross-covariance
|
|
352
364
|
# matrix of $\mathbf{x}$ and $\mathbf{y}$.
|
|
353
365
|
#
|
|
@@ -363,6 +375,7 @@ with warnings.catch_warnings():
|
|
|
363
375
|
#
|
|
364
376
|
# For a joint Gaussian random variable, the marginalisation of $\mathbf{x}$ or
|
|
365
377
|
# $\mathbf{y}$ is given by
|
|
378
|
+
#
|
|
366
379
|
# $$
|
|
367
380
|
# \begin{alignat}{3}
|
|
368
381
|
# & \int p(\mathbf{x}, \mathbf{y})\mathrm{d}\mathbf{y} && = p(\mathbf{x})
|
|
@@ -372,7 +385,9 @@ with warnings.catch_warnings():
|
|
|
372
385
|
# \boldsymbol{\Sigma}_{\mathbf{yy}})\,.
|
|
373
386
|
# \end{alignat}
|
|
374
387
|
# $$
|
|
388
|
+
#
|
|
375
389
|
# The conditional distributions are given by
|
|
390
|
+
#
|
|
376
391
|
# $$
|
|
377
392
|
# \begin{align}
|
|
378
393
|
# p(\mathbf{y}\,|\, \mathbf{x}) & = \mathcal{N}\left(\boldsymbol{\mu}_{\mathbf{y}} + \boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}(\mathbf{x}-\boldsymbol{\mu}_{\mathbf{x}}), \boldsymbol{\Sigma}_{\mathbf{yy}}-\boldsymbol{\Sigma}_{\mathbf{yx}}\boldsymbol{\Sigma}_{\mathbf{xx}}^{-1}\boldsymbol{\Sigma}_{\mathbf{xy}}\right)\,.
|
|
@@ -401,6 +416,7 @@ with warnings.catch_warnings():
|
|
|
401
416
|
# We aim to capture the relationship between $\mathbf{X}$ and $\mathbf{y}$ using
|
|
402
417
|
# a model $f$ with which we may make predictions at an unseen set of test points
|
|
403
418
|
# $\mathbf{X}^{\star}\subset\mathcal{X}$. We formalise this by
|
|
419
|
+
#
|
|
404
420
|
# $$
|
|
405
421
|
# \begin{align}
|
|
406
422
|
# y = f(\mathbf{X}) + \varepsilon\,,
|
|
@@ -430,6 +446,7 @@ with warnings.catch_warnings():
|
|
|
430
446
|
# convenience in the remainder of this article.
|
|
431
447
|
#
|
|
432
448
|
# We define a joint GP prior over the latent function
|
|
449
|
+
#
|
|
433
450
|
# $$
|
|
434
451
|
# \begin{align}
|
|
435
452
|
# p(\mathbf{f}, \mathbf{f}^{\star}) = \mathcal{N}\left(\mathbf{0}, \begin{bmatrix}
|
|
@@ -437,14 +454,17 @@ with warnings.catch_warnings():
|
|
|
437
454
|
# \end{bmatrix}\right)\,,
|
|
438
455
|
# \end{align}
|
|
439
456
|
# $$
|
|
457
|
+
#
|
|
440
458
|
# where $\mathbf{f}^{\star} = f(\mathbf{X}^{\star})$. Conditional on the GP's
|
|
441
459
|
# latent function $f$, we assume a factorising likelihood generates our
|
|
442
460
|
# observations
|
|
461
|
+
#
|
|
443
462
|
# $$
|
|
444
463
|
# \begin{align}
|
|
445
464
|
# p(\mathbf{y}\,|\,\mathbf{f}) = \prod_{i=1}^n p(y_i\,|\, f_i)\,.
|
|
446
465
|
# \end{align}
|
|
447
466
|
# $$
|
|
467
|
+
#
|
|
448
468
|
# Strictly speaking, the likelihood function is
|
|
449
469
|
# $p(\mathbf{y}\,|\,\phi(\mathbf{f}))$ where $\phi$ is the likelihood function's
|
|
450
470
|
# associated link function. Example link functions include the probit or
|
|
@@ -453,7 +473,7 @@ with warnings.catch_warnings():
|
|
|
453
473
|
# considers Gaussian likelihood functions where the role of $\phi$ is
|
|
454
474
|
# superfluous. However, this intuition will be helpful for models with a
|
|
455
475
|
# non-Gaussian likelihood, such as those encountered in
|
|
456
|
-
# [classification](https://docs.jaxgaussianprocesses.com/
|
|
476
|
+
# [classification](https://docs.jaxgaussianprocesses.com/_examples/classification).
|
|
457
477
|
#
|
|
458
478
|
# Applying Bayes' theorem \eqref{eq:BayesTheorem} yields the joint posterior distribution over the
|
|
459
479
|
# latent function
|
|
@@ -470,7 +490,7 @@ with warnings.catch_warnings():
|
|
|
470
490
|
# function with parameters $\boldsymbol{\theta}$ that maps pairs of inputs
|
|
471
491
|
# $\mathbf{X}, \mathbf{X}' \in \mathcal{X}$ onto the real line. We dedicate the
|
|
472
492
|
# entirety of the [Introduction to Kernels
|
|
473
|
-
# notebook](https://docs.jaxgaussianprocesses.com/
|
|
493
|
+
# notebook](https://docs.jaxgaussianprocesses.com/_examples/intro_to_kernels) to
|
|
474
494
|
# exploring the different GPs each kernel can yield.
|
|
475
495
|
#
|
|
476
496
|
# ## Gaussian process regression
|
|
@@ -479,20 +499,25 @@ with warnings.catch_warnings():
|
|
|
479
499
|
# $p(y_i\,|\, f_i) = \mathcal{N}(y_i\,|\, f_i, \sigma_n^2)$,
|
|
480
500
|
# marginalising $\mathbf{f}$ from the joint posterior to obtain
|
|
481
501
|
# the posterior predictive distribution is exact
|
|
502
|
+
#
|
|
482
503
|
# $$
|
|
483
504
|
# \begin{align}
|
|
484
505
|
# p(\mathbf{f}^{\star}\mid \mathbf{y}) = \mathcal{N}(\mathbf{f}^{\star}\,|\,\boldsymbol{\mu}_{\,|\,\mathbf{y}}, \Sigma_{\,|\,\mathbf{y}})\,,
|
|
485
506
|
# \end{align}
|
|
486
507
|
# $$
|
|
508
|
+
#
|
|
487
509
|
# where
|
|
510
|
+
#
|
|
488
511
|
# $$
|
|
489
512
|
# \begin{align}
|
|
490
513
|
# \mathbf{\mu}_{\mid \mathbf{y}} & = \mathbf{K}_{\star f}\left( \mathbf{K}_{ff}+\sigma^2_n\mathbf{I}_n\right)^{-1}\mathbf{y} \\
|
|
491
514
|
# \Sigma_{\,|\,\mathbf{y}} & = \mathbf{K}_{\star\star} - \mathbf{K}_{xf}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n\right)^{-1}\mathbf{K}_{fx} \,.
|
|
492
515
|
# \end{align}
|
|
493
516
|
# $$
|
|
517
|
+
#
|
|
494
518
|
# Further, the log of the marginal likelihood of the GP can
|
|
495
519
|
# be analytically expressed as
|
|
520
|
+
#
|
|
496
521
|
# $$
|
|
497
522
|
# \begin{align}
|
|
498
523
|
# & = 0.5\left(-\underbrace{\mathbf{y}^{\top}\left(\mathbf{K}_{ff} + \sigma_n^2\mathbf{I}_n \right)^{-1}\mathbf{y}}_{\text{Data fit}} -\underbrace{\log\lvert \mathbf{K}_{ff} + \sigma^2_n\rvert}_{\text{Complexity}} -\underbrace{n\log 2\pi}_{\text{Constant}} \right)\,.
|
|
@@ -505,6 +530,7 @@ with warnings.catch_warnings():
|
|
|
505
530
|
# we call these terms the model hyperparameters
|
|
506
531
|
# $\boldsymbol{\xi} = \{\boldsymbol{\theta},\sigma_n^2\}$
|
|
507
532
|
# from which the maximum likelihood estimate is given by
|
|
533
|
+
#
|
|
508
534
|
# $$
|
|
509
535
|
# \begin{align*}
|
|
510
536
|
# \boldsymbol{\xi}^{\star} = \operatorname{argmax}_{\boldsymbol{\xi} \in \Xi} \log p(\mathbf{y})\,.
|
|
@@ -532,7 +558,7 @@ with warnings.catch_warnings():
|
|
|
532
558
|
# Bayes' theorem and the definition of a Gaussian random variable. Using the
|
|
533
559
|
# ideas presented in this notebook, the user should be in a position to dive
|
|
534
560
|
# into our [Regression
|
|
535
|
-
# notebook](https://docs.jaxgaussianprocesses.com/
|
|
561
|
+
# notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/) and
|
|
536
562
|
# start getting their hands on some code. For those looking to learn more about
|
|
537
563
|
# the underling theory of GPs, an excellent starting point is the [Gaussian
|
|
538
564
|
# Processes for Machine Learning](http://gaussianprocess.org/gpml/) textbook.
|