gpjax 0.10.1__tar.gz → 0.10.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. {gpjax-0.10.1 → gpjax-0.10.2}/PKG-INFO +1 -1
  2. {gpjax-0.10.1 → gpjax-0.10.2}/docs/sharp_bits.md +57 -0
  3. {gpjax-0.10.1 → gpjax-0.10.2}/examples/oak_example.py +9 -7
  4. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/__init__.py +1 -1
  5. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/polynomial.py +1 -1
  6. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/parameters.py +88 -26
  7. gpjax-0.10.2/tests/test_parameters.py +113 -0
  8. gpjax-0.10.1/gpjax/kernels/nonstationary/oak.py +0 -406
  9. gpjax-0.10.1/tests/kernels/nonstationary/test_oak.py +0 -208
  10. gpjax-0.10.1/tests/test_parameters.py +0 -56
  11. {gpjax-0.10.1 → gpjax-0.10.2}/.cursorrules +0 -0
  12. {gpjax-0.10.1 → gpjax-0.10.2}/.github/CODE_OF_CONDUCT.md +0 -0
  13. {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  14. {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  15. {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  16. {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  17. {gpjax-0.10.1 → gpjax-0.10.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  18. {gpjax-0.10.1 → gpjax-0.10.2}/.github/codecov.yml +0 -0
  19. {gpjax-0.10.1 → gpjax-0.10.2}/.github/labels.yml +0 -0
  20. {gpjax-0.10.1 → gpjax-0.10.2}/.github/pull_request_template.md +0 -0
  21. {gpjax-0.10.1 → gpjax-0.10.2}/.github/release-drafter.yml +0 -0
  22. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/build_docs.yml +0 -0
  23. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/integration.yml +0 -0
  24. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/pr_greeting.yml +0 -0
  25. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/ruff.yml +0 -0
  26. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/stale_prs.yml +0 -0
  27. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/test_docs.yml +0 -0
  28. {gpjax-0.10.1 → gpjax-0.10.2}/.github/workflows/tests.yml +0 -0
  29. {gpjax-0.10.1 → gpjax-0.10.2}/.gitignore +0 -0
  30. {gpjax-0.10.1 → gpjax-0.10.2}/CITATION.bib +0 -0
  31. {gpjax-0.10.1 → gpjax-0.10.2}/LICENSE.txt +0 -0
  32. {gpjax-0.10.1 → gpjax-0.10.2}/Makefile +0 -0
  33. {gpjax-0.10.1 → gpjax-0.10.2}/README.md +0 -0
  34. {gpjax-0.10.1 → gpjax-0.10.2}/docs/CODE_OF_CONDUCT.md +0 -0
  35. {gpjax-0.10.1 → gpjax-0.10.2}/docs/GOVERNANCE.md +0 -0
  36. {gpjax-0.10.1 → gpjax-0.10.2}/docs/contributing.md +0 -0
  37. {gpjax-0.10.1 → gpjax-0.10.2}/docs/design.md +0 -0
  38. {gpjax-0.10.1 → gpjax-0.10.2}/docs/index.md +0 -0
  39. {gpjax-0.10.1 → gpjax-0.10.2}/docs/index.rst +0 -0
  40. {gpjax-0.10.1 → gpjax-0.10.2}/docs/installation.md +0 -0
  41. {gpjax-0.10.1 → gpjax-0.10.2}/docs/javascripts/katex.js +0 -0
  42. {gpjax-0.10.1 → gpjax-0.10.2}/docs/refs.bib +0 -0
  43. {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/gen_examples.py +0 -0
  44. {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/gen_pages.py +0 -0
  45. {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/notebook_converter.py +0 -0
  46. {gpjax-0.10.1 → gpjax-0.10.2}/docs/scripts/sharp_bits_figure.py +0 -0
  47. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/GP.pdf +0 -0
  48. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/GP.svg +0 -0
  49. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/bijector_figure.svg +0 -0
  50. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/css/gpjax_theme.css +0 -0
  51. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/favicon.ico +0 -0
  52. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax.mplstyle +0 -0
  53. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax_logo.pdf +0 -0
  54. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/gpjax_logo.svg +0 -0
  55. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/lato.ttf +0 -0
  56. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/logo.png +0 -0
  57. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/logo.svg +0 -0
  58. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/jaxkern/main.py +0 -0
  59. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/step_size_figure.png +0 -0
  60. {gpjax-0.10.1 → gpjax-0.10.2}/docs/static/step_size_figure.svg +0 -0
  61. {gpjax-0.10.1 → gpjax-0.10.2}/docs/stylesheets/extra.css +0 -0
  62. {gpjax-0.10.1 → gpjax-0.10.2}/docs/stylesheets/permalinks.css +0 -0
  63. {gpjax-0.10.1 → gpjax-0.10.2}/examples/backend.py +0 -0
  64. {gpjax-0.10.1 → gpjax-0.10.2}/examples/barycentres/barycentre_gp.gif +0 -0
  65. {gpjax-0.10.1 → gpjax-0.10.2}/examples/barycentres.py +0 -0
  66. {gpjax-0.10.1 → gpjax-0.10.2}/examples/classification.py +0 -0
  67. {gpjax-0.10.1 → gpjax-0.10.2}/examples/collapsed_vi.py +0 -0
  68. {gpjax-0.10.1 → gpjax-0.10.2}/examples/constructing_new_kernels.py +0 -0
  69. {gpjax-0.10.1 → gpjax-0.10.2}/examples/data/max_tempeature_switzerland.csv +0 -0
  70. {gpjax-0.10.1 → gpjax-0.10.2}/examples/data/yacht_hydrodynamics.data +0 -0
  71. {gpjax-0.10.1 → gpjax-0.10.2}/examples/deep_kernels.py +0 -0
  72. {gpjax-0.10.1 → gpjax-0.10.2}/examples/gpjax.mplstyle +0 -0
  73. {gpjax-0.10.1 → gpjax-0.10.2}/examples/graph_kernels.py +0 -0
  74. {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
  75. {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps/generating_process.png +0 -0
  76. {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_gps.py +0 -0
  77. {gpjax-0.10.1 → gpjax-0.10.2}/examples/intro_to_kernels.py +0 -0
  78. {gpjax-0.10.1 → gpjax-0.10.2}/examples/likelihoods_guide.py +0 -0
  79. {gpjax-0.10.1 → gpjax-0.10.2}/examples/oceanmodelling.py +0 -0
  80. {gpjax-0.10.1 → gpjax-0.10.2}/examples/poisson.py +0 -0
  81. {gpjax-0.10.1 → gpjax-0.10.2}/examples/regression.py +0 -0
  82. {gpjax-0.10.1 → gpjax-0.10.2}/examples/uncollapsed_vi.py +0 -0
  83. {gpjax-0.10.1 → gpjax-0.10.2}/examples/utils.py +0 -0
  84. {gpjax-0.10.1 → gpjax-0.10.2}/examples/yacht.py +0 -0
  85. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/citation.py +0 -0
  86. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/dataset.py +0 -0
  87. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/distributions.py +0 -0
  88. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/fit.py +0 -0
  89. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/gps.py +0 -0
  90. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/integrators.py +0 -0
  91. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/__init__.py +0 -0
  92. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/approximations/__init__.py +0 -0
  93. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/approximations/rff.py +0 -0
  94. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/base.py +0 -0
  95. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/__init__.py +0 -0
  96. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/base.py +0 -0
  97. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/basis_functions.py +0 -0
  98. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  99. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/dense.py +0 -0
  100. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/diagonal.py +0 -0
  101. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/computations/eigen.py +0 -0
  102. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  103. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
  104. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
  105. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
  106. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  107. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/nonstationary/linear.py +0 -0
  108. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/__init__.py +0 -0
  109. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/base.py +0 -0
  110. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern12.py +0 -0
  111. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern32.py +0 -0
  112. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/matern52.py +0 -0
  113. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/periodic.py +0 -0
  114. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  115. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  116. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/rbf.py +0 -0
  117. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/utils.py +0 -0
  118. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/kernels/stationary/white.py +0 -0
  119. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/likelihoods.py +0 -0
  120. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/lower_cholesky.py +0 -0
  121. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/mean_functions.py +0 -0
  122. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/objectives.py +0 -0
  123. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/scan.py +0 -0
  124. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/typing.py +0 -0
  125. {gpjax-0.10.1 → gpjax-0.10.2}/gpjax/variational_families.py +0 -0
  126. {gpjax-0.10.1 → gpjax-0.10.2}/mkdocs.yml +0 -0
  127. {gpjax-0.10.1 → gpjax-0.10.2}/pyproject.toml +0 -0
  128. {gpjax-0.10.1 → gpjax-0.10.2}/static/CONTRIBUTING.md +0 -0
  129. {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.bib +0 -0
  130. {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.md +0 -0
  131. {gpjax-0.10.1 → gpjax-0.10.2}/static/paper.pdf +0 -0
  132. {gpjax-0.10.1 → gpjax-0.10.2}/tests/__init__.py +0 -0
  133. {gpjax-0.10.1 → gpjax-0.10.2}/tests/conftest.py +0 -0
  134. {gpjax-0.10.1 → gpjax-0.10.2}/tests/integration_tests.py +0 -0
  135. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_citations.py +0 -0
  136. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_dataset.py +0 -0
  137. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_fit.py +0 -0
  138. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_gaussian_distribution.py +0 -0
  139. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_gps.py +0 -0
  140. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_integrators.py +0 -0
  141. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/__init__.py +0 -0
  142. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_approximations.py +0 -0
  143. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_base.py +0 -0
  144. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_computation.py +0 -0
  145. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_non_euclidean.py +0 -0
  146. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_nonstationary.py +0 -0
  147. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_stationary.py +0 -0
  148. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_kernels/test_utils.py +0 -0
  149. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_likelihoods.py +0 -0
  150. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_lower_cholesky.py +0 -0
  151. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_markdown.py +0 -0
  152. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_mean_functions.py +0 -0
  153. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_objectives.py +0 -0
  154. {gpjax-0.10.1 → gpjax-0.10.2}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.1
3
+ Version: 0.10.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -175,3 +175,60 @@ mini-batch optimisation of the parameters of your sparse Gaussian process model.
175
175
  model will scale linearly in the batch size and quadratically in the number of inducing
176
176
  points. We demonstrate its use in
177
177
  [our sparse stochastic variational inference notebook](_examples/uncollapsed_vi.md).
178
+
179
+ ## JIT compilation
180
+
181
+ There are a subset of operations in GPJax that are not JIT compatible by default. This
182
+ is because we have assertions in place to check the properties of the parameters. For
183
+ example, we check that the lengthscale parameter that a user provides is positive. This
184
+ makes for a better user experience as we can provide more informative error messages;
185
+ however, JIT compiling functions wherein these assertions are made will break the code.
186
+ As an example, consider the following code:
187
+
188
+ ```python
189
+ import jax
190
+ import jax.numpy as jnp
191
+ import gpjax as gpx
192
+
193
+ x = jnp.linspace(0, 1, 10)[:, None]
194
+
195
+ def compute_gram(lengthscale):
196
+ k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
197
+ return k.gram(x)
198
+
199
+ compute_gram(1.0)
200
+ ```
201
+
202
+ so far so good. However, if we try to JIT compile this function, we will get an error:
203
+
204
+ ```python
205
+ jit_compute_gram = jax.jit(compute_gram)
206
+ try:
207
+ jit_compute_gram(1.0)
208
+ except Exception as e:
209
+ print(e)
210
+ ```
211
+
212
+ This error is due to the fact that the `RBF` kernel contains an assertion that checks
213
+ that the lengthscale is positive. It does not matter that the assertion is satisfied;
214
+ the very presence of the assertion will break JIT compilation.
215
+
216
+ To resolve this, we can use the `checkify` decorator to remove the assertion. This will
217
+ allow the function to be JIT compiled.
218
+
219
+ ```python
220
+ from jax.experimental import checkify
221
+
222
+ jit_compute_gram = jax.jit(checkify.checkify(compute_gram))
223
+ error, value = jit_compute_gram(1.0)
224
+ ```
225
+ By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
226
+ output of the assertion, and the second element is the value of the function.
227
+
228
+ This design is not perfect, and in an ideal world we would not enforce the user to wrap
229
+ their code in `checkify.checkify`. We are actively looking into cleaner ways to provide
230
+ guardrails in a less intrusive manner. However, for now, should you try to JIT compile
231
+ a component of GPJax wherein there is an assertion, you will need to wrap the function
232
+ in `checkify.checkify` as shown above.
233
+
234
+ For more on `checkify`, please see the [JAX Checkify Doc](https://docs.jax.dev/en/latest/debugging/checkify_guide.html).
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.11.2
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: docs
14
14
  # language: python
@@ -37,19 +37,21 @@
37
37
  # %%
38
38
  import jax
39
39
  from jax import config
40
-
41
- config.update("jax_enable_x64", True) # Enable Float64 precision
42
-
43
40
  import jax.numpy as jnp
44
- import matplotlib.pyplot as plt
45
41
  from matplotlib.colors import ListedColormap
42
+ import matplotlib.pyplot as plt
46
43
  import optax
47
44
 
48
45
  import gpjax as gpx
49
46
  from gpjax.dataset import Dataset
50
- from gpjax.kernels import OrthogonalAdditiveKernel, RBF
47
+ from gpjax.kernels import (
48
+ RBF,
49
+ OrthogonalAdditiveKernel,
50
+ )
51
51
  from gpjax.typing import KeyArray
52
52
 
53
+ config.update("jax_enable_x64", True) # Enable Float64 precision
54
+
53
55
 
54
56
  # %%
55
57
  def f(x: jnp.ndarray) -> jnp.ndarray:
@@ -198,7 +200,7 @@ def main():
198
200
 
199
201
  print("\nRelative Importance of Input Dimensions:")
200
202
  for i, imp in enumerate(relative_importance):
201
- print(f"Dimension {i+1}: {imp:.4f}")
203
+ print(f"Dimension {i + 1}: {imp:.4f}")
202
204
 
203
205
  if opt_posterior.params.kernel.coeffs_2 is not None:
204
206
  # Analyze second-order interactions
@@ -39,7 +39,7 @@ __license__ = "MIT"
39
39
  __description__ = "Didactic Gaussian processes in JAX"
40
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.10.1"
42
+ __version__ = "0.10.2"
43
43
 
44
44
  __all__ = [
45
45
  "base",
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
46
46
  self,
47
47
  active_dims: tp.Union[list[int], slice, None] = None,
48
48
  degree: int = 2,
49
- shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 0.0,
49
+ shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
50
50
  variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
51
51
  n_dims: tp.Union[int, None] = None,
52
52
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
@@ -1,6 +1,7 @@
1
1
  import typing as tp
2
2
 
3
3
  from flax import nnx
4
+ from jax.experimental import checkify
4
5
  import jax.numpy as jnp
5
6
  import jax.tree_util as jtu
6
7
  from jax.typing import ArrayLike
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
84
85
 
85
86
  def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
86
87
  super().__init__(value=value, tag=tag, **kwargs)
87
-
88
- _check_is_positive(self.value)
88
+ _safe_assert(_check_is_positive, self.value)
89
89
 
90
90
 
91
91
  class Real(Parameter[T]):
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
101
101
  def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
102
102
  super().__init__(value=value, tag=tag, **kwargs)
103
103
 
104
- _check_in_bounds(self.value, 0.0, 1.0)
104
+ # Only perform validation in non-JIT contexts
105
+ if (
106
+ not isinstance(value, jnp.ndarray)
107
+ or not getattr(value, "aval", None) is None
108
+ ):
109
+ _safe_assert(
110
+ _check_in_bounds,
111
+ self.value,
112
+ low=jnp.array(0.0),
113
+ high=jnp.array(1.0),
114
+ )
105
115
 
106
116
 
107
117
  class Static(nnx.Variable[T]):
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
120
130
  def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
121
131
  super().__init__(value=value, tag=tag, **kwargs)
122
132
 
123
- _check_is_square(self.value)
124
- _check_is_lower_triangular(self.value)
133
+ # Only perform validation in non-JIT contexts
134
+ if (
135
+ not isinstance(value, jnp.ndarray)
136
+ or not getattr(value, "aval", None) is None
137
+ ):
138
+ _safe_assert(_check_is_square, self.value)
139
+ _safe_assert(_check_is_lower_triangular, self.value)
125
140
 
126
141
 
127
142
  DEFAULT_BIJECTION = {
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
132
147
  }
133
148
 
134
149
 
135
- def _check_is_arraylike(value: T):
150
+ def _check_is_arraylike(value: T) -> None:
151
+ """Check if a value is array-like.
152
+
153
+ Args:
154
+ value: The value to check.
155
+
156
+ Raises:
157
+ TypeError: If the value is not array-like.
158
+ """
136
159
  if not isinstance(value, (ArrayLike, list)):
137
160
  raise TypeError(
138
161
  f"Expected parameter value to be an array-like type. Got {value}."
139
162
  )
140
163
 
141
164
 
142
- def _check_is_positive(value: T):
143
- if jnp.any(value < 0):
144
- raise ValueError(
145
- f"Expected parameter value to be strictly positive. Got {value}."
146
- )
165
+ @checkify.checkify
166
+ def _check_is_positive(value):
167
+ checkify.check(
168
+ jnp.all(value > 0), "value needs to be positive, got {value}", value=value
169
+ )
147
170
 
148
171
 
149
- def _check_is_square(value: T):
150
- if value.shape[0] != value.shape[1]:
151
- raise ValueError(
152
- f"Expected parameter value to be a square matrix. Got {value}."
153
- )
172
+ @checkify.checkify
173
+ def _check_is_square(value: T) -> None:
174
+ """Check if a value is a square matrix.
154
175
 
176
+ Args:
177
+ value: The value to check.
155
178
 
156
- def _check_is_lower_triangular(value: T):
157
- if not jnp.all(jnp.tril(value) == value):
158
- raise ValueError(
159
- f"Expected parameter value to be a lower triangular matrix. Got {value}."
160
- )
179
+ Raises:
180
+ ValueError: If the value is not a square matrix.
181
+ """
182
+ checkify.check(
183
+ value.shape[0] == value.shape[1],
184
+ "value needs to be a square matrix, got {value}",
185
+ value=value,
186
+ )
161
187
 
162
188
 
163
- def _check_in_bounds(value: T, low: float, high: float):
164
- if jnp.any((value < low) | (value > high)):
165
- raise ValueError(
166
- f"Expected parameter value to be bounded between {low} and {high}. Got {value}."
167
- )
189
+ @checkify.checkify
190
+ def _check_is_lower_triangular(value: T) -> None:
191
+ """Check if a value is a lower triangular matrix.
192
+
193
+ Args:
194
+ value: The value to check.
195
+
196
+ Raises:
197
+ ValueError: If the value is not a lower triangular matrix.
198
+ """
199
+ checkify.check(
200
+ jnp.all(jnp.tril(value) == value),
201
+ "value needs to be a lower triangular matrix, got {value}",
202
+ value=value,
203
+ )
204
+
205
+
206
+ @checkify.checkify
207
+ def _check_in_bounds(value: T, low: T, high: T) -> None:
208
+ """Check if a value is bounded between low and high.
209
+
210
+ Args:
211
+ value: The value to check.
212
+ low: The lower bound.
213
+ high: The upper bound.
214
+
215
+ Raises:
216
+ ValueError: If any element of value is outside the bounds.
217
+ """
218
+ checkify.check(
219
+ jnp.all((value >= low) & (value <= high)),
220
+ "value needs to be bounded between {low} and {high}, got {value}",
221
+ value=value,
222
+ low=low,
223
+ high=high,
224
+ )
225
+
226
+
227
+ def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
228
+ error, _ = fn(value, **kwargs)
229
+ checkify.check_error(error)
@@ -0,0 +1,113 @@
1
+ from flax import nnx
2
+ from jax import jit
3
+ from jax.experimental import checkify
4
+ import jax.numpy as jnp
5
+ import pytest
6
+
7
+ from gpjax.parameters import (
8
+ DEFAULT_BIJECTION,
9
+ LowerTriangular,
10
+ Parameter,
11
+ PositiveReal,
12
+ Real,
13
+ SigmoidBounded,
14
+ Static,
15
+ _check_in_bounds,
16
+ _check_is_lower_triangular,
17
+ _check_is_positive,
18
+ _check_is_square,
19
+ _safe_assert,
20
+ transform,
21
+ )
22
+
23
+
24
+ @pytest.mark.parametrize(
25
+ "param, value",
26
+ [
27
+ (PositiveReal, 1.0),
28
+ (Real, 2.0),
29
+ (SigmoidBounded, 0.5),
30
+ ],
31
+ )
32
+ def test_transform(param, value):
33
+ # Create mock parameters and bijectors
34
+ params = nnx.State(
35
+ {
36
+ "param1": param(value),
37
+ "param2": Parameter(2.0, tag="real"),
38
+ }
39
+ )
40
+
41
+ # Test forward transformation
42
+ t_params = transform(params, DEFAULT_BIJECTION)
43
+ t_param1_expected = DEFAULT_BIJECTION[params["param1"]._tag].forward(value)
44
+ assert jnp.allclose(t_params["param1"].value, t_param1_expected)
45
+ assert jnp.allclose(t_params["param2"].value, 2.0)
46
+
47
+ # Test inverse transformation
48
+ it_params = transform(t_params, DEFAULT_BIJECTION, inverse=True)
49
+ assert repr(it_params) == repr(params)
50
+
51
+
52
+ @pytest.mark.parametrize(
53
+ "param, tag",
54
+ [
55
+ (PositiveReal(1.0), "positive"),
56
+ (Real(2.0), "real"),
57
+ (SigmoidBounded(0.5), "sigmoid"),
58
+ (Static(2.0), "static"),
59
+ (LowerTriangular(jnp.eye(2)), "lower_triangular"),
60
+ ],
61
+ )
62
+ def test_default_tags(param, tag):
63
+ assert param._tag == tag
64
+
65
+
66
+ def test_check_is_positive():
67
+ # Check singleton
68
+ _safe_assert(_check_is_positive, jnp.array(3.0))
69
+ # Check array
70
+ _safe_assert(_check_is_positive, jnp.array([3.0, 4.0]))
71
+
72
+ # Check negative singleton
73
+ with pytest.raises(ValueError):
74
+ _safe_assert(_check_is_positive, jnp.array(-3.0))
75
+
76
+ # Check negative array
77
+ with pytest.raises(ValueError):
78
+ _safe_assert(_check_is_positive, jnp.array([-3.0, 4.0]))
79
+
80
+ # Test that functions wrapping _check_is_positive are jittable
81
+ def _dummy_fn(value):
82
+ _safe_assert(_check_is_positive, value)
83
+
84
+ jitted_fn = jit(checkify.checkify(_dummy_fn))
85
+ jitted_fn(jnp.array(3.0))
86
+
87
+
88
+ def test_check_is_square():
89
+ # Check square matrix
90
+ _safe_assert(_check_is_square, jnp.full((2, 2), 1.0))
91
+ # Check non-square matrix
92
+ with pytest.raises(ValueError):
93
+ _safe_assert(_check_is_square, jnp.full((2, 3), 1.0))
94
+
95
+
96
+ def test_check_is_lower_triangular():
97
+ # Check lower triangular matrix
98
+ _safe_assert(_check_is_lower_triangular, jnp.tril(jnp.eye(2)))
99
+ # Check non-lower triangular matrix
100
+ with pytest.raises(ValueError):
101
+ _safe_assert(_check_is_lower_triangular, jnp.linspace(0.0, 1.0, 4))
102
+
103
+
104
+ def test_check_in_bounds():
105
+ # Check in bounds
106
+ _safe_assert(
107
+ _check_in_bounds, jnp.array(0.5), low=jnp.array(0.0), high=jnp.array(1.0)
108
+ )
109
+ # Check out of bounds
110
+ with pytest.raises(ValueError):
111
+ _safe_assert(
112
+ _check_in_bounds, jnp.array(1.5), low=jnp.array(0.0), high=jnp.array(1.0)
113
+ )