gpjax 0.11.2__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/objectives.py CHANGED
@@ -1,13 +1,5 @@
1
1
  from typing import TypeVar
2
2
 
3
- from cola.annotations import PSD
4
- from cola.linalg.decompositions.decompositions import Cholesky
5
- from cola.linalg.inverse.inv import (
6
- inv,
7
- solve,
8
- )
9
- from cola.linalg.trace.diag_trace import diag
10
- from cola.ops.operators import I_like
11
3
  from flax import nnx
12
4
  from jax import vmap
13
5
  import jax.numpy as jnp
@@ -22,7 +14,12 @@ from gpjax.gps import (
22
14
  ConjugatePosterior,
23
15
  NonConjugatePosterior,
24
16
  )
25
- from gpjax.lower_cholesky import lower_cholesky
17
+ from gpjax.linalg import (
18
+ Dense,
19
+ lower_cholesky,
20
+ psd,
21
+ solve,
22
+ )
26
23
  from gpjax.typing import (
27
24
  Array,
28
25
  ScalarFloat,
@@ -100,9 +97,9 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
100
97
 
101
98
  # Σ = (Kxx + Io²) = LLᵀ
102
99
  Kxx = posterior.prior.kernel.gram(x)
103
- Kxx += I_like(Kxx) * posterior.prior.jitter
104
- Sigma = Kxx + I_like(Kxx) * obs_noise
105
- Sigma = PSD(Sigma)
100
+ Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
101
+ Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
102
+ Sigma = psd(Dense(Sigma_dense))
106
103
 
107
104
  # p(y | x, θ), where θ are the model hyperparameters:
108
105
  mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma)
@@ -164,11 +161,14 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
164
161
 
165
162
  # Σ = (Kxx + Io²)
166
163
  Kxx = posterior.prior.kernel.gram(x)
167
- Sigma = Kxx + I_like(Kxx) * (obs_var + posterior.prior.jitter)
168
- Sigma = PSD(Sigma) # [N, N]
164
+ Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * (
165
+ obs_var + posterior.prior.jitter
166
+ )
167
+ Sigma = psd(Dense(Sigma_dense)) # [N, N]
169
168
 
170
- Sigma_inv_y = solve(Sigma, y - mx, Cholesky()) # [N, 1]
171
- Sigma_inv_diag = diag(inv(Sigma, Cholesky()))[:, None] # [N, 1]
169
+ Sigma_inv_y = solve(Sigma, y - mx) # [N, 1]
170
+ Sigma_inv = jnp.linalg.inv(Sigma.to_dense())
171
+ Sigma_inv_diag = jnp.diag(Sigma_inv)[:, None] # [N, 1]
172
172
 
173
173
  loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
174
174
  loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
@@ -213,8 +213,8 @@ def log_posterior_density(
213
213
 
214
214
  # Gram matrix
215
215
  Kxx = posterior.prior.kernel.gram(x)
216
- Kxx += I_like(Kxx) * posterior.prior.jitter
217
- Kxx = PSD(Kxx)
216
+ Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
217
+ Kxx = psd(Dense(Kxx_dense))
218
218
  Lx = lower_cholesky(Kxx)
219
219
 
220
220
  # Compute the prior mean function
@@ -349,8 +349,8 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
349
349
  noise = variational_family.posterior.likelihood.obs_stddev.value**2
350
350
  z = variational_family.inducing_inputs.value
351
351
  Kzz = kernel.gram(z)
352
- Kzz += I_like(Kzz) * variational_family.jitter
353
- Kzz = PSD(Kzz)
352
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * variational_family.jitter
353
+ Kzz = psd(Dense(Kzz_dense))
354
354
  Kzx = kernel.cross_covariance(z, x)
355
355
  Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
356
356
  μx = mean_function(x)
@@ -383,7 +383,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
383
383
  #
384
384
  # with A and B defined as above.
385
385
 
386
- A = solve(Lz, Kzx, Cholesky()) / jnp.sqrt(noise)
386
+ A = solve(Lz, Kzx) / jnp.sqrt(noise)
387
387
 
388
388
  # AAᵀ
389
389
  AAT = jnp.matmul(A, A.T)
gpjax/parameters.py CHANGED
@@ -21,23 +21,20 @@ def transform(
21
21
  r"""Transforms parameters using a bijector.
22
22
 
23
23
  Example:
24
- ```pycon
25
24
  >>> from gpjax.parameters import PositiveReal, transform
26
25
  >>> import jax.numpy as jnp
27
26
  >>> import numpyro.distributions.transforms as npt
28
27
  >>> from flax import nnx
29
28
  >>> params = nnx.State(
30
- >>> {
31
- >>> "a": PositiveReal(jnp.array([1.0])),
32
- >>> "b": PositiveReal(jnp.array([2.0])),
33
- >>> }
34
- >>> )
29
+ ... {
30
+ ... "a": PositiveReal(jnp.array([1.0])),
31
+ ... "b": PositiveReal(jnp.array([2.0])),
32
+ ... }
33
+ ... )
35
34
  >>> params_bijection = {'positive': npt.SoftplusTransform()}
36
35
  >>> transformed_params = transform(params, params_bijection)
37
36
  >>> print(transformed_params["a"].value)
38
- [1.3132617]
39
- ```
40
-
37
+ [1.3132617]
41
38
 
42
39
  Args:
43
40
  params: A nnx.State object containing parameters to be transformed.
@@ -49,7 +46,7 @@ def transform(
49
46
  """
50
47
 
51
48
  def _inner(param):
52
- bijector = params_bijection.get(param._tag, npt.IdentityTransform())
49
+ bijector = params_bijection.get(param.tag, npt.IdentityTransform())
53
50
  if inverse:
54
51
  transformed_value = bijector.inv(param.value)
55
52
  else:
@@ -60,10 +57,11 @@ def transform(
60
57
 
61
58
  gp_params, *other_params = params.split(Parameter, ...)
62
59
 
60
+ # Transform each parameter in the state
63
61
  transformed_gp_params: nnx.State = jtu.tree_map(
64
- lambda x: _inner(x),
62
+ lambda x: _inner(x) if isinstance(x, Parameter) else x,
65
63
  gp_params,
66
- is_leaf=lambda x: isinstance(x, nnx.VariableState),
64
+ is_leaf=lambda x: isinstance(x, Parameter),
67
65
  )
68
66
  return nnx.State.merge(transformed_gp_params, *other_params)
69
67
 
@@ -79,7 +77,7 @@ class Parameter(nnx.Variable[T]):
79
77
  _check_is_arraylike(value)
80
78
 
81
79
  super().__init__(value=jnp.asarray(value), **kwargs)
82
- self._tag = tag
80
+ self.tag = tag
83
81
 
84
82
 
85
83
  class NonNegativeReal(Parameter[T]):
@@ -16,15 +16,6 @@
16
16
  import abc
17
17
 
18
18
  import beartype.typing as tp
19
- from cola.annotations import PSD
20
- from cola.linalg.decompositions.decompositions import Cholesky
21
- from cola.linalg.inverse.inv import solve
22
- from cola.ops.operators import (
23
- Dense,
24
- I_like,
25
- Identity,
26
- Triangular,
27
- )
28
19
  from flax import nnx
29
20
  import jax.numpy as jnp
30
21
  import jax.scipy as jsp
@@ -41,7 +32,14 @@ from gpjax.likelihoods import (
41
32
  Gaussian,
42
33
  NonGaussian,
43
34
  )
44
- from gpjax.lower_cholesky import lower_cholesky
35
+ from gpjax.linalg import (
36
+ Dense,
37
+ Identity,
38
+ Triangular,
39
+ lower_cholesky,
40
+ psd,
41
+ solve,
42
+ )
45
43
  from gpjax.mean_functions import AbstractMeanFunction
46
44
  from gpjax.parameters import (
47
45
  LowerTriangular,
@@ -189,7 +187,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
189
187
 
190
188
  muz = mean_function(z)
191
189
  Kzz = kernel.gram(z)
192
- Kzz = PSD(Kzz + I_like(Kzz) * self.jitter)
190
+ Kzz = psd(Dense(Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter))
193
191
 
194
192
  sqrt = Triangular(sqrt)
195
193
  S = sqrt @ sqrt.T
@@ -226,7 +224,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
226
224
  kernel = self.posterior.prior.kernel
227
225
 
228
226
  Kzz = kernel.gram(z)
229
- Kzz += I_like(Kzz) * self.jitter
227
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
228
+ Kzz = psd(Dense(Kzz_dense))
230
229
  Lz = lower_cholesky(Kzz)
231
230
  muz = mean_function(z)
232
231
 
@@ -238,10 +237,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
238
237
  mut = mean_function(t)
239
238
 
240
239
  # Lz⁻¹ Kzt
241
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
240
+ Lz_inv_Kzt = solve(Lz, Kzt)
242
241
 
243
242
  # Kzz⁻¹ Kzt
244
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
243
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
245
244
 
246
245
  # Ktz Kzz⁻¹ sqrt
247
246
  Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
@@ -255,7 +254,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
255
254
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
256
255
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
257
256
  )
258
- covariance += I_like(covariance) * self.jitter
257
+ covariance += jnp.eye(covariance.shape[0]) * self.jitter
259
258
 
260
259
  return GaussianDistribution(
261
260
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -330,7 +329,8 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
330
329
  kernel = self.posterior.prior.kernel
331
330
 
332
331
  Kzz = kernel.gram(z)
333
- Kzz += I_like(Kzz) * self.jitter
332
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
333
+ Kzz = psd(Dense(Kzz_dense))
334
334
  Lz = lower_cholesky(Kzz)
335
335
 
336
336
  # Unpack test inputs
@@ -341,7 +341,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
341
341
  mut = mean_function(t)
342
342
 
343
343
  # Lz⁻¹ Kzt
344
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
344
+ Lz_inv_Kzt = solve(Lz, Kzt)
345
345
 
346
346
  # Ktz Lz⁻ᵀ sqrt
347
347
  Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt)
@@ -355,7 +355,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
355
355
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
356
356
  + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
357
357
  )
358
- covariance += I_like(covariance) * self.jitter
358
+ covariance += jnp.eye(covariance.shape[0]) * self.jitter
359
359
 
360
360
  return GaussianDistribution(
361
361
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -441,7 +441,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
441
441
 
442
442
  muz = mean_function(z)
443
443
  Kzz = kernel.gram(z)
444
- Kzz += I_like(Kzz) * self.jitter
444
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
445
+ Kzz = psd(Dense(Kzz_dense))
445
446
 
446
447
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
447
448
  pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
@@ -492,7 +493,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
492
493
  mu = jnp.matmul(S, natural_vector)
493
494
 
494
495
  Kzz = kernel.gram(z)
495
- Kzz += I_like(Kzz) * self.jitter
496
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
497
+ Kzz = psd(Dense(Kzz_dense))
496
498
  Lz = lower_cholesky(Kzz)
497
499
  muz = mean_function(z)
498
500
 
@@ -501,10 +503,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
501
503
  mut = mean_function(test_inputs)
502
504
 
503
505
  # Lz⁻¹ Kzt
504
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
506
+ Lz_inv_Kzt = solve(Lz, Kzt)
505
507
 
506
508
  # Kzz⁻¹ Kzt
507
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
509
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
508
510
 
509
511
  # Ktz Kzz⁻¹ L
510
512
  Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
@@ -518,7 +520,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
518
520
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
519
521
  + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
520
522
  )
521
- covariance += I_like(covariance) * self.jitter
523
+ covariance += jnp.eye(covariance.shape[0]) * self.jitter
522
524
 
523
525
  return GaussianDistribution(
524
526
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -592,12 +594,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
592
594
 
593
595
  # S = η₂ - η₁ η₁ᵀ
594
596
  S = expectation_matrix - jnp.outer(mu, mu)
595
- S = PSD(Dense(S))
596
- S += I_like(S) * self.jitter
597
+ S = psd(Dense(S))
598
+ S_dense = S.to_dense() + jnp.eye(S.shape[0]) * self.jitter
599
+ S = psd(Dense(S_dense))
597
600
 
598
601
  muz = mean_function(z)
599
602
  Kzz = kernel.gram(z)
600
- Kzz += I_like(Kzz) * self.jitter
603
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
604
+ Kzz = psd(Dense(Kzz_dense))
601
605
 
602
606
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
603
607
  pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
@@ -636,14 +640,15 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
636
640
 
637
641
  # S = η₂ - η₁ η₁ᵀ
638
642
  S = expectation_matrix - jnp.matmul(mu, mu.T)
639
- S = Dense(S) + I_like(S) * self.jitter
640
- S = PSD(S)
643
+ S = Dense(S + jnp.eye(S.shape[0]) * self.jitter)
644
+ S = psd(S)
641
645
 
642
646
  # S = sqrt sqrtᵀ
643
647
  sqrt = lower_cholesky(S)
644
648
 
645
649
  Kzz = kernel.gram(z)
646
- Kzz += I_like(Kzz) * self.jitter
650
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
651
+ Kzz = psd(Dense(Kzz_dense))
647
652
  Lz = lower_cholesky(Kzz)
648
653
  muz = mean_function(z)
649
654
 
@@ -655,10 +660,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
655
660
  mut = mean_function(t)
656
661
 
657
662
  # Lz⁻¹ Kzt
658
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
663
+ Lz_inv_Kzt = solve(Lz, Kzt)
659
664
 
660
665
  # Kzz⁻¹ Kzt
661
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
666
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
662
667
 
663
668
  # Ktz Kzz⁻¹ sqrt
664
669
  Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt
@@ -672,7 +677,7 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
672
677
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
673
678
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
674
679
  )
675
- covariance += I_like(covariance) * self.jitter
680
+ covariance += jnp.eye(covariance.shape[0]) * self.jitter
676
681
 
677
682
  return GaussianDistribution(
678
683
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -729,13 +734,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
729
734
 
730
735
  Kzx = kernel.cross_covariance(z, x)
731
736
  Kzz = kernel.gram(z)
732
- Kzz += I_like(Kzz) * self.jitter
737
+ Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
738
+ Kzz = psd(Dense(Kzz_dense))
733
739
 
734
740
  # Lz Lzᵀ = Kzz
735
741
  Lz = lower_cholesky(Kzz)
736
742
 
737
743
  # Lz⁻¹ Kzx
738
- Lz_inv_Kzx = solve(Lz, Kzx, Cholesky())
744
+ Lz_inv_Kzx = solve(Lz, Kzx)
739
745
 
740
746
  # A = Lz⁻¹ Kzt / o
741
747
  A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev.value
@@ -753,14 +759,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
753
759
  Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff))
754
760
 
755
761
  # Kzz⁻¹ Kzx (y - μx)
756
- Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff, Cholesky())
762
+ Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff)
757
763
 
758
764
  Ktt = kernel.gram(t)
759
765
  Kzt = kernel.cross_covariance(z, t)
760
766
  mut = mean_function(t)
761
767
 
762
768
  # Lz⁻¹ Kzt
763
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
769
+ Lz_inv_Kzt = solve(Lz, Kzt)
764
770
 
765
771
  # L⁻¹ Lz⁻¹ Kzt
766
772
  L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)
@@ -774,7 +780,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
774
780
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
775
781
  + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
776
782
  )
777
- covariance += I_like(covariance) * self.jitter
783
+ covariance += jnp.eye(covariance.shape[0]) * self.jitter
778
784
 
779
785
  return GaussianDistribution(
780
786
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.2
3
+ Version: 0.12.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
17
18
  Classifier: Programming Language :: Python :: Implementation :: CPython
18
19
  Classifier: Programming Language :: Python :: Implementation :: PyPy
19
- Requires-Python: <3.13,>=3.10
20
+ Requires-Python: <=3.13,>=3.10
20
21
  Requires-Dist: beartype>0.16.1
21
- Requires-Dist: cola-ml>=0.0.7
22
22
  Requires-Dist: flax>=0.10.0
23
23
  Requires-Dist: jax>=0.5.0
24
24
  Requires-Dist: jaxlib>=0.5.0
@@ -27,6 +27,47 @@ Requires-Dist: numpy>=2.0.0
27
27
  Requires-Dist: numpyro
28
28
  Requires-Dist: optax>0.2.1
29
29
  Requires-Dist: tqdm>4.66.2
30
+ Provides-Extra: dev
31
+ Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
32
+ Requires-Dist: autoflake; extra == 'dev'
33
+ Requires-Dist: black; extra == 'dev'
34
+ Requires-Dist: codespell>=2.2.4; extra == 'dev'
35
+ Requires-Dist: coverage>=7.2.2; extra == 'dev'
36
+ Requires-Dist: interrogate>=1.5.0; extra == 'dev'
37
+ Requires-Dist: isort; extra == 'dev'
38
+ Requires-Dist: jupytext; extra == 'dev'
39
+ Requires-Dist: mktestdocs>=0.2.1; extra == 'dev'
40
+ Requires-Dist: networkx; extra == 'dev'
41
+ Requires-Dist: pre-commit>=3.2.2; extra == 'dev'
42
+ Requires-Dist: pytest-beartype; extra == 'dev'
43
+ Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
44
+ Requires-Dist: pytest-pretty>=1.1.1; extra == 'dev'
45
+ Requires-Dist: pytest-xdist>=3.2.1; extra == 'dev'
46
+ Requires-Dist: pytest>=7.2.2; extra == 'dev'
47
+ Requires-Dist: ruff>=0.6; extra == 'dev'
48
+ Requires-Dist: xdoctest>=1.1.1; extra == 'dev'
49
+ Provides-Extra: docs
50
+ Requires-Dist: blackjax>=0.9.6; extra == 'docs'
51
+ Requires-Dist: ipykernel>=6.22.0; extra == 'docs'
52
+ Requires-Dist: ipython>=8.11.0; extra == 'docs'
53
+ Requires-Dist: ipywidgets>=8.0.5; extra == 'docs'
54
+ Requires-Dist: jupytext>=1.14.5; extra == 'docs'
55
+ Requires-Dist: markdown-katex>=202406.1035; extra == 'docs'
56
+ Requires-Dist: matplotlib>=3.7.1; extra == 'docs'
57
+ Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
58
+ Requires-Dist: mkdocs-git-authors-plugin>=0.7.0; extra == 'docs'
59
+ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
60
+ Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
61
+ Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
62
+ Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<0.28.0; extra == 'docs'
64
+ Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
65
+ Requires-Dist: networkx>=3.0; extra == 'docs'
66
+ Requires-Dist: pandas>=1.5.3; extra == 'docs'
67
+ Requires-Dist: pymdown-extensions>=10.7.1; extra == 'docs'
68
+ Requires-Dist: scikit-learn>=1.5.1; extra == 'docs'
69
+ Requires-Dist: seaborn>=0.12.2; extra == 'docs'
70
+ Requires-Dist: watermark>=2.3.1; extra == 'docs'
30
71
  Description-Content-Type: text/markdown
31
72
 
32
73
  <!-- <h1 align='center'>GPJax</h1>
@@ -41,12 +82,12 @@ Description-Content-Type: text/markdown
41
82
  [![PyPI version](https://badge.fury.io/py/GPJax.svg)](https://badge.fury.io/py/GPJax)
42
83
  [![DOI](https://joss.theoj.org/papers/10.21105/joss.04455/status.svg)](https://doi.org/10.21105/joss.04455)
43
84
  [![Downloads](https://pepy.tech/badge/gpjax)](https://pepy.tech/project/gpjax)
44
- [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)
85
+ [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
45
86
 
46
87
  [**Quickstart**](#simple-example)
47
88
  | [**Install guide**](#installation)
48
89
  | [**Documentation**](https://docs.jaxgaussianprocesses.com/)
49
- | [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)
90
+ | [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
50
91
 
51
92
  GPJax aims to provide a low-level interface to Gaussian process (GP) models in
52
93
  [Jax](https://github.com/google/jax), structured to give researchers maximum
@@ -81,7 +122,7 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
81
122
  one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
82
123
 
83
124
  Feel free to join our [Slack
84
- Channel](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw),
125
+ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
85
126
  where we can discuss the development of GPJax and broader support for Gaussian
86
127
  process modelling.
87
128
 
@@ -177,14 +218,13 @@ configuration in development mode.
177
218
  ```bash
178
219
  git clone https://github.com/JaxGaussianProcesses/GPJax.git
179
220
  cd GPJax
180
- hatch env create
181
- hatch shell
221
+ uv sync --extra dev
182
222
  ```
183
223
 
184
224
  > We recommend you check your installation passes the supplied unit tests:
185
225
  >
186
226
  > ```python
187
- > hatch run dev:test
227
+ > uv run pytest --beartype-packages='gpjax'
188
228
  > ```
189
229
 
190
230
  # Citing GPJax
@@ -1,29 +1,28 @@
1
- gpjax/__init__.py,sha256=ylCFMtXwcMS2zxm4pI3KsnRdnX6bdh26TSdTfUh9l9o,1686
1
+ gpjax/__init__.py,sha256=FSrKDFSQ7xDqwQGBWwEPqqjvYxEbhUPPestKLoAPjWA,1686
2
2
  gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
4
+ gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
5
  gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
6
- gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
6
+ gpjax/gps.py,sha256=-Log0pcU8qmB5fUxfzoNjD0S64gpiypAjFzjGXX6w7I,30301
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
8
  gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
9
- gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
9
  gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
11
10
  gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
12
- gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
13
- gpjax/parameters.py,sha256=H-DiXmotdBZCbf-GOjRaJoS_isk3GgFrpKHTq5GpnoA,6998
11
+ gpjax/objectives.py,sha256=Tm36h8fz_nWkZPlufMQzZWKK1ytrtT9yvvP8YdxYKNw,15359
12
+ gpjax/parameters.py,sha256=qIEqyMKNd2n2Ak15PisCmqhX5qhsoRgng_s4doL96rE,7044
14
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
15
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
16
- gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
15
+ gpjax/variational_families.py,sha256=rE3LarwIAkvDvLlWrz8Ww6BUBz88YHdV4ceY97r3IBw,28637
17
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
18
- gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
17
+ gpjax/kernels/base.py,sha256=hOUXwarspDFnuI2_QreyIVPdz2fzRVJj4p3Zdu1touw,11606
19
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
20
19
  gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
21
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
22
- gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
23
- gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
24
- gpjax/kernels/computations/constant_diagonal.py,sha256=_4fIFPq76Z416-9dIr8N335WP291dGluO-RqqUsK9ZY,2058
21
+ gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
+ gpjax/kernels/computations/basis_functions.py,sha256=MPSo40NEx_ngnSLTa9ntVJzma_jugvm5dMpZd5MtG5M,2490
23
+ gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
25
24
  gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
26
- gpjax/kernels/computations/diagonal.py,sha256=z2JpUue7oY-pL-c0Pc6Bngv_IJCR6z4MaW7kN0wgjxk,1803
25
+ gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
27
26
  gpjax/kernels/computations/eigen.py,sha256=w7I7LK42j0ouchHCI1ltXx0lpwqvK1bRb4HclnF3rKs,1936
28
27
  gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
29
28
  gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
@@ -43,7 +42,11 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
43
42
  gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
43
  gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
45
44
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
46
- gpjax-0.11.2.dist-info/METADATA,sha256=lTQVlrUbkxI7fU9Gdnac_eoNRyjCHEoEuEvvWbKmDqM,8558
47
- gpjax-0.11.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
- gpjax-0.11.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
- gpjax-0.11.2.dist-info/RECORD,,
45
+ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
+ gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
+ gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
+ gpjax/linalg/utils.py,sha256=DGX40TDhmfYn7JBxElpBm_9W0cetm0HZUK7B3j74xxo,895
49
+ gpjax-0.12.0.dist-info/METADATA,sha256=8lLQb5SUvWvniry-zBOR3wzm03tXvHe7Lzry_Ho3peE,10562
50
+ gpjax-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.12.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.12.0.dist-info/RECORD,,
gpjax/lower_cholesky.py DELETED
@@ -1,69 +0,0 @@
1
- # Copyright 2023 The GPJax Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from cola.annotations import PSD
17
- from cola.fns import dispatch
18
- from cola.ops.operator_base import LinearOperator
19
- from cola.ops.operators import (
20
- BlockDiag,
21
- Diagonal,
22
- Identity,
23
- Kronecker,
24
- Triangular,
25
- )
26
- import jax.numpy as jnp
27
-
28
- # TODO: Once this functionality is supported in CoLA, remove this.
29
-
30
-
31
- @dispatch
32
- def lower_cholesky(A: LinearOperator) -> Triangular: # noqa: F811
33
- """Returns the lower Cholesky factor of a linear operator.
34
-
35
- Args:
36
- A: The input linear operator.
37
-
38
- Returns:
39
- Triangular: The lower Cholesky factor of A.
40
- """
41
-
42
- if PSD not in A.annotations:
43
- raise ValueError(
44
- "Expected LinearOperator to be PSD, did you forget to use cola.PSD?"
45
- )
46
-
47
- return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
48
-
49
-
50
- @lower_cholesky.dispatch
51
- def _(A: Diagonal): # noqa: F811
52
- return Diagonal(jnp.sqrt(A.diag))
53
-
54
-
55
- @lower_cholesky.dispatch
56
- def _(A: Identity): # noqa: F811
57
- return A
58
-
59
-
60
- @lower_cholesky.dispatch
61
- def _(A: Kronecker): # noqa: F811
62
- return Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms])
63
-
64
-
65
- @lower_cholesky.dispatch
66
- def _(A: BlockDiag): # noqa: F811
67
- return BlockDiag(
68
- *[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities
69
- )
File without changes