gpjax 0.13.3__tar.gz → 0.13.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/auto-label.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/build_docs.yml +2 -2
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/commit-lint.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/integration.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/release.yml +10 -10
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/ruff.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/security-analysis.yml +3 -3
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/test_docs.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/tests.yml +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/.gitignore +5 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/PKG-INFO +4 -3
- {gpjax-0.13.3 → gpjax-0.13.5}/README.md +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/sharp_bits.md +32 -28
- gpjax-0.13.5/examples/heteroscedastic_inference.py +394 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/regression.py +24 -23
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/__init__.py +1 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/citation.py +13 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/gps.py +77 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/likelihoods.py +234 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/mean_functions.py +2 -2
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/objectives.py +56 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/parameters.py +10 -2
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/variational_families.py +129 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/mkdocs.yml +2 -2
- {gpjax-0.13.3 → gpjax-0.13.5}/pyproject.toml +7 -5
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/conftest.py +7 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/integration_tests.py +9 -2
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_citations.py +16 -0
- gpjax-0.13.5/tests/test_heteroscedastic.py +407 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_mean_functions.py +16 -1
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_parameters.py +20 -0
- gpjax-0.13.5/uv.lock +4279 -0
- gpjax-0.13.3/.github/workflows/pr_greeting.yml +0 -62
- gpjax-0.13.3/uv.lock +0 -3535
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/FUNDING.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/codecov.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/commitlint.config.js +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/dependabot.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/labeler.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/labels.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/pull_request_template.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/.github/release-drafter.yml +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/CITATION.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/LICENSE.txt +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/Makefile +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/contributing.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/design.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/index.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/index.rst +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/installation.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/javascripts/katex.js +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/refs.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/GP.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/GP.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/favicon.ico +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/backend.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/barycentres.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/classification.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/collapsed_vi.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/deep_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/graph_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/oceanmodelling.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/poisson.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/examples/yacht.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/dataset.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/distributions.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/fit.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/integrators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/operations.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/operators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/scan.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/typing.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.bib +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.md +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.pdf +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_dataset.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_fit.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_gps.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_imports.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_integrators.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_likelihoods.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_linalg.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_markdown.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_objectives.py +0 -0
- {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_variational_families.py +0 -0
|
@@ -27,7 +27,7 @@ jobs:
|
|
|
27
27
|
steps:
|
|
28
28
|
# Grap the latest commit from the branch
|
|
29
29
|
- name: Checkout the branch
|
|
30
|
-
uses: actions/checkout@
|
|
30
|
+
uses: actions/checkout@v6
|
|
31
31
|
with:
|
|
32
32
|
persist-credentials: false
|
|
33
33
|
|
|
@@ -61,7 +61,7 @@ jobs:
|
|
|
61
61
|
uv run mkdocs build
|
|
62
62
|
|
|
63
63
|
- name: Deploy Page 🚀
|
|
64
|
-
uses: JamesIves/github-pages-deploy-action@v4.
|
|
64
|
+
uses: JamesIves/github-pages-deploy-action@v4.8.0
|
|
65
65
|
with:
|
|
66
66
|
branch: gh-pages
|
|
67
67
|
folder: site
|
|
@@ -29,7 +29,7 @@ jobs:
|
|
|
29
29
|
|
|
30
30
|
steps:
|
|
31
31
|
- name: Checkout repository
|
|
32
|
-
uses: actions/checkout@
|
|
32
|
+
uses: actions/checkout@v6
|
|
33
33
|
with:
|
|
34
34
|
fetch-depth: 0
|
|
35
35
|
|
|
@@ -72,7 +72,7 @@ jobs:
|
|
|
72
72
|
|
|
73
73
|
steps:
|
|
74
74
|
- name: Checkout repository
|
|
75
|
-
uses: actions/checkout@
|
|
75
|
+
uses: actions/checkout@v6
|
|
76
76
|
|
|
77
77
|
- name: Set up Python ${{ matrix.python-version }}
|
|
78
78
|
uses: actions/setup-python@v6
|
|
@@ -108,7 +108,7 @@ jobs:
|
|
|
108
108
|
|
|
109
109
|
steps:
|
|
110
110
|
- name: Checkout repository
|
|
111
|
-
uses: actions/checkout@
|
|
111
|
+
uses: actions/checkout@v6
|
|
112
112
|
|
|
113
113
|
- name: Set up Python
|
|
114
114
|
uses: actions/setup-python@v6
|
|
@@ -132,7 +132,7 @@ jobs:
|
|
|
132
132
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
|
|
133
133
|
|
|
134
134
|
- name: Upload security reports
|
|
135
|
-
uses: actions/upload-artifact@
|
|
135
|
+
uses: actions/upload-artifact@v6
|
|
136
136
|
with:
|
|
137
137
|
name: security-reports
|
|
138
138
|
path: |
|
|
@@ -146,7 +146,7 @@ jobs:
|
|
|
146
146
|
|
|
147
147
|
steps:
|
|
148
148
|
- name: Checkout repository
|
|
149
|
-
uses: actions/checkout@
|
|
149
|
+
uses: actions/checkout@v6
|
|
150
150
|
|
|
151
151
|
- name: Set up Python
|
|
152
152
|
uses: actions/setup-python@v6
|
|
@@ -166,7 +166,7 @@ jobs:
|
|
|
166
166
|
uv run twine check dist/*
|
|
167
167
|
|
|
168
168
|
- name: Upload build artifacts
|
|
169
|
-
uses: actions/upload-artifact@
|
|
169
|
+
uses: actions/upload-artifact@v6
|
|
170
170
|
with:
|
|
171
171
|
name: dist-packages
|
|
172
172
|
path: dist/
|
|
@@ -181,7 +181,7 @@ jobs:
|
|
|
181
181
|
|
|
182
182
|
steps:
|
|
183
183
|
- name: Checkout repository
|
|
184
|
-
uses: actions/checkout@
|
|
184
|
+
uses: actions/checkout@v6
|
|
185
185
|
with:
|
|
186
186
|
fetch-depth: 0
|
|
187
187
|
|
|
@@ -261,10 +261,10 @@ jobs:
|
|
|
261
261
|
|
|
262
262
|
steps:
|
|
263
263
|
- name: Checkout repository
|
|
264
|
-
uses: actions/checkout@
|
|
264
|
+
uses: actions/checkout@v6
|
|
265
265
|
|
|
266
266
|
- name: Download build artifacts
|
|
267
|
-
uses: actions/download-artifact@
|
|
267
|
+
uses: actions/download-artifact@v7
|
|
268
268
|
with:
|
|
269
269
|
name: dist-packages
|
|
270
270
|
path: dist/
|
|
@@ -294,7 +294,7 @@ jobs:
|
|
|
294
294
|
|
|
295
295
|
steps:
|
|
296
296
|
- name: Download build artifacts
|
|
297
|
-
uses: actions/download-artifact@
|
|
297
|
+
uses: actions/download-artifact@v7
|
|
298
298
|
with:
|
|
299
299
|
name: dist-packages
|
|
300
300
|
path: dist/
|
|
@@ -20,7 +20,7 @@ jobs:
|
|
|
20
20
|
|
|
21
21
|
steps:
|
|
22
22
|
- name: Checkout repository
|
|
23
|
-
uses: actions/checkout@
|
|
23
|
+
uses: actions/checkout@v6
|
|
24
24
|
|
|
25
25
|
- name: Set up Python
|
|
26
26
|
uses: actions/setup-python@v6
|
|
@@ -47,7 +47,7 @@ jobs:
|
|
|
47
47
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || true
|
|
48
48
|
|
|
49
49
|
- name: Upload dependency scan results
|
|
50
|
-
uses: actions/upload-artifact@
|
|
50
|
+
uses: actions/upload-artifact@v6
|
|
51
51
|
if: always()
|
|
52
52
|
with:
|
|
53
53
|
name: security-scan-results
|
|
@@ -62,7 +62,7 @@ jobs:
|
|
|
62
62
|
|
|
63
63
|
steps:
|
|
64
64
|
- name: Checkout repository
|
|
65
|
-
uses: actions/checkout@
|
|
65
|
+
uses: actions/checkout@v6
|
|
66
66
|
with:
|
|
67
67
|
fetch-depth: 0 # Fetch full history for comprehensive scanning
|
|
68
68
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.5
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
|
|
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
|
|
|
25
25
|
Requires-Dist: numpy>=2.0.0
|
|
26
26
|
Requires-Dist: numpyro
|
|
27
27
|
Requires-Dist: optax>0.2.1
|
|
28
|
+
Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
|
|
28
29
|
Requires-Dist: tqdm>4.66.2
|
|
29
30
|
Provides-Extra: dev
|
|
30
31
|
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
|
59
60
|
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
60
61
|
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
61
62
|
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
62
|
-
Requires-Dist: mkdocstrings[python]<
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
|
|
63
64
|
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
64
65
|
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
65
66
|
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
@@ -141,7 +142,7 @@ GPJax into the package it is today.
|
|
|
141
142
|
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
142
143
|
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
143
144
|
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
144
|
-
> - [**
|
|
145
|
+
> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
|
|
145
146
|
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
146
147
|
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
147
148
|
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
@@ -70,7 +70,7 @@ GPJax into the package it is today.
|
|
|
70
70
|
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
71
71
|
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
72
72
|
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
73
|
-
> - [**
|
|
73
|
+
> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
|
|
74
74
|
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
75
75
|
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
76
76
|
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
@@ -178,12 +178,19 @@ points. We demonstrate its use in
|
|
|
178
178
|
|
|
179
179
|
## JIT compilation
|
|
180
180
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
181
|
+
GPJax validates parameters at construction time using two kinds of checks:
|
|
182
|
+
|
|
183
|
+
1. **Type checks** — plain Python `isinstance` checks that verify values are array-like.
|
|
184
|
+
2. **Value checks** — JAX-compatible assertions (via `checkify`) that verify constraints
|
|
185
|
+
like positivity or bounds.
|
|
186
|
+
|
|
187
|
+
During JIT tracing, concrete values are replaced by abstract tracers. The type checks
|
|
188
|
+
use `isinstance`, which is a pure Python operation that cannot be intercepted by JAX's
|
|
189
|
+
`checkify` transformation. This means that constructing GPJax objects (kernels, mean
|
|
190
|
+
functions, likelihoods, etc.) **inside** a JIT boundary will fail.
|
|
191
|
+
|
|
192
|
+
As an example, consider the following code that constructs a kernel inside a
|
|
193
|
+
JIT-compiled function:
|
|
187
194
|
|
|
188
195
|
```python
|
|
189
196
|
import jax
|
|
@@ -192,43 +199,40 @@ import gpjax as gpx
|
|
|
192
199
|
|
|
193
200
|
x = jnp.linspace(0, 1, 10)[:, None]
|
|
194
201
|
|
|
195
|
-
def
|
|
202
|
+
def compute_gram_bad(lengthscale):
|
|
196
203
|
k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
|
|
197
204
|
return k.gram(x)
|
|
198
205
|
|
|
199
|
-
|
|
206
|
+
compute_gram_bad(1.0) # works fine outside JIT
|
|
200
207
|
```
|
|
201
208
|
|
|
202
|
-
|
|
209
|
+
If we try to JIT compile this function, we get a `TypeError` because the kernel
|
|
210
|
+
constructor receives a JAX tracer instead of a concrete array:
|
|
203
211
|
|
|
204
212
|
```python
|
|
205
|
-
|
|
213
|
+
jit_compute_gram_bad = jax.jit(compute_gram_bad)
|
|
206
214
|
try:
|
|
207
|
-
|
|
215
|
+
jit_compute_gram_bad(1.0)
|
|
208
216
|
except Exception as e:
|
|
209
217
|
print(e)
|
|
210
218
|
```
|
|
211
219
|
|
|
212
|
-
|
|
213
|
-
that the lengthscale is positive. It does not matter that the assertion is satisfied;
|
|
214
|
-
the very presence of the assertion will break JIT compilation.
|
|
220
|
+
### The fix: construct objects outside JIT
|
|
215
221
|
|
|
216
|
-
|
|
217
|
-
|
|
222
|
+
The solution is to construct GPJax objects **outside** the JIT boundary and only JIT the
|
|
223
|
+
computation itself. This follows the standard JAX pattern of keeping object construction
|
|
224
|
+
separate from traced computation:
|
|
218
225
|
|
|
219
226
|
```python
|
|
220
|
-
|
|
227
|
+
k = gpx.kernels.RBF(active_dims=[0], lengthscale=1.0, variance=jnp.array(1.0))
|
|
221
228
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
|
|
226
|
-
output of the assertion, and the second element is the value of the function.
|
|
229
|
+
@jax.jit
|
|
230
|
+
def compute_gram(x):
|
|
231
|
+
return k.gram(x)
|
|
227
232
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
guardrails in a less intrusive manner. However, for now, should you try to JIT compile
|
|
231
|
-
a component of GPJax wherein there is an assertion, you will need to wrap the function
|
|
232
|
-
in `checkify.checkify` as shown above.
|
|
233
|
+
result = compute_gram(x)
|
|
234
|
+
```
|
|
233
235
|
|
|
234
|
-
|
|
236
|
+
More generally, any GPJax object should be constructed outside of `jax.jit`, `jax.vmap`,
|
|
237
|
+
or `jax.grad` boundaries. Once constructed, their methods can be freely used inside
|
|
238
|
+
these JAX transformations.
|