gpjax 0.11.1__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.11.1"
43
+ __version__ = "0.11.2"
44
44
 
45
45
  __all__ = [
46
46
  "base",
gpjax/citation.py CHANGED
@@ -8,7 +8,12 @@ from beartype.typing import (
8
8
  Dict,
9
9
  Union,
10
10
  )
11
- from jaxlib.xla_extension import PjitFunction
11
+
12
+ try:
13
+ # safely removable once jax>=0.6.0
14
+ from jaxlib.xla_extension import PjitFunction
15
+ except ModuleNotFoundError:
16
+ from jaxlib._jax import PjitFunction
12
17
 
13
18
  from gpjax.kernels import (
14
19
  RFF,
@@ -45,7 +50,7 @@ class AbstractCitation:
45
50
 
46
51
 
47
52
  class NullCitation(AbstractCitation):
48
- def __str__(self) -> str:
53
+ def as_str(self) -> str:
49
54
  return (
50
55
  "No citation available. If you think this is an error, please open a pull"
51
56
  " request."
gpjax/fit.py CHANGED
@@ -15,13 +15,13 @@
15
15
 
16
16
  import typing as tp
17
17
 
18
+ from flax import nnx
18
19
  import jax
20
+ from jax.flatten_util import ravel_pytree
19
21
  import jax.numpy as jnp
20
22
  import jax.random as jr
21
- import optax as ox
22
- from flax import nnx
23
- from jax.flatten_util import ravel_pytree
24
23
  from numpyro.distributions.transforms import Transform
24
+ import optax as ox
25
25
  from scipy.optimize import minimize
26
26
 
27
27
  from gpjax.dataset import Dataset
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.1
3
+ Version: 0.11.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,8 +1,8 @@
1
- gpjax/__init__.py,sha256=TjAAfeZTCEl_zsibA8pV76M1jcHkeFhNfWk_SllfgHY,1686
2
- gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
1
+ gpjax/__init__.py,sha256=ylCFMtXwcMS2zxm4pI3KsnRdnX6bdh26TSdTfUh9l9o,1686
2
+ gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
5
- gpjax/fit.py,sha256=7L2veA6aRNiozZD8fWa-MVDoYFUKjGJahmvjz8Wp-P0,15046
5
+ gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
6
6
  gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
8
  gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
@@ -43,7 +43,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
43
43
  gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
44
  gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
45
45
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
46
- gpjax-0.11.1.dist-info/METADATA,sha256=02crI6D0dsht6XJ8N1ZqNj5ZktmS5NymVfY45pPmEgM,8558
47
- gpjax-0.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
- gpjax-0.11.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
- gpjax-0.11.1.dist-info/RECORD,,
46
+ gpjax-0.11.2.dist-info/METADATA,sha256=lTQVlrUbkxI7fU9Gdnac_eoNRyjCHEoEuEvvWbKmDqM,8558
47
+ gpjax-0.11.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
+ gpjax-0.11.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
+ gpjax-0.11.2.dist-info/RECORD,,
File without changes