gpjax 0.9.1__tar.gz → 0.9.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.9.1 → gpjax-0.9.3}/PKG-INFO +1 -1
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_kernels.py +1 -1
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/poisson.py +11 -22
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/uncollapsed_vi.py +1 -2
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/__init__.py +1 -1
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/gps.py +8 -1
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/likelihoods.py +3 -5
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/scan.py +10 -10
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/variational_families.py +9 -2
- {gpjax-0.9.1 → gpjax-0.9.3}/mkdocs.yml +2 -2
- gpjax-0.9.3/publish/gpjax-0.9.3-py3-none-any.whl +0 -0
- gpjax-0.9.3/publish/gpjax-0.9.3.tar.gz +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_gps.py +13 -6
- gpjax-0.9.1/.github/workflows/labeler.yml +0 -18
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/codecov.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/labels.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/pull_request_template.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/release-drafter.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/integration.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.github/workflows/tests.yml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/.gitignore +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/CITATION.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/LICENSE +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/Makefile +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/README.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/contributing.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/design.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/index.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/index.rst +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/installation.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/javascripts/katex.js +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/refs.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/sharp_bits.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/GP.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/GP.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/favicon.ico +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/backend.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/barycentres.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/bayesian_optimisation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/classification.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/collapsed_vi.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/decision_making.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/deep_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/graph_kernels.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/intro_to_gps.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/oceanmodelling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/regression.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/examples/yacht.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/citation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/dataset.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/decision_maker.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/posterior_handler.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/search_space.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/continuous_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/test_functions/non_conjugate_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/expected_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/probability_of_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/thompson_sampling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_maximizer.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/distributions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/fit.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/integrators.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/mean_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/objectives.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/parameters.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/gpjax/typing.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/pyproject.toml +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.bib +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.md +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/static/paper.pdf +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/conftest.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/integration_tests.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_citations.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_dataset.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_decision_maker.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_posterior_handler.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_search_space.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/test_continuous_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_expected_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_probability_of_improvement.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_thompson_sampling.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_functions/test_utility_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utility_maximizer.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/test_utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_decision_making/utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_fit.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_integrators.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_likelihoods.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_markdown.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_mean_functions.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_objectives.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_parameters.py +0 -0
- {gpjax-0.9.1 → gpjax-0.9.3}/tests/test_variational_families.py +0 -0
|
@@ -246,7 +246,7 @@ kernel = gpx.kernels.Matern52(
|
|
|
246
246
|
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
|
|
247
247
|
|
|
248
248
|
likelihood = gpx.likelihoods.Gaussian(
|
|
249
|
-
num_datapoints=D.n,
|
|
249
|
+
num_datapoints=D.n, obs_stdev=Static(jnp.array(1e-3))
|
|
250
250
|
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
|
|
251
251
|
|
|
252
252
|
no_opt_posterior = prior * likelihood
|
|
@@ -154,33 +154,22 @@ def logprob_fn(params):
|
|
|
154
154
|
return gpx.objectives.log_posterior_density(model, D)
|
|
155
155
|
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
157
|
+
step_size = 1e-3
|
|
158
|
+
inverse_mass_matrix = jnp.ones(53)
|
|
159
|
+
nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix)
|
|
160
160
|
|
|
161
|
+
state = nuts.init(params)
|
|
161
162
|
|
|
162
|
-
|
|
163
|
-
blackjax.nuts, logprob_fn, num_adapt, target_acceptance_rate=0.65, progress_bar=True
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
# Initialise the chain
|
|
167
|
-
last_state, kernel, _ = adapt.run(key, params)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
def inference_loop(rng_key, kernel, initial_state, num_samples):
|
|
171
|
-
def one_step(state, rng_key):
|
|
172
|
-
state, info = kernel(rng_key, state)
|
|
173
|
-
return state, (state, info)
|
|
174
|
-
|
|
175
|
-
keys = jax.random.split(rng_key, num_samples)
|
|
176
|
-
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys, unroll=10)
|
|
163
|
+
step = jax.jit(nuts.step)
|
|
177
164
|
|
|
178
|
-
return states, infos
|
|
179
165
|
|
|
166
|
+
def one_step(state, rng_key):
|
|
167
|
+
state, info = step(rng_key, state)
|
|
168
|
+
return state, (state, info)
|
|
180
169
|
|
|
181
|
-
# Sample from the posterior distribution
|
|
182
|
-
states, infos = inference_loop(key, kernel, last_state, num_samples)
|
|
183
170
|
|
|
171
|
+
keys = jax.random.split(key, num_samples)
|
|
172
|
+
_, (states, infos) = jax.lax.scan(one_step, state, keys, unroll=10)
|
|
184
173
|
|
|
185
174
|
# %% [markdown]
|
|
186
175
|
# ### Sampler efficiency
|
|
@@ -190,7 +179,7 @@ states, infos = inference_loop(key, kernel, last_state, num_samples)
|
|
|
190
179
|
# proposed sample, divided by the total number of steps run by the chain).
|
|
191
180
|
|
|
192
181
|
# %%
|
|
193
|
-
acceptance_rate = jnp.mean(infos.
|
|
182
|
+
acceptance_rate = jnp.mean(infos.acceptance_rate)
|
|
194
183
|
print(f"Acceptance rate: {acceptance_rate:.2f}")
|
|
195
184
|
|
|
196
185
|
# %%
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.
|
|
11
|
+
# jupytext_version: 1.11.2
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax_beartype
|
|
14
14
|
# language: python
|
|
@@ -319,7 +319,6 @@ opt_rep, history = gpx.fit(
|
|
|
319
319
|
model=q,
|
|
320
320
|
objective=lambda p, d: -gpx.objectives.elbo(p, d),
|
|
321
321
|
train_data=D,
|
|
322
|
-
params_bijection=params_bijection,
|
|
323
322
|
optim=ox.adam(learning_rate=0.01),
|
|
324
323
|
num_iters=3000,
|
|
325
324
|
key=jr.key(42),
|
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.9.
|
|
43
|
+
__version__ = "0.9.3"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
|
@@ -17,6 +17,7 @@ from abc import abstractmethod
|
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
19
|
from cola.annotations import PSD
|
|
20
|
+
from cola.linalg.algorithm_base import Algorithm
|
|
20
21
|
from cola.linalg.decompositions.decompositions import Cholesky
|
|
21
22
|
from cola.linalg.inverse.inv import solve
|
|
22
23
|
from cola.ops.operators import I_like
|
|
@@ -530,6 +531,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
530
531
|
train_data: Dataset,
|
|
531
532
|
key: KeyArray,
|
|
532
533
|
num_features: int | None = 100,
|
|
534
|
+
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
|
|
533
535
|
) -> FunctionalSample:
|
|
534
536
|
r"""Draw approximate samples from the Gaussian process posterior.
|
|
535
537
|
|
|
@@ -563,6 +565,11 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
563
565
|
key (KeyArray): The random seed used for the sample(s).
|
|
564
566
|
num_features (int): The number of features used when approximating the
|
|
565
567
|
kernel.
|
|
568
|
+
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
|
|
569
|
+
the inverse of the covariance matrix. See the
|
|
570
|
+
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
|
|
571
|
+
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
|
|
572
|
+
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
|
|
566
573
|
|
|
567
574
|
Returns:
|
|
568
575
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
@@ -588,7 +595,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
588
595
|
canonical_weights = solve(
|
|
589
596
|
Sigma,
|
|
590
597
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
591
|
-
|
|
598
|
+
solver_algorithm,
|
|
592
599
|
) # [N, B]
|
|
593
600
|
|
|
594
601
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
@@ -28,7 +28,6 @@ from gpjax.integrators import (
|
|
|
28
28
|
GHQuadratureIntegrator,
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
|
-
Parameter,
|
|
32
31
|
PositiveReal,
|
|
33
32
|
Static,
|
|
34
33
|
)
|
|
@@ -152,10 +151,9 @@ class Gaussian(AbstractLikelihood):
|
|
|
152
151
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
153
152
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
154
153
|
"""
|
|
155
|
-
if isinstance(obs_stddev,
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
|
|
154
|
+
if not isinstance(obs_stddev, (PositiveReal, Static)):
|
|
155
|
+
obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
|
|
156
|
+
self.obs_stddev = obs_stddev
|
|
159
157
|
|
|
160
158
|
super().__init__(num_datapoints, integrator)
|
|
161
159
|
|
|
@@ -22,7 +22,6 @@ from beartype.typing import (
|
|
|
22
22
|
)
|
|
23
23
|
import jax
|
|
24
24
|
from jax import lax
|
|
25
|
-
from jax.experimental import host_callback as hcb
|
|
26
25
|
import jax.numpy as jnp
|
|
27
26
|
import jax.tree_util as jtu
|
|
28
27
|
from jaxtyping import (
|
|
@@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:
|
|
|
54
53
|
|
|
55
54
|
def _do_callback(_) -> int:
|
|
56
55
|
"""Perform the callback."""
|
|
57
|
-
|
|
56
|
+
jax.debug.callback(func, *args)
|
|
57
|
+
return _dummy_result
|
|
58
58
|
|
|
59
59
|
def _not_callback(_) -> int:
|
|
60
60
|
"""Do nothing."""
|
|
@@ -113,19 +113,19 @@ def vscan(
|
|
|
113
113
|
_progress_bar = trange(_length)
|
|
114
114
|
_progress_bar.set_description("Compiling...", refresh=True)
|
|
115
115
|
|
|
116
|
-
def _set_running(args: Any
|
|
116
|
+
def _set_running(*args: Any) -> None:
|
|
117
117
|
"""Set the tqdm progress bar to running."""
|
|
118
118
|
_progress_bar.set_description("Running", refresh=False)
|
|
119
119
|
|
|
120
|
-
def _update_tqdm(args: Any
|
|
120
|
+
def _update_tqdm(*args: Any) -> None:
|
|
121
121
|
"""Update the tqdm progress bar with the latest objective value."""
|
|
122
122
|
_value, _iter_num = args
|
|
123
|
-
_progress_bar.update(_iter_num)
|
|
123
|
+
_progress_bar.update(_iter_num.item())
|
|
124
124
|
|
|
125
125
|
if log_value and _value is not None:
|
|
126
126
|
_progress_bar.set_postfix({"Value": f"{_value: .2f}"})
|
|
127
127
|
|
|
128
|
-
def _close_tqdm(args: Any
|
|
128
|
+
def _close_tqdm(*args: Any) -> None:
|
|
129
129
|
"""Close the tqdm progress bar."""
|
|
130
130
|
_progress_bar.close()
|
|
131
131
|
|
|
@@ -145,16 +145,16 @@ def vscan(
|
|
|
145
145
|
_is_last: bool = iter_num == _length - 1
|
|
146
146
|
|
|
147
147
|
# Update progress bar, if first of log_rate.
|
|
148
|
-
_callback(_is_first, _set_running
|
|
148
|
+
_callback(_is_first, _set_running)
|
|
149
149
|
|
|
150
150
|
# Update progress bar, if multiple of log_rate.
|
|
151
|
-
_callback(_is_multiple, _update_tqdm,
|
|
151
|
+
_callback(_is_multiple, _update_tqdm, y, log_rate)
|
|
152
152
|
|
|
153
153
|
# Update progress bar, if remainder.
|
|
154
|
-
_callback(_is_remainder, _update_tqdm,
|
|
154
|
+
_callback(_is_remainder, _update_tqdm, y, _remainder)
|
|
155
155
|
|
|
156
156
|
# Close progress bar, if last iteration.
|
|
157
|
-
_callback(_is_last, _close_tqdm
|
|
157
|
+
_callback(_is_last, _close_tqdm)
|
|
158
158
|
|
|
159
159
|
return carry, y
|
|
160
160
|
|
|
@@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
108
108
|
def __init__(
|
|
109
109
|
self,
|
|
110
110
|
posterior: AbstractPosterior[P, L],
|
|
111
|
-
inducing_inputs:
|
|
111
|
+
inducing_inputs: tp.Union[
|
|
112
|
+
Float[Array, "N D"],
|
|
113
|
+
Real,
|
|
114
|
+
Static,
|
|
115
|
+
],
|
|
112
116
|
jitter: ScalarFloat = 1e-6,
|
|
113
117
|
):
|
|
114
|
-
|
|
118
|
+
if not isinstance(inducing_inputs, (Real, Static)):
|
|
119
|
+
inducing_inputs = Real(inducing_inputs)
|
|
120
|
+
|
|
121
|
+
self.inducing_inputs = inducing_inputs
|
|
115
122
|
self.jitter = jitter
|
|
116
123
|
|
|
117
124
|
super().__init__(posterior)
|
|
@@ -24,8 +24,8 @@ nav:
|
|
|
24
24
|
- Barycentres: _examples/barycentres.md
|
|
25
25
|
- Deep kernel learning: _examples/deep_kernels.md
|
|
26
26
|
- Graph kernels: _examples/graph_kernels.md
|
|
27
|
-
- Sparse GPs: _examples/
|
|
28
|
-
- Stochastic sparse GPs: _examples/
|
|
27
|
+
- Sparse GPs: _examples/collapsed_vi.md
|
|
28
|
+
- Stochastic sparse GPs: _examples/uncollapsed_vi.md
|
|
29
29
|
- Bayesian Optimisation: _examples/bayesian_optimisation.md
|
|
30
30
|
- Decision Making: _examples/decision_making.md
|
|
31
31
|
- Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md
|
|
Binary file
|
|
Binary file
|
|
@@ -25,13 +25,15 @@ from typing import (
|
|
|
25
25
|
Type,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
+
from cola.linalg.algorithm_base import Auto
|
|
29
|
+
from cola.linalg.decompositions.decompositions import Cholesky
|
|
30
|
+
from cola.linalg.inverse.cg import CG
|
|
28
31
|
from jax import config
|
|
29
32
|
import jax.numpy as jnp
|
|
30
33
|
import jax.random as jr
|
|
31
34
|
import pytest
|
|
32
35
|
import tensorflow_probability.substrates.jax.distributions as tfd
|
|
33
36
|
|
|
34
|
-
# from gpjax.dataset import Dataset
|
|
35
37
|
from gpjax.dataset import Dataset
|
|
36
38
|
from gpjax.distributions import GaussianDistribution
|
|
37
39
|
from gpjax.gps import (
|
|
@@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
|
|
|
283
285
|
@pytest.mark.parametrize("num_datapoints", [1, 5])
|
|
284
286
|
@pytest.mark.parametrize("kernel", [RBF, Matern52])
|
|
285
287
|
@pytest.mark.parametrize("mean_function", [Zero, Constant])
|
|
286
|
-
|
|
288
|
+
@pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
|
|
289
|
+
def test_conjugate_posterior_sample_approx(
|
|
290
|
+
num_datapoints, kernel, mean_function, solver_algorithm
|
|
291
|
+
):
|
|
287
292
|
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
|
|
288
293
|
p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
|
|
289
294
|
num_datapoints=num_datapoints
|
|
@@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
|
|
|
310
315
|
# with pytest.raises(ValidationErrors):
|
|
311
316
|
# p.sample_approx(1, D, key, 0.5)
|
|
312
317
|
|
|
313
|
-
sampled_fn = p.sample_approx(1, D, key, 100)
|
|
318
|
+
sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
|
|
314
319
|
assert isinstance(sampled_fn, Callable) # check type
|
|
315
320
|
|
|
316
321
|
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
|
|
317
322
|
evals = sampled_fn(x)
|
|
318
323
|
assert evals.shape == (num_datapoints, 1.0) # check shape
|
|
319
324
|
|
|
320
|
-
sampled_fn_2 = p.sample_approx(1, D, key, 100)
|
|
325
|
+
sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
|
|
321
326
|
evals_2 = sampled_fn_2(x)
|
|
322
327
|
max_delta = jnp.max(jnp.abs(evals - evals_2))
|
|
323
328
|
assert max_delta == 0.0 # samples same for same seed
|
|
324
329
|
|
|
325
330
|
new_key = jr.key(12345)
|
|
326
|
-
sampled_fn_3 = p.sample_approx(
|
|
331
|
+
sampled_fn_3 = p.sample_approx(
|
|
332
|
+
1, D, new_key, 100, solver_algorithm=solver_algorithm
|
|
333
|
+
)
|
|
327
334
|
evals_3 = sampled_fn_3(x)
|
|
328
335
|
max_delta = jnp.max(jnp.abs(evals - evals_3))
|
|
329
336
|
assert max_delta > 0.01 # samples different for different seed
|
|
330
337
|
|
|
331
338
|
# Check validty of samples using Monte-Carlo
|
|
332
|
-
sampled_fn = p.sample_approx(10_000, D, key, 100)
|
|
339
|
+
sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
|
|
333
340
|
sampled_evals = sampled_fn(x)
|
|
334
341
|
approx_mean = jnp.mean(sampled_evals, -1)
|
|
335
342
|
approx_var = jnp.var(sampled_evals, -1)
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
name: Labeler
|
|
2
|
-
|
|
3
|
-
on:
|
|
4
|
-
push:
|
|
5
|
-
branches:
|
|
6
|
-
- main
|
|
7
|
-
|
|
8
|
-
jobs:
|
|
9
|
-
labeler:
|
|
10
|
-
runs-on: ubuntu-latest
|
|
11
|
-
steps:
|
|
12
|
-
- name: Check out the repository
|
|
13
|
-
uses: actions/checkout@v3.5.2
|
|
14
|
-
|
|
15
|
-
- name: Run Labeler
|
|
16
|
-
uses: crazy-max/ghaction-github-labeler@v4.1.0
|
|
17
|
-
with:
|
|
18
|
-
skip-delete: true
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{gpjax-0.9.1 → gpjax-0.9.3}/gpjax/decision_making/utility_functions/probability_of_improvement.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|