gpjax 0.11.1__tar.gz → 0.11.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.11.1 → gpjax-0.11.2}/PKG-INFO +1 -1
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/__init__.py +1 -1
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/citation.py +7 -2
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_fit.py +7 -6
- gpjax-0.11.2/uv.lock +832 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/codecov.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/labels.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/pull_request_template.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/release-drafter.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/integration.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/tests.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/.gitignore +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/CITATION.bib +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/LICENSE.txt +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/Makefile +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/README.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/contributing.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/design.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/index.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/index.rst +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/installation.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/javascripts/katex.js +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/refs.bib +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/sharp_bits_figure.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/sharp_bits.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/GP.pdf +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/GP.svg +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/favicon.ico +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/backend.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/barycentres.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/classification.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/collapsed_vi.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/constructing_new_kernels.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/deep_kernels.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/graph_kernels.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_kernels.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/likelihoods_guide.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/oceanmodelling.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/poisson.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/regression.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/uncollapsed_vi.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/utils.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/examples/yacht.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/dataset.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/distributions.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/fit.py +3 -3
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/gps.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/integrators.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/approximations/rff.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/base.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/base.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern12.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern32.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern52.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/rbf.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/utils.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/likelihoods.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/mean_functions.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/numpyro_extras.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/objectives.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/parameters.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/scan.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/typing.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/variational_families.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/mkdocs.yml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/pyproject.toml +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.bib +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.md +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.pdf +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/conftest.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/integration_tests.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_citations.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_dataset.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_gaussian_distribution.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_gps.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_integrators.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_likelihoods.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_markdown.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_mean_functions.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_numpyro_extras.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_objectives.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_parameters.py +0 -0
- {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_variational_families.py +0 -0
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.11.
|
|
43
|
+
__version__ = "0.11.2"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
|
@@ -8,7 +8,12 @@ from beartype.typing import (
|
|
|
8
8
|
Dict,
|
|
9
9
|
Union,
|
|
10
10
|
)
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
# safely removable once jax>=0.6.0
|
|
14
|
+
from jaxlib.xla_extension import PjitFunction
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
from jaxlib._jax import PjitFunction
|
|
12
17
|
|
|
13
18
|
from gpjax.kernels import (
|
|
14
19
|
RFF,
|
|
@@ -45,7 +50,7 @@ class AbstractCitation:
|
|
|
45
50
|
|
|
46
51
|
|
|
47
52
|
class NullCitation(AbstractCitation):
|
|
48
|
-
def
|
|
53
|
+
def as_str(self) -> str:
|
|
49
54
|
return (
|
|
50
55
|
"No citation available. If you think this is an error, please open a pull"
|
|
51
56
|
" request."
|
|
@@ -13,13 +13,18 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from beartype.typing import Any
|
|
17
|
+
from flax import nnx
|
|
16
18
|
import jax.numpy as jnp
|
|
17
19
|
import jax.random as jr
|
|
20
|
+
from jaxtyping import (
|
|
21
|
+
Float,
|
|
22
|
+
Num,
|
|
23
|
+
)
|
|
18
24
|
import optax as ox
|
|
19
25
|
import pytest
|
|
20
26
|
import scipy
|
|
21
|
-
|
|
22
|
-
from flax import nnx
|
|
27
|
+
|
|
23
28
|
from gpjax.dataset import Dataset
|
|
24
29
|
from gpjax.fit import (
|
|
25
30
|
_check_batch_size,
|
|
@@ -54,10 +59,6 @@ from gpjax.parameters import (
|
|
|
54
59
|
)
|
|
55
60
|
from gpjax.typing import Array
|
|
56
61
|
from gpjax.variational_families import VariationalGaussian
|
|
57
|
-
from jaxtyping import (
|
|
58
|
-
Float,
|
|
59
|
-
Num,
|
|
60
|
-
)
|
|
61
62
|
|
|
62
63
|
|
|
63
64
|
def test_fit_simple() -> None:
|