gpjax 0.10.1__tar.gz → 0.10.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.10.1 → gpjax-0.10.2}/PKG-INFO +1 -1
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/sharp_bits.md +57 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/oak_example.py +9 -7
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/__init__.py +1 -1
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/polynomial.py +1 -1
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/parameters.py +88 -26
- gpjax-0.10.2/tests/test_parameters.py +113 -0
- gpjax-0.10.1/gpjax/kernels/nonstationary/oak.py +0 -406
- gpjax-0.10.1/tests/kernels/nonstationary/test_oak.py +0 -208
- gpjax-0.10.1/tests/test_parameters.py +0 -56
- {gpjax-0.10.1 → gpjax-0.10.2}/.cursorrules +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/codecov.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/labels.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/pull_request_template.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/release-drafter.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/integration.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/tests.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/.gitignore +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/CITATION.bib +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/LICENSE.txt +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/Makefile +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/README.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/contributing.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/design.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/index.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/index.rst +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/installation.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/javascripts/katex.js +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/refs.bib +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/GP.pdf +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/GP.svg +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/favicon.ico +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/backend.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/barycentres.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/classification.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/collapsed_vi.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/deep_kernels.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/graph_kernels.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/oceanmodelling.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/poisson.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/regression.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/utils.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/examples/yacht.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/citation.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/dataset.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/distributions.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/fit.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/gps.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/integrators.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/base.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/likelihoods.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/mean_functions.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/objectives.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/scan.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/typing.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/variational_families.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/mkdocs.yml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/pyproject.toml +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.bib +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.md +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.pdf +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/conftest.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/integration_tests.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_citations.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_dataset.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_fit.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_gps.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_integrators.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_likelihoods.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_markdown.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_mean_functions.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_objectives.py +0 -0
- {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_variational_families.py +0 -0
|
@@ -175,3 +175,60 @@ mini-batch optimisation of the parameters of your sparse Gaussian process model.
|
|
|
175
175
|
model will scale linearly in the batch size and quadratically in the number of inducing
|
|
176
176
|
points. We demonstrate its use in
|
|
177
177
|
[our sparse stochastic variational inference notebook](_examples/uncollapsed_vi.md).
|
|
178
|
+
|
|
179
|
+
## JIT compilation
|
|
180
|
+
|
|
181
|
+
There are a subset of operations in GPJax that are not JIT compatible by default. This
|
|
182
|
+
is because we have assertions in place to check the properties of the parameters. For
|
|
183
|
+
example, we check that the lengthscale parameter that a user provides is positive. This
|
|
184
|
+
makes for a better user experience as we can provide more informative error messages;
|
|
185
|
+
however, JIT compiling functions wherein these assertions are made will break the code.
|
|
186
|
+
As an example, consider the following code:
|
|
187
|
+
|
|
188
|
+
```python
|
|
189
|
+
import jax
|
|
190
|
+
import jax.numpy as jnp
|
|
191
|
+
import gpjax as gpx
|
|
192
|
+
|
|
193
|
+
x = jnp.linspace(0, 1, 10)[:, None]
|
|
194
|
+
|
|
195
|
+
def compute_gram(lengthscale):
|
|
196
|
+
k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
|
|
197
|
+
return k.gram(x)
|
|
198
|
+
|
|
199
|
+
compute_gram(1.0)
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
so far so good. However, if we try to JIT compile this function, we will get an error:
|
|
203
|
+
|
|
204
|
+
```python
|
|
205
|
+
jit_compute_gram = jax.jit(compute_gram)
|
|
206
|
+
try:
|
|
207
|
+
jit_compute_gram(1.0)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
print(e)
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
This error is due to the fact that the `RBF` kernel contains an assertion that checks
|
|
213
|
+
that the lengthscale is positive. It does not matter that the assertion is satisfied;
|
|
214
|
+
the very presence of the assertion will break JIT compilation.
|
|
215
|
+
|
|
216
|
+
To resolve this, we can use the `checkify` decorator to remove the assertion. This will
|
|
217
|
+
allow the function to be JIT compiled.
|
|
218
|
+
|
|
219
|
+
```python
|
|
220
|
+
from jax.experimental import checkify
|
|
221
|
+
|
|
222
|
+
jit_compute_gram = jax.jit(checkify.checkify(compute_gram))
|
|
223
|
+
error, value = jit_compute_gram(1.0)
|
|
224
|
+
```
|
|
225
|
+
By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
|
|
226
|
+
output of the assertion, and the second element is the value of the function.
|
|
227
|
+
|
|
228
|
+
This design is not perfect, and in an ideal world we would not enforce the user to wrap
|
|
229
|
+
their code in `checkify.checkify`. We are actively looking into cleaner ways to provide
|
|
230
|
+
guardrails in a less intrusive manner. However, for now, should you try to JIT compile
|
|
231
|
+
a component of GPJax wherein there is an assertion, you will need to wrap the function
|
|
232
|
+
in `checkify.checkify` as shown above.
|
|
233
|
+
|
|
234
|
+
For more on `checkify`, please see the [JAX Checkify Doc](https://docs.jax.dev/en/latest/debugging/checkify_guide.html).
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: docs
|
|
14
14
|
# language: python
|
|
@@ -37,19 +37,21 @@
|
|
|
37
37
|
# %%
|
|
38
38
|
import jax
|
|
39
39
|
from jax import config
|
|
40
|
-
|
|
41
|
-
config.update("jax_enable_x64", True) # Enable Float64 precision
|
|
42
|
-
|
|
43
40
|
import jax.numpy as jnp
|
|
44
|
-
import matplotlib.pyplot as plt
|
|
45
41
|
from matplotlib.colors import ListedColormap
|
|
42
|
+
import matplotlib.pyplot as plt
|
|
46
43
|
import optax
|
|
47
44
|
|
|
48
45
|
import gpjax as gpx
|
|
49
46
|
from gpjax.dataset import Dataset
|
|
50
|
-
from gpjax.kernels import
|
|
47
|
+
from gpjax.kernels import (
|
|
48
|
+
RBF,
|
|
49
|
+
OrthogonalAdditiveKernel,
|
|
50
|
+
)
|
|
51
51
|
from gpjax.typing import KeyArray
|
|
52
52
|
|
|
53
|
+
config.update("jax_enable_x64", True) # Enable Float64 precision
|
|
54
|
+
|
|
53
55
|
|
|
54
56
|
# %%
|
|
55
57
|
def f(x: jnp.ndarray) -> jnp.ndarray:
|
|
@@ -198,7 +200,7 @@ def main():
|
|
|
198
200
|
|
|
199
201
|
print("\nRelative Importance of Input Dimensions:")
|
|
200
202
|
for i, imp in enumerate(relative_importance):
|
|
201
|
-
print(f"Dimension {i+1}: {imp:.4f}")
|
|
203
|
+
print(f"Dimension {i + 1}: {imp:.4f}")
|
|
202
204
|
|
|
203
205
|
if opt_posterior.params.kernel.coeffs_2 is not None:
|
|
204
206
|
# Analyze second-order interactions
|
|
@@ -39,7 +39,7 @@ __license__ = "MIT"
|
|
|
39
39
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
40
40
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
41
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.10.
|
|
42
|
+
__version__ = "0.10.2"
|
|
43
43
|
|
|
44
44
|
__all__ = [
|
|
45
45
|
"base",
|
|
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
|
|
|
46
46
|
self,
|
|
47
47
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
48
48
|
degree: int = 2,
|
|
49
|
-
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] =
|
|
49
|
+
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
50
50
|
variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
51
51
|
n_dims: tp.Union[int, None] = None,
|
|
52
52
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
3
|
from flax import nnx
|
|
4
|
+
from jax.experimental import checkify
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import jax.tree_util as jtu
|
|
6
7
|
from jax.typing import ArrayLike
|
|
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
|
|
|
84
85
|
|
|
85
86
|
def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
|
|
86
87
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
87
|
-
|
|
88
|
-
_check_is_positive(self.value)
|
|
88
|
+
_safe_assert(_check_is_positive, self.value)
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class Real(Parameter[T]):
|
|
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
|
|
|
101
101
|
def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
|
|
102
102
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
103
103
|
|
|
104
|
-
|
|
104
|
+
# Only perform validation in non-JIT contexts
|
|
105
|
+
if (
|
|
106
|
+
not isinstance(value, jnp.ndarray)
|
|
107
|
+
or not getattr(value, "aval", None) is None
|
|
108
|
+
):
|
|
109
|
+
_safe_assert(
|
|
110
|
+
_check_in_bounds,
|
|
111
|
+
self.value,
|
|
112
|
+
low=jnp.array(0.0),
|
|
113
|
+
high=jnp.array(1.0),
|
|
114
|
+
)
|
|
105
115
|
|
|
106
116
|
|
|
107
117
|
class Static(nnx.Variable[T]):
|
|
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
|
|
|
120
130
|
def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
|
|
121
131
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
122
132
|
|
|
123
|
-
|
|
124
|
-
|
|
133
|
+
# Only perform validation in non-JIT contexts
|
|
134
|
+
if (
|
|
135
|
+
not isinstance(value, jnp.ndarray)
|
|
136
|
+
or not getattr(value, "aval", None) is None
|
|
137
|
+
):
|
|
138
|
+
_safe_assert(_check_is_square, self.value)
|
|
139
|
+
_safe_assert(_check_is_lower_triangular, self.value)
|
|
125
140
|
|
|
126
141
|
|
|
127
142
|
DEFAULT_BIJECTION = {
|
|
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
|
|
|
132
147
|
}
|
|
133
148
|
|
|
134
149
|
|
|
135
|
-
def _check_is_arraylike(value: T):
|
|
150
|
+
def _check_is_arraylike(value: T) -> None:
|
|
151
|
+
"""Check if a value is array-like.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
value: The value to check.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
TypeError: If the value is not array-like.
|
|
158
|
+
"""
|
|
136
159
|
if not isinstance(value, (ArrayLike, list)):
|
|
137
160
|
raise TypeError(
|
|
138
161
|
f"Expected parameter value to be an array-like type. Got {value}."
|
|
139
162
|
)
|
|
140
163
|
|
|
141
164
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
165
|
+
@checkify.checkify
|
|
166
|
+
def _check_is_positive(value):
|
|
167
|
+
checkify.check(
|
|
168
|
+
jnp.all(value > 0), "value needs to be positive, got {value}", value=value
|
|
169
|
+
)
|
|
147
170
|
|
|
148
171
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
f"Expected parameter value to be a square matrix. Got {value}."
|
|
153
|
-
)
|
|
172
|
+
@checkify.checkify
|
|
173
|
+
def _check_is_square(value: T) -> None:
|
|
174
|
+
"""Check if a value is a square matrix.
|
|
154
175
|
|
|
176
|
+
Args:
|
|
177
|
+
value: The value to check.
|
|
155
178
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the value is not a square matrix.
|
|
181
|
+
"""
|
|
182
|
+
checkify.check(
|
|
183
|
+
value.shape[0] == value.shape[1],
|
|
184
|
+
"value needs to be a square matrix, got {value}",
|
|
185
|
+
value=value,
|
|
186
|
+
)
|
|
161
187
|
|
|
162
188
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
189
|
+
@checkify.checkify
|
|
190
|
+
def _check_is_lower_triangular(value: T) -> None:
|
|
191
|
+
"""Check if a value is a lower triangular matrix.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
value: The value to check.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If the value is not a lower triangular matrix.
|
|
198
|
+
"""
|
|
199
|
+
checkify.check(
|
|
200
|
+
jnp.all(jnp.tril(value) == value),
|
|
201
|
+
"value needs to be a lower triangular matrix, got {value}",
|
|
202
|
+
value=value,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@checkify.checkify
|
|
207
|
+
def _check_in_bounds(value: T, low: T, high: T) -> None:
|
|
208
|
+
"""Check if a value is bounded between low and high.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
value: The value to check.
|
|
212
|
+
low: The lower bound.
|
|
213
|
+
high: The upper bound.
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
ValueError: If any element of value is outside the bounds.
|
|
217
|
+
"""
|
|
218
|
+
checkify.check(
|
|
219
|
+
jnp.all((value >= low) & (value <= high)),
|
|
220
|
+
"value needs to be bounded between {low} and {high}, got {value}",
|
|
221
|
+
value=value,
|
|
222
|
+
low=low,
|
|
223
|
+
high=high,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
|
|
228
|
+
error, _ = fn(value, **kwargs)
|
|
229
|
+
checkify.check_error(error)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from flax import nnx
|
|
2
|
+
from jax import jit
|
|
3
|
+
from jax.experimental import checkify
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from gpjax.parameters import (
|
|
8
|
+
DEFAULT_BIJECTION,
|
|
9
|
+
LowerTriangular,
|
|
10
|
+
Parameter,
|
|
11
|
+
PositiveReal,
|
|
12
|
+
Real,
|
|
13
|
+
SigmoidBounded,
|
|
14
|
+
Static,
|
|
15
|
+
_check_in_bounds,
|
|
16
|
+
_check_is_lower_triangular,
|
|
17
|
+
_check_is_positive,
|
|
18
|
+
_check_is_square,
|
|
19
|
+
_safe_assert,
|
|
20
|
+
transform,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.mark.parametrize(
|
|
25
|
+
"param, value",
|
|
26
|
+
[
|
|
27
|
+
(PositiveReal, 1.0),
|
|
28
|
+
(Real, 2.0),
|
|
29
|
+
(SigmoidBounded, 0.5),
|
|
30
|
+
],
|
|
31
|
+
)
|
|
32
|
+
def test_transform(param, value):
|
|
33
|
+
# Create mock parameters and bijectors
|
|
34
|
+
params = nnx.State(
|
|
35
|
+
{
|
|
36
|
+
"param1": param(value),
|
|
37
|
+
"param2": Parameter(2.0, tag="real"),
|
|
38
|
+
}
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Test forward transformation
|
|
42
|
+
t_params = transform(params, DEFAULT_BIJECTION)
|
|
43
|
+
t_param1_expected = DEFAULT_BIJECTION[params["param1"]._tag].forward(value)
|
|
44
|
+
assert jnp.allclose(t_params["param1"].value, t_param1_expected)
|
|
45
|
+
assert jnp.allclose(t_params["param2"].value, 2.0)
|
|
46
|
+
|
|
47
|
+
# Test inverse transformation
|
|
48
|
+
it_params = transform(t_params, DEFAULT_BIJECTION, inverse=True)
|
|
49
|
+
assert repr(it_params) == repr(params)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest.mark.parametrize(
|
|
53
|
+
"param, tag",
|
|
54
|
+
[
|
|
55
|
+
(PositiveReal(1.0), "positive"),
|
|
56
|
+
(Real(2.0), "real"),
|
|
57
|
+
(SigmoidBounded(0.5), "sigmoid"),
|
|
58
|
+
(Static(2.0), "static"),
|
|
59
|
+
(LowerTriangular(jnp.eye(2)), "lower_triangular"),
|
|
60
|
+
],
|
|
61
|
+
)
|
|
62
|
+
def test_default_tags(param, tag):
|
|
63
|
+
assert param._tag == tag
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_check_is_positive():
|
|
67
|
+
# Check singleton
|
|
68
|
+
_safe_assert(_check_is_positive, jnp.array(3.0))
|
|
69
|
+
# Check array
|
|
70
|
+
_safe_assert(_check_is_positive, jnp.array([3.0, 4.0]))
|
|
71
|
+
|
|
72
|
+
# Check negative singleton
|
|
73
|
+
with pytest.raises(ValueError):
|
|
74
|
+
_safe_assert(_check_is_positive, jnp.array(-3.0))
|
|
75
|
+
|
|
76
|
+
# Check negative array
|
|
77
|
+
with pytest.raises(ValueError):
|
|
78
|
+
_safe_assert(_check_is_positive, jnp.array([-3.0, 4.0]))
|
|
79
|
+
|
|
80
|
+
# Test that functions wrapping _check_is_positive are jittable
|
|
81
|
+
def _dummy_fn(value):
|
|
82
|
+
_safe_assert(_check_is_positive, value)
|
|
83
|
+
|
|
84
|
+
jitted_fn = jit(checkify.checkify(_dummy_fn))
|
|
85
|
+
jitted_fn(jnp.array(3.0))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_check_is_square():
|
|
89
|
+
# Check square matrix
|
|
90
|
+
_safe_assert(_check_is_square, jnp.full((2, 2), 1.0))
|
|
91
|
+
# Check non-square matrix
|
|
92
|
+
with pytest.raises(ValueError):
|
|
93
|
+
_safe_assert(_check_is_square, jnp.full((2, 3), 1.0))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_check_is_lower_triangular():
|
|
97
|
+
# Check lower triangular matrix
|
|
98
|
+
_safe_assert(_check_is_lower_triangular, jnp.tril(jnp.eye(2)))
|
|
99
|
+
# Check non-lower triangular matrix
|
|
100
|
+
with pytest.raises(ValueError):
|
|
101
|
+
_safe_assert(_check_is_lower_triangular, jnp.linspace(0.0, 1.0, 4))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_check_in_bounds():
|
|
105
|
+
# Check in bounds
|
|
106
|
+
_safe_assert(
|
|
107
|
+
_check_in_bounds, jnp.array(0.5), low=jnp.array(0.0), high=jnp.array(1.0)
|
|
108
|
+
)
|
|
109
|
+
# Check out of bounds
|
|
110
|
+
with pytest.raises(ValueError):
|
|
111
|
+
_safe_assert(
|
|
112
|
+
_check_in_bounds, jnp.array(1.5), low=jnp.array(0.0), high=jnp.array(1.0)
|
|
113
|
+
)
|