gpjax 0.11.0__tar.gz → 0.11.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.11.0 → gpjax-0.11.1}/PKG-INFO +1 -1
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/constructing_new_kernels.py +0 -3
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/__init__.py +4 -2
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/fit.py +107 -4
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/arccosine.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/linear.py +3 -3
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/polynomial.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/base.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/likelihoods.py +4 -4
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/mean_functions.py +1 -1
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/parameters.py +16 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_fit.py +195 -13
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_nonstationary.py +5 -5
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_stationary.py +5 -4
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_likelihoods.py +2 -2
- gpjax-0.11.1/tests/test_mean_functions.py +249 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_numpyro_extras.py +76 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_parameters.py +4 -0
- gpjax-0.11.0/.cursorrules +0 -37
- gpjax-0.11.0/tests/test_mean_functions.py +0 -81
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/codecov.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/labels.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/pull_request_template.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/release-drafter.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/integration.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.github/workflows/tests.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/.gitignore +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/CITATION.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/LICENSE.txt +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/Makefile +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/README.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/contributing.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/design.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/index.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/index.rst +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/installation.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/javascripts/katex.js +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/refs.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/sharp_bits.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/GP.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/GP.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/favicon.ico +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/backend.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/barycentres.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/classification.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/collapsed_vi.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/deep_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/graph_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/oceanmodelling.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/poisson.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/regression.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/examples/yacht.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/citation.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/dataset.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/distributions.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/integrators.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/objectives.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/scan.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/typing.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/gpjax/variational_families.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/mkdocs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/pyproject.toml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/static/paper.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/conftest.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/integration_tests.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_citations.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_dataset.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_integrators.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_markdown.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_objectives.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.1}/tests/test_variational_families.py +0 -0
|
@@ -33,7 +33,6 @@ from jaxtyping import (
|
|
|
33
33
|
install_import_hook,
|
|
34
34
|
)
|
|
35
35
|
import matplotlib.pyplot as plt
|
|
36
|
-
import numpyro.distributions as npd
|
|
37
36
|
from numpyro.distributions import constraints
|
|
38
37
|
import numpyro.distributions.transforms as npt
|
|
39
38
|
|
|
@@ -52,8 +51,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
|
|
|
52
51
|
import gpjax as gpx
|
|
53
52
|
|
|
54
53
|
|
|
55
|
-
tfb = tfp.bijectors
|
|
56
|
-
|
|
57
54
|
# set the default style for plotting
|
|
58
55
|
use_mpl_style()
|
|
59
56
|
|
|
@@ -32,14 +32,15 @@ from gpjax.citation import cite
|
|
|
32
32
|
from gpjax.dataset import Dataset
|
|
33
33
|
from gpjax.fit import (
|
|
34
34
|
fit,
|
|
35
|
+
fit_lbfgs,
|
|
35
36
|
fit_scipy,
|
|
36
37
|
)
|
|
37
38
|
|
|
38
39
|
__license__ = "MIT"
|
|
39
|
-
__description__ = "
|
|
40
|
+
__description__ = "Gaussian processes in JAX and Flax"
|
|
40
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.11.
|
|
43
|
+
__version__ = "0.11.1"
|
|
43
44
|
|
|
44
45
|
__all__ = [
|
|
45
46
|
"base",
|
|
@@ -56,5 +57,6 @@ __all__ = [
|
|
|
56
57
|
"fit",
|
|
57
58
|
"Module",
|
|
58
59
|
"param_field",
|
|
60
|
+
"fit_lbfgs",
|
|
59
61
|
"fit_scipy",
|
|
60
62
|
]
|
|
@@ -15,13 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
import typing as tp
|
|
17
17
|
|
|
18
|
-
from flax import nnx
|
|
19
18
|
import jax
|
|
20
|
-
from jax.flatten_util import ravel_pytree
|
|
21
19
|
import jax.numpy as jnp
|
|
22
20
|
import jax.random as jr
|
|
23
|
-
from numpyro.distributions.transforms import Transform
|
|
24
21
|
import optax as ox
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from jax.flatten_util import ravel_pytree
|
|
24
|
+
from numpyro.distributions.transforms import Transform
|
|
25
25
|
from scipy.optimize import minimize
|
|
26
26
|
|
|
27
27
|
from gpjax.dataset import Dataset
|
|
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
|
|
|
127
127
|
_check_verbose(verbose)
|
|
128
128
|
|
|
129
129
|
# Model state filtering
|
|
130
|
-
|
|
131
130
|
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
132
131
|
|
|
133
132
|
# Parameters bijection to unconstrained space
|
|
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
|
|
|
253
252
|
return model, history
|
|
254
253
|
|
|
255
254
|
|
|
255
|
+
def fit_lbfgs(
|
|
256
|
+
*,
|
|
257
|
+
model: Model,
|
|
258
|
+
objective: Objective,
|
|
259
|
+
train_data: Dataset,
|
|
260
|
+
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
261
|
+
max_iters: int = 100,
|
|
262
|
+
safe: bool = True,
|
|
263
|
+
max_linesearch_steps: int = 32,
|
|
264
|
+
gtol: float = 1e-5,
|
|
265
|
+
) -> tuple[Model, jax.Array]:
|
|
266
|
+
r"""Train a Module model with respect to a supplied Objective function.
|
|
267
|
+
|
|
268
|
+
Uses Optax's LBFGS implementation and a jax.lax.while loop.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
model: the model Module to be optimised.
|
|
272
|
+
objective: The objective function that we are optimising with
|
|
273
|
+
respect to.
|
|
274
|
+
train_data (Dataset): The training data to be used for the optimisation.
|
|
275
|
+
max_iters (int): The maximum number of optimisation steps to run. Defaults
|
|
276
|
+
to 500.
|
|
277
|
+
safe (bool): Whether to check the types of the inputs.
|
|
278
|
+
max_linesearch_steps (int): The maximum number of linesearch steps to use
|
|
279
|
+
for finding the stepsize.
|
|
280
|
+
gtol (float): Terminate the optimisation if the L2 norm of the gradient is
|
|
281
|
+
below this threshold.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A tuple comprising the optimised model and final loss.
|
|
285
|
+
"""
|
|
286
|
+
if safe:
|
|
287
|
+
# Check inputs
|
|
288
|
+
_check_model(model)
|
|
289
|
+
_check_train_data(train_data)
|
|
290
|
+
_check_num_iters(max_iters)
|
|
291
|
+
|
|
292
|
+
# Model state filtering
|
|
293
|
+
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
294
|
+
|
|
295
|
+
# Parameters bijection to unconstrained space
|
|
296
|
+
if params_bijection is not None:
|
|
297
|
+
params = transform(params, params_bijection, inverse=True)
|
|
298
|
+
|
|
299
|
+
# Loss definition
|
|
300
|
+
def loss(params: nnx.State) -> ScalarFloat:
|
|
301
|
+
params = transform(params, params_bijection)
|
|
302
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
303
|
+
return objective(model, train_data)
|
|
304
|
+
|
|
305
|
+
# Initialise optimiser
|
|
306
|
+
optim = ox.lbfgs(
|
|
307
|
+
linesearch=ox.scale_by_zoom_linesearch(
|
|
308
|
+
max_linesearch_steps=max_linesearch_steps,
|
|
309
|
+
initial_guess_strategy="one",
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
opt_state = optim.init(params)
|
|
313
|
+
loss_value_and_grad = ox.value_and_grad_from_state(loss)
|
|
314
|
+
|
|
315
|
+
# Optimisation step.
|
|
316
|
+
def step(carry):
|
|
317
|
+
params, opt_state = carry
|
|
318
|
+
|
|
319
|
+
# Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
|
|
320
|
+
# See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
|
|
321
|
+
loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
|
|
322
|
+
updates, opt_state = optim.update(
|
|
323
|
+
loss_gradient,
|
|
324
|
+
opt_state,
|
|
325
|
+
params,
|
|
326
|
+
value=loss_val,
|
|
327
|
+
grad=loss_gradient,
|
|
328
|
+
value_fn=loss,
|
|
329
|
+
)
|
|
330
|
+
params = ox.apply_updates(params, updates)
|
|
331
|
+
|
|
332
|
+
return params, opt_state
|
|
333
|
+
|
|
334
|
+
def continue_fn(carry):
|
|
335
|
+
_, opt_state = carry
|
|
336
|
+
n = ox.tree_utils.tree_get(opt_state, "count")
|
|
337
|
+
g = ox.tree_utils.tree_get(opt_state, "grad")
|
|
338
|
+
g_l2_norm = ox.tree_utils.tree_l2_norm(g)
|
|
339
|
+
return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
|
|
340
|
+
|
|
341
|
+
# Optimisation loop
|
|
342
|
+
params, opt_state = jax.lax.while_loop(
|
|
343
|
+
continue_fn,
|
|
344
|
+
step,
|
|
345
|
+
(params, opt_state),
|
|
346
|
+
)
|
|
347
|
+
final_loss = ox.tree_utils.tree_get(opt_state, "value")
|
|
348
|
+
|
|
349
|
+
# Parameters bijection to constrained space
|
|
350
|
+
if params_bijection is not None:
|
|
351
|
+
params = transform(params, params_bijection)
|
|
352
|
+
|
|
353
|
+
# Reconstruct model
|
|
354
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
355
|
+
|
|
356
|
+
return model, final_loss
|
|
357
|
+
|
|
358
|
+
|
|
256
359
|
def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
|
|
257
360
|
"""Batch the data into mini-batches. Sampling is done with replacement.
|
|
258
361
|
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
|
|
|
91
94
|
if isinstance(variance, nnx.Variable):
|
|
92
95
|
self.variance = variance
|
|
93
96
|
else:
|
|
94
|
-
self.variance =
|
|
97
|
+
self.variance = NonNegativeReal(variance)
|
|
95
98
|
if tp.TYPE_CHECKING:
|
|
96
|
-
self.variance = tp.cast(
|
|
99
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
97
100
|
|
|
98
101
|
if isinstance(bias_variance, nnx.Variable):
|
|
99
102
|
self.bias_variance = bias_variance
|
|
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import NonNegativeReal
|
|
27
27
|
from gpjax.typing import (
|
|
28
28
|
Array,
|
|
29
29
|
ScalarArray,
|
|
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
|
|
|
64
64
|
if isinstance(variance, nnx.Variable):
|
|
65
65
|
self.variance = variance
|
|
66
66
|
else:
|
|
67
|
-
self.variance =
|
|
67
|
+
self.variance = NonNegativeReal(variance)
|
|
68
68
|
if tp.TYPE_CHECKING:
|
|
69
|
-
self.variance = tp.cast(
|
|
69
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
70
70
|
|
|
71
71
|
def __call__(
|
|
72
72
|
self,
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
|
|
|
76
79
|
if isinstance(variance, nnx.Variable):
|
|
77
80
|
self.variance = variance
|
|
78
81
|
else:
|
|
79
|
-
self.variance =
|
|
82
|
+
self.variance = NonNegativeReal(variance)
|
|
80
83
|
if tp.TYPE_CHECKING:
|
|
81
|
-
self.variance = tp.cast(
|
|
84
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
82
85
|
|
|
83
86
|
self.name = f"Polynomial (degree {self.degree})"
|
|
84
87
|
|
|
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
AbstractKernelComputation,
|
|
26
26
|
DenseKernelComputation,
|
|
27
27
|
)
|
|
28
|
-
from gpjax.parameters import
|
|
28
|
+
from gpjax.parameters import (
|
|
29
|
+
NonNegativeReal,
|
|
30
|
+
PositiveReal,
|
|
31
|
+
)
|
|
29
32
|
from gpjax.typing import (
|
|
30
33
|
Array,
|
|
31
34
|
ScalarArray,
|
|
@@ -85,11 +88,11 @@ class StationaryKernel(AbstractKernel):
|
|
|
85
88
|
if isinstance(variance, nnx.Variable):
|
|
86
89
|
self.variance = variance
|
|
87
90
|
else:
|
|
88
|
-
self.variance =
|
|
91
|
+
self.variance = NonNegativeReal(variance)
|
|
89
92
|
|
|
90
93
|
# static typing
|
|
91
94
|
if tp.TYPE_CHECKING:
|
|
92
|
-
self.variance = tp.cast(
|
|
95
|
+
self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
98
|
def spectral_density(self) -> npd.Normal | npd.StudentT:
|
|
@@ -28,7 +28,7 @@ from gpjax.integrators import (
|
|
|
28
28
|
GHQuadratureIntegrator,
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
|
-
|
|
31
|
+
NonNegativeReal,
|
|
32
32
|
Static,
|
|
33
33
|
)
|
|
34
34
|
from gpjax.typing import (
|
|
@@ -134,7 +134,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
134
134
|
self,
|
|
135
135
|
num_datapoints: int,
|
|
136
136
|
obs_stddev: tp.Union[
|
|
137
|
-
ScalarFloat, Float[Array, "#N"],
|
|
137
|
+
ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
|
|
138
138
|
] = 1.0,
|
|
139
139
|
integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
|
|
140
140
|
):
|
|
@@ -148,8 +148,8 @@ class Gaussian(AbstractLikelihood):
|
|
|
148
148
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
149
149
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
150
150
|
"""
|
|
151
|
-
if not isinstance(obs_stddev, (
|
|
152
|
-
obs_stddev =
|
|
151
|
+
if not isinstance(obs_stddev, (NonNegativeReal, Static)):
|
|
152
|
+
obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
|
|
153
153
|
self.obs_stddev = obs_stddev
|
|
154
154
|
|
|
155
155
|
super().__init__(num_datapoints, integrator)
|
|
@@ -207,5 +207,5 @@ SumMeanFunction = ft.partial(
|
|
|
207
207
|
CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
|
|
208
208
|
)
|
|
209
209
|
ProductMeanFunction = ft.partial(
|
|
210
|
-
CombinationMeanFunction, operator=ft.partial(jnp.
|
|
210
|
+
CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
|
|
211
211
|
)
|
|
@@ -82,6 +82,14 @@ class Parameter(nnx.Variable[T]):
|
|
|
82
82
|
self._tag = tag
|
|
83
83
|
|
|
84
84
|
|
|
85
|
+
class NonNegativeReal(Parameter[T]):
|
|
86
|
+
"""Parameter that is non-negative."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
|
|
89
|
+
super().__init__(value=value, tag=tag, **kwargs)
|
|
90
|
+
_safe_assert(_check_is_non_negative, self.value)
|
|
91
|
+
|
|
92
|
+
|
|
85
93
|
class PositiveReal(Parameter[T]):
|
|
86
94
|
"""Parameter that is strictly positive."""
|
|
87
95
|
|
|
@@ -143,6 +151,7 @@ class LowerTriangular(Parameter[T]):
|
|
|
143
151
|
|
|
144
152
|
DEFAULT_BIJECTION = {
|
|
145
153
|
"positive": npt.SoftplusTransform(),
|
|
154
|
+
"non_negative": npt.SoftplusTransform(),
|
|
146
155
|
"real": npt.IdentityTransform(),
|
|
147
156
|
"sigmoid": npt.SigmoidTransform(),
|
|
148
157
|
"lower_triangular": FillTriangularTransform(),
|
|
@@ -164,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
|
|
|
164
173
|
)
|
|
165
174
|
|
|
166
175
|
|
|
176
|
+
@checkify.checkify
|
|
177
|
+
def _check_is_non_negative(value):
|
|
178
|
+
checkify.check(
|
|
179
|
+
jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
167
183
|
@checkify.checkify
|
|
168
184
|
def _check_is_positive(value):
|
|
169
185
|
checkify.check(
|
|
@@ -13,20 +13,24 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from flax import nnx
|
|
17
16
|
import jax.numpy as jnp
|
|
18
17
|
import jax.random as jr
|
|
19
|
-
from jaxtyping import (
|
|
20
|
-
Float,
|
|
21
|
-
Num,
|
|
22
|
-
)
|
|
23
18
|
import optax as ox
|
|
24
19
|
import pytest
|
|
25
20
|
import scipy
|
|
26
|
-
|
|
21
|
+
from beartype.typing import Any
|
|
22
|
+
from flax import nnx
|
|
27
23
|
from gpjax.dataset import Dataset
|
|
28
24
|
from gpjax.fit import (
|
|
25
|
+
_check_batch_size,
|
|
26
|
+
_check_log_rate,
|
|
27
|
+
_check_model,
|
|
28
|
+
_check_num_iters,
|
|
29
|
+
_check_optim,
|
|
30
|
+
_check_train_data,
|
|
31
|
+
_check_verbose,
|
|
29
32
|
fit,
|
|
33
|
+
fit_lbfgs,
|
|
30
34
|
fit_scipy,
|
|
31
35
|
get_batch,
|
|
32
36
|
)
|
|
@@ -50,6 +54,10 @@ from gpjax.parameters import (
|
|
|
50
54
|
)
|
|
51
55
|
from gpjax.typing import Array
|
|
52
56
|
from gpjax.variational_families import VariationalGaussian
|
|
57
|
+
from jaxtyping import (
|
|
58
|
+
Float,
|
|
59
|
+
Num,
|
|
60
|
+
)
|
|
53
61
|
|
|
54
62
|
|
|
55
63
|
def test_fit_simple() -> None:
|
|
@@ -141,6 +149,46 @@ def test_fit_scipy_simple():
|
|
|
141
149
|
assert trained_model.bias.value == 1.0
|
|
142
150
|
|
|
143
151
|
|
|
152
|
+
def test_fit_lbfgs_simple():
|
|
153
|
+
# Create dataset:
|
|
154
|
+
X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1)
|
|
155
|
+
y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1)
|
|
156
|
+
D = Dataset(X, y)
|
|
157
|
+
|
|
158
|
+
# Define linear model:
|
|
159
|
+
class LinearModel(nnx.Module):
|
|
160
|
+
def __init__(self, weight: float, bias: float):
|
|
161
|
+
self.weight = PositiveReal(weight)
|
|
162
|
+
self.bias = Static(bias)
|
|
163
|
+
|
|
164
|
+
def __call__(self, x):
|
|
165
|
+
return self.weight.value * x + self.bias.value
|
|
166
|
+
|
|
167
|
+
model = LinearModel(weight=1.0, bias=1.0)
|
|
168
|
+
|
|
169
|
+
# Define loss function:
|
|
170
|
+
def mse(model, data):
|
|
171
|
+
pred = model(data.X)
|
|
172
|
+
return jnp.mean((pred - data.y) ** 2)
|
|
173
|
+
|
|
174
|
+
# Train with bfgs!
|
|
175
|
+
trained_model, final_loss = fit_lbfgs(
|
|
176
|
+
model=model,
|
|
177
|
+
objective=mse,
|
|
178
|
+
train_data=D,
|
|
179
|
+
max_iters=10,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Ensure we return a model of the same class
|
|
183
|
+
assert isinstance(trained_model, LinearModel)
|
|
184
|
+
|
|
185
|
+
# Test reduction in loss:
|
|
186
|
+
assert mse(trained_model, D) < mse(model, D)
|
|
187
|
+
|
|
188
|
+
# Test stop_gradient on bias:
|
|
189
|
+
assert trained_model.bias.value == 1.0
|
|
190
|
+
|
|
191
|
+
|
|
144
192
|
@pytest.mark.parametrize("n_data", [20])
|
|
145
193
|
@pytest.mark.parametrize("verbose", [True, False])
|
|
146
194
|
def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
|
|
@@ -179,8 +227,7 @@ def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
|
|
|
179
227
|
|
|
180
228
|
|
|
181
229
|
@pytest.mark.parametrize("n_data", [20])
|
|
182
|
-
|
|
183
|
-
def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
|
|
230
|
+
def test_fit_lbfgs_gp_regression(n_data: int) -> None:
|
|
184
231
|
# Create dataset:
|
|
185
232
|
key = jr.PRNGKey(123)
|
|
186
233
|
x = jnp.sort(
|
|
@@ -195,20 +242,16 @@ def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
|
|
|
195
242
|
posterior = prior * likelihood
|
|
196
243
|
|
|
197
244
|
# Train with BFGS!
|
|
198
|
-
trained_model_bfgs,
|
|
245
|
+
trained_model_bfgs, final_loss = fit_lbfgs(
|
|
199
246
|
model=posterior,
|
|
200
247
|
objective=conjugate_mll,
|
|
201
248
|
train_data=D,
|
|
202
249
|
max_iters=40,
|
|
203
|
-
verbose=verbose,
|
|
204
250
|
)
|
|
205
251
|
|
|
206
252
|
# Ensure the trained model is a Gaussian process posterior
|
|
207
253
|
assert isinstance(trained_model_bfgs, ConjugatePosterior)
|
|
208
254
|
|
|
209
|
-
# Ensure we return a history_bfgs of the correct length
|
|
210
|
-
assert len(history_bfgs) > 2
|
|
211
|
-
|
|
212
255
|
# Ensure we reduce the loss
|
|
213
256
|
assert conjugate_mll(trained_model_bfgs, D) < conjugate_mll(posterior, D)
|
|
214
257
|
|
|
@@ -324,3 +367,142 @@ def test_get_batch(n_data: int, n_dim: int, batch_size: int):
|
|
|
324
367
|
assert New.y.shape[1:] == y.shape[1:]
|
|
325
368
|
assert jnp.sum(New.X == B.X) <= n_dim * batch_size / n_data
|
|
326
369
|
assert jnp.sum(New.y == B.y) <= n_dim * batch_size / n_data
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
@pytest.fixture
|
|
373
|
+
def valid_model() -> nnx.Module:
|
|
374
|
+
"""Return a valid model for testing."""
|
|
375
|
+
|
|
376
|
+
class LinearModel(nnx.Module):
|
|
377
|
+
def __init__(self, weight: float, bias: float) -> None:
|
|
378
|
+
self.weight = PositiveReal(weight)
|
|
379
|
+
self.bias = Static(bias)
|
|
380
|
+
|
|
381
|
+
def __call__(self, x: Any) -> Any:
|
|
382
|
+
return self.weight.value * x + self.bias.value
|
|
383
|
+
|
|
384
|
+
return LinearModel(weight=1.0, bias=1.0)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
@pytest.fixture
|
|
388
|
+
def valid_dataset() -> Dataset:
|
|
389
|
+
"""Return a valid dataset for testing."""
|
|
390
|
+
X = jnp.array([[1.0], [2.0], [3.0]])
|
|
391
|
+
y = jnp.array([[1.0], [2.0], [3.0]])
|
|
392
|
+
return Dataset(X=X, y=y)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def test_check_model_valid(valid_model: nnx.Module) -> None:
|
|
396
|
+
"""Test that a valid model passes validation."""
|
|
397
|
+
_check_model(valid_model)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def test_check_model_invalid() -> None:
|
|
401
|
+
"""Test that an invalid model raises a TypeError."""
|
|
402
|
+
model = "not a model"
|
|
403
|
+
with pytest.raises(
|
|
404
|
+
TypeError, match="Expected model to be a subclass of nnx.Module"
|
|
405
|
+
):
|
|
406
|
+
_check_model(model)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def test_check_train_data_valid(valid_dataset: Dataset) -> None:
|
|
410
|
+
"""Test that valid training data passes validation."""
|
|
411
|
+
_check_train_data(valid_dataset)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def test_check_train_data_invalid() -> None:
|
|
415
|
+
"""Test that invalid training data raises a TypeError."""
|
|
416
|
+
train_data = "not a dataset"
|
|
417
|
+
with pytest.raises(
|
|
418
|
+
TypeError, match="Expected train_data to be of type gpjax.Dataset"
|
|
419
|
+
):
|
|
420
|
+
_check_train_data(train_data)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def test_check_optim_valid() -> None:
|
|
424
|
+
"""Test that a valid optimiser passes validation."""
|
|
425
|
+
optim = ox.sgd(0.1)
|
|
426
|
+
_check_optim(optim)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def test_check_optim_invalid() -> None:
|
|
430
|
+
"""Test that an invalid optimiser raises a TypeError."""
|
|
431
|
+
optim = "not an optimiser"
|
|
432
|
+
with pytest.raises(
|
|
433
|
+
TypeError, match="Expected optim to be of type optax.GradientTransformation"
|
|
434
|
+
):
|
|
435
|
+
_check_optim(optim)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
@pytest.mark.parametrize("num_iters", [1, 10, 100])
|
|
439
|
+
def test_check_num_iters_valid(num_iters: int) -> None:
|
|
440
|
+
"""Test that valid number of iterations passes validation."""
|
|
441
|
+
_check_num_iters(num_iters)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def test_check_num_iters_invalid_type() -> None:
|
|
445
|
+
"""Test that an invalid num_iters type raises a TypeError."""
|
|
446
|
+
num_iters = "not an int"
|
|
447
|
+
with pytest.raises(TypeError, match="Expected num_iters to be of type int"):
|
|
448
|
+
_check_num_iters(num_iters)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@pytest.mark.parametrize("num_iters", [0, -5])
|
|
452
|
+
def test_check_num_iters_invalid_value(num_iters: int) -> None:
|
|
453
|
+
"""Test that an invalid num_iters value raises a ValueError."""
|
|
454
|
+
with pytest.raises(ValueError, match="Expected num_iters to be positive"):
|
|
455
|
+
_check_num_iters(num_iters)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
@pytest.mark.parametrize("log_rate", [1, 10, 100])
|
|
459
|
+
def test_check_log_rate_valid(log_rate: int) -> None:
|
|
460
|
+
"""Test that a valid log rate passes validation."""
|
|
461
|
+
_check_log_rate(log_rate)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def test_check_log_rate_invalid_type() -> None:
|
|
465
|
+
"""Test that an invalid log_rate type raises a TypeError."""
|
|
466
|
+
log_rate = "not an int"
|
|
467
|
+
with pytest.raises(TypeError, match="Expected log_rate to be of type int"):
|
|
468
|
+
_check_log_rate(log_rate)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@pytest.mark.parametrize("log_rate", [0, -5])
|
|
472
|
+
def test_check_log_rate_invalid_value(log_rate: int) -> None:
|
|
473
|
+
"""Test that an invalid log_rate value raises a ValueError."""
|
|
474
|
+
with pytest.raises(ValueError, match="Expected log_rate to be positive"):
|
|
475
|
+
_check_log_rate(log_rate)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@pytest.mark.parametrize("verbose", [True, False])
|
|
479
|
+
def test_check_verbose_valid(verbose: bool) -> None:
|
|
480
|
+
"""Test that valid verbose values pass validation."""
|
|
481
|
+
_check_verbose(verbose)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def test_check_verbose_invalid() -> None:
|
|
485
|
+
"""Test that an invalid verbose value raises a TypeError."""
|
|
486
|
+
verbose = "not a bool"
|
|
487
|
+
with pytest.raises(TypeError, match="Expected verbose to be of type bool"):
|
|
488
|
+
_check_verbose(verbose)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@pytest.mark.parametrize("batch_size", [1, 10, 100, -1])
|
|
492
|
+
def test_check_batch_size_valid(batch_size: int) -> None:
|
|
493
|
+
"""Test that valid batch sizes pass validation."""
|
|
494
|
+
_check_batch_size(batch_size)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def test_check_batch_size_invalid_type() -> None:
|
|
498
|
+
"""Test that an invalid batch_size type raises a TypeError."""
|
|
499
|
+
batch_size = "not an int"
|
|
500
|
+
with pytest.raises(TypeError, match="Expected batch_size to be of type int"):
|
|
501
|
+
_check_batch_size(batch_size)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@pytest.mark.parametrize("batch_size", [0, -2, -5])
|
|
505
|
+
def test_check_batch_size_invalid_value(batch_size: int) -> None:
|
|
506
|
+
"""Test that invalid batch_size values raise a ValueError."""
|
|
507
|
+
with pytest.raises(ValueError, match="Expected batch_size to be positive or -1"):
|
|
508
|
+
_check_batch_size(batch_size)
|
|
@@ -31,7 +31,7 @@ from gpjax.kernels.nonstationary import (
|
|
|
31
31
|
Polynomial,
|
|
32
32
|
)
|
|
33
33
|
from gpjax.parameters import (
|
|
34
|
-
|
|
34
|
+
NonNegativeReal,
|
|
35
35
|
Static,
|
|
36
36
|
)
|
|
37
37
|
|
|
@@ -96,8 +96,8 @@ def test_init_override_paramtype(kernel_request):
|
|
|
96
96
|
continue
|
|
97
97
|
new_params[param] = Static(value)
|
|
98
98
|
|
|
99
|
-
k = kernel(**new_params, variance=
|
|
100
|
-
assert isinstance(k.variance,
|
|
99
|
+
k = kernel(**new_params, variance=NonNegativeReal(variance))
|
|
100
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
101
101
|
|
|
102
102
|
for param in params.keys():
|
|
103
103
|
if param in ("degree", "order"):
|
|
@@ -112,7 +112,7 @@ def test_init_defaults(kernel: type[AbstractKernel]):
|
|
|
112
112
|
|
|
113
113
|
# Check that the parameters are set correctly
|
|
114
114
|
assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
|
|
115
|
-
assert isinstance(k.variance,
|
|
115
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
116
116
|
|
|
117
117
|
|
|
118
118
|
@pytest.mark.parametrize("kernel", [k[0] for k in TESTED_KERNELS])
|
|
@@ -122,7 +122,7 @@ def test_init_variances(kernel: type[AbstractKernel], variance):
|
|
|
122
122
|
k = kernel(variance=variance)
|
|
123
123
|
|
|
124
124
|
# Check that the parameters are set correctly
|
|
125
|
-
assert isinstance(k.variance,
|
|
125
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
126
126
|
assert jnp.allclose(k.variance.value, jnp.asarray(variance))
|
|
127
127
|
|
|
128
128
|
# Check that error is raised if variance is not valid
|