gpjax 0.13.3__tar.gz → 0.13.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.13.3 → gpjax-0.13.4}/.gitignore +5 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/PKG-INFO +2 -2
- {gpjax-0.13.3 → gpjax-0.13.4}/README.md +1 -1
- gpjax-0.13.4/examples/heteroscedastic_inference.py +389 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/regression.py +24 -23
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/__init__.py +1 -1
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/citation.py +13 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/gps.py +77 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/likelihoods.py +234 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/mean_functions.py +2 -2
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/objectives.py +56 -1
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/parameters.py +8 -1
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/variational_families.py +129 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/mkdocs.yml +1 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/pyproject.toml +2 -1
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/conftest.py +7 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/integration_tests.py +9 -2
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_citations.py +16 -0
- gpjax-0.13.4/tests/test_heteroscedastic.py +407 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_mean_functions.py +16 -1
- {gpjax-0.13.3 → gpjax-0.13.4}/uv.lock +23 -0
- gpjax-0.13.3/.github/workflows/pr_greeting.yml +0 -62
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/FUNDING.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/codecov.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/commitlint.config.js +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/dependabot.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/labeler.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/labels.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/pull_request_template.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/release-drafter.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/auto-label.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/commit-lint.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/integration.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/release.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/security-analysis.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/tests.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/CITATION.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/LICENSE.txt +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/Makefile +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/contributing.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/design.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/index.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/index.rst +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/installation.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/javascripts/katex.js +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/refs.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/sharp_bits.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/GP.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/GP.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/favicon.ico +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/backend.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/barycentres.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/classification.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/collapsed_vi.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/deep_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/graph_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/oceanmodelling.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/poisson.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/examples/yacht.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/dataset.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/distributions.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/fit.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/integrators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/operations.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/operators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/scan.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/typing.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_dataset.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_fit.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_gps.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_imports.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_integrators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_likelihoods.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_linalg.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_markdown.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_objectives.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_parameters.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_variational_families.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.4
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
|
|
@@ -141,7 +141,7 @@ GPJax into the package it is today.
|
|
|
141
141
|
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
142
142
|
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
143
143
|
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
144
|
-
> - [**
|
|
144
|
+
> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
|
|
145
145
|
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
146
146
|
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
147
147
|
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
@@ -70,7 +70,7 @@ GPJax into the package it is today.
|
|
|
70
70
|
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
71
71
|
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
72
72
|
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
73
|
-
> - [**
|
|
73
|
+
> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
|
|
74
74
|
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
75
75
|
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
76
76
|
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# ---
|
|
3
|
+
# jupyter:
|
|
4
|
+
# jupytext:
|
|
5
|
+
# cell_metadata_filter: -all
|
|
6
|
+
# custom_cell_magics: kql
|
|
7
|
+
# text_representation:
|
|
8
|
+
# extension: .py
|
|
9
|
+
# format_name: percent
|
|
10
|
+
# format_version: '1.3'
|
|
11
|
+
# jupytext_version: 1.17.3
|
|
12
|
+
# kernelspec:
|
|
13
|
+
# display_name: .venv
|
|
14
|
+
# language: python
|
|
15
|
+
# name: python3
|
|
16
|
+
# ---
|
|
17
|
+
|
|
18
|
+
# %% [markdown]
|
|
19
|
+
# # Heteroscedastic inference for regression and classification
|
|
20
|
+
#
|
|
21
|
+
# This notebook shows how to fit a heteroscedastic Gaussian processes (GPs) that
|
|
22
|
+
# allows one to perform regression where there exists non-constant, or
|
|
23
|
+
# input-dependent, noise.
|
|
24
|
+
#
|
|
25
|
+
#
|
|
26
|
+
# ## Background
|
|
27
|
+
# A heteroscedastic GP couples two latent functions:
|
|
28
|
+
# - A **signal GP** $f(\cdot)$ for the mean response.
|
|
29
|
+
# - A **noise GP** $g(\cdot)$ that maps to a positive variance
|
|
30
|
+
# $\sigma^2(x) = \phi(g(x))$ via a positivity transform $\phi$ (typically
|
|
31
|
+
# ${\rm exp}$ or ${\rm softplus}$). Intuitively, we are introducing a pair of GPs;
|
|
32
|
+
# one to model the latent mean, and a second that models the log-noise variance. This
|
|
33
|
+
# is in direct contrast a
|
|
34
|
+
# [homoscedastic GP](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
35
|
+
# where we learn a constant value for the noise.
|
|
36
|
+
#
|
|
37
|
+
# In the Gaussian case, the observed response follows
|
|
38
|
+
# $$y \mid f, g \sim \mathcal{N}(f, \sigma^2(x)).$$
|
|
39
|
+
# Variational inference works with independent posteriors $q(f)q(g)$, combining the
|
|
40
|
+
# moments of each into an ELBO. For non-Gaussian likelihoods the same structure
|
|
41
|
+
# remains; only the expected log-likelihood changes.
|
|
42
|
+
|
|
43
|
+
# %%
|
|
44
|
+
from jax import config
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
import jax.random as jr
|
|
47
|
+
import matplotlib as mpl
|
|
48
|
+
import matplotlib.pyplot as plt
|
|
49
|
+
import optax as ox
|
|
50
|
+
|
|
51
|
+
from examples.utils import use_mpl_style
|
|
52
|
+
import gpjax as gpx
|
|
53
|
+
from gpjax.likelihoods import (
|
|
54
|
+
HeteroscedasticGaussian,
|
|
55
|
+
LogNormalTransform,
|
|
56
|
+
SoftplusTransform,
|
|
57
|
+
)
|
|
58
|
+
from gpjax.variational_families import (
|
|
59
|
+
HeteroscedasticVariationalFamily,
|
|
60
|
+
VariationalGaussianInit,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Enable Float64 for stable linear algebra.
|
|
64
|
+
config.update("jax_enable_x64", True)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
use_mpl_style()
|
|
68
|
+
key = jr.key(123)
|
|
69
|
+
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# %% [markdown]
|
|
73
|
+
# ## Dataset simulation
|
|
74
|
+
# We simulate whose mean and noise levels vary with
|
|
75
|
+
# the input. We sample inputs $x \sim \mathcal{U}(0, 1)$ and define the
|
|
76
|
+
# latent signal to be
|
|
77
|
+
# $$f(x) = (x - 0.5)^2 + 0.05;$$
|
|
78
|
+
# a smooth bowl-shaped curve. The observation standard deviation is chosen to be
|
|
79
|
+
# proportional to the signal,
|
|
80
|
+
# $$\sigma(x) = 0.5\,f(x),$$
|
|
81
|
+
# which yields the heteroscedastic generative model
|
|
82
|
+
# $$y \mid x \sim \mathcal{N}\!\big(f(x), \sigma^2(x)\big).$$
|
|
83
|
+
# This construction makes the noise small near the minimum of the bowl and much
|
|
84
|
+
# larger in the tails. We also create a dense test grid that we shall use later for
|
|
85
|
+
# visualising posterior fits and predictive uncertainty.
|
|
86
|
+
|
|
87
|
+
# %%
|
|
88
|
+
# Create data with input-dependent variance.
|
|
89
|
+
key, x_key, noise_key = jr.split(key, 3)
|
|
90
|
+
n = 200
|
|
91
|
+
x = jr.uniform(x_key, (n, 1), minval=0.0, maxval=1.0)
|
|
92
|
+
signal = (x - 0.5) ** 2 + 0.05
|
|
93
|
+
noise_scale = 0.5 * signal
|
|
94
|
+
noise = noise_scale * jr.normal(noise_key, shape=(n, 1))
|
|
95
|
+
y = signal + noise
|
|
96
|
+
train = gpx.Dataset(X=x, y=y)
|
|
97
|
+
|
|
98
|
+
xtest = jnp.linspace(-0.1, 1.1, 200)[:, None]
|
|
99
|
+
signal_test = (xtest - 0.5) ** 2 + 0.05
|
|
100
|
+
noise_scale_test = 0.5 * signal_test
|
|
101
|
+
noise_test = noise_scale_test * jr.normal(noise_key, shape=(200, 1))
|
|
102
|
+
ytest = signal_test + noise_test
|
|
103
|
+
|
|
104
|
+
fig, ax = plt.subplots()
|
|
105
|
+
ax.plot(x, y, "o", label="Observations", alpha=0.7, color=cols[0])
|
|
106
|
+
ax.plot(xtest, signal_test, label="Signal", alpha=0.7, color=cols[1])
|
|
107
|
+
ax.plot(xtest, noise_scale_test, label="Noise scale", alpha=0.7, color=cols[2])
|
|
108
|
+
ax.set_xlabel("$x$")
|
|
109
|
+
ax.set_ylabel("$y$")
|
|
110
|
+
ax.legend(loc="upper left")
|
|
111
|
+
|
|
112
|
+
# %% [markdown]
|
|
113
|
+
# For a homoscedastic baseline, compare this figure with the
|
|
114
|
+
# [Gaussian process regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
|
|
115
|
+
# (`examples/regression.py`), where a single latent GP is paired with constant
|
|
116
|
+
# observation noise.
|
|
117
|
+
|
|
118
|
+
# %% [markdown]
|
|
119
|
+
# ## Prior specification
|
|
120
|
+
# We place independent Gaussian process priors on the signal and noise processes:
|
|
121
|
+
# $$f \sim \mathcal{GP}\big(0, k_f\big), \qquad g \sim \mathcal{GP}\big(0, k_g\big),$$
|
|
122
|
+
# where $k_f$ and $k_g$ are stationary squared-exponential kernels with unit
|
|
123
|
+
# variance and lengthscale of one. The noise process $g$ is mapped to the variance
|
|
124
|
+
# via the logarithmic transform in `LogNormalTransform`, giving
|
|
125
|
+
# $\sigma^2(x) = \exp\big(g(x)\big)$. The joint prior over $(f, g)$ combines with
|
|
126
|
+
# the heteroscedastic Gaussian likelihood,
|
|
127
|
+
# $$p(\mathbf{y} \mid f, g) = \prod_{i=1}^n
|
|
128
|
+
# \mathcal{N}\!\big(y_i \mid f(x_i), \exp(g(x_i))\big),$$
|
|
129
|
+
# to form the posterior target that we shall approximate variationally. The product
|
|
130
|
+
# syntax `signal_prior * likelihood` used below constructs this augmented GP model.
|
|
131
|
+
|
|
132
|
+
# %%
|
|
133
|
+
# Signal and noise priors.
|
|
134
|
+
signal_prior = gpx.gps.Prior(
|
|
135
|
+
mean_function=gpx.mean_functions.Zero(),
|
|
136
|
+
kernel=gpx.kernels.RBF(),
|
|
137
|
+
)
|
|
138
|
+
noise_prior = gpx.gps.Prior(
|
|
139
|
+
mean_function=gpx.mean_functions.Zero(),
|
|
140
|
+
kernel=gpx.kernels.RBF(),
|
|
141
|
+
)
|
|
142
|
+
likelihood = HeteroscedasticGaussian(
|
|
143
|
+
num_datapoints=train.n,
|
|
144
|
+
noise_prior=noise_prior,
|
|
145
|
+
noise_transform=LogNormalTransform(),
|
|
146
|
+
)
|
|
147
|
+
posterior = signal_prior * likelihood
|
|
148
|
+
|
|
149
|
+
# Variational family over both processes.
|
|
150
|
+
z = jnp.linspace(-3.2, 3.2, 25)[:, None]
|
|
151
|
+
q = HeteroscedasticVariationalFamily(
|
|
152
|
+
posterior=posterior,
|
|
153
|
+
inducing_inputs=z,
|
|
154
|
+
inducing_inputs_g=z,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# %% [markdown]
|
|
158
|
+
# The variational family introduces inducing variables for both latent functions,
|
|
159
|
+
# located at the set $Z = \{z_m\}_{m=1}^M$. These inducing variables summarise the
|
|
160
|
+
# infinite-dimensional GP priors in terms of multivariate Gaussian parameters.
|
|
161
|
+
# Optimising the evidence lower bound (ELBO) corresponds to adjusting the means and
|
|
162
|
+
# covariances of the variational posteriors $q(f)$ and $q(g)$ so that they best
|
|
163
|
+
# explain the observed data whilst remaining close to the prior. For a deeper look at
|
|
164
|
+
# these constructions in the homoscedastic setting, refer to the
|
|
165
|
+
# [Sparse Gaussian Process Regression](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
|
|
166
|
+
# (`examples/collapsed_vi.py`) and
|
|
167
|
+
# [Sparse Stochastic Variational Inference](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
|
|
168
|
+
# (`examples/uncollapsed_vi.py`) notebooks.
|
|
169
|
+
|
|
170
|
+
# %% [markdown]
|
|
171
|
+
# ### Optimisation
|
|
172
|
+
# With the model specified, we minimise the negative ELBO,
|
|
173
|
+
# $$\mathcal{L} = \mathbb{E}_{q(f)q(g)}\!\big[\log p(\mathbf{y}\mid f, g)\big]
|
|
174
|
+
# - \mathrm{KL}\!\left[q(f) \,\|\, p(f)\right]
|
|
175
|
+
# - \mathrm{KL}\!\left[q(g) \,\|\, p(g)\right],$$
|
|
176
|
+
# using the Adam optimiser. GPJax automatically selects the tight bound of
|
|
177
|
+
# Lázaro-Gredilla & Titsias (2011) when the likelihood is Gaussian, yielding an
|
|
178
|
+
# analytically tractable expectation over the latent noise process. The resulting
|
|
179
|
+
# optimisation iteratively updates the inducing posteriors for both latent GPs.
|
|
180
|
+
|
|
181
|
+
# %%
|
|
182
|
+
# Optimise the heteroscedastic ELBO (selects tighter bound).
|
|
183
|
+
objective = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
|
|
184
|
+
optimiser = ox.adam(1e-2)
|
|
185
|
+
q_trained, history = gpx.fit(
|
|
186
|
+
model=q,
|
|
187
|
+
objective=objective,
|
|
188
|
+
train_data=train,
|
|
189
|
+
optim=optimiser,
|
|
190
|
+
num_iters=10000,
|
|
191
|
+
verbose=False,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
loss_trace = jnp.asarray(history)
|
|
195
|
+
print(f"Final regression ELBO: {-loss_trace[-1]:.3f}")
|
|
196
|
+
|
|
197
|
+
# %% [markdown]
|
|
198
|
+
# ## Prediction
|
|
199
|
+
# After training we obtain posterior marginals for both latent functions. To make a
|
|
200
|
+
# prediction we evaluate two quantities:
|
|
201
|
+
# 1. The latent posterior over $f$ (mean and variance), which reflects uncertainty
|
|
202
|
+
# in the latent function **prior** to observing noise.
|
|
203
|
+
# 2. The marginal predictive over observations, which integrates out both $f$ and
|
|
204
|
+
# $g$ to provide predictive intervals for future noisy measurements.
|
|
205
|
+
# The helper method `likelihood.predict` performs the second integration for us.
|
|
206
|
+
|
|
207
|
+
# %%
|
|
208
|
+
# Predict on a dense grid.
|
|
209
|
+
xtest = jnp.linspace(-0.1, 1.1, 200)[:, None]
|
|
210
|
+
mf, vf, mg, vg = q_trained.predict(xtest)
|
|
211
|
+
|
|
212
|
+
signal_pred, noise_pred = q_trained.predict_latents(xtest)
|
|
213
|
+
predictive = likelihood.predict(signal_pred, noise_pred)
|
|
214
|
+
|
|
215
|
+
fig, ax = plt.subplots()
|
|
216
|
+
ax.plot(train.X, train.y, "o", label="Observations", alpha=0.5)
|
|
217
|
+
ax.plot(xtest, mf, color="C0", label="Posterior mean")
|
|
218
|
+
ax.fill_between(
|
|
219
|
+
xtest.squeeze(),
|
|
220
|
+
(mf.squeeze() - 2 * jnp.sqrt(vf.squeeze())).squeeze(),
|
|
221
|
+
(mf.squeeze() + 2 * jnp.sqrt(vf.squeeze())).squeeze(),
|
|
222
|
+
color="C0",
|
|
223
|
+
alpha=0.15,
|
|
224
|
+
label="±2 std (latent)",
|
|
225
|
+
)
|
|
226
|
+
ax.fill_between(
|
|
227
|
+
xtest.squeeze(),
|
|
228
|
+
predictive.mean - 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)),
|
|
229
|
+
predictive.mean + 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)),
|
|
230
|
+
color="C1",
|
|
231
|
+
alpha=0.15,
|
|
232
|
+
label="±2 std (observed)",
|
|
233
|
+
)
|
|
234
|
+
ax.set_xlabel("$x$")
|
|
235
|
+
ax.set_ylabel("$y$")
|
|
236
|
+
ax.legend(loc="upper left")
|
|
237
|
+
ax.set_title("Heteroscedastic regression")
|
|
238
|
+
|
|
239
|
+
# %% [markdown]
|
|
240
|
+
# The latent intervals quantify epistemic uncertainty about $f$, whereas the broader
|
|
241
|
+
# observed band adds the aleatoric noise predicted by $g$. The widening of the orange
|
|
242
|
+
# band in the right half matches the ground-truth construction of the dataset.
|
|
243
|
+
|
|
244
|
+
# %% [markdown]
|
|
245
|
+
# ## Sparse Heteroscedastic Regression
|
|
246
|
+
#
|
|
247
|
+
# We now demonstrate how the aforementioned heteroscedastic approach can be extended
|
|
248
|
+
# into sparse scenarios, thus offering more favourable scalability as the size of our
|
|
249
|
+
# dataset grows. To achieve this we defined inducing points for both the signal and
|
|
250
|
+
# noise processes. Decoupling these grids allows us to focus modelling
|
|
251
|
+
# capacity where each latent function varies the most. The synthetic dataset below
|
|
252
|
+
# contains a smooth sinusoidal signal but exhibits a sharply peaked noise shock,
|
|
253
|
+
# mimicking the situation where certain regions of the input space are far noisier
|
|
254
|
+
# than others.
|
|
255
|
+
|
|
256
|
+
# %%
|
|
257
|
+
# Generate data
|
|
258
|
+
key, x_key, noise_key = jr.split(key, 3)
|
|
259
|
+
n = 300
|
|
260
|
+
x = jr.uniform(x_key, (n, 1), minval=-2.0, maxval=2.0)
|
|
261
|
+
signal = jnp.sin(2.0 * x)
|
|
262
|
+
# Gaussian bump of noise
|
|
263
|
+
noise_std = 0.1 + 0.5 * jnp.exp(-0.5 * ((x - 0.5) / 0.4) ** 2)
|
|
264
|
+
y = signal + noise_std * jr.normal(noise_key, shape=(n, 1))
|
|
265
|
+
data_adv = gpx.Dataset(X=x, y=y)
|
|
266
|
+
|
|
267
|
+
# %% [markdown]
|
|
268
|
+
# ### Model components
|
|
269
|
+
# We again adopt RBF priors for both processes but now apply a `SoftplusTransform`
|
|
270
|
+
# to the noise GP. This alternative map enforces positivity whilst avoiding the
|
|
271
|
+
# heavier tails induced by the log-normal transform. The `HeteroscedasticGaussian`
|
|
272
|
+
# likelihood seamlessly accepts the new transform.
|
|
273
|
+
|
|
274
|
+
# %%
|
|
275
|
+
# Define model components
|
|
276
|
+
mean_prior = gpx.gps.Prior(
|
|
277
|
+
mean_function=gpx.mean_functions.Zero(),
|
|
278
|
+
kernel=gpx.kernels.RBF(),
|
|
279
|
+
)
|
|
280
|
+
noise_prior_adv = gpx.gps.Prior(
|
|
281
|
+
mean_function=gpx.mean_functions.Zero(),
|
|
282
|
+
kernel=gpx.kernels.RBF(),
|
|
283
|
+
)
|
|
284
|
+
likelihood_adv = HeteroscedasticGaussian(
|
|
285
|
+
num_datapoints=data_adv.n,
|
|
286
|
+
noise_prior=noise_prior_adv,
|
|
287
|
+
noise_transform=SoftplusTransform(),
|
|
288
|
+
)
|
|
289
|
+
posterior_adv = mean_prior * likelihood_adv
|
|
290
|
+
|
|
291
|
+
# %%
|
|
292
|
+
# Configure variational family
|
|
293
|
+
# The signal requires a richer inducing set to capture its oscillations, whereas the
|
|
294
|
+
# noise process can be summarised with fewer points because the burst is localised.
|
|
295
|
+
z_signal = jnp.linspace(-2.0, 2.0, 30)[:, None]
|
|
296
|
+
z_noise = jnp.linspace(-2.0, 2.0, 15)[:, None]
|
|
297
|
+
|
|
298
|
+
# Use VariationalGaussianInit to pass specific configurations
|
|
299
|
+
q_init_f = VariationalGaussianInit(inducing_inputs=z_signal)
|
|
300
|
+
q_init_g = VariationalGaussianInit(inducing_inputs=z_noise)
|
|
301
|
+
|
|
302
|
+
q_adv = HeteroscedasticVariationalFamily(
|
|
303
|
+
posterior=posterior_adv,
|
|
304
|
+
signal_init=q_init_f,
|
|
305
|
+
noise_init=q_init_g,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# %% [markdown]
|
|
309
|
+
# The initialisation objects `VariationalGaussianInit` allow us to prescribe
|
|
310
|
+
# different inducing grids and initial covariance structures for $f$ and $g$. This
|
|
311
|
+
# flexibility is invaluable when working with large datasets where the latent
|
|
312
|
+
# functions have markedly different smoothness properties.
|
|
313
|
+
|
|
314
|
+
# %%
|
|
315
|
+
# Optimize
|
|
316
|
+
objective_adv = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
|
|
317
|
+
optimiser_adv = ox.adam(1e-2)
|
|
318
|
+
q_adv_trained, _ = gpx.fit(
|
|
319
|
+
model=q_adv,
|
|
320
|
+
objective=objective_adv,
|
|
321
|
+
train_data=data_adv,
|
|
322
|
+
optim=optimiser_adv,
|
|
323
|
+
num_iters=8000,
|
|
324
|
+
verbose=False,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# %%
|
|
328
|
+
# Plotting
|
|
329
|
+
xtest = jnp.linspace(-2.2, 2.2, 200)[:, None]
|
|
330
|
+
pred = q_adv_trained.predict(xtest)
|
|
331
|
+
|
|
332
|
+
# Unpack the named tuple
|
|
333
|
+
mf = pred.mean_f
|
|
334
|
+
vf = pred.variance_f
|
|
335
|
+
mg = pred.mean_g
|
|
336
|
+
vg = pred.variance_g
|
|
337
|
+
|
|
338
|
+
# Calculate total predictive variance
|
|
339
|
+
# The likelihood expects the *latent* noise distribution to compute the predictive
|
|
340
|
+
# but here we can just use the transformed expected variance for plotting.
|
|
341
|
+
# For accurate predictive intervals, we should use likelihood.predict.
|
|
342
|
+
signal_dist, noise_dist = q_adv_trained.predict_latents(xtest)
|
|
343
|
+
predictive_dist = likelihood_adv.predict(signal_dist, noise_dist)
|
|
344
|
+
predictive_mean = predictive_dist.mean
|
|
345
|
+
predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix))
|
|
346
|
+
|
|
347
|
+
fig, ax = plt.subplots()
|
|
348
|
+
ax.plot(x, y, "o", color="black", alpha=0.3, label="Data")
|
|
349
|
+
ax.plot(xtest, mf, color="C0", label="Signal Mean")
|
|
350
|
+
ax.fill_between(
|
|
351
|
+
xtest.squeeze(),
|
|
352
|
+
mf.squeeze() - 2 * jnp.sqrt(vf.squeeze()),
|
|
353
|
+
mf.squeeze() + 2 * jnp.sqrt(vf.squeeze()),
|
|
354
|
+
color="C0",
|
|
355
|
+
alpha=0.2,
|
|
356
|
+
label="Signal Uncertainty",
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Plot total uncertainty (signal + noise)
|
|
360
|
+
ax.plot(xtest, predictive_mean, "--", color="C1", alpha=0.5)
|
|
361
|
+
ax.fill_between(
|
|
362
|
+
xtest.squeeze(),
|
|
363
|
+
predictive_mean - 2 * predictive_std,
|
|
364
|
+
predictive_mean + 2 * predictive_std,
|
|
365
|
+
color="C1",
|
|
366
|
+
alpha=0.1,
|
|
367
|
+
label="Predictive Uncertainty (95%)",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
ax.set_title("Heteroscedastic Regression with Custom Inducing Points")
|
|
371
|
+
ax.legend(loc="upper left", fontsize="small")
|
|
372
|
+
|
|
373
|
+
# %% [markdown]
|
|
374
|
+
# ## Takeaways
|
|
375
|
+
# - The heteroscedastic GP model couples two latent GPs, enabling separate control of
|
|
376
|
+
# epistemic and aleatoric uncertainties.
|
|
377
|
+
# - We support multiple positivity transforms for the noise process; the choice
|
|
378
|
+
# affects the implied variance tails and should reflect prior beliefs.
|
|
379
|
+
# - Inducing points for the signal and noise processes can be tuned independently to
|
|
380
|
+
# balance computational budget against the local complexity of each function.
|
|
381
|
+
# - The ELBO implementation automatically selects the tightest analytical bound
|
|
382
|
+
# available, streamlining heteroscedastic inference workflows.
|
|
383
|
+
|
|
384
|
+
# %% [markdown]
|
|
385
|
+
# ## System configuration
|
|
386
|
+
|
|
387
|
+
# %%
|
|
388
|
+
# %reload_ext watermark
|
|
389
|
+
# %watermark -n -u -v -iv -w -a 'Thomas Pinder'
|
|
@@ -29,7 +29,6 @@ import matplotlib as mpl
|
|
|
29
29
|
import matplotlib.pyplot as plt
|
|
30
30
|
|
|
31
31
|
from examples.utils import (
|
|
32
|
-
clean_legend,
|
|
33
32
|
use_mpl_style,
|
|
34
33
|
)
|
|
35
34
|
|
|
@@ -129,26 +128,26 @@ prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
|
|
|
129
128
|
|
|
130
129
|
# %%
|
|
131
130
|
# %% [markdown]
|
|
132
|
-
prior_dist = prior.predict(xtest, return_covariance_type="dense")
|
|
133
|
-
|
|
134
|
-
prior_mean = prior_dist.mean
|
|
135
|
-
prior_std = prior_dist.variance
|
|
136
|
-
samples = prior_dist.sample(key=key, sample_shape=(20,))
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
fig, ax = plt.subplots()
|
|
140
|
-
ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
|
|
141
|
-
ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
|
|
142
|
-
ax.fill_between(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
150
|
-
ax.legend(loc="best")
|
|
151
|
-
ax = clean_legend(ax)
|
|
131
|
+
# prior_dist = prior.predict(xtest, return_covariance_type="dense")
|
|
132
|
+
#
|
|
133
|
+
# prior_mean = prior_dist.mean
|
|
134
|
+
# prior_std = prior_dist.variance
|
|
135
|
+
# samples = prior_dist.sample(key=key, sample_shape=(20,))
|
|
136
|
+
#
|
|
137
|
+
#
|
|
138
|
+
# fig, ax = plt.subplots()
|
|
139
|
+
# ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
|
|
140
|
+
# ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
|
|
141
|
+
# ax.fill_between(
|
|
142
|
+
# xtest.flatten(),
|
|
143
|
+
# prior_mean - prior_std,
|
|
144
|
+
# prior_mean + prior_std,
|
|
145
|
+
# alpha=0.3,
|
|
146
|
+
# color=cols[1],
|
|
147
|
+
# label="Prior variance",
|
|
148
|
+
# )
|
|
149
|
+
# ax.legend(loc="best")
|
|
150
|
+
# ax = clean_legend(ax)
|
|
152
151
|
|
|
153
152
|
# %% [markdown]
|
|
154
153
|
# ## Constructing the posterior
|
|
@@ -217,13 +216,15 @@ print(-gpx.objectives.conjugate_mll(opt_posterior, D))
|
|
|
217
216
|
# this, we use our defined `posterior` and `likelihood` at our test inputs to obtain
|
|
218
217
|
# the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean`
|
|
219
218
|
# and `stddev` can be used to extract the predictive mean and standard deviatation.
|
|
220
|
-
#
|
|
219
|
+
#
|
|
221
220
|
# We are only concerned here about the variance between the test points and themselves, so
|
|
222
221
|
# we can just copute the diagonal version of the covariance. We enforce this by using
|
|
223
222
|
# `return_covariance_type = "diagonal"` in the `predict` call.
|
|
224
223
|
|
|
225
224
|
# %%
|
|
226
|
-
latent_dist = opt_posterior.predict(
|
|
225
|
+
latent_dist = opt_posterior.predict(
|
|
226
|
+
xtest, train_data=D, return_covariance_type="diagonal"
|
|
227
|
+
)
|
|
227
228
|
predictive_dist = opt_posterior.likelihood(latent_dist)
|
|
228
229
|
|
|
229
230
|
predictive_mean = predictive_dist.mean
|
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/thomaspinder/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.13.
|
|
43
|
+
__version__ = "0.13.4"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
|
@@ -23,6 +23,7 @@ from gpjax.kernels import (
|
|
|
23
23
|
Matern32,
|
|
24
24
|
Matern52,
|
|
25
25
|
)
|
|
26
|
+
from gpjax.likelihoods import HeteroscedasticGaussian
|
|
26
27
|
|
|
27
28
|
CitationType = Union[None, str, Dict[str, str]]
|
|
28
29
|
|
|
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
|
|
|
149
150
|
booktitle="Advances in neural information processing systems",
|
|
150
151
|
citation_type="article",
|
|
151
152
|
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@cite.register(HeteroscedasticGaussian)
|
|
156
|
+
def _(tree) -> PaperCitation:
|
|
157
|
+
return PaperCitation(
|
|
158
|
+
citation_key="lazaro2011variational",
|
|
159
|
+
authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
|
|
160
|
+
title="Variational heteroscedastic Gaussian process regression",
|
|
161
|
+
year="2011",
|
|
162
|
+
booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
|
|
163
|
+
citation_type="inproceedings",
|
|
164
|
+
)
|
|
@@ -32,8 +32,10 @@ from gpjax.distributions import GaussianDistribution
|
|
|
32
32
|
from gpjax.kernels import RFF
|
|
33
33
|
from gpjax.kernels.base import AbstractKernel
|
|
34
34
|
from gpjax.likelihoods import (
|
|
35
|
+
AbstractHeteroscedasticLikelihood,
|
|
35
36
|
AbstractLikelihood,
|
|
36
37
|
Gaussian,
|
|
38
|
+
HeteroscedasticGaussian,
|
|
37
39
|
NonGaussian,
|
|
38
40
|
)
|
|
39
41
|
from gpjax.linalg import (
|
|
@@ -62,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
|
|
|
62
64
|
L = tp.TypeVar("L", bound=AbstractLikelihood)
|
|
63
65
|
NGL = tp.TypeVar("NGL", bound=NonGaussian)
|
|
64
66
|
GL = tp.TypeVar("GL", bound=Gaussian)
|
|
67
|
+
HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
|
|
65
68
|
|
|
66
69
|
|
|
67
70
|
class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
@@ -476,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
476
479
|
raise NotImplementedError
|
|
477
480
|
|
|
478
481
|
|
|
482
|
+
class LatentPosterior(AbstractPosterior[P, L]):
|
|
483
|
+
r"""A posterior shell used to expose prior structure without inference."""
|
|
484
|
+
|
|
485
|
+
def predict(
|
|
486
|
+
self,
|
|
487
|
+
test_inputs: Num[Array, "N D"],
|
|
488
|
+
train_data: Dataset,
|
|
489
|
+
*,
|
|
490
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
491
|
+
) -> GaussianDistribution:
|
|
492
|
+
raise NotImplementedError(
|
|
493
|
+
"LatentPosteriors are a lightweight wrapper for priors and do not "
|
|
494
|
+
"implement predictive distributions. Use a variational family for inference."
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
|
|
479
498
|
class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
480
499
|
r"""A Conjuate Gaussian process posterior object.
|
|
481
500
|
|
|
@@ -839,6 +858,40 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
839
858
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
|
|
840
859
|
|
|
841
860
|
|
|
861
|
+
class HeteroscedasticPosterior(LatentPosterior[P, HL]):
|
|
862
|
+
r"""Posterior shell for heteroscedastic likelihoods.
|
|
863
|
+
|
|
864
|
+
The posterior retains both the signal and noise priors; inference is delegated
|
|
865
|
+
to variational families and specialised objectives.
|
|
866
|
+
"""
|
|
867
|
+
|
|
868
|
+
def __init__(
|
|
869
|
+
self,
|
|
870
|
+
prior: AbstractPrior[M, K],
|
|
871
|
+
likelihood: HL,
|
|
872
|
+
jitter: float = 1e-6,
|
|
873
|
+
):
|
|
874
|
+
if likelihood.noise_prior is None:
|
|
875
|
+
raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
|
|
876
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
877
|
+
self.noise_prior = likelihood.noise_prior
|
|
878
|
+
self.noise_posterior = LatentPosterior(
|
|
879
|
+
prior=self.noise_prior, likelihood=likelihood, jitter=jitter
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
|
|
884
|
+
r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
|
|
885
|
+
|
|
886
|
+
def __init__(
|
|
887
|
+
self,
|
|
888
|
+
prior: AbstractPrior[M, K],
|
|
889
|
+
likelihood: HL,
|
|
890
|
+
jitter: float = 1e-6,
|
|
891
|
+
):
|
|
892
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
893
|
+
|
|
894
|
+
|
|
842
895
|
#######################
|
|
843
896
|
# Utils
|
|
844
897
|
#######################
|
|
@@ -854,6 +907,18 @@ def construct_posterior( # noqa: F811
|
|
|
854
907
|
) -> NonConjugatePosterior[P, NGL]: ...
|
|
855
908
|
|
|
856
909
|
|
|
910
|
+
@tp.overload
|
|
911
|
+
def construct_posterior( # noqa: F811
|
|
912
|
+
prior: P, likelihood: HeteroscedasticGaussian
|
|
913
|
+
) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
@tp.overload
|
|
917
|
+
def construct_posterior( # noqa: F811
|
|
918
|
+
prior: P, likelihood: AbstractHeteroscedasticLikelihood
|
|
919
|
+
) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
|
|
920
|
+
|
|
921
|
+
|
|
857
922
|
def construct_posterior(prior, likelihood): # noqa: F811
|
|
858
923
|
r"""Utility function for constructing a posterior object from a prior and
|
|
859
924
|
likelihood. The function will automatically select the correct posterior
|
|
@@ -873,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
|
|
|
873
938
|
if isinstance(likelihood, Gaussian):
|
|
874
939
|
return ConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
875
940
|
|
|
941
|
+
if (
|
|
942
|
+
isinstance(likelihood, HeteroscedasticGaussian)
|
|
943
|
+
and likelihood.supports_tight_bound()
|
|
944
|
+
):
|
|
945
|
+
return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
|
|
946
|
+
|
|
947
|
+
if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
|
|
948
|
+
return ChainedPosterior(prior=prior, likelihood=likelihood)
|
|
949
|
+
|
|
876
950
|
return NonConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
877
951
|
|
|
878
952
|
|
|
@@ -911,7 +985,10 @@ __all__ = [
|
|
|
911
985
|
"AbstractPrior",
|
|
912
986
|
"Prior",
|
|
913
987
|
"AbstractPosterior",
|
|
988
|
+
"LatentPosterior",
|
|
914
989
|
"ConjugatePosterior",
|
|
915
990
|
"NonConjugatePosterior",
|
|
991
|
+
"HeteroscedasticPosterior",
|
|
992
|
+
"ChainedPosterior",
|
|
916
993
|
"construct_posterior",
|
|
917
994
|
]
|