gpjax 0.9.1__tar.gz → 0.9.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.9.1 → gpjax-0.9.2}/PKG-INFO +1 -1
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/__init__.py +1 -1
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/gps.py +8 -1
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_gps.py +13 -6
- gpjax-0.9.1/.github/workflows/labeler.yml +0 -18
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/codecov.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/labels.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/pull_request_template.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/release-drafter.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/integration.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.github/workflows/tests.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/.gitignore +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/CITATION.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/LICENSE +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/Makefile +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/README.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/contributing.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/design.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/index.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/index.rst +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/installation.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/javascripts/katex.js +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/refs.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/sharp_bits.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/GP.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/GP.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/favicon.ico +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/backend.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/barycentres.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/bayesian_optimisation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/classification.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/collapsed_vi.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/decision_making.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/deep_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/graph_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_gps.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/oceanmodelling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/poisson.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/regression.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/examples/yacht.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/citation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/dataset.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/decision_maker.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/posterior_handler.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/search_space.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_maximizer.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/distributions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/fit.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/integrators.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/likelihoods.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/mean_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/objectives.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/parameters.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/scan.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/typing.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/gpjax/variational_families.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/mkdocs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/pyproject.toml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/static/paper.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/conftest.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/integration_tests.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_citations.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_dataset.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_decision_maker.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_posterior_handler.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_search_space.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utility_maximizer.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/test_utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_decision_making/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_fit.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_integrators.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_likelihoods.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_markdown.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_mean_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_objectives.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_parameters.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.2}/tests/test_variational_families.py +0 -0
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.9.
|
|
43
|
+
__version__ = "0.9.2"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
|
@@ -17,6 +17,7 @@ from abc import abstractmethod
|
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
19
|
from cola.annotations import PSD
|
|
20
|
+
from cola.linalg.algorithm_base import Algorithm
|
|
20
21
|
from cola.linalg.decompositions.decompositions import Cholesky
|
|
21
22
|
from cola.linalg.inverse.inv import solve
|
|
22
23
|
from cola.ops.operators import I_like
|
|
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
530
531
|
train_data: Dataset,
|
|
531
532
|
key: KeyArray,
|
|
532
533
|
num_features: int | None = 100,
|
|
534
|
+
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
|
|
533
535
|
) -> FunctionalSample:
|
|
534
536
|
r"""Draw approximate samples from the Gaussian process posterior.
|
|
535
537
|
|
|
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
563
565
|
key (KeyArray): The random seed used for the sample(s).
|
|
564
566
|
num_features (int): The number of features used when approximating the
|
|
565
567
|
kernel.
|
|
568
|
+
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
|
|
569
|
+
the inverse of the covariance matrix. See the
|
|
570
|
+
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
|
|
571
|
+
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
|
|
572
|
+
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
|
|
566
573
|
|
|
567
574
|
Returns:
|
|
568
575
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
588
595
|
canonical_weights = solve(
|
|
589
596
|
Sigma,
|
|
590
597
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
591
|
-
|
|
598
|
+
solver_algorithm,
|
|
592
599
|
) # [N, B]
|
|
593
600
|
|
|
594
601
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
@@ -25,13 +25,15 @@ from typing import (
|
|
|
25
25
|
Type,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
+
from cola.linalg.algorithm_base import Auto
|
|
29
|
+
from cola.linalg.decompositions.decompositions import Cholesky
|
|
30
|
+
from cola.linalg.inverse.cg import CG
|
|
28
31
|
from jax import config
|
|
29
32
|
import jax.numpy as jnp
|
|
30
33
|
import jax.random as jr
|
|
31
34
|
import pytest
|
|
32
35
|
import tensorflow_probability.substrates.jax.distributions as tfd
|
|
33
36
|
|
|
34
|
-
# from gpjax.dataset import Dataset
|
|
35
37
|
from gpjax.dataset import Dataset
|
|
36
38
|
from gpjax.distributions import GaussianDistribution
|
|
37
39
|
from gpjax.gps import (
|
|
@@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
|
|
|
283
285
|
@pytest.mark.parametrize("num_datapoints", [1, 5])
|
|
284
286
|
@pytest.mark.parametrize("kernel", [RBF, Matern52])
|
|
285
287
|
@pytest.mark.parametrize("mean_function", [Zero, Constant])
|
|
286
|
-
|
|
288
|
+
@pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
|
|
289
|
+
def test_conjugate_posterior_sample_approx(
|
|
290
|
+
num_datapoints, kernel, mean_function, solver_algorithm
|
|
291
|
+
):
|
|
287
292
|
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
|
|
288
293
|
p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
|
|
289
294
|
num_datapoints=num_datapoints
|
|
@@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
|
|
|
310
315
|
# with pytest.raises(ValidationErrors):
|
|
311
316
|
# p.sample_approx(1, D, key, 0.5)
|
|
312
317
|
|
|
313
|
-
sampled_fn = p.sample_approx(1, D, key, 100)
|
|
318
|
+
sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
|
|
314
319
|
assert isinstance(sampled_fn, Callable) # check type
|
|
315
320
|
|
|
316
321
|
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
|
|
317
322
|
evals = sampled_fn(x)
|
|
318
323
|
assert evals.shape == (num_datapoints, 1.0) # check shape
|
|
319
324
|
|
|
320
|
-
sampled_fn_2 = p.sample_approx(1, D, key, 100)
|
|
325
|
+
sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
|
|
321
326
|
evals_2 = sampled_fn_2(x)
|
|
322
327
|
max_delta = jnp.max(jnp.abs(evals - evals_2))
|
|
323
328
|
assert max_delta == 0.0 # samples same for same seed
|
|
324
329
|
|
|
325
330
|
new_key = jr.key(12345)
|
|
326
|
-
sampled_fn_3 = p.sample_approx(
|
|
331
|
+
sampled_fn_3 = p.sample_approx(
|
|
332
|
+
1, D, new_key, 100, solver_algorithm=solver_algorithm
|
|
333
|
+
)
|
|
327
334
|
evals_3 = sampled_fn_3(x)
|
|
328
335
|
max_delta = jnp.max(jnp.abs(evals - evals_3))
|
|
329
336
|
assert max_delta > 0.01 # samples different for different seed
|
|
330
337
|
|
|
331
338
|
# Check validty of samples using Monte-Carlo
|
|
332
|
-
sampled_fn = p.sample_approx(10_000, D, key, 100)
|
|
339
|
+
sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
|
|
333
340
|
sampled_evals = sampled_fn(x)
|
|
334
341
|
approx_mean = jnp.mean(sampled_evals, -1)
|
|
335
342
|
approx_var = jnp.var(sampled_evals, -1)
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
name: Labeler
|
|
2
|
-
|
|
3
|
-
on:
|
|
4
|
-
push:
|
|
5
|
-
branches:
|
|
6
|
-
- main
|
|
7
|
-
|
|
8
|
-
jobs:
|
|
9
|
-
labeler:
|
|
10
|
-
runs-on: ubuntu-latest
|
|
11
|
-
steps:
|
|
12
|
-
- name: Check out the repository
|
|
13
|
-
uses: actions/checkout@v3.5.2
|
|
14
|
-
|
|
15
|
-
- name: Run Labeler
|
|
16
|
-
uses: crazy-max/ghaction-github-labeler@v4.1.0
|
|
17
|
-
with:
|
|
18
|
-
skip-delete: true
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gpjax-0.9.1 → gpjax-0.9.2}/gpjax/decision_making/utility_functions/probability_of_improvement.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|