gpjax 0.12.2__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/gps.py +1 -1
- gpjax/kernels/approximations/rff.py +2 -1
- gpjax/kernels/base.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -15
- gpjax/kernels/non_euclidean/graph.py +7 -6
- gpjax/kernels/non_euclidean/utils.py +30 -0
- gpjax/kernels/stationary/base.py +1 -30
- gpjax/kernels/stationary/matern12.py +1 -1
- gpjax/kernels/stationary/matern32.py +1 -1
- gpjax/kernels/stationary/matern52.py +1 -1
- gpjax/mean_functions.py +1 -1
- gpjax/variational_families.py +69 -5
- {gpjax-0.12.2.dist-info → gpjax-0.13.1.dist-info}/METADATA +18 -10
- {gpjax-0.12.2.dist-info → gpjax-0.13.1.dist-info}/RECORD +17 -17
- {gpjax-0.12.2.dist-info → gpjax-0.13.1.dist-info}/WHEEL +0 -0
- {gpjax-0.12.2.dist-info → gpjax-0.13.1.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.13.1"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
gpjax/gps.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Compute Random Fourier Feature (RFF) kernel approximations."""
|
|
2
2
|
|
|
3
3
|
import beartype.typing as tp
|
|
4
|
+
from flax import nnx
|
|
4
5
|
import jax.random as jr
|
|
5
6
|
from jaxtyping import Float
|
|
6
7
|
|
|
@@ -54,7 +55,7 @@ class RFF(AbstractKernel):
|
|
|
54
55
|
self._check_valid_base_kernel(base_kernel)
|
|
55
56
|
self.base_kernel = base_kernel
|
|
56
57
|
self.num_basis_fns = num_basis_fns
|
|
57
|
-
self.frequencies = frequencies
|
|
58
|
+
self.frequencies = nnx.data(frequencies)
|
|
58
59
|
self.compute_engine = compute_engine
|
|
59
60
|
|
|
60
61
|
if self.frequencies is None:
|
gpjax/kernels/base.py
CHANGED
|
@@ -253,7 +253,7 @@ class CombinationKernel(AbstractKernel):
|
|
|
253
253
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
254
254
|
):
|
|
255
255
|
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
|
|
256
|
-
kernels_list: list[AbstractKernel] = []
|
|
256
|
+
kernels_list: list[AbstractKernel] = nnx.List([])
|
|
257
257
|
for kernel in kernels:
|
|
258
258
|
if not isinstance(kernel, AbstractKernel):
|
|
259
259
|
raise TypeError("can only combine Kernel instances") # pragma: no cover
|
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
import beartype.typing as tp
|
|
18
|
-
import jax.numpy as jnp
|
|
19
18
|
from jaxtyping import (
|
|
20
19
|
Float,
|
|
21
20
|
Num,
|
|
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
|
|
|
39
38
|
def _cross_covariance(
|
|
40
39
|
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
|
|
41
40
|
) -> Float[Array, "N M"]:
|
|
42
|
-
|
|
43
|
-
# RBF kernel's SPDE form.
|
|
44
|
-
S = jnp.power(
|
|
45
|
-
kernel.eigenvalues
|
|
46
|
-
+ 2
|
|
47
|
-
* kernel.smoothness.value
|
|
48
|
-
/ kernel.lengthscale.value
|
|
49
|
-
/ kernel.lengthscale.value,
|
|
50
|
-
-kernel.smoothness.value,
|
|
51
|
-
)
|
|
52
|
-
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
|
|
53
|
-
# Scale the transform eigenvalues by the kernel variance
|
|
54
|
-
S = jnp.multiply(S, kernel.variance.value)
|
|
55
|
-
return kernel(x, y, S=S)
|
|
41
|
+
return kernel(x, y)
|
|
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
AbstractKernelComputation,
|
|
26
26
|
EigenKernelComputation,
|
|
27
27
|
)
|
|
28
|
-
from gpjax.kernels.non_euclidean.utils import
|
|
28
|
+
from gpjax.kernels.non_euclidean.utils import (
|
|
29
|
+
calculate_heat_semigroup,
|
|
30
|
+
jax_gather_nd,
|
|
31
|
+
)
|
|
29
32
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
30
33
|
from gpjax.parameters import (
|
|
31
34
|
Parameter,
|
|
@@ -98,14 +101,12 @@ class GraphKernel(StationaryKernel):
|
|
|
98
101
|
|
|
99
102
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
100
103
|
|
|
101
|
-
def __call__(
|
|
104
|
+
def __call__(
|
|
102
105
|
self,
|
|
103
106
|
x: Int[Array, "N 1"],
|
|
104
|
-
y: Int[Array, "
|
|
105
|
-
*,
|
|
106
|
-
S,
|
|
107
|
-
**kwargs,
|
|
107
|
+
y: Int[Array, "M 1"],
|
|
108
108
|
):
|
|
109
|
+
S = calculate_heat_semigroup(self)
|
|
109
110
|
Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
|
|
110
111
|
jax_gather_nd(self.eigenvectors, y)
|
|
111
112
|
) # shape (n,n)
|
|
@@ -13,6 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import beartype.typing as tp
|
|
19
|
+
import jax.numpy as jnp
|
|
16
20
|
from jaxtyping import (
|
|
17
21
|
Float,
|
|
18
22
|
Int,
|
|
@@ -20,6 +24,9 @@ from jaxtyping import (
|
|
|
20
24
|
|
|
21
25
|
from gpjax.typing import Array
|
|
22
26
|
|
|
27
|
+
if tp.TYPE_CHECKING:
|
|
28
|
+
from gpjax.kernels.non_euclidean.graph import GraphKernel
|
|
29
|
+
|
|
23
30
|
|
|
24
31
|
def jax_gather_nd(
|
|
25
32
|
params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
|
|
@@ -41,3 +48,26 @@ def jax_gather_nd(
|
|
|
41
48
|
"""
|
|
42
49
|
tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
|
|
43
50
|
return params[tuple_indices]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
|
|
54
|
+
r"""Returns the rescaled heat semigroup, S
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
kernel: instance of the graph kernel
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
S
|
|
61
|
+
"""
|
|
62
|
+
S = jnp.power(
|
|
63
|
+
kernel.eigenvalues
|
|
64
|
+
+ 2
|
|
65
|
+
* kernel.smoothness.value
|
|
66
|
+
/ kernel.lengthscale.value
|
|
67
|
+
/ kernel.lengthscale.value,
|
|
68
|
+
-kernel.smoothness.value,
|
|
69
|
+
)
|
|
70
|
+
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
|
|
71
|
+
# Scale the transform eigenvalues by the kernel variance
|
|
72
|
+
S = jnp.multiply(S, kernel.variance.value)
|
|
73
|
+
return S
|
gpjax/kernels/stationary/base.py
CHANGED
|
@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
|
|
|
127
127
|
"""
|
|
128
128
|
|
|
129
129
|
if isinstance(lengthscale, nnx.Variable):
|
|
130
|
-
return
|
|
130
|
+
return _check_lengthscale_dims_compat(lengthscale.value, n_dims)
|
|
131
131
|
|
|
132
132
|
lengthscale = jnp.asarray(lengthscale)
|
|
133
133
|
ls_shape = jnp.shape(lengthscale)
|
|
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
|
|
|
146
146
|
return n_dims
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
def _check_lengthscale_dims_compat_old(
|
|
150
|
-
lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
|
|
151
|
-
n_dims: tp.Union[int, None],
|
|
152
|
-
):
|
|
153
|
-
r"""Check that the lengthscale is compatible with n_dims.
|
|
154
|
-
|
|
155
|
-
If possible, infer the number of input dimensions from the lengthscale.
|
|
156
|
-
"""
|
|
157
|
-
|
|
158
|
-
if isinstance(lengthscale, nnx.Variable):
|
|
159
|
-
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
|
|
160
|
-
|
|
161
|
-
lengthscale = jnp.asarray(lengthscale)
|
|
162
|
-
ls_shape = jnp.shape(lengthscale)
|
|
163
|
-
|
|
164
|
-
if ls_shape == ():
|
|
165
|
-
return lengthscale, n_dims
|
|
166
|
-
elif ls_shape != () and n_dims is None:
|
|
167
|
-
return lengthscale, ls_shape[0]
|
|
168
|
-
elif ls_shape != () and n_dims is not None:
|
|
169
|
-
if ls_shape != (n_dims,):
|
|
170
|
-
raise ValueError(
|
|
171
|
-
"Expected `lengthscale` to be compatible with the number "
|
|
172
|
-
f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
|
|
173
|
-
f"but the number of input dimensions is {n_dims}."
|
|
174
|
-
)
|
|
175
|
-
return lengthscale, n_dims
|
|
176
|
-
|
|
177
|
-
|
|
178
149
|
def _check_lengthscale(lengthscale: tp.Any):
|
|
179
150
|
"""Check that the lengthscale is a valid value."""
|
|
180
151
|
|
|
@@ -32,7 +32,7 @@ class Matern32(StationaryKernel):
|
|
|
32
32
|
lengthscale parameter $\ell$ and variance $\sigma^2$.
|
|
33
33
|
|
|
34
34
|
$$
|
|
35
|
-
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell
|
|
35
|
+
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
|
|
36
36
|
$$
|
|
37
37
|
"""
|
|
38
38
|
|
|
@@ -33,7 +33,7 @@ class Matern52(StationaryKernel):
|
|
|
33
33
|
lengthscale parameter $\ell$ and variance $\sigma^2$.
|
|
34
34
|
|
|
35
35
|
$$
|
|
36
|
-
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell
|
|
36
|
+
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
|
|
37
37
|
$$
|
|
38
38
|
"""
|
|
39
39
|
|
gpjax/mean_functions.py
CHANGED
|
@@ -176,7 +176,7 @@ class CombinationMeanFunction(AbstractMeanFunction):
|
|
|
176
176
|
super().__init__(**kwargs)
|
|
177
177
|
|
|
178
178
|
# Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
|
|
179
|
-
items_list: list[AbstractMeanFunction] = []
|
|
179
|
+
items_list: list[AbstractMeanFunction] = nnx.List([])
|
|
180
180
|
|
|
181
181
|
for item in means:
|
|
182
182
|
if not isinstance(item, AbstractMeanFunction):
|
gpjax/variational_families.py
CHANGED
|
@@ -19,7 +19,10 @@ import beartype.typing as tp
|
|
|
19
19
|
from flax import nnx
|
|
20
20
|
import jax.numpy as jnp
|
|
21
21
|
import jax.scipy as jsp
|
|
22
|
-
from jaxtyping import
|
|
22
|
+
from jaxtyping import (
|
|
23
|
+
Float,
|
|
24
|
+
Int,
|
|
25
|
+
)
|
|
23
26
|
|
|
24
27
|
from gpjax.dataset import Dataset
|
|
25
28
|
from gpjax.distributions import GaussianDistribution
|
|
@@ -108,6 +111,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
108
111
|
self,
|
|
109
112
|
posterior: AbstractPosterior[P, L],
|
|
110
113
|
inducing_inputs: tp.Union[
|
|
114
|
+
Int[Array, "N D"],
|
|
111
115
|
Float[Array, "N D"],
|
|
112
116
|
Real,
|
|
113
117
|
],
|
|
@@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
140
144
|
def __init__(
|
|
141
145
|
self,
|
|
142
146
|
posterior: AbstractPosterior[P, L],
|
|
143
|
-
inducing_inputs: Float[Array, "N D"],
|
|
147
|
+
inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
|
|
144
148
|
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
|
|
145
149
|
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
|
|
146
150
|
jitter: ScalarFloat = 1e-6,
|
|
@@ -156,6 +160,12 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
156
160
|
self.variational_mean = Real(variational_mean)
|
|
157
161
|
self.variational_root_covariance = LowerTriangular(variational_root_covariance)
|
|
158
162
|
|
|
163
|
+
def _fmt_Kzt_Ktt(self, Kzt, Ktt):
|
|
164
|
+
return Kzt, Ktt
|
|
165
|
+
|
|
166
|
+
def _fmt_inducing_inputs(self):
|
|
167
|
+
return self.inducing_inputs.value
|
|
168
|
+
|
|
159
169
|
def prior_kl(self) -> ScalarFloat:
|
|
160
170
|
r"""Compute the prior KL divergence.
|
|
161
171
|
|
|
@@ -178,7 +188,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
178
188
|
# Unpack variational parameters
|
|
179
189
|
variational_mean = self.variational_mean.value
|
|
180
190
|
variational_sqrt = self.variational_root_covariance.value
|
|
181
|
-
inducing_inputs = self.
|
|
191
|
+
inducing_inputs = self._fmt_inducing_inputs()
|
|
182
192
|
|
|
183
193
|
# Unpack mean function and kernel
|
|
184
194
|
mean_function = self.posterior.prior.mean_function
|
|
@@ -202,7 +212,9 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
202
212
|
|
|
203
213
|
return q_inducing.kl_divergence(p_inducing)
|
|
204
214
|
|
|
205
|
-
def predict(
|
|
215
|
+
def predict(
|
|
216
|
+
self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
|
|
217
|
+
) -> GaussianDistribution:
|
|
206
218
|
r"""Compute the predictive distribution of the GP at the test inputs t.
|
|
207
219
|
|
|
208
220
|
This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
|
|
@@ -222,7 +234,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
222
234
|
# Unpack variational parameters
|
|
223
235
|
variational_mean = self.variational_mean.value
|
|
224
236
|
variational_sqrt = self.variational_root_covariance.value
|
|
225
|
-
inducing_inputs = self.
|
|
237
|
+
inducing_inputs = self._fmt_inducing_inputs()
|
|
226
238
|
|
|
227
239
|
# Unpack mean function and kernel
|
|
228
240
|
mean_function = self.posterior.prior.mean_function
|
|
@@ -241,6 +253,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
241
253
|
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
|
|
242
254
|
test_mean = mean_function(test_points)
|
|
243
255
|
|
|
256
|
+
Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)
|
|
257
|
+
|
|
244
258
|
# Lz⁻¹ Kzt
|
|
245
259
|
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
246
260
|
|
|
@@ -259,8 +273,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
259
273
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
260
274
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
261
275
|
)
|
|
276
|
+
|
|
262
277
|
if hasattr(covariance, "to_dense"):
|
|
263
278
|
covariance = covariance.to_dense()
|
|
279
|
+
|
|
264
280
|
covariance = add_jitter(covariance, self.jitter)
|
|
265
281
|
covariance = Dense(covariance)
|
|
266
282
|
|
|
@@ -269,6 +285,53 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
269
285
|
)
|
|
270
286
|
|
|
271
287
|
|
|
288
|
+
class GraphVariationalGaussian(VariationalGaussian[L]):
|
|
289
|
+
r"""A variational Gaussian defined over graph-structured inducing inputs.
|
|
290
|
+
|
|
291
|
+
This subclass adapts the :class:`VariationalGaussian` family to the
|
|
292
|
+
case where the inducing inputs are discrete graph node indices rather
|
|
293
|
+
than continuous spatial coordinates.
|
|
294
|
+
|
|
295
|
+
The main differences are:
|
|
296
|
+
* Inducing inputs are integer node IDs.
|
|
297
|
+
* Kernel matrices are ensured to be dense and 2D.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(
|
|
301
|
+
self,
|
|
302
|
+
posterior: AbstractPosterior[P, L],
|
|
303
|
+
inducing_inputs: Int[Array, "N D"],
|
|
304
|
+
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
|
|
305
|
+
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
|
|
306
|
+
jitter: ScalarFloat = 1e-6,
|
|
307
|
+
):
|
|
308
|
+
super().__init__(
|
|
309
|
+
posterior,
|
|
310
|
+
inducing_inputs,
|
|
311
|
+
variational_mean,
|
|
312
|
+
variational_root_covariance,
|
|
313
|
+
jitter,
|
|
314
|
+
)
|
|
315
|
+
self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)
|
|
316
|
+
|
|
317
|
+
def _fmt_Kzt_Ktt(self, Kzt, Ktt):
|
|
318
|
+
Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
|
|
319
|
+
Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
|
|
320
|
+
Ktt = jnp.atleast_2d(Ktt)
|
|
321
|
+
Kzt = (
|
|
322
|
+
jnp.transpose(jnp.atleast_2d(Kzt)) if Kzt.ndim < 2 else jnp.atleast_2d(Kzt)
|
|
323
|
+
)
|
|
324
|
+
return Kzt, Ktt
|
|
325
|
+
|
|
326
|
+
def _fmt_inducing_inputs(self):
|
|
327
|
+
return self.inducing_inputs
|
|
328
|
+
|
|
329
|
+
@property
|
|
330
|
+
def num_inducing(self) -> int:
|
|
331
|
+
"""The number of inducing inputs."""
|
|
332
|
+
return self.inducing_inputs.shape[0]
|
|
333
|
+
|
|
334
|
+
|
|
272
335
|
class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
273
336
|
r"""The whitened variational Gaussian family of probability distributions.
|
|
274
337
|
|
|
@@ -811,6 +874,7 @@ __all__ = [
|
|
|
811
874
|
"AbstractVariationalFamily",
|
|
812
875
|
"AbstractVariationalGaussian",
|
|
813
876
|
"VariationalGaussian",
|
|
877
|
+
"GraphVariationalGaussian",
|
|
814
878
|
"WhitenedVariationalGaussian",
|
|
815
879
|
"NaturalVariationalGaussian",
|
|
816
880
|
"ExpectationVariationalGaussian",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.13.1
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -11,15 +11,14 @@ License-File: LICENSE.txt
|
|
|
11
11
|
Keywords: gaussian-processes jax machine-learning bayesian
|
|
12
12
|
Classifier: Development Status :: 4 - Beta
|
|
13
13
|
Classifier: Programming Language :: Python
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
17
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
19
18
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
20
|
-
Requires-Python: >=3.
|
|
19
|
+
Requires-Python: >=3.11
|
|
21
20
|
Requires-Dist: beartype>0.16.1
|
|
22
|
-
Requires-Dist: flax>=0.
|
|
21
|
+
Requires-Dist: flax>=0.12.0
|
|
23
22
|
Requires-Dist: jax>=0.5.0
|
|
24
23
|
Requires-Dist: jaxlib>=0.5.0
|
|
25
24
|
Requires-Dist: jaxtyping>0.2.10
|
|
@@ -80,6 +79,7 @@ Description-Content-Type: text/markdown
|
|
|
80
79
|
[](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax)
|
|
81
80
|
[](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
|
|
82
81
|
[](https://badge.fury.io/py/GPJax)
|
|
82
|
+
[](https://anaconda.org/conda-forge/gpjax)
|
|
83
83
|
[](https://doi.org/10.21105/joss.04455)
|
|
84
84
|
[](https://pepy.tech/project/gpjax)
|
|
85
85
|
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
@@ -174,13 +174,21 @@ jupytext --to py:percent example.ipynb
|
|
|
174
174
|
|
|
175
175
|
## Stable version
|
|
176
176
|
|
|
177
|
-
The latest stable version of GPJax can be installed
|
|
178
|
-
pip:
|
|
177
|
+
The latest stable version of GPJax can be installed from [PyPI](https://pypi.org/project/gpjax/):
|
|
179
178
|
|
|
180
179
|
```bash
|
|
181
180
|
pip install gpjax
|
|
182
181
|
```
|
|
183
182
|
|
|
183
|
+
or from [conda-forge](https://github.com/conda-forge/gpjax-feedstock):
|
|
184
|
+
|
|
185
|
+
```bash
|
|
186
|
+
# with Pixi
|
|
187
|
+
pixi add gpjax
|
|
188
|
+
# or with conda
|
|
189
|
+
conda install --channel conda-forge gpjax
|
|
190
|
+
```
|
|
191
|
+
|
|
184
192
|
> **Note**
|
|
185
193
|
>
|
|
186
194
|
> We recommend you check your installation version:
|
|
@@ -199,7 +207,7 @@ pip install gpjax
|
|
|
199
207
|
>
|
|
200
208
|
> We advise you create virtual environment before installing:
|
|
201
209
|
> ```
|
|
202
|
-
> conda create -n gpjax_experimental python=3.
|
|
210
|
+
> conda create -n gpjax_experimental python=3.11.0
|
|
203
211
|
> conda activate gpjax_experimental
|
|
204
212
|
> ```
|
|
205
213
|
|
|
@@ -209,14 +217,14 @@ configuration in development mode.
|
|
|
209
217
|
```bash
|
|
210
218
|
git clone https://github.com/JaxGaussianProcesses/GPJax.git
|
|
211
219
|
cd GPJax
|
|
212
|
-
|
|
213
|
-
|
|
220
|
+
uv venv
|
|
221
|
+
uv sync --extra dev
|
|
214
222
|
```
|
|
215
223
|
|
|
216
224
|
> We recommend you check your installation passes the supplied unit tests:
|
|
217
225
|
>
|
|
218
226
|
> ```python
|
|
219
|
-
>
|
|
227
|
+
> uv run poe all-tests
|
|
220
228
|
> ```
|
|
221
229
|
|
|
222
230
|
# Citing GPJax
|
|
@@ -1,41 +1,41 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=asMWra4r95NSlYQbniJhCQV6pEk39ONOTvnkm-wy8OA,1641
|
|
2
2
|
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
5
|
gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
|
|
6
|
-
gpjax/gps.py,sha256=
|
|
6
|
+
gpjax/gps.py,sha256=fSbHjfQ1vKUZ3CnqdNgJVgFsuTkCwK3yK_c0SQsWYo0,30118
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
8
|
gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
|
|
9
|
-
gpjax/mean_functions.py,sha256=
|
|
9
|
+
gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
|
|
10
10
|
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
11
|
gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
|
|
12
12
|
gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
|
|
13
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=
|
|
15
|
+
gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
|
|
16
16
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
|
-
gpjax/kernels/base.py,sha256=
|
|
17
|
+
gpjax/kernels/base.py,sha256=4vUV6_wnbV64EUaOecVWE9Vz4G5Dg6xUMNsQSis8A1o,11561
|
|
18
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
-
gpjax/kernels/approximations/rff.py,sha256=
|
|
19
|
+
gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
|
|
20
20
|
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
21
21
|
gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
|
|
22
22
|
gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
|
|
23
23
|
gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
|
|
24
24
|
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
25
25
|
gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
|
|
26
|
-
gpjax/kernels/computations/eigen.py,sha256=
|
|
26
|
+
gpjax/kernels/computations/eigen.py,sha256=LuwYVPK0AuDGNSccPmAS8wyDJ2ngkRbc6kNZzqiJEOg,1380
|
|
27
27
|
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
28
|
-
gpjax/kernels/non_euclidean/graph.py,sha256=
|
|
29
|
-
gpjax/kernels/non_euclidean/utils.py,sha256=
|
|
28
|
+
gpjax/kernels/non_euclidean/graph.py,sha256=MukL4ZghUywkau4tVWmHsrAjzTcl9kHN6onEAvxSvjc,4109
|
|
29
|
+
gpjax/kernels/non_euclidean/utils.py,sha256=vxjRX209tOL3jKsw69hQfVtBqkASDZ__4tJOGI4oLjo,2342
|
|
30
30
|
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
31
|
gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
|
|
32
32
|
gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
|
|
33
33
|
gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
|
|
34
34
|
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
|
-
gpjax/kernels/stationary/base.py,sha256=
|
|
36
|
-
gpjax/kernels/stationary/matern12.py,sha256=
|
|
37
|
-
gpjax/kernels/stationary/matern32.py,sha256=
|
|
38
|
-
gpjax/kernels/stationary/matern52.py,sha256=
|
|
35
|
+
gpjax/kernels/stationary/base.py,sha256=YSgur73wqAFZBzt8D6CfujvyjowXoPOK1po10gjzMJo,6038
|
|
36
|
+
gpjax/kernels/stationary/matern12.py,sha256=tfrIP-pJML-kSqVwf8wdVUwC53mhFohMnf7z5OMX3xE,1771
|
|
37
|
+
gpjax/kernels/stationary/matern32.py,sha256=vdrszTtH03PB9D4AKg-YC_56xCupK3tO9L8vH0GYFcc,1964
|
|
38
|
+
gpjax/kernels/stationary/matern52.py,sha256=IVsIKDo9bTZnA98yBmA_Q31peFdSCTXjf0_cVlLQh6k,2017
|
|
39
39
|
gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
|
|
40
40
|
gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
|
|
41
41
|
gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
|
|
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
|
46
46
|
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
47
|
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
48
|
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
-
gpjax-0.
|
|
50
|
-
gpjax-0.
|
|
51
|
-
gpjax-0.
|
|
52
|
-
gpjax-0.
|
|
49
|
+
gpjax-0.13.1.dist-info/METADATA,sha256=qhsF8IgUu4oYTuLdqT56XQ4EKkDnUH_4D7imLxL_nPQ,10400
|
|
50
|
+
gpjax-0.13.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
+
gpjax-0.13.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.13.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|