gpjax 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. gpjax/__init__.py +3 -3
  2. gpjax/dataset.py +1 -1
  3. gpjax/fit.py +1 -1
  4. gpjax/gps.py +197 -64
  5. gpjax/kernels/__init__.py +1 -1
  6. gpjax/kernels/base.py +2 -2
  7. gpjax/kernels/computations/__init__.py +1 -1
  8. gpjax/kernels/computations/base.py +1 -1
  9. gpjax/kernels/computations/constant_diagonal.py +1 -1
  10. gpjax/kernels/computations/dense.py +1 -1
  11. gpjax/kernels/computations/diagonal.py +1 -1
  12. gpjax/kernels/computations/eigen.py +1 -1
  13. gpjax/kernels/non_euclidean/__init__.py +1 -1
  14. gpjax/kernels/non_euclidean/graph.py +18 -6
  15. gpjax/kernels/non_euclidean/utils.py +1 -1
  16. gpjax/kernels/nonstationary/__init__.py +1 -1
  17. gpjax/kernels/nonstationary/arccosine.py +1 -1
  18. gpjax/kernels/nonstationary/linear.py +1 -1
  19. gpjax/kernels/nonstationary/polynomial.py +1 -1
  20. gpjax/kernels/stationary/__init__.py +1 -1
  21. gpjax/kernels/stationary/base.py +1 -1
  22. gpjax/kernels/stationary/matern12.py +1 -1
  23. gpjax/kernels/stationary/matern32.py +1 -1
  24. gpjax/kernels/stationary/matern52.py +1 -1
  25. gpjax/kernels/stationary/periodic.py +1 -1
  26. gpjax/kernels/stationary/powered_exponential.py +1 -1
  27. gpjax/kernels/stationary/rational_quadratic.py +1 -1
  28. gpjax/kernels/stationary/rbf.py +1 -1
  29. gpjax/kernels/stationary/utils.py +1 -1
  30. gpjax/kernels/stationary/white.py +1 -1
  31. gpjax/scan.py +1 -1
  32. {gpjax-0.13.2.dist-info → gpjax-0.13.3.dist-info}/METADATA +12 -12
  33. gpjax-0.13.3.dist-info/RECORD +52 -0
  34. gpjax-0.13.2.dist-info/RECORD +0 -52
  35. {gpjax-0.13.2.dist-info → gpjax-0.13.3.dist-info}/WHEEL +0 -0
  36. {gpjax-0.13.2.dist-info → gpjax-0.13.3.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py CHANGED
@@ -38,9 +38,9 @@ from gpjax.fit import (
38
38
 
39
39
  __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
- __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
- __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.13.2"
41
+ __url__ = "https://github.com/thomaspinder/GPJax"
42
+ __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
43
+ __version__ = "0.13.3"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
gpjax/dataset.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/fit.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/gps.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # you may not use this file except in compliance with the License.
5
5
  # You may obtain a copy of the License at
6
6
  #
7
- # http://www.apache.org/licenses/LICENSE-2.0
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
8
  #
9
9
  # Unless required by applicable law or agreed to in writing, software
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # from __future__ import annotations
16
+
16
17
  from abc import abstractmethod
18
+ from typing import Literal
17
19
 
18
20
  import beartype.typing as tp
19
21
  from flax import nnx
22
+ import jax
20
23
  import jax.numpy as jnp
21
24
  import jax.random as jr
22
25
  from jaxtyping import (
@@ -35,10 +38,13 @@ from gpjax.likelihoods import (
35
38
  )
36
39
  from gpjax.linalg import (
37
40
  Dense,
41
+ Diagonal,
38
42
  psd,
39
43
  solve,
40
44
  )
41
- from gpjax.linalg.operations import lower_cholesky
45
+ from gpjax.linalg.operations import (
46
+ lower_cholesky,
47
+ )
42
48
  from gpjax.linalg.utils import add_jitter
43
49
  from gpjax.mean_functions import AbstractMeanFunction
44
50
  from gpjax.parameters import (
@@ -77,7 +83,12 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
77
83
  self.mean_function = mean_function
78
84
  self.jitter = jitter
79
85
 
80
- def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
86
+ def __call__(
87
+ self,
88
+ test_inputs: Num[Array, "N D"],
89
+ *,
90
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
91
+ ) -> GaussianDistribution:
81
92
  r"""Evaluate the Gaussian process at the given points.
82
93
 
83
94
  The output of this function is a
@@ -91,15 +102,27 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
91
102
 
92
103
  Args:
93
104
  test_inputs: Input locations where the GP should be evaluated.
105
+ return_covariance_type: Literal denoting whether to return the full covariance
106
+ of the joint predictive distribution at the test_inputs (dense)
107
+ or just the the standard-deviation of the predictive distribution at
108
+ the test_inputs.
94
109
 
95
110
  Returns:
96
111
  GaussianDistribution: A multivariate normal random variable representation
97
112
  of the Gaussian process.
98
113
  """
99
- return self.predict(test_inputs)
114
+ return self.predict(
115
+ test_inputs,
116
+ return_covariance_type=return_covariance_type,
117
+ )
100
118
 
101
119
  @abstractmethod
102
- def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
120
+ def predict(
121
+ self,
122
+ test_inputs: Num[Array, "N D"],
123
+ *,
124
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
125
+ ) -> GaussianDistribution:
103
126
  r"""Evaluate the predictive distribution.
104
127
 
105
128
  Compute the latent function's multivariate normal distribution for a
@@ -108,6 +131,10 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
108
131
 
109
132
  Args:
110
133
  test_inputs: Input locations where the GP should be evaluated.
134
+ return_covariance_type: Literal denoting whether to return the full covariance
135
+ of the joint predictive distribution at the test_inputs (dense)
136
+ or just the the standard-deviation of the predictive distribution at
137
+ the test_inputs.
111
138
 
112
139
  Returns:
113
140
  GaussianDistribution: A multivariate normal random variable representation
@@ -123,8 +150,8 @@ class Prior(AbstractPrior[M, K]):
123
150
  r"""A Gaussian process prior object.
124
151
 
125
152
  The GP is parameterised by a
126
- [mean](https://docs.jaxgaussianprocesses.com/api/mean_functions/)
127
- and [kernel](https://docs.jaxgaussianprocesses.com/api/kernels/base/)
153
+ [mean](https://docs.thomaspinder.com/api/mean_functions/)
154
+ and [kernel](https://docs.thomaspinder.com/api/kernels/base/)
128
155
  function.
129
156
 
130
157
  A Gaussian process prior parameterised by a mean function $m(\cdot)$ and a kernel
@@ -220,7 +247,12 @@ class Prior(AbstractPrior[M, K]):
220
247
  """
221
248
  return self.__mul__(other)
222
249
 
223
- def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
250
+ def predict(
251
+ self,
252
+ test_inputs: Num[Array, "N D"],
253
+ *,
254
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
255
+ ) -> GaussianDistribution:
224
256
  r"""Compute the predictive prior distribution for a given set of
225
257
  parameters. The output of this function is a function that computes
226
258
  a TFP distribution for a given set of inputs.
@@ -241,17 +273,43 @@ class Prior(AbstractPrior[M, K]):
241
273
  Args:
242
274
  test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the
243
275
  prior distribution.
276
+ return_covariance_type: Literal denoting whether to return the full covariance
277
+ of the joint predictive distribution at the test_inputs (dense)
278
+ or just the the standard-deviation of the predictive distribution at
279
+ the test_inputs.
244
280
 
245
281
  Returns:
246
282
  GaussianDistribution: A multivariate normal random variable representation
247
283
  of the Gaussian process.
248
284
  """
285
+
286
+ def _return_full_covariance(
287
+ t: Num[Array, "N D"],
288
+ ) -> Dense:
289
+ Kxx = self.kernel.gram(t)
290
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
291
+ Kxx = psd(Dense(Kxx_dense))
292
+ return Kxx
293
+
294
+ def _return_diagonal_covariance(
295
+ t: Num[Array, "N D"],
296
+ ) -> Dense:
297
+ Kxx = self.kernel.diagonal(t).diagonal
298
+ Kxx += self.jitter
299
+ Kxx = psd(Dense(Diagonal(Kxx).to_dense()))
300
+ return Kxx
301
+
249
302
  mean_at_test = self.mean_function(test_inputs)
250
- Kxx = self.kernel.gram(test_inputs)
251
- Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
252
- Kxx = psd(Dense(Kxx_dense))
303
+ cov = jax.lax.cond(
304
+ return_covariance_type == "dense",
305
+ _return_full_covariance,
306
+ _return_diagonal_covariance,
307
+ test_inputs,
308
+ )
253
309
 
254
- return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
310
+ return GaussianDistribution(
311
+ loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov
312
+ )
255
313
 
256
314
  def sample_approx(
257
315
  self,
@@ -329,7 +387,7 @@ P = tp.TypeVar("P", bound=AbstractPrior)
329
387
 
330
388
  #######################
331
389
  # GP Posteriors
332
- #######################
390
+ #######################from gpjax.linalg.operators import LinearOperator
333
391
  class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
334
392
  r"""Abstract Gaussian process posterior.
335
393
 
@@ -356,7 +414,11 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
356
414
  self.jitter = jitter
357
415
 
358
416
  def __call__(
359
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
417
+ self,
418
+ test_inputs: Num[Array, "N D"],
419
+ train_data: Dataset,
420
+ *,
421
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
360
422
  ) -> GaussianDistribution:
361
423
  r"""Evaluate the Gaussian process posterior at the given points.
362
424
 
@@ -372,16 +434,28 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
372
434
  Args:
373
435
  test_inputs: Input locations where the GP should be evaluated.
374
436
  train_data: Training dataset to condition on.
437
+ return_covariance_type: Literal denoting whether to return the full covariance
438
+ of the joint predictive distribution at the test_inputs (dense)
439
+ or just the the standard-deviation of the predictive distribution at
440
+ the test_inputs.
375
441
 
376
442
  Returns:
377
443
  GaussianDistribution: A multivariate normal random variable representation
378
444
  of the Gaussian process.
379
445
  """
380
- return self.predict(test_inputs, train_data)
446
+ return self.predict(
447
+ test_inputs,
448
+ train_data,
449
+ return_covariance_type=return_covariance_type,
450
+ )
381
451
 
382
452
  @abstractmethod
383
453
  def predict(
384
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
454
+ self,
455
+ test_inputs: Num[Array, "N D"],
456
+ train_data: Dataset,
457
+ *,
458
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
385
459
  ) -> GaussianDistribution:
386
460
  r"""Compute the latent function's multivariate normal distribution for a
387
461
  given set of parameters. For any class inheriting the `AbstractPosterior` class,
@@ -390,6 +464,10 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
390
464
  Args:
391
465
  test_inputs: Input locations where the GP should be evaluated.
392
466
  train_data: Training dataset to condition on.
467
+ return_covariance_type: Literal denoting whether to return the full covariance
468
+ of the joint predictive distribution at the test_inputs (dense)
469
+ or just the the standard-deviation of the predictive distribution at
470
+ the test_inputs.
393
471
 
394
472
  Returns:
395
473
  GaussianDistribution: A multivariate normal random variable representation
@@ -442,8 +520,10 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
442
520
 
443
521
  def predict(
444
522
  self,
445
- test_inputs: Num[Array, "N D"],
523
+ test_inputs: Num[Array, "M D"],
446
524
  train_data: Dataset,
525
+ *,
526
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
447
527
  ) -> GaussianDistribution:
448
528
  r"""Query the predictive posterior distribution.
449
529
 
@@ -454,13 +534,13 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
454
534
 
455
535
  The predictive distribution of a conjugate GP is given by
456
536
  $$
457
- p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
458
- & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
537
+ p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
538
+ & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
459
539
  $$
460
540
  where
461
541
  $$
462
- \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
463
- \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
542
+ \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
543
+ \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
464
544
  $$
465
545
 
466
546
  The conditioning set is a GPJax `Dataset` object, whilst predictions
@@ -486,44 +566,65 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
486
566
  predictive distribution is evaluated.
487
567
  train_data (Dataset): A `gpx.Dataset` object that contains the input and
488
568
  output data used for training dataset.
569
+ return_covariance_type: Literal denoting whether to return the full covariance
570
+ of the joint predictive distribution at the test_inputs (dense)
571
+ or just the the standard-deviation of the predictive distribution at
572
+ the test_inputs.
489
573
 
490
574
  Returns:
491
575
  GaussianDistribution: A function that accepts an input array and
492
576
  returns the predictive distribution as a `GaussianDistribution`.
493
577
  """
494
- # Unpack training data
495
- x, y = train_data.X, train_data.y
496
-
497
- # Unpack test inputs
498
- t = test_inputs
499
-
578
+ x = train_data.X
579
+ y = train_data.y
500
580
  # Observation noise o²
501
- obs_noise = self.likelihood.obs_stddev.value**2
581
+ obs_noise = jnp.square(self.likelihood.obs_stddev.value)
502
582
  mx = self.prior.mean_function(x)
503
-
504
583
  # Precompute Gram matrix, Kxx, at training inputs, x
505
584
  Kxx = self.prior.kernel.gram(x)
506
- Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
507
- Kxx = Dense(Kxx_dense)
585
+ Kxx = add_jitter(Kxx.to_dense(), self.jitter)
508
586
 
509
- Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
587
+ Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise
510
588
  Sigma = psd(Dense(Sigma_dense))
511
589
  L_sigma = lower_cholesky(Sigma)
512
590
 
513
- mean_t = self.prior.mean_function(t)
514
- Ktt = self.prior.kernel.gram(t)
515
- Kxt = self.prior.kernel.cross_covariance(x, t)
591
+ Kxt = self.prior.kernel.cross_covariance(x, test_inputs)
516
592
 
517
593
  L_inv_Kxt = solve(L_sigma, Kxt)
518
594
  L_inv_y_diff = solve(L_sigma, y - mx)
519
595
 
596
+ mean_t = self.prior.mean_function(test_inputs)
520
597
  mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
521
598
 
522
- covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
523
- covariance = add_jitter(covariance, self.prior.jitter)
524
- covariance = psd(Dense(covariance))
525
-
526
- return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
599
+ def _return_full_covariance(
600
+ L_inv_Kxt: Num[Array, "N M"],
601
+ t: Num[Array, "M D"],
602
+ ) -> Dense:
603
+ Ktt = self.prior.kernel.gram(t)
604
+ covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
605
+ covariance = add_jitter(covariance, self.prior.jitter)
606
+ covariance = psd(Dense(covariance))
607
+ return covariance
608
+
609
+ def _return_diagonal_covariance(
610
+ L_inv_Kxt: Num[Array, "N M"],
611
+ t: Num[Array, "M D"],
612
+ ) -> Dense:
613
+ Ktt = self.prior.kernel.diagonal(t).diagonal
614
+ covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt)
615
+ covariance += self.prior.jitter
616
+ covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
617
+ return covariance
618
+
619
+ cov = jax.lax.cond(
620
+ return_covariance_type == "dense",
621
+ _return_full_covariance,
622
+ _return_diagonal_covariance,
623
+ L_inv_Kxt,
624
+ test_inputs,
625
+ )
626
+
627
+ return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov)
527
628
 
528
629
  def sample_approx(
529
630
  self,
@@ -567,7 +668,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
567
668
 
568
669
  Returns:
569
670
  FunctionalSample: A function representing an approximate sample from the Gaussian
570
- process prior.
671
+ process prior.
571
672
  """
572
673
  if (not isinstance(num_samples, int)) or num_samples <= 0:
573
674
  raise ValueError("num_samples must be a positive integer")
@@ -586,7 +687,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
586
687
  canonical_weights = solve(
587
688
  Sigma,
588
689
  y + eps - jnp.inner(Phi, fourier_weights),
589
- ) # [N, B]
690
+ ) # [N, B]
590
691
 
591
692
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
592
693
  fourier_features = fourier_feature_fn(test_inputs)
@@ -648,7 +749,11 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
648
749
  self.key = key
649
750
 
650
751
  def predict(
651
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
752
+ self,
753
+ test_inputs: Num[Array, "M D"],
754
+ train_data: Dataset,
755
+ *,
756
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
652
757
  ) -> GaussianDistribution:
653
758
  r"""Query the predictive posterior distribution.
654
759
 
@@ -660,50 +765,78 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
660
765
  transformed through the likelihood function's inverse link function.
661
766
 
662
767
  Args:
663
- train_data (Dataset): A `gpx.Dataset` object that contains the input
664
- and output data used for training dataset.
768
+ test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the
769
+ predictive distribution is evaluated.
770
+ train_data (Dataset): A `gpx.Dataset` object that contains the input
771
+ and output data used for training dataset.
772
+ return_covariance_type: Literal denoting whether to return the full covariance
773
+ of the joint predictive distribution at the test_inputs (dense)
774
+ or just the the standard-deviation of the predictive distribution at
775
+ the test_inputs.
665
776
 
666
777
  Returns:
667
778
  GaussianDistribution: A function that accepts an
668
779
  input array and returns the predictive distribution as
669
780
  a `dx.Distribution`.
670
781
  """
671
- # Unpack training data
672
782
  x = train_data.X
673
-
674
- # Unpack mean function and kernel
783
+ t = test_inputs
675
784
  mean_function = self.prior.mean_function
676
785
  kernel = self.prior.kernel
677
786
 
678
- # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
787
+ # Precompute lower triangular of Gram matrix
679
788
  Kxx = kernel.gram(x)
680
789
  Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
681
790
  Kxx = psd(Dense(Kxx_dense))
682
791
  Lx = lower_cholesky(Kxx)
683
792
 
684
- # Unpack test inputs
685
- t = test_inputs
686
-
687
- # Compute terms of the posterior predictive distribution
688
- Ktx = kernel.cross_covariance(t, x)
689
- Ktt = kernel.gram(t)
690
- mean_t = mean_function(t)
691
-
793
+ Kxt = kernel.cross_covariance(x, t)
692
794
  # Lx⁻¹ Kxt
693
- Lx_inv_Kxt = solve(Lx, Ktx.T)
795
+ Lx_inv_Kxt = solve(Lx, Kxt)
694
796
 
797
+ mean_t = mean_function(t)
695
798
  # Whitened function values, wx, corresponding to the inputs, x
696
799
  wx = self.latent.value
697
800
 
698
801
  # μt + Ktx Lx⁻¹ wx
699
802
  mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
700
803
 
701
- # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
702
- covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
703
- covariance = add_jitter(covariance, self.prior.jitter)
704
- covariance = psd(Dense(covariance))
705
-
706
- return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
804
+ def _return_full_covariance(
805
+ Lx_inv_Kxt: Num[Array, "N M"],
806
+ t: Num[Array, "M D"],
807
+ ) -> Dense:
808
+ Ktt = kernel.gram(t)
809
+ covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
810
+ covariance = add_jitter(covariance, self.prior.jitter)
811
+ covariance = psd(Dense(covariance))
812
+
813
+ return covariance
814
+
815
+ def _return_diagonal_covariance(
816
+ Lx_inv_Kxt: Num[Array, "N M"],
817
+ t: Num[Array, "M D"],
818
+ ) -> Dense:
819
+ Ktt = kernel.diagonal(t).diagonal
820
+ covariance = Ktt - jnp.einsum("ij, ji->i", Lx_inv_Kxt.T, Lx_inv_Kxt)
821
+ covariance += self.prior.jitter
822
+ # It would be nice to return a Diagonal here, but the pytree needs
823
+ # to be the same for both cond branches and the other branch needs
824
+ # to return a Dense.
825
+ # They are both LinearOperators, but they inherit from that class
826
+ # and hence are not the same pytree anymore.
827
+ covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
828
+
829
+ return covariance
830
+
831
+ cov = jax.lax.cond(
832
+ return_covariance_type == "dense",
833
+ _return_full_covariance,
834
+ _return_diagonal_covariance,
835
+ Lx_inv_Kxt,
836
+ test_inputs,
837
+ )
838
+
839
+ return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
707
840
 
708
841
 
709
842
  #######################
gpjax/kernels/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/kernels/base.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -123,7 +123,7 @@ class AbstractKernel(nnx.Module):
123
123
  """
124
124
  return self.compute_engine.gram(self, x)
125
125
 
126
- def diagonal(self, x: Num[Array, "N D"]) -> Float[Array, " N"]:
126
+ def diagonal(self, x: Num[Array, "N D"]) -> LinearOperator:
127
127
  r"""Compute the diagonal of the gram matrix of the kernel.
128
128
 
129
129
  Args:
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@ import beartype.typing as tp
17
17
  import jax.numpy as jnp
18
18
  from jaxtyping import (
19
19
  Float,
20
- Int,
20
+ Integer,
21
21
  Num,
22
22
  )
23
23
 
@@ -103,11 +103,23 @@ class GraphKernel(StationaryKernel):
103
103
 
104
104
  def __call__(
105
105
  self,
106
- x: Int[Array, "N 1"],
107
- y: Int[Array, "M 1"],
106
+ x: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"],
107
+ y: ScalarInt | Integer[Array, " M"] | Integer[Array, "M 1"],
108
108
  ):
109
+ x_idx = self._prepare_indices(x)
110
+ y_idx = self._prepare_indices(y)
109
111
  S = calculate_heat_semigroup(self)
110
- Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
111
- jax_gather_nd(self.eigenvectors, y)
112
+ Kxx = (jax_gather_nd(self.eigenvectors, x_idx) * S.squeeze()) @ jnp.transpose(
113
+ jax_gather_nd(self.eigenvectors, y_idx)
112
114
  ) # shape (n,n)
113
115
  return Kxx.squeeze()
116
+
117
+ def _prepare_indices(
118
+ self,
119
+ indices: ScalarInt | Integer[Array, " N"] | Integer[Array, "N 1"],
120
+ ) -> Integer[Array, "N 1"]:
121
+ """Ensure index arrays are integer column vectors regardless of caller shape."""
122
+
123
+ idx = jnp.asarray(indices, dtype=jnp.int32)
124
+ idx = jnp.atleast_1d(idx)
125
+ return idx.reshape(-1, 1)
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/scan.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.2
3
+ Version: 0.13.3
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
- Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
7
- Project-URL: Source, https://github.com/JaxGaussianProcesses/GPJax
6
+ Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
7
+ Project-URL: Source, https://github.com/thomaspinder/GPJax
8
8
  Author-email: Thomas Pinder <tompinder@live.co.uk>
9
9
  License: MIT
10
10
  License-File: LICENSE.txt
@@ -72,11 +72,11 @@ Description-Content-Type: text/markdown
72
72
  <!-- <h1 align='center'>GPJax</h1>
73
73
  <h2 align='center'>Gaussian processes in Jax.</h2> -->
74
74
  <p align="center">
75
- <img width="700" height="300" src="https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/static/gpjax_logo.svg" alt="GPJax's logo">
75
+ <img width="700" height="300" src="https://raw.githubusercontent.com/thomaspinder/GPJax/main/docs/static/gpjax_logo.svg" alt="GPJax's logo">
76
76
  </p>
77
77
 
78
- [![codecov](https://codecov.io/gh/JaxGaussianProcesses/GPJax/branch/master/graph/badge.svg?token=DM1DRDASU2)](https://codecov.io/gh/JaxGaussianProcesses/GPJax)
79
- [![CodeFactor](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax/badge)](https://www.codefactor.io/repository/github/jaxgaussianprocesses/gpjax)
78
+ [![codecov](https://codecov.io/gh/thomaspinder/GPJax/branch/master/graph/badge.svg?token=DM1DRDASU2)](https://codecov.io/gh/thomaspinder/GPJax)
79
+ [![CodeFactor](https://www.codefactor.io/repository/github/thomaspinder/GPJax/badge)](https://www.codefactor.io/repository/github/thomaspinder/GPJax)
80
80
  [![Netlify Status](https://api.netlify.com/api/v1/badges/d3950e6f-321f-4508-9e52-426b5dae2715/deploy-status)](https://app.netlify.com/sites/endearing-crepe-c2d5fe/deploys)
81
81
  [![PyPI version](https://badge.fury.io/py/GPJax.svg)](https://badge.fury.io/py/GPJax)
82
82
  [![Conda Version](https://img.shields.io/conda/vn/conda-forge/gpjax.svg)](https://anaconda.org/conda-forge/gpjax)
@@ -101,19 +101,19 @@ with GP models.
101
101
 
102
102
  We would be delighted to receive contributions from interested individuals and
103
103
  groups. To learn how you can get involved, please read our [guide for
104
- contributing](https://github.com/JaxGaussianProcesses/GPJax/blob/main/docs/contributing.md).
104
+ contributing](https://github.com/thomaspinder/GPJax/blob/main/docs/contributing.md).
105
105
  If you have any questions, we encourage you to [open an
106
- issue](https://github.com/JaxGaussianProcesses/GPJax/issues/new/choose). For
106
+ issue](https://github.com/thomaspinder/GPJax/issues/new/choose). For
107
107
  broader conversations, such as best GP fitting practices or questions about the
108
108
  mathematics of GPs, we invite you to [open a
109
- discussion](https://github.com/JaxGaussianProcesses/GPJax/discussions).
109
+ discussion](https://github.com/thomaspinder/GPJax/discussions).
110
110
 
111
111
  Another way you can contribute to GPJax is through [issue
112
112
  triaging](https://www.codetriage.com/what). This can include reproducing bug reports,
113
113
  asking for vital information such as version numbers and reproduction instructions, or
114
114
  identifying stale issues. If you would like to begin triaging issues, an easy way to get
115
115
  started is to
116
- [subscribe to GPJax on CodeTriage](https://www.codetriage.com/jaxgaussianprocesses/gpjax).
116
+ [subscribe to GPJax on CodeTriage](https://www.codetriage.com/thomaspinder/GPJax).
117
117
 
118
118
  As a contributor to GPJax, you are expected to abide by our [code of
119
119
  conduct](docs/CODE_OF_CONDUCT.md). If you feel that you have either experienced or
@@ -127,7 +127,7 @@ where we can discuss the development of GPJax and broader support for Gaussian
127
127
  process modelling.
128
128
 
129
129
  We appreciate all [the contributors to
130
- GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
130
+ GPJax](https://github.com/thomaspinder/GPJax/graphs/contributors) who have helped to shape
131
131
  GPJax into the package it is today.
132
132
 
133
133
  # Supported methods and interfaces
@@ -215,7 +215,7 @@ conda install --channel conda-forge gpjax
215
215
  Clone a copy of the repository to your local machine and run the setup
216
216
  configuration in development mode.
217
217
  ```bash
218
- git clone https://github.com/JaxGaussianProcesses/GPJax.git
218
+ git clone https://github.com/thomaspinder/GPJax.git
219
219
  cd GPJax
220
220
  uv venv
221
221
  uv sync --extra dev
@@ -0,0 +1,52 @@
1
+ gpjax/__init__.py,sha256=8EULPS_vtq4TeN6MjdLVHLlTVVyQADOoHtuVoRN8z5Y,1625
2
+ gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
+ gpjax/dataset.py,sha256=Ef5JGrl4jJS1mQmL3JdO0fdqbVmflT_Cu5VrlpYdJY4,4071
4
+ gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
+ gpjax/fit.py,sha256=tOXmM3l5-N3Jlnq8MOVkSpRj-0fRWOy2t9GxQRyUqxY,15318
6
+ gpjax/gps.py,sha256=NcPXkkx0kXrSBTUje4QR6PS0DGrD7p5c-DyiOATwUz8,35338
7
+ gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
+ gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
+ gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
10
+ gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
+ gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
+ gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
+ gpjax/scan.py,sha256=Z_V1yd76dL4B7_rnJnr_fohom6xzN_WRYxlTAAdqfa0,5357
14
+ gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
+ gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
16
+ gpjax/kernels/__init__.py,sha256=GFiku1s8KPPzvfwDYzEjENPQReywvJ29M9Kl2JjkcoU,1893
17
+ gpjax/kernels/base.py,sha256=gov60BgXOQK0PxbKQRVgjU55kCVZYAkDZrAInODvugo,11549
18
+ gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
+ gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
20
+ gpjax/kernels/computations/__init__.py,sha256=gSkYASLcyPwcIg12Jjv7vxW9rixMVUhnlcl07iNvN_8,1371
21
+ gpjax/kernels/computations/base.py,sha256=Fh_ppC7JwrbqKNJj1JqDOaUUGm23Zdm7RRLnGbQtePI,3608
22
+ gpjax/kernels/computations/basis_functions.py,sha256=Y6OS4UEvo-Mdn7yRMWPde5rOvewDHIleaUIVBOpCTd4,2466
23
+ gpjax/kernels/computations/constant_diagonal.py,sha256=ihxR0rcwcLi_sv3a026nEXRObMvmScBLOKj09XGwuns,1976
24
+ gpjax/kernels/computations/dense.py,sha256=TGbAXSoaqo3Erzsaj7NztrEN2a_sHW5mi3bPbA1W9nc,1384
25
+ gpjax/kernels/computations/diagonal.py,sha256=4V8aVXqC61GdgmFTvxDhmtMw1uSPi_C-m-d9-Fo6B8U,1760
26
+ gpjax/kernels/computations/eigen.py,sha256=koVBKeLrHZX77PLlrykM-h2FN5g7l7S5xoyGUKD69Oc,1372
27
+ gpjax/kernels/non_euclidean/__init__.py,sha256=Jb3K9a5-H-kNOirkrJJ46e3sgJIvoQS1-LfSsNXspJg,782
28
+ gpjax/kernels/non_euclidean/graph.py,sha256=SevOCNeVZ5eK0MPhYGLnohAvNqig0C5c-oOpKuUKxeA,4631
29
+ gpjax/kernels/non_euclidean/utils.py,sha256=PANh8fPOXnV8x6liTCJ08YYnDIQ33PrpLD-9hzjSg9w,2334
30
+ gpjax/kernels/nonstationary/__init__.py,sha256=X6WVSMrejtLH6k_ct-1uZ9dmdmd_IEBIneweceXgXI0,922
31
+ gpjax/kernels/nonstationary/arccosine.py,sha256=hjh6umWkSYs2xWa0rJQDPnmmkvk-vAwbGhMrkCev70o,5400
32
+ gpjax/kernels/nonstationary/linear.py,sha256=xjeG4DN5K3-fgRw-HTFvPfcxYfYO5vPxdZc6K5qvnW8,2563
33
+ gpjax/kernels/nonstationary/polynomial.py,sha256=D_czngVP0VGqFlqi_uJBayKegYs1Q-sHTCWCRDUWVAo,3380
34
+ gpjax/kernels/stationary/__init__.py,sha256=b8KuCIENnW1nVQv3KdlZmOicfgIdoZzqgvVKlkgzkUk,1398
35
+ gpjax/kernels/stationary/base.py,sha256=zIg21SK5nTg9kW0RJy6xXIp7XF7TjZxJT9Vvpsscz4M,6030
36
+ gpjax/kernels/stationary/matern12.py,sha256=FbhdY5xlkgoCCOizVd7ZKK2XjJOCMDmfn2bTdxTOIKw,1763
37
+ gpjax/kernels/stationary/matern32.py,sha256=tHHMN3pNMf5sGZNyCowi1RUUB0SENdkjHgQE8ZS3kMU,1956
38
+ gpjax/kernels/stationary/matern52.py,sha256=TjS2_R9tYga8ePpSd5UIBTKDCsGWJX8AlUVvElYAKYk,2009
39
+ gpjax/kernels/stationary/periodic.py,sha256=sCGWFYSTM6P91AFw9qTThAtzP5spatNrcU66Lr9qQjU,3520
40
+ gpjax/kernels/stationary/powered_exponential.py,sha256=Z0q9B6mZysL8pKwFPf_kYsJ8hUOf2E8rSlxhbl0U3f8,3671
41
+ gpjax/kernels/stationary/rational_quadratic.py,sha256=9X6X6YEIW8021XPzkCOMUihdWS0_JnBVO1b2JToFa2w,3410
42
+ gpjax/kernels/stationary/rbf.py,sha256=uuhDlV-kZZ3gL0ZRpKWAOh3pVTpkddEVIILsWdEsMmM,1682
43
+ gpjax/kernels/stationary/utils.py,sha256=wcS4rRJmNIeL3Log7gcK7klWWprulF4krQJq803alhk,2172
44
+ gpjax/kernels/stationary/white.py,sha256=rPVf2xzfv2WpPpb8E7LXGuguiyZ6QKa7M3ltzLItmpk,2204
45
+ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
+ gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
+ gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
+ gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
+ gpjax-0.13.3.dist-info/METADATA,sha256=1Au0xXIXWvS64SPLP8R_mPQ6nuUMayx2qdNLD-HdhW8,10296
50
+ gpjax-0.13.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.13.3.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.13.3.dist-info/RECORD,,
@@ -1,52 +0,0 @@
1
- gpjax/__init__.py,sha256=jpuNxARmW3gOmXFAhJvrXeMk_rvYDn49E-8wy3ATZKg,1641
2
- gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
- gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
- gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
6
- gpjax/gps.py,sha256=fSbHjfQ1vKUZ3CnqdNgJVgFsuTkCwK3yK_c0SQsWYo0,30118
7
- gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
- gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
10
- gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
- gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
- gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
- gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
- gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
16
- gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
- gpjax/kernels/base.py,sha256=4vUV6_wnbV64EUaOecVWE9Vz4G5Dg6xUMNsQSis8A1o,11561
18
- gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
- gpjax/kernels/approximations/rff.py,sha256=WJ2Bw4RRW5bLXGKfStxph65H-ozDelvt_4mxXQm-g7w,4074
20
- gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
- gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
- gpjax/kernels/computations/basis_functions.py,sha256=Y6OS4UEvo-Mdn7yRMWPde5rOvewDHIleaUIVBOpCTd4,2466
23
- gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
- gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
- gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
26
- gpjax/kernels/computations/eigen.py,sha256=LuwYVPK0AuDGNSccPmAS8wyDJ2ngkRbc6kNZzqiJEOg,1380
27
- gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
28
- gpjax/kernels/non_euclidean/graph.py,sha256=MukL4ZghUywkau4tVWmHsrAjzTcl9kHN6onEAvxSvjc,4109
29
- gpjax/kernels/non_euclidean/utils.py,sha256=vxjRX209tOL3jKsw69hQfVtBqkASDZ__4tJOGI4oLjo,2342
30
- gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
- gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
32
- gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
33
- gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
34
- gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
- gpjax/kernels/stationary/base.py,sha256=YSgur73wqAFZBzt8D6CfujvyjowXoPOK1po10gjzMJo,6038
36
- gpjax/kernels/stationary/matern12.py,sha256=tfrIP-pJML-kSqVwf8wdVUwC53mhFohMnf7z5OMX3xE,1771
37
- gpjax/kernels/stationary/matern32.py,sha256=vdrszTtH03PB9D4AKg-YC_56xCupK3tO9L8vH0GYFcc,1964
38
- gpjax/kernels/stationary/matern52.py,sha256=IVsIKDo9bTZnA98yBmA_Q31peFdSCTXjf0_cVlLQh6k,2017
39
- gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
40
- gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
41
- gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
42
- gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
43
- gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
44
- gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
- gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
- gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
- gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
- gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
- gpjax-0.13.2.dist-info/METADATA,sha256=_dkTRApmtHexaDFVnH8FhpOf_YkIyci2Gihg9elVeII,10400
50
- gpjax-0.13.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.13.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.13.2.dist-info/RECORD,,
File without changes