gpjax 0.11.1__tar.gz → 0.11.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. {gpjax-0.11.1 → gpjax-0.11.2}/PKG-INFO +1 -1
  2. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/__init__.py +1 -1
  3. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/citation.py +7 -2
  4. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_fit.py +7 -6
  5. gpjax-0.11.2/uv.lock +832 -0
  6. {gpjax-0.11.1 → gpjax-0.11.2}/.github/CODE_OF_CONDUCT.md +0 -0
  7. {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  8. {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  9. {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  10. {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  11. {gpjax-0.11.1 → gpjax-0.11.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  12. {gpjax-0.11.1 → gpjax-0.11.2}/.github/codecov.yml +0 -0
  13. {gpjax-0.11.1 → gpjax-0.11.2}/.github/labels.yml +0 -0
  14. {gpjax-0.11.1 → gpjax-0.11.2}/.github/pull_request_template.md +0 -0
  15. {gpjax-0.11.1 → gpjax-0.11.2}/.github/release-drafter.yml +0 -0
  16. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/build_docs.yml +0 -0
  17. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/integration.yml +0 -0
  18. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/pr_greeting.yml +0 -0
  19. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/ruff.yml +0 -0
  20. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/stale_prs.yml +0 -0
  21. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/test_docs.yml +0 -0
  22. {gpjax-0.11.1 → gpjax-0.11.2}/.github/workflows/tests.yml +0 -0
  23. {gpjax-0.11.1 → gpjax-0.11.2}/.gitignore +0 -0
  24. {gpjax-0.11.1 → gpjax-0.11.2}/CITATION.bib +0 -0
  25. {gpjax-0.11.1 → gpjax-0.11.2}/LICENSE.txt +0 -0
  26. {gpjax-0.11.1 → gpjax-0.11.2}/Makefile +0 -0
  27. {gpjax-0.11.1 → gpjax-0.11.2}/README.md +0 -0
  28. {gpjax-0.11.1 → gpjax-0.11.2}/docs/CODE_OF_CONDUCT.md +0 -0
  29. {gpjax-0.11.1 → gpjax-0.11.2}/docs/GOVERNANCE.md +0 -0
  30. {gpjax-0.11.1 → gpjax-0.11.2}/docs/contributing.md +0 -0
  31. {gpjax-0.11.1 → gpjax-0.11.2}/docs/design.md +0 -0
  32. {gpjax-0.11.1 → gpjax-0.11.2}/docs/index.md +0 -0
  33. {gpjax-0.11.1 → gpjax-0.11.2}/docs/index.rst +0 -0
  34. {gpjax-0.11.1 → gpjax-0.11.2}/docs/installation.md +0 -0
  35. {gpjax-0.11.1 → gpjax-0.11.2}/docs/javascripts/katex.js +0 -0
  36. {gpjax-0.11.1 → gpjax-0.11.2}/docs/refs.bib +0 -0
  37. {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/gen_examples.py +0 -0
  38. {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/gen_pages.py +0 -0
  39. {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/notebook_converter.py +0 -0
  40. {gpjax-0.11.1 → gpjax-0.11.2}/docs/scripts/sharp_bits_figure.py +0 -0
  41. {gpjax-0.11.1 → gpjax-0.11.2}/docs/sharp_bits.md +0 -0
  42. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/GP.pdf +0 -0
  43. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/GP.svg +0 -0
  44. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/bijector_figure.svg +0 -0
  45. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/css/gpjax_theme.css +0 -0
  46. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/favicon.ico +0 -0
  47. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax.mplstyle +0 -0
  48. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax_logo.pdf +0 -0
  49. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/gpjax_logo.svg +0 -0
  50. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/lato.ttf +0 -0
  51. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/logo.png +0 -0
  52. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/logo.svg +0 -0
  53. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/jaxkern/main.py +0 -0
  54. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/step_size_figure.png +0 -0
  55. {gpjax-0.11.1 → gpjax-0.11.2}/docs/static/step_size_figure.svg +0 -0
  56. {gpjax-0.11.1 → gpjax-0.11.2}/docs/stylesheets/extra.css +0 -0
  57. {gpjax-0.11.1 → gpjax-0.11.2}/docs/stylesheets/permalinks.css +0 -0
  58. {gpjax-0.11.1 → gpjax-0.11.2}/examples/backend.py +0 -0
  59. {gpjax-0.11.1 → gpjax-0.11.2}/examples/barycentres/barycentre_gp.gif +0 -0
  60. {gpjax-0.11.1 → gpjax-0.11.2}/examples/barycentres.py +0 -0
  61. {gpjax-0.11.1 → gpjax-0.11.2}/examples/classification.py +0 -0
  62. {gpjax-0.11.1 → gpjax-0.11.2}/examples/collapsed_vi.py +0 -0
  63. {gpjax-0.11.1 → gpjax-0.11.2}/examples/constructing_new_kernels.py +0 -0
  64. {gpjax-0.11.1 → gpjax-0.11.2}/examples/data/max_tempeature_switzerland.csv +0 -0
  65. {gpjax-0.11.1 → gpjax-0.11.2}/examples/data/yacht_hydrodynamics.data +0 -0
  66. {gpjax-0.11.1 → gpjax-0.11.2}/examples/deep_kernels.py +0 -0
  67. {gpjax-0.11.1 → gpjax-0.11.2}/examples/gpjax.mplstyle +0 -0
  68. {gpjax-0.11.1 → gpjax-0.11.2}/examples/graph_kernels.py +0 -0
  69. {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
  70. {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps/generating_process.png +0 -0
  71. {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_gps.py +0 -0
  72. {gpjax-0.11.1 → gpjax-0.11.2}/examples/intro_to_kernels.py +0 -0
  73. {gpjax-0.11.1 → gpjax-0.11.2}/examples/likelihoods_guide.py +0 -0
  74. {gpjax-0.11.1 → gpjax-0.11.2}/examples/oceanmodelling.py +0 -0
  75. {gpjax-0.11.1 → gpjax-0.11.2}/examples/poisson.py +0 -0
  76. {gpjax-0.11.1 → gpjax-0.11.2}/examples/regression.py +0 -0
  77. {gpjax-0.11.1 → gpjax-0.11.2}/examples/uncollapsed_vi.py +0 -0
  78. {gpjax-0.11.1 → gpjax-0.11.2}/examples/utils.py +0 -0
  79. {gpjax-0.11.1 → gpjax-0.11.2}/examples/yacht.py +0 -0
  80. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/dataset.py +0 -0
  81. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/distributions.py +0 -0
  82. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/fit.py +3 -3
  83. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/gps.py +0 -0
  84. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/integrators.py +0 -0
  85. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/__init__.py +0 -0
  86. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/approximations/__init__.py +0 -0
  87. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/approximations/rff.py +0 -0
  88. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/base.py +0 -0
  89. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/__init__.py +0 -0
  90. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/base.py +0 -0
  91. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/basis_functions.py +0 -0
  92. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  93. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/dense.py +0 -0
  94. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/diagonal.py +0 -0
  95. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/computations/eigen.py +0 -0
  96. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  97. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/graph.py +0 -0
  98. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/non_euclidean/utils.py +0 -0
  99. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
  100. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  101. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/linear.py +0 -0
  102. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  103. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/__init__.py +0 -0
  104. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/base.py +0 -0
  105. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern12.py +0 -0
  106. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern32.py +0 -0
  107. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/matern52.py +0 -0
  108. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/periodic.py +0 -0
  109. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  110. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  111. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/rbf.py +0 -0
  112. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/utils.py +0 -0
  113. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/kernels/stationary/white.py +0 -0
  114. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/likelihoods.py +0 -0
  115. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/lower_cholesky.py +0 -0
  116. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/mean_functions.py +0 -0
  117. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/numpyro_extras.py +0 -0
  118. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/objectives.py +0 -0
  119. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/parameters.py +0 -0
  120. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/scan.py +0 -0
  121. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/typing.py +0 -0
  122. {gpjax-0.11.1 → gpjax-0.11.2}/gpjax/variational_families.py +0 -0
  123. {gpjax-0.11.1 → gpjax-0.11.2}/mkdocs.yml +0 -0
  124. {gpjax-0.11.1 → gpjax-0.11.2}/pyproject.toml +0 -0
  125. {gpjax-0.11.1 → gpjax-0.11.2}/static/CONTRIBUTING.md +0 -0
  126. {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.bib +0 -0
  127. {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.md +0 -0
  128. {gpjax-0.11.1 → gpjax-0.11.2}/static/paper.pdf +0 -0
  129. {gpjax-0.11.1 → gpjax-0.11.2}/tests/__init__.py +0 -0
  130. {gpjax-0.11.1 → gpjax-0.11.2}/tests/conftest.py +0 -0
  131. {gpjax-0.11.1 → gpjax-0.11.2}/tests/integration_tests.py +0 -0
  132. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_citations.py +0 -0
  133. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_dataset.py +0 -0
  134. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_gaussian_distribution.py +0 -0
  135. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_gps.py +0 -0
  136. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_integrators.py +0 -0
  137. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/__init__.py +0 -0
  138. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_approximations.py +0 -0
  139. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_base.py +0 -0
  140. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_computation.py +0 -0
  141. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_non_euclidean.py +0 -0
  142. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_nonstationary.py +0 -0
  143. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_stationary.py +0 -0
  144. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_kernels/test_utils.py +0 -0
  145. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_likelihoods.py +0 -0
  146. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_lower_cholesky.py +0 -0
  147. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_markdown.py +0 -0
  148. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_mean_functions.py +0 -0
  149. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_numpyro_extras.py +0 -0
  150. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_objectives.py +0 -0
  151. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_parameters.py +0 -0
  152. {gpjax-0.11.1 → gpjax-0.11.2}/tests/test_variational_families.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.1
3
+ Version: 0.11.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.11.1"
43
+ __version__ = "0.11.2"
44
44
 
45
45
  __all__ = [
46
46
  "base",
@@ -8,7 +8,12 @@ from beartype.typing import (
8
8
  Dict,
9
9
  Union,
10
10
  )
11
- from jaxlib.xla_extension import PjitFunction
11
+
12
+ try:
13
+ # safely removable once jax>=0.6.0
14
+ from jaxlib.xla_extension import PjitFunction
15
+ except ModuleNotFoundError:
16
+ from jaxlib._jax import PjitFunction
12
17
 
13
18
  from gpjax.kernels import (
14
19
  RFF,
@@ -45,7 +50,7 @@ class AbstractCitation:
45
50
 
46
51
 
47
52
  class NullCitation(AbstractCitation):
48
- def __str__(self) -> str:
53
+ def as_str(self) -> str:
49
54
  return (
50
55
  "No citation available. If you think this is an error, please open a pull"
51
56
  " request."
@@ -13,13 +13,18 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from beartype.typing import Any
17
+ from flax import nnx
16
18
  import jax.numpy as jnp
17
19
  import jax.random as jr
20
+ from jaxtyping import (
21
+ Float,
22
+ Num,
23
+ )
18
24
  import optax as ox
19
25
  import pytest
20
26
  import scipy
21
- from beartype.typing import Any
22
- from flax import nnx
27
+
23
28
  from gpjax.dataset import Dataset
24
29
  from gpjax.fit import (
25
30
  _check_batch_size,
@@ -54,10 +59,6 @@ from gpjax.parameters import (
54
59
  )
55
60
  from gpjax.typing import Array
56
61
  from gpjax.variational_families import VariationalGaussian
57
- from jaxtyping import (
58
- Float,
59
- Num,
60
- )
61
62
 
62
63
 
63
64
  def test_fit_simple() -> None: