gpjax 0.13.0__tar.gz → 0.13.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/build_docs.yml +2 -2
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/integration.yml +1 -1
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/pr_greeting.yml +2 -2
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/release.yml +7 -7
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/security-analysis.yml +2 -2
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/test_docs.yml +1 -1
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/tests.yml +1 -1
- {gpjax-0.13.0 → gpjax-0.13.1}/PKG-INFO +1 -1
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/__init__.py +1 -1
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/eigen.py +1 -15
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/non_euclidean/graph.py +7 -6
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/non_euclidean/utils.py +30 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/variational_families.py +69 -5
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_variational_families.py +59 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/codecov.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/commitlint.config.js +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/dependabot.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/labeler.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/labels.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/pull_request_template.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/release-drafter.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/auto-label.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/commit-lint.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/.gitignore +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/CITATION.bib +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/LICENSE.txt +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/Makefile +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/README.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/contributing.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/design.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/index.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/index.rst +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/installation.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/javascripts/katex.js +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/refs.bib +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/sharp_bits.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/GP.pdf +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/GP.svg +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/favicon.ico +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/backend.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/barycentres.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/classification.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/collapsed_vi.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/deep_kernels.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/graph_kernels.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/intro_to_gps.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/oceanmodelling.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/poisson.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/regression.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/utils.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/examples/yacht.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/citation.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/dataset.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/distributions.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/fit.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/gps.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/integrators.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/base.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/likelihoods.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/linalg/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/linalg/operations.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/linalg/operators.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/linalg/utils.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/mean_functions.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/objectives.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/parameters.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/scan.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/gpjax/typing.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/mkdocs.yml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/pyproject.toml +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/static/paper.bib +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/static/paper.md +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/static/paper.pdf +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/conftest.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/integration_tests.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_citations.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_dataset.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_fit.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_gps.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_imports.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_integrators.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_likelihoods.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_linalg.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_markdown.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_mean_functions.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_objectives.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/tests/test_parameters.py +0 -0
- {gpjax-0.13.0 → gpjax-0.13.1}/uv.lock +0 -0
|
@@ -40,7 +40,7 @@ jobs:
|
|
|
40
40
|
|
|
41
41
|
# Install katex for math support
|
|
42
42
|
- name: Install NPM
|
|
43
|
-
uses: actions/setup-node@
|
|
43
|
+
uses: actions/setup-node@v6
|
|
44
44
|
with:
|
|
45
45
|
node-version: 16
|
|
46
46
|
- name: Install KaTeX
|
|
@@ -49,7 +49,7 @@ jobs:
|
|
|
49
49
|
|
|
50
50
|
# Install uv
|
|
51
51
|
- name: Install uv
|
|
52
|
-
uses: astral-sh/setup-uv@
|
|
52
|
+
uses: astral-sh/setup-uv@v7
|
|
53
53
|
with:
|
|
54
54
|
version: "latest"
|
|
55
55
|
|
|
@@ -80,7 +80,7 @@ jobs:
|
|
|
80
80
|
python-version: ${{ matrix.python-version }}
|
|
81
81
|
|
|
82
82
|
- name: Install uv
|
|
83
|
-
uses: astral-sh/setup-uv@
|
|
83
|
+
uses: astral-sh/setup-uv@v7
|
|
84
84
|
|
|
85
85
|
- name: Install dependencies
|
|
86
86
|
run: |
|
|
@@ -116,7 +116,7 @@ jobs:
|
|
|
116
116
|
python-version: '3.11'
|
|
117
117
|
|
|
118
118
|
- name: Install uv
|
|
119
|
-
uses: astral-sh/setup-uv@
|
|
119
|
+
uses: astral-sh/setup-uv@v7
|
|
120
120
|
|
|
121
121
|
- name: Install dependencies
|
|
122
122
|
run: |
|
|
@@ -132,7 +132,7 @@ jobs:
|
|
|
132
132
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
|
|
133
133
|
|
|
134
134
|
- name: Upload security reports
|
|
135
|
-
uses: actions/upload-artifact@
|
|
135
|
+
uses: actions/upload-artifact@v5
|
|
136
136
|
with:
|
|
137
137
|
name: security-reports
|
|
138
138
|
path: |
|
|
@@ -154,7 +154,7 @@ jobs:
|
|
|
154
154
|
python-version: '3.11'
|
|
155
155
|
|
|
156
156
|
- name: Install uv
|
|
157
|
-
uses: astral-sh/setup-uv@
|
|
157
|
+
uses: astral-sh/setup-uv@v7
|
|
158
158
|
|
|
159
159
|
- name: Build package
|
|
160
160
|
run: |
|
|
@@ -166,7 +166,7 @@ jobs:
|
|
|
166
166
|
uv run twine check dist/*
|
|
167
167
|
|
|
168
168
|
- name: Upload build artifacts
|
|
169
|
-
uses: actions/upload-artifact@
|
|
169
|
+
uses: actions/upload-artifact@v5
|
|
170
170
|
with:
|
|
171
171
|
name: dist-packages
|
|
172
172
|
path: dist/
|
|
@@ -264,7 +264,7 @@ jobs:
|
|
|
264
264
|
uses: actions/checkout@v5
|
|
265
265
|
|
|
266
266
|
- name: Download build artifacts
|
|
267
|
-
uses: actions/download-artifact@
|
|
267
|
+
uses: actions/download-artifact@v6
|
|
268
268
|
with:
|
|
269
269
|
name: dist-packages
|
|
270
270
|
path: dist/
|
|
@@ -294,7 +294,7 @@ jobs:
|
|
|
294
294
|
|
|
295
295
|
steps:
|
|
296
296
|
- name: Download build artifacts
|
|
297
|
-
uses: actions/download-artifact@
|
|
297
|
+
uses: actions/download-artifact@v6
|
|
298
298
|
with:
|
|
299
299
|
name: dist-packages
|
|
300
300
|
path: dist/
|
|
@@ -28,7 +28,7 @@ jobs:
|
|
|
28
28
|
python-version: '3.11'
|
|
29
29
|
|
|
30
30
|
- name: Install uv
|
|
31
|
-
uses: astral-sh/setup-uv@
|
|
31
|
+
uses: astral-sh/setup-uv@v7
|
|
32
32
|
with:
|
|
33
33
|
version: "latest"
|
|
34
34
|
|
|
@@ -47,7 +47,7 @@ jobs:
|
|
|
47
47
|
uv run bandit -r gpjax/ -f json -o bandit-report.json || true
|
|
48
48
|
|
|
49
49
|
- name: Upload dependency scan results
|
|
50
|
-
uses: actions/upload-artifact@
|
|
50
|
+
uses: actions/upload-artifact@v5
|
|
51
51
|
if: always()
|
|
52
52
|
with:
|
|
53
53
|
name: security-scan-results
|
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.13.
|
|
43
|
+
__version__ = "0.13.1"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
import beartype.typing as tp
|
|
18
|
-
import jax.numpy as jnp
|
|
19
18
|
from jaxtyping import (
|
|
20
19
|
Float,
|
|
21
20
|
Num,
|
|
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
|
|
|
39
38
|
def _cross_covariance(
|
|
40
39
|
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
|
|
41
40
|
) -> Float[Array, "N M"]:
|
|
42
|
-
|
|
43
|
-
# RBF kernel's SPDE form.
|
|
44
|
-
S = jnp.power(
|
|
45
|
-
kernel.eigenvalues
|
|
46
|
-
+ 2
|
|
47
|
-
* kernel.smoothness.value
|
|
48
|
-
/ kernel.lengthscale.value
|
|
49
|
-
/ kernel.lengthscale.value,
|
|
50
|
-
-kernel.smoothness.value,
|
|
51
|
-
)
|
|
52
|
-
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
|
|
53
|
-
# Scale the transform eigenvalues by the kernel variance
|
|
54
|
-
S = jnp.multiply(S, kernel.variance.value)
|
|
55
|
-
return kernel(x, y, S=S)
|
|
41
|
+
return kernel(x, y)
|
|
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
AbstractKernelComputation,
|
|
26
26
|
EigenKernelComputation,
|
|
27
27
|
)
|
|
28
|
-
from gpjax.kernels.non_euclidean.utils import
|
|
28
|
+
from gpjax.kernels.non_euclidean.utils import (
|
|
29
|
+
calculate_heat_semigroup,
|
|
30
|
+
jax_gather_nd,
|
|
31
|
+
)
|
|
29
32
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
30
33
|
from gpjax.parameters import (
|
|
31
34
|
Parameter,
|
|
@@ -98,14 +101,12 @@ class GraphKernel(StationaryKernel):
|
|
|
98
101
|
|
|
99
102
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
100
103
|
|
|
101
|
-
def __call__(
|
|
104
|
+
def __call__(
|
|
102
105
|
self,
|
|
103
106
|
x: Int[Array, "N 1"],
|
|
104
|
-
y: Int[Array, "
|
|
105
|
-
*,
|
|
106
|
-
S,
|
|
107
|
-
**kwargs,
|
|
107
|
+
y: Int[Array, "M 1"],
|
|
108
108
|
):
|
|
109
|
+
S = calculate_heat_semigroup(self)
|
|
109
110
|
Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
|
|
110
111
|
jax_gather_nd(self.eigenvectors, y)
|
|
111
112
|
) # shape (n,n)
|
|
@@ -13,6 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import beartype.typing as tp
|
|
19
|
+
import jax.numpy as jnp
|
|
16
20
|
from jaxtyping import (
|
|
17
21
|
Float,
|
|
18
22
|
Int,
|
|
@@ -20,6 +24,9 @@ from jaxtyping import (
|
|
|
20
24
|
|
|
21
25
|
from gpjax.typing import Array
|
|
22
26
|
|
|
27
|
+
if tp.TYPE_CHECKING:
|
|
28
|
+
from gpjax.kernels.non_euclidean.graph import GraphKernel
|
|
29
|
+
|
|
23
30
|
|
|
24
31
|
def jax_gather_nd(
|
|
25
32
|
params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
|
|
@@ -41,3 +48,26 @@ def jax_gather_nd(
|
|
|
41
48
|
"""
|
|
42
49
|
tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
|
|
43
50
|
return params[tuple_indices]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
|
|
54
|
+
r"""Returns the rescaled heat semigroup, S
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
kernel: instance of the graph kernel
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
S
|
|
61
|
+
"""
|
|
62
|
+
S = jnp.power(
|
|
63
|
+
kernel.eigenvalues
|
|
64
|
+
+ 2
|
|
65
|
+
* kernel.smoothness.value
|
|
66
|
+
/ kernel.lengthscale.value
|
|
67
|
+
/ kernel.lengthscale.value,
|
|
68
|
+
-kernel.smoothness.value,
|
|
69
|
+
)
|
|
70
|
+
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
|
|
71
|
+
# Scale the transform eigenvalues by the kernel variance
|
|
72
|
+
S = jnp.multiply(S, kernel.variance.value)
|
|
73
|
+
return S
|
|
@@ -19,7 +19,10 @@ import beartype.typing as tp
|
|
|
19
19
|
from flax import nnx
|
|
20
20
|
import jax.numpy as jnp
|
|
21
21
|
import jax.scipy as jsp
|
|
22
|
-
from jaxtyping import
|
|
22
|
+
from jaxtyping import (
|
|
23
|
+
Float,
|
|
24
|
+
Int,
|
|
25
|
+
)
|
|
23
26
|
|
|
24
27
|
from gpjax.dataset import Dataset
|
|
25
28
|
from gpjax.distributions import GaussianDistribution
|
|
@@ -108,6 +111,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
108
111
|
self,
|
|
109
112
|
posterior: AbstractPosterior[P, L],
|
|
110
113
|
inducing_inputs: tp.Union[
|
|
114
|
+
Int[Array, "N D"],
|
|
111
115
|
Float[Array, "N D"],
|
|
112
116
|
Real,
|
|
113
117
|
],
|
|
@@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
140
144
|
def __init__(
|
|
141
145
|
self,
|
|
142
146
|
posterior: AbstractPosterior[P, L],
|
|
143
|
-
inducing_inputs: Float[Array, "N D"],
|
|
147
|
+
inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
|
|
144
148
|
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
|
|
145
149
|
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
|
|
146
150
|
jitter: ScalarFloat = 1e-6,
|
|
@@ -156,6 +160,12 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
156
160
|
self.variational_mean = Real(variational_mean)
|
|
157
161
|
self.variational_root_covariance = LowerTriangular(variational_root_covariance)
|
|
158
162
|
|
|
163
|
+
def _fmt_Kzt_Ktt(self, Kzt, Ktt):
|
|
164
|
+
return Kzt, Ktt
|
|
165
|
+
|
|
166
|
+
def _fmt_inducing_inputs(self):
|
|
167
|
+
return self.inducing_inputs.value
|
|
168
|
+
|
|
159
169
|
def prior_kl(self) -> ScalarFloat:
|
|
160
170
|
r"""Compute the prior KL divergence.
|
|
161
171
|
|
|
@@ -178,7 +188,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
178
188
|
# Unpack variational parameters
|
|
179
189
|
variational_mean = self.variational_mean.value
|
|
180
190
|
variational_sqrt = self.variational_root_covariance.value
|
|
181
|
-
inducing_inputs = self.
|
|
191
|
+
inducing_inputs = self._fmt_inducing_inputs()
|
|
182
192
|
|
|
183
193
|
# Unpack mean function and kernel
|
|
184
194
|
mean_function = self.posterior.prior.mean_function
|
|
@@ -202,7 +212,9 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
202
212
|
|
|
203
213
|
return q_inducing.kl_divergence(p_inducing)
|
|
204
214
|
|
|
205
|
-
def predict(
|
|
215
|
+
def predict(
|
|
216
|
+
self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
|
|
217
|
+
) -> GaussianDistribution:
|
|
206
218
|
r"""Compute the predictive distribution of the GP at the test inputs t.
|
|
207
219
|
|
|
208
220
|
This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
|
|
@@ -222,7 +234,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
222
234
|
# Unpack variational parameters
|
|
223
235
|
variational_mean = self.variational_mean.value
|
|
224
236
|
variational_sqrt = self.variational_root_covariance.value
|
|
225
|
-
inducing_inputs = self.
|
|
237
|
+
inducing_inputs = self._fmt_inducing_inputs()
|
|
226
238
|
|
|
227
239
|
# Unpack mean function and kernel
|
|
228
240
|
mean_function = self.posterior.prior.mean_function
|
|
@@ -241,6 +253,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
241
253
|
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
|
|
242
254
|
test_mean = mean_function(test_points)
|
|
243
255
|
|
|
256
|
+
Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)
|
|
257
|
+
|
|
244
258
|
# Lz⁻¹ Kzt
|
|
245
259
|
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
246
260
|
|
|
@@ -259,8 +273,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
259
273
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
260
274
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
261
275
|
)
|
|
276
|
+
|
|
262
277
|
if hasattr(covariance, "to_dense"):
|
|
263
278
|
covariance = covariance.to_dense()
|
|
279
|
+
|
|
264
280
|
covariance = add_jitter(covariance, self.jitter)
|
|
265
281
|
covariance = Dense(covariance)
|
|
266
282
|
|
|
@@ -269,6 +285,53 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
269
285
|
)
|
|
270
286
|
|
|
271
287
|
|
|
288
|
+
class GraphVariationalGaussian(VariationalGaussian[L]):
|
|
289
|
+
r"""A variational Gaussian defined over graph-structured inducing inputs.
|
|
290
|
+
|
|
291
|
+
This subclass adapts the :class:`VariationalGaussian` family to the
|
|
292
|
+
case where the inducing inputs are discrete graph node indices rather
|
|
293
|
+
than continuous spatial coordinates.
|
|
294
|
+
|
|
295
|
+
The main differences are:
|
|
296
|
+
* Inducing inputs are integer node IDs.
|
|
297
|
+
* Kernel matrices are ensured to be dense and 2D.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(
|
|
301
|
+
self,
|
|
302
|
+
posterior: AbstractPosterior[P, L],
|
|
303
|
+
inducing_inputs: Int[Array, "N D"],
|
|
304
|
+
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
|
|
305
|
+
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
|
|
306
|
+
jitter: ScalarFloat = 1e-6,
|
|
307
|
+
):
|
|
308
|
+
super().__init__(
|
|
309
|
+
posterior,
|
|
310
|
+
inducing_inputs,
|
|
311
|
+
variational_mean,
|
|
312
|
+
variational_root_covariance,
|
|
313
|
+
jitter,
|
|
314
|
+
)
|
|
315
|
+
self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)
|
|
316
|
+
|
|
317
|
+
def _fmt_Kzt_Ktt(self, Kzt, Ktt):
|
|
318
|
+
Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
|
|
319
|
+
Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
|
|
320
|
+
Ktt = jnp.atleast_2d(Ktt)
|
|
321
|
+
Kzt = (
|
|
322
|
+
jnp.transpose(jnp.atleast_2d(Kzt)) if Kzt.ndim < 2 else jnp.atleast_2d(Kzt)
|
|
323
|
+
)
|
|
324
|
+
return Kzt, Ktt
|
|
325
|
+
|
|
326
|
+
def _fmt_inducing_inputs(self):
|
|
327
|
+
return self.inducing_inputs
|
|
328
|
+
|
|
329
|
+
@property
|
|
330
|
+
def num_inducing(self) -> int:
|
|
331
|
+
"""The number of inducing inputs."""
|
|
332
|
+
return self.inducing_inputs.shape[0]
|
|
333
|
+
|
|
334
|
+
|
|
272
335
|
class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
273
336
|
r"""The whitened variational Gaussian family of probability distributions.
|
|
274
337
|
|
|
@@ -811,6 +874,7 @@ __all__ = [
|
|
|
811
874
|
"AbstractVariationalFamily",
|
|
812
875
|
"AbstractVariationalGaussian",
|
|
813
876
|
"VariationalGaussian",
|
|
877
|
+
"GraphVariationalGaussian",
|
|
814
878
|
"WhitenedVariationalGaussian",
|
|
815
879
|
"NaturalVariationalGaussian",
|
|
816
880
|
"ExpectationVariationalGaussian",
|
|
@@ -25,6 +25,8 @@ from jaxtyping import (
|
|
|
25
25
|
Array,
|
|
26
26
|
Float,
|
|
27
27
|
)
|
|
28
|
+
import networkx as nx
|
|
29
|
+
import numpy as np
|
|
28
30
|
import numpyro.distributions as npd
|
|
29
31
|
from numpyro.distributions import Distribution as NumpyroDistribution
|
|
30
32
|
import pytest
|
|
@@ -35,6 +37,7 @@ from gpjax.variational_families import (
|
|
|
35
37
|
AbstractVariationalFamily,
|
|
36
38
|
CollapsedVariationalGaussian,
|
|
37
39
|
ExpectationVariationalGaussian,
|
|
40
|
+
GraphVariationalGaussian,
|
|
38
41
|
NaturalVariationalGaussian,
|
|
39
42
|
VariationalGaussian,
|
|
40
43
|
WhitenedVariationalGaussian,
|
|
@@ -118,6 +121,7 @@ def test_variational_gaussians(
|
|
|
118
121
|
)
|
|
119
122
|
likelihood = gpx.likelihoods.Gaussian(123)
|
|
120
123
|
inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1)
|
|
124
|
+
|
|
121
125
|
test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1)
|
|
122
126
|
|
|
123
127
|
posterior = prior * likelihood
|
|
@@ -174,6 +178,61 @@ def test_variational_gaussians(
|
|
|
174
178
|
assert sigma.shape == (n_test, n_test)
|
|
175
179
|
|
|
176
180
|
|
|
181
|
+
@pytest.mark.parametrize("n_test", [10, 20])
|
|
182
|
+
@pytest.mark.parametrize("n_inducing", [10, 20])
|
|
183
|
+
@pytest.mark.parametrize(
|
|
184
|
+
"variational_family",
|
|
185
|
+
[
|
|
186
|
+
GraphVariationalGaussian,
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
def test_graph_variational_gaussian(
|
|
190
|
+
n_test: int,
|
|
191
|
+
n_inducing: int,
|
|
192
|
+
variational_family: AbstractVariationalFamily,
|
|
193
|
+
) -> None:
|
|
194
|
+
G = nx.barbell_graph(100, 0)
|
|
195
|
+
L = nx.laplacian_matrix(G).toarray()
|
|
196
|
+
|
|
197
|
+
kernel = gpx.kernels.GraphKernel(
|
|
198
|
+
laplacian=L,
|
|
199
|
+
lengthscale=2.3,
|
|
200
|
+
variance=3.2,
|
|
201
|
+
smoothness=6.1,
|
|
202
|
+
)
|
|
203
|
+
meanf = gpx.mean_functions.Constant()
|
|
204
|
+
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
|
|
205
|
+
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=G.number_of_nodes())
|
|
206
|
+
|
|
207
|
+
inducing_inputs = jnp.array(
|
|
208
|
+
np.random.randint(low=1, high=100, size=(n_inducing, 1))
|
|
209
|
+
).astype(jnp.int64)
|
|
210
|
+
|
|
211
|
+
test_inputs = jnp.array(np.random.randint(low=0, high=1, size=(n_test, 1))).astype(
|
|
212
|
+
jnp.int64
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
posterior = prior * likelihood
|
|
216
|
+
q = variational_family(posterior=posterior, inducing_inputs=inducing_inputs)
|
|
217
|
+
# Test KL
|
|
218
|
+
kl = q.prior_kl()
|
|
219
|
+
assert isinstance(kl, jnp.ndarray)
|
|
220
|
+
assert kl.shape == ()
|
|
221
|
+
assert kl >= 0.0
|
|
222
|
+
|
|
223
|
+
# Test predictions
|
|
224
|
+
predictive_dist = q(test_inputs)
|
|
225
|
+
assert isinstance(predictive_dist, NumpyroDistribution)
|
|
226
|
+
|
|
227
|
+
mu = predictive_dist.mean
|
|
228
|
+
sigma = predictive_dist.covariance()
|
|
229
|
+
|
|
230
|
+
assert isinstance(mu, jnp.ndarray)
|
|
231
|
+
assert isinstance(sigma, jnp.ndarray)
|
|
232
|
+
assert mu.shape == (n_test,)
|
|
233
|
+
assert sigma.shape == (n_test, n_test)
|
|
234
|
+
|
|
235
|
+
|
|
177
236
|
@pytest.mark.parametrize("n_test", [1, 10])
|
|
178
237
|
@pytest.mark.parametrize("n_datapoints", [1, 10])
|
|
179
238
|
@pytest.mark.parametrize("n_inducing", [1, 10, 20])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|