gpjax 0.13.3__tar.gz → 0.13.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {gpjax-0.13.3 → gpjax-0.13.4}/.gitignore +5 -0
  2. {gpjax-0.13.3 → gpjax-0.13.4}/PKG-INFO +2 -2
  3. {gpjax-0.13.3 → gpjax-0.13.4}/README.md +1 -1
  4. gpjax-0.13.4/examples/heteroscedastic_inference.py +389 -0
  5. {gpjax-0.13.3 → gpjax-0.13.4}/examples/regression.py +24 -23
  6. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/__init__.py +1 -1
  7. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/citation.py +13 -0
  8. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/gps.py +77 -0
  9. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/likelihoods.py +234 -0
  10. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/mean_functions.py +2 -2
  11. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/objectives.py +56 -1
  12. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/parameters.py +8 -1
  13. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/variational_families.py +129 -0
  14. {gpjax-0.13.3 → gpjax-0.13.4}/mkdocs.yml +1 -0
  15. {gpjax-0.13.3 → gpjax-0.13.4}/pyproject.toml +2 -1
  16. {gpjax-0.13.3 → gpjax-0.13.4}/tests/conftest.py +7 -0
  17. {gpjax-0.13.3 → gpjax-0.13.4}/tests/integration_tests.py +9 -2
  18. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_citations.py +16 -0
  19. gpjax-0.13.4/tests/test_heteroscedastic.py +407 -0
  20. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_mean_functions.py +16 -1
  21. {gpjax-0.13.3 → gpjax-0.13.4}/uv.lock +23 -0
  22. gpjax-0.13.3/.github/workflows/pr_greeting.yml +0 -62
  23. {gpjax-0.13.3 → gpjax-0.13.4}/.github/CODE_OF_CONDUCT.md +0 -0
  24. {gpjax-0.13.3 → gpjax-0.13.4}/.github/FUNDING.yml +0 -0
  25. {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  26. {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  27. {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  28. {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  29. {gpjax-0.13.3 → gpjax-0.13.4}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  30. {gpjax-0.13.3 → gpjax-0.13.4}/.github/codecov.yml +0 -0
  31. {gpjax-0.13.3 → gpjax-0.13.4}/.github/commitlint.config.js +0 -0
  32. {gpjax-0.13.3 → gpjax-0.13.4}/.github/dependabot.yml +0 -0
  33. {gpjax-0.13.3 → gpjax-0.13.4}/.github/labeler.yml +0 -0
  34. {gpjax-0.13.3 → gpjax-0.13.4}/.github/labels.yml +0 -0
  35. {gpjax-0.13.3 → gpjax-0.13.4}/.github/pull_request_template.md +0 -0
  36. {gpjax-0.13.3 → gpjax-0.13.4}/.github/release-drafter.yml +0 -0
  37. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/auto-label.yml +0 -0
  38. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/build_docs.yml +0 -0
  39. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/commit-lint.yml +0 -0
  40. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/integration.yml +0 -0
  41. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/release.yml +0 -0
  42. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/ruff.yml +0 -0
  43. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/security-analysis.yml +0 -0
  44. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/test_docs.yml +0 -0
  45. {gpjax-0.13.3 → gpjax-0.13.4}/.github/workflows/tests.yml +0 -0
  46. {gpjax-0.13.3 → gpjax-0.13.4}/CITATION.bib +0 -0
  47. {gpjax-0.13.3 → gpjax-0.13.4}/LICENSE.txt +0 -0
  48. {gpjax-0.13.3 → gpjax-0.13.4}/Makefile +0 -0
  49. {gpjax-0.13.3 → gpjax-0.13.4}/docs/CODE_OF_CONDUCT.md +0 -0
  50. {gpjax-0.13.3 → gpjax-0.13.4}/docs/GOVERNANCE.md +0 -0
  51. {gpjax-0.13.3 → gpjax-0.13.4}/docs/contributing.md +0 -0
  52. {gpjax-0.13.3 → gpjax-0.13.4}/docs/design.md +0 -0
  53. {gpjax-0.13.3 → gpjax-0.13.4}/docs/index.md +0 -0
  54. {gpjax-0.13.3 → gpjax-0.13.4}/docs/index.rst +0 -0
  55. {gpjax-0.13.3 → gpjax-0.13.4}/docs/installation.md +0 -0
  56. {gpjax-0.13.3 → gpjax-0.13.4}/docs/javascripts/katex.js +0 -0
  57. {gpjax-0.13.3 → gpjax-0.13.4}/docs/refs.bib +0 -0
  58. {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/gen_examples.py +0 -0
  59. {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/gen_pages.py +0 -0
  60. {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/notebook_converter.py +0 -0
  61. {gpjax-0.13.3 → gpjax-0.13.4}/docs/scripts/sharp_bits_figure.py +0 -0
  62. {gpjax-0.13.3 → gpjax-0.13.4}/docs/sharp_bits.md +0 -0
  63. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/GP.pdf +0 -0
  64. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/GP.svg +0 -0
  65. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/bijector_figure.svg +0 -0
  66. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/css/gpjax_theme.css +0 -0
  67. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/favicon.ico +0 -0
  68. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax.mplstyle +0 -0
  69. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax_logo.pdf +0 -0
  70. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/gpjax_logo.svg +0 -0
  71. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/lato.ttf +0 -0
  72. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/logo.png +0 -0
  73. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/logo.svg +0 -0
  74. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/jaxkern/main.py +0 -0
  75. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/step_size_figure.png +0 -0
  76. {gpjax-0.13.3 → gpjax-0.13.4}/docs/static/step_size_figure.svg +0 -0
  77. {gpjax-0.13.3 → gpjax-0.13.4}/docs/stylesheets/extra.css +0 -0
  78. {gpjax-0.13.3 → gpjax-0.13.4}/docs/stylesheets/permalinks.css +0 -0
  79. {gpjax-0.13.3 → gpjax-0.13.4}/examples/backend.py +0 -0
  80. {gpjax-0.13.3 → gpjax-0.13.4}/examples/barycentres/barycentre_gp.gif +0 -0
  81. {gpjax-0.13.3 → gpjax-0.13.4}/examples/barycentres.py +0 -0
  82. {gpjax-0.13.3 → gpjax-0.13.4}/examples/classification.py +0 -0
  83. {gpjax-0.13.3 → gpjax-0.13.4}/examples/collapsed_vi.py +0 -0
  84. {gpjax-0.13.3 → gpjax-0.13.4}/examples/constructing_new_kernels.py +0 -0
  85. {gpjax-0.13.3 → gpjax-0.13.4}/examples/data/max_tempeature_switzerland.csv +0 -0
  86. {gpjax-0.13.3 → gpjax-0.13.4}/examples/data/yacht_hydrodynamics.data +0 -0
  87. {gpjax-0.13.3 → gpjax-0.13.4}/examples/deep_kernels.py +0 -0
  88. {gpjax-0.13.3 → gpjax-0.13.4}/examples/gpjax.mplstyle +0 -0
  89. {gpjax-0.13.3 → gpjax-0.13.4}/examples/graph_kernels.py +0 -0
  90. {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps/decomposed_mll.png +0 -0
  91. {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps/generating_process.png +0 -0
  92. {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_gps.py +0 -0
  93. {gpjax-0.13.3 → gpjax-0.13.4}/examples/intro_to_kernels.py +0 -0
  94. {gpjax-0.13.3 → gpjax-0.13.4}/examples/likelihoods_guide.py +0 -0
  95. {gpjax-0.13.3 → gpjax-0.13.4}/examples/oceanmodelling.py +0 -0
  96. {gpjax-0.13.3 → gpjax-0.13.4}/examples/poisson.py +0 -0
  97. {gpjax-0.13.3 → gpjax-0.13.4}/examples/uncollapsed_vi.py +0 -0
  98. {gpjax-0.13.3 → gpjax-0.13.4}/examples/utils.py +0 -0
  99. {gpjax-0.13.3 → gpjax-0.13.4}/examples/yacht.py +0 -0
  100. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/dataset.py +0 -0
  101. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/distributions.py +0 -0
  102. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/fit.py +0 -0
  103. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/integrators.py +0 -0
  104. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/__init__.py +0 -0
  105. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/approximations/__init__.py +0 -0
  106. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/approximations/rff.py +0 -0
  107. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/base.py +0 -0
  108. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/__init__.py +0 -0
  109. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/base.py +0 -0
  110. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/basis_functions.py +0 -0
  111. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  112. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/dense.py +0 -0
  113. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/diagonal.py +0 -0
  114. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/computations/eigen.py +0 -0
  115. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  116. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/graph.py +0 -0
  117. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/non_euclidean/utils.py +0 -0
  118. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/__init__.py +0 -0
  119. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  120. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/linear.py +0 -0
  121. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  122. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/__init__.py +0 -0
  123. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/base.py +0 -0
  124. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern12.py +0 -0
  125. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern32.py +0 -0
  126. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/matern52.py +0 -0
  127. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/periodic.py +0 -0
  128. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  129. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  130. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/rbf.py +0 -0
  131. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/utils.py +0 -0
  132. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/kernels/stationary/white.py +0 -0
  133. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/__init__.py +0 -0
  134. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/operations.py +0 -0
  135. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/operators.py +0 -0
  136. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/linalg/utils.py +0 -0
  137. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/numpyro_extras.py +0 -0
  138. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/scan.py +0 -0
  139. {gpjax-0.13.3 → gpjax-0.13.4}/gpjax/typing.py +0 -0
  140. {gpjax-0.13.3 → gpjax-0.13.4}/static/CONTRIBUTING.md +0 -0
  141. {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.bib +0 -0
  142. {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.md +0 -0
  143. {gpjax-0.13.3 → gpjax-0.13.4}/static/paper.pdf +0 -0
  144. {gpjax-0.13.3 → gpjax-0.13.4}/tests/__init__.py +0 -0
  145. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_dataset.py +0 -0
  146. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_fit.py +0 -0
  147. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_gaussian_distribution.py +0 -0
  148. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_gps.py +0 -0
  149. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_imports.py +0 -0
  150. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_integrators.py +0 -0
  151. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/__init__.py +0 -0
  152. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_approximations.py +0 -0
  153. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_base.py +0 -0
  154. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_computation.py +0 -0
  155. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_non_euclidean.py +0 -0
  156. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_nonstationary.py +0 -0
  157. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_stationary.py +0 -0
  158. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_kernels/test_utils.py +0 -0
  159. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_likelihoods.py +0 -0
  160. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_linalg.py +0 -0
  161. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_markdown.py +0 -0
  162. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_numpyro_extras.py +0 -0
  163. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_objectives.py +0 -0
  164. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_parameters.py +0 -0
  165. {gpjax-0.13.3 → gpjax-0.13.4}/tests/test_variational_families.py +0 -0
@@ -153,3 +153,8 @@ node_modules/
153
153
 
154
154
  docs/api
155
155
  docs/_examples
156
+ local_libs/
157
+ local_papers/
158
+ GEMINI.md
159
+ AGENTS.md
160
+ plans/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.3
3
+ Version: 0.13.4
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
@@ -141,7 +141,7 @@ GPJax into the package it is today.
141
141
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
142
142
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
143
143
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
144
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
144
+ > - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
145
145
  > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
146
146
  > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
147
147
  > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
@@ -70,7 +70,7 @@ GPJax into the package it is today.
70
70
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
71
71
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
72
72
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
73
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
73
+ > - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
74
74
  > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
75
75
  > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
76
76
  > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
@@ -0,0 +1,389 @@
1
+ # -*- coding: utf-8 -*-
2
+ # ---
3
+ # jupyter:
4
+ # jupytext:
5
+ # cell_metadata_filter: -all
6
+ # custom_cell_magics: kql
7
+ # text_representation:
8
+ # extension: .py
9
+ # format_name: percent
10
+ # format_version: '1.3'
11
+ # jupytext_version: 1.17.3
12
+ # kernelspec:
13
+ # display_name: .venv
14
+ # language: python
15
+ # name: python3
16
+ # ---
17
+
18
+ # %% [markdown]
19
+ # # Heteroscedastic inference for regression and classification
20
+ #
21
+ # This notebook shows how to fit a heteroscedastic Gaussian processes (GPs) that
22
+ # allows one to perform regression where there exists non-constant, or
23
+ # input-dependent, noise.
24
+ #
25
+ #
26
+ # ## Background
27
+ # A heteroscedastic GP couples two latent functions:
28
+ # - A **signal GP** $f(\cdot)$ for the mean response.
29
+ # - A **noise GP** $g(\cdot)$ that maps to a positive variance
30
+ # $\sigma^2(x) = \phi(g(x))$ via a positivity transform $\phi$ (typically
31
+ # ${\rm exp}$ or ${\rm softplus}$). Intuitively, we are introducing a pair of GPs;
32
+ # one to model the latent mean, and a second that models the log-noise variance. This
33
+ # is in direct contrast a
34
+ # [homoscedastic GP](https://docs.jaxgaussianprocesses.com/_examples/regression/)
35
+ # where we learn a constant value for the noise.
36
+ #
37
+ # In the Gaussian case, the observed response follows
38
+ # $$y \mid f, g \sim \mathcal{N}(f, \sigma^2(x)).$$
39
+ # Variational inference works with independent posteriors $q(f)q(g)$, combining the
40
+ # moments of each into an ELBO. For non-Gaussian likelihoods the same structure
41
+ # remains; only the expected log-likelihood changes.
42
+
43
+ # %%
44
+ from jax import config
45
+ import jax.numpy as jnp
46
+ import jax.random as jr
47
+ import matplotlib as mpl
48
+ import matplotlib.pyplot as plt
49
+ import optax as ox
50
+
51
+ from examples.utils import use_mpl_style
52
+ import gpjax as gpx
53
+ from gpjax.likelihoods import (
54
+ HeteroscedasticGaussian,
55
+ LogNormalTransform,
56
+ SoftplusTransform,
57
+ )
58
+ from gpjax.variational_families import (
59
+ HeteroscedasticVariationalFamily,
60
+ VariationalGaussianInit,
61
+ )
62
+
63
+ # Enable Float64 for stable linear algebra.
64
+ config.update("jax_enable_x64", True)
65
+
66
+
67
+ use_mpl_style()
68
+ key = jr.key(123)
69
+ cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
70
+
71
+
72
+ # %% [markdown]
73
+ # ## Dataset simulation
74
+ # We simulate whose mean and noise levels vary with
75
+ # the input. We sample inputs $x \sim \mathcal{U}(0, 1)$ and define the
76
+ # latent signal to be
77
+ # $$f(x) = (x - 0.5)^2 + 0.05;$$
78
+ # a smooth bowl-shaped curve. The observation standard deviation is chosen to be
79
+ # proportional to the signal,
80
+ # $$\sigma(x) = 0.5\,f(x),$$
81
+ # which yields the heteroscedastic generative model
82
+ # $$y \mid x \sim \mathcal{N}\!\big(f(x), \sigma^2(x)\big).$$
83
+ # This construction makes the noise small near the minimum of the bowl and much
84
+ # larger in the tails. We also create a dense test grid that we shall use later for
85
+ # visualising posterior fits and predictive uncertainty.
86
+
87
+ # %%
88
+ # Create data with input-dependent variance.
89
+ key, x_key, noise_key = jr.split(key, 3)
90
+ n = 200
91
+ x = jr.uniform(x_key, (n, 1), minval=0.0, maxval=1.0)
92
+ signal = (x - 0.5) ** 2 + 0.05
93
+ noise_scale = 0.5 * signal
94
+ noise = noise_scale * jr.normal(noise_key, shape=(n, 1))
95
+ y = signal + noise
96
+ train = gpx.Dataset(X=x, y=y)
97
+
98
+ xtest = jnp.linspace(-0.1, 1.1, 200)[:, None]
99
+ signal_test = (xtest - 0.5) ** 2 + 0.05
100
+ noise_scale_test = 0.5 * signal_test
101
+ noise_test = noise_scale_test * jr.normal(noise_key, shape=(200, 1))
102
+ ytest = signal_test + noise_test
103
+
104
+ fig, ax = plt.subplots()
105
+ ax.plot(x, y, "o", label="Observations", alpha=0.7, color=cols[0])
106
+ ax.plot(xtest, signal_test, label="Signal", alpha=0.7, color=cols[1])
107
+ ax.plot(xtest, noise_scale_test, label="Noise scale", alpha=0.7, color=cols[2])
108
+ ax.set_xlabel("$x$")
109
+ ax.set_ylabel("$y$")
110
+ ax.legend(loc="upper left")
111
+
112
+ # %% [markdown]
113
+ # For a homoscedastic baseline, compare this figure with the
114
+ # [Gaussian process regression notebook](https://docs.jaxgaussianprocesses.com/_examples/regression/)
115
+ # (`examples/regression.py`), where a single latent GP is paired with constant
116
+ # observation noise.
117
+
118
+ # %% [markdown]
119
+ # ## Prior specification
120
+ # We place independent Gaussian process priors on the signal and noise processes:
121
+ # $$f \sim \mathcal{GP}\big(0, k_f\big), \qquad g \sim \mathcal{GP}\big(0, k_g\big),$$
122
+ # where $k_f$ and $k_g$ are stationary squared-exponential kernels with unit
123
+ # variance and lengthscale of one. The noise process $g$ is mapped to the variance
124
+ # via the logarithmic transform in `LogNormalTransform`, giving
125
+ # $\sigma^2(x) = \exp\big(g(x)\big)$. The joint prior over $(f, g)$ combines with
126
+ # the heteroscedastic Gaussian likelihood,
127
+ # $$p(\mathbf{y} \mid f, g) = \prod_{i=1}^n
128
+ # \mathcal{N}\!\big(y_i \mid f(x_i), \exp(g(x_i))\big),$$
129
+ # to form the posterior target that we shall approximate variationally. The product
130
+ # syntax `signal_prior * likelihood` used below constructs this augmented GP model.
131
+
132
+ # %%
133
+ # Signal and noise priors.
134
+ signal_prior = gpx.gps.Prior(
135
+ mean_function=gpx.mean_functions.Zero(),
136
+ kernel=gpx.kernels.RBF(),
137
+ )
138
+ noise_prior = gpx.gps.Prior(
139
+ mean_function=gpx.mean_functions.Zero(),
140
+ kernel=gpx.kernels.RBF(),
141
+ )
142
+ likelihood = HeteroscedasticGaussian(
143
+ num_datapoints=train.n,
144
+ noise_prior=noise_prior,
145
+ noise_transform=LogNormalTransform(),
146
+ )
147
+ posterior = signal_prior * likelihood
148
+
149
+ # Variational family over both processes.
150
+ z = jnp.linspace(-3.2, 3.2, 25)[:, None]
151
+ q = HeteroscedasticVariationalFamily(
152
+ posterior=posterior,
153
+ inducing_inputs=z,
154
+ inducing_inputs_g=z,
155
+ )
156
+
157
+ # %% [markdown]
158
+ # The variational family introduces inducing variables for both latent functions,
159
+ # located at the set $Z = \{z_m\}_{m=1}^M$. These inducing variables summarise the
160
+ # infinite-dimensional GP priors in terms of multivariate Gaussian parameters.
161
+ # Optimising the evidence lower bound (ELBO) corresponds to adjusting the means and
162
+ # covariances of the variational posteriors $q(f)$ and $q(g)$ so that they best
163
+ # explain the observed data whilst remaining close to the prior. For a deeper look at
164
+ # these constructions in the homoscedastic setting, refer to the
165
+ # [Sparse Gaussian Process Regression](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
166
+ # (`examples/collapsed_vi.py`) and
167
+ # [Sparse Stochastic Variational Inference](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
168
+ # (`examples/uncollapsed_vi.py`) notebooks.
169
+
170
+ # %% [markdown]
171
+ # ### Optimisation
172
+ # With the model specified, we minimise the negative ELBO,
173
+ # $$\mathcal{L} = \mathbb{E}_{q(f)q(g)}\!\big[\log p(\mathbf{y}\mid f, g)\big]
174
+ # - \mathrm{KL}\!\left[q(f) \,\|\, p(f)\right]
175
+ # - \mathrm{KL}\!\left[q(g) \,\|\, p(g)\right],$$
176
+ # using the Adam optimiser. GPJax automatically selects the tight bound of
177
+ # Lázaro-Gredilla & Titsias (2011) when the likelihood is Gaussian, yielding an
178
+ # analytically tractable expectation over the latent noise process. The resulting
179
+ # optimisation iteratively updates the inducing posteriors for both latent GPs.
180
+
181
+ # %%
182
+ # Optimise the heteroscedastic ELBO (selects tighter bound).
183
+ objective = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
184
+ optimiser = ox.adam(1e-2)
185
+ q_trained, history = gpx.fit(
186
+ model=q,
187
+ objective=objective,
188
+ train_data=train,
189
+ optim=optimiser,
190
+ num_iters=10000,
191
+ verbose=False,
192
+ )
193
+
194
+ loss_trace = jnp.asarray(history)
195
+ print(f"Final regression ELBO: {-loss_trace[-1]:.3f}")
196
+
197
+ # %% [markdown]
198
+ # ## Prediction
199
+ # After training we obtain posterior marginals for both latent functions. To make a
200
+ # prediction we evaluate two quantities:
201
+ # 1. The latent posterior over $f$ (mean and variance), which reflects uncertainty
202
+ # in the latent function **prior** to observing noise.
203
+ # 2. The marginal predictive over observations, which integrates out both $f$ and
204
+ # $g$ to provide predictive intervals for future noisy measurements.
205
+ # The helper method `likelihood.predict` performs the second integration for us.
206
+
207
+ # %%
208
+ # Predict on a dense grid.
209
+ xtest = jnp.linspace(-0.1, 1.1, 200)[:, None]
210
+ mf, vf, mg, vg = q_trained.predict(xtest)
211
+
212
+ signal_pred, noise_pred = q_trained.predict_latents(xtest)
213
+ predictive = likelihood.predict(signal_pred, noise_pred)
214
+
215
+ fig, ax = plt.subplots()
216
+ ax.plot(train.X, train.y, "o", label="Observations", alpha=0.5)
217
+ ax.plot(xtest, mf, color="C0", label="Posterior mean")
218
+ ax.fill_between(
219
+ xtest.squeeze(),
220
+ (mf.squeeze() - 2 * jnp.sqrt(vf.squeeze())).squeeze(),
221
+ (mf.squeeze() + 2 * jnp.sqrt(vf.squeeze())).squeeze(),
222
+ color="C0",
223
+ alpha=0.15,
224
+ label="±2 std (latent)",
225
+ )
226
+ ax.fill_between(
227
+ xtest.squeeze(),
228
+ predictive.mean - 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)),
229
+ predictive.mean + 2 * jnp.sqrt(jnp.diag(predictive.covariance_matrix)),
230
+ color="C1",
231
+ alpha=0.15,
232
+ label="±2 std (observed)",
233
+ )
234
+ ax.set_xlabel("$x$")
235
+ ax.set_ylabel("$y$")
236
+ ax.legend(loc="upper left")
237
+ ax.set_title("Heteroscedastic regression")
238
+
239
+ # %% [markdown]
240
+ # The latent intervals quantify epistemic uncertainty about $f$, whereas the broader
241
+ # observed band adds the aleatoric noise predicted by $g$. The widening of the orange
242
+ # band in the right half matches the ground-truth construction of the dataset.
243
+
244
+ # %% [markdown]
245
+ # ## Sparse Heteroscedastic Regression
246
+ #
247
+ # We now demonstrate how the aforementioned heteroscedastic approach can be extended
248
+ # into sparse scenarios, thus offering more favourable scalability as the size of our
249
+ # dataset grows. To achieve this we defined inducing points for both the signal and
250
+ # noise processes. Decoupling these grids allows us to focus modelling
251
+ # capacity where each latent function varies the most. The synthetic dataset below
252
+ # contains a smooth sinusoidal signal but exhibits a sharply peaked noise shock,
253
+ # mimicking the situation where certain regions of the input space are far noisier
254
+ # than others.
255
+
256
+ # %%
257
+ # Generate data
258
+ key, x_key, noise_key = jr.split(key, 3)
259
+ n = 300
260
+ x = jr.uniform(x_key, (n, 1), minval=-2.0, maxval=2.0)
261
+ signal = jnp.sin(2.0 * x)
262
+ # Gaussian bump of noise
263
+ noise_std = 0.1 + 0.5 * jnp.exp(-0.5 * ((x - 0.5) / 0.4) ** 2)
264
+ y = signal + noise_std * jr.normal(noise_key, shape=(n, 1))
265
+ data_adv = gpx.Dataset(X=x, y=y)
266
+
267
+ # %% [markdown]
268
+ # ### Model components
269
+ # We again adopt RBF priors for both processes but now apply a `SoftplusTransform`
270
+ # to the noise GP. This alternative map enforces positivity whilst avoiding the
271
+ # heavier tails induced by the log-normal transform. The `HeteroscedasticGaussian`
272
+ # likelihood seamlessly accepts the new transform.
273
+
274
+ # %%
275
+ # Define model components
276
+ mean_prior = gpx.gps.Prior(
277
+ mean_function=gpx.mean_functions.Zero(),
278
+ kernel=gpx.kernels.RBF(),
279
+ )
280
+ noise_prior_adv = gpx.gps.Prior(
281
+ mean_function=gpx.mean_functions.Zero(),
282
+ kernel=gpx.kernels.RBF(),
283
+ )
284
+ likelihood_adv = HeteroscedasticGaussian(
285
+ num_datapoints=data_adv.n,
286
+ noise_prior=noise_prior_adv,
287
+ noise_transform=SoftplusTransform(),
288
+ )
289
+ posterior_adv = mean_prior * likelihood_adv
290
+
291
+ # %%
292
+ # Configure variational family
293
+ # The signal requires a richer inducing set to capture its oscillations, whereas the
294
+ # noise process can be summarised with fewer points because the burst is localised.
295
+ z_signal = jnp.linspace(-2.0, 2.0, 30)[:, None]
296
+ z_noise = jnp.linspace(-2.0, 2.0, 15)[:, None]
297
+
298
+ # Use VariationalGaussianInit to pass specific configurations
299
+ q_init_f = VariationalGaussianInit(inducing_inputs=z_signal)
300
+ q_init_g = VariationalGaussianInit(inducing_inputs=z_noise)
301
+
302
+ q_adv = HeteroscedasticVariationalFamily(
303
+ posterior=posterior_adv,
304
+ signal_init=q_init_f,
305
+ noise_init=q_init_g,
306
+ )
307
+
308
+ # %% [markdown]
309
+ # The initialisation objects `VariationalGaussianInit` allow us to prescribe
310
+ # different inducing grids and initial covariance structures for $f$ and $g$. This
311
+ # flexibility is invaluable when working with large datasets where the latent
312
+ # functions have markedly different smoothness properties.
313
+
314
+ # %%
315
+ # Optimize
316
+ objective_adv = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
317
+ optimiser_adv = ox.adam(1e-2)
318
+ q_adv_trained, _ = gpx.fit(
319
+ model=q_adv,
320
+ objective=objective_adv,
321
+ train_data=data_adv,
322
+ optim=optimiser_adv,
323
+ num_iters=8000,
324
+ verbose=False,
325
+ )
326
+
327
+ # %%
328
+ # Plotting
329
+ xtest = jnp.linspace(-2.2, 2.2, 200)[:, None]
330
+ pred = q_adv_trained.predict(xtest)
331
+
332
+ # Unpack the named tuple
333
+ mf = pred.mean_f
334
+ vf = pred.variance_f
335
+ mg = pred.mean_g
336
+ vg = pred.variance_g
337
+
338
+ # Calculate total predictive variance
339
+ # The likelihood expects the *latent* noise distribution to compute the predictive
340
+ # but here we can just use the transformed expected variance for plotting.
341
+ # For accurate predictive intervals, we should use likelihood.predict.
342
+ signal_dist, noise_dist = q_adv_trained.predict_latents(xtest)
343
+ predictive_dist = likelihood_adv.predict(signal_dist, noise_dist)
344
+ predictive_mean = predictive_dist.mean
345
+ predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix))
346
+
347
+ fig, ax = plt.subplots()
348
+ ax.plot(x, y, "o", color="black", alpha=0.3, label="Data")
349
+ ax.plot(xtest, mf, color="C0", label="Signal Mean")
350
+ ax.fill_between(
351
+ xtest.squeeze(),
352
+ mf.squeeze() - 2 * jnp.sqrt(vf.squeeze()),
353
+ mf.squeeze() + 2 * jnp.sqrt(vf.squeeze()),
354
+ color="C0",
355
+ alpha=0.2,
356
+ label="Signal Uncertainty",
357
+ )
358
+
359
+ # Plot total uncertainty (signal + noise)
360
+ ax.plot(xtest, predictive_mean, "--", color="C1", alpha=0.5)
361
+ ax.fill_between(
362
+ xtest.squeeze(),
363
+ predictive_mean - 2 * predictive_std,
364
+ predictive_mean + 2 * predictive_std,
365
+ color="C1",
366
+ alpha=0.1,
367
+ label="Predictive Uncertainty (95%)",
368
+ )
369
+
370
+ ax.set_title("Heteroscedastic Regression with Custom Inducing Points")
371
+ ax.legend(loc="upper left", fontsize="small")
372
+
373
+ # %% [markdown]
374
+ # ## Takeaways
375
+ # - The heteroscedastic GP model couples two latent GPs, enabling separate control of
376
+ # epistemic and aleatoric uncertainties.
377
+ # - We support multiple positivity transforms for the noise process; the choice
378
+ # affects the implied variance tails and should reflect prior beliefs.
379
+ # - Inducing points for the signal and noise processes can be tuned independently to
380
+ # balance computational budget against the local complexity of each function.
381
+ # - The ELBO implementation automatically selects the tightest analytical bound
382
+ # available, streamlining heteroscedastic inference workflows.
383
+
384
+ # %% [markdown]
385
+ # ## System configuration
386
+
387
+ # %%
388
+ # %reload_ext watermark
389
+ # %watermark -n -u -v -iv -w -a 'Thomas Pinder'
@@ -29,7 +29,6 @@ import matplotlib as mpl
29
29
  import matplotlib.pyplot as plt
30
30
 
31
31
  from examples.utils import (
32
- clean_legend,
33
32
  use_mpl_style,
34
33
  )
35
34
 
@@ -129,26 +128,26 @@ prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
129
128
 
130
129
  # %%
131
130
  # %% [markdown]
132
- prior_dist = prior.predict(xtest, return_covariance_type="dense")
133
-
134
- prior_mean = prior_dist.mean
135
- prior_std = prior_dist.variance
136
- samples = prior_dist.sample(key=key, sample_shape=(20,))
137
-
138
-
139
- fig, ax = plt.subplots()
140
- ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
141
- ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
142
- ax.fill_between(
143
- xtest.flatten(),
144
- prior_mean - prior_std,
145
- prior_mean + prior_std,
146
- alpha=0.3,
147
- color=cols[1],
148
- label="Prior variance",
149
- )
150
- ax.legend(loc="best")
151
- ax = clean_legend(ax)
131
+ # prior_dist = prior.predict(xtest, return_covariance_type="dense")
132
+ #
133
+ # prior_mean = prior_dist.mean
134
+ # prior_std = prior_dist.variance
135
+ # samples = prior_dist.sample(key=key, sample_shape=(20,))
136
+ #
137
+ #
138
+ # fig, ax = plt.subplots()
139
+ # ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], label="Prior samples")
140
+ # ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
141
+ # ax.fill_between(
142
+ # xtest.flatten(),
143
+ # prior_mean - prior_std,
144
+ # prior_mean + prior_std,
145
+ # alpha=0.3,
146
+ # color=cols[1],
147
+ # label="Prior variance",
148
+ # )
149
+ # ax.legend(loc="best")
150
+ # ax = clean_legend(ax)
152
151
 
153
152
  # %% [markdown]
154
153
  # ## Constructing the posterior
@@ -217,13 +216,15 @@ print(-gpx.objectives.conjugate_mll(opt_posterior, D))
217
216
  # this, we use our defined `posterior` and `likelihood` at our test inputs to obtain
218
217
  # the predictive distribution as a `Distrax` multivariate Gaussian upon which `mean`
219
218
  # and `stddev` can be used to extract the predictive mean and standard deviatation.
220
- #
219
+ #
221
220
  # We are only concerned here about the variance between the test points and themselves, so
222
221
  # we can just copute the diagonal version of the covariance. We enforce this by using
223
222
  # `return_covariance_type = "diagonal"` in the `predict` call.
224
223
 
225
224
  # %%
226
- latent_dist = opt_posterior.predict(xtest, train_data=D, return_covariance_type="diagonal")
225
+ latent_dist = opt_posterior.predict(
226
+ xtest, train_data=D, return_covariance_type="diagonal"
227
+ )
227
228
  predictive_dist = opt_posterior.likelihood(latent_dist)
228
229
 
229
230
  predictive_mean = predictive_dist.mean
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/thomaspinder/GPJax"
42
42
  __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
43
- __version__ = "0.13.3"
43
+ __version__ = "0.13.4"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
@@ -23,6 +23,7 @@ from gpjax.kernels import (
23
23
  Matern32,
24
24
  Matern52,
25
25
  )
26
+ from gpjax.likelihoods import HeteroscedasticGaussian
26
27
 
27
28
  CitationType = Union[None, str, Dict[str, str]]
28
29
 
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
149
150
  booktitle="Advances in neural information processing systems",
150
151
  citation_type="article",
151
152
  )
153
+
154
+
155
+ @cite.register(HeteroscedasticGaussian)
156
+ def _(tree) -> PaperCitation:
157
+ return PaperCitation(
158
+ citation_key="lazaro2011variational",
159
+ authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
160
+ title="Variational heteroscedastic Gaussian process regression",
161
+ year="2011",
162
+ booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
163
+ citation_type="inproceedings",
164
+ )
@@ -32,8 +32,10 @@ from gpjax.distributions import GaussianDistribution
32
32
  from gpjax.kernels import RFF
33
33
  from gpjax.kernels.base import AbstractKernel
34
34
  from gpjax.likelihoods import (
35
+ AbstractHeteroscedasticLikelihood,
35
36
  AbstractLikelihood,
36
37
  Gaussian,
38
+ HeteroscedasticGaussian,
37
39
  NonGaussian,
38
40
  )
39
41
  from gpjax.linalg import (
@@ -62,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
62
64
  L = tp.TypeVar("L", bound=AbstractLikelihood)
63
65
  NGL = tp.TypeVar("NGL", bound=NonGaussian)
64
66
  GL = tp.TypeVar("GL", bound=Gaussian)
67
+ HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
65
68
 
66
69
 
67
70
  class AbstractPrior(nnx.Module, tp.Generic[M, K]):
@@ -476,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
476
479
  raise NotImplementedError
477
480
 
478
481
 
482
+ class LatentPosterior(AbstractPosterior[P, L]):
483
+ r"""A posterior shell used to expose prior structure without inference."""
484
+
485
+ def predict(
486
+ self,
487
+ test_inputs: Num[Array, "N D"],
488
+ train_data: Dataset,
489
+ *,
490
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
491
+ ) -> GaussianDistribution:
492
+ raise NotImplementedError(
493
+ "LatentPosteriors are a lightweight wrapper for priors and do not "
494
+ "implement predictive distributions. Use a variational family for inference."
495
+ )
496
+
497
+
479
498
  class ConjugatePosterior(AbstractPosterior[P, GL]):
480
499
  r"""A Conjuate Gaussian process posterior object.
481
500
 
@@ -839,6 +858,40 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
839
858
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
840
859
 
841
860
 
861
+ class HeteroscedasticPosterior(LatentPosterior[P, HL]):
862
+ r"""Posterior shell for heteroscedastic likelihoods.
863
+
864
+ The posterior retains both the signal and noise priors; inference is delegated
865
+ to variational families and specialised objectives.
866
+ """
867
+
868
+ def __init__(
869
+ self,
870
+ prior: AbstractPrior[M, K],
871
+ likelihood: HL,
872
+ jitter: float = 1e-6,
873
+ ):
874
+ if likelihood.noise_prior is None:
875
+ raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
876
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
877
+ self.noise_prior = likelihood.noise_prior
878
+ self.noise_posterior = LatentPosterior(
879
+ prior=self.noise_prior, likelihood=likelihood, jitter=jitter
880
+ )
881
+
882
+
883
+ class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
884
+ r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
885
+
886
+ def __init__(
887
+ self,
888
+ prior: AbstractPrior[M, K],
889
+ likelihood: HL,
890
+ jitter: float = 1e-6,
891
+ ):
892
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
893
+
894
+
842
895
  #######################
843
896
  # Utils
844
897
  #######################
@@ -854,6 +907,18 @@ def construct_posterior( # noqa: F811
854
907
  ) -> NonConjugatePosterior[P, NGL]: ...
855
908
 
856
909
 
910
+ @tp.overload
911
+ def construct_posterior( # noqa: F811
912
+ prior: P, likelihood: HeteroscedasticGaussian
913
+ ) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
914
+
915
+
916
+ @tp.overload
917
+ def construct_posterior( # noqa: F811
918
+ prior: P, likelihood: AbstractHeteroscedasticLikelihood
919
+ ) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
920
+
921
+
857
922
  def construct_posterior(prior, likelihood): # noqa: F811
858
923
  r"""Utility function for constructing a posterior object from a prior and
859
924
  likelihood. The function will automatically select the correct posterior
@@ -873,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
873
938
  if isinstance(likelihood, Gaussian):
874
939
  return ConjugatePosterior(prior=prior, likelihood=likelihood)
875
940
 
941
+ if (
942
+ isinstance(likelihood, HeteroscedasticGaussian)
943
+ and likelihood.supports_tight_bound()
944
+ ):
945
+ return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
946
+
947
+ if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
948
+ return ChainedPosterior(prior=prior, likelihood=likelihood)
949
+
876
950
  return NonConjugatePosterior(prior=prior, likelihood=likelihood)
877
951
 
878
952
 
@@ -911,7 +985,10 @@ __all__ = [
911
985
  "AbstractPrior",
912
986
  "Prior",
913
987
  "AbstractPosterior",
988
+ "LatentPosterior",
914
989
  "ConjugatePosterior",
915
990
  "NonConjugatePosterior",
991
+ "HeteroscedasticPosterior",
992
+ "ChainedPosterior",
916
993
  "construct_posterior",
917
994
  ]