gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,15 +16,6 @@
16
16
  import abc
17
17
 
18
18
  import beartype.typing as tp
19
- from cola.annotations import PSD
20
- from cola.linalg.decompositions.decompositions import Cholesky
21
- from cola.linalg.inverse.inv import solve
22
- from cola.ops.operators import (
23
- Dense,
24
- I_like,
25
- Identity,
26
- Triangular,
27
- )
28
19
  from flax import nnx
29
20
  import jax.numpy as jnp
30
21
  import jax.scipy as jsp
@@ -41,12 +32,19 @@ from gpjax.likelihoods import (
41
32
  Gaussian,
42
33
  NonGaussian,
43
34
  )
44
- from gpjax.lower_cholesky import lower_cholesky
35
+ from gpjax.linalg import (
36
+ Dense,
37
+ Identity,
38
+ Triangular,
39
+ lower_cholesky,
40
+ psd,
41
+ solve,
42
+ )
43
+ from gpjax.linalg.utils import add_jitter
45
44
  from gpjax.mean_functions import AbstractMeanFunction
46
45
  from gpjax.parameters import (
47
46
  LowerTriangular,
48
47
  Real,
49
- Static,
50
48
  )
51
49
  from gpjax.typing import (
52
50
  Array,
@@ -112,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
112
110
  inducing_inputs: tp.Union[
113
111
  Float[Array, "N D"],
114
112
  Real,
115
- Static,
116
113
  ],
117
114
  jitter: ScalarFloat = 1e-6,
118
115
  ):
119
- if not isinstance(inducing_inputs, (Real, Static)):
116
+ if not isinstance(inducing_inputs, Real):
120
117
  inducing_inputs = Real(inducing_inputs)
121
118
 
122
119
  self.inducing_inputs = inducing_inputs
@@ -179,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
179
176
  approximation and the GP prior.
180
177
  """
181
178
  # Unpack variational parameters
182
- mu = self.variational_mean.value
183
- sqrt = self.variational_root_covariance.value
184
- z = self.inducing_inputs.value
179
+ variational_mean = self.variational_mean.value
180
+ variational_sqrt = self.variational_root_covariance.value
181
+ inducing_inputs = self.inducing_inputs.value
185
182
 
186
183
  # Unpack mean function and kernel
187
184
  mean_function = self.posterior.prior.mean_function
188
185
  kernel = self.posterior.prior.kernel
189
186
 
190
- muz = mean_function(z)
191
- Kzz = kernel.gram(z)
192
- Kzz = PSD(Kzz + I_like(Kzz) * self.jitter)
187
+ inducing_mean = mean_function(inducing_inputs)
188
+ Kzz = kernel.gram(inducing_inputs)
189
+ Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
193
190
 
194
- sqrt = Triangular(sqrt)
195
- S = sqrt @ sqrt.T
191
+ variational_sqrt_triangular = Triangular(variational_sqrt)
192
+ variational_covariance = (
193
+ variational_sqrt_triangular @ variational_sqrt_triangular.T
194
+ )
196
195
 
197
- qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
198
- pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
196
+ q_inducing = GaussianDistribution(
197
+ loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
198
+ )
199
+ p_inducing = GaussianDistribution(
200
+ loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
201
+ )
199
202
 
200
- return qu.kl_divergence(pu)
203
+ return q_inducing.kl_divergence(p_inducing)
201
204
 
202
205
  def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
203
206
  r"""Compute the predictive distribution of the GP at the test inputs t.
@@ -217,37 +220,38 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
217
220
  the test inputs.
218
221
  """
219
222
  # Unpack variational parameters
220
- mu = self.variational_mean.value
221
- sqrt = self.variational_root_covariance.value
222
- z = self.inducing_inputs.value
223
+ variational_mean = self.variational_mean.value
224
+ variational_sqrt = self.variational_root_covariance.value
225
+ inducing_inputs = self.inducing_inputs.value
223
226
 
224
227
  # Unpack mean function and kernel
225
228
  mean_function = self.posterior.prior.mean_function
226
229
  kernel = self.posterior.prior.kernel
227
230
 
228
- Kzz = kernel.gram(z)
229
- Kzz += I_like(Kzz) * self.jitter
231
+ Kzz = kernel.gram(inducing_inputs)
232
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
233
+ Kzz = psd(Dense(Kzz_dense))
230
234
  Lz = lower_cholesky(Kzz)
231
- muz = mean_function(z)
235
+ inducing_mean = mean_function(inducing_inputs)
232
236
 
233
237
  # Unpack test inputs
234
- t = test_inputs
238
+ test_points = test_inputs
235
239
 
236
- Ktt = kernel.gram(t)
237
- Kzt = kernel.cross_covariance(z, t)
238
- mut = mean_function(t)
240
+ Ktt = kernel.gram(test_points)
241
+ Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
+ test_mean = mean_function(test_points)
239
243
 
240
244
  # Lz⁻¹ Kzt
241
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
245
+ Lz_inv_Kzt = solve(Lz, Kzt)
242
246
 
243
247
  # Kzz⁻¹ Kzt
244
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
248
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
245
249
 
246
250
  # Ktz Kzz⁻¹ sqrt
247
- Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
251
+ Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
248
252
 
249
253
  # μt + Ktz Kzz⁻¹ (μ - μz)
250
- mean = mut + jnp.matmul(Kzz_inv_Kzt.T, mu - muz)
254
+ mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
251
255
 
252
256
  # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
253
257
  covariance = (
@@ -255,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
255
259
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
256
260
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
257
261
  )
258
- covariance += I_like(covariance) * self.jitter
262
+ if hasattr(covariance, "to_dense"):
263
+ covariance = covariance.to_dense()
264
+ covariance = add_jitter(covariance, self.jitter)
265
+ covariance = Dense(covariance)
259
266
 
260
267
  return GaussianDistribution(
261
268
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -330,7 +337,8 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
330
337
  kernel = self.posterior.prior.kernel
331
338
 
332
339
  Kzz = kernel.gram(z)
333
- Kzz += I_like(Kzz) * self.jitter
340
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
341
+ Kzz = psd(Dense(Kzz_dense))
334
342
  Lz = lower_cholesky(Kzz)
335
343
 
336
344
  # Unpack test inputs
@@ -341,7 +349,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
341
349
  mut = mean_function(t)
342
350
 
343
351
  # Lz⁻¹ Kzt
344
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
352
+ Lz_inv_Kzt = solve(Lz, Kzt)
345
353
 
346
354
  # Ktz Lz⁻ᵀ sqrt
347
355
  Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt)
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
355
363
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
356
364
  + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
357
365
  )
358
- covariance += I_like(covariance) * self.jitter
366
+ if hasattr(covariance, "to_dense"):
367
+ covariance = covariance.to_dense()
368
+ covariance = add_jitter(covariance, self.jitter)
369
+ covariance = Dense(covariance)
359
370
 
360
371
  return GaussianDistribution(
361
372
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
390
401
  if natural_matrix is None:
391
402
  natural_matrix = -0.5 * jnp.eye(self.num_inducing)
392
403
 
393
- self.natural_vector = Static(natural_vector)
394
- self.natural_matrix = Static(natural_matrix)
404
+ self.natural_vector = Real(natural_vector)
405
+ self.natural_matrix = Real(natural_matrix)
395
406
 
396
407
  def prior_kl(self) -> ScalarFloat:
397
408
  r"""Compute the KL-divergence between our current variational approximation
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
422
433
 
423
434
  # S⁻¹ = -2θ₂
424
435
  S_inv = -2 * natural_matrix
425
- S_inv += jnp.eye(m) * self.jitter
436
+ S_inv = add_jitter(S_inv, self.jitter)
426
437
 
427
438
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
428
439
  sqrt_inv = jnp.swapaxes(
@@ -441,7 +452,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
441
452
 
442
453
  muz = mean_function(z)
443
454
  Kzz = kernel.gram(z)
444
- Kzz += I_like(Kzz) * self.jitter
455
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
456
+ Kzz = psd(Dense(Kzz_dense))
445
457
 
446
458
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
447
459
  pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
@@ -475,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
475
487
 
476
488
  # S⁻¹ = -2θ₂
477
489
  S_inv = -2 * natural_matrix
478
- S_inv += jnp.eye(m) * self.jitter
490
+ S_inv = add_jitter(S_inv, self.jitter)
479
491
 
480
492
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
481
493
  sqrt_inv = jnp.swapaxes(
@@ -492,7 +504,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
492
504
  mu = jnp.matmul(S, natural_vector)
493
505
 
494
506
  Kzz = kernel.gram(z)
495
- Kzz += I_like(Kzz) * self.jitter
507
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
508
+ Kzz = psd(Dense(Kzz_dense))
496
509
  Lz = lower_cholesky(Kzz)
497
510
  muz = mean_function(z)
498
511
 
@@ -501,10 +514,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
501
514
  mut = mean_function(test_inputs)
502
515
 
503
516
  # Lz⁻¹ Kzt
504
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
517
+ Lz_inv_Kzt = solve(Lz, Kzt)
505
518
 
506
519
  # Kzz⁻¹ Kzt
507
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
520
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
508
521
 
509
522
  # Ktz Kzz⁻¹ L
510
523
  Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
@@ -518,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
518
531
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
519
532
  + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
520
533
  )
521
- covariance += I_like(covariance) * self.jitter
534
+ if hasattr(covariance, "to_dense"):
535
+ covariance = covariance.to_dense()
536
+ covariance = add_jitter(covariance, self.jitter)
537
+ covariance = Dense(covariance)
522
538
 
523
539
  return GaussianDistribution(
524
540
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -554,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
554
570
  if expectation_matrix is None:
555
571
  expectation_matrix = jnp.eye(self.num_inducing)
556
572
 
557
- self.expectation_vector = Static(expectation_vector)
558
- self.expectation_matrix = Static(expectation_matrix)
573
+ self.expectation_vector = Real(expectation_vector)
574
+ self.expectation_matrix = Real(expectation_matrix)
559
575
 
560
576
  def prior_kl(self) -> ScalarFloat:
561
577
  r"""Evaluate the prior KL-divergence.
@@ -592,12 +608,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
592
608
 
593
609
  # S = η₂ - η₁ η₁ᵀ
594
610
  S = expectation_matrix - jnp.outer(mu, mu)
595
- S = PSD(Dense(S))
596
- S += I_like(S) * self.jitter
611
+ S = psd(Dense(S))
612
+ S_dense = add_jitter(S.to_dense(), self.jitter)
613
+ S = psd(Dense(S_dense))
597
614
 
598
615
  muz = mean_function(z)
599
616
  Kzz = kernel.gram(z)
600
- Kzz += I_like(Kzz) * self.jitter
617
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
618
+ Kzz = psd(Dense(Kzz_dense))
601
619
 
602
620
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
603
621
  pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
@@ -636,14 +654,15 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
636
654
 
637
655
  # S = η₂ - η₁ η₁ᵀ
638
656
  S = expectation_matrix - jnp.matmul(mu, mu.T)
639
- S = Dense(S) + I_like(S) * self.jitter
640
- S = PSD(S)
657
+ S = Dense(add_jitter(S, self.jitter))
658
+ S = psd(S)
641
659
 
642
660
  # S = sqrt sqrtᵀ
643
661
  sqrt = lower_cholesky(S)
644
662
 
645
663
  Kzz = kernel.gram(z)
646
- Kzz += I_like(Kzz) * self.jitter
664
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
665
+ Kzz = psd(Dense(Kzz_dense))
647
666
  Lz = lower_cholesky(Kzz)
648
667
  muz = mean_function(z)
649
668
 
@@ -655,10 +674,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
655
674
  mut = mean_function(t)
656
675
 
657
676
  # Lz⁻¹ Kzt
658
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
677
+ Lz_inv_Kzt = solve(Lz, Kzt)
659
678
 
660
679
  # Kzz⁻¹ Kzt
661
- Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt, Cholesky())
680
+ Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
662
681
 
663
682
  # Ktz Kzz⁻¹ sqrt
664
683
  Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt
@@ -672,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
672
691
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
673
692
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
674
693
  )
675
- covariance += I_like(covariance) * self.jitter
694
+ if hasattr(covariance, "to_dense"):
695
+ covariance = covariance.to_dense()
696
+ covariance = add_jitter(covariance, self.jitter)
697
+ covariance = Dense(covariance)
676
698
 
677
699
  return GaussianDistribution(
678
700
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -729,13 +751,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
729
751
 
730
752
  Kzx = kernel.cross_covariance(z, x)
731
753
  Kzz = kernel.gram(z)
732
- Kzz += I_like(Kzz) * self.jitter
754
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
755
+ Kzz = psd(Dense(Kzz_dense))
733
756
 
734
757
  # Lz Lzᵀ = Kzz
735
758
  Lz = lower_cholesky(Kzz)
736
759
 
737
760
  # Lz⁻¹ Kzx
738
- Lz_inv_Kzx = solve(Lz, Kzx, Cholesky())
761
+ Lz_inv_Kzx = solve(Lz, Kzx)
739
762
 
740
763
  # A = Lz⁻¹ Kzt / o
741
764
  A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev.value
@@ -753,14 +776,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
753
776
  Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff))
754
777
 
755
778
  # Kzz⁻¹ Kzx (y - μx)
756
- Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff, Cholesky())
779
+ Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff)
757
780
 
758
781
  Ktt = kernel.gram(t)
759
782
  Kzt = kernel.cross_covariance(z, t)
760
783
  mut = mean_function(t)
761
784
 
762
785
  # Lz⁻¹ Kzt
763
- Lz_inv_Kzt = solve(Lz, Kzt, Cholesky())
786
+ Lz_inv_Kzt = solve(Lz, Kzt)
764
787
 
765
788
  # L⁻¹ Lz⁻¹ Kzt
766
789
  L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)
@@ -774,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
774
797
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
775
798
  + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
776
799
  )
777
- covariance += I_like(covariance) * self.jitter
800
+ if hasattr(covariance, "to_dense"):
801
+ covariance = covariance.to_dense()
802
+ covariance = add_jitter(covariance, self.jitter)
803
+ covariance = Dense(covariance)
778
804
 
779
805
  return GaussianDistribution(
780
806
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.2
3
+ Version: 0.12.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
17
18
  Classifier: Programming Language :: Python :: Implementation :: CPython
18
19
  Classifier: Programming Language :: Python :: Implementation :: PyPy
19
- Requires-Python: <3.13,>=3.10
20
+ Requires-Python: >=3.10
20
21
  Requires-Dist: beartype>0.16.1
21
- Requires-Dist: cola-ml>=0.0.7
22
22
  Requires-Dist: flax>=0.10.0
23
23
  Requires-Dist: jax>=0.5.0
24
24
  Requires-Dist: jaxlib>=0.5.0
@@ -27,6 +27,47 @@ Requires-Dist: numpy>=2.0.0
27
27
  Requires-Dist: numpyro
28
28
  Requires-Dist: optax>0.2.1
29
29
  Requires-Dist: tqdm>4.66.2
30
+ Provides-Extra: dev
31
+ Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
32
+ Requires-Dist: autoflake; extra == 'dev'
33
+ Requires-Dist: black; extra == 'dev'
34
+ Requires-Dist: codespell>=2.2.4; extra == 'dev'
35
+ Requires-Dist: coverage>=7.2.2; extra == 'dev'
36
+ Requires-Dist: interrogate>=1.5.0; extra == 'dev'
37
+ Requires-Dist: isort; extra == 'dev'
38
+ Requires-Dist: jupytext; extra == 'dev'
39
+ Requires-Dist: mktestdocs>=0.2.1; extra == 'dev'
40
+ Requires-Dist: networkx; extra == 'dev'
41
+ Requires-Dist: pre-commit>=3.2.2; extra == 'dev'
42
+ Requires-Dist: pytest-beartype; extra == 'dev'
43
+ Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
44
+ Requires-Dist: pytest-pretty>=1.1.1; extra == 'dev'
45
+ Requires-Dist: pytest-xdist>=3.2.1; extra == 'dev'
46
+ Requires-Dist: pytest>=7.2.2; extra == 'dev'
47
+ Requires-Dist: ruff>=0.6; extra == 'dev'
48
+ Requires-Dist: xdoctest>=1.1.1; extra == 'dev'
49
+ Provides-Extra: docs
50
+ Requires-Dist: blackjax>=0.9.6; extra == 'docs'
51
+ Requires-Dist: ipykernel>=6.22.0; extra == 'docs'
52
+ Requires-Dist: ipython>=8.11.0; extra == 'docs'
53
+ Requires-Dist: ipywidgets>=8.0.5; extra == 'docs'
54
+ Requires-Dist: jupytext>=1.14.5; extra == 'docs'
55
+ Requires-Dist: markdown-katex>=202406.1035; extra == 'docs'
56
+ Requires-Dist: matplotlib>=3.7.1; extra == 'docs'
57
+ Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
58
+ Requires-Dist: mkdocs-git-authors-plugin>=0.7.0; extra == 'docs'
59
+ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
60
+ Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
61
+ Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
62
+ Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
64
+ Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
65
+ Requires-Dist: networkx>=3.0; extra == 'docs'
66
+ Requires-Dist: pandas>=1.5.3; extra == 'docs'
67
+ Requires-Dist: pymdown-extensions>=10.7.1; extra == 'docs'
68
+ Requires-Dist: scikit-learn>=1.5.1; extra == 'docs'
69
+ Requires-Dist: seaborn>=0.12.2; extra == 'docs'
70
+ Requires-Dist: watermark>=2.3.1; extra == 'docs'
30
71
  Description-Content-Type: text/markdown
31
72
 
32
73
  <!-- <h1 align='center'>GPJax</h1>
@@ -41,12 +82,12 @@ Description-Content-Type: text/markdown
41
82
  [![PyPI version](https://badge.fury.io/py/GPJax.svg)](https://badge.fury.io/py/GPJax)
42
83
  [![DOI](https://joss.theoj.org/papers/10.21105/joss.04455/status.svg)](https://doi.org/10.21105/joss.04455)
43
84
  [![Downloads](https://pepy.tech/badge/gpjax)](https://pepy.tech/project/gpjax)
44
- [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)
85
+ [![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
45
86
 
46
87
  [**Quickstart**](#simple-example)
47
88
  | [**Install guide**](#installation)
48
89
  | [**Documentation**](https://docs.jaxgaussianprocesses.com/)
49
- | [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw)
90
+ | [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
50
91
 
51
92
  GPJax aims to provide a low-level interface to Gaussian process (GP) models in
52
93
  [Jax](https://github.com/google/jax), structured to give researchers maximum
@@ -81,22 +122,13 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
81
122
  one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
82
123
 
83
124
  Feel free to join our [Slack
84
- Channel](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw),
125
+ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
85
126
  where we can discuss the development of GPJax and broader support for Gaussian
86
127
  process modelling.
87
128
 
88
-
89
- ## Governance
90
-
91
- GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
92
- project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
93
- [henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
94
- [thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
95
- governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
96
- [the contributors to
97
- GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
98
- helped to shape GPJax into the package it is today.
99
-
129
+ We appreciate all [the contributors to
130
+ GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
131
+ GPJax into the package it is today.
100
132
 
101
133
  # Supported methods and interfaces
102
134
 
@@ -0,0 +1,52 @@
1
+ gpjax/__init__.py,sha256=RzwpixFXn6HNHLVLy4LVXhFUk2c-_ce6n1gjZ2B93F0,1641
2
+ gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
+ gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
+ gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
+ gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
6
+ gpjax/gps.py,sha256=ipaeYMnPffhKK_JsEHe4fF8GmolQIjXB1YbyfUIL8H4,30118
7
+ gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
+ gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
+ gpjax/mean_functions.py,sha256=KiHQXI-b7o0Vi5KQxGm6RNsUjitJc9jEOCq2GrSx4II,6531
10
+ gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
+ gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
+ gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
+ gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
+ gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
+ gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
16
+ gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
+ gpjax/kernels/base.py,sha256=4Lx8y3kPX4WqQZGRGAsBkqn_i6FlfoAhSn9Tv415xuQ,11551
18
+ gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
+ gpjax/kernels/approximations/rff.py,sha256=GbNUmDPEKEKuMwxUcocxl_9IFR3Q9KEPZXzjy_ZD-2w,4043
20
+ gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
+ gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
+ gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
23
+ gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
+ gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
+ gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
26
+ gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
27
+ gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
28
+ gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
29
+ gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
30
+ gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
+ gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
32
+ gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
33
+ gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
34
+ gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
+ gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
36
+ gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
37
+ gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
38
+ gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
39
+ gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
40
+ gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
41
+ gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
42
+ gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
43
+ gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
44
+ gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
+ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
+ gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
+ gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
+ gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
+ gpjax-0.12.2.dist-info/METADATA,sha256=eckQKXiBXi8XbBeJFviBAIPdBGVWGFQg7wVZwMfPPxs,10129
50
+ gpjax-0.12.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.12.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.12.2.dist-info/RECORD,,
gpjax/lower_cholesky.py DELETED
@@ -1,69 +0,0 @@
1
- # Copyright 2023 The GPJax Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from cola.annotations import PSD
17
- from cola.fns import dispatch
18
- from cola.ops.operator_base import LinearOperator
19
- from cola.ops.operators import (
20
- BlockDiag,
21
- Diagonal,
22
- Identity,
23
- Kronecker,
24
- Triangular,
25
- )
26
- import jax.numpy as jnp
27
-
28
- # TODO: Once this functionality is supported in CoLA, remove this.
29
-
30
-
31
- @dispatch
32
- def lower_cholesky(A: LinearOperator) -> Triangular: # noqa: F811
33
- """Returns the lower Cholesky factor of a linear operator.
34
-
35
- Args:
36
- A: The input linear operator.
37
-
38
- Returns:
39
- Triangular: The lower Cholesky factor of A.
40
- """
41
-
42
- if PSD not in A.annotations:
43
- raise ValueError(
44
- "Expected LinearOperator to be PSD, did you forget to use cola.PSD?"
45
- )
46
-
47
- return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
48
-
49
-
50
- @lower_cholesky.dispatch
51
- def _(A: Diagonal): # noqa: F811
52
- return Diagonal(jnp.sqrt(A.diag))
53
-
54
-
55
- @lower_cholesky.dispatch
56
- def _(A: Identity): # noqa: F811
57
- return A
58
-
59
-
60
- @lower_cholesky.dispatch
61
- def _(A: Kronecker): # noqa: F811
62
- return Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms])
63
-
64
-
65
- @lower_cholesky.dispatch
66
- def _(A: BlockDiag): # noqa: F811
67
- return BlockDiag(
68
- *[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities
69
- )
@@ -1,49 +0,0 @@
1
- gpjax/__init__.py,sha256=ylCFMtXwcMS2zxm4pI3KsnRdnX6bdh26TSdTfUh9l9o,1686
2
- gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
- gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
5
- gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
6
- gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
- gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
9
- gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
- gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
11
- gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
12
- gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
13
- gpjax/parameters.py,sha256=H-DiXmotdBZCbf-GOjRaJoS_isk3GgFrpKHTq5GpnoA,6998
14
- gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
15
- gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
16
- gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
17
- gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
18
- gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
19
- gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
20
- gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
21
- gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
22
- gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
23
- gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
24
- gpjax/kernels/computations/constant_diagonal.py,sha256=_4fIFPq76Z416-9dIr8N335WP291dGluO-RqqUsK9ZY,2058
25
- gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
26
- gpjax/kernels/computations/diagonal.py,sha256=z2JpUue7oY-pL-c0Pc6Bngv_IJCR6z4MaW7kN0wgjxk,1803
27
- gpjax/kernels/computations/eigen.py,sha256=w7I7LK42j0ouchHCI1ltXx0lpwqvK1bRb4HclnF3rKs,1936
28
- gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
29
- gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
30
- gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
31
- gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
32
- gpjax/kernels/nonstationary/arccosine.py,sha256=2WV6aM0Z3-xXZnoPw-77n2CW62n-AZuJy-7AQ9xrMco,5858
33
- gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
34
- gpjax/kernels/nonstationary/polynomial.py,sha256=arP8DK0jnBOaayDWcFvHF0pdu9FVhwzXdqjnHUAL2VI,3293
35
- gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
36
- gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
37
- gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
38
- gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
39
- gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
40
- gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
41
- gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
42
- gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
43
- gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
- gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
45
- gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
46
- gpjax-0.11.2.dist-info/METADATA,sha256=lTQVlrUbkxI7fU9Gdnac_eoNRyjCHEoEuEvvWbKmDqM,8558
47
- gpjax-0.11.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
- gpjax-0.11.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
- gpjax-0.11.2.dist-info/RECORD,,
File without changes