gpjax 0.13.4__tar.gz → 0.13.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/auto-label.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/build_docs.yml +2 -2
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/commit-lint.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/integration.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/release.yml +10 -10
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/ruff.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/security-analysis.yml +3 -3
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/test_docs.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/tests.yml +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/PKG-INFO +3 -2
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/sharp_bits.md +32 -28
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/heteroscedastic_inference.py +29 -24
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/__init__.py +1 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/parameters.py +2 -1
- {gpjax-0.13.4 → gpjax-0.13.5}/mkdocs.yml +1 -2
- {gpjax-0.13.4 → gpjax-0.13.5}/pyproject.toml +5 -4
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_parameters.py +20 -0
- gpjax-0.13.5/uv.lock +4279 -0
- gpjax-0.13.4/uv.lock +0 -3558
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/FUNDING.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/codecov.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/commitlint.config.js +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/dependabot.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/labeler.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/labels.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/pull_request_template.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.github/release-drafter.yml +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/.gitignore +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/CITATION.bib +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/LICENSE.txt +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/Makefile +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/README.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/contributing.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/design.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/index.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/index.rst +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/installation.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/javascripts/katex.js +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/refs.bib +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/GP.pdf +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/GP.svg +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/favicon.ico +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/backend.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/barycentres.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/classification.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/collapsed_vi.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/deep_kernels.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/graph_kernels.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/oceanmodelling.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/poisson.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/regression.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/utils.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/examples/yacht.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/citation.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/dataset.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/distributions.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/fit.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/gps.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/integrators.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/base.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/likelihoods.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/operations.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/operators.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/utils.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/mean_functions.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/objectives.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/scan.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/typing.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/variational_families.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.bib +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.md +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.pdf +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/conftest.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/integration_tests.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_citations.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_dataset.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_fit.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_gps.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_heteroscedastic.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_imports.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_integrators.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_likelihoods.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_linalg.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_markdown.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_mean_functions.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_objectives.py +0 -0
- {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_variational_families.py +0 -0
|
@@ -27,7 +27,7 @@ jobs:
|
|
|
27
27
|
steps:
|
|
28
28
|
# Grap the latest commit from the branch
|
|
29
29
|
- name: Checkout the branch
|
|
30
|
-
uses: actions/checkout@
|
|
30
|
+
uses: actions/checkout@v6
|
|
31
31
|
with:
|
|
32
32
|
persist-credentials: false
|
|
33
33
|
|
|
@@ -61,7 +61,7 @@ jobs:
|
|
|
61
61
|
uv run mkdocs build
|
|
62
62
|
|
|
63
63
|
- name: Deploy Page 🚀
|
|
64
|
-
uses: JamesIves/github-pages-deploy-action@v4.
|
|
64
|
+
uses: JamesIves/github-pages-deploy-action@v4.8.0
|
|
65
65
|
with:
|
|
66
66
|
branch: gh-pages
|
|
67
67
|
folder: site
|
|
@@ -29,7 +29,7 @@ jobs:
|
|
|
29
29
|
|
|
30
30
|
steps:
|
|
31
31
|
- name: Checkout repository
|
|
32
|
-
uses: actions/checkout@
|
|
32
|
+
uses: actions/checkout@v6
|
|
33
33
|
with:
|
|
34
34
|
fetch-depth: 0
|
|
35
35
|
|
|
@@ -72,7 +72,7 @@ jobs:
|
|
|
72
72
|
|
|
73
73
|
steps:
|
|
74
74
|
- name: Checkout repository
|
|
75
|
-
uses: actions/checkout@
|
|
75
|
+
uses: actions/checkout@v6
|
|
76
76
|
|
|
77
77
|
- name: Set up Python ${{ matrix.python-version }}
|
|
78
78
|
uses: actions/setup-python@v6
|
|
@@ -108,7 +108,7 @@ jobs:
|
|
|
108
108
|
|
|
109
109
|
steps:
|
|
110
110
|
- name: Checkout repository
|
|
111
|
-
uses: actions/checkout@
|
|
111
|
+
uses: actions/checkout@v6
|
|
112
112
|
|
|
113
113
|
- name: Set up Python
|
|
114
114
|
uses: actions/setup-python@v6
|
|
@@ -132,7 +132,7 @@ jobs:
|
|
|
132
132
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
|
|
133
133
|
|
|
134
134
|
- name: Upload security reports
|
|
135
|
-
uses: actions/upload-artifact@
|
|
135
|
+
uses: actions/upload-artifact@v6
|
|
136
136
|
with:
|
|
137
137
|
name: security-reports
|
|
138
138
|
path: |
|
|
@@ -146,7 +146,7 @@ jobs:
|
|
|
146
146
|
|
|
147
147
|
steps:
|
|
148
148
|
- name: Checkout repository
|
|
149
|
-
uses: actions/checkout@
|
|
149
|
+
uses: actions/checkout@v6
|
|
150
150
|
|
|
151
151
|
- name: Set up Python
|
|
152
152
|
uses: actions/setup-python@v6
|
|
@@ -166,7 +166,7 @@ jobs:
|
|
|
166
166
|
uv run twine check dist/*
|
|
167
167
|
|
|
168
168
|
- name: Upload build artifacts
|
|
169
|
-
uses: actions/upload-artifact@
|
|
169
|
+
uses: actions/upload-artifact@v6
|
|
170
170
|
with:
|
|
171
171
|
name: dist-packages
|
|
172
172
|
path: dist/
|
|
@@ -181,7 +181,7 @@ jobs:
|
|
|
181
181
|
|
|
182
182
|
steps:
|
|
183
183
|
- name: Checkout repository
|
|
184
|
-
uses: actions/checkout@
|
|
184
|
+
uses: actions/checkout@v6
|
|
185
185
|
with:
|
|
186
186
|
fetch-depth: 0
|
|
187
187
|
|
|
@@ -261,10 +261,10 @@ jobs:
|
|
|
261
261
|
|
|
262
262
|
steps:
|
|
263
263
|
- name: Checkout repository
|
|
264
|
-
uses: actions/checkout@
|
|
264
|
+
uses: actions/checkout@v6
|
|
265
265
|
|
|
266
266
|
- name: Download build artifacts
|
|
267
|
-
uses: actions/download-artifact@
|
|
267
|
+
uses: actions/download-artifact@v7
|
|
268
268
|
with:
|
|
269
269
|
name: dist-packages
|
|
270
270
|
path: dist/
|
|
@@ -294,7 +294,7 @@ jobs:
|
|
|
294
294
|
|
|
295
295
|
steps:
|
|
296
296
|
- name: Download build artifacts
|
|
297
|
-
uses: actions/download-artifact@
|
|
297
|
+
uses: actions/download-artifact@v7
|
|
298
298
|
with:
|
|
299
299
|
name: dist-packages
|
|
300
300
|
path: dist/
|
|
@@ -20,7 +20,7 @@ jobs:
|
|
|
20
20
|
|
|
21
21
|
steps:
|
|
22
22
|
- name: Checkout repository
|
|
23
|
-
uses: actions/checkout@
|
|
23
|
+
uses: actions/checkout@v6
|
|
24
24
|
|
|
25
25
|
- name: Set up Python
|
|
26
26
|
uses: actions/setup-python@v6
|
|
@@ -47,7 +47,7 @@ jobs:
|
|
|
47
47
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || true
|
|
48
48
|
|
|
49
49
|
- name: Upload dependency scan results
|
|
50
|
-
uses: actions/upload-artifact@
|
|
50
|
+
uses: actions/upload-artifact@v6
|
|
51
51
|
if: always()
|
|
52
52
|
with:
|
|
53
53
|
name: security-scan-results
|
|
@@ -62,7 +62,7 @@ jobs:
|
|
|
62
62
|
|
|
63
63
|
steps:
|
|
64
64
|
- name: Checkout repository
|
|
65
|
-
uses: actions/checkout@
|
|
65
|
+
uses: actions/checkout@v6
|
|
66
66
|
with:
|
|
67
67
|
fetch-depth: 0 # Fetch full history for comprehensive scanning
|
|
68
68
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.5
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
|
|
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
|
|
|
25
25
|
Requires-Dist: numpy>=2.0.0
|
|
26
26
|
Requires-Dist: numpyro
|
|
27
27
|
Requires-Dist: optax>0.2.1
|
|
28
|
+
Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
|
|
28
29
|
Requires-Dist: tqdm>4.66.2
|
|
29
30
|
Provides-Extra: dev
|
|
30
31
|
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
|
59
60
|
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
60
61
|
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
61
62
|
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
62
|
-
Requires-Dist: mkdocstrings[python]<
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
|
|
63
64
|
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
64
65
|
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
65
66
|
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
@@ -178,12 +178,19 @@ points. We demonstrate its use in
|
|
|
178
178
|
|
|
179
179
|
## JIT compilation
|
|
180
180
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
181
|
+
GPJax validates parameters at construction time using two kinds of checks:
|
|
182
|
+
|
|
183
|
+
1. **Type checks** — plain Python `isinstance` checks that verify values are array-like.
|
|
184
|
+
2. **Value checks** — JAX-compatible assertions (via `checkify`) that verify constraints
|
|
185
|
+
like positivity or bounds.
|
|
186
|
+
|
|
187
|
+
During JIT tracing, concrete values are replaced by abstract tracers. The type checks
|
|
188
|
+
use `isinstance`, which is a pure Python operation that cannot be intercepted by JAX's
|
|
189
|
+
`checkify` transformation. This means that constructing GPJax objects (kernels, mean
|
|
190
|
+
functions, likelihoods, etc.) **inside** a JIT boundary will fail.
|
|
191
|
+
|
|
192
|
+
As an example, consider the following code that constructs a kernel inside a
|
|
193
|
+
JIT-compiled function:
|
|
187
194
|
|
|
188
195
|
```python
|
|
189
196
|
import jax
|
|
@@ -192,43 +199,40 @@ import gpjax as gpx
|
|
|
192
199
|
|
|
193
200
|
x = jnp.linspace(0, 1, 10)[:, None]
|
|
194
201
|
|
|
195
|
-
def
|
|
202
|
+
def compute_gram_bad(lengthscale):
|
|
196
203
|
k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
|
|
197
204
|
return k.gram(x)
|
|
198
205
|
|
|
199
|
-
|
|
206
|
+
compute_gram_bad(1.0) # works fine outside JIT
|
|
200
207
|
```
|
|
201
208
|
|
|
202
|
-
|
|
209
|
+
If we try to JIT compile this function, we get a `TypeError` because the kernel
|
|
210
|
+
constructor receives a JAX tracer instead of a concrete array:
|
|
203
211
|
|
|
204
212
|
```python
|
|
205
|
-
|
|
213
|
+
jit_compute_gram_bad = jax.jit(compute_gram_bad)
|
|
206
214
|
try:
|
|
207
|
-
|
|
215
|
+
jit_compute_gram_bad(1.0)
|
|
208
216
|
except Exception as e:
|
|
209
217
|
print(e)
|
|
210
218
|
```
|
|
211
219
|
|
|
212
|
-
|
|
213
|
-
that the lengthscale is positive. It does not matter that the assertion is satisfied;
|
|
214
|
-
the very presence of the assertion will break JIT compilation.
|
|
220
|
+
### The fix: construct objects outside JIT
|
|
215
221
|
|
|
216
|
-
|
|
217
|
-
|
|
222
|
+
The solution is to construct GPJax objects **outside** the JIT boundary and only JIT the
|
|
223
|
+
computation itself. This follows the standard JAX pattern of keeping object construction
|
|
224
|
+
separate from traced computation:
|
|
218
225
|
|
|
219
226
|
```python
|
|
220
|
-
|
|
227
|
+
k = gpx.kernels.RBF(active_dims=[0], lengthscale=1.0, variance=jnp.array(1.0))
|
|
221
228
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
|
|
226
|
-
output of the assertion, and the second element is the value of the function.
|
|
229
|
+
@jax.jit
|
|
230
|
+
def compute_gram(x):
|
|
231
|
+
return k.gram(x)
|
|
227
232
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
guardrails in a less intrusive manner. However, for now, should you try to JIT compile
|
|
231
|
-
a component of GPJax wherein there is an assertion, you will need to wrap the function
|
|
232
|
-
in `checkify.checkify` as shown above.
|
|
233
|
+
result = compute_gram(x)
|
|
234
|
+
```
|
|
233
235
|
|
|
234
|
-
|
|
236
|
+
More generally, any GPJax object should be constructed outside of `jax.jit`, `jax.vmap`,
|
|
237
|
+
or `jax.grad` boundaries. Once constructed, their methods can be freely used inside
|
|
238
|
+
these JAX transformations.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
# ---
|
|
17
17
|
|
|
18
18
|
# %% [markdown]
|
|
19
|
-
# # Heteroscedastic
|
|
19
|
+
# # Heteroscedastic Inference
|
|
20
20
|
#
|
|
21
21
|
# This notebook shows how to fit a heteroscedastic Gaussian processes (GPs) that
|
|
22
22
|
# allows one to perform regression where there exists non-constant, or
|
|
@@ -293,13 +293,13 @@ posterior_adv = mean_prior * likelihood_adv
|
|
|
293
293
|
# The signal requires a richer inducing set to capture its oscillations, whereas the
|
|
294
294
|
# noise process can be summarised with fewer points because the burst is localised.
|
|
295
295
|
z_signal = jnp.linspace(-2.0, 2.0, 30)[:, None]
|
|
296
|
-
z_noise = jnp.linspace(-2.0, 2.0,
|
|
296
|
+
z_noise = jnp.linspace(-2.0, 2.0, 20)[:, None]
|
|
297
297
|
|
|
298
298
|
# Use VariationalGaussianInit to pass specific configurations
|
|
299
299
|
q_init_f = VariationalGaussianInit(inducing_inputs=z_signal)
|
|
300
300
|
q_init_g = VariationalGaussianInit(inducing_inputs=z_noise)
|
|
301
301
|
|
|
302
|
-
|
|
302
|
+
q_sparse = HeteroscedasticVariationalFamily(
|
|
303
303
|
posterior=posterior_adv,
|
|
304
304
|
signal_init=q_init_f,
|
|
305
305
|
noise_init=q_init_g,
|
|
@@ -315,19 +315,19 @@ q_adv = HeteroscedasticVariationalFamily(
|
|
|
315
315
|
# Optimize
|
|
316
316
|
objective_adv = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
|
|
317
317
|
optimiser_adv = ox.adam(1e-2)
|
|
318
|
-
|
|
319
|
-
model=
|
|
318
|
+
q_sparse_trained, _ = gpx.fit(
|
|
319
|
+
model=q_sparse,
|
|
320
320
|
objective=objective_adv,
|
|
321
321
|
train_data=data_adv,
|
|
322
322
|
optim=optimiser_adv,
|
|
323
|
-
num_iters=
|
|
323
|
+
num_iters=10000,
|
|
324
324
|
verbose=False,
|
|
325
325
|
)
|
|
326
326
|
|
|
327
327
|
# %%
|
|
328
328
|
# Plotting
|
|
329
|
-
xtest = jnp.linspace(-2.2, 2.2,
|
|
330
|
-
pred =
|
|
329
|
+
xtest = jnp.linspace(-2.2, 2.2, 300)[:, None]
|
|
330
|
+
pred = q_sparse_trained.predict(xtest)
|
|
331
331
|
|
|
332
332
|
# Unpack the named tuple
|
|
333
333
|
mf = pred.mean_f
|
|
@@ -339,36 +339,41 @@ vg = pred.variance_g
|
|
|
339
339
|
# The likelihood expects the *latent* noise distribution to compute the predictive
|
|
340
340
|
# but here we can just use the transformed expected variance for plotting.
|
|
341
341
|
# For accurate predictive intervals, we should use likelihood.predict.
|
|
342
|
-
signal_dist, noise_dist =
|
|
342
|
+
signal_dist, noise_dist = q_sparse_trained.predict_latents(xtest)
|
|
343
343
|
predictive_dist = likelihood_adv.predict(signal_dist, noise_dist)
|
|
344
344
|
predictive_mean = predictive_dist.mean
|
|
345
345
|
predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix))
|
|
346
346
|
|
|
347
|
-
fig, ax = plt.subplots()
|
|
348
|
-
ax.plot(x, y, "
|
|
349
|
-
|
|
347
|
+
fig, ax = plt.subplots(figsize=(6, 2.5))
|
|
348
|
+
ax.plot(x, y, "x", color="black", alpha=0.5, label="Data")
|
|
349
|
+
|
|
350
|
+
# Plot total uncertainty (signal + noise)
|
|
351
|
+
ax.plot(xtest, predictive_mean, "--", color=cols[1], linewidth=2)
|
|
350
352
|
ax.fill_between(
|
|
351
353
|
xtest.squeeze(),
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
color=
|
|
355
|
-
alpha=0.
|
|
356
|
-
label="
|
|
354
|
+
predictive_mean - predictive_std,
|
|
355
|
+
predictive_mean + predictive_std,
|
|
356
|
+
color=cols[1],
|
|
357
|
+
alpha=0.3,
|
|
358
|
+
label="One std. dev.",
|
|
357
359
|
)
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
ax.plot(xtest, predictive_mean, "--", color="C1", alpha=0.5)
|
|
360
|
+
ax.plot(xtest.squeeze(), predictive_mean - predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
|
|
361
|
+
ax.plot(xtest.squeeze(), predictive_mean + predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
|
|
361
362
|
ax.fill_between(
|
|
362
363
|
xtest.squeeze(),
|
|
363
364
|
predictive_mean - 2 * predictive_std,
|
|
364
365
|
predictive_mean + 2 * predictive_std,
|
|
365
|
-
color=
|
|
366
|
+
color=cols[1],
|
|
366
367
|
alpha=0.1,
|
|
367
|
-
label="
|
|
368
|
+
label="Two std. dev.",
|
|
368
369
|
)
|
|
370
|
+
ax.plot(xtest.squeeze(), predictive_mean - 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
|
|
371
|
+
ax.plot(xtest.squeeze(), predictive_mean + 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
|
|
369
372
|
|
|
370
|
-
ax.set_title("Heteroscedastic Regression
|
|
371
|
-
ax.legend(loc="
|
|
373
|
+
ax.set_title("Sparse Heteroscedastic Regression")
|
|
374
|
+
ax.legend(loc="best", fontsize="small")
|
|
375
|
+
ax.set_xlabel("$x$")
|
|
376
|
+
ax.set_ylabel("$y$")
|
|
372
377
|
|
|
373
378
|
# %% [markdown]
|
|
374
379
|
# ## Takeaways
|
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/thomaspinder/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.13.
|
|
43
|
+
__version__ = "0.13.5"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
3
|
from flax import nnx
|
|
4
|
+
import jax
|
|
4
5
|
from jax.experimental import checkify
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import jax.tree_util as jtu
|
|
@@ -162,7 +163,7 @@ def _check_is_arraylike(value: T) -> None:
|
|
|
162
163
|
Raises:
|
|
163
164
|
TypeError: If the value is not array-like.
|
|
164
165
|
"""
|
|
165
|
-
if not isinstance(value, (ArrayLike, list)):
|
|
166
|
+
if not isinstance(value, (jax.Array, ArrayLike, list)):
|
|
166
167
|
raise TypeError(
|
|
167
168
|
f"Expected parameter value to be an array-like type. Got {value}."
|
|
168
169
|
)
|
|
@@ -118,10 +118,9 @@ plugins:
|
|
|
118
118
|
handlers:
|
|
119
119
|
python:
|
|
120
120
|
paths: ["gpjax"]
|
|
121
|
-
|
|
121
|
+
options:
|
|
122
122
|
show_symbol_type_toc: true
|
|
123
123
|
show_signature_annotations: true
|
|
124
|
-
options:
|
|
125
124
|
members_order: source
|
|
126
125
|
inherited_members: true
|
|
127
126
|
show_source: false
|
|
@@ -30,13 +30,14 @@ dependencies = [
|
|
|
30
30
|
"beartype>0.16.1",
|
|
31
31
|
"flax>=0.12.0",
|
|
32
32
|
"numpy>=2.0.0",
|
|
33
|
+
"tensorstore!=0.1.76; sys_platform == 'darwin'",
|
|
33
34
|
]
|
|
34
35
|
|
|
35
36
|
[project.optional-dependencies]
|
|
36
37
|
docs = [
|
|
37
38
|
"mkdocs>=1.5.3",
|
|
38
39
|
"mkdocs-material>=9.5.12",
|
|
39
|
-
"mkdocstrings[python]<
|
|
40
|
+
"mkdocstrings[python]<1.1.0",
|
|
40
41
|
"mkdocs-jupyter>=0.24.3",
|
|
41
42
|
"mkdocs-gen-files>=0.5.0",
|
|
42
43
|
"mkdocs-literate-nav>=0.6.0",
|
|
@@ -78,6 +79,7 @@ dev = [
|
|
|
78
79
|
]
|
|
79
80
|
|
|
80
81
|
[tool.uv]
|
|
82
|
+
exclude-newer = "7 days"
|
|
81
83
|
managed = true
|
|
82
84
|
dev-dependencies = [
|
|
83
85
|
"ruff>=0.6",
|
|
@@ -260,12 +262,11 @@ convention = "numpy"
|
|
|
260
262
|
|
|
261
263
|
[tool.ruff.lint.per-file-ignores]
|
|
262
264
|
"gpjax/__init__.py" = ['I', 'F401', 'E402', 'D104']
|
|
263
|
-
"gpjax/progress_bar.py" = ["TCH004"]
|
|
264
265
|
"gpjax/scan.py" = ["PLR0913"]
|
|
265
266
|
"gpjax/citation.py" = ["F811"]
|
|
266
|
-
"tests/test_base/test_module.py" = ["PLR0915"]
|
|
267
267
|
"tests/test_objectives.py" = ["PLR0913"]
|
|
268
|
-
"
|
|
268
|
+
"examples/barycentres.py" = ["PLR0913"]
|
|
269
|
+
"tests/*.py" = ["PLW0108"]
|
|
269
270
|
|
|
270
271
|
[tool.isort]
|
|
271
272
|
profile = "black"
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import jax
|
|
1
2
|
from flax import nnx
|
|
2
3
|
from jax import jit
|
|
3
4
|
from jax.experimental import checkify
|
|
@@ -109,3 +110,22 @@ def test_check_in_bounds():
|
|
|
109
110
|
_safe_assert(
|
|
110
111
|
_check_in_bounds, jnp.array(1.5), low=jnp.array(0.0), high=jnp.array(1.0)
|
|
111
112
|
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@pytest.mark.parametrize(
|
|
116
|
+
"param_cls, value",
|
|
117
|
+
[
|
|
118
|
+
(PositiveReal, jnp.array(1.0)),
|
|
119
|
+
(PositiveReal, jnp.array([1.0, 2.0])),
|
|
120
|
+
(Real, jnp.array(1.0)),
|
|
121
|
+
(NonNegativeReal, jnp.array(1.0)),
|
|
122
|
+
],
|
|
123
|
+
)
|
|
124
|
+
def test_parameter_construction_under_grad(param_cls, value):
|
|
125
|
+
"""Regression test for #592: parameter construction must accept JAX tracers."""
|
|
126
|
+
|
|
127
|
+
def f(x):
|
|
128
|
+
return param_cls(x).value.sum()
|
|
129
|
+
|
|
130
|
+
grad = jax.grad(f)(value)
|
|
131
|
+
assert grad.shape == value.shape
|