gpjax 0.11.0__tar.gz → 0.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {gpjax-0.11.0 → gpjax-0.11.1}/PKG-INFO +1 -1
  2. {gpjax-0.11.0 → gpjax-0.11.1}/examples/constructing_new_kernels.py +0 -3
  3. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/__init__.py +4 -2
  4. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/fit.py +107 -4
  5. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/arccosine.py +6 -3
  6. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/linear.py +3 -3
  7. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/polynomial.py +6 -3
  8. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/base.py +6 -3
  9. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/likelihoods.py +4 -4
  10. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/mean_functions.py +1 -1
  11. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/parameters.py +16 -0
  12. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_fit.py +195 -13
  13. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_nonstationary.py +5 -5
  14. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_stationary.py +5 -4
  15. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_likelihoods.py +2 -2
  16. gpjax-0.11.1/tests/test_mean_functions.py +249 -0
  17. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_numpyro_extras.py +76 -0
  18. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_parameters.py +4 -0
  19. gpjax-0.11.0/.cursorrules +0 -37
  20. gpjax-0.11.0/tests/test_mean_functions.py +0 -81
  21. {gpjax-0.11.0 → gpjax-0.11.1}/.github/CODE_OF_CONDUCT.md +0 -0
  22. {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  23. {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  24. {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  25. {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  26. {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  27. {gpjax-0.11.0 → gpjax-0.11.1}/.github/codecov.yml +0 -0
  28. {gpjax-0.11.0 → gpjax-0.11.1}/.github/labels.yml +0 -0
  29. {gpjax-0.11.0 → gpjax-0.11.1}/.github/pull_request_template.md +0 -0
  30. {gpjax-0.11.0 → gpjax-0.11.1}/.github/release-drafter.yml +0 -0
  31. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/build_docs.yml +0 -0
  32. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/integration.yml +0 -0
  33. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/pr_greeting.yml +0 -0
  34. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/ruff.yml +0 -0
  35. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/stale_prs.yml +0 -0
  36. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/test_docs.yml +0 -0
  37. {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/tests.yml +0 -0
  38. {gpjax-0.11.0 → gpjax-0.11.1}/.gitignore +0 -0
  39. {gpjax-0.11.0 → gpjax-0.11.1}/CITATION.bib +0 -0
  40. {gpjax-0.11.0 → gpjax-0.11.1}/LICENSE.txt +0 -0
  41. {gpjax-0.11.0 → gpjax-0.11.1}/Makefile +0 -0
  42. {gpjax-0.11.0 → gpjax-0.11.1}/README.md +0 -0
  43. {gpjax-0.11.0 → gpjax-0.11.1}/docs/CODE_OF_CONDUCT.md +0 -0
  44. {gpjax-0.11.0 → gpjax-0.11.1}/docs/GOVERNANCE.md +0 -0
  45. {gpjax-0.11.0 → gpjax-0.11.1}/docs/contributing.md +0 -0
  46. {gpjax-0.11.0 → gpjax-0.11.1}/docs/design.md +0 -0
  47. {gpjax-0.11.0 → gpjax-0.11.1}/docs/index.md +0 -0
  48. {gpjax-0.11.0 → gpjax-0.11.1}/docs/index.rst +0 -0
  49. {gpjax-0.11.0 → gpjax-0.11.1}/docs/installation.md +0 -0
  50. {gpjax-0.11.0 → gpjax-0.11.1}/docs/javascripts/katex.js +0 -0
  51. {gpjax-0.11.0 → gpjax-0.11.1}/docs/refs.bib +0 -0
  52. {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/gen_examples.py +0 -0
  53. {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/gen_pages.py +0 -0
  54. {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/notebook_converter.py +0 -0
  55. {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/sharp_bits_figure.py +0 -0
  56. {gpjax-0.11.0 → gpjax-0.11.1}/docs/sharp_bits.md +0 -0
  57. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/GP.pdf +0 -0
  58. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/GP.svg +0 -0
  59. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/bijector_figure.svg +0 -0
  60. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/css/gpjax_theme.css +0 -0
  61. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/favicon.ico +0 -0
  62. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax.mplstyle +0 -0
  63. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax_logo.pdf +0 -0
  64. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax_logo.svg +0 -0
  65. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/lato.ttf +0 -0
  66. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/logo.png +0 -0
  67. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/logo.svg +0 -0
  68. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/main.py +0 -0
  69. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/step_size_figure.png +0 -0
  70. {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/step_size_figure.svg +0 -0
  71. {gpjax-0.11.0 → gpjax-0.11.1}/docs/stylesheets/extra.css +0 -0
  72. {gpjax-0.11.0 → gpjax-0.11.1}/docs/stylesheets/permalinks.css +0 -0
  73. {gpjax-0.11.0 → gpjax-0.11.1}/examples/backend.py +0 -0
  74. {gpjax-0.11.0 → gpjax-0.11.1}/examples/barycentres/barycentre_gp.gif +0 -0
  75. {gpjax-0.11.0 → gpjax-0.11.1}/examples/barycentres.py +0 -0
  76. {gpjax-0.11.0 → gpjax-0.11.1}/examples/classification.py +0 -0
  77. {gpjax-0.11.0 → gpjax-0.11.1}/examples/collapsed_vi.py +0 -0
  78. {gpjax-0.11.0 → gpjax-0.11.1}/examples/data/max_tempeature_switzerland.csv +0 -0
  79. {gpjax-0.11.0 → gpjax-0.11.1}/examples/data/yacht_hydrodynamics.data +0 -0
  80. {gpjax-0.11.0 → gpjax-0.11.1}/examples/deep_kernels.py +0 -0
  81. {gpjax-0.11.0 → gpjax-0.11.1}/examples/gpjax.mplstyle +0 -0
  82. {gpjax-0.11.0 → gpjax-0.11.1}/examples/graph_kernels.py +0 -0
  83. {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps/decomposed_mll.png +0 -0
  84. {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps/generating_process.png +0 -0
  85. {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps.py +0 -0
  86. {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_kernels.py +0 -0
  87. {gpjax-0.11.0 → gpjax-0.11.1}/examples/likelihoods_guide.py +0 -0
  88. {gpjax-0.11.0 → gpjax-0.11.1}/examples/oceanmodelling.py +0 -0
  89. {gpjax-0.11.0 → gpjax-0.11.1}/examples/poisson.py +0 -0
  90. {gpjax-0.11.0 → gpjax-0.11.1}/examples/regression.py +0 -0
  91. {gpjax-0.11.0 → gpjax-0.11.1}/examples/uncollapsed_vi.py +0 -0
  92. {gpjax-0.11.0 → gpjax-0.11.1}/examples/utils.py +0 -0
  93. {gpjax-0.11.0 → gpjax-0.11.1}/examples/yacht.py +0 -0
  94. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/citation.py +0 -0
  95. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/dataset.py +0 -0
  96. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/distributions.py +0 -0
  97. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/gps.py +0 -0
  98. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/integrators.py +0 -0
  99. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/__init__.py +0 -0
  100. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/approximations/__init__.py +0 -0
  101. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/approximations/rff.py +0 -0
  102. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/base.py +0 -0
  103. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/__init__.py +0 -0
  104. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/base.py +0 -0
  105. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/basis_functions.py +0 -0
  106. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  107. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/dense.py +0 -0
  108. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/diagonal.py +0 -0
  109. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/eigen.py +0 -0
  110. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  111. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/graph.py +0 -0
  112. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/utils.py +0 -0
  113. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/__init__.py +0 -0
  114. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/__init__.py +0 -0
  115. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern12.py +0 -0
  116. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern32.py +0 -0
  117. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern52.py +0 -0
  118. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/periodic.py +0 -0
  119. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  120. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  121. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/rbf.py +0 -0
  122. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/utils.py +0 -0
  123. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/white.py +0 -0
  124. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/lower_cholesky.py +0 -0
  125. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/numpyro_extras.py +0 -0
  126. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/objectives.py +0 -0
  127. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/scan.py +0 -0
  128. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/typing.py +0 -0
  129. {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/variational_families.py +0 -0
  130. {gpjax-0.11.0 → gpjax-0.11.1}/mkdocs.yml +0 -0
  131. {gpjax-0.11.0 → gpjax-0.11.1}/pyproject.toml +0 -0
  132. {gpjax-0.11.0 → gpjax-0.11.1}/static/CONTRIBUTING.md +0 -0
  133. {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.bib +0 -0
  134. {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.md +0 -0
  135. {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.pdf +0 -0
  136. {gpjax-0.11.0 → gpjax-0.11.1}/tests/__init__.py +0 -0
  137. {gpjax-0.11.0 → gpjax-0.11.1}/tests/conftest.py +0 -0
  138. {gpjax-0.11.0 → gpjax-0.11.1}/tests/integration_tests.py +0 -0
  139. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_citations.py +0 -0
  140. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_dataset.py +0 -0
  141. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_gaussian_distribution.py +0 -0
  142. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_gps.py +0 -0
  143. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_integrators.py +0 -0
  144. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/__init__.py +0 -0
  145. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_approximations.py +0 -0
  146. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_base.py +0 -0
  147. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_computation.py +0 -0
  148. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_non_euclidean.py +0 -0
  149. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_utils.py +0 -0
  150. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_lower_cholesky.py +0 -0
  151. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_markdown.py +0 -0
  152. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_objectives.py +0 -0
  153. {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.0
3
+ Version: 0.11.1
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -33,7 +33,6 @@ from jaxtyping import (
33
33
  install_import_hook,
34
34
  )
35
35
  import matplotlib.pyplot as plt
36
- import numpyro.distributions as npd
37
36
  from numpyro.distributions import constraints
38
37
  import numpyro.distributions.transforms as npt
39
38
 
@@ -52,8 +51,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
52
51
  import gpjax as gpx
53
52
 
54
53
 
55
- tfb = tfp.bijectors
56
-
57
54
  # set the default style for plotting
58
55
  use_mpl_style()
59
56
 
@@ -32,14 +32,15 @@ from gpjax.citation import cite
32
32
  from gpjax.dataset import Dataset
33
33
  from gpjax.fit import (
34
34
  fit,
35
+ fit_lbfgs,
35
36
  fit_scipy,
36
37
  )
37
38
 
38
39
  __license__ = "MIT"
39
- __description__ = "Didactic Gaussian processes in JAX"
40
+ __description__ = "Gaussian processes in JAX and Flax"
40
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.11.0"
43
+ __version__ = "0.11.1"
43
44
 
44
45
  __all__ = [
45
46
  "base",
@@ -56,5 +57,6 @@ __all__ = [
56
57
  "fit",
57
58
  "Module",
58
59
  "param_field",
60
+ "fit_lbfgs",
59
61
  "fit_scipy",
60
62
  ]
@@ -15,13 +15,13 @@
15
15
 
16
16
  import typing as tp
17
17
 
18
- from flax import nnx
19
18
  import jax
20
- from jax.flatten_util import ravel_pytree
21
19
  import jax.numpy as jnp
22
20
  import jax.random as jr
23
- from numpyro.distributions.transforms import Transform
24
21
  import optax as ox
22
+ from flax import nnx
23
+ from jax.flatten_util import ravel_pytree
24
+ from numpyro.distributions.transforms import Transform
25
25
  from scipy.optimize import minimize
26
26
 
27
27
  from gpjax.dataset import Dataset
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
127
127
  _check_verbose(verbose)
128
128
 
129
129
  # Model state filtering
130
-
131
130
  graphdef, params, *static_state = nnx.split(model, Parameter, ...)
132
131
 
133
132
  # Parameters bijection to unconstrained space
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
253
252
  return model, history
254
253
 
255
254
 
255
+ def fit_lbfgs(
256
+ *,
257
+ model: Model,
258
+ objective: Objective,
259
+ train_data: Dataset,
260
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
261
+ max_iters: int = 100,
262
+ safe: bool = True,
263
+ max_linesearch_steps: int = 32,
264
+ gtol: float = 1e-5,
265
+ ) -> tuple[Model, jax.Array]:
266
+ r"""Train a Module model with respect to a supplied Objective function.
267
+
268
+ Uses Optax's LBFGS implementation and a jax.lax.while loop.
269
+
270
+ Args:
271
+ model: the model Module to be optimised.
272
+ objective: The objective function that we are optimising with
273
+ respect to.
274
+ train_data (Dataset): The training data to be used for the optimisation.
275
+ max_iters (int): The maximum number of optimisation steps to run. Defaults
276
+ to 500.
277
+ safe (bool): Whether to check the types of the inputs.
278
+ max_linesearch_steps (int): The maximum number of linesearch steps to use
279
+ for finding the stepsize.
280
+ gtol (float): Terminate the optimisation if the L2 norm of the gradient is
281
+ below this threshold.
282
+
283
+ Returns:
284
+ A tuple comprising the optimised model and final loss.
285
+ """
286
+ if safe:
287
+ # Check inputs
288
+ _check_model(model)
289
+ _check_train_data(train_data)
290
+ _check_num_iters(max_iters)
291
+
292
+ # Model state filtering
293
+ graphdef, params, *static_state = nnx.split(model, Parameter, ...)
294
+
295
+ # Parameters bijection to unconstrained space
296
+ if params_bijection is not None:
297
+ params = transform(params, params_bijection, inverse=True)
298
+
299
+ # Loss definition
300
+ def loss(params: nnx.State) -> ScalarFloat:
301
+ params = transform(params, params_bijection)
302
+ model = nnx.merge(graphdef, params, *static_state)
303
+ return objective(model, train_data)
304
+
305
+ # Initialise optimiser
306
+ optim = ox.lbfgs(
307
+ linesearch=ox.scale_by_zoom_linesearch(
308
+ max_linesearch_steps=max_linesearch_steps,
309
+ initial_guess_strategy="one",
310
+ )
311
+ )
312
+ opt_state = optim.init(params)
313
+ loss_value_and_grad = ox.value_and_grad_from_state(loss)
314
+
315
+ # Optimisation step.
316
+ def step(carry):
317
+ params, opt_state = carry
318
+
319
+ # Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
320
+ # See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
321
+ loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
322
+ updates, opt_state = optim.update(
323
+ loss_gradient,
324
+ opt_state,
325
+ params,
326
+ value=loss_val,
327
+ grad=loss_gradient,
328
+ value_fn=loss,
329
+ )
330
+ params = ox.apply_updates(params, updates)
331
+
332
+ return params, opt_state
333
+
334
+ def continue_fn(carry):
335
+ _, opt_state = carry
336
+ n = ox.tree_utils.tree_get(opt_state, "count")
337
+ g = ox.tree_utils.tree_get(opt_state, "grad")
338
+ g_l2_norm = ox.tree_utils.tree_l2_norm(g)
339
+ return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
340
+
341
+ # Optimisation loop
342
+ params, opt_state = jax.lax.while_loop(
343
+ continue_fn,
344
+ step,
345
+ (params, opt_state),
346
+ )
347
+ final_loss = ox.tree_utils.tree_get(opt_state, "value")
348
+
349
+ # Parameters bijection to constrained space
350
+ if params_bijection is not None:
351
+ params = transform(params, params_bijection)
352
+
353
+ # Reconstruct model
354
+ model = nnx.merge(graphdef, params, *static_state)
355
+
356
+ return model, final_loss
357
+
358
+
256
359
  def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
257
360
  """Batch the data into mini-batches. Sampling is done with replacement.
258
361
 
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
91
94
  if isinstance(variance, nnx.Variable):
92
95
  self.variance = variance
93
96
  else:
94
- self.variance = PositiveReal(variance)
97
+ self.variance = NonNegativeReal(variance)
95
98
  if tp.TYPE_CHECKING:
96
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
99
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
97
100
 
98
101
  if isinstance(bias_variance, nnx.Variable):
99
102
  self.bias_variance = bias_variance
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import NonNegativeReal
27
27
  from gpjax.typing import (
28
28
  Array,
29
29
  ScalarArray,
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
64
64
  if isinstance(variance, nnx.Variable):
65
65
  self.variance = variance
66
66
  else:
67
- self.variance = PositiveReal(variance)
67
+ self.variance = NonNegativeReal(variance)
68
68
  if tp.TYPE_CHECKING:
69
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
69
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
70
70
 
71
71
  def __call__(
72
72
  self,
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
76
79
  if isinstance(variance, nnx.Variable):
77
80
  self.variance = variance
78
81
  else:
79
- self.variance = PositiveReal(variance)
82
+ self.variance = NonNegativeReal(variance)
80
83
  if tp.TYPE_CHECKING:
81
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
84
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
82
85
 
83
86
  self.name = f"Polynomial (degree {self.degree})"
84
87
 
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  DenseKernelComputation,
27
27
  )
28
- from gpjax.parameters import PositiveReal
28
+ from gpjax.parameters import (
29
+ NonNegativeReal,
30
+ PositiveReal,
31
+ )
29
32
  from gpjax.typing import (
30
33
  Array,
31
34
  ScalarArray,
@@ -85,11 +88,11 @@ class StationaryKernel(AbstractKernel):
85
88
  if isinstance(variance, nnx.Variable):
86
89
  self.variance = variance
87
90
  else:
88
- self.variance = PositiveReal(variance)
91
+ self.variance = NonNegativeReal(variance)
89
92
 
90
93
  # static typing
91
94
  if tp.TYPE_CHECKING:
92
- self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
95
+ self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
93
96
 
94
97
  @property
95
98
  def spectral_density(self) -> npd.Normal | npd.StudentT:
@@ -28,7 +28,7 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- PositiveReal,
31
+ NonNegativeReal,
32
32
  Static,
33
33
  )
34
34
  from gpjax.typing import (
@@ -134,7 +134,7 @@ class Gaussian(AbstractLikelihood):
134
134
  self,
135
135
  num_datapoints: int,
136
136
  obs_stddev: tp.Union[
137
- ScalarFloat, Float[Array, "#N"], PositiveReal, Static
137
+ ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
138
138
  ] = 1.0,
139
139
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
140
140
  ):
@@ -148,8 +148,8 @@ class Gaussian(AbstractLikelihood):
148
148
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
149
149
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
150
150
  """
151
- if not isinstance(obs_stddev, (PositiveReal, Static)):
152
- obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
151
+ if not isinstance(obs_stddev, (NonNegativeReal, Static)):
152
+ obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
153
153
  self.obs_stddev = obs_stddev
154
154
 
155
155
  super().__init__(num_datapoints, integrator)
@@ -207,5 +207,5 @@ SumMeanFunction = ft.partial(
207
207
  CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
208
208
  )
209
209
  ProductMeanFunction = ft.partial(
210
- CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
210
+ CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
211
211
  )
@@ -82,6 +82,14 @@ class Parameter(nnx.Variable[T]):
82
82
  self._tag = tag
83
83
 
84
84
 
85
+ class NonNegativeReal(Parameter[T]):
86
+ """Parameter that is non-negative."""
87
+
88
+ def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
89
+ super().__init__(value=value, tag=tag, **kwargs)
90
+ _safe_assert(_check_is_non_negative, self.value)
91
+
92
+
85
93
  class PositiveReal(Parameter[T]):
86
94
  """Parameter that is strictly positive."""
87
95
 
@@ -143,6 +151,7 @@ class LowerTriangular(Parameter[T]):
143
151
 
144
152
  DEFAULT_BIJECTION = {
145
153
  "positive": npt.SoftplusTransform(),
154
+ "non_negative": npt.SoftplusTransform(),
146
155
  "real": npt.IdentityTransform(),
147
156
  "sigmoid": npt.SigmoidTransform(),
148
157
  "lower_triangular": FillTriangularTransform(),
@@ -164,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
164
173
  )
165
174
 
166
175
 
176
+ @checkify.checkify
177
+ def _check_is_non_negative(value):
178
+ checkify.check(
179
+ jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
180
+ )
181
+
182
+
167
183
  @checkify.checkify
168
184
  def _check_is_positive(value):
169
185
  checkify.check(
@@ -13,20 +13,24 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from flax import nnx
17
16
  import jax.numpy as jnp
18
17
  import jax.random as jr
19
- from jaxtyping import (
20
- Float,
21
- Num,
22
- )
23
18
  import optax as ox
24
19
  import pytest
25
20
  import scipy
26
-
21
+ from beartype.typing import Any
22
+ from flax import nnx
27
23
  from gpjax.dataset import Dataset
28
24
  from gpjax.fit import (
25
+ _check_batch_size,
26
+ _check_log_rate,
27
+ _check_model,
28
+ _check_num_iters,
29
+ _check_optim,
30
+ _check_train_data,
31
+ _check_verbose,
29
32
  fit,
33
+ fit_lbfgs,
30
34
  fit_scipy,
31
35
  get_batch,
32
36
  )
@@ -50,6 +54,10 @@ from gpjax.parameters import (
50
54
  )
51
55
  from gpjax.typing import Array
52
56
  from gpjax.variational_families import VariationalGaussian
57
+ from jaxtyping import (
58
+ Float,
59
+ Num,
60
+ )
53
61
 
54
62
 
55
63
  def test_fit_simple() -> None:
@@ -141,6 +149,46 @@ def test_fit_scipy_simple():
141
149
  assert trained_model.bias.value == 1.0
142
150
 
143
151
 
152
+ def test_fit_lbfgs_simple():
153
+ # Create dataset:
154
+ X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1)
155
+ y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1)
156
+ D = Dataset(X, y)
157
+
158
+ # Define linear model:
159
+ class LinearModel(nnx.Module):
160
+ def __init__(self, weight: float, bias: float):
161
+ self.weight = PositiveReal(weight)
162
+ self.bias = Static(bias)
163
+
164
+ def __call__(self, x):
165
+ return self.weight.value * x + self.bias.value
166
+
167
+ model = LinearModel(weight=1.0, bias=1.0)
168
+
169
+ # Define loss function:
170
+ def mse(model, data):
171
+ pred = model(data.X)
172
+ return jnp.mean((pred - data.y) ** 2)
173
+
174
+ # Train with bfgs!
175
+ trained_model, final_loss = fit_lbfgs(
176
+ model=model,
177
+ objective=mse,
178
+ train_data=D,
179
+ max_iters=10,
180
+ )
181
+
182
+ # Ensure we return a model of the same class
183
+ assert isinstance(trained_model, LinearModel)
184
+
185
+ # Test reduction in loss:
186
+ assert mse(trained_model, D) < mse(model, D)
187
+
188
+ # Test stop_gradient on bias:
189
+ assert trained_model.bias.value == 1.0
190
+
191
+
144
192
  @pytest.mark.parametrize("n_data", [20])
145
193
  @pytest.mark.parametrize("verbose", [True, False])
146
194
  def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
@@ -179,8 +227,7 @@ def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
179
227
 
180
228
 
181
229
  @pytest.mark.parametrize("n_data", [20])
182
- @pytest.mark.parametrize("verbose", [True, False])
183
- def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
230
+ def test_fit_lbfgs_gp_regression(n_data: int) -> None:
184
231
  # Create dataset:
185
232
  key = jr.PRNGKey(123)
186
233
  x = jnp.sort(
@@ -195,20 +242,16 @@ def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
195
242
  posterior = prior * likelihood
196
243
 
197
244
  # Train with BFGS!
198
- trained_model_bfgs, history_bfgs = fit_scipy(
245
+ trained_model_bfgs, final_loss = fit_lbfgs(
199
246
  model=posterior,
200
247
  objective=conjugate_mll,
201
248
  train_data=D,
202
249
  max_iters=40,
203
- verbose=verbose,
204
250
  )
205
251
 
206
252
  # Ensure the trained model is a Gaussian process posterior
207
253
  assert isinstance(trained_model_bfgs, ConjugatePosterior)
208
254
 
209
- # Ensure we return a history_bfgs of the correct length
210
- assert len(history_bfgs) > 2
211
-
212
255
  # Ensure we reduce the loss
213
256
  assert conjugate_mll(trained_model_bfgs, D) < conjugate_mll(posterior, D)
214
257
 
@@ -324,3 +367,142 @@ def test_get_batch(n_data: int, n_dim: int, batch_size: int):
324
367
  assert New.y.shape[1:] == y.shape[1:]
325
368
  assert jnp.sum(New.X == B.X) <= n_dim * batch_size / n_data
326
369
  assert jnp.sum(New.y == B.y) <= n_dim * batch_size / n_data
370
+
371
+
372
+ @pytest.fixture
373
+ def valid_model() -> nnx.Module:
374
+ """Return a valid model for testing."""
375
+
376
+ class LinearModel(nnx.Module):
377
+ def __init__(self, weight: float, bias: float) -> None:
378
+ self.weight = PositiveReal(weight)
379
+ self.bias = Static(bias)
380
+
381
+ def __call__(self, x: Any) -> Any:
382
+ return self.weight.value * x + self.bias.value
383
+
384
+ return LinearModel(weight=1.0, bias=1.0)
385
+
386
+
387
+ @pytest.fixture
388
+ def valid_dataset() -> Dataset:
389
+ """Return a valid dataset for testing."""
390
+ X = jnp.array([[1.0], [2.0], [3.0]])
391
+ y = jnp.array([[1.0], [2.0], [3.0]])
392
+ return Dataset(X=X, y=y)
393
+
394
+
395
+ def test_check_model_valid(valid_model: nnx.Module) -> None:
396
+ """Test that a valid model passes validation."""
397
+ _check_model(valid_model)
398
+
399
+
400
+ def test_check_model_invalid() -> None:
401
+ """Test that an invalid model raises a TypeError."""
402
+ model = "not a model"
403
+ with pytest.raises(
404
+ TypeError, match="Expected model to be a subclass of nnx.Module"
405
+ ):
406
+ _check_model(model)
407
+
408
+
409
+ def test_check_train_data_valid(valid_dataset: Dataset) -> None:
410
+ """Test that valid training data passes validation."""
411
+ _check_train_data(valid_dataset)
412
+
413
+
414
+ def test_check_train_data_invalid() -> None:
415
+ """Test that invalid training data raises a TypeError."""
416
+ train_data = "not a dataset"
417
+ with pytest.raises(
418
+ TypeError, match="Expected train_data to be of type gpjax.Dataset"
419
+ ):
420
+ _check_train_data(train_data)
421
+
422
+
423
+ def test_check_optim_valid() -> None:
424
+ """Test that a valid optimiser passes validation."""
425
+ optim = ox.sgd(0.1)
426
+ _check_optim(optim)
427
+
428
+
429
+ def test_check_optim_invalid() -> None:
430
+ """Test that an invalid optimiser raises a TypeError."""
431
+ optim = "not an optimiser"
432
+ with pytest.raises(
433
+ TypeError, match="Expected optim to be of type optax.GradientTransformation"
434
+ ):
435
+ _check_optim(optim)
436
+
437
+
438
+ @pytest.mark.parametrize("num_iters", [1, 10, 100])
439
+ def test_check_num_iters_valid(num_iters: int) -> None:
440
+ """Test that valid number of iterations passes validation."""
441
+ _check_num_iters(num_iters)
442
+
443
+
444
+ def test_check_num_iters_invalid_type() -> None:
445
+ """Test that an invalid num_iters type raises a TypeError."""
446
+ num_iters = "not an int"
447
+ with pytest.raises(TypeError, match="Expected num_iters to be of type int"):
448
+ _check_num_iters(num_iters)
449
+
450
+
451
+ @pytest.mark.parametrize("num_iters", [0, -5])
452
+ def test_check_num_iters_invalid_value(num_iters: int) -> None:
453
+ """Test that an invalid num_iters value raises a ValueError."""
454
+ with pytest.raises(ValueError, match="Expected num_iters to be positive"):
455
+ _check_num_iters(num_iters)
456
+
457
+
458
+ @pytest.mark.parametrize("log_rate", [1, 10, 100])
459
+ def test_check_log_rate_valid(log_rate: int) -> None:
460
+ """Test that a valid log rate passes validation."""
461
+ _check_log_rate(log_rate)
462
+
463
+
464
+ def test_check_log_rate_invalid_type() -> None:
465
+ """Test that an invalid log_rate type raises a TypeError."""
466
+ log_rate = "not an int"
467
+ with pytest.raises(TypeError, match="Expected log_rate to be of type int"):
468
+ _check_log_rate(log_rate)
469
+
470
+
471
+ @pytest.mark.parametrize("log_rate", [0, -5])
472
+ def test_check_log_rate_invalid_value(log_rate: int) -> None:
473
+ """Test that an invalid log_rate value raises a ValueError."""
474
+ with pytest.raises(ValueError, match="Expected log_rate to be positive"):
475
+ _check_log_rate(log_rate)
476
+
477
+
478
+ @pytest.mark.parametrize("verbose", [True, False])
479
+ def test_check_verbose_valid(verbose: bool) -> None:
480
+ """Test that valid verbose values pass validation."""
481
+ _check_verbose(verbose)
482
+
483
+
484
+ def test_check_verbose_invalid() -> None:
485
+ """Test that an invalid verbose value raises a TypeError."""
486
+ verbose = "not a bool"
487
+ with pytest.raises(TypeError, match="Expected verbose to be of type bool"):
488
+ _check_verbose(verbose)
489
+
490
+
491
+ @pytest.mark.parametrize("batch_size", [1, 10, 100, -1])
492
+ def test_check_batch_size_valid(batch_size: int) -> None:
493
+ """Test that valid batch sizes pass validation."""
494
+ _check_batch_size(batch_size)
495
+
496
+
497
+ def test_check_batch_size_invalid_type() -> None:
498
+ """Test that an invalid batch_size type raises a TypeError."""
499
+ batch_size = "not an int"
500
+ with pytest.raises(TypeError, match="Expected batch_size to be of type int"):
501
+ _check_batch_size(batch_size)
502
+
503
+
504
+ @pytest.mark.parametrize("batch_size", [0, -2, -5])
505
+ def test_check_batch_size_invalid_value(batch_size: int) -> None:
506
+ """Test that invalid batch_size values raise a ValueError."""
507
+ with pytest.raises(ValueError, match="Expected batch_size to be positive or -1"):
508
+ _check_batch_size(batch_size)
@@ -31,7 +31,7 @@ from gpjax.kernels.nonstationary import (
31
31
  Polynomial,
32
32
  )
33
33
  from gpjax.parameters import (
34
- PositiveReal,
34
+ NonNegativeReal,
35
35
  Static,
36
36
  )
37
37
 
@@ -96,8 +96,8 @@ def test_init_override_paramtype(kernel_request):
96
96
  continue
97
97
  new_params[param] = Static(value)
98
98
 
99
- k = kernel(**new_params, variance=PositiveReal(variance))
100
- assert isinstance(k.variance, PositiveReal)
99
+ k = kernel(**new_params, variance=NonNegativeReal(variance))
100
+ assert isinstance(k.variance, NonNegativeReal)
101
101
 
102
102
  for param in params.keys():
103
103
  if param in ("degree", "order"):
@@ -112,7 +112,7 @@ def test_init_defaults(kernel: type[AbstractKernel]):
112
112
 
113
113
  # Check that the parameters are set correctly
114
114
  assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
115
- assert isinstance(k.variance, PositiveReal)
115
+ assert isinstance(k.variance, NonNegativeReal)
116
116
 
117
117
 
118
118
  @pytest.mark.parametrize("kernel", [k[0] for k in TESTED_KERNELS])
@@ -122,7 +122,7 @@ def test_init_variances(kernel: type[AbstractKernel], variance):
122
122
  k = kernel(variance=variance)
123
123
 
124
124
  # Check that the parameters are set correctly
125
- assert isinstance(k.variance, PositiveReal)
125
+ assert isinstance(k.variance, NonNegativeReal)
126
126
  assert jnp.allclose(k.variance.value, jnp.asarray(variance))
127
127
 
128
128
  # Check that error is raised if variance is not valid