gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/distributions.py +16 -56
- gpjax/fit.py +11 -6
- gpjax/gps.py +61 -73
- gpjax/kernels/approximations/rff.py +2 -5
- gpjax/kernels/base.py +2 -5
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +7 -6
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +65 -0
- gpjax/mean_functions.py +8 -7
- gpjax/objectives.py +22 -21
- gpjax/parameters.py +11 -23
- gpjax/variational_families.py +93 -67
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/METADATA +50 -18
- gpjax-0.12.2.dist-info/RECORD +52 -0
- gpjax/lower_cholesky.py +0 -69
- gpjax-0.11.2.dist-info/RECORD +0 -49
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/WHEEL +0 -0
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/variational_families.py
CHANGED
|
@@ -16,15 +16,6 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
21
|
-
from cola.linalg.inverse.inv import solve
|
|
22
|
-
from cola.ops.operators import (
|
|
23
|
-
Dense,
|
|
24
|
-
I_like,
|
|
25
|
-
Identity,
|
|
26
|
-
Triangular,
|
|
27
|
-
)
|
|
28
19
|
from flax import nnx
|
|
29
20
|
import jax.numpy as jnp
|
|
30
21
|
import jax.scipy as jsp
|
|
@@ -41,12 +32,19 @@ from gpjax.likelihoods import (
|
|
|
41
32
|
Gaussian,
|
|
42
33
|
NonGaussian,
|
|
43
34
|
)
|
|
44
|
-
from gpjax.
|
|
35
|
+
from gpjax.linalg import (
|
|
36
|
+
Dense,
|
|
37
|
+
Identity,
|
|
38
|
+
Triangular,
|
|
39
|
+
lower_cholesky,
|
|
40
|
+
psd,
|
|
41
|
+
solve,
|
|
42
|
+
)
|
|
43
|
+
from gpjax.linalg.utils import add_jitter
|
|
45
44
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
46
45
|
from gpjax.parameters import (
|
|
47
46
|
LowerTriangular,
|
|
48
47
|
Real,
|
|
49
|
-
Static,
|
|
50
48
|
)
|
|
51
49
|
from gpjax.typing import (
|
|
52
50
|
Array,
|
|
@@ -112,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
112
110
|
inducing_inputs: tp.Union[
|
|
113
111
|
Float[Array, "N D"],
|
|
114
112
|
Real,
|
|
115
|
-
Static,
|
|
116
113
|
],
|
|
117
114
|
jitter: ScalarFloat = 1e-6,
|
|
118
115
|
):
|
|
119
|
-
if not isinstance(inducing_inputs,
|
|
116
|
+
if not isinstance(inducing_inputs, Real):
|
|
120
117
|
inducing_inputs = Real(inducing_inputs)
|
|
121
118
|
|
|
122
119
|
self.inducing_inputs = inducing_inputs
|
|
@@ -179,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
179
176
|
approximation and the GP prior.
|
|
180
177
|
"""
|
|
181
178
|
# Unpack variational parameters
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
179
|
+
variational_mean = self.variational_mean.value
|
|
180
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
181
|
+
inducing_inputs = self.inducing_inputs.value
|
|
185
182
|
|
|
186
183
|
# Unpack mean function and kernel
|
|
187
184
|
mean_function = self.posterior.prior.mean_function
|
|
188
185
|
kernel = self.posterior.prior.kernel
|
|
189
186
|
|
|
190
|
-
|
|
191
|
-
Kzz = kernel.gram(
|
|
192
|
-
Kzz =
|
|
187
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
188
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
189
|
+
Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
|
|
193
190
|
|
|
194
|
-
|
|
195
|
-
|
|
191
|
+
variational_sqrt_triangular = Triangular(variational_sqrt)
|
|
192
|
+
variational_covariance = (
|
|
193
|
+
variational_sqrt_triangular @ variational_sqrt_triangular.T
|
|
194
|
+
)
|
|
196
195
|
|
|
197
|
-
|
|
198
|
-
|
|
196
|
+
q_inducing = GaussianDistribution(
|
|
197
|
+
loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
|
|
198
|
+
)
|
|
199
|
+
p_inducing = GaussianDistribution(
|
|
200
|
+
loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
|
|
201
|
+
)
|
|
199
202
|
|
|
200
|
-
return
|
|
203
|
+
return q_inducing.kl_divergence(p_inducing)
|
|
201
204
|
|
|
202
205
|
def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
|
|
203
206
|
r"""Compute the predictive distribution of the GP at the test inputs t.
|
|
@@ -217,37 +220,38 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
217
220
|
the test inputs.
|
|
218
221
|
"""
|
|
219
222
|
# Unpack variational parameters
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
+
variational_mean = self.variational_mean.value
|
|
224
|
+
variational_sqrt = self.variational_root_covariance.value
|
|
225
|
+
inducing_inputs = self.inducing_inputs.value
|
|
223
226
|
|
|
224
227
|
# Unpack mean function and kernel
|
|
225
228
|
mean_function = self.posterior.prior.mean_function
|
|
226
229
|
kernel = self.posterior.prior.kernel
|
|
227
230
|
|
|
228
|
-
Kzz = kernel.gram(
|
|
229
|
-
|
|
231
|
+
Kzz = kernel.gram(inducing_inputs)
|
|
232
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
233
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
230
234
|
Lz = lower_cholesky(Kzz)
|
|
231
|
-
|
|
235
|
+
inducing_mean = mean_function(inducing_inputs)
|
|
232
236
|
|
|
233
237
|
# Unpack test inputs
|
|
234
|
-
|
|
238
|
+
test_points = test_inputs
|
|
235
239
|
|
|
236
|
-
Ktt = kernel.gram(
|
|
237
|
-
Kzt = kernel.cross_covariance(
|
|
238
|
-
|
|
240
|
+
Ktt = kernel.gram(test_points)
|
|
241
|
+
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
|
|
242
|
+
test_mean = mean_function(test_points)
|
|
239
243
|
|
|
240
244
|
# Lz⁻¹ Kzt
|
|
241
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
245
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
242
246
|
|
|
243
247
|
# Kzz⁻¹ Kzt
|
|
244
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
248
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
245
249
|
|
|
246
250
|
# Ktz Kzz⁻¹ sqrt
|
|
247
|
-
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T,
|
|
251
|
+
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
|
|
248
252
|
|
|
249
253
|
# μt + Ktz Kzz⁻¹ (μ - μz)
|
|
250
|
-
mean =
|
|
254
|
+
mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
|
|
251
255
|
|
|
252
256
|
# Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
|
|
253
257
|
covariance = (
|
|
@@ -255,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
255
259
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
256
260
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
257
261
|
)
|
|
258
|
-
|
|
262
|
+
if hasattr(covariance, "to_dense"):
|
|
263
|
+
covariance = covariance.to_dense()
|
|
264
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
265
|
+
covariance = Dense(covariance)
|
|
259
266
|
|
|
260
267
|
return GaussianDistribution(
|
|
261
268
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -330,7 +337,8 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
330
337
|
kernel = self.posterior.prior.kernel
|
|
331
338
|
|
|
332
339
|
Kzz = kernel.gram(z)
|
|
333
|
-
|
|
340
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
341
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
334
342
|
Lz = lower_cholesky(Kzz)
|
|
335
343
|
|
|
336
344
|
# Unpack test inputs
|
|
@@ -341,7 +349,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
341
349
|
mut = mean_function(t)
|
|
342
350
|
|
|
343
351
|
# Lz⁻¹ Kzt
|
|
344
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
352
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
345
353
|
|
|
346
354
|
# Ktz Lz⁻ᵀ sqrt
|
|
347
355
|
Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt)
|
|
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
355
363
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
356
364
|
+ jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
|
|
357
365
|
)
|
|
358
|
-
|
|
366
|
+
if hasattr(covariance, "to_dense"):
|
|
367
|
+
covariance = covariance.to_dense()
|
|
368
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
369
|
+
covariance = Dense(covariance)
|
|
359
370
|
|
|
360
371
|
return GaussianDistribution(
|
|
361
372
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
390
401
|
if natural_matrix is None:
|
|
391
402
|
natural_matrix = -0.5 * jnp.eye(self.num_inducing)
|
|
392
403
|
|
|
393
|
-
self.natural_vector =
|
|
394
|
-
self.natural_matrix =
|
|
404
|
+
self.natural_vector = Real(natural_vector)
|
|
405
|
+
self.natural_matrix = Real(natural_matrix)
|
|
395
406
|
|
|
396
407
|
def prior_kl(self) -> ScalarFloat:
|
|
397
408
|
r"""Compute the KL-divergence between our current variational approximation
|
|
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
422
433
|
|
|
423
434
|
# S⁻¹ = -2θ₂
|
|
424
435
|
S_inv = -2 * natural_matrix
|
|
425
|
-
S_inv
|
|
436
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
426
437
|
|
|
427
438
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
428
439
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -441,7 +452,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
441
452
|
|
|
442
453
|
muz = mean_function(z)
|
|
443
454
|
Kzz = kernel.gram(z)
|
|
444
|
-
|
|
455
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
456
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
445
457
|
|
|
446
458
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
447
459
|
pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
|
|
@@ -475,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
475
487
|
|
|
476
488
|
# S⁻¹ = -2θ₂
|
|
477
489
|
S_inv = -2 * natural_matrix
|
|
478
|
-
S_inv
|
|
490
|
+
S_inv = add_jitter(S_inv, self.jitter)
|
|
479
491
|
|
|
480
492
|
# Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
|
|
481
493
|
sqrt_inv = jnp.swapaxes(
|
|
@@ -492,7 +504,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
492
504
|
mu = jnp.matmul(S, natural_vector)
|
|
493
505
|
|
|
494
506
|
Kzz = kernel.gram(z)
|
|
495
|
-
|
|
507
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
508
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
496
509
|
Lz = lower_cholesky(Kzz)
|
|
497
510
|
muz = mean_function(z)
|
|
498
511
|
|
|
@@ -501,10 +514,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
501
514
|
mut = mean_function(test_inputs)
|
|
502
515
|
|
|
503
516
|
# Lz⁻¹ Kzt
|
|
504
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
517
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
505
518
|
|
|
506
519
|
# Kzz⁻¹ Kzt
|
|
507
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
520
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
508
521
|
|
|
509
522
|
# Ktz Kzz⁻¹ L
|
|
510
523
|
Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
|
|
@@ -518,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
518
531
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
519
532
|
+ jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
|
|
520
533
|
)
|
|
521
|
-
|
|
534
|
+
if hasattr(covariance, "to_dense"):
|
|
535
|
+
covariance = covariance.to_dense()
|
|
536
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
537
|
+
covariance = Dense(covariance)
|
|
522
538
|
|
|
523
539
|
return GaussianDistribution(
|
|
524
540
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -554,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
554
570
|
if expectation_matrix is None:
|
|
555
571
|
expectation_matrix = jnp.eye(self.num_inducing)
|
|
556
572
|
|
|
557
|
-
self.expectation_vector =
|
|
558
|
-
self.expectation_matrix =
|
|
573
|
+
self.expectation_vector = Real(expectation_vector)
|
|
574
|
+
self.expectation_matrix = Real(expectation_matrix)
|
|
559
575
|
|
|
560
576
|
def prior_kl(self) -> ScalarFloat:
|
|
561
577
|
r"""Evaluate the prior KL-divergence.
|
|
@@ -592,12 +608,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
592
608
|
|
|
593
609
|
# S = η₂ - η₁ η₁ᵀ
|
|
594
610
|
S = expectation_matrix - jnp.outer(mu, mu)
|
|
595
|
-
S =
|
|
596
|
-
|
|
611
|
+
S = psd(Dense(S))
|
|
612
|
+
S_dense = add_jitter(S.to_dense(), self.jitter)
|
|
613
|
+
S = psd(Dense(S_dense))
|
|
597
614
|
|
|
598
615
|
muz = mean_function(z)
|
|
599
616
|
Kzz = kernel.gram(z)
|
|
600
|
-
|
|
617
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
618
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
601
619
|
|
|
602
620
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
603
621
|
pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
|
|
@@ -636,14 +654,15 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
636
654
|
|
|
637
655
|
# S = η₂ - η₁ η₁ᵀ
|
|
638
656
|
S = expectation_matrix - jnp.matmul(mu, mu.T)
|
|
639
|
-
S = Dense(
|
|
640
|
-
S =
|
|
657
|
+
S = Dense(add_jitter(S, self.jitter))
|
|
658
|
+
S = psd(S)
|
|
641
659
|
|
|
642
660
|
# S = sqrt sqrtᵀ
|
|
643
661
|
sqrt = lower_cholesky(S)
|
|
644
662
|
|
|
645
663
|
Kzz = kernel.gram(z)
|
|
646
|
-
|
|
664
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
665
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
647
666
|
Lz = lower_cholesky(Kzz)
|
|
648
667
|
muz = mean_function(z)
|
|
649
668
|
|
|
@@ -655,10 +674,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
655
674
|
mut = mean_function(t)
|
|
656
675
|
|
|
657
676
|
# Lz⁻¹ Kzt
|
|
658
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
677
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
659
678
|
|
|
660
679
|
# Kzz⁻¹ Kzt
|
|
661
|
-
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt
|
|
680
|
+
Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
|
|
662
681
|
|
|
663
682
|
# Ktz Kzz⁻¹ sqrt
|
|
664
683
|
Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt
|
|
@@ -672,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
|
|
|
672
691
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
673
692
|
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
|
|
674
693
|
)
|
|
675
|
-
|
|
694
|
+
if hasattr(covariance, "to_dense"):
|
|
695
|
+
covariance = covariance.to_dense()
|
|
696
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
697
|
+
covariance = Dense(covariance)
|
|
676
698
|
|
|
677
699
|
return GaussianDistribution(
|
|
678
700
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -729,13 +751,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
729
751
|
|
|
730
752
|
Kzx = kernel.cross_covariance(z, x)
|
|
731
753
|
Kzz = kernel.gram(z)
|
|
732
|
-
|
|
754
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
|
|
755
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
733
756
|
|
|
734
757
|
# Lz Lzᵀ = Kzz
|
|
735
758
|
Lz = lower_cholesky(Kzz)
|
|
736
759
|
|
|
737
760
|
# Lz⁻¹ Kzx
|
|
738
|
-
Lz_inv_Kzx = solve(Lz, Kzx
|
|
761
|
+
Lz_inv_Kzx = solve(Lz, Kzx)
|
|
739
762
|
|
|
740
763
|
# A = Lz⁻¹ Kzt / o
|
|
741
764
|
A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev.value
|
|
@@ -753,14 +776,14 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
753
776
|
Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff))
|
|
754
777
|
|
|
755
778
|
# Kzz⁻¹ Kzx (y - μx)
|
|
756
|
-
Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff
|
|
779
|
+
Kzz_inv_Kzx_diff = solve(Lz.T, Lz_inv_Kzx_diff)
|
|
757
780
|
|
|
758
781
|
Ktt = kernel.gram(t)
|
|
759
782
|
Kzt = kernel.cross_covariance(z, t)
|
|
760
783
|
mut = mean_function(t)
|
|
761
784
|
|
|
762
785
|
# Lz⁻¹ Kzt
|
|
763
|
-
Lz_inv_Kzt = solve(Lz, Kzt
|
|
786
|
+
Lz_inv_Kzt = solve(Lz, Kzt)
|
|
764
787
|
|
|
765
788
|
# L⁻¹ Lz⁻¹ Kzt
|
|
766
789
|
L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)
|
|
@@ -774,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
774
797
|
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
|
|
775
798
|
+ jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
|
|
776
799
|
)
|
|
777
|
-
|
|
800
|
+
if hasattr(covariance, "to_dense"):
|
|
801
|
+
covariance = covariance.to_dense()
|
|
802
|
+
covariance = add_jitter(covariance, self.jitter)
|
|
803
|
+
covariance = Dense(covariance)
|
|
778
804
|
|
|
779
805
|
return GaussianDistribution(
|
|
780
806
|
loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.2
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
18
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
18
19
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
19
|
-
Requires-Python:
|
|
20
|
+
Requires-Python: >=3.10
|
|
20
21
|
Requires-Dist: beartype>0.16.1
|
|
21
|
-
Requires-Dist: cola-ml>=0.0.7
|
|
22
22
|
Requires-Dist: flax>=0.10.0
|
|
23
23
|
Requires-Dist: jax>=0.5.0
|
|
24
24
|
Requires-Dist: jaxlib>=0.5.0
|
|
@@ -27,6 +27,47 @@ Requires-Dist: numpy>=2.0.0
|
|
|
27
27
|
Requires-Dist: numpyro
|
|
28
28
|
Requires-Dist: optax>0.2.1
|
|
29
29
|
Requires-Dist: tqdm>4.66.2
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
32
|
+
Requires-Dist: autoflake; extra == 'dev'
|
|
33
|
+
Requires-Dist: black; extra == 'dev'
|
|
34
|
+
Requires-Dist: codespell>=2.2.4; extra == 'dev'
|
|
35
|
+
Requires-Dist: coverage>=7.2.2; extra == 'dev'
|
|
36
|
+
Requires-Dist: interrogate>=1.5.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: isort; extra == 'dev'
|
|
38
|
+
Requires-Dist: jupytext; extra == 'dev'
|
|
39
|
+
Requires-Dist: mktestdocs>=0.2.1; extra == 'dev'
|
|
40
|
+
Requires-Dist: networkx; extra == 'dev'
|
|
41
|
+
Requires-Dist: pre-commit>=3.2.2; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest-beartype; extra == 'dev'
|
|
43
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: pytest-pretty>=1.1.1; extra == 'dev'
|
|
45
|
+
Requires-Dist: pytest-xdist>=3.2.1; extra == 'dev'
|
|
46
|
+
Requires-Dist: pytest>=7.2.2; extra == 'dev'
|
|
47
|
+
Requires-Dist: ruff>=0.6; extra == 'dev'
|
|
48
|
+
Requires-Dist: xdoctest>=1.1.1; extra == 'dev'
|
|
49
|
+
Provides-Extra: docs
|
|
50
|
+
Requires-Dist: blackjax>=0.9.6; extra == 'docs'
|
|
51
|
+
Requires-Dist: ipykernel>=6.22.0; extra == 'docs'
|
|
52
|
+
Requires-Dist: ipython>=8.11.0; extra == 'docs'
|
|
53
|
+
Requires-Dist: ipywidgets>=8.0.5; extra == 'docs'
|
|
54
|
+
Requires-Dist: jupytext>=1.14.5; extra == 'docs'
|
|
55
|
+
Requires-Dist: markdown-katex>=202406.1035; extra == 'docs'
|
|
56
|
+
Requires-Dist: matplotlib>=3.7.1; extra == 'docs'
|
|
57
|
+
Requires-Dist: mkdocs-gen-files>=0.5.0; extra == 'docs'
|
|
58
|
+
Requires-Dist: mkdocs-git-authors-plugin>=0.7.0; extra == 'docs'
|
|
59
|
+
Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
60
|
+
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
61
|
+
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
62
|
+
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
|
|
64
|
+
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
65
|
+
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
66
|
+
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
67
|
+
Requires-Dist: pymdown-extensions>=10.7.1; extra == 'docs'
|
|
68
|
+
Requires-Dist: scikit-learn>=1.5.1; extra == 'docs'
|
|
69
|
+
Requires-Dist: seaborn>=0.12.2; extra == 'docs'
|
|
70
|
+
Requires-Dist: watermark>=2.3.1; extra == 'docs'
|
|
30
71
|
Description-Content-Type: text/markdown
|
|
31
72
|
|
|
32
73
|
<!-- <h1 align='center'>GPJax</h1>
|
|
@@ -41,12 +82,12 @@ Description-Content-Type: text/markdown
|
|
|
41
82
|
[](https://badge.fury.io/py/GPJax)
|
|
42
83
|
[](https://doi.org/10.21105/joss.04455)
|
|
43
84
|
[](https://pepy.tech/project/gpjax)
|
|
44
|
-
[](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
85
|
+
[](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
45
86
|
|
|
46
87
|
[**Quickstart**](#simple-example)
|
|
47
88
|
| [**Install guide**](#installation)
|
|
48
89
|
| [**Documentation**](https://docs.jaxgaussianprocesses.com/)
|
|
49
|
-
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
90
|
+
| [**Slack Community**](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA)
|
|
50
91
|
|
|
51
92
|
GPJax aims to provide a low-level interface to Gaussian process (GP) models in
|
|
52
93
|
[Jax](https://github.com/google/jax), structured to give researchers maximum
|
|
@@ -81,22 +122,13 @@ behaviours through [this form](https://jaxgaussianprocesses.com/contact/) or rea
|
|
|
81
122
|
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).
|
|
82
123
|
|
|
83
124
|
Feel free to join our [Slack
|
|
84
|
-
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-
|
|
125
|
+
Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw7~~eMlCYA),
|
|
85
126
|
where we can discuss the development of GPJax and broader support for Gaussian
|
|
86
127
|
process modelling.
|
|
87
128
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
|
|
92
|
-
project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
|
|
93
|
-
[henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
|
|
94
|
-
[thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
|
|
95
|
-
governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
|
|
96
|
-
[the contributors to
|
|
97
|
-
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
|
|
98
|
-
helped to shape GPJax into the package it is today.
|
|
99
|
-
|
|
129
|
+
We appreciate all [the contributors to
|
|
130
|
+
GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
|
|
131
|
+
GPJax into the package it is today.
|
|
100
132
|
|
|
101
133
|
# Supported methods and interfaces
|
|
102
134
|
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
gpjax/__init__.py,sha256=RzwpixFXn6HNHLVLy4LVXhFUk2c-_ce6n1gjZ2B93F0,1641
|
|
2
|
+
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
|
+
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
+
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
|
+
gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
|
|
6
|
+
gpjax/gps.py,sha256=ipaeYMnPffhKK_JsEHe4fF8GmolQIjXB1YbyfUIL8H4,30118
|
|
7
|
+
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
+
gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
|
|
9
|
+
gpjax/mean_functions.py,sha256=KiHQXI-b7o0Vi5KQxGm6RNsUjitJc9jEOCq2GrSx4II,6531
|
|
10
|
+
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
|
+
gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
|
|
12
|
+
gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
|
|
13
|
+
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
|
+
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
+
gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
|
|
16
|
+
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
|
+
gpjax/kernels/base.py,sha256=4Lx8y3kPX4WqQZGRGAsBkqn_i6FlfoAhSn9Tv415xuQ,11551
|
|
18
|
+
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
+
gpjax/kernels/approximations/rff.py,sha256=GbNUmDPEKEKuMwxUcocxl_9IFR3Q9KEPZXzjy_ZD-2w,4043
|
|
20
|
+
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
21
|
+
gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
|
|
22
|
+
gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
|
|
23
|
+
gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
|
|
24
|
+
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
25
|
+
gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
|
|
26
|
+
gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
|
|
27
|
+
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
28
|
+
gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
|
|
29
|
+
gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
|
|
30
|
+
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
|
+
gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
|
|
32
|
+
gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
|
|
33
|
+
gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
|
|
34
|
+
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
|
+
gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
|
|
36
|
+
gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
|
|
37
|
+
gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
|
|
38
|
+
gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
|
|
39
|
+
gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
|
|
40
|
+
gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
|
|
41
|
+
gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
|
|
42
|
+
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
43
|
+
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
44
|
+
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
45
|
+
gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
46
|
+
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
|
+
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
|
+
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
+
gpjax-0.12.2.dist-info/METADATA,sha256=eckQKXiBXi8XbBeJFviBAIPdBGVWGFQg7wVZwMfPPxs,10129
|
|
50
|
+
gpjax-0.12.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
51
|
+
gpjax-0.12.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.12.2.dist-info/RECORD,,
|
gpjax/lower_cholesky.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from cola.annotations import PSD
|
|
17
|
-
from cola.fns import dispatch
|
|
18
|
-
from cola.ops.operator_base import LinearOperator
|
|
19
|
-
from cola.ops.operators import (
|
|
20
|
-
BlockDiag,
|
|
21
|
-
Diagonal,
|
|
22
|
-
Identity,
|
|
23
|
-
Kronecker,
|
|
24
|
-
Triangular,
|
|
25
|
-
)
|
|
26
|
-
import jax.numpy as jnp
|
|
27
|
-
|
|
28
|
-
# TODO: Once this functionality is supported in CoLA, remove this.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@dispatch
|
|
32
|
-
def lower_cholesky(A: LinearOperator) -> Triangular: # noqa: F811
|
|
33
|
-
"""Returns the lower Cholesky factor of a linear operator.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
A: The input linear operator.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
Triangular: The lower Cholesky factor of A.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
if PSD not in A.annotations:
|
|
43
|
-
raise ValueError(
|
|
44
|
-
"Expected LinearOperator to be PSD, did you forget to use cola.PSD?"
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
@lower_cholesky.dispatch
|
|
51
|
-
def _(A: Diagonal): # noqa: F811
|
|
52
|
-
return Diagonal(jnp.sqrt(A.diag))
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
@lower_cholesky.dispatch
|
|
56
|
-
def _(A: Identity): # noqa: F811
|
|
57
|
-
return A
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@lower_cholesky.dispatch
|
|
61
|
-
def _(A: Kronecker): # noqa: F811
|
|
62
|
-
return Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms])
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@lower_cholesky.dispatch
|
|
66
|
-
def _(A: BlockDiag): # noqa: F811
|
|
67
|
-
return BlockDiag(
|
|
68
|
-
*[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities
|
|
69
|
-
)
|
gpjax-0.11.2.dist-info/RECORD
DELETED
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=ylCFMtXwcMS2zxm4pI3KsnRdnX6bdh26TSdTfUh9l9o,1686
|
|
2
|
-
gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
|
|
3
|
-
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
-
gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
|
|
5
|
-
gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
|
|
6
|
-
gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
|
|
7
|
-
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
|
|
9
|
-
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
|
-
gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
|
|
11
|
-
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
12
|
-
gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
|
|
13
|
-
gpjax/parameters.py,sha256=H-DiXmotdBZCbf-GOjRaJoS_isk3GgFrpKHTq5GpnoA,6998
|
|
14
|
-
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
15
|
-
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
16
|
-
gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
|
|
17
|
-
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
18
|
-
gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
|
|
19
|
-
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
20
|
-
gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
|
|
21
|
-
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
22
|
-
gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
|
|
23
|
-
gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
|
|
24
|
-
gpjax/kernels/computations/constant_diagonal.py,sha256=_4fIFPq76Z416-9dIr8N335WP291dGluO-RqqUsK9ZY,2058
|
|
25
|
-
gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
|
|
26
|
-
gpjax/kernels/computations/diagonal.py,sha256=z2JpUue7oY-pL-c0Pc6Bngv_IJCR6z4MaW7kN0wgjxk,1803
|
|
27
|
-
gpjax/kernels/computations/eigen.py,sha256=w7I7LK42j0ouchHCI1ltXx0lpwqvK1bRb4HclnF3rKs,1936
|
|
28
|
-
gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
|
|
29
|
-
gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
|
|
30
|
-
gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
|
|
31
|
-
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
32
|
-
gpjax/kernels/nonstationary/arccosine.py,sha256=2WV6aM0Z3-xXZnoPw-77n2CW62n-AZuJy-7AQ9xrMco,5858
|
|
33
|
-
gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
|
|
34
|
-
gpjax/kernels/nonstationary/polynomial.py,sha256=arP8DK0jnBOaayDWcFvHF0pdu9FVhwzXdqjnHUAL2VI,3293
|
|
35
|
-
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
36
|
-
gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
|
|
37
|
-
gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
|
|
38
|
-
gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
|
|
39
|
-
gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
|
|
40
|
-
gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
|
|
41
|
-
gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
|
|
42
|
-
gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
|
|
43
|
-
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
44
|
-
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
45
|
-
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
46
|
-
gpjax-0.11.2.dist-info/METADATA,sha256=lTQVlrUbkxI7fU9Gdnac_eoNRyjCHEoEuEvvWbKmDqM,8558
|
|
47
|
-
gpjax-0.11.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
48
|
-
gpjax-0.11.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
49
|
-
gpjax-0.11.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|