gpjax 0.13.1__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.13.1"
43
+ __version__ = "0.13.2"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
@@ -6,9 +6,7 @@ from jaxtyping import Float
6
6
  import gpjax
7
7
  from gpjax.kernels.computations.base import AbstractKernelComputation
8
8
  from gpjax.linalg import (
9
- Dense,
10
9
  Diagonal,
11
- psd,
12
10
  )
13
11
  from gpjax.typing import Array
14
12
 
@@ -27,9 +25,9 @@ class BasisFunctionComputation(AbstractKernelComputation):
27
25
  z2 = self.compute_features(kernel, y)
28
26
  return self.scaling(kernel) * jnp.matmul(z1, z2.T)
29
27
 
30
- def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
28
+ def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]:
31
29
  z1 = self.compute_features(kernel, inputs)
32
- return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
30
+ return self.scaling(kernel) * jnp.matmul(z1, z1.T)
33
31
 
34
32
  def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
35
33
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.1
3
+ Version: 0.13.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,4 +1,4 @@
1
- gpjax/__init__.py,sha256=asMWra4r95NSlYQbniJhCQV6pEk39ONOTvnkm-wy8OA,1641
1
+ gpjax/__init__.py,sha256=jpuNxARmW3gOmXFAhJvrXeMk_rvYDn49E-8wy3ATZKg,1641
2
2
  gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
@@ -19,7 +19,7 @@ gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfj
19
19
  gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
20
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
21
  gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
- gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
22
+ gpjax/kernels/computations/basis_functions.py,sha256=Y6OS4UEvo-Mdn7yRMWPde5rOvewDHIleaUIVBOpCTd4,2466
23
23
  gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
24
  gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
25
  gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
46
  gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
47
  gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
48
  gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
- gpjax-0.13.1.dist-info/METADATA,sha256=qhsF8IgUu4oYTuLdqT56XQ4EKkDnUH_4D7imLxL_nPQ,10400
50
- gpjax-0.13.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.13.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.13.1.dist-info/RECORD,,
49
+ gpjax-0.13.2.dist-info/METADATA,sha256=_dkTRApmtHexaDFVnH8FhpOf_YkIyci2Gihg9elVeII,10400
50
+ gpjax-0.13.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.13.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.13.2.dist-info/RECORD,,
File without changes