gpjax 0.13.3__py3-none-any.whl → 0.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/citation.py +13 -0
- gpjax/gps.py +77 -0
- gpjax/likelihoods.py +234 -0
- gpjax/mean_functions.py +2 -2
- gpjax/objectives.py +56 -1
- gpjax/parameters.py +10 -2
- gpjax/variational_families.py +129 -0
- {gpjax-0.13.3.dist-info → gpjax-0.13.5.dist-info}/METADATA +4 -3
- {gpjax-0.13.3.dist-info → gpjax-0.13.5.dist-info}/RECORD +12 -12
- {gpjax-0.13.3.dist-info → gpjax-0.13.5.dist-info}/WHEEL +1 -1
- {gpjax-0.13.3.dist-info → gpjax-0.13.5.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/thomaspinder/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.13.
|
|
43
|
+
__version__ = "0.13.5"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
gpjax/citation.py
CHANGED
|
@@ -23,6 +23,7 @@ from gpjax.kernels import (
|
|
|
23
23
|
Matern32,
|
|
24
24
|
Matern52,
|
|
25
25
|
)
|
|
26
|
+
from gpjax.likelihoods import HeteroscedasticGaussian
|
|
26
27
|
|
|
27
28
|
CitationType = Union[None, str, Dict[str, str]]
|
|
28
29
|
|
|
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
|
|
|
149
150
|
booktitle="Advances in neural information processing systems",
|
|
150
151
|
citation_type="article",
|
|
151
152
|
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@cite.register(HeteroscedasticGaussian)
|
|
156
|
+
def _(tree) -> PaperCitation:
|
|
157
|
+
return PaperCitation(
|
|
158
|
+
citation_key="lazaro2011variational",
|
|
159
|
+
authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
|
|
160
|
+
title="Variational heteroscedastic Gaussian process regression",
|
|
161
|
+
year="2011",
|
|
162
|
+
booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
|
|
163
|
+
citation_type="inproceedings",
|
|
164
|
+
)
|
gpjax/gps.py
CHANGED
|
@@ -32,8 +32,10 @@ from gpjax.distributions import GaussianDistribution
|
|
|
32
32
|
from gpjax.kernels import RFF
|
|
33
33
|
from gpjax.kernels.base import AbstractKernel
|
|
34
34
|
from gpjax.likelihoods import (
|
|
35
|
+
AbstractHeteroscedasticLikelihood,
|
|
35
36
|
AbstractLikelihood,
|
|
36
37
|
Gaussian,
|
|
38
|
+
HeteroscedasticGaussian,
|
|
37
39
|
NonGaussian,
|
|
38
40
|
)
|
|
39
41
|
from gpjax.linalg import (
|
|
@@ -62,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
|
|
|
62
64
|
L = tp.TypeVar("L", bound=AbstractLikelihood)
|
|
63
65
|
NGL = tp.TypeVar("NGL", bound=NonGaussian)
|
|
64
66
|
GL = tp.TypeVar("GL", bound=Gaussian)
|
|
67
|
+
HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
|
|
65
68
|
|
|
66
69
|
|
|
67
70
|
class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
@@ -476,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
476
479
|
raise NotImplementedError
|
|
477
480
|
|
|
478
481
|
|
|
482
|
+
class LatentPosterior(AbstractPosterior[P, L]):
|
|
483
|
+
r"""A posterior shell used to expose prior structure without inference."""
|
|
484
|
+
|
|
485
|
+
def predict(
|
|
486
|
+
self,
|
|
487
|
+
test_inputs: Num[Array, "N D"],
|
|
488
|
+
train_data: Dataset,
|
|
489
|
+
*,
|
|
490
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
491
|
+
) -> GaussianDistribution:
|
|
492
|
+
raise NotImplementedError(
|
|
493
|
+
"LatentPosteriors are a lightweight wrapper for priors and do not "
|
|
494
|
+
"implement predictive distributions. Use a variational family for inference."
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
|
|
479
498
|
class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
480
499
|
r"""A Conjuate Gaussian process posterior object.
|
|
481
500
|
|
|
@@ -839,6 +858,40 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
839
858
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
|
|
840
859
|
|
|
841
860
|
|
|
861
|
+
class HeteroscedasticPosterior(LatentPosterior[P, HL]):
|
|
862
|
+
r"""Posterior shell for heteroscedastic likelihoods.
|
|
863
|
+
|
|
864
|
+
The posterior retains both the signal and noise priors; inference is delegated
|
|
865
|
+
to variational families and specialised objectives.
|
|
866
|
+
"""
|
|
867
|
+
|
|
868
|
+
def __init__(
|
|
869
|
+
self,
|
|
870
|
+
prior: AbstractPrior[M, K],
|
|
871
|
+
likelihood: HL,
|
|
872
|
+
jitter: float = 1e-6,
|
|
873
|
+
):
|
|
874
|
+
if likelihood.noise_prior is None:
|
|
875
|
+
raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
|
|
876
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
877
|
+
self.noise_prior = likelihood.noise_prior
|
|
878
|
+
self.noise_posterior = LatentPosterior(
|
|
879
|
+
prior=self.noise_prior, likelihood=likelihood, jitter=jitter
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
|
|
884
|
+
r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
|
|
885
|
+
|
|
886
|
+
def __init__(
|
|
887
|
+
self,
|
|
888
|
+
prior: AbstractPrior[M, K],
|
|
889
|
+
likelihood: HL,
|
|
890
|
+
jitter: float = 1e-6,
|
|
891
|
+
):
|
|
892
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
893
|
+
|
|
894
|
+
|
|
842
895
|
#######################
|
|
843
896
|
# Utils
|
|
844
897
|
#######################
|
|
@@ -854,6 +907,18 @@ def construct_posterior( # noqa: F811
|
|
|
854
907
|
) -> NonConjugatePosterior[P, NGL]: ...
|
|
855
908
|
|
|
856
909
|
|
|
910
|
+
@tp.overload
|
|
911
|
+
def construct_posterior( # noqa: F811
|
|
912
|
+
prior: P, likelihood: HeteroscedasticGaussian
|
|
913
|
+
) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
@tp.overload
|
|
917
|
+
def construct_posterior( # noqa: F811
|
|
918
|
+
prior: P, likelihood: AbstractHeteroscedasticLikelihood
|
|
919
|
+
) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
|
|
920
|
+
|
|
921
|
+
|
|
857
922
|
def construct_posterior(prior, likelihood): # noqa: F811
|
|
858
923
|
r"""Utility function for constructing a posterior object from a prior and
|
|
859
924
|
likelihood. The function will automatically select the correct posterior
|
|
@@ -873,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
|
|
|
873
938
|
if isinstance(likelihood, Gaussian):
|
|
874
939
|
return ConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
875
940
|
|
|
941
|
+
if (
|
|
942
|
+
isinstance(likelihood, HeteroscedasticGaussian)
|
|
943
|
+
and likelihood.supports_tight_bound()
|
|
944
|
+
):
|
|
945
|
+
return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
|
|
946
|
+
|
|
947
|
+
if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
|
|
948
|
+
return ChainedPosterior(prior=prior, likelihood=likelihood)
|
|
949
|
+
|
|
876
950
|
return NonConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
877
951
|
|
|
878
952
|
|
|
@@ -911,7 +985,10 @@ __all__ = [
|
|
|
911
985
|
"AbstractPrior",
|
|
912
986
|
"Prior",
|
|
913
987
|
"AbstractPosterior",
|
|
988
|
+
"LatentPosterior",
|
|
914
989
|
"ConjugatePosterior",
|
|
915
990
|
"NonConjugatePosterior",
|
|
991
|
+
"HeteroscedasticPosterior",
|
|
992
|
+
"ChainedPosterior",
|
|
916
993
|
"construct_posterior",
|
|
917
994
|
]
|
gpjax/likelihoods.py
CHANGED
|
@@ -10,15 +10,20 @@
|
|
|
10
10
|
# See the License for the specific language governing permissions and
|
|
11
11
|
# limitations under the License.
|
|
12
12
|
# ==============================================================================
|
|
13
|
+
from __future__ import annotations
|
|
13
14
|
|
|
14
15
|
import abc
|
|
16
|
+
from dataclasses import dataclass
|
|
15
17
|
|
|
16
18
|
import beartype.typing as tp
|
|
17
19
|
from flax import nnx
|
|
20
|
+
import jax
|
|
18
21
|
from jax import vmap
|
|
22
|
+
import jax.nn as jnn
|
|
19
23
|
import jax.numpy as jnp
|
|
20
24
|
import jax.scipy as jsp
|
|
21
25
|
from jaxtyping import Float
|
|
26
|
+
import numpy as np
|
|
22
27
|
import numpyro.distributions as npd
|
|
23
28
|
|
|
24
29
|
from gpjax.distributions import GaussianDistribution
|
|
@@ -36,6 +41,20 @@ from gpjax.typing import (
|
|
|
36
41
|
)
|
|
37
42
|
|
|
38
43
|
|
|
44
|
+
@dataclass
|
|
45
|
+
class NoiseMoments:
|
|
46
|
+
log_variance: Array
|
|
47
|
+
inv_variance: Array
|
|
48
|
+
variance: Array
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
jax.tree_util.register_pytree_node(
|
|
52
|
+
NoiseMoments,
|
|
53
|
+
lambda x: ((x.log_variance, x.inv_variance, x.variance), None),
|
|
54
|
+
lambda _, x: NoiseMoments(*x),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
39
58
|
class AbstractLikelihood(nnx.Module):
|
|
40
59
|
r"""Abstract base class for likelihoods.
|
|
41
60
|
|
|
@@ -103,6 +122,9 @@ class AbstractLikelihood(nnx.Module):
|
|
|
103
122
|
y: Float[Array, "N D"],
|
|
104
123
|
mean: Float[Array, "N D"],
|
|
105
124
|
variance: Float[Array, "N D"],
|
|
125
|
+
mean_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
126
|
+
variance_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
127
|
+
**_: tp.Any,
|
|
106
128
|
) -> Float[Array, " N"]:
|
|
107
129
|
r"""Compute the expected log likelihood.
|
|
108
130
|
|
|
@@ -116,6 +138,12 @@ class AbstractLikelihood(nnx.Module):
|
|
|
116
138
|
y (Float[Array, 'N D']): The observed response variable.
|
|
117
139
|
mean (Float[Array, 'N D']): The variational mean.
|
|
118
140
|
variance (Float[Array, 'N D']): The variational variance.
|
|
141
|
+
mean_g (Float[Array, 'N D']): Optional moments of the latent noise
|
|
142
|
+
process for heteroscedastic likelihoods.
|
|
143
|
+
variance_g (Float[Array, 'N D']): Optional moments of the latent noise
|
|
144
|
+
process for heteroscedastic likelihoods.
|
|
145
|
+
**_: Unused extra arguments for compatibility with specialised
|
|
146
|
+
likelihoods.
|
|
119
147
|
|
|
120
148
|
Returns:
|
|
121
149
|
ScalarFloat: The expected log likelihood.
|
|
@@ -126,6 +154,143 @@ class AbstractLikelihood(nnx.Module):
|
|
|
126
154
|
)
|
|
127
155
|
|
|
128
156
|
|
|
157
|
+
class AbstractNoiseTransform(nnx.Module):
|
|
158
|
+
"""Abstract base class for noise transformations."""
|
|
159
|
+
|
|
160
|
+
@abc.abstractmethod
|
|
161
|
+
def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
|
|
162
|
+
"""Transform the input noise signal."""
|
|
163
|
+
raise NotImplementedError
|
|
164
|
+
|
|
165
|
+
@abc.abstractmethod
|
|
166
|
+
def moments(
|
|
167
|
+
self, mean: Float[Array, "..."], variance: Float[Array, "..."]
|
|
168
|
+
) -> NoiseMoments:
|
|
169
|
+
"""Compute the moments of the transformed noise signal."""
|
|
170
|
+
raise NotImplementedError
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class LogNormalTransform(AbstractNoiseTransform):
|
|
174
|
+
"""Log-normal noise transformation."""
|
|
175
|
+
|
|
176
|
+
def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
|
|
177
|
+
return jnp.exp(x)
|
|
178
|
+
|
|
179
|
+
def moments(
|
|
180
|
+
self, mean: Float[Array, "..."], variance: Float[Array, "..."]
|
|
181
|
+
) -> NoiseMoments:
|
|
182
|
+
expected_variance = jnp.exp(mean + 0.5 * variance)
|
|
183
|
+
expected_log_variance = mean
|
|
184
|
+
expected_inv_variance = jnp.exp(-mean + 0.5 * variance)
|
|
185
|
+
return NoiseMoments(
|
|
186
|
+
log_variance=expected_log_variance,
|
|
187
|
+
inv_variance=expected_inv_variance,
|
|
188
|
+
variance=expected_variance,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class SoftplusTransform(AbstractNoiseTransform):
|
|
193
|
+
"""Softplus noise transformation."""
|
|
194
|
+
|
|
195
|
+
def __init__(self, num_points: int = 20):
|
|
196
|
+
self.num_points = num_points
|
|
197
|
+
|
|
198
|
+
def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
|
|
199
|
+
return jnn.softplus(x)
|
|
200
|
+
|
|
201
|
+
def moments(
|
|
202
|
+
self, mean: Float[Array, "..."], variance: Float[Array, "..."]
|
|
203
|
+
) -> NoiseMoments:
|
|
204
|
+
quad_x, quad_w = np.polynomial.hermite.hermgauss(self.num_points)
|
|
205
|
+
quad_w = jnp.asarray(quad_w / jnp.sqrt(jnp.pi))
|
|
206
|
+
quad_x = jnp.asarray(quad_x)
|
|
207
|
+
|
|
208
|
+
std = jnp.sqrt(variance)
|
|
209
|
+
samples = mean[..., None] + jnp.sqrt(2.0) * std[..., None] * quad_x
|
|
210
|
+
sigma2 = self(samples)
|
|
211
|
+
log_sigma2 = jnp.log(sigma2)
|
|
212
|
+
inv_sigma2 = 1.0 / sigma2
|
|
213
|
+
|
|
214
|
+
expected_variance = jnp.sum(sigma2 * quad_w, axis=-1)
|
|
215
|
+
expected_log_variance = jnp.sum(log_sigma2 * quad_w, axis=-1)
|
|
216
|
+
expected_inv_variance = jnp.sum(inv_sigma2 * quad_w, axis=-1)
|
|
217
|
+
|
|
218
|
+
return NoiseMoments(
|
|
219
|
+
log_variance=expected_log_variance,
|
|
220
|
+
inv_variance=expected_inv_variance,
|
|
221
|
+
variance=expected_variance,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class AbstractHeteroscedasticLikelihood(AbstractLikelihood):
|
|
226
|
+
r"""Base class for heteroscedastic likelihoods with latent noise processes."""
|
|
227
|
+
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
num_datapoints: int,
|
|
231
|
+
noise_prior,
|
|
232
|
+
noise_transform: tp.Union[
|
|
233
|
+
AbstractNoiseTransform,
|
|
234
|
+
tp.Callable[[Float[Array, "..."]], Float[Array, "..."]],
|
|
235
|
+
] = SoftplusTransform(),
|
|
236
|
+
integrator: AbstractIntegrator = GHQuadratureIntegrator(),
|
|
237
|
+
):
|
|
238
|
+
self.noise_prior = noise_prior
|
|
239
|
+
|
|
240
|
+
if isinstance(noise_transform, AbstractNoiseTransform):
|
|
241
|
+
self.noise_transform = noise_transform
|
|
242
|
+
else:
|
|
243
|
+
transform_name = getattr(noise_transform, "__name__", "")
|
|
244
|
+
if noise_transform is jnp.exp or transform_name == "exp":
|
|
245
|
+
self.noise_transform = LogNormalTransform()
|
|
246
|
+
else:
|
|
247
|
+
# Default to SoftplusTransform for softplus or unknown callables (legacy behavior used quadrature)
|
|
248
|
+
# Note: If an unknown callable is passed, we technically use SoftplusTransform which applies softplus.
|
|
249
|
+
# Users should implement AbstractNoiseTransform for custom transforms.
|
|
250
|
+
self.noise_transform = SoftplusTransform()
|
|
251
|
+
|
|
252
|
+
super().__init__(num_datapoints=num_datapoints, integrator=integrator)
|
|
253
|
+
|
|
254
|
+
def __call__(
|
|
255
|
+
self,
|
|
256
|
+
dist: tp.Union[npd.MultivariateNormal, GaussianDistribution],
|
|
257
|
+
noise_dist: tp.Optional[
|
|
258
|
+
tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
259
|
+
] = None,
|
|
260
|
+
) -> npd.Distribution:
|
|
261
|
+
return self.predict(dist, noise_dist)
|
|
262
|
+
|
|
263
|
+
def supports_tight_bound(self) -> bool:
|
|
264
|
+
"""Return whether the tighter bound from Lázaro-Gredilla & Titsias (2011)
|
|
265
|
+
is applicable."""
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
def noise_statistics(
|
|
269
|
+
self, mean: Float[Array, "N D"], variance: Float[Array, "N D"]
|
|
270
|
+
) -> NoiseMoments:
|
|
271
|
+
r"""Moment matching of the transformed noise process.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
mean: Mean of the latent noise GP.
|
|
275
|
+
variance: Variance of the latent noise GP.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
NoiseMoments: Expected log variance, inverse variance, and variance.
|
|
279
|
+
"""
|
|
280
|
+
return self.noise_transform.moments(mean, variance)
|
|
281
|
+
|
|
282
|
+
def expected_log_likelihood(
|
|
283
|
+
self,
|
|
284
|
+
y: Float[Array, "N D"],
|
|
285
|
+
mean: Float[Array, "N D"],
|
|
286
|
+
variance: Float[Array, "N D"],
|
|
287
|
+
mean_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
288
|
+
variance_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
289
|
+
**kwargs: tp.Any,
|
|
290
|
+
) -> Float[Array, " N"]:
|
|
291
|
+
raise NotImplementedError
|
|
292
|
+
|
|
293
|
+
|
|
129
294
|
class Gaussian(AbstractLikelihood):
|
|
130
295
|
r"""Gaussian likelihood object."""
|
|
131
296
|
|
|
@@ -186,6 +351,69 @@ class Gaussian(AbstractLikelihood):
|
|
|
186
351
|
return npd.MultivariateNormal(dist.mean, noisy_cov)
|
|
187
352
|
|
|
188
353
|
|
|
354
|
+
class HeteroscedasticGaussian(AbstractHeteroscedasticLikelihood):
|
|
355
|
+
def predict(
|
|
356
|
+
self,
|
|
357
|
+
dist: tp.Union[npd.MultivariateNormal, GaussianDistribution],
|
|
358
|
+
noise_dist: tp.Optional[
|
|
359
|
+
tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
360
|
+
] = None,
|
|
361
|
+
) -> npd.MultivariateNormal:
|
|
362
|
+
if noise_dist is None:
|
|
363
|
+
raise ValueError(
|
|
364
|
+
"noise_dist must be provided for heteroscedastic prediction."
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
n_data = dist.event_shape[0]
|
|
368
|
+
noise_mean = noise_dist.mean
|
|
369
|
+
noise_variance = jnp.diag(noise_dist.covariance_matrix)
|
|
370
|
+
noise_stats = self.noise_statistics(
|
|
371
|
+
noise_mean[..., None], noise_variance[..., None]
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
cov = dist.covariance_matrix
|
|
375
|
+
noisy_cov = cov.at[jnp.diag_indices(n_data)].add(noise_stats.variance.squeeze())
|
|
376
|
+
|
|
377
|
+
return npd.MultivariateNormal(dist.mean, noisy_cov)
|
|
378
|
+
|
|
379
|
+
def link_function(self, f: Float[Array, "..."]) -> npd.Normal:
|
|
380
|
+
sigma2 = self.noise_transform(jnp.zeros_like(f))
|
|
381
|
+
return npd.Normal(loc=f, scale=jnp.sqrt(sigma2))
|
|
382
|
+
|
|
383
|
+
def expected_log_likelihood(
|
|
384
|
+
self,
|
|
385
|
+
y: Float[Array, "N D"],
|
|
386
|
+
mean: Float[Array, "N D"],
|
|
387
|
+
variance: Float[Array, "N D"],
|
|
388
|
+
mean_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
389
|
+
variance_g: tp.Optional[Float[Array, "N D"]] = None,
|
|
390
|
+
noise_stats: tp.Optional[NoiseMoments] = None,
|
|
391
|
+
return_parts: bool = False,
|
|
392
|
+
**_: tp.Any,
|
|
393
|
+
) -> tp.Union[Float[Array, " N"], tuple[Float[Array, " N"], NoiseMoments]]:
|
|
394
|
+
if mean_g is None or variance_g is None:
|
|
395
|
+
raise ValueError(
|
|
396
|
+
"mean_g and variance_g must be provided for heteroscedastic models."
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if noise_stats is None:
|
|
400
|
+
noise_stats = self.noise_statistics(mean_g, variance_g)
|
|
401
|
+
sq_error = jnp.square(y - mean)
|
|
402
|
+
log2pi = jnp.log(2.0 * jnp.pi)
|
|
403
|
+
expected = -0.5 * (
|
|
404
|
+
log2pi
|
|
405
|
+
+ noise_stats.log_variance
|
|
406
|
+
+ (sq_error + variance) * noise_stats.inv_variance
|
|
407
|
+
)
|
|
408
|
+
expected_sum = jnp.sum(expected, axis=1)
|
|
409
|
+
if return_parts:
|
|
410
|
+
return expected_sum, noise_stats
|
|
411
|
+
return expected_sum
|
|
412
|
+
|
|
413
|
+
def supports_tight_bound(self) -> bool:
|
|
414
|
+
return True
|
|
415
|
+
|
|
416
|
+
|
|
189
417
|
class Bernoulli(AbstractLikelihood):
|
|
190
418
|
def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs:
|
|
191
419
|
r"""The probit link function of the Bernoulli likelihood.
|
|
@@ -268,7 +496,13 @@ __all__ = [
|
|
|
268
496
|
"AbstractLikelihood",
|
|
269
497
|
"NonGaussian",
|
|
270
498
|
"Gaussian",
|
|
499
|
+
"AbstractHeteroscedasticLikelihood",
|
|
500
|
+
"HeteroscedasticGaussian",
|
|
271
501
|
"Bernoulli",
|
|
272
502
|
"Poisson",
|
|
273
503
|
"inv_probit",
|
|
504
|
+
"NoiseMoments",
|
|
505
|
+
"AbstractNoiseTransform",
|
|
506
|
+
"LogNormalTransform",
|
|
507
|
+
"SoftplusTransform",
|
|
274
508
|
]
|
gpjax/mean_functions.py
CHANGED
|
@@ -147,9 +147,9 @@ class Constant(AbstractMeanFunction):
|
|
|
147
147
|
Float[Array, "1"]: The evaluated mean function.
|
|
148
148
|
"""
|
|
149
149
|
if isinstance(self.constant, Parameter):
|
|
150
|
-
return jnp.ones((x.shape[0], 1)) * self.constant.value
|
|
150
|
+
return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant.value
|
|
151
151
|
else:
|
|
152
|
-
return jnp.ones((x.shape[0], 1)) * self.constant
|
|
152
|
+
return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant
|
|
153
153
|
|
|
154
154
|
|
|
155
155
|
class Zero(Constant):
|
gpjax/objectives.py
CHANGED
|
@@ -14,6 +14,9 @@ from gpjax.gps import (
|
|
|
14
14
|
ConjugatePosterior,
|
|
15
15
|
NonConjugatePosterior,
|
|
16
16
|
)
|
|
17
|
+
from gpjax.likelihoods import (
|
|
18
|
+
AbstractHeteroscedasticLikelihood,
|
|
19
|
+
)
|
|
17
20
|
from gpjax.linalg import (
|
|
18
21
|
Dense,
|
|
19
22
|
lower_cholesky,
|
|
@@ -25,9 +28,13 @@ from gpjax.typing import (
|
|
|
25
28
|
Array,
|
|
26
29
|
ScalarFloat,
|
|
27
30
|
)
|
|
28
|
-
from gpjax.variational_families import
|
|
31
|
+
from gpjax.variational_families import (
|
|
32
|
+
AbstractVariationalFamily,
|
|
33
|
+
HeteroscedasticVariationalFamily,
|
|
34
|
+
)
|
|
29
35
|
|
|
30
36
|
VF = TypeVar("VF", bound=AbstractVariationalFamily)
|
|
37
|
+
HVF = TypeVar("HVF", bound=HeteroscedasticVariationalFamily)
|
|
31
38
|
|
|
32
39
|
|
|
33
40
|
Objective = tpe.Callable[[nnx.Module, Dataset], ScalarFloat]
|
|
@@ -414,3 +421,51 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
414
421
|
|
|
415
422
|
# log N(y; μx, Io² + KxzKzz⁻¹Kzx) - 1/2o² tr(Kxx - KxzKzz⁻¹Kzx)
|
|
416
423
|
return (two_log_prob - two_trace).squeeze() / 2.0
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def heteroscedastic_elbo_conjugate(
|
|
427
|
+
variational_family: HVF, data: Dataset
|
|
428
|
+
) -> ScalarFloat:
|
|
429
|
+
r"""Tight bound from Lázaro-Gredilla & Titsias (2011) for heteroscedastic Gaussian likelihoods."""
|
|
430
|
+
likelihood = variational_family.posterior.likelihood
|
|
431
|
+
mean_f, var_f, mean_g, var_g = variational_family.predict(data.X)
|
|
432
|
+
|
|
433
|
+
expected_ll, _ = likelihood.expected_log_likelihood(
|
|
434
|
+
data.y,
|
|
435
|
+
mean_f,
|
|
436
|
+
var_f,
|
|
437
|
+
mean_g=mean_g,
|
|
438
|
+
variance_g=var_g,
|
|
439
|
+
return_parts=True,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
scale = likelihood.num_datapoints / data.n
|
|
443
|
+
return scale * jnp.sum(expected_ll) - variational_family.prior_kl()
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def heteroscedastic_elbo_chained(variational_family: HVF, data: Dataset) -> ScalarFloat:
|
|
447
|
+
r"""Generic chained bound for heteroscedastic likelihoods."""
|
|
448
|
+
likelihood: AbstractHeteroscedasticLikelihood = (
|
|
449
|
+
variational_family.posterior.likelihood
|
|
450
|
+
)
|
|
451
|
+
mean_f, var_f, mean_g, var_g = variational_family.predict(data.X)
|
|
452
|
+
noise_stats = likelihood.noise_statistics(mean_g, var_g)
|
|
453
|
+
|
|
454
|
+
expected_ll = likelihood.expected_log_likelihood(
|
|
455
|
+
data.y,
|
|
456
|
+
mean_f,
|
|
457
|
+
var_f,
|
|
458
|
+
mean_g=mean_g,
|
|
459
|
+
variance_g=var_g,
|
|
460
|
+
noise_stats=noise_stats,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
scale = likelihood.num_datapoints / data.n
|
|
464
|
+
return scale * jnp.sum(expected_ll) - variational_family.prior_kl()
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def heteroscedastic_elbo(variational_family: HVF, data: Dataset) -> ScalarFloat:
|
|
468
|
+
likelihood = variational_family.posterior.likelihood
|
|
469
|
+
if likelihood.supports_tight_bound():
|
|
470
|
+
return heteroscedastic_elbo_conjugate(variational_family, data)
|
|
471
|
+
return heteroscedastic_elbo_chained(variational_family, data)
|
gpjax/parameters.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
3
|
from flax import nnx
|
|
4
|
+
import jax
|
|
4
5
|
from jax.experimental import checkify
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import jax.tree_util as jtu
|
|
@@ -77,7 +78,14 @@ class Parameter(nnx.Variable[T]):
|
|
|
77
78
|
_check_is_arraylike(value)
|
|
78
79
|
|
|
79
80
|
super().__init__(value=jnp.asarray(value), **kwargs)
|
|
80
|
-
|
|
81
|
+
|
|
82
|
+
# nnx.Variable metadata must be set via set_metadata (direct setattr is disallowed).
|
|
83
|
+
self.set_metadata(tag=tag)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def tag(self) -> ParameterTag:
|
|
87
|
+
"""Return the parameter's constraint tag."""
|
|
88
|
+
return self.metadata.get("tag", "real")
|
|
81
89
|
|
|
82
90
|
|
|
83
91
|
class NonNegativeReal(Parameter[T]):
|
|
@@ -155,7 +163,7 @@ def _check_is_arraylike(value: T) -> None:
|
|
|
155
163
|
Raises:
|
|
156
164
|
TypeError: If the value is not array-like.
|
|
157
165
|
"""
|
|
158
|
-
if not isinstance(value, (ArrayLike, list)):
|
|
166
|
+
if not isinstance(value, (jax.Array, ArrayLike, list)):
|
|
159
167
|
raise TypeError(
|
|
160
168
|
f"Expected parameter value to be an array-like type. Got {value}."
|
|
161
169
|
)
|
gpjax/variational_families.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import abc
|
|
17
|
+
from dataclasses import dataclass
|
|
17
18
|
|
|
18
19
|
import beartype.typing as tp
|
|
19
20
|
from flax import nnx
|
|
@@ -29,9 +30,12 @@ from gpjax.distributions import GaussianDistribution
|
|
|
29
30
|
from gpjax.gps import (
|
|
30
31
|
AbstractPosterior,
|
|
31
32
|
AbstractPrior,
|
|
33
|
+
ChainedPosterior,
|
|
34
|
+
HeteroscedasticPosterior,
|
|
32
35
|
)
|
|
33
36
|
from gpjax.kernels.base import AbstractKernel
|
|
34
37
|
from gpjax.likelihoods import (
|
|
38
|
+
AbstractHeteroscedasticLikelihood,
|
|
35
39
|
Gaussian,
|
|
36
40
|
NonGaussian,
|
|
37
41
|
)
|
|
@@ -59,8 +63,10 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
|
|
|
59
63
|
L = tp.TypeVar("L", Gaussian, NonGaussian)
|
|
60
64
|
NGL = tp.TypeVar("NGL", bound=NonGaussian)
|
|
61
65
|
GL = tp.TypeVar("GL", bound=Gaussian)
|
|
66
|
+
HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
|
|
62
67
|
P = tp.TypeVar("P", bound=AbstractPrior)
|
|
63
68
|
PP = tp.TypeVar("PP", bound=AbstractPosterior)
|
|
69
|
+
HP = tp.TypeVar("HP", HeteroscedasticPosterior, ChainedPosterior)
|
|
64
70
|
|
|
65
71
|
|
|
66
72
|
class AbstractVariationalFamily(nnx.Module, tp.Generic[L]):
|
|
@@ -870,6 +876,126 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
|
|
|
870
876
|
)
|
|
871
877
|
|
|
872
878
|
|
|
879
|
+
@dataclass
|
|
880
|
+
class VariationalGaussianInit:
|
|
881
|
+
"""Initialization parameters for a variational Gaussian distribution."""
|
|
882
|
+
|
|
883
|
+
inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
|
|
884
|
+
variational_mean: tp.Union[Float[Array, "N 1"], None] = None
|
|
885
|
+
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
class HeteroscedasticPrediction(tp.NamedTuple):
|
|
889
|
+
"""Mean and variance of the signal and noise latent processes."""
|
|
890
|
+
|
|
891
|
+
mean_f: Float[Array, "N 1"]
|
|
892
|
+
variance_f: Float[Array, "N 1"]
|
|
893
|
+
mean_g: Float[Array, "N 1"]
|
|
894
|
+
variance_g: Float[Array, "N 1"]
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
class HeteroscedasticVariationalFamily(AbstractVariationalFamily[HL]):
|
|
898
|
+
r"""Variational family for two independent latent processes f and g."""
|
|
899
|
+
|
|
900
|
+
def __init__(
|
|
901
|
+
self,
|
|
902
|
+
posterior: HP,
|
|
903
|
+
inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] = None,
|
|
904
|
+
inducing_inputs_g: tp.Union[
|
|
905
|
+
Int[Array, "M D"], Float[Array, "M D"], None
|
|
906
|
+
] = None,
|
|
907
|
+
variational_mean_f: tp.Union[Float[Array, "N 1"], None] = None,
|
|
908
|
+
variational_root_covariance_f: tp.Union[Float[Array, "N N"], None] = None,
|
|
909
|
+
variational_mean_g: tp.Union[Float[Array, "M 1"], None] = None,
|
|
910
|
+
variational_root_covariance_g: tp.Union[Float[Array, "M M"], None] = None,
|
|
911
|
+
jitter: ScalarFloat = 1e-6,
|
|
912
|
+
signal_init: tp.Optional[VariationalGaussianInit] = None,
|
|
913
|
+
noise_init: tp.Optional[VariationalGaussianInit] = None,
|
|
914
|
+
):
|
|
915
|
+
self.jitter = jitter
|
|
916
|
+
|
|
917
|
+
if signal_init is not None:
|
|
918
|
+
self.signal_variational = VariationalGaussian(
|
|
919
|
+
posterior=posterior,
|
|
920
|
+
inducing_inputs=signal_init.inducing_inputs,
|
|
921
|
+
variational_mean=signal_init.variational_mean,
|
|
922
|
+
variational_root_covariance=signal_init.variational_root_covariance,
|
|
923
|
+
jitter=jitter,
|
|
924
|
+
)
|
|
925
|
+
elif inducing_inputs is not None:
|
|
926
|
+
self.signal_variational = VariationalGaussian(
|
|
927
|
+
posterior=posterior,
|
|
928
|
+
inducing_inputs=inducing_inputs,
|
|
929
|
+
variational_mean=variational_mean_f,
|
|
930
|
+
variational_root_covariance=variational_root_covariance_f,
|
|
931
|
+
jitter=jitter,
|
|
932
|
+
)
|
|
933
|
+
else:
|
|
934
|
+
raise ValueError("Either signal_init or inducing_inputs must be provided.")
|
|
935
|
+
|
|
936
|
+
if noise_init is not None:
|
|
937
|
+
self.noise_variational = VariationalGaussian(
|
|
938
|
+
posterior=posterior.noise_posterior,
|
|
939
|
+
inducing_inputs=noise_init.inducing_inputs,
|
|
940
|
+
variational_mean=noise_init.variational_mean,
|
|
941
|
+
variational_root_covariance=noise_init.variational_root_covariance,
|
|
942
|
+
jitter=jitter,
|
|
943
|
+
)
|
|
944
|
+
else:
|
|
945
|
+
noise_inducing = (
|
|
946
|
+
inducing_inputs if inducing_inputs_g is None else inducing_inputs_g
|
|
947
|
+
)
|
|
948
|
+
if noise_inducing is None and signal_init is not None:
|
|
949
|
+
noise_inducing = signal_init.inducing_inputs
|
|
950
|
+
|
|
951
|
+
if noise_inducing is None:
|
|
952
|
+
raise ValueError(
|
|
953
|
+
"Could not determine inducing inputs for noise process."
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
self.noise_variational = VariationalGaussian(
|
|
957
|
+
posterior=posterior.noise_posterior,
|
|
958
|
+
inducing_inputs=noise_inducing,
|
|
959
|
+
variational_mean=variational_mean_g,
|
|
960
|
+
variational_root_covariance=variational_root_covariance_g,
|
|
961
|
+
jitter=jitter,
|
|
962
|
+
)
|
|
963
|
+
super().__init__(posterior)
|
|
964
|
+
|
|
965
|
+
def prior_kl(self) -> ScalarFloat:
|
|
966
|
+
return self.signal_variational.prior_kl() + self.noise_variational.prior_kl()
|
|
967
|
+
|
|
968
|
+
def predict(
|
|
969
|
+
self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
|
|
970
|
+
) -> HeteroscedasticPrediction:
|
|
971
|
+
dist_f = self.signal_variational.predict(test_inputs)
|
|
972
|
+
dist_g = self.noise_variational.predict(test_inputs)
|
|
973
|
+
|
|
974
|
+
mean_f = dist_f.mean[:, None] if dist_f.mean.ndim == 1 else dist_f.mean
|
|
975
|
+
var_f = (
|
|
976
|
+
dist_f.variance[:, None] if dist_f.variance.ndim == 1 else dist_f.variance
|
|
977
|
+
)
|
|
978
|
+
mean_g = dist_g.mean[:, None] if dist_g.mean.ndim == 1 else dist_g.mean
|
|
979
|
+
var_g = (
|
|
980
|
+
dist_g.variance[:, None] if dist_g.variance.ndim == 1 else dist_g.variance
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
return HeteroscedasticPrediction(
|
|
984
|
+
mean_f=mean_f,
|
|
985
|
+
variance_f=var_f,
|
|
986
|
+
mean_g=mean_g,
|
|
987
|
+
variance_g=var_g,
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
def predict_latents(
|
|
991
|
+
self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
|
|
992
|
+
) -> tuple[GaussianDistribution, GaussianDistribution]:
|
|
993
|
+
return (
|
|
994
|
+
self.signal_variational.predict(test_inputs),
|
|
995
|
+
self.noise_variational.predict(test_inputs),
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
|
|
873
999
|
__all__ = [
|
|
874
1000
|
"AbstractVariationalFamily",
|
|
875
1001
|
"AbstractVariationalGaussian",
|
|
@@ -879,4 +1005,7 @@ __all__ = [
|
|
|
879
1005
|
"NaturalVariationalGaussian",
|
|
880
1006
|
"ExpectationVariationalGaussian",
|
|
881
1007
|
"CollapsedVariationalGaussian",
|
|
1008
|
+
"HeteroscedasticVariationalFamily",
|
|
1009
|
+
"VariationalGaussianInit",
|
|
1010
|
+
"HeteroscedasticPrediction",
|
|
882
1011
|
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.5
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
|
|
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
|
|
|
25
25
|
Requires-Dist: numpy>=2.0.0
|
|
26
26
|
Requires-Dist: numpyro
|
|
27
27
|
Requires-Dist: optax>0.2.1
|
|
28
|
+
Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
|
|
28
29
|
Requires-Dist: tqdm>4.66.2
|
|
29
30
|
Provides-Extra: dev
|
|
30
31
|
Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
|
|
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
|
|
|
59
60
|
Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
|
|
60
61
|
Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
|
|
61
62
|
Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
|
|
62
|
-
Requires-Dist: mkdocstrings[python]<
|
|
63
|
+
Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
|
|
63
64
|
Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
|
|
64
65
|
Requires-Dist: networkx>=3.0; extra == 'docs'
|
|
65
66
|
Requires-Dist: pandas>=1.5.3; extra == 'docs'
|
|
@@ -141,7 +142,7 @@ GPJax into the package it is today.
|
|
|
141
142
|
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
|
|
142
143
|
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
|
|
143
144
|
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
|
|
144
|
-
> - [**
|
|
145
|
+
> - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
|
|
145
146
|
> - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
|
|
146
147
|
> - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
|
|
147
148
|
> - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
2
|
-
gpjax/citation.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=8guaWVHiNXlZ9m_XvjPo8nherGpCpGWCCxB5fenq_Gw,1625
|
|
2
|
+
gpjax/citation.py,sha256=9L-OVIDCrwjM3TeWinC5UHS3Lxn_9sVIsj_H9i08BT8,4428
|
|
3
3
|
gpjax/dataset.py,sha256=Ef5JGrl4jJS1mQmL3JdO0fdqbVmflT_Cu5VrlpYdJY4,4071
|
|
4
4
|
gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
|
|
5
5
|
gpjax/fit.py,sha256=tOXmM3l5-N3Jlnq8MOVkSpRj-0fRWOy2t9GxQRyUqxY,15318
|
|
6
|
-
gpjax/gps.py,sha256=
|
|
6
|
+
gpjax/gps.py,sha256=83rki591ma3B_yr1Bq1OUlAH75sThpg-x00ZZXeqwN0,37906
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=
|
|
9
|
-
gpjax/mean_functions.py,sha256=
|
|
8
|
+
gpjax/likelihoods.py,sha256=SzGh3uAkULVMfAVA3QpKQwWp-XAcx46B7QhhgKymn6A,17381
|
|
9
|
+
gpjax/mean_functions.py,sha256=cDv3X5E4-to8pgzHpbKhoMKDLdWg3WhEq1W21kvbDDg,6571
|
|
10
10
|
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
11
|
-
gpjax/objectives.py,sha256=
|
|
12
|
-
gpjax/parameters.py,sha256=
|
|
11
|
+
gpjax/objectives.py,sha256=QScq5e77-jugtt2wKZNLgk3VtkqSDddJyYBs-hQ7dZc,17206
|
|
12
|
+
gpjax/parameters.py,sha256=iGab7IMAANuhFm8ojcK5x5fJNzUOdflr2h2pBaVLe9Q,7035
|
|
13
13
|
gpjax/scan.py,sha256=Z_V1yd76dL4B7_rnJnr_fohom6xzN_WRYxlTAAdqfa0,5357
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=
|
|
15
|
+
gpjax/variational_families.py,sha256=J7L1lh0McYSgr1BKq5vsVjtfu2XeyOCn-I_AHQym-xA,36723
|
|
16
16
|
gpjax/kernels/__init__.py,sha256=GFiku1s8KPPzvfwDYzEjENPQReywvJ29M9Kl2JjkcoU,1893
|
|
17
17
|
gpjax/kernels/base.py,sha256=gov60BgXOQK0PxbKQRVgjU55kCVZYAkDZrAInODvugo,11549
|
|
18
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
|
|
|
46
46
|
gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
|
|
47
47
|
gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
|
|
48
48
|
gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
|
|
49
|
-
gpjax-0.13.
|
|
50
|
-
gpjax-0.13.
|
|
51
|
-
gpjax-0.13.
|
|
52
|
-
gpjax-0.13.
|
|
49
|
+
gpjax-0.13.5.dist-info/METADATA,sha256=qJ8ZeG1TGMX6a_2WQnNSwTl2JBXVp8107QqY9f1XWyE,10382
|
|
50
|
+
gpjax-0.13.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
51
|
+
gpjax-0.13.5.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
52
|
+
gpjax-0.13.5.dist-info/RECORD,,
|
|
File without changes
|