gpjax 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.12.2"
43
+ __version__ = "0.13.1"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
gpjax/gps.py CHANGED
@@ -16,9 +16,9 @@
16
16
  from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
+ from flax import nnx
19
20
  import jax.numpy as jnp
20
21
  import jax.random as jr
21
- from flax import nnx
22
22
  from jaxtyping import (
23
23
  Float,
24
24
  Num,
@@ -1,6 +1,7 @@
1
1
  """Compute Random Fourier Feature (RFF) kernel approximations."""
2
2
 
3
3
  import beartype.typing as tp
4
+ from flax import nnx
4
5
  import jax.random as jr
5
6
  from jaxtyping import Float
6
7
 
@@ -54,7 +55,7 @@ class RFF(AbstractKernel):
54
55
  self._check_valid_base_kernel(base_kernel)
55
56
  self.base_kernel = base_kernel
56
57
  self.num_basis_fns = num_basis_fns
57
- self.frequencies = frequencies
58
+ self.frequencies = nnx.data(frequencies)
58
59
  self.compute_engine = compute_engine
59
60
 
60
61
  if self.frequencies is None:
gpjax/kernels/base.py CHANGED
@@ -253,7 +253,7 @@ class CombinationKernel(AbstractKernel):
253
253
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
254
254
  ):
255
255
  # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
256
- kernels_list: list[AbstractKernel] = []
256
+ kernels_list: list[AbstractKernel] = nnx.List([])
257
257
  for kernel in kernels:
258
258
  if not isinstance(kernel, AbstractKernel):
259
259
  raise TypeError("can only combine Kernel instances") # pragma: no cover
@@ -15,7 +15,6 @@
15
15
 
16
16
 
17
17
  import beartype.typing as tp
18
- import jax.numpy as jnp
19
18
  from jaxtyping import (
20
19
  Float,
21
20
  Num,
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
39
38
  def _cross_covariance(
40
39
  self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
41
40
  ) -> Float[Array, "N M"]:
42
- # Transform the eigenvalues of the graph Laplacian according to the
43
- # RBF kernel's SPDE form.
44
- S = jnp.power(
45
- kernel.eigenvalues
46
- + 2
47
- * kernel.smoothness.value
48
- / kernel.lengthscale.value
49
- / kernel.lengthscale.value,
50
- -kernel.smoothness.value,
51
- )
52
- S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
53
- # Scale the transform eigenvalues by the kernel variance
54
- S = jnp.multiply(S, kernel.variance.value)
55
- return kernel(x, y, S=S)
41
+ return kernel(x, y)
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  EigenKernelComputation,
27
27
  )
28
- from gpjax.kernels.non_euclidean.utils import jax_gather_nd
28
+ from gpjax.kernels.non_euclidean.utils import (
29
+ calculate_heat_semigroup,
30
+ jax_gather_nd,
31
+ )
29
32
  from gpjax.kernels.stationary.base import StationaryKernel
30
33
  from gpjax.parameters import (
31
34
  Parameter,
@@ -98,14 +101,12 @@ class GraphKernel(StationaryKernel):
98
101
 
99
102
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
100
103
 
101
- def __call__( # TODO not consistent with general kernel interface
104
+ def __call__(
102
105
  self,
103
106
  x: Int[Array, "N 1"],
104
- y: Int[Array, "N 1"],
105
- *,
106
- S,
107
- **kwargs,
107
+ y: Int[Array, "M 1"],
108
108
  ):
109
+ S = calculate_heat_semigroup(self)
109
110
  Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110
111
  jax_gather_nd(self.eigenvectors, y)
111
112
  ) # shape (n,n)
@@ -13,6 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
17
+
18
+ import beartype.typing as tp
19
+ import jax.numpy as jnp
16
20
  from jaxtyping import (
17
21
  Float,
18
22
  Int,
@@ -20,6 +24,9 @@ from jaxtyping import (
20
24
 
21
25
  from gpjax.typing import Array
22
26
 
27
+ if tp.TYPE_CHECKING:
28
+ from gpjax.kernels.non_euclidean.graph import GraphKernel
29
+
23
30
 
24
31
  def jax_gather_nd(
25
32
  params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
@@ -41,3 +48,26 @@ def jax_gather_nd(
41
48
  """
42
49
  tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
43
50
  return params[tuple_indices]
51
+
52
+
53
+ def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
54
+ r"""Returns the rescaled heat semigroup, S
55
+
56
+ Args:
57
+ kernel: instance of the graph kernel
58
+
59
+ Returns:
60
+ S
61
+ """
62
+ S = jnp.power(
63
+ kernel.eigenvalues
64
+ + 2
65
+ * kernel.smoothness.value
66
+ / kernel.lengthscale.value
67
+ / kernel.lengthscale.value,
68
+ -kernel.smoothness.value,
69
+ )
70
+ S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
71
+ # Scale the transform eigenvalues by the kernel variance
72
+ S = jnp.multiply(S, kernel.variance.value)
73
+ return S
@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
127
127
  """
128
128
 
129
129
  if isinstance(lengthscale, nnx.Variable):
130
- return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
130
+ return _check_lengthscale_dims_compat(lengthscale.value, n_dims)
131
131
 
132
132
  lengthscale = jnp.asarray(lengthscale)
133
133
  ls_shape = jnp.shape(lengthscale)
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
146
146
  return n_dims
147
147
 
148
148
 
149
- def _check_lengthscale_dims_compat_old(
150
- lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
151
- n_dims: tp.Union[int, None],
152
- ):
153
- r"""Check that the lengthscale is compatible with n_dims.
154
-
155
- If possible, infer the number of input dimensions from the lengthscale.
156
- """
157
-
158
- if isinstance(lengthscale, nnx.Variable):
159
- return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
160
-
161
- lengthscale = jnp.asarray(lengthscale)
162
- ls_shape = jnp.shape(lengthscale)
163
-
164
- if ls_shape == ():
165
- return lengthscale, n_dims
166
- elif ls_shape != () and n_dims is None:
167
- return lengthscale, ls_shape[0]
168
- elif ls_shape != () and n_dims is not None:
169
- if ls_shape != (n_dims,):
170
- raise ValueError(
171
- "Expected `lengthscale` to be compatible with the number "
172
- f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
173
- f"but the number of input dimensions is {n_dims}."
174
- )
175
- return lengthscale, n_dims
176
-
177
-
178
149
  def _check_lengthscale(lengthscale: tp.Any):
179
150
  """Check that the lengthscale is a valid value."""
180
151
 
@@ -35,7 +35,7 @@ class Matern12(StationaryKernel):
35
35
  lengthscale parameter $\ell$ and variance $\sigma^2$.
36
36
 
37
37
  $$
38
- k(x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell^2}\Bigg)
38
+ k(x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell}\Bigg)
39
39
  $$
40
40
  """
41
41
 
@@ -32,7 +32,7 @@ class Matern32(StationaryKernel):
32
32
  lengthscale parameter $\ell$ and variance $\sigma^2$.
33
33
 
34
34
  $$
35
- k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell^2} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
35
+ k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
36
36
  $$
37
37
  """
38
38
 
@@ -33,7 +33,7 @@ class Matern52(StationaryKernel):
33
33
  lengthscale parameter $\ell$ and variance $\sigma^2$.
34
34
 
35
35
  $$
36
- k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell^2} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
36
+ k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
37
37
  $$
38
38
  """
39
39
 
gpjax/mean_functions.py CHANGED
@@ -176,7 +176,7 @@ class CombinationMeanFunction(AbstractMeanFunction):
176
176
  super().__init__(**kwargs)
177
177
 
178
178
  # Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
179
- items_list: list[AbstractMeanFunction] = []
179
+ items_list: list[AbstractMeanFunction] = nnx.List([])
180
180
 
181
181
  for item in means:
182
182
  if not isinstance(item, AbstractMeanFunction):
@@ -19,7 +19,10 @@ import beartype.typing as tp
19
19
  from flax import nnx
20
20
  import jax.numpy as jnp
21
21
  import jax.scipy as jsp
22
- from jaxtyping import Float
22
+ from jaxtyping import (
23
+ Float,
24
+ Int,
25
+ )
23
26
 
24
27
  from gpjax.dataset import Dataset
25
28
  from gpjax.distributions import GaussianDistribution
@@ -108,6 +111,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
111
  self,
109
112
  posterior: AbstractPosterior[P, L],
110
113
  inducing_inputs: tp.Union[
114
+ Int[Array, "N D"],
111
115
  Float[Array, "N D"],
112
116
  Real,
113
117
  ],
@@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
140
144
  def __init__(
141
145
  self,
142
146
  posterior: AbstractPosterior[P, L],
143
- inducing_inputs: Float[Array, "N D"],
147
+ inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
144
148
  variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
145
149
  variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
146
150
  jitter: ScalarFloat = 1e-6,
@@ -156,6 +160,12 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
156
160
  self.variational_mean = Real(variational_mean)
157
161
  self.variational_root_covariance = LowerTriangular(variational_root_covariance)
158
162
 
163
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
164
+ return Kzt, Ktt
165
+
166
+ def _fmt_inducing_inputs(self):
167
+ return self.inducing_inputs.value
168
+
159
169
  def prior_kl(self) -> ScalarFloat:
160
170
  r"""Compute the prior KL divergence.
161
171
 
@@ -178,7 +188,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
178
188
  # Unpack variational parameters
179
189
  variational_mean = self.variational_mean.value
180
190
  variational_sqrt = self.variational_root_covariance.value
181
- inducing_inputs = self.inducing_inputs.value
191
+ inducing_inputs = self._fmt_inducing_inputs()
182
192
 
183
193
  # Unpack mean function and kernel
184
194
  mean_function = self.posterior.prior.mean_function
@@ -202,7 +212,9 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
202
212
 
203
213
  return q_inducing.kl_divergence(p_inducing)
204
214
 
205
- def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
215
+ def predict(
216
+ self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
217
+ ) -> GaussianDistribution:
206
218
  r"""Compute the predictive distribution of the GP at the test inputs t.
207
219
 
208
220
  This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
@@ -222,7 +234,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
222
234
  # Unpack variational parameters
223
235
  variational_mean = self.variational_mean.value
224
236
  variational_sqrt = self.variational_root_covariance.value
225
- inducing_inputs = self.inducing_inputs.value
237
+ inducing_inputs = self._fmt_inducing_inputs()
226
238
 
227
239
  # Unpack mean function and kernel
228
240
  mean_function = self.posterior.prior.mean_function
@@ -241,6 +253,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
241
253
  Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
254
  test_mean = mean_function(test_points)
243
255
 
256
+ Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)
257
+
244
258
  # Lz⁻¹ Kzt
245
259
  Lz_inv_Kzt = solve(Lz, Kzt)
246
260
 
@@ -259,8 +273,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
259
273
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
260
274
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
261
275
  )
276
+
262
277
  if hasattr(covariance, "to_dense"):
263
278
  covariance = covariance.to_dense()
279
+
264
280
  covariance = add_jitter(covariance, self.jitter)
265
281
  covariance = Dense(covariance)
266
282
 
@@ -269,6 +285,53 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
269
285
  )
270
286
 
271
287
 
288
+ class GraphVariationalGaussian(VariationalGaussian[L]):
289
+ r"""A variational Gaussian defined over graph-structured inducing inputs.
290
+
291
+ This subclass adapts the :class:`VariationalGaussian` family to the
292
+ case where the inducing inputs are discrete graph node indices rather
293
+ than continuous spatial coordinates.
294
+
295
+ The main differences are:
296
+ * Inducing inputs are integer node IDs.
297
+ * Kernel matrices are ensured to be dense and 2D.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ posterior: AbstractPosterior[P, L],
303
+ inducing_inputs: Int[Array, "N D"],
304
+ variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
305
+ variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
306
+ jitter: ScalarFloat = 1e-6,
307
+ ):
308
+ super().__init__(
309
+ posterior,
310
+ inducing_inputs,
311
+ variational_mean,
312
+ variational_root_covariance,
313
+ jitter,
314
+ )
315
+ self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)
316
+
317
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
318
+ Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
319
+ Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
320
+ Ktt = jnp.atleast_2d(Ktt)
321
+ Kzt = (
322
+ jnp.transpose(jnp.atleast_2d(Kzt)) if Kzt.ndim < 2 else jnp.atleast_2d(Kzt)
323
+ )
324
+ return Kzt, Ktt
325
+
326
+ def _fmt_inducing_inputs(self):
327
+ return self.inducing_inputs
328
+
329
+ @property
330
+ def num_inducing(self) -> int:
331
+ """The number of inducing inputs."""
332
+ return self.inducing_inputs.shape[0]
333
+
334
+
272
335
  class WhitenedVariationalGaussian(VariationalGaussian[L]):
273
336
  r"""The whitened variational Gaussian family of probability distributions.
274
337
 
@@ -811,6 +874,7 @@ __all__ = [
811
874
  "AbstractVariationalFamily",
812
875
  "AbstractVariationalGaussian",
813
876
  "VariationalGaussian",
877
+ "GraphVariationalGaussian",
814
878
  "WhitenedVariationalGaussian",
815
879
  "NaturalVariationalGaussian",
816
880
  "ExpectationVariationalGaussian",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.12.2
3
+ Version: 0.13.1
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -11,15 +11,14 @@ License-File: LICENSE.txt
11
11
  Keywords: gaussian-processes jax machine-learning bayesian
12
12
  Classifier: Development Status :: 4 - Beta
13
13
  Classifier: Programming Language :: Python
14
- Classifier: Programming Language :: Python :: 3.10
15
14
  Classifier: Programming Language :: Python :: 3.11
16
15
  Classifier: Programming Language :: Python :: 3.12
17
16
  Classifier: Programming Language :: Python :: 3.13
18
17
  Classifier: Programming Language :: Python :: Implementation :: CPython
19
18
  Classifier: Programming Language :: Python :: Implementation :: PyPy
20
- Requires-Python: >=3.10
19
+ Requires-Python: >=3.11
21
20
  Requires-Dist: beartype>0.16.1
22
- Requires-Dist: flax>=0.10.0
21
+ Requires-Dist: flax>=0.12.0
23
22
  Requires-Dist: jax>=0.5.0
24
23
  Requires-Dist: jaxlib>=0.5.0
25
24
  Requires-Dist: jaxtyping>0.2.10
@@ -80,6 +79,7 @@ Description-Content-Type: text/markdown
80
79
  [![CodeFactor](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax/badge)](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax)
81
80
  [![Netlify Status](https://api.netlify.com/api/v1/badges/d3950e6f-321f-4508-9e52-426b5dae2715/deploy-status)](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
82
81
  [![PyPI version](https://badge.fury.io/py/GPJax.svg)](https://badge.fury.io/py/GPJax)
82
+ [![Conda Version](https://img.shields.io/conda/vn/conda-forge/gpjax.svg)](https://anaconda.org/conda-forge/gpjax)
83
83
  [![DOI](https://joss.theoj.org/papers/10.21105/joss.04455/status.svg)](https://doi.org/10.21105/joss.04455)
84
84
  [![Downloads](https://pepy.tech/badge/gpjax)](https://pepy.tech/project/gpjax)
85
85
  [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
@@ -174,13 +174,21 @@ jupytext --to py:percent example.ipynb
174
174
 
175
175
  ## Stable version
176
176
 
177
- The latest stable version of GPJax can be installed via
178
- pip:
177
+ The latest stable version of GPJax can be installed from [PyPI](https://pypi.org/project/gpjax/):
179
178
 
180
179
  ```bash
181
180
  pip install gpjax
182
181
  ```
183
182
 
183
+ or from [conda-forge](https://github.com/conda-forge/gpjax-feedstock):
184
+
185
+ ```bash
186
+ # with Pixi
187
+ pixi add gpjax
188
+ # or with conda
189
+ conda install --channel conda-forge gpjax
190
+ ```
191
+
184
192
  > **Note**
185
193
  >
186
194
  > We recommend you check your installation version:
@@ -199,7 +207,7 @@ pip install gpjax
199
207
  >
200
208
  > We advise you create virtual environment before installing:
201
209
  > ```
202
- > conda create -n gpjax_experimental python=3.10.0
210
+ > conda create -n gpjax_experimental python=3.11.0
203
211
  > conda activate gpjax_experimental
204
212
  > ```
205
213
 
@@ -209,14 +217,14 @@ configuration in development mode.
209
217
  ```bash
210
218
  git clone https://github.com/JaxGaussianProcesses/GPJax.git
211
219
  cd GPJax
212
- hatch env create
213
- hatch shell
220
+ uv venv
221
+ uv sync --extra dev
214
222
  ```
215
223
 
216
224
  > We recommend you check your installation passes the supplied unit tests:
217
225
  >
218
226
  > ```python
219
- > hatch run dev:test
227
+ > uv run poe all-tests
220
228
  > ```
221
229
 
222
230
  # Citing GPJax
@@ -1,41 +1,41 @@
1
- gpjax/__init__.py,sha256=RzwpixFXn6HNHLVLy4LVXhFUk2c-_ce6n1gjZ2B93F0,1641
1
+ gpjax/__init__.py,sha256=asMWra4r95NSlYQbniJhCQV6pEk39ONOTvnkm-wy8OA,1641
2
2
  gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
5
  gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
6
- gpjax/gps.py,sha256=ipaeYMnPffhKK_JsEHe4fF8GmolQIjXB1YbyfUIL8H4,30118
6
+ gpjax/gps.py,sha256=fSbHjfQ1vKUZ3CnqdNgJVgFsuTkCwK3yK_c0SQsWYo0,30118
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
8
  gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
- gpjax/mean_functions.py,sha256=KiHQXI-b7o0Vi5KQxGm6RNsUjitJc9jEOCq2GrSx4II,6531
9
+ gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
10
10
  gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
11
  gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
12
  gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
15
+ gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
16
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
- gpjax/kernels/base.py,sha256=4Lx8y3kPX4WqQZGRGAsBkqn_i6FlfoAhSn9Tv415xuQ,11551
17
+ gpjax/kernels/base.py,sha256=4vUV6_wnbV64EUaOecVWE9Vz4G5Dg6xUMNsQSis8A1o,11561
18
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
- gpjax/kernels/approximations/rff.py,sha256=GbNUmDPEKEKuMwxUcocxl_9IFR3Q9KEPZXzjy_ZD-2w,4043
19
+ gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
20
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
21
  gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
22
  gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
23
23
  gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
24
  gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
25
  gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
26
- gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
26
+ gpjax/kernels/computations/eigen.py,sha256=LuwYVPK0AuDGNSccPmAS8wyDJ2ngkRbc6kNZzqiJEOg,1380
27
27
  gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
28
- gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
29
- gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
28
+ gpjax/kernels/non_euclidean/graph.py,sha256=MukL4ZghUywkau4tVWmHsrAjzTcl9kHN6onEAvxSvjc,4109
29
+ gpjax/kernels/non_euclidean/utils.py,sha256=vxjRX209tOL3jKsw69hQfVtBqkASDZ__4tJOGI4oLjo,2342
30
30
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
31
  gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
32
32
  gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
33
33
  gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
34
34
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
- gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
36
- gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
37
- gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
38
- gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
35
+ gpjax/kernels/stationary/base.py,sha256=YSgur73wqAFZBzt8D6CfujvyjowXoPOK1po10gjzMJo,6038
36
+ gpjax/kernels/stationary/matern12.py,sha256=tfrIP-pJML-kSqVwf8wdVUwC53mhFohMnf7z5OMX3xE,1771
37
+ gpjax/kernels/stationary/matern32.py,sha256=vdrszTtH03PB9D4AKg-YC_56xCupK3tO9L8vH0GYFcc,1964
38
+ gpjax/kernels/stationary/matern52.py,sha256=IVsIKDo9bTZnA98yBmA_Q31peFdSCTXjf0_cVlLQh6k,2017
39
39
  gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
40
40
  gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
41
41
  gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
46
  gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
47
  gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
48
  gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
- gpjax-0.12.2.dist-info/METADATA,sha256=eckQKXiBXi8XbBeJFviBAIPdBGVWGFQg7wVZwMfPPxs,10129
50
- gpjax-0.12.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.12.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.12.2.dist-info/RECORD,,
49
+ gpjax-0.13.1.dist-info/METADATA,sha256=qhsF8IgUu4oYTuLdqT56XQ4EKkDnUH_4D7imLxL_nPQ,10400
50
+ gpjax-0.13.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.13.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.13.1.dist-info/RECORD,,
File without changes