gpjax 0.11.0__tar.gz → 0.11.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.11.0 → gpjax-0.11.2}/PKG-INFO +1 -1
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/constructing_new_kernels.py +0 -3
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/__init__.py +4 -2
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/citation.py +7 -2
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/fit.py +104 -1
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/arccosine.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/linear.py +3 -3
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/polynomial.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/base.py +6 -3
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/likelihoods.py +4 -4
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/mean_functions.py +1 -1
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/parameters.py +16 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_fit.py +190 -7
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_nonstationary.py +5 -5
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_stationary.py +5 -4
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_likelihoods.py +2 -2
- gpjax-0.11.2/tests/test_mean_functions.py +249 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_numpyro_extras.py +76 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_parameters.py +4 -0
- gpjax-0.11.2/uv.lock +832 -0
- gpjax-0.11.0/.cursorrules +0 -37
- gpjax-0.11.0/tests/test_mean_functions.py +0 -81
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/codecov.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/labels.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/pull_request_template.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/release-drafter.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/integration.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.github/workflows/tests.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/.gitignore +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/CITATION.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/LICENSE.txt +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/Makefile +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/README.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/contributing.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/design.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/index.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/index.rst +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/installation.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/javascripts/katex.js +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/refs.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/sharp_bits.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/GP.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/GP.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/favicon.ico +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/backend.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/barycentres.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/classification.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/collapsed_vi.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/deep_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/graph_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/oceanmodelling.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/poisson.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/regression.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/examples/yacht.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/dataset.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/distributions.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/integrators.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/objectives.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/scan.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/typing.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/gpjax/variational_families.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/mkdocs.yml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/pyproject.toml +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.bib +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.md +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/static/paper.pdf +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/conftest.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/integration_tests.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_citations.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_dataset.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_gps.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_integrators.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_markdown.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_objectives.py +0 -0
- {gpjax-0.11.0 → gpjax-0.11.2}/tests/test_variational_families.py +0 -0
|
@@ -33,7 +33,6 @@ from jaxtyping import (
|
|
|
33
33
|
install_import_hook,
|
|
34
34
|
)
|
|
35
35
|
import matplotlib.pyplot as plt
|
|
36
|
-
import numpyro.distributions as npd
|
|
37
36
|
from numpyro.distributions import constraints
|
|
38
37
|
import numpyro.distributions.transforms as npt
|
|
39
38
|
|
|
@@ -52,8 +51,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
|
|
|
52
51
|
import gpjax as gpx
|
|
53
52
|
|
|
54
53
|
|
|
55
|
-
tfb = tfp.bijectors
|
|
56
|
-
|
|
57
54
|
# set the default style for plotting
|
|
58
55
|
use_mpl_style()
|
|
59
56
|
|
|
@@ -32,14 +32,15 @@ from gpjax.citation import cite
|
|
|
32
32
|
from gpjax.dataset import Dataset
|
|
33
33
|
from gpjax.fit import (
|
|
34
34
|
fit,
|
|
35
|
+
fit_lbfgs,
|
|
35
36
|
fit_scipy,
|
|
36
37
|
)
|
|
37
38
|
|
|
38
39
|
__license__ = "MIT"
|
|
39
|
-
__description__ = "
|
|
40
|
+
__description__ = "Gaussian processes in JAX and Flax"
|
|
40
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.11.
|
|
43
|
+
__version__ = "0.11.2"
|
|
43
44
|
|
|
44
45
|
__all__ = [
|
|
45
46
|
"base",
|
|
@@ -56,5 +57,6 @@ __all__ = [
|
|
|
56
57
|
"fit",
|
|
57
58
|
"Module",
|
|
58
59
|
"param_field",
|
|
60
|
+
"fit_lbfgs",
|
|
59
61
|
"fit_scipy",
|
|
60
62
|
]
|
|
@@ -8,7 +8,12 @@ from beartype.typing import (
|
|
|
8
8
|
Dict,
|
|
9
9
|
Union,
|
|
10
10
|
)
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
# safely removable once jax>=0.6.0
|
|
14
|
+
from jaxlib.xla_extension import PjitFunction
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
from jaxlib._jax import PjitFunction
|
|
12
17
|
|
|
13
18
|
from gpjax.kernels import (
|
|
14
19
|
RFF,
|
|
@@ -45,7 +50,7 @@ class AbstractCitation:
|
|
|
45
50
|
|
|
46
51
|
|
|
47
52
|
class NullCitation(AbstractCitation):
|
|
48
|
-
def
|
|
53
|
+
def as_str(self) -> str:
|
|
49
54
|
return (
|
|
50
55
|
"No citation available. If you think this is an error, please open a pull"
|
|
51
56
|
" request."
|
|
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
|
|
|
127
127
|
_check_verbose(verbose)
|
|
128
128
|
|
|
129
129
|
# Model state filtering
|
|
130
|
-
|
|
131
130
|
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
132
131
|
|
|
133
132
|
# Parameters bijection to unconstrained space
|
|
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
|
|
|
253
252
|
return model, history
|
|
254
253
|
|
|
255
254
|
|
|
255
|
+
def fit_lbfgs(
|
|
256
|
+
*,
|
|
257
|
+
model: Model,
|
|
258
|
+
objective: Objective,
|
|
259
|
+
train_data: Dataset,
|
|
260
|
+
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
261
|
+
max_iters: int = 100,
|
|
262
|
+
safe: bool = True,
|
|
263
|
+
max_linesearch_steps: int = 32,
|
|
264
|
+
gtol: float = 1e-5,
|
|
265
|
+
) -> tuple[Model, jax.Array]:
|
|
266
|
+
r"""Train a Module model with respect to a supplied Objective function.
|
|
267
|
+
|
|
268
|
+
Uses Optax's LBFGS implementation and a jax.lax.while loop.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
model: the model Module to be optimised.
|
|
272
|
+
objective: The objective function that we are optimising with
|
|
273
|
+
respect to.
|
|
274
|
+
train_data (Dataset): The training data to be used for the optimisation.
|
|
275
|
+
max_iters (int): The maximum number of optimisation steps to run. Defaults
|
|
276
|
+
to 500.
|
|
277
|
+
safe (bool): Whether to check the types of the inputs.
|
|
278
|
+
max_linesearch_steps (int): The maximum number of linesearch steps to use
|
|
279
|
+
for finding the stepsize.
|
|
280
|
+
gtol (float): Terminate the optimisation if the L2 norm of the gradient is
|
|
281
|
+
below this threshold.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A tuple comprising the optimised model and final loss.
|
|
285
|
+
"""
|
|
286
|
+
if safe:
|
|
287
|
+
# Check inputs
|
|
288
|
+
_check_model(model)
|
|
289
|
+
_check_train_data(train_data)
|
|
290
|
+
_check_num_iters(max_iters)
|
|
291
|
+
|
|
292
|
+
# Model state filtering
|
|
293
|
+
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
294
|
+
|
|
295
|
+
# Parameters bijection to unconstrained space
|
|
296
|
+
if params_bijection is not None:
|
|
297
|
+
params = transform(params, params_bijection, inverse=True)
|
|
298
|
+
|
|
299
|
+
# Loss definition
|
|
300
|
+
def loss(params: nnx.State) -> ScalarFloat:
|
|
301
|
+
params = transform(params, params_bijection)
|
|
302
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
303
|
+
return objective(model, train_data)
|
|
304
|
+
|
|
305
|
+
# Initialise optimiser
|
|
306
|
+
optim = ox.lbfgs(
|
|
307
|
+
linesearch=ox.scale_by_zoom_linesearch(
|
|
308
|
+
max_linesearch_steps=max_linesearch_steps,
|
|
309
|
+
initial_guess_strategy="one",
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
opt_state = optim.init(params)
|
|
313
|
+
loss_value_and_grad = ox.value_and_grad_from_state(loss)
|
|
314
|
+
|
|
315
|
+
# Optimisation step.
|
|
316
|
+
def step(carry):
|
|
317
|
+
params, opt_state = carry
|
|
318
|
+
|
|
319
|
+
# Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
|
|
320
|
+
# See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
|
|
321
|
+
loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
|
|
322
|
+
updates, opt_state = optim.update(
|
|
323
|
+
loss_gradient,
|
|
324
|
+
opt_state,
|
|
325
|
+
params,
|
|
326
|
+
value=loss_val,
|
|
327
|
+
grad=loss_gradient,
|
|
328
|
+
value_fn=loss,
|
|
329
|
+
)
|
|
330
|
+
params = ox.apply_updates(params, updates)
|
|
331
|
+
|
|
332
|
+
return params, opt_state
|
|
333
|
+
|
|
334
|
+
def continue_fn(carry):
|
|
335
|
+
_, opt_state = carry
|
|
336
|
+
n = ox.tree_utils.tree_get(opt_state, "count")
|
|
337
|
+
g = ox.tree_utils.tree_get(opt_state, "grad")
|
|
338
|
+
g_l2_norm = ox.tree_utils.tree_l2_norm(g)
|
|
339
|
+
return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
|
|
340
|
+
|
|
341
|
+
# Optimisation loop
|
|
342
|
+
params, opt_state = jax.lax.while_loop(
|
|
343
|
+
continue_fn,
|
|
344
|
+
step,
|
|
345
|
+
(params, opt_state),
|
|
346
|
+
)
|
|
347
|
+
final_loss = ox.tree_utils.tree_get(opt_state, "value")
|
|
348
|
+
|
|
349
|
+
# Parameters bijection to constrained space
|
|
350
|
+
if params_bijection is not None:
|
|
351
|
+
params = transform(params, params_bijection)
|
|
352
|
+
|
|
353
|
+
# Reconstruct model
|
|
354
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
355
|
+
|
|
356
|
+
return model, final_loss
|
|
357
|
+
|
|
358
|
+
|
|
256
359
|
def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
|
|
257
360
|
"""Batch the data into mini-batches. Sampling is done with replacement.
|
|
258
361
|
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
|
|
|
91
94
|
if isinstance(variance, nnx.Variable):
|
|
92
95
|
self.variance = variance
|
|
93
96
|
else:
|
|
94
|
-
self.variance =
|
|
97
|
+
self.variance = NonNegativeReal(variance)
|
|
95
98
|
if tp.TYPE_CHECKING:
|
|
96
|
-
self.variance = tp.cast(
|
|
99
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
97
100
|
|
|
98
101
|
if isinstance(bias_variance, nnx.Variable):
|
|
99
102
|
self.bias_variance = bias_variance
|
|
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import NonNegativeReal
|
|
27
27
|
from gpjax.typing import (
|
|
28
28
|
Array,
|
|
29
29
|
ScalarArray,
|
|
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
|
|
|
64
64
|
if isinstance(variance, nnx.Variable):
|
|
65
65
|
self.variance = variance
|
|
66
66
|
else:
|
|
67
|
-
self.variance =
|
|
67
|
+
self.variance = NonNegativeReal(variance)
|
|
68
68
|
if tp.TYPE_CHECKING:
|
|
69
|
-
self.variance = tp.cast(
|
|
69
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
70
70
|
|
|
71
71
|
def __call__(
|
|
72
72
|
self,
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
|
|
|
76
79
|
if isinstance(variance, nnx.Variable):
|
|
77
80
|
self.variance = variance
|
|
78
81
|
else:
|
|
79
|
-
self.variance =
|
|
82
|
+
self.variance = NonNegativeReal(variance)
|
|
80
83
|
if tp.TYPE_CHECKING:
|
|
81
|
-
self.variance = tp.cast(
|
|
84
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
82
85
|
|
|
83
86
|
self.name = f"Polynomial (degree {self.degree})"
|
|
84
87
|
|
|
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
AbstractKernelComputation,
|
|
26
26
|
DenseKernelComputation,
|
|
27
27
|
)
|
|
28
|
-
from gpjax.parameters import
|
|
28
|
+
from gpjax.parameters import (
|
|
29
|
+
NonNegativeReal,
|
|
30
|
+
PositiveReal,
|
|
31
|
+
)
|
|
29
32
|
from gpjax.typing import (
|
|
30
33
|
Array,
|
|
31
34
|
ScalarArray,
|
|
@@ -85,11 +88,11 @@ class StationaryKernel(AbstractKernel):
|
|
|
85
88
|
if isinstance(variance, nnx.Variable):
|
|
86
89
|
self.variance = variance
|
|
87
90
|
else:
|
|
88
|
-
self.variance =
|
|
91
|
+
self.variance = NonNegativeReal(variance)
|
|
89
92
|
|
|
90
93
|
# static typing
|
|
91
94
|
if tp.TYPE_CHECKING:
|
|
92
|
-
self.variance = tp.cast(
|
|
95
|
+
self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
98
|
def spectral_density(self) -> npd.Normal | npd.StudentT:
|
|
@@ -28,7 +28,7 @@ from gpjax.integrators import (
|
|
|
28
28
|
GHQuadratureIntegrator,
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
|
-
|
|
31
|
+
NonNegativeReal,
|
|
32
32
|
Static,
|
|
33
33
|
)
|
|
34
34
|
from gpjax.typing import (
|
|
@@ -134,7 +134,7 @@ class Gaussian(AbstractLikelihood):
|
|
|
134
134
|
self,
|
|
135
135
|
num_datapoints: int,
|
|
136
136
|
obs_stddev: tp.Union[
|
|
137
|
-
ScalarFloat, Float[Array, "#N"],
|
|
137
|
+
ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
|
|
138
138
|
] = 1.0,
|
|
139
139
|
integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
|
|
140
140
|
):
|
|
@@ -148,8 +148,8 @@ class Gaussian(AbstractLikelihood):
|
|
|
148
148
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
149
149
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
150
150
|
"""
|
|
151
|
-
if not isinstance(obs_stddev, (
|
|
152
|
-
obs_stddev =
|
|
151
|
+
if not isinstance(obs_stddev, (NonNegativeReal, Static)):
|
|
152
|
+
obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
|
|
153
153
|
self.obs_stddev = obs_stddev
|
|
154
154
|
|
|
155
155
|
super().__init__(num_datapoints, integrator)
|
|
@@ -207,5 +207,5 @@ SumMeanFunction = ft.partial(
|
|
|
207
207
|
CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
|
|
208
208
|
)
|
|
209
209
|
ProductMeanFunction = ft.partial(
|
|
210
|
-
CombinationMeanFunction, operator=ft.partial(jnp.
|
|
210
|
+
CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
|
|
211
211
|
)
|
|
@@ -82,6 +82,14 @@ class Parameter(nnx.Variable[T]):
|
|
|
82
82
|
self._tag = tag
|
|
83
83
|
|
|
84
84
|
|
|
85
|
+
class NonNegativeReal(Parameter[T]):
|
|
86
|
+
"""Parameter that is non-negative."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
|
|
89
|
+
super().__init__(value=value, tag=tag, **kwargs)
|
|
90
|
+
_safe_assert(_check_is_non_negative, self.value)
|
|
91
|
+
|
|
92
|
+
|
|
85
93
|
class PositiveReal(Parameter[T]):
|
|
86
94
|
"""Parameter that is strictly positive."""
|
|
87
95
|
|
|
@@ -143,6 +151,7 @@ class LowerTriangular(Parameter[T]):
|
|
|
143
151
|
|
|
144
152
|
DEFAULT_BIJECTION = {
|
|
145
153
|
"positive": npt.SoftplusTransform(),
|
|
154
|
+
"non_negative": npt.SoftplusTransform(),
|
|
146
155
|
"real": npt.IdentityTransform(),
|
|
147
156
|
"sigmoid": npt.SigmoidTransform(),
|
|
148
157
|
"lower_triangular": FillTriangularTransform(),
|
|
@@ -164,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
|
|
|
164
173
|
)
|
|
165
174
|
|
|
166
175
|
|
|
176
|
+
@checkify.checkify
|
|
177
|
+
def _check_is_non_negative(value):
|
|
178
|
+
checkify.check(
|
|
179
|
+
jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
167
183
|
@checkify.checkify
|
|
168
184
|
def _check_is_positive(value):
|
|
169
185
|
checkify.check(
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from beartype.typing import Any
|
|
16
17
|
from flax import nnx
|
|
17
18
|
import jax.numpy as jnp
|
|
18
19
|
import jax.random as jr
|
|
@@ -26,7 +27,15 @@ import scipy
|
|
|
26
27
|
|
|
27
28
|
from gpjax.dataset import Dataset
|
|
28
29
|
from gpjax.fit import (
|
|
30
|
+
_check_batch_size,
|
|
31
|
+
_check_log_rate,
|
|
32
|
+
_check_model,
|
|
33
|
+
_check_num_iters,
|
|
34
|
+
_check_optim,
|
|
35
|
+
_check_train_data,
|
|
36
|
+
_check_verbose,
|
|
29
37
|
fit,
|
|
38
|
+
fit_lbfgs,
|
|
30
39
|
fit_scipy,
|
|
31
40
|
get_batch,
|
|
32
41
|
)
|
|
@@ -141,6 +150,46 @@ def test_fit_scipy_simple():
|
|
|
141
150
|
assert trained_model.bias.value == 1.0
|
|
142
151
|
|
|
143
152
|
|
|
153
|
+
def test_fit_lbfgs_simple():
|
|
154
|
+
# Create dataset:
|
|
155
|
+
X = jnp.linspace(0.0, 10.0, 100).reshape(-1, 1)
|
|
156
|
+
y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape).reshape(-1, 1)
|
|
157
|
+
D = Dataset(X, y)
|
|
158
|
+
|
|
159
|
+
# Define linear model:
|
|
160
|
+
class LinearModel(nnx.Module):
|
|
161
|
+
def __init__(self, weight: float, bias: float):
|
|
162
|
+
self.weight = PositiveReal(weight)
|
|
163
|
+
self.bias = Static(bias)
|
|
164
|
+
|
|
165
|
+
def __call__(self, x):
|
|
166
|
+
return self.weight.value * x + self.bias.value
|
|
167
|
+
|
|
168
|
+
model = LinearModel(weight=1.0, bias=1.0)
|
|
169
|
+
|
|
170
|
+
# Define loss function:
|
|
171
|
+
def mse(model, data):
|
|
172
|
+
pred = model(data.X)
|
|
173
|
+
return jnp.mean((pred - data.y) ** 2)
|
|
174
|
+
|
|
175
|
+
# Train with bfgs!
|
|
176
|
+
trained_model, final_loss = fit_lbfgs(
|
|
177
|
+
model=model,
|
|
178
|
+
objective=mse,
|
|
179
|
+
train_data=D,
|
|
180
|
+
max_iters=10,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Ensure we return a model of the same class
|
|
184
|
+
assert isinstance(trained_model, LinearModel)
|
|
185
|
+
|
|
186
|
+
# Test reduction in loss:
|
|
187
|
+
assert mse(trained_model, D) < mse(model, D)
|
|
188
|
+
|
|
189
|
+
# Test stop_gradient on bias:
|
|
190
|
+
assert trained_model.bias.value == 1.0
|
|
191
|
+
|
|
192
|
+
|
|
144
193
|
@pytest.mark.parametrize("n_data", [20])
|
|
145
194
|
@pytest.mark.parametrize("verbose", [True, False])
|
|
146
195
|
def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
|
|
@@ -179,8 +228,7 @@ def test_fit_gp_regression(n_data: int, verbose: bool) -> None:
|
|
|
179
228
|
|
|
180
229
|
|
|
181
230
|
@pytest.mark.parametrize("n_data", [20])
|
|
182
|
-
|
|
183
|
-
def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
|
|
231
|
+
def test_fit_lbfgs_gp_regression(n_data: int) -> None:
|
|
184
232
|
# Create dataset:
|
|
185
233
|
key = jr.PRNGKey(123)
|
|
186
234
|
x = jnp.sort(
|
|
@@ -195,20 +243,16 @@ def test_fit_scipy_gp_regression(n_data: int, verbose: bool) -> None:
|
|
|
195
243
|
posterior = prior * likelihood
|
|
196
244
|
|
|
197
245
|
# Train with BFGS!
|
|
198
|
-
trained_model_bfgs,
|
|
246
|
+
trained_model_bfgs, final_loss = fit_lbfgs(
|
|
199
247
|
model=posterior,
|
|
200
248
|
objective=conjugate_mll,
|
|
201
249
|
train_data=D,
|
|
202
250
|
max_iters=40,
|
|
203
|
-
verbose=verbose,
|
|
204
251
|
)
|
|
205
252
|
|
|
206
253
|
# Ensure the trained model is a Gaussian process posterior
|
|
207
254
|
assert isinstance(trained_model_bfgs, ConjugatePosterior)
|
|
208
255
|
|
|
209
|
-
# Ensure we return a history_bfgs of the correct length
|
|
210
|
-
assert len(history_bfgs) > 2
|
|
211
|
-
|
|
212
256
|
# Ensure we reduce the loss
|
|
213
257
|
assert conjugate_mll(trained_model_bfgs, D) < conjugate_mll(posterior, D)
|
|
214
258
|
|
|
@@ -324,3 +368,142 @@ def test_get_batch(n_data: int, n_dim: int, batch_size: int):
|
|
|
324
368
|
assert New.y.shape[1:] == y.shape[1:]
|
|
325
369
|
assert jnp.sum(New.X == B.X) <= n_dim * batch_size / n_data
|
|
326
370
|
assert jnp.sum(New.y == B.y) <= n_dim * batch_size / n_data
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@pytest.fixture
|
|
374
|
+
def valid_model() -> nnx.Module:
|
|
375
|
+
"""Return a valid model for testing."""
|
|
376
|
+
|
|
377
|
+
class LinearModel(nnx.Module):
|
|
378
|
+
def __init__(self, weight: float, bias: float) -> None:
|
|
379
|
+
self.weight = PositiveReal(weight)
|
|
380
|
+
self.bias = Static(bias)
|
|
381
|
+
|
|
382
|
+
def __call__(self, x: Any) -> Any:
|
|
383
|
+
return self.weight.value * x + self.bias.value
|
|
384
|
+
|
|
385
|
+
return LinearModel(weight=1.0, bias=1.0)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
@pytest.fixture
|
|
389
|
+
def valid_dataset() -> Dataset:
|
|
390
|
+
"""Return a valid dataset for testing."""
|
|
391
|
+
X = jnp.array([[1.0], [2.0], [3.0]])
|
|
392
|
+
y = jnp.array([[1.0], [2.0], [3.0]])
|
|
393
|
+
return Dataset(X=X, y=y)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def test_check_model_valid(valid_model: nnx.Module) -> None:
|
|
397
|
+
"""Test that a valid model passes validation."""
|
|
398
|
+
_check_model(valid_model)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def test_check_model_invalid() -> None:
|
|
402
|
+
"""Test that an invalid model raises a TypeError."""
|
|
403
|
+
model = "not a model"
|
|
404
|
+
with pytest.raises(
|
|
405
|
+
TypeError, match="Expected model to be a subclass of nnx.Module"
|
|
406
|
+
):
|
|
407
|
+
_check_model(model)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def test_check_train_data_valid(valid_dataset: Dataset) -> None:
|
|
411
|
+
"""Test that valid training data passes validation."""
|
|
412
|
+
_check_train_data(valid_dataset)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def test_check_train_data_invalid() -> None:
|
|
416
|
+
"""Test that invalid training data raises a TypeError."""
|
|
417
|
+
train_data = "not a dataset"
|
|
418
|
+
with pytest.raises(
|
|
419
|
+
TypeError, match="Expected train_data to be of type gpjax.Dataset"
|
|
420
|
+
):
|
|
421
|
+
_check_train_data(train_data)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def test_check_optim_valid() -> None:
|
|
425
|
+
"""Test that a valid optimiser passes validation."""
|
|
426
|
+
optim = ox.sgd(0.1)
|
|
427
|
+
_check_optim(optim)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def test_check_optim_invalid() -> None:
|
|
431
|
+
"""Test that an invalid optimiser raises a TypeError."""
|
|
432
|
+
optim = "not an optimiser"
|
|
433
|
+
with pytest.raises(
|
|
434
|
+
TypeError, match="Expected optim to be of type optax.GradientTransformation"
|
|
435
|
+
):
|
|
436
|
+
_check_optim(optim)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@pytest.mark.parametrize("num_iters", [1, 10, 100])
|
|
440
|
+
def test_check_num_iters_valid(num_iters: int) -> None:
|
|
441
|
+
"""Test that valid number of iterations passes validation."""
|
|
442
|
+
_check_num_iters(num_iters)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def test_check_num_iters_invalid_type() -> None:
|
|
446
|
+
"""Test that an invalid num_iters type raises a TypeError."""
|
|
447
|
+
num_iters = "not an int"
|
|
448
|
+
with pytest.raises(TypeError, match="Expected num_iters to be of type int"):
|
|
449
|
+
_check_num_iters(num_iters)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@pytest.mark.parametrize("num_iters", [0, -5])
|
|
453
|
+
def test_check_num_iters_invalid_value(num_iters: int) -> None:
|
|
454
|
+
"""Test that an invalid num_iters value raises a ValueError."""
|
|
455
|
+
with pytest.raises(ValueError, match="Expected num_iters to be positive"):
|
|
456
|
+
_check_num_iters(num_iters)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@pytest.mark.parametrize("log_rate", [1, 10, 100])
|
|
460
|
+
def test_check_log_rate_valid(log_rate: int) -> None:
|
|
461
|
+
"""Test that a valid log rate passes validation."""
|
|
462
|
+
_check_log_rate(log_rate)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def test_check_log_rate_invalid_type() -> None:
|
|
466
|
+
"""Test that an invalid log_rate type raises a TypeError."""
|
|
467
|
+
log_rate = "not an int"
|
|
468
|
+
with pytest.raises(TypeError, match="Expected log_rate to be of type int"):
|
|
469
|
+
_check_log_rate(log_rate)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@pytest.mark.parametrize("log_rate", [0, -5])
|
|
473
|
+
def test_check_log_rate_invalid_value(log_rate: int) -> None:
|
|
474
|
+
"""Test that an invalid log_rate value raises a ValueError."""
|
|
475
|
+
with pytest.raises(ValueError, match="Expected log_rate to be positive"):
|
|
476
|
+
_check_log_rate(log_rate)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@pytest.mark.parametrize("verbose", [True, False])
|
|
480
|
+
def test_check_verbose_valid(verbose: bool) -> None:
|
|
481
|
+
"""Test that valid verbose values pass validation."""
|
|
482
|
+
_check_verbose(verbose)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def test_check_verbose_invalid() -> None:
|
|
486
|
+
"""Test that an invalid verbose value raises a TypeError."""
|
|
487
|
+
verbose = "not a bool"
|
|
488
|
+
with pytest.raises(TypeError, match="Expected verbose to be of type bool"):
|
|
489
|
+
_check_verbose(verbose)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@pytest.mark.parametrize("batch_size", [1, 10, 100, -1])
|
|
493
|
+
def test_check_batch_size_valid(batch_size: int) -> None:
|
|
494
|
+
"""Test that valid batch sizes pass validation."""
|
|
495
|
+
_check_batch_size(batch_size)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def test_check_batch_size_invalid_type() -> None:
|
|
499
|
+
"""Test that an invalid batch_size type raises a TypeError."""
|
|
500
|
+
batch_size = "not an int"
|
|
501
|
+
with pytest.raises(TypeError, match="Expected batch_size to be of type int"):
|
|
502
|
+
_check_batch_size(batch_size)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
@pytest.mark.parametrize("batch_size", [0, -2, -5])
|
|
506
|
+
def test_check_batch_size_invalid_value(batch_size: int) -> None:
|
|
507
|
+
"""Test that invalid batch_size values raise a ValueError."""
|
|
508
|
+
with pytest.raises(ValueError, match="Expected batch_size to be positive or -1"):
|
|
509
|
+
_check_batch_size(batch_size)
|
|
@@ -31,7 +31,7 @@ from gpjax.kernels.nonstationary import (
|
|
|
31
31
|
Polynomial,
|
|
32
32
|
)
|
|
33
33
|
from gpjax.parameters import (
|
|
34
|
-
|
|
34
|
+
NonNegativeReal,
|
|
35
35
|
Static,
|
|
36
36
|
)
|
|
37
37
|
|
|
@@ -96,8 +96,8 @@ def test_init_override_paramtype(kernel_request):
|
|
|
96
96
|
continue
|
|
97
97
|
new_params[param] = Static(value)
|
|
98
98
|
|
|
99
|
-
k = kernel(**new_params, variance=
|
|
100
|
-
assert isinstance(k.variance,
|
|
99
|
+
k = kernel(**new_params, variance=NonNegativeReal(variance))
|
|
100
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
101
101
|
|
|
102
102
|
for param in params.keys():
|
|
103
103
|
if param in ("degree", "order"):
|
|
@@ -112,7 +112,7 @@ def test_init_defaults(kernel: type[AbstractKernel]):
|
|
|
112
112
|
|
|
113
113
|
# Check that the parameters are set correctly
|
|
114
114
|
assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
|
|
115
|
-
assert isinstance(k.variance,
|
|
115
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
116
116
|
|
|
117
117
|
|
|
118
118
|
@pytest.mark.parametrize("kernel", [k[0] for k in TESTED_KERNELS])
|
|
@@ -122,7 +122,7 @@ def test_init_variances(kernel: type[AbstractKernel], variance):
|
|
|
122
122
|
k = kernel(variance=variance)
|
|
123
123
|
|
|
124
124
|
# Check that the parameters are set correctly
|
|
125
|
-
assert isinstance(k.variance,
|
|
125
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
126
126
|
assert jnp.allclose(k.variance.value, jnp.asarray(variance))
|
|
127
127
|
|
|
128
128
|
# Check that error is raised if variance is not valid
|
|
@@ -35,6 +35,7 @@ from gpjax.kernels.stationary import (
|
|
|
35
35
|
)
|
|
36
36
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
37
37
|
from gpjax.parameters import (
|
|
38
|
+
NonNegativeReal,
|
|
38
39
|
PositiveReal,
|
|
39
40
|
Static,
|
|
40
41
|
)
|
|
@@ -106,12 +107,12 @@ def test_init_override_paramtype(kernel_request):
|
|
|
106
107
|
for param, value in params.items():
|
|
107
108
|
new_params[param] = Static(value)
|
|
108
109
|
|
|
109
|
-
kwargs = {**new_params, "variance":
|
|
110
|
+
kwargs = {**new_params, "variance": NonNegativeReal(variance)}
|
|
110
111
|
if kernel != White:
|
|
111
112
|
kwargs["lengthscale"] = PositiveReal(lengthscale)
|
|
112
113
|
|
|
113
114
|
k = kernel(**kwargs)
|
|
114
|
-
assert isinstance(k.variance,
|
|
115
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
115
116
|
|
|
116
117
|
for param in params.keys():
|
|
117
118
|
assert isinstance(getattr(k, param), Static)
|
|
@@ -124,7 +125,7 @@ def test_init_defaults(kernel: type[StationaryKernel]):
|
|
|
124
125
|
|
|
125
126
|
# Check that the parameters are set correctly
|
|
126
127
|
assert isinstance(k.compute_engine, type(AbstractKernelComputation()))
|
|
127
|
-
assert isinstance(k.variance,
|
|
128
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
128
129
|
assert isinstance(k.lengthscale, PositiveReal)
|
|
129
130
|
|
|
130
131
|
|
|
@@ -167,7 +168,7 @@ def test_init_variances(kernel: type[StationaryKernel], variance):
|
|
|
167
168
|
k = kernel(variance=variance)
|
|
168
169
|
|
|
169
170
|
# Check that the parameters are set correctly
|
|
170
|
-
assert isinstance(k.variance,
|
|
171
|
+
assert isinstance(k.variance, NonNegativeReal)
|
|
171
172
|
assert jnp.allclose(k.variance.value, jnp.asarray(variance))
|
|
172
173
|
|
|
173
174
|
# Check that error is raised if variance is not valid
|