gpjax 0.11.0__tar.gz → 0.11.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. {gpjax-0.11.0 → gpjax-0.11.2}/PKG-INFO +1 -1
  2. {gpjax-0.11.0 → gpjax-0.11.2}/examples/constructing_new_kernels.py +0 -3
  3. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/__init__.py +4 -2
  4. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/citation.py +7 -2
  5. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/fit.py +104 -1
  6. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/arccosine.py +6 -3
  7. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/linear.py +3 -3
  8. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/polynomial.py +6 -3
  9. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/base.py +6 -3
  10. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/likelihoods.py +4 -4
  11. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/mean_functions.py +1 -1
  12. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/parameters.py +16 -0
  13. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_fit.py +190 -7
  14. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_nonstationary.py +5 -5
  15. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_stationary.py +5 -4
  16. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_likelihoods.py +2 -2
  17. gpjax-0.11.2/tests/test_mean_functions.py +249 -0
  18. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_numpyro_extras.py +76 -0
  19. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_parameters.py +4 -0
  20. gpjax-0.11.2/uv.lock +832 -0
  21. gpjax-0.11.0/.cursorrules +0 -37
  22. gpjax-0.11.0/tests/test_mean_functions.py +0 -81
  23. {gpjax-0.11.0 → gpjax-0.11.2}/.github/CODE_OF_CONDUCT.md +0 -0
  24. {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  25. {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  26. {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  27. {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  28. {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  29. {gpjax-0.11.0 → gpjax-0.11.2}/.github/codecov.yml +0 -0
  30. {gpjax-0.11.0 → gpjax-0.11.2}/.github/labels.yml +0 -0
  31. {gpjax-0.11.0 → gpjax-0.11.2}/.github/pull_request_template.md +0 -0
  32. {gpjax-0.11.0 → gpjax-0.11.2}/.github/release-drafter.yml +0 -0
  33. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/build_docs.yml +0 -0
  34. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/integration.yml +0 -0
  35. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/pr_greeting.yml +0 -0
  36. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/ruff.yml +0 -0
  37. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/stale_prs.yml +0 -0
  38. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/test_docs.yml +0 -0
  39. {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/tests.yml +0 -0
  40. {gpjax-0.11.0 → gpjax-0.11.2}/.gitignore +0 -0
  41. {gpjax-0.11.0 → gpjax-0.11.2}/CITATION.bib +0 -0
  42. {gpjax-0.11.0 → gpjax-0.11.2}/LICENSE.txt +0 -0
  43. {gpjax-0.11.0 → gpjax-0.11.2}/Makefile +0 -0
  44. {gpjax-0.11.0 → gpjax-0.11.2}/README.md +0 -0
  45. {gpjax-0.11.0 → gpjax-0.11.2}/docs/CODE_OF_CONDUCT.md +0 -0
  46. {gpjax-0.11.0 → gpjax-0.11.2}/docs/GOVERNANCE.md +0 -0
  47. {gpjax-0.11.0 → gpjax-0.11.2}/docs/contributing.md +0 -0
  48. {gpjax-0.11.0 → gpjax-0.11.2}/docs/design.md +0 -0
  49. {gpjax-0.11.0 → gpjax-0.11.2}/docs/index.md +0 -0
  50. {gpjax-0.11.0 → gpjax-0.11.2}/docs/index.rst +0 -0
  51. {gpjax-0.11.0 → gpjax-0.11.2}/docs/installation.md +0 -0
  52. {gpjax-0.11.0 → gpjax-0.11.2}/docs/javascripts/katex.js +0 -0
  53. {gpjax-0.11.0 → gpjax-0.11.2}/docs/refs.bib +0 -0
  54. {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/gen_examples.py +0 -0
  55. {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/gen_pages.py +0 -0
  56. {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/notebook_converter.py +0 -0
  57. {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/sharp_bits_figure.py +0 -0
  58. {gpjax-0.11.0 → gpjax-0.11.2}/docs/sharp_bits.md +0 -0
  59. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/GP.pdf +0 -0
  60. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/GP.svg +0 -0
  61. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/bijector_figure.svg +0 -0
  62. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/css/gpjax_theme.css +0 -0
  63. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/favicon.ico +0 -0
  64. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax.mplstyle +0 -0
  65. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax_logo.pdf +0 -0
  66. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax_logo.svg +0 -0
  67. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/lato.ttf +0 -0
  68. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/logo.png +0 -0
  69. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/logo.svg +0 -0
  70. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/main.py +0 -0
  71. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/step_size_figure.png +0 -0
  72. {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/step_size_figure.svg +0 -0
  73. {gpjax-0.11.0 → gpjax-0.11.2}/docs/stylesheets/extra.css +0 -0
  74. {gpjax-0.11.0 → gpjax-0.11.2}/docs/stylesheets/permalinks.css +0 -0
  75. {gpjax-0.11.0 → gpjax-0.11.2}/examples/backend.py +0 -0
  76. {gpjax-0.11.0 → gpjax-0.11.2}/examples/barycentres/barycentre_gp.gif +0 -0
  77. {gpjax-0.11.0 → gpjax-0.11.2}/examples/barycentres.py +0 -0
  78. {gpjax-0.11.0 → gpjax-0.11.2}/examples/classification.py +0 -0
  79. {gpjax-0.11.0 → gpjax-0.11.2}/examples/collapsed_vi.py +0 -0
  80. {gpjax-0.11.0 → gpjax-0.11.2}/examples/data/max_tempeature_switzerland.csv +0 -0
  81. {gpjax-0.11.0 → gpjax-0.11.2}/examples/data/yacht_hydrodynamics.data +0 -0
  82. {gpjax-0.11.0 → gpjax-0.11.2}/examples/deep_kernels.py +0 -0
  83. {gpjax-0.11.0 → gpjax-0.11.2}/examples/gpjax.mplstyle +0 -0
  84. {gpjax-0.11.0 → gpjax-0.11.2}/examples/graph_kernels.py +0 -0
  85. {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
  86. {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps/generating_process.png +0 -0
  87. {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps.py +0 -0
  88. {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_kernels.py +0 -0
  89. {gpjax-0.11.0 → gpjax-0.11.2}/examples/likelihoods_guide.py +0 -0
  90. {gpjax-0.11.0 → gpjax-0.11.2}/examples/oceanmodelling.py +0 -0
  91. {gpjax-0.11.0 → gpjax-0.11.2}/examples/poisson.py +0 -0
  92. {gpjax-0.11.0 → gpjax-0.11.2}/examples/regression.py +0 -0
  93. {gpjax-0.11.0 → gpjax-0.11.2}/examples/uncollapsed_vi.py +0 -0
  94. {gpjax-0.11.0 → gpjax-0.11.2}/examples/utils.py +0 -0
  95. {gpjax-0.11.0 → gpjax-0.11.2}/examples/yacht.py +0 -0
  96. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/dataset.py +0 -0
  97. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/distributions.py +0 -0
  98. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/gps.py +0 -0
  99. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/integrators.py +0 -0
  100. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/__init__.py +0 -0
  101. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/approximations/__init__.py +0 -0
  102. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/approximations/rff.py +0 -0
  103. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/base.py +0 -0
  104. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/__init__.py +0 -0
  105. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/base.py +0 -0
  106. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/basis_functions.py +0 -0
  107. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  108. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/dense.py +0 -0
  109. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/diagonal.py +0 -0
  110. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/eigen.py +0 -0
  111. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  112. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
  113. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
  114. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
  115. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/__init__.py +0 -0
  116. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern12.py +0 -0
  117. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern32.py +0 -0
  118. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern52.py +0 -0
  119. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/periodic.py +0 -0
  120. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  121. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  122. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/rbf.py +0 -0
  123. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/utils.py +0 -0
  124. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/white.py +0 -0
  125. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/lower_cholesky.py +0 -0
  126. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/numpyro_extras.py +0 -0
  127. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/objectives.py +0 -0
  128. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/scan.py +0 -0
  129. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/typing.py +0 -0
  130. {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/variational_families.py +0 -0
  131. {gpjax-0.11.0 → gpjax-0.11.2}/mkdocs.yml +0 -0
  132. {gpjax-0.11.0 → gpjax-0.11.2}/pyproject.toml +0 -0
  133. {gpjax-0.11.0 → gpjax-0.11.2}/static/CONTRIBUTING.md +0 -0
  134. {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.bib +0 -0
  135. {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.md +0 -0
  136. {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.pdf +0 -0
  137. {gpjax-0.11.0 → gpjax-0.11.2}/tests/__init__.py +0 -0
  138. {gpjax-0.11.0 → gpjax-0.11.2}/tests/conftest.py +0 -0
  139. {gpjax-0.11.0 → gpjax-0.11.2}/tests/integration_tests.py +0 -0
  140. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_citations.py +0 -0
  141. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_dataset.py +0 -0
  142. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_gaussian_distribution.py +0 -0
  143. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_gps.py +0 -0
  144. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_integrators.py +0 -0
  145. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/__init__.py +0 -0
  146. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_approximations.py +0 -0
  147. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_base.py +0 -0
  148. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_computation.py +0 -0
  149. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_non_euclidean.py +0 -0
  150. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_utils.py +0 -0
  151. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_lower_cholesky.py +0 -0
  152. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_markdown.py +0 -0
  153. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_objectives.py +0 -0
  154. {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.0
3
+ Version: 0.11.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -33,7 +33,6 @@ from jaxtyping import (
33
33
  install_import_hook,
34
34
  )
35
35
  import matplotlib.pyplot as plt
36
- import numpyro.distributions as npd
37
36
  from numpyro.distributions import constraints
38
37
  import numpyro.distributions.transforms as npt
39
38
 
@@ -52,8 +51,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
52
51
  import gpjax as gpx
53
52
 
54
53
 
55
- tfb = tfp.bijectors
56
-
57
54
  # set the default style for plotting
58
55
  use_mpl_style()
59
56
 
@@ -32,14 +32,15 @@ from gpjax.citation import cite
32
32
  from gpjax.dataset import Dataset
33
33
  from gpjax.fit import (
34
34
  fit,
35
+ fit_lbfgs,
35
36
  fit_scipy,
36
37
  )
37
38
 
38
39
  __license__ = "MIT"
39
- __description__ = "Didactic Gaussian processes in JAX"
40
+ __description__ = "Gaussian processes in JAX and Flax"
40
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.11.0"
43
+ __version__ = "0.11.2"
43
44
 
44
45
  __all__ = [
45
46
  "base",
@@ -56,5 +57,6 @@ __all__ = [
56
57
  "fit",
57
58
  "Module",
58
59
  "param_field",
60
+ "fit_lbfgs",
59
61
  "fit_scipy",
60
62
  ]
@@ -8,7 +8,12 @@ from beartype.typing import (
8
8
  Dict,
9
9
  Union,
10
10
  )
11
- from jaxlib.xla_extension import PjitFunction
11
+
12
+ try:
13
+ # safely removable once jax>=0.6.0
14
+ from jaxlib.xla_extension import PjitFunction
15
+ except ModuleNotFoundError:
16
+ from jaxlib._jax import PjitFunction
12
17
 
13
18
  from gpjax.kernels import (
14
19
  RFF,
@@ -45,7 +50,7 @@ class AbstractCitation:
45
50
 
46
51
 
47
52
  class NullCitation(AbstractCitation):
48
- def __str__(self) -> str:
53
+ def as_str(self) -> str:
49
54
  return (
50
55
  "No citation available. If you think this is an error, please open a pull"
51
56
  " request."
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
127
127
  _check_verbose(verbose)
128
128
 
129
129
  # Model state filtering
130
-
131
130
  graphdef, params, *static_state = nnx.split(model, Parameter, ...)
132
131
 
133
132
  # Parameters bijection to unconstrained space
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
253
252
  return model, history
254
253
 
255
254
 
255
+ def fit_lbfgs(
256
+ *,
257
+ model: Model,
258
+ objective: Objective,
259
+ train_data: Dataset,
260
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
261
+ max_iters: int = 100,
262
+ safe: bool = True,
263
+ max_linesearch_steps: int = 32,
264
+ gtol: float = 1e-5,
265
+ ) -> tuple[Model, jax.Array]:
266
+ r"""Train a Module model with respect to a supplied Objective function.
267
+
268
+ Uses Optax's LBFGS implementation and a jax.lax.while loop.
269
+
270
+ Args:
271
+ model: the model Module to be optimised.
272
+ objective: The objective function that we are optimising with
273
+ respect to.
274
+ train_data (Dataset): The training data to be used for the optimisation.
275
+ max_iters (int): The maximum number of optimisation steps to run. Defaults
276
+ to 500.
277
+ safe (bool): Whether to check the types of the inputs.
278
+ max_linesearch_steps (int): The maximum number of linesearch steps to use
279
+ for finding the stepsize.
280
+ gtol (float): Terminate the optimisation if the L2 norm of the gradient is
281
+ below this threshold.
282
+
283
+ Returns:
284
+ A tuple comprising the optimised model and final loss.
285
+ """
286
+ if safe:
287
+ # Check inputs
288
+ _check_model(model)
289
+ _check_train_data(train_data)
290
+ _check_num_iters(max_iters)
291
+
292
+ # Model state filtering
293
+ graphdef, params, *static_state = nnx.split(model, Parameter, ...)
294
+
295
+ # Parameters bijection to unconstrained space
296
+ if params_bijection is not None:
297
+ params = transform(params, params_bijection, inverse=True)
298
+
299
+ # Loss definition
300
+ def loss(params: nnx.State) -> ScalarFloat:
301
+ params = transform(params, params_bijection)
302
+ model = nnx.merge(graphdef, params, *static_state)
303
+ return objective(model, train_data)
304
+
305
+ # Initialise optimiser
306
+ optim = ox.lbfgs(
307
+ linesearch=ox.scale_by_zoom_linesearch(
308
+ max_linesearch_steps=max_linesearch_steps,
309
+ initial_guess_strategy="one",
310
+ )
311
+ )
312
+ opt_state = optim.init(params)
313
+ loss_value_and_grad = ox.value_and_grad_from_state(loss)
314
+
315
+ # Optimisation step.
316
+ def step(carry):
317
+ params, opt_state = carry
318
+
319
+ # Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
320
+ # See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
321
+ loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
322
+ updates, opt_state = optim.update(
323
+ loss_gradient,
324
+ opt_state,
325
+ params,
326
+ value=loss_val,
327
+ grad=loss_gradient,
328
+ value_fn=loss,
329
+ )
330
+ params = ox.apply_updates(params, updates)
331
+
332
+ return params, opt_state
333
+
334
+ def continue_fn(carry):
335
+ _, opt_state = carry
336
+ n = ox.tree_utils.tree_get(opt_state, "count")
337
+ g = ox.tree_utils.tree_get(opt_state, "grad")
338
+ g_l2_norm = ox.tree_utils.tree_l2_norm(g)
339
+ return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
340
+
341
+ # Optimisation loop
342
+ params, opt_state = jax.lax.while_loop(
343
+ continue_fn,
344
+ step,
345
+ (params, opt_state),
346
+ )
347
+ final_loss = ox.tree_utils.tree_get(opt_state, "value")
348
+
349
+ # Parameters bijection to constrained space
350
+ if params_bijection is not None:
351
+ params = transform(params, params_bijection)
352
+
353
+ # Reconstruct model
354
+ model = nnx.merge(graphdef, params, *static_state)
355
+
356
+ return model, final_loss
357
+
358
+
256
359
  def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
257
360
  """Batch the data into mini-batches. Sampling is done with replacement.
258
361
 
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
91
94
  if isinstance(variance, nnx.Variable):
92
95
  self.variance = variance
93
96
  else:
94
- self.variance = PositiveReal(variance)
97
+ self.variance = NonNegativeReal(variance)
95
98
  if tp.TYPE_CHECKING:
96
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
99
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
97
100
 
98
101
  if isinstance(bias_variance, nnx.Variable):
99
102
  self.bias_variance = bias_variance
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import NonNegativeReal
27
27
  from gpjax.typing import (
28
28
  Array,
29
29
  ScalarArray,
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
64
64
  if isinstance(variance, nnx.Variable):
65
65
  self.variance = variance
66
66
  else:
67
- self.variance = PositiveReal(variance)
67
+ self.variance = NonNegativeReal(variance)
68
68
  if tp.TYPE_CHECKING:
69
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
69
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
70
70
 
71
71
  def __call__(
72
72
  self,
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
76
79
  if isinstance(variance, nnx.Variable):
77
80
  self.variance = variance
78
81
  else:
79
- self.variance = PositiveReal(variance)
82
+ self.variance = NonNegativeReal(variance)
80
83
  if tp.TYPE_CHECKING:
81
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
84
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
82
85
 
83
86
  self.name = f"Polynomial (degree {self.degree})"
84
87
 
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  DenseKernelComputation,
27
27
  )
28
- from gpjax.parameters import PositiveReal
28
+ from gpjax.parameters import (
29
+ NonNegativeReal,
30
+ PositiveReal,
31
+ )
29
32
  from gpjax.typing import (
30
33
  Array,
31
34
  ScalarArray,
@@ -85,11 +88,11 @@ class StationaryKernel(AbstractKernel):
85
88
  if isinstance(variance, nnx.Variable):
86
89
  self.variance = variance
87
90
  else:
88
- self.variance = PositiveReal(variance)
91
+ self.variance = NonNegativeReal(variance)
89
92
 
90
93
  # static typing
91
94
  if tp.TYPE_CHECKING:
92
- self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
95
+ self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
93
96
 
94
97
  @property
95
98
  def spectral_density(self) -> npd.Normal | npd.StudentT:
@@ -28,7 +28,7 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- PositiveReal,
31
+ NonNegativeReal,
32
32
  Static,
33
33
  )
34
34
  from gpjax.typing import (
@@ -134,7 +134,7 @@ class Gaussian(AbstractLikelihood):
134
134
  self,
135
135
  num_datapoints: int,
136
136
  obs_stddev: tp.Union[
137
- ScalarFloat, Float[Array, "#N"], PositiveReal, Static
137
+ ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
138
138
  ] = 1.0,
139
139
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
140
140
  ):
@@ -148,8 +148,8 @@ class Gaussian(AbstractLikelihood):
148
148
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
149
149
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
150
150
  """
151
- if not isinstance(obs_stddev, (PositiveReal, Static)):
152
- obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
151
+ if not isinstance(obs_stddev, (NonNegativeReal, Static)):
152
+ obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
153
153
  self.obs_stddev = obs_stddev
154
154
 
155
155
  super().__init__(num_datapoints, integrator)
@@ -207,5 +207,5 @@ SumMeanFunction = ft.partial(
207
207
  CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
208
208
  )
209
209
  ProductMeanFunction = ft.partial(
210
- CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
210
+ CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
211
211
  )
@@ -82,6 +82,14 @@ class Parameter(nnx.Variable[T]):
82
82
  self._tag = tag
83
83
 
84
84
 
85
+ class NonNegativeReal(Parameter[T]):
86
+ """Parameter that is non-negative."""
87
+
88
+ def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
89
+ super().__init__(value=value, tag=tag, **kwargs)
90
+ _safe_assert(_check_is_non_negative, self.value)
91
+
92
+
85
93
  class PositiveReal(Parameter[T]):
86
94
  """Parameter that is strictly positive."""
87
95
 
@@ -143,6 +151,7 @@ class LowerTriangular(Parameter[T]):
143
151
 
144
152
  DEFAULT_BIJECTION = {
145
153
  "positive": npt.SoftplusTransform(),
154
+ "non_negative": npt.SoftplusTransform(),
146
155
  "real": npt.IdentityTransform(),
147
156
  "sigmoid": npt.SigmoidTransform(),
148
157
  "lower_triangular": FillTriangularTransform(),
@@ -164,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
164
173
  )
165
174
 
166
175
 
176
+ @checkify.checkify
177
+ def _check_is_non_negative(value):
178
+ checkify.check(
179
+ jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
180
+ )
181
+
182
+
167
183
  @checkify.checkify
168
184
  def _check_is_positive(value):
169
185
  checkify.check(
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from beartype.typing import Any
16
17
  from flax import nnx
17
18
  import jax.numpy as jnp
18
19
  import jax.random as jr
@@ -26,7 +27,15 @@ import scipy
26
27
 
27
28
  from gpjax.dataset import Dataset
28
29
  from gpjax.fit import (
30
+ _check_batch_size,
31
+ _check_log_rate,
32
+ _check_model,
33
+ _check_num_iters,
34
+ _check_optim,
35
+ _check_train_data,
36
+ _check_verbose,
29
37
  fit,
38
+ fit_lbfgs,
30
39
  fit_scipy,
31
40
  get_batch,
32
41
  )
@@ -141,6 +150,46 @@ def test_fit_scipy_simple():
141
150
  assert trained_model.bias.value == 1.0
142
151
 
143
152
 
153
+ def test_fit_lbfgs_simple():
154
+ # Create dataset:
155
+ X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1)
156
+ y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1)
157
+ D = Dataset(X, y)
158
+
159
+ # Define linear model:
160
+ class LinearModel(nnx.Module):
161
+ def __init__(self, weight: float, bias: float):
162
+ self.weight = PositiveReal(weight)
163
+ self.bias = Static(bias)
164
+
165
+ def __call__(self, x):
166
+ return self.weight.value * x + self.bias.value
167
+
168
+ model = LinearModel(weight=1.0, bias=1.0)
169
+
170
+ # Define loss function:
171
+ def mse(model, data):
172
+ pred = model(data.X)
173
+ return jnp.mean((pred - data.y) ** 2)
174
+
175
+ # Train with bfgs!
176
+ trained_model, final_loss = fit_lbfgs(
177
+ model=model,
178
+ objective=mse,
179
+ train_data=D,
180
+ max_iters=10,
181
+ )
182
+
183
+ # Ensure we return a model of the same class
184
+ assert isinstance(trained_model, LinearModel)
185
+
186
+ # Test reduction in loss:
187
+ assert mse(trained_model, D) < mse(model, D)
188
+
189
+ # Test stop_gradient on bias:
190
+ assert trained_model.bias.value == 1.0
191
+
192
+
144
193
  @pytest.mark.parametrize("n_data", [20])
145
194
  @pytest.mark.parametrize("verbose", [True, False])
146
195
  def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
@@ -179,8 +228,7 @@ def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
179
228
 
180
229
 
181
230
  @pytest.mark.parametrize("n_data", [20])
182
- @pytest.mark.parametrize("verbose", [True, False])
183
- def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
231
+ def test_fit_lbfgs_gp_regression(n_data: int) -> None:
184
232
  # Create dataset:
185
233
  key = jr.PRNGKey(123)
186
234
  x = jnp.sort(
@@ -195,20 +243,16 @@ def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
195
243
  posterior = prior * likelihood
196
244
 
197
245
  # Train with BFGS!
198
- trained_model_bfgs, history_bfgs = fit_scipy(
246
+ trained_model_bfgs, final_loss = fit_lbfgs(
199
247
  model=posterior,
200
248
  objective=conjugate_mll,
201
249
  train_data=D,
202
250
  max_iters=40,
203
- verbose=verbose,
204
251
  )
205
252
 
206
253
  # Ensure the trained model is a Gaussian process posterior
207
254
  assert isinstance(trained_model_bfgs, ConjugatePosterior)
208
255
 
209
- # Ensure we return a history_bfgs of the correct length
210
- assert len(history_bfgs) > 2
211
-
212
256
  # Ensure we reduce the loss
213
257
  assert conjugate_mll(trained_model_bfgs, D) < conjugate_mll(posterior, D)
214
258
 
@@ -324,3 +368,142 @@ def test_get_batch(n_data: int, n_dim: int, batch_size: int):
324
368
  assert New.y.shape[1:] == y.shape[1:]
325
369
  assert jnp.sum(New.X == B.X) <= n_dim * batch_size / n_data
326
370
  assert jnp.sum(New.y == B.y) <= n_dim * batch_size / n_data
371
+
372
+
373
+ @pytest.fixture
374
+ def valid_model() -> nnx.Module:
375
+ """Return a valid model for testing."""
376
+
377
+ class LinearModel(nnx.Module):
378
+ def __init__(self, weight: float, bias: float) -> None:
379
+ self.weight = PositiveReal(weight)
380
+ self.bias = Static(bias)
381
+
382
+ def __call__(self, x: Any) -> Any:
383
+ return self.weight.value * x + self.bias.value
384
+
385
+ return LinearModel(weight=1.0, bias=1.0)
386
+
387
+
388
+ @pytest.fixture
389
+ def valid_dataset() -> Dataset:
390
+ """Return a valid dataset for testing."""
391
+ X = jnp.array([[1.0], [2.0], [3.0]])
392
+ y = jnp.array([[1.0], [2.0], [3.0]])
393
+ return Dataset(X=X, y=y)
394
+
395
+
396
+ def test_check_model_valid(valid_model: nnx.Module) -> None:
397
+ """Test that a valid model passes validation."""
398
+ _check_model(valid_model)
399
+
400
+
401
+ def test_check_model_invalid() -> None:
402
+ """Test that an invalid model raises a TypeError."""
403
+ model = "not a model"
404
+ with pytest.raises(
405
+ TypeError, match="Expected model to be a subclass of nnx.Module"
406
+ ):
407
+ _check_model(model)
408
+
409
+
410
+ def test_check_train_data_valid(valid_dataset: Dataset) -> None:
411
+ """Test that valid training data passes validation."""
412
+ _check_train_data(valid_dataset)
413
+
414
+
415
+ def test_check_train_data_invalid() -> None:
416
+ """Test that invalid training data raises a TypeError."""
417
+ train_data = "not a dataset"
418
+ with pytest.raises(
419
+ TypeError, match="Expected train_data to be of type gpjax.Dataset"
420
+ ):
421
+ _check_train_data(train_data)
422
+
423
+
424
+ def test_check_optim_valid() -> None:
425
+ """Test that a valid optimiser passes validation."""
426
+ optim = ox.sgd(0.1)
427
+ _check_optim(optim)
428
+
429
+
430
+ def test_check_optim_invalid() -> None:
431
+ """Test that an invalid optimiser raises a TypeError."""
432
+ optim = "not an optimiser"
433
+ with pytest.raises(
434
+ TypeError, match="Expected optim to be of type optax.GradientTransformation"
435
+ ):
436
+ _check_optim(optim)
437
+
438
+
439
+ @pytest.mark.parametrize("num_iters", [1, 10, 100])
440
+ def test_check_num_iters_valid(num_iters: int) -> None:
441
+ """Test that valid number of iterations passes validation."""
442
+ _check_num_iters(num_iters)
443
+
444
+
445
+ def test_check_num_iters_invalid_type() -> None:
446
+ """Test that an invalid num_iters type raises a TypeError."""
447
+ num_iters = "not an int"
448
+ with pytest.raises(TypeError, match="Expected num_iters to be of type int"):
449
+ _check_num_iters(num_iters)
450
+
451
+
452
+ @pytest.mark.parametrize("num_iters", [0, -5])
453
+ def test_check_num_iters_invalid_value(num_iters: int) -> None:
454
+ """Test that an invalid num_iters value raises a ValueError."""
455
+ with pytest.raises(ValueError, match="Expected num_iters to be positive"):
456
+ _check_num_iters(num_iters)
457
+
458
+
459
+ @pytest.mark.parametrize("log_rate", [1, 10, 100])
460
+ def test_check_log_rate_valid(log_rate: int) -> None:
461
+ """Test that a valid log rate passes validation."""
462
+ _check_log_rate(log_rate)
463
+
464
+
465
+ def test_check_log_rate_invalid_type() -> None:
466
+ """Test that an invalid log_rate type raises a TypeError."""
467
+ log_rate = "not an int"
468
+ with pytest.raises(TypeError, match="Expected log_rate to be of type int"):
469
+ _check_log_rate(log_rate)
470
+
471
+
472
+ @pytest.mark.parametrize("log_rate", [0, -5])
473
+ def test_check_log_rate_invalid_value(log_rate: int) -> None:
474
+ """Test that an invalid log_rate value raises a ValueError."""
475
+ with pytest.raises(ValueError, match="Expected log_rate to be positive"):
476
+ _check_log_rate(log_rate)
477
+
478
+
479
+ @pytest.mark.parametrize("verbose", [True, False])
480
+ def test_check_verbose_valid(verbose: bool) -> None:
481
+ """Test that valid verbose values pass validation."""
482
+ _check_verbose(verbose)
483
+
484
+
485
+ def test_check_verbose_invalid() -> None:
486
+ """Test that an invalid verbose value raises a TypeError."""
487
+ verbose = "not a bool"
488
+ with pytest.raises(TypeError, match="Expected verbose to be of type bool"):
489
+ _check_verbose(verbose)
490
+
491
+
492
+ @pytest.mark.parametrize("batch_size", [1, 10, 100, -1])
493
+ def test_check_batch_size_valid(batch_size: int) -> None:
494
+ """Test that valid batch sizes pass validation."""
495
+ _check_batch_size(batch_size)
496
+
497
+
498
+ def test_check_batch_size_invalid_type() -> None:
499
+ """Test that an invalid batch_size type raises a TypeError."""
500
+ batch_size = "not an int"
501
+ with pytest.raises(TypeError, match="Expected batch_size to be of type int"):
502
+ _check_batch_size(batch_size)
503
+
504
+
505
+ @pytest.mark.parametrize("batch_size", [0, -2, -5])
506
+ def test_check_batch_size_invalid_value(batch_size: int) -> None:
507
+ """Test that invalid batch_size values raise a ValueError."""
508
+ with pytest.raises(ValueError, match="Expected batch_size to be positive or -1"):
509
+ _check_batch_size(batch_size)
@@ -31,7 +31,7 @@ from gpjax.kernels.nonstationary import (
31
31
  Polynomial,
32
32
  )
33
33
  from gpjax.parameters import (
34
- PositiveReal,
34
+ NonNegativeReal,
35
35
  Static,
36
36
  )
37
37
 
@@ -96,8 +96,8 @@ def test_init_override_paramtype(kernel_request):
96
96
  continue
97
97
  new_params[param] = Static(value)
98
98
 
99
- k = kernel(**new_params, variance=PositiveReal(variance))
100
- assert isinstance(k.variance, PositiveReal)
99
+ k = kernel(**new_params, variance=NonNegativeReal(variance))
100
+ assert isinstance(k.variance, NonNegativeReal)
101
101
 
102
102
  for param in params.keys():
103
103
  if param in ("degree", "order"):
@@ -112,7 +112,7 @@ def test_init_defaults(kernel: type[AbstractKernel]):
112
112
 
113
113
  # Check that the parameters are set correctly
114
114
  assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
115
- assert isinstance(k.variance, PositiveReal)
115
+ assert isinstance(k.variance, NonNegativeReal)
116
116
 
117
117
 
118
118
  @pytest.mark.parametrize("kernel", [k[0] for k in TESTED_KERNELS])
@@ -122,7 +122,7 @@ def test_init_variances(kernel: type[AbstractKernel], variance):
122
122
  k = kernel(variance=variance)
123
123
 
124
124
  # Check that the parameters are set correctly
125
- assert isinstance(k.variance, PositiveReal)
125
+ assert isinstance(k.variance, NonNegativeReal)
126
126
  assert jnp.allclose(k.variance.value, jnp.asarray(variance))
127
127
 
128
128
  # Check that error is raised if variance is not valid
@@ -35,6 +35,7 @@ from gpjax.kernels.stationary import (
35
35
  )
36
36
  from gpjax.kernels.stationary.base import StationaryKernel
37
37
  from gpjax.parameters import (
38
+ NonNegativeReal,
38
39
  PositiveReal,
39
40
  Static,
40
41
  )
@@ -106,12 +107,12 @@ def test_init_override_paramtype(kernel_request):
106
107
  for param, value in params.items():
107
108
  new_params[param] = Static(value)
108
109
 
109
- kwargs = {**new_params, "variance": PositiveReal(variance)}
110
+ kwargs = {**new_params, "variance": NonNegativeReal(variance)}
110
111
  if kernel != White:
111
112
  kwargs["lengthscale"] = PositiveReal(lengthscale)
112
113
 
113
114
  k = kernel(**kwargs)
114
- assert isinstance(k.variance, PositiveReal)
115
+ assert isinstance(k.variance, NonNegativeReal)
115
116
 
116
117
  for param in params.keys():
117
118
  assert isinstance(getattr(k, param), Static)
@@ -124,7 +125,7 @@ def test_init_defaults(kernel: type[StationaryKernel]):
124
125
 
125
126
  # Check that the parameters are set correctly
126
127
  assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
127
- assert isinstance(k.variance, PositiveReal)
128
+ assert isinstance(k.variance, NonNegativeReal)
128
129
  assert isinstance(k.lengthscale, PositiveReal)
129
130
 
130
131
 
@@ -167,7 +168,7 @@ def test_init_variances(kernel: type[StationaryKernel], variance):
167
168
  k = kernel(variance=variance)
168
169
 
169
170
  # Check that the parameters are set correctly
170
- assert isinstance(k.variance, PositiveReal)
171
+ assert isinstance(k.variance, NonNegativeReal)
171
172
  assert jnp.allclose(k.variance.value, jnp.asarray(variance))
172
173
 
173
174
  # Check that error is raised if variance is not valid