gpjax 0.11.2__tar.gz → 0.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax-0.12.0/.claude/settings.local.json +32 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/pull_request_template.md +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/build_docs.yml +9 -5
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/integration.yml +9 -5
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/test_docs.yml +9 -5
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/tests.yml +11 -8
- gpjax-0.12.0/CLAUDE.md +100 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/Makefile +8 -3
- {gpjax-0.11.2 → gpjax-0.12.0}/PKG-INFO +49 -9
- {gpjax-0.11.2 → gpjax-0.12.0}/README.md +5 -6
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/contributing.md +4 -4
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/installation.md +2 -2
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/barycentres.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/classification.py +6 -7
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/collapsed_vi.py +2 -2
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/constructing_new_kernels.py +3 -3
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/deep_kernels.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/graph_kernels.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/intro_to_gps.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/intro_to_kernels.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/likelihoods_guide.py +1 -1
- gpjax-0.12.0/examples/oak_autompg_example.py +603 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/oceanmodelling.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/poisson.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/regression.py +2 -2
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/uncollapsed_vi.py +2 -2
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/yacht.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/__init__.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/distributions.py +16 -56
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/gps.py +34 -48
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/base.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/base.py +7 -7
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/basis_functions.py +6 -5
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/constant_diagonal.py +10 -12
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/diagonal.py +6 -6
- gpjax-0.12.0/gpjax/linalg/__init__.py +37 -0
- gpjax-0.12.0/gpjax/linalg/operations.py +237 -0
- gpjax-0.12.0/gpjax/linalg/operators.py +411 -0
- gpjax-0.12.0/gpjax/linalg/utils.py +33 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/objectives.py +21 -21
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/parameters.py +11 -13
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/variational_families.py +43 -37
- {gpjax-0.11.2 → gpjax-0.12.0}/pyproject.toml +85 -37
- gpjax-0.12.0/scripts/format.sh +20 -0
- gpjax-0.12.0/scripts/format_code.py +48 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/static/CONTRIBUTING.md +4 -4
- gpjax-0.12.0/tasks/prd-advanced-state-management.md +51 -0
- gpjax-0.12.0/tasks/prd-dynamic-parameter-training.md +62 -0
- gpjax-0.12.0/tasks/tasks-prd-advanced-state-management.md +37 -0
- gpjax-0.12.0/tasks/tasks-prd-dynamic-parameter-training.md +53 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_gaussian_distribution.py +9 -10
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_gps.py +6 -14
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_approximations.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_computation.py +8 -8
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_non_euclidean.py +2 -2
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_nonstationary.py +1 -1
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_stationary.py +1 -1
- gpjax-0.12.0/tests/test_linalg.py +484 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_mean_functions.py +17 -12
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_parameters.py +2 -2
- gpjax-0.12.0/uv.lock +3695 -0
- gpjax-0.11.2/.github/workflows/stale_prs.yml +0 -45
- gpjax-0.11.2/gpjax/lower_cholesky.py +0 -69
- gpjax-0.11.2/tests/test_lower_cholesky.py +0 -110
- gpjax-0.11.2/uv.lock +0 -832
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/codecov.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/labels.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/release-drafter.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/.gitignore +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/CITATION.bib +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/LICENSE.txt +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/design.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/index.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/index.rst +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/javascripts/katex.js +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/refs.bib +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/sharp_bits.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/GP.pdf +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/GP.svg +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/favicon.ico +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/backend.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/examples/utils.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/citation.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/dataset.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/fit.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/integrators.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/likelihoods.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/mean_functions.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/scan.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/gpjax/typing.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/mkdocs.yml +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/static/paper.bib +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/static/paper.md +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/static/paper.pdf +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/conftest.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/integration_tests.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_citations.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_dataset.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_fit.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_integrators.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_likelihoods.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_markdown.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_objectives.py +0 -0
- {gpjax-0.11.2 → gpjax-0.12.0}/tests/test_variational_families.py +0 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
{
|
|
2
|
+
"permissions": {
|
|
3
|
+
"allow": [
|
|
4
|
+
"WebFetch(domain:raw.githubusercontent.com)",
|
|
5
|
+
"Bash(hatch run dev:test:*)",
|
|
6
|
+
"Bash(hatch run dev:python:*)",
|
|
7
|
+
"Bash(hatch run dev:lint:*)",
|
|
8
|
+
"Bash(hatch env:*)",
|
|
9
|
+
"Bash(hatch run dev:format:*)",
|
|
10
|
+
"Bash(hatch run docs:ipython:*)",
|
|
11
|
+
"Bash(timeout 30s hatch run docs:ipython examples/orthogonal_additive_kernel.py)",
|
|
12
|
+
"Bash(gtimeout:*)",
|
|
13
|
+
"mcp__fetch__imageFetch",
|
|
14
|
+
"Bash(hatch run:*)",
|
|
15
|
+
"Bash(mkdir:*)",
|
|
16
|
+
"Bash(find:*)",
|
|
17
|
+
"Bash(grep:*)",
|
|
18
|
+
"WebFetch(domain:flax.readthedocs.io)",
|
|
19
|
+
"WebSearch",
|
|
20
|
+
"WebFetch(domain:docs.jax.dev)",
|
|
21
|
+
"mcp__github__search_pull_requests",
|
|
22
|
+
"mcp__github__get_me",
|
|
23
|
+
"mcp__github__get_pull_request_status",
|
|
24
|
+
"mcp__github__list_workflows",
|
|
25
|
+
"mcp__github__list_workflow_runs",
|
|
26
|
+
"mcp__github__get_job_logs",
|
|
27
|
+
"Bash(uv run:*)"
|
|
28
|
+
],
|
|
29
|
+
"deny": [],
|
|
30
|
+
"ask": []
|
|
31
|
+
}
|
|
32
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
## Checklist
|
|
2
2
|
|
|
3
|
-
- [ ] I've formatted the new code by running `
|
|
3
|
+
- [ ] I've formatted the new code by running `uv run poe format` before committing.
|
|
4
4
|
- [ ] I've added tests for new code.
|
|
5
5
|
- [ ] I've added docstrings for the new code.
|
|
6
6
|
|
|
@@ -22,7 +22,7 @@ jobs:
|
|
|
22
22
|
strategy:
|
|
23
23
|
matrix:
|
|
24
24
|
os: ["ubuntu-latest"]
|
|
25
|
-
python-version: ["3.
|
|
25
|
+
python-version: ["3.11"]
|
|
26
26
|
|
|
27
27
|
steps:
|
|
28
28
|
# Grap the latest commit from the branch
|
|
@@ -47,14 +47,18 @@ jobs:
|
|
|
47
47
|
run: |
|
|
48
48
|
npm install katex
|
|
49
49
|
|
|
50
|
-
# Install
|
|
51
|
-
- name: Install
|
|
52
|
-
uses:
|
|
50
|
+
# Install uv
|
|
51
|
+
- name: Install uv
|
|
52
|
+
uses: astral-sh/setup-uv@v3
|
|
53
|
+
with:
|
|
54
|
+
version: "latest"
|
|
53
55
|
|
|
54
56
|
- name: Build the documentation with MKDocs
|
|
55
57
|
run: |
|
|
56
58
|
conda install pandoc
|
|
57
|
-
|
|
59
|
+
uv sync --extra docs
|
|
60
|
+
uv run python docs/scripts/gen_examples.py --execute
|
|
61
|
+
uv run mkdocs build
|
|
58
62
|
|
|
59
63
|
- name: Deploy Page 🚀
|
|
60
64
|
uses: JamesIves/github-pages-deploy-action@v4.4.1
|
|
@@ -13,7 +13,7 @@ jobs:
|
|
|
13
13
|
matrix:
|
|
14
14
|
# Select the Python versions to test against
|
|
15
15
|
os: ["ubuntu-latest", "macos-latest"]
|
|
16
|
-
python-version: ["3.10", "3.11"]
|
|
16
|
+
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
|
17
17
|
fail-fast: true
|
|
18
18
|
steps:
|
|
19
19
|
- name: Check out the code
|
|
@@ -25,10 +25,14 @@ jobs:
|
|
|
25
25
|
with:
|
|
26
26
|
python-version: ${{ matrix.python-version }}
|
|
27
27
|
|
|
28
|
-
# Install
|
|
29
|
-
- name: Install
|
|
30
|
-
uses:
|
|
28
|
+
# Install uv
|
|
29
|
+
- name: Install uv
|
|
30
|
+
uses: astral-sh/setup-uv@v3
|
|
31
|
+
with:
|
|
32
|
+
version: "latest"
|
|
31
33
|
|
|
32
34
|
# Run the unit tests and build the coverage report
|
|
33
35
|
- name: Run Integration Tests
|
|
34
|
-
run:
|
|
36
|
+
run: |
|
|
37
|
+
uv sync --extra docs
|
|
38
|
+
uv run python tests/integration_tests.py
|
|
@@ -17,7 +17,7 @@ jobs:
|
|
|
17
17
|
strategy:
|
|
18
18
|
matrix:
|
|
19
19
|
os: ["ubuntu-latest"]
|
|
20
|
-
python-version: ["3.
|
|
20
|
+
python-version: ["3.11"]
|
|
21
21
|
|
|
22
22
|
steps:
|
|
23
23
|
# Grap the latest commit from the branch
|
|
@@ -33,11 +33,15 @@ jobs:
|
|
|
33
33
|
auto-update-conda: true
|
|
34
34
|
python-version: ${{ matrix.python-version }}
|
|
35
35
|
|
|
36
|
-
# Install
|
|
37
|
-
- name: Install
|
|
38
|
-
uses:
|
|
36
|
+
# Install uv
|
|
37
|
+
- name: Install uv
|
|
38
|
+
uses: astral-sh/setup-uv@v3
|
|
39
|
+
with:
|
|
40
|
+
version: "latest"
|
|
39
41
|
|
|
40
42
|
- name: Build the documentation with MKDocs
|
|
41
43
|
run: |
|
|
42
44
|
conda install pandoc
|
|
43
|
-
|
|
45
|
+
uv sync --extra docs
|
|
46
|
+
uv run python docs/scripts/gen_examples.py --execute
|
|
47
|
+
uv run mkdocs build
|
|
@@ -8,13 +8,12 @@ on:
|
|
|
8
8
|
jobs:
|
|
9
9
|
unit-tests:
|
|
10
10
|
name: Run Tests
|
|
11
|
-
runs-on: ubuntu-latest
|
|
12
11
|
strategy:
|
|
13
12
|
matrix:
|
|
14
|
-
# Select the Python versions to test against
|
|
15
13
|
os: ["ubuntu-latest", "macos-latest"]
|
|
16
|
-
python-version: ["3.10", "3.11"]
|
|
14
|
+
python-version: ["3.10", "3.11", "3.12", "3.13"]
|
|
17
15
|
fail-fast: true
|
|
16
|
+
runs-on: ${{ matrix.os }}
|
|
18
17
|
steps:
|
|
19
18
|
- name: Check out the code
|
|
20
19
|
uses: actions/checkout@v3.5.2
|
|
@@ -25,18 +24,22 @@ jobs:
|
|
|
25
24
|
with:
|
|
26
25
|
python-version: ${{ matrix.python-version }}
|
|
27
26
|
|
|
28
|
-
# Install
|
|
29
|
-
- name: Install
|
|
30
|
-
uses:
|
|
27
|
+
# Install uv
|
|
28
|
+
- name: Install uv
|
|
29
|
+
uses: astral-sh/setup-uv@v3
|
|
30
|
+
with:
|
|
31
|
+
version: "latest"
|
|
31
32
|
|
|
32
33
|
# Install the dependencies
|
|
33
34
|
- name: Check docstrings
|
|
34
35
|
run: |
|
|
35
|
-
|
|
36
|
+
uv sync --extra dev
|
|
37
|
+
uv run xdoctest ./gpjax
|
|
36
38
|
|
|
37
39
|
# Run the unit tests and build the coverage report
|
|
38
40
|
- name: Run Tests
|
|
39
|
-
run:
|
|
41
|
+
run: uv run pytest . -v --cov=./gpjax --cov-report=xml:./coverage.xml
|
|
42
|
+
|
|
40
43
|
|
|
41
44
|
- name: Upload code coverage
|
|
42
45
|
uses: codecov/codecov-action@v3
|
gpjax-0.12.0/CLAUDE.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# CLAUDE.md
|
|
2
|
+
|
|
3
|
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
|
4
|
+
|
|
5
|
+
## Development Commands
|
|
6
|
+
|
|
7
|
+
### Testing
|
|
8
|
+
- **Run all tests**: `hatch run dev:test` or `pytest . -v -n 4 --beartype-packages='gpjax'`
|
|
9
|
+
- **Run tests with coverage**: `hatch run dev:coverage`
|
|
10
|
+
- **Run a single test file**: `pytest tests/test_file.py -v`
|
|
11
|
+
- **Run docstring tests**: `hatch run dev:docstrings`
|
|
12
|
+
- **Full test suite**: `hatch run dev:all-tests` (includes format check, docstrings, and tests)
|
|
13
|
+
|
|
14
|
+
### Code Formatting and Linting
|
|
15
|
+
- **Format all code**: `hatch run dev:format` (runs black, isort, and ruff format)
|
|
16
|
+
- **Check formatting**: `hatch run dev:check` (black-check, imports-check, lint-check)
|
|
17
|
+
- **Format with black**: `hatch run dev:black-format`
|
|
18
|
+
- **Format imports**: `hatch run dev:imports-format`
|
|
19
|
+
- **Lint and auto-fix**: `hatch run dev:lint-check`
|
|
20
|
+
- **Remove unused imports**: `hatch run dev:remove-unused`
|
|
21
|
+
|
|
22
|
+
### Documentation
|
|
23
|
+
- **Build docs**: `hatch run docs:build`
|
|
24
|
+
- **Serve docs locally**: `hatch run docs:serve`
|
|
25
|
+
- **Run integration tests**: `hatch run docs:integration`
|
|
26
|
+
|
|
27
|
+
### Build and Installation
|
|
28
|
+
- **Install for development**: `hatch env create` then `hatch shell`
|
|
29
|
+
- **Install stable version**: `pip install gpjax`
|
|
30
|
+
|
|
31
|
+
## High-Level Architecture
|
|
32
|
+
|
|
33
|
+
GPJax is a Gaussian Process library built on JAX and Flax (nnx), designed with a modular architecture that closely mirrors mathematical abstractions:
|
|
34
|
+
|
|
35
|
+
### Core Components
|
|
36
|
+
|
|
37
|
+
1. **Gaussian Processes (`gpjax.gps`)**
|
|
38
|
+
- `AbstractPrior`: Base class for GP priors combining kernels and mean functions
|
|
39
|
+
- `Prior`: Standard GP prior implementation
|
|
40
|
+
- `AbstractPosterior`: Base class for posterior inference
|
|
41
|
+
- `ConjugatePosterior`: Exact inference for Gaussian likelihoods
|
|
42
|
+
- `NonConjugatePosterior`: Approximate inference for non-Gaussian likelihoods
|
|
43
|
+
- Uses Flax nnx.Module for parameter management
|
|
44
|
+
|
|
45
|
+
2. **Kernels (`gpjax.kernels`)**
|
|
46
|
+
- `AbstractKernel`: Base kernel class with composition support (+, *)
|
|
47
|
+
- Stationary kernels: RBF, Matern12/32/52, RationalQuadratic, Periodic, White
|
|
48
|
+
- Non-stationary kernels: Linear, Polynomial, ArcCosine
|
|
49
|
+
- Non-Euclidean kernels: Graph kernels for structured data
|
|
50
|
+
- Kernel computations: Dense, Diagonal, Eigen decomposition strategies
|
|
51
|
+
- Approximations: Random Fourier Features (RFF)
|
|
52
|
+
|
|
53
|
+
3. **Likelihoods (`gpjax.likelihoods`)**
|
|
54
|
+
- `AbstractLikelihood`: Base class defining the observation model
|
|
55
|
+
- `Gaussian`: Standard Gaussian likelihood with observation noise
|
|
56
|
+
- Non-Gaussian: Bernoulli, Poisson for classification and count data
|
|
57
|
+
- Links prediction and observation spaces
|
|
58
|
+
|
|
59
|
+
4. **Variational Inference (`gpjax.variational_families`)**
|
|
60
|
+
- `AbstractVariationalFamily`: Base for variational approximations
|
|
61
|
+
- `VariationalGaussian`: Mean-field Gaussian variational distribution
|
|
62
|
+
- `WhitenedVariationalGaussian`: Whitened parameterization
|
|
63
|
+
- `NaturalVariationalGaussian`: Natural gradient parameterization
|
|
64
|
+
- Supports both collapsed and uncollapsed VI
|
|
65
|
+
|
|
66
|
+
5. **Objectives (`gpjax.objectives`)**
|
|
67
|
+
- `AbstractObjective`: Base class for optimization objectives
|
|
68
|
+
- `ConjugateMLL`: Marginal likelihood for exact inference
|
|
69
|
+
- `NonConjugateMLL`: Laplace approximation for non-Gaussian likelihoods
|
|
70
|
+
- `ELBO`: Evidence lower bound for variational inference
|
|
71
|
+
- `CollapsedELBO`: Analytically integrated ELBO
|
|
72
|
+
|
|
73
|
+
6. **Optimization (`gpjax.fit`)**
|
|
74
|
+
- `fit()`: General optimizer using Optax optimizers
|
|
75
|
+
- `fit_lbfgs()`: L-BFGS optimization via Optax
|
|
76
|
+
- `fit_scipy()`: Interface to scipy optimizers via JAXopt
|
|
77
|
+
- Supports custom loss functions and stopping criteria
|
|
78
|
+
|
|
79
|
+
### Key Design Patterns
|
|
80
|
+
|
|
81
|
+
- **Functional Design**: Functions are first-class, mirroring mathematical notation
|
|
82
|
+
- **Composability**: Kernels support arithmetic operations for easy combination
|
|
83
|
+
- **JAX Integration**: Full support for JIT compilation, automatic differentiation, and vectorization
|
|
84
|
+
- **Type Safety**: Extensive use of jaxtyping and beartype for runtime type checking
|
|
85
|
+
- **CoLA Integration**: Uses CoLA (Compositional Linear Algebra) for efficient linear algebra operations
|
|
86
|
+
- **Parameter Management**: Uses Flax nnx for trainable parameters with PyTree support
|
|
87
|
+
|
|
88
|
+
### Data Structures
|
|
89
|
+
|
|
90
|
+
- `Dataset`: Simple dataclass for (X, y) pairs with optional y
|
|
91
|
+
- `GaussianDistribution`: Represents multivariate Gaussians with efficient sampling
|
|
92
|
+
- Parameters use `nnx.Param` with transformation support (e.g., `SoftplusTransformation`)
|
|
93
|
+
|
|
94
|
+
### Important Implementation Details
|
|
95
|
+
|
|
96
|
+
- All kernels must implement `__call__(x, y)` for single point evaluation
|
|
97
|
+
- Kernel matrices are computed via `compute_engine.gram()` or `compute_engine.cross_covariance()`
|
|
98
|
+
- Jitter (small diagonal noise) is added for numerical stability, default 1e-6
|
|
99
|
+
- Uses `cola.PSD` annotations for positive semi-definite matrices
|
|
100
|
+
- Cholesky decompositions use custom `lower_cholesky` for better gradients
|
|
@@ -21,11 +21,16 @@ black: ## Format code in-place using black.
|
|
|
21
21
|
isort: ## Format imports in-place using isort.
|
|
22
22
|
isort ${PKGROOT}/ tests/
|
|
23
23
|
|
|
24
|
-
format: ## Code styling - black, isort
|
|
25
|
-
black ${PKGROOT}/ tests/
|
|
24
|
+
format: ## Code styling - black, isort, ruff
|
|
25
|
+
uv run black ${PKGROOT}/ tests/
|
|
26
26
|
@printf "\033[1;34mBlack passes!\033[0m\n\n"
|
|
27
|
-
|
|
27
|
+
uv run jupytext --pipe black examples/*.py
|
|
28
|
+
@printf "\033[1;34mJupytext passes!\033[0m\n\n"
|
|
29
|
+
uv run isort ${PKGROOT}/ tests/
|
|
30
|
+
uv run isort examples/*.py --treat-comment-as-code '# %%' --float-to-top
|
|
28
31
|
@printf "\033[1;34misort passes!\033[0m\n\n"
|
|
32
|
+
uv run ruff format ${PKGROOT}/ tests/ examples/
|
|
33
|
+
@printf "\033[1;34mRuff passes!\033[0m\n\n"
|
|
29
34
|
|
|
30
35
|
##@ Testing
|
|
31
36
|
test: ## Test code using pytest.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
18
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
18
19
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
19
|
-
Requires-Python:
|
|
20
|
+
Requires-Python: <=3.13,>=3.10
|
|
20
21
|
Requires-Dist: beartype>0.16.1
|
|
21
|
-
Requires-Dist: cola-ml>=0.0.7
|
|
22
22
|
Requires-Dist: flax>=0.10.0
|
|
23
23
|
Requires-Dist: jax>=0.5.0
|
|
24
24
|
Requires-Dist: jaxlib>=0.5.0
|
|
@@ -27,6 +27,47 @@ Requires-Dist: numpy>=2.0.0
|
|
|
27
27
|
Requires-Dist: numpyro
|
|
28
28
|
Requires-Dist: optax>0.2.1
|
|
29
29
|
Requires-Dist: tqdm>4.66.2
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
32
|
+
Requires-Dist: autoflake; extra == 'dev'
|
|
33
|
+
Requires-Dist: black; extra == 'dev'
|
|
34
|
+
Requires-Dist: codespell>=2.2.4; extra == 'dev'
|
|
35
|
+
Requires-Dist: coverage>=7.2.2; extra == 'dev'
|
|
36
|
+
Requires-Dist: interrogate>=1.5.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: isort; extra == 'dev'
|
|
38
|
+
Requires-Dist: jupytext; extra == 'dev'
|
|
39
|
+
Requires-Dist: mktestdocs>=0.2.1; extra == 'dev'
|
|
40
|
+
Requires-Dist: networkx; extra == 'dev'
|
|
41
|
+
Requires-Dist: pre-commit>=3.2.2; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest-beartype; extra == 'dev'
|
|
43
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: pytest-pretty>=1.1.1; extra == 'dev'
|
|
45
|
+
Requires-Dist: pytest-xdist>=3.2.1; extra == 'dev'
|
|
46
|
+
Requires-Dist: pytest>=7.2.2; extra == 'dev'
|
|
47
|
+
Requires-Dist: ruff>=0.6; extra == 'dev'
|
|
48
|
+
Requires-Dist: xdoctest>=1.1.1; extra == 'dev'
|
|
49
|
+
Provides-Extra: docs
|
|
50
|
+
Requires-Dist: blackjax>=0.9.6; extra == 'docs'
|
|
51
|
+
Requires-Dist: ipykernel>=6.22.0; extra == 'docs'
|
|
52
|
+
Requires-Dist: ipython>=8.11.0; extra == 'docs'
|
|
53
|
+
Requires-Dist: ipywidgets>=8.0.5; extra == 'docs'
|
|
54
|
+
Requires-Dist: jupytext>=1.14.5; extra == 'docs'
|
|
55
|
+
Requires-Dist: markdown-katex>=202406.1035; extra == 'docs'
|
|
56
|
+
Requires-Dist: matplotlib>=3.7.1; extra == 'docs'
|
|
57
|
+
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
|
|
58
|
+
Requires-Dist: mkdocs-git-authors-plugin>=0.7.0; extra == 'docs'
|
|
59
|
+
Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
60
|
+
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
61
|
+
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
62
|
+
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<0.28.0; extra == 'docs'
|
|
64
|
+
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
65
|
+
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
66
|
+
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
67
|
+
Requires-Dist: pymdown-extensions>=10.7.1; extra == 'docs'
|
|
68
|
+
Requires-Dist: scikit-learn>=1.5.1; extra == 'docs'
|
|
69
|
+
Requires-Dist: seaborn>=0.12.2; extra == 'docs'
|
|
70
|
+
Requires-Dist: watermark>=2.3.1; extra == 'docs'
|
|
30
71
|
Description-Content-Type: text/markdown
|
|
31
72
|
|
|
32
73
|
<!-- <h1 align='center'>GPJax</h1>
|
|
@@ -41,12 +82,12 @@ Description-Content-Type: text/markdown
|
|
|
41
82
|
[](https://badge.fury.io/py/GPJax)
|
|
42
83
|
[](https://doi.org/10.21105/joss.04455)
|
|
43
84
|
[](https://pepy.tech/project/gpjax)
|
|
44
|
-
[](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
85
|
+
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
45
86
|
|
|
46
87
|
[**Quickstart**](#simple-example)
|
|
47
88
|
| [**Install guide**](#installation)
|
|
48
89
|
| [**Documentation**](https://docs.jaxgaussianprocesses.com/)
|
|
49
|
-
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
90
|
+
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
50
91
|
|
|
51
92
|
GPJax aims to provide a low-level interface to Gaussian process (GP) models in
|
|
52
93
|
[Jax](https://github.com/google/jax), structured to give researchers maximum
|
|
@@ -81,7 +122,7 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
|
|
|
81
122
|
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
|
|
82
123
|
|
|
83
124
|
Feel free to join our [Slack
|
|
84
|
-
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
125
|
+
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
|
|
85
126
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
86
127
|
process modelling.
|
|
87
128
|
|
|
@@ -177,14 +218,13 @@ configuration in development mode.
|
|
|
177
218
|
```bash
|
|
178
219
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
179
220
|
cd GPJax
|
|
180
|
-
|
|
181
|
-
hatch shell
|
|
221
|
+
uv sync --extra dev
|
|
182
222
|
```
|
|
183
223
|
|
|
184
224
|
> We recommend you check your installation passes the supplied unit tests:
|
|
185
225
|
>
|
|
186
226
|
> ```python
|
|
187
|
-
>
|
|
227
|
+
> uv run pytest --beartype-packages='gpjax'
|
|
188
228
|
> ```
|
|
189
229
|
|
|
190
230
|
# Citing GPJax
|
|
@@ -10,12 +10,12 @@
|
|
|
10
10
|
[](https://badge.fury.io/py/GPJax)
|
|
11
11
|
[](https://doi.org/10.21105/joss.04455)
|
|
12
12
|
[](https://pepy.tech/project/gpjax)
|
|
13
|
-
[](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
13
|
+
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
14
14
|
|
|
15
15
|
[**Quickstart**](#simple-example)
|
|
16
16
|
| [**Install guide**](#installation)
|
|
17
17
|
| [**Documentation**](https://docs.jaxgaussianprocesses.com/)
|
|
18
|
-
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
18
|
+
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
19
19
|
|
|
20
20
|
GPJax aims to provide a low-level interface to Gaussian process (GP) models in
|
|
21
21
|
[Jax](https://github.com/google/jax), structured to give researchers maximum
|
|
@@ -50,7 +50,7 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
|
|
|
50
50
|
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
|
|
51
51
|
|
|
52
52
|
Feel free to join our [Slack
|
|
53
|
-
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
53
|
+
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
|
|
54
54
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
55
55
|
process modelling.
|
|
56
56
|
|
|
@@ -146,14 +146,13 @@ configuration in development mode.
|
|
|
146
146
|
```bash
|
|
147
147
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
148
148
|
cd GPJax
|
|
149
|
-
|
|
150
|
-
hatch shell
|
|
149
|
+
uv sync --extra dev
|
|
151
150
|
```
|
|
152
151
|
|
|
153
152
|
> We recommend you check your installation passes the supplied unit tests:
|
|
154
153
|
>
|
|
155
154
|
> ```python
|
|
156
|
-
>
|
|
155
|
+
> uv run pytest --beartype-packages='gpjax'
|
|
157
156
|
> ```
|
|
158
157
|
|
|
159
158
|
# Citing GPJax
|
|
@@ -72,16 +72,16 @@ you through every detail!
|
|
|
72
72
|
Always use a `feature` branch. It's good practice to avoid
|
|
73
73
|
work on the ``main`` branch of any repository.
|
|
74
74
|
|
|
75
|
-
4. We use [
|
|
75
|
+
4. We use [uv](https://docs.astral.sh/uv/) for packaging and dependency management. Project requirements are in ``pyproject.toml``. To install GPJax with uv, run:
|
|
76
76
|
|
|
77
77
|
```bash
|
|
78
|
-
$
|
|
78
|
+
$ uv sync --extra dev
|
|
79
79
|
```
|
|
80
80
|
|
|
81
81
|
At this point we recommend you check your installation passes the supplied unit tests:
|
|
82
82
|
|
|
83
83
|
```bash
|
|
84
|
-
$
|
|
84
|
+
$ uv run poe all-tests
|
|
85
85
|
```
|
|
86
86
|
|
|
87
87
|
5. Add changed files using `git add` and then `git commit` files to record your
|
|
@@ -142,7 +142,7 @@ request, we recommend you check the following:
|
|
|
142
142
|
accepted. Test coverage can be checked with:
|
|
143
143
|
|
|
144
144
|
```bash
|
|
145
|
-
$
|
|
145
|
+
$ uv run poe coverage
|
|
146
146
|
```
|
|
147
147
|
|
|
148
148
|
Navigate to the newly created folder `htmlcov` and open `index.html` to view
|
|
@@ -31,7 +31,7 @@ hardware acceleration support as detailed in the
|
|
|
31
31
|
```bash
|
|
32
32
|
git clone https://github.com/thomaspinder/GPJax.git
|
|
33
33
|
cd GPJax
|
|
34
|
-
|
|
34
|
+
uv sync --extra dev
|
|
35
35
|
```
|
|
36
36
|
|
|
37
37
|
!!! tip
|
|
@@ -45,5 +45,5 @@ hardware acceleration support as detailed in the
|
|
|
45
45
|
and recommend you check your installation passes the supplied unit tests:
|
|
46
46
|
|
|
47
47
|
```bash
|
|
48
|
-
|
|
48
|
+
uv run poe all-tests
|
|
49
49
|
```
|
|
@@ -8,9 +8,9 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.
|
|
11
|
+
# jupytext_version: 1.11.2
|
|
12
12
|
# kernelspec:
|
|
13
|
-
# display_name:
|
|
13
|
+
# display_name: .venv
|
|
14
14
|
# language: python
|
|
15
15
|
# name: python3
|
|
16
16
|
# ---
|
|
@@ -22,7 +22,6 @@
|
|
|
22
22
|
# with non-Gaussian likelihoods via maximum a posteriori (MAP). We focus on a classification task here.
|
|
23
23
|
|
|
24
24
|
# %%
|
|
25
|
-
import cola
|
|
26
25
|
from flax import nnx
|
|
27
26
|
import jax
|
|
28
27
|
|
|
@@ -41,7 +40,7 @@ import numpyro.distributions as npd
|
|
|
41
40
|
import optax as ox
|
|
42
41
|
|
|
43
42
|
from examples.utils import use_mpl_style
|
|
44
|
-
from gpjax.
|
|
43
|
+
from gpjax.linalg import lower_cholesky, PSD, solve
|
|
45
44
|
|
|
46
45
|
config.update("jax_enable_x64", True)
|
|
47
46
|
|
|
@@ -219,7 +218,7 @@ jitter = 1e-6
|
|
|
219
218
|
# Compute (latent) function value map estimates at training points:
|
|
220
219
|
Kxx = opt_posterior.prior.kernel.gram(x)
|
|
221
220
|
Kxx += identity_matrix(D.n) * jitter
|
|
222
|
-
Kxx =
|
|
221
|
+
Kxx = PSD(Kxx)
|
|
223
222
|
Lx = lower_cholesky(Kxx)
|
|
224
223
|
f_hat = Lx @ opt_posterior.latent.value
|
|
225
224
|
|
|
@@ -267,10 +266,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNorma
|
|
|
267
266
|
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
|
|
268
267
|
Kxx = opt_posterior.prior.kernel.gram(x)
|
|
269
268
|
Kxx += identity_matrix(D.n) * jitter
|
|
270
|
-
Kxx =
|
|
269
|
+
Kxx = PSD(Kxx)
|
|
271
270
|
|
|
272
271
|
# Kxx⁻¹ Kxt
|
|
273
|
-
Kxx_inv_Kxt =
|
|
272
|
+
Kxx_inv_Kxt = solve(Kxx, Kxt)
|
|
274
273
|
|
|
275
274
|
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
|
|
276
275
|
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.
|
|
11
|
+
# jupytext_version: 1.17.3
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -95,7 +95,7 @@ meanf = gpx.mean_functions.Zero()
|
|
|
95
95
|
for k, ax, c in zip(kernels, axes.ravel(), cols, strict=False):
|
|
96
96
|
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
|
|
97
97
|
rv = prior(x)
|
|
98
|
-
y = rv.sample(
|
|
98
|
+
y = rv.sample(key=jr.key(22), sample_shape=(10,))
|
|
99
99
|
ax.plot(x, y.T, alpha=0.7, color=c)
|
|
100
100
|
ax.set_title(k.name)
|
|
101
101
|
|
|
@@ -326,7 +326,7 @@ opt_posterior, history = gpx.fit_scipy(
|
|
|
326
326
|
# %%
|
|
327
327
|
posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
|
|
328
328
|
mu = posterior_rv.mean
|
|
329
|
-
one_sigma = posterior_rv.
|
|
329
|
+
one_sigma = jnp.sqrt(posterior_rv.variance)
|
|
330
330
|
|
|
331
331
|
# %%
|
|
332
332
|
fig = plt.figure(figsize=(7, 3.5))
|