gpjax 0.9.5__tar.gz → 0.10.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. gpjax-0.10.1/.cursorrules +37 -0
  2. {gpjax-0.9.5 → gpjax-0.10.1}/PKG-INFO +6 -6
  3. gpjax-0.10.1/examples/oak_example.py +214 -0
  4. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/__init__.py +1 -1
  5. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/base.py +4 -1
  6. gpjax-0.10.1/gpjax/kernels/nonstationary/oak.py +406 -0
  7. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/mean_functions.py +4 -3
  8. {gpjax-0.9.5 → gpjax-0.10.1}/pyproject.toml +5 -5
  9. gpjax-0.10.1/tests/kernels/nonstationary/test_oak.py +208 -0
  10. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_mean_functions.py +25 -32
  11. {gpjax-0.9.5 → gpjax-0.10.1}/.github/CODE_OF_CONDUCT.md +0 -0
  12. {gpjax-0.9.5 → gpjax-0.10.1}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  13. {gpjax-0.9.5 → gpjax-0.10.1}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  14. {gpjax-0.9.5 → gpjax-0.10.1}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  15. {gpjax-0.9.5 → gpjax-0.10.1}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  16. {gpjax-0.9.5 → gpjax-0.10.1}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  17. {gpjax-0.9.5 → gpjax-0.10.1}/.github/codecov.yml +0 -0
  18. {gpjax-0.9.5 → gpjax-0.10.1}/.github/labels.yml +0 -0
  19. {gpjax-0.9.5 → gpjax-0.10.1}/.github/pull_request_template.md +0 -0
  20. {gpjax-0.9.5 → gpjax-0.10.1}/.github/release-drafter.yml +0 -0
  21. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/build_docs.yml +0 -0
  22. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/integration.yml +0 -0
  23. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/pr_greeting.yml +0 -0
  24. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/ruff.yml +0 -0
  25. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/stale_prs.yml +0 -0
  26. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/test_docs.yml +0 -0
  27. {gpjax-0.9.5 → gpjax-0.10.1}/.github/workflows/tests.yml +0 -0
  28. {gpjax-0.9.5 → gpjax-0.10.1}/.gitignore +0 -0
  29. {gpjax-0.9.5 → gpjax-0.10.1}/CITATION.bib +0 -0
  30. {gpjax-0.9.5 → gpjax-0.10.1}/LICENSE.txt +0 -0
  31. {gpjax-0.9.5 → gpjax-0.10.1}/Makefile +0 -0
  32. {gpjax-0.9.5 → gpjax-0.10.1}/README.md +0 -0
  33. {gpjax-0.9.5 → gpjax-0.10.1}/docs/CODE_OF_CONDUCT.md +0 -0
  34. {gpjax-0.9.5 → gpjax-0.10.1}/docs/GOVERNANCE.md +0 -0
  35. {gpjax-0.9.5 → gpjax-0.10.1}/docs/contributing.md +0 -0
  36. {gpjax-0.9.5 → gpjax-0.10.1}/docs/design.md +0 -0
  37. {gpjax-0.9.5 → gpjax-0.10.1}/docs/index.md +0 -0
  38. {gpjax-0.9.5 → gpjax-0.10.1}/docs/index.rst +0 -0
  39. {gpjax-0.9.5 → gpjax-0.10.1}/docs/installation.md +0 -0
  40. {gpjax-0.9.5 → gpjax-0.10.1}/docs/javascripts/katex.js +0 -0
  41. {gpjax-0.9.5 → gpjax-0.10.1}/docs/refs.bib +0 -0
  42. {gpjax-0.9.5 → gpjax-0.10.1}/docs/scripts/gen_examples.py +0 -0
  43. {gpjax-0.9.5 → gpjax-0.10.1}/docs/scripts/gen_pages.py +0 -0
  44. {gpjax-0.9.5 → gpjax-0.10.1}/docs/scripts/notebook_converter.py +0 -0
  45. {gpjax-0.9.5 → gpjax-0.10.1}/docs/scripts/sharp_bits_figure.py +0 -0
  46. {gpjax-0.9.5 → gpjax-0.10.1}/docs/sharp_bits.md +0 -0
  47. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/GP.pdf +0 -0
  48. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/GP.svg +0 -0
  49. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/bijector_figure.svg +0 -0
  50. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/css/gpjax_theme.css +0 -0
  51. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/favicon.ico +0 -0
  52. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/gpjax.mplstyle +0 -0
  53. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/gpjax_logo.pdf +0 -0
  54. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/gpjax_logo.svg +0 -0
  55. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/jaxkern/lato.ttf +0 -0
  56. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/jaxkern/logo.png +0 -0
  57. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/jaxkern/logo.svg +0 -0
  58. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/jaxkern/main.py +0 -0
  59. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/step_size_figure.png +0 -0
  60. {gpjax-0.9.5 → gpjax-0.10.1}/docs/static/step_size_figure.svg +0 -0
  61. {gpjax-0.9.5 → gpjax-0.10.1}/docs/stylesheets/extra.css +0 -0
  62. {gpjax-0.9.5 → gpjax-0.10.1}/docs/stylesheets/permalinks.css +0 -0
  63. {gpjax-0.9.5 → gpjax-0.10.1}/examples/backend.py +0 -0
  64. {gpjax-0.9.5 → gpjax-0.10.1}/examples/barycentres/barycentre_gp.gif +0 -0
  65. {gpjax-0.9.5 → gpjax-0.10.1}/examples/barycentres.py +0 -0
  66. {gpjax-0.9.5 → gpjax-0.10.1}/examples/classification.py +0 -0
  67. {gpjax-0.9.5 → gpjax-0.10.1}/examples/collapsed_vi.py +0 -0
  68. {gpjax-0.9.5 → gpjax-0.10.1}/examples/constructing_new_kernels.py +0 -0
  69. {gpjax-0.9.5 → gpjax-0.10.1}/examples/data/max_tempeature_switzerland.csv +0 -0
  70. {gpjax-0.9.5 → gpjax-0.10.1}/examples/data/yacht_hydrodynamics.data +0 -0
  71. {gpjax-0.9.5 → gpjax-0.10.1}/examples/deep_kernels.py +0 -0
  72. {gpjax-0.9.5 → gpjax-0.10.1}/examples/gpjax.mplstyle +0 -0
  73. {gpjax-0.9.5 → gpjax-0.10.1}/examples/graph_kernels.py +0 -0
  74. {gpjax-0.9.5 → gpjax-0.10.1}/examples/intro_to_gps/decomposed_mll.png +0 -0
  75. {gpjax-0.9.5 → gpjax-0.10.1}/examples/intro_to_gps/generating_process.png +0 -0
  76. {gpjax-0.9.5 → gpjax-0.10.1}/examples/intro_to_gps.py +0 -0
  77. {gpjax-0.9.5 → gpjax-0.10.1}/examples/intro_to_kernels.py +0 -0
  78. {gpjax-0.9.5 → gpjax-0.10.1}/examples/likelihoods_guide.py +0 -0
  79. {gpjax-0.9.5 → gpjax-0.10.1}/examples/oceanmodelling.py +0 -0
  80. {gpjax-0.9.5 → gpjax-0.10.1}/examples/poisson.py +0 -0
  81. {gpjax-0.9.5 → gpjax-0.10.1}/examples/regression.py +0 -0
  82. {gpjax-0.9.5 → gpjax-0.10.1}/examples/uncollapsed_vi.py +0 -0
  83. {gpjax-0.9.5 → gpjax-0.10.1}/examples/utils.py +0 -0
  84. {gpjax-0.9.5 → gpjax-0.10.1}/examples/yacht.py +0 -0
  85. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/citation.py +0 -0
  86. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/dataset.py +0 -0
  87. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/distributions.py +0 -0
  88. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/fit.py +0 -0
  89. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/gps.py +0 -0
  90. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/integrators.py +0 -0
  91. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/__init__.py +0 -0
  92. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/approximations/__init__.py +0 -0
  93. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/approximations/rff.py +0 -0
  94. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/__init__.py +0 -0
  95. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/base.py +0 -0
  96. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/basis_functions.py +0 -0
  97. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  98. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/dense.py +0 -0
  99. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/diagonal.py +0 -0
  100. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/computations/eigen.py +0 -0
  101. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  102. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/non_euclidean/graph.py +0 -0
  103. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/non_euclidean/utils.py +0 -0
  104. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/nonstationary/__init__.py +0 -0
  105. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  106. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/nonstationary/linear.py +0 -0
  107. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  108. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/__init__.py +0 -0
  109. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/base.py +0 -0
  110. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/matern12.py +0 -0
  111. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/matern32.py +0 -0
  112. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/matern52.py +0 -0
  113. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/periodic.py +0 -0
  114. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  115. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  116. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/rbf.py +0 -0
  117. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/utils.py +0 -0
  118. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/kernels/stationary/white.py +0 -0
  119. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/likelihoods.py +0 -0
  120. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/lower_cholesky.py +0 -0
  121. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/objectives.py +0 -0
  122. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/parameters.py +0 -0
  123. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/scan.py +0 -0
  124. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/typing.py +0 -0
  125. {gpjax-0.9.5 → gpjax-0.10.1}/gpjax/variational_families.py +0 -0
  126. {gpjax-0.9.5 → gpjax-0.10.1}/mkdocs.yml +0 -0
  127. {gpjax-0.9.5 → gpjax-0.10.1}/static/CONTRIBUTING.md +0 -0
  128. {gpjax-0.9.5 → gpjax-0.10.1}/static/paper.bib +0 -0
  129. {gpjax-0.9.5 → gpjax-0.10.1}/static/paper.md +0 -0
  130. {gpjax-0.9.5 → gpjax-0.10.1}/static/paper.pdf +0 -0
  131. {gpjax-0.9.5 → gpjax-0.10.1}/tests/__init__.py +0 -0
  132. {gpjax-0.9.5 → gpjax-0.10.1}/tests/conftest.py +0 -0
  133. {gpjax-0.9.5 → gpjax-0.10.1}/tests/integration_tests.py +0 -0
  134. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_citations.py +0 -0
  135. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_dataset.py +0 -0
  136. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_fit.py +0 -0
  137. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_gaussian_distribution.py +0 -0
  138. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_gps.py +0 -0
  139. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_integrators.py +0 -0
  140. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/__init__.py +0 -0
  141. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_approximations.py +0 -0
  142. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_base.py +0 -0
  143. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_computation.py +0 -0
  144. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_non_euclidean.py +0 -0
  145. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_nonstationary.py +0 -0
  146. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_stationary.py +0 -0
  147. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_kernels/test_utils.py +0 -0
  148. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_likelihoods.py +0 -0
  149. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_lower_cholesky.py +0 -0
  150. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_markdown.py +0 -0
  151. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_objectives.py +0 -0
  152. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_parameters.py +0 -0
  153. {gpjax-0.9.5 → gpjax-0.10.1}/tests/test_variational_families.py +0 -0
@@ -0,0 +1,37 @@
1
+ You are an AI assistant specialized in Python development and machine learning. Your approach emphasizes:
2
+
3
+ Clear project structure with separate directories for source code, tests, docs, and config.
4
+
5
+ Modular design with distinct files for models, services, controllers, and utilities.
6
+
7
+ Configuration management using environment variables.
8
+
9
+ Robust error handling and logging, including context capture.
10
+
11
+ Comprehensive testing with pytest.
12
+
13
+ Detailed documentation using docstrings and README files.
14
+
15
+ Code style consistency using Ruff.
16
+
17
+ CI/CD implementation with GitHub Actions or GitLab CI.
18
+
19
+ AI-friendly coding practices:
20
+
21
+ You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development.
22
+
23
+ Follow the following rules:
24
+
25
+ For any python file, be sure to ALWAYS add typing annotations to each function or class. Be sure to include return types when necessary. Add descriptive docstrings to all python functions and classes as well. Please use pep257 convention. Update existing docstrings if need be.
26
+
27
+ Make sure you keep any comments that exist in a file.
28
+
29
+ When writing tests, make sure that you ONLY use pytest or pytest plugins, do NOT use the unittest module. All tests should have typing annotations as well. All tests should be in ./tests. Be sure to create all necessary files and folders. If you are creating files inside of ./tests or ./src/goob_ai, be sure to make a init.py file if one does not exist.
30
+
31
+ All tests should be fully annotated and should contain docstrings. Be sure to import the following if TYPE_CHECKING:
32
+
33
+ from _pytest.capture import CaptureFixture
34
+ from _pytest.fixtures import FixtureRequest
35
+ from _pytest.logging import LogCaptureFixture
36
+ from _pytest.monkeypatch import MonkeyPatch
37
+ from pytest_mock.plugin import MockerFixture
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.9.5
3
+ Version: 0.10.1
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -18,12 +18,12 @@ Classifier: Programming Language :: Python :: Implementation :: CPython
18
18
  Classifier: Programming Language :: Python :: Implementation :: PyPy
19
19
  Requires-Python: <3.13,>=3.10
20
20
  Requires-Dist: beartype>0.16.1
21
- Requires-Dist: cola-ml==0.0.5
22
- Requires-Dist: flax<0.10.0
23
- Requires-Dist: jax<0.4.28
24
- Requires-Dist: jaxlib<0.4.28
21
+ Requires-Dist: cola-ml>=0.0.7
22
+ Requires-Dist: flax>=0.10.0
23
+ Requires-Dist: jax>=0.5.0
24
+ Requires-Dist: jaxlib>=0.5.0
25
25
  Requires-Dist: jaxtyping>0.2.10
26
- Requires-Dist: numpy<2.0.0
26
+ Requires-Dist: numpy>=2.0.0
27
27
  Requires-Dist: optax>0.2.1
28
28
  Requires-Dist: tensorflow-probability>=0.24.0
29
29
  Requires-Dist: tqdm>4.66.2
@@ -0,0 +1,214 @@
1
+ # -*- coding: utf-8 -*-
2
+ # ---
3
+ # jupyter:
4
+ # jupytext:
5
+ # cell_metadata_filter: -all
6
+ # custom_cell_magics: kql
7
+ # text_representation:
8
+ # extension: .py
9
+ # format_name: percent
10
+ # format_version: '1.3'
11
+ # jupytext_version: 1.11.2
12
+ # kernelspec:
13
+ # display_name: docs
14
+ # language: python
15
+ # name: python3
16
+ # ---
17
+
18
+ # %% [markdown]
19
+ # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
20
+ #
21
+ # Licensed under the Apache License, Version 2.0 (the "License");
22
+ # you may not use this file except in compliance with the License.
23
+ # You may obtain a copy of the License at
24
+ #
25
+ # http://www.apache.org/licenses/LICENSE-2.0
26
+ #
27
+ # Unless required by applicable law or agreed to in writing, software
28
+ # distributed under the License is distributed on an "AS IS" BASIS,
29
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
+ # See the License for the specific language governing permissions and
31
+ # limitations under the License.
32
+ # ==============================================================================
33
+
34
+ # %%
35
+ """Example of using the OrthogonalAdditiveKernel."""
36
+
37
+ # %%
38
+ import jax
39
+ from jax import config
40
+
41
+ config.update("jax_enable_x64", True) # Enable Float64 precision
42
+
43
+ import jax.numpy as jnp
44
+ import matplotlib.pyplot as plt
45
+ from matplotlib.colors import ListedColormap
46
+ import optax
47
+
48
+ import gpjax as gpx
49
+ from gpjax.dataset import Dataset
50
+ from gpjax.kernels import OrthogonalAdditiveKernel, RBF
51
+ from gpjax.typing import KeyArray
52
+
53
+
54
+ # %%
55
+ def f(x: jnp.ndarray) -> jnp.ndarray:
56
+ """Additive function with mixed dependencies:
57
+ f(x) = sin(π*x₁) + 2*cos(2π*x₂) + 0.5*sin(3π*x₁*x₂)
58
+
59
+ Args:
60
+ x: Input points array with shape (..., 2)
61
+
62
+ Returns:
63
+ Function values at the input points
64
+ """
65
+ return (
66
+ jnp.sin(jnp.pi * x[..., 0])
67
+ + 2.0 * jnp.cos(2.0 * jnp.pi * x[..., 1])
68
+ + 0.5 * jnp.sin(3.0 * jnp.pi * x[..., 0] * x[..., 1])
69
+ )
70
+
71
+
72
+ # %%
73
+ def generate_data(
74
+ key: KeyArray, n_train: int = 100, noise_std: float = 0.1
75
+ ) -> tuple[Dataset, jnp.ndarray, jnp.ndarray]:
76
+ """Generate synthetic training data.
77
+
78
+ Args:
79
+ key: JAX PRNG key for random number generation
80
+ n_train: Number of training points to generate
81
+ noise_std: Standard deviation of Gaussian observation noise
82
+
83
+ Returns:
84
+ Tuple of (training_data, X_test, meshgrid_for_plotting)
85
+ """
86
+ key1, key2, key3 = jax.random.split(key, 3)
87
+
88
+ # Generate training data
89
+ X_train = jax.random.uniform(key1, (n_train, 2))
90
+ y_train = f(X_train) + noise_std * jax.random.normal(key2, (n_train,))
91
+
92
+ training_data = Dataset(X=X_train, y=y_train[:, None])
93
+
94
+ # Generate test points for prediction
95
+ n_test = 20
96
+ x_range = jnp.linspace(0.0, 1.0, n_test)
97
+ X1, X2 = jnp.meshgrid(x_range, x_range)
98
+ X_test = jnp.vstack([X1.flatten(), X2.flatten()]).T
99
+
100
+ return training_data, X_test, (X1, X2)
101
+
102
+
103
+ # %%
104
+ def main():
105
+ # Set random seed for reproducibility
106
+ key = jax.random.PRNGKey(42)
107
+
108
+ # Generate synthetic training data
109
+ training_data, X_test, (X1, X2) = generate_data(key, n_train=100, noise_std=0.1)
110
+
111
+ # Create base kernel (RBF)
112
+ base_kernel = RBF(lengthscale=0.2)
113
+
114
+ # Create OAK kernel with second-order interactions
115
+ oak_kernel = OrthogonalAdditiveKernel(
116
+ base_kernel=base_kernel,
117
+ dim=2,
118
+ quad_deg=20,
119
+ second_order=True,
120
+ )
121
+
122
+ # Create a GP prior model
123
+ prior = gpx.gps.Prior(
124
+ mean_function=gpx.mean_functions.Zero(),
125
+ kernel=oak_kernel,
126
+ )
127
+
128
+ # Create a likelihood
129
+ likelihood = gpx.likelihoods.Gaussian(num_datapoints=training_data.n)
130
+
131
+ # Create the posterior
132
+ posterior = prior * likelihood
133
+
134
+ # Create parameter optimizer
135
+ optimizer = optax.adam(learning_rate=0.01)
136
+
137
+ # Define objective function for training
138
+ def objective(model, data):
139
+ return -model.mll(model.params, data)
140
+
141
+ # Optimize hyperparameters
142
+ opt_posterior, history = gpx.fit(
143
+ model=posterior,
144
+ objective=objective,
145
+ train_data=training_data,
146
+ optim=optimizer,
147
+ num_iters=300,
148
+ key=key,
149
+ verbose=True,
150
+ )
151
+
152
+ # Plot training curve
153
+ plt.figure(figsize=(10, 4))
154
+ plt.subplot(1, 2, 1)
155
+ plt.plot(history)
156
+ plt.title("Negative Log Marginal Likelihood")
157
+ plt.xlabel("Iteration")
158
+ plt.ylabel("NLML")
159
+
160
+ # Get posterior predictions
161
+ latent_dist = opt_posterior.predict(params=opt_posterior.params, x=X_test)
162
+ predictive_dist = opt_posterior.likelihood.condition(
163
+ latent_dist, opt_posterior.params
164
+ )
165
+ mu = predictive_dist.mean().reshape(X1.shape)
166
+ std = predictive_dist.stddev().reshape(X1.shape)
167
+
168
+ # Plot predictions
169
+ plt.subplot(1, 2, 2)
170
+ plt.contourf(X1, X2, mu, 50, cmap="viridis")
171
+ plt.colorbar(label="Predicted Mean")
172
+ plt.scatter(
173
+ training_data.X[:, 0],
174
+ training_data.X[:, 1],
175
+ c=training_data.y,
176
+ cmap=ListedColormap(["red", "blue"]),
177
+ alpha=0.6,
178
+ s=20,
179
+ edgecolors="k",
180
+ )
181
+ plt.title("OAK GP Predictions")
182
+ plt.xlabel("$x_1$")
183
+ plt.ylabel("$x_2$")
184
+
185
+ plt.tight_layout()
186
+ plt.savefig("oak_example.png", dpi=300)
187
+ plt.show()
188
+
189
+ # Print learned kernel parameters
190
+ print("\nLearned Parameters:")
191
+ print(f"Offset coefficient: {opt_posterior.params.kernel.offset.value}")
192
+ print(f"First-order coefficients: {opt_posterior.params.kernel.coeffs_1.value}")
193
+
194
+ # Analyze the importance of each dimension
195
+ importance_1st_order = opt_posterior.params.kernel.coeffs_1.value
196
+ total_importance = jnp.sum(importance_1st_order)
197
+ relative_importance = importance_1st_order / total_importance
198
+
199
+ print("\nRelative Importance of Input Dimensions:")
200
+ for i, imp in enumerate(relative_importance):
201
+ print(f"Dimension {i+1}: {imp:.4f}")
202
+
203
+ if opt_posterior.params.kernel.coeffs_2 is not None:
204
+ # Analyze second-order interactions
205
+ coeffs_2 = opt_posterior.params.kernel.coeffs_2
206
+ print("\nSecond-order Interaction Coefficient:")
207
+ print(f"{coeffs_2[0, 1]:.4f}")
208
+
209
+
210
+ # %%
211
+ if __name__ == "__main__":
212
+ main()
213
+
214
+ # %%
@@ -39,7 +39,7 @@ __license__ = "MIT"
39
39
  __description__ = "Didactic Gaussian processes in JAX"
40
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.9.5"
42
+ __version__ = "0.10.1"
43
43
 
44
44
  __all__ = [
45
45
  "base",
@@ -32,6 +32,7 @@ from gpjax.kernels.computations import (
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
35
+ Static,
35
36
  )
36
37
  from gpjax.typing import (
37
38
  Array,
@@ -220,7 +221,9 @@ class Constant(AbstractKernel):
220
221
  def __init__(
221
222
  self,
222
223
  active_dims: tp.Union[list[int], slice, None] = None,
223
- constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
224
+ constant: tp.Union[
225
+ ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226
+ ] = jnp.array(0.0),
224
227
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
225
228
  ):
226
229
  if isinstance(constant, Parameter):