gpjax 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.13.0"
43
+ __version__ = "0.13.2"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
@@ -6,9 +6,7 @@ from jaxtyping import Float
6
6
  import gpjax
7
7
  from gpjax.kernels.computations.base import AbstractKernelComputation
8
8
  from gpjax.linalg import (
9
- Dense,
10
9
  Diagonal,
11
- psd,
12
10
  )
13
11
  from gpjax.typing import Array
14
12
 
@@ -27,9 +25,9 @@ class BasisFunctionComputation(AbstractKernelComputation):
27
25
  z2 = self.compute_features(kernel, y)
28
26
  return self.scaling(kernel) * jnp.matmul(z1, z2.T)
29
27
 
30
- def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
28
+ def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]:
31
29
  z1 = self.compute_features(kernel, inputs)
32
- return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
30
+ return self.scaling(kernel) * jnp.matmul(z1, z1.T)
33
31
 
34
32
  def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
35
33
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -15,7 +15,6 @@
15
15
 
16
16
 
17
17
  import beartype.typing as tp
18
- import jax.numpy as jnp
19
18
  from jaxtyping import (
20
19
  Float,
21
20
  Num,
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
39
38
  def _cross_covariance(
40
39
  self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
41
40
  ) -> Float[Array, "N M"]:
42
- # Transform the eigenvalues of the graph Laplacian according to the
43
- # RBF kernel's SPDE form.
44
- S = jnp.power(
45
- kernel.eigenvalues
46
- + 2
47
- * kernel.smoothness.value
48
- / kernel.lengthscale.value
49
- / kernel.lengthscale.value,
50
- -kernel.smoothness.value,
51
- )
52
- S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
53
- # Scale the transform eigenvalues by the kernel variance
54
- S = jnp.multiply(S, kernel.variance.value)
55
- return kernel(x, y, S=S)
41
+ return kernel(x, y)
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  EigenKernelComputation,
27
27
  )
28
- from gpjax.kernels.non_euclidean.utils import jax_gather_nd
28
+ from gpjax.kernels.non_euclidean.utils import (
29
+ calculate_heat_semigroup,
30
+ jax_gather_nd,
31
+ )
29
32
  from gpjax.kernels.stationary.base import StationaryKernel
30
33
  from gpjax.parameters import (
31
34
  Parameter,
@@ -98,14 +101,12 @@ class GraphKernel(StationaryKernel):
98
101
 
99
102
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
100
103
 
101
- def __call__( # TODO not consistent with general kernel interface
104
+ def __call__(
102
105
  self,
103
106
  x: Int[Array, "N 1"],
104
- y: Int[Array, "N 1"],
105
- *,
106
- S,
107
- **kwargs,
107
+ y: Int[Array, "M 1"],
108
108
  ):
109
+ S = calculate_heat_semigroup(self)
109
110
  Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110
111
  jax_gather_nd(self.eigenvectors, y)
111
112
  ) # shape (n,n)
@@ -13,6 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
17
+
18
+ import beartype.typing as tp
19
+ import jax.numpy as jnp
16
20
  from jaxtyping import (
17
21
  Float,
18
22
  Int,
@@ -20,6 +24,9 @@ from jaxtyping import (
20
24
 
21
25
  from gpjax.typing import Array
22
26
 
27
+ if tp.TYPE_CHECKING:
28
+ from gpjax.kernels.non_euclidean.graph import GraphKernel
29
+
23
30
 
24
31
  def jax_gather_nd(
25
32
  params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
@@ -41,3 +48,26 @@ def jax_gather_nd(
41
48
  """
42
49
  tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
43
50
  return params[tuple_indices]
51
+
52
+
53
+ def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
54
+ r"""Returns the rescaled heat semigroup, S
55
+
56
+ Args:
57
+ kernel: instance of the graph kernel
58
+
59
+ Returns:
60
+ S
61
+ """
62
+ S = jnp.power(
63
+ kernel.eigenvalues
64
+ + 2
65
+ * kernel.smoothness.value
66
+ / kernel.lengthscale.value
67
+ / kernel.lengthscale.value,
68
+ -kernel.smoothness.value,
69
+ )
70
+ S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
71
+ # Scale the transform eigenvalues by the kernel variance
72
+ S = jnp.multiply(S, kernel.variance.value)
73
+ return S
@@ -19,7 +19,10 @@ import beartype.typing as tp
19
19
  from flax import nnx
20
20
  import jax.numpy as jnp
21
21
  import jax.scipy as jsp
22
- from jaxtyping import Float
22
+ from jaxtyping import (
23
+ Float,
24
+ Int,
25
+ )
23
26
 
24
27
  from gpjax.dataset import Dataset
25
28
  from gpjax.distributions import GaussianDistribution
@@ -108,6 +111,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
111
  self,
109
112
  posterior: AbstractPosterior[P, L],
110
113
  inducing_inputs: tp.Union[
114
+ Int[Array, "N D"],
111
115
  Float[Array, "N D"],
112
116
  Real,
113
117
  ],
@@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
140
144
  def __init__(
141
145
  self,
142
146
  posterior: AbstractPosterior[P, L],
143
- inducing_inputs: Float[Array, "N D"],
147
+ inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
144
148
  variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
145
149
  variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
146
150
  jitter: ScalarFloat = 1e-6,
@@ -156,6 +160,12 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
156
160
  self.variational_mean = Real(variational_mean)
157
161
  self.variational_root_covariance = LowerTriangular(variational_root_covariance)
158
162
 
163
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
164
+ return Kzt, Ktt
165
+
166
+ def _fmt_inducing_inputs(self):
167
+ return self.inducing_inputs.value
168
+
159
169
  def prior_kl(self) -> ScalarFloat:
160
170
  r"""Compute the prior KL divergence.
161
171
 
@@ -178,7 +188,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
178
188
  # Unpack variational parameters
179
189
  variational_mean = self.variational_mean.value
180
190
  variational_sqrt = self.variational_root_covariance.value
181
- inducing_inputs = self.inducing_inputs.value
191
+ inducing_inputs = self._fmt_inducing_inputs()
182
192
 
183
193
  # Unpack mean function and kernel
184
194
  mean_function = self.posterior.prior.mean_function
@@ -202,7 +212,9 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
202
212
 
203
213
  return q_inducing.kl_divergence(p_inducing)
204
214
 
205
- def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
215
+ def predict(
216
+ self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
217
+ ) -> GaussianDistribution:
206
218
  r"""Compute the predictive distribution of the GP at the test inputs t.
207
219
 
208
220
  This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
@@ -222,7 +234,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
222
234
  # Unpack variational parameters
223
235
  variational_mean = self.variational_mean.value
224
236
  variational_sqrt = self.variational_root_covariance.value
225
- inducing_inputs = self.inducing_inputs.value
237
+ inducing_inputs = self._fmt_inducing_inputs()
226
238
 
227
239
  # Unpack mean function and kernel
228
240
  mean_function = self.posterior.prior.mean_function
@@ -241,6 +253,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
241
253
  Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
254
  test_mean = mean_function(test_points)
243
255
 
256
+ Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)
257
+
244
258
  # Lz⁻¹ Kzt
245
259
  Lz_inv_Kzt = solve(Lz, Kzt)
246
260
 
@@ -259,8 +273,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
259
273
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
260
274
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
261
275
  )
276
+
262
277
  if hasattr(covariance, "to_dense"):
263
278
  covariance = covariance.to_dense()
279
+
264
280
  covariance = add_jitter(covariance, self.jitter)
265
281
  covariance = Dense(covariance)
266
282
 
@@ -269,6 +285,53 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
269
285
  )
270
286
 
271
287
 
288
+ class GraphVariationalGaussian(VariationalGaussian[L]):
289
+ r"""A variational Gaussian defined over graph-structured inducing inputs.
290
+
291
+ This subclass adapts the :class:`VariationalGaussian` family to the
292
+ case where the inducing inputs are discrete graph node indices rather
293
+ than continuous spatial coordinates.
294
+
295
+ The main differences are:
296
+ * Inducing inputs are integer node IDs.
297
+ * Kernel matrices are ensured to be dense and 2D.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ posterior: AbstractPosterior[P, L],
303
+ inducing_inputs: Int[Array, "N D"],
304
+ variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
305
+ variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
306
+ jitter: ScalarFloat = 1e-6,
307
+ ):
308
+ super().__init__(
309
+ posterior,
310
+ inducing_inputs,
311
+ variational_mean,
312
+ variational_root_covariance,
313
+ jitter,
314
+ )
315
+ self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)
316
+
317
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
318
+ Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
319
+ Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
320
+ Ktt = jnp.atleast_2d(Ktt)
321
+ Kzt = (
322
+ jnp.transpose(jnp.atleast_2d(Kzt)) if Kzt.ndim < 2 else jnp.atleast_2d(Kzt)
323
+ )
324
+ return Kzt, Ktt
325
+
326
+ def _fmt_inducing_inputs(self):
327
+ return self.inducing_inputs
328
+
329
+ @property
330
+ def num_inducing(self) -> int:
331
+ """The number of inducing inputs."""
332
+ return self.inducing_inputs.shape[0]
333
+
334
+
272
335
  class WhitenedVariationalGaussian(VariationalGaussian[L]):
273
336
  r"""The whitened variational Gaussian family of probability distributions.
274
337
 
@@ -811,6 +874,7 @@ __all__ = [
811
874
  "AbstractVariationalFamily",
812
875
  "AbstractVariationalGaussian",
813
876
  "VariationalGaussian",
877
+ "GraphVariationalGaussian",
814
878
  "WhitenedVariationalGaussian",
815
879
  "NaturalVariationalGaussian",
816
880
  "ExpectationVariationalGaussian",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.0
3
+ Version: 0.13.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,4 +1,4 @@
1
- gpjax/__init__.py,sha256=dB99f4sP0k5SfxABCyVMYkLVf-UVCXY-iR6153RRpv8,1641
1
+ gpjax/__init__.py,sha256=jpuNxARmW3gOmXFAhJvrXeMk_rvYDn49E-8wy3ATZKg,1641
2
2
  gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
@@ -12,21 +12,21 @@ gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
12
  gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
15
+ gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
16
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
17
  gpjax/kernels/base.py,sha256=4vUV6_wnbV64EUaOecVWE9Vz4G5Dg6xUMNsQSis8A1o,11561
18
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
19
  gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
20
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
21
  gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
- gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
22
+ gpjax/kernels/computations/basis_functions.py,sha256=Y6OS4UEvo-Mdn7yRMWPde5rOvewDHIleaUIVBOpCTd4,2466
23
23
  gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
24
  gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
25
  gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
26
- gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
26
+ gpjax/kernels/computations/eigen.py,sha256=LuwYVPK0AuDGNSccPmAS8wyDJ2ngkRbc6kNZzqiJEOg,1380
27
27
  gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
28
- gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
29
- gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
28
+ gpjax/kernels/non_euclidean/graph.py,sha256=MukL4ZghUywkau4tVWmHsrAjzTcl9kHN6onEAvxSvjc,4109
29
+ gpjax/kernels/non_euclidean/utils.py,sha256=vxjRX209tOL3jKsw69hQfVtBqkASDZ__4tJOGI4oLjo,2342
30
30
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
31
  gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
32
32
  gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
46
  gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
47
  gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
48
  gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
- gpjax-0.13.0.dist-info/METADATA,sha256=LZllpjzOW4QDzpUX-haSNrLMXzkiREGpaYveG6jeufE,10400
50
- gpjax-0.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.13.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.13.0.dist-info/RECORD,,
49
+ gpjax-0.13.2.dist-info/METADATA,sha256=_dkTRApmtHexaDFVnH8FhpOf_YkIyci2Gihg9elVeII,10400
50
+ gpjax-0.13.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.13.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.13.2.dist-info/RECORD,,
File without changes