gpjax 0.13.3__py3-none-any.whl → 0.13.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/thomaspinder/GPJax"
42
42
  __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
43
- __version__ = "0.13.3"
43
+ __version__ = "0.13.5"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
gpjax/citation.py CHANGED
@@ -23,6 +23,7 @@ from gpjax.kernels import (
23
23
  Matern32,
24
24
  Matern52,
25
25
  )
26
+ from gpjax.likelihoods import HeteroscedasticGaussian
26
27
 
27
28
  CitationType = Union[None, str, Dict[str, str]]
28
29
 
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
149
150
  booktitle="Advances in neural information processing systems",
150
151
  citation_type="article",
151
152
  )
153
+
154
+
155
+ @cite.register(HeteroscedasticGaussian)
156
+ def _(tree) -> PaperCitation:
157
+ return PaperCitation(
158
+ citation_key="lazaro2011variational",
159
+ authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
160
+ title="Variational heteroscedastic Gaussian process regression",
161
+ year="2011",
162
+ booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
163
+ citation_type="inproceedings",
164
+ )
gpjax/gps.py CHANGED
@@ -32,8 +32,10 @@ from gpjax.distributions import GaussianDistribution
32
32
  from gpjax.kernels import RFF
33
33
  from gpjax.kernels.base import AbstractKernel
34
34
  from gpjax.likelihoods import (
35
+ AbstractHeteroscedasticLikelihood,
35
36
  AbstractLikelihood,
36
37
  Gaussian,
38
+ HeteroscedasticGaussian,
37
39
  NonGaussian,
38
40
  )
39
41
  from gpjax.linalg import (
@@ -62,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
62
64
  L = tp.TypeVar("L", bound=AbstractLikelihood)
63
65
  NGL = tp.TypeVar("NGL", bound=NonGaussian)
64
66
  GL = tp.TypeVar("GL", bound=Gaussian)
67
+ HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
65
68
 
66
69
 
67
70
  class AbstractPrior(nnx.Module, tp.Generic[M, K]):
@@ -476,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
476
479
  raise NotImplementedError
477
480
 
478
481
 
482
+ class LatentPosterior(AbstractPosterior[P, L]):
483
+ r"""A posterior shell used to expose prior structure without inference."""
484
+
485
+ def predict(
486
+ self,
487
+ test_inputs: Num[Array, "N D"],
488
+ train_data: Dataset,
489
+ *,
490
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
491
+ ) -> GaussianDistribution:
492
+ raise NotImplementedError(
493
+ "LatentPosteriors are a lightweight wrapper for priors and do not "
494
+ "implement predictive distributions. Use a variational family for inference."
495
+ )
496
+
497
+
479
498
  class ConjugatePosterior(AbstractPosterior[P, GL]):
480
499
  r"""A Conjuate Gaussian process posterior object.
481
500
 
@@ -839,6 +858,40 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
839
858
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
840
859
 
841
860
 
861
+ class HeteroscedasticPosterior(LatentPosterior[P, HL]):
862
+ r"""Posterior shell for heteroscedastic likelihoods.
863
+
864
+ The posterior retains both the signal and noise priors; inference is delegated
865
+ to variational families and specialised objectives.
866
+ """
867
+
868
+ def __init__(
869
+ self,
870
+ prior: AbstractPrior[M, K],
871
+ likelihood: HL,
872
+ jitter: float = 1e-6,
873
+ ):
874
+ if likelihood.noise_prior is None:
875
+ raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
876
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
877
+ self.noise_prior = likelihood.noise_prior
878
+ self.noise_posterior = LatentPosterior(
879
+ prior=self.noise_prior, likelihood=likelihood, jitter=jitter
880
+ )
881
+
882
+
883
+ class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
884
+ r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
885
+
886
+ def __init__(
887
+ self,
888
+ prior: AbstractPrior[M, K],
889
+ likelihood: HL,
890
+ jitter: float = 1e-6,
891
+ ):
892
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
893
+
894
+
842
895
  #######################
843
896
  # Utils
844
897
  #######################
@@ -854,6 +907,18 @@ def construct_posterior( # noqa: F811
854
907
  ) -> NonConjugatePosterior[P, NGL]: ...
855
908
 
856
909
 
910
+ @tp.overload
911
+ def construct_posterior( # noqa: F811
912
+ prior: P, likelihood: HeteroscedasticGaussian
913
+ ) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
914
+
915
+
916
+ @tp.overload
917
+ def construct_posterior( # noqa: F811
918
+ prior: P, likelihood: AbstractHeteroscedasticLikelihood
919
+ ) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
920
+
921
+
857
922
  def construct_posterior(prior, likelihood): # noqa: F811
858
923
  r"""Utility function for constructing a posterior object from a prior and
859
924
  likelihood. The function will automatically select the correct posterior
@@ -873,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
873
938
  if isinstance(likelihood, Gaussian):
874
939
  return ConjugatePosterior(prior=prior, likelihood=likelihood)
875
940
 
941
+ if (
942
+ isinstance(likelihood, HeteroscedasticGaussian)
943
+ and likelihood.supports_tight_bound()
944
+ ):
945
+ return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
946
+
947
+ if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
948
+ return ChainedPosterior(prior=prior, likelihood=likelihood)
949
+
876
950
  return NonConjugatePosterior(prior=prior, likelihood=likelihood)
877
951
 
878
952
 
@@ -911,7 +985,10 @@ __all__ = [
911
985
  "AbstractPrior",
912
986
  "Prior",
913
987
  "AbstractPosterior",
988
+ "LatentPosterior",
914
989
  "ConjugatePosterior",
915
990
  "NonConjugatePosterior",
991
+ "HeteroscedasticPosterior",
992
+ "ChainedPosterior",
916
993
  "construct_posterior",
917
994
  ]
gpjax/likelihoods.py CHANGED
@@ -10,15 +10,20 @@
10
10
  # See the License for the specific language governing permissions and
11
11
  # limitations under the License.
12
12
  # ==============================================================================
13
+ from __future__ import annotations
13
14
 
14
15
  import abc
16
+ from dataclasses import dataclass
15
17
 
16
18
  import beartype.typing as tp
17
19
  from flax import nnx
20
+ import jax
18
21
  from jax import vmap
22
+ import jax.nn as jnn
19
23
  import jax.numpy as jnp
20
24
  import jax.scipy as jsp
21
25
  from jaxtyping import Float
26
+ import numpy as np
22
27
  import numpyro.distributions as npd
23
28
 
24
29
  from gpjax.distributions import GaussianDistribution
@@ -36,6 +41,20 @@ from gpjax.typing import (
36
41
  )
37
42
 
38
43
 
44
+ @dataclass
45
+ class NoiseMoments:
46
+ log_variance: Array
47
+ inv_variance: Array
48
+ variance: Array
49
+
50
+
51
+ jax.tree_util.register_pytree_node(
52
+ NoiseMoments,
53
+ lambda x: ((x.log_variance, x.inv_variance, x.variance), None),
54
+ lambda _, x: NoiseMoments(*x),
55
+ )
56
+
57
+
39
58
  class AbstractLikelihood(nnx.Module):
40
59
  r"""Abstract base class for likelihoods.
41
60
 
@@ -103,6 +122,9 @@ class AbstractLikelihood(nnx.Module):
103
122
  y: Float[Array, "N D"],
104
123
  mean: Float[Array, "N D"],
105
124
  variance: Float[Array, "N D"],
125
+ mean_g: tp.Optional[Float[Array, "N D"]] = None,
126
+ variance_g: tp.Optional[Float[Array, "N D"]] = None,
127
+ **_: tp.Any,
106
128
  ) -> Float[Array, " N"]:
107
129
  r"""Compute the expected log likelihood.
108
130
 
@@ -116,6 +138,12 @@ class AbstractLikelihood(nnx.Module):
116
138
  y (Float[Array, 'N D']): The observed response variable.
117
139
  mean (Float[Array, 'N D']): The variational mean.
118
140
  variance (Float[Array, 'N D']): The variational variance.
141
+ mean_g (Float[Array, 'N D']): Optional moments of the latent noise
142
+ process for heteroscedastic likelihoods.
143
+ variance_g (Float[Array, 'N D']): Optional moments of the latent noise
144
+ process for heteroscedastic likelihoods.
145
+ **_: Unused extra arguments for compatibility with specialised
146
+ likelihoods.
119
147
 
120
148
  Returns:
121
149
  ScalarFloat: The expected log likelihood.
@@ -126,6 +154,143 @@ class AbstractLikelihood(nnx.Module):
126
154
  )
127
155
 
128
156
 
157
+ class AbstractNoiseTransform(nnx.Module):
158
+ """Abstract base class for noise transformations."""
159
+
160
+ @abc.abstractmethod
161
+ def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
162
+ """Transform the input noise signal."""
163
+ raise NotImplementedError
164
+
165
+ @abc.abstractmethod
166
+ def moments(
167
+ self, mean: Float[Array, "..."], variance: Float[Array, "..."]
168
+ ) -> NoiseMoments:
169
+ """Compute the moments of the transformed noise signal."""
170
+ raise NotImplementedError
171
+
172
+
173
+ class LogNormalTransform(AbstractNoiseTransform):
174
+ """Log-normal noise transformation."""
175
+
176
+ def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
177
+ return jnp.exp(x)
178
+
179
+ def moments(
180
+ self, mean: Float[Array, "..."], variance: Float[Array, "..."]
181
+ ) -> NoiseMoments:
182
+ expected_variance = jnp.exp(mean + 0.5 * variance)
183
+ expected_log_variance = mean
184
+ expected_inv_variance = jnp.exp(-mean + 0.5 * variance)
185
+ return NoiseMoments(
186
+ log_variance=expected_log_variance,
187
+ inv_variance=expected_inv_variance,
188
+ variance=expected_variance,
189
+ )
190
+
191
+
192
+ class SoftplusTransform(AbstractNoiseTransform):
193
+ """Softplus noise transformation."""
194
+
195
+ def __init__(self, num_points: int = 20):
196
+ self.num_points = num_points
197
+
198
+ def __call__(self, x: Float[Array, "..."]) -> Float[Array, "..."]:
199
+ return jnn.softplus(x)
200
+
201
+ def moments(
202
+ self, mean: Float[Array, "..."], variance: Float[Array, "..."]
203
+ ) -> NoiseMoments:
204
+ quad_x, quad_w = np.polynomial.hermite.hermgauss(self.num_points)
205
+ quad_w = jnp.asarray(quad_w / jnp.sqrt(jnp.pi))
206
+ quad_x = jnp.asarray(quad_x)
207
+
208
+ std = jnp.sqrt(variance)
209
+ samples = mean[..., None] + jnp.sqrt(2.0) * std[..., None] * quad_x
210
+ sigma2 = self(samples)
211
+ log_sigma2 = jnp.log(sigma2)
212
+ inv_sigma2 = 1.0 / sigma2
213
+
214
+ expected_variance = jnp.sum(sigma2 * quad_w, axis=-1)
215
+ expected_log_variance = jnp.sum(log_sigma2 * quad_w, axis=-1)
216
+ expected_inv_variance = jnp.sum(inv_sigma2 * quad_w, axis=-1)
217
+
218
+ return NoiseMoments(
219
+ log_variance=expected_log_variance,
220
+ inv_variance=expected_inv_variance,
221
+ variance=expected_variance,
222
+ )
223
+
224
+
225
+ class AbstractHeteroscedasticLikelihood(AbstractLikelihood):
226
+ r"""Base class for heteroscedastic likelihoods with latent noise processes."""
227
+
228
+ def __init__(
229
+ self,
230
+ num_datapoints: int,
231
+ noise_prior,
232
+ noise_transform: tp.Union[
233
+ AbstractNoiseTransform,
234
+ tp.Callable[[Float[Array, "..."]], Float[Array, "..."]],
235
+ ] = SoftplusTransform(),
236
+ integrator: AbstractIntegrator = GHQuadratureIntegrator(),
237
+ ):
238
+ self.noise_prior = noise_prior
239
+
240
+ if isinstance(noise_transform, AbstractNoiseTransform):
241
+ self.noise_transform = noise_transform
242
+ else:
243
+ transform_name = getattr(noise_transform, "__name__", "")
244
+ if noise_transform is jnp.exp or transform_name == "exp":
245
+ self.noise_transform = LogNormalTransform()
246
+ else:
247
+ # Default to SoftplusTransform for softplus or unknown callables (legacy behavior used quadrature)
248
+ # Note: If an unknown callable is passed, we technically use SoftplusTransform which applies softplus.
249
+ # Users should implement AbstractNoiseTransform for custom transforms.
250
+ self.noise_transform = SoftplusTransform()
251
+
252
+ super().__init__(num_datapoints=num_datapoints, integrator=integrator)
253
+
254
+ def __call__(
255
+ self,
256
+ dist: tp.Union[npd.MultivariateNormal, GaussianDistribution],
257
+ noise_dist: tp.Optional[
258
+ tp.Union[npd.MultivariateNormal, GaussianDistribution]
259
+ ] = None,
260
+ ) -> npd.Distribution:
261
+ return self.predict(dist, noise_dist)
262
+
263
+ def supports_tight_bound(self) -> bool:
264
+ """Return whether the tighter bound from Lázaro-Gredilla & Titsias (2011)
265
+ is applicable."""
266
+ return False
267
+
268
+ def noise_statistics(
269
+ self, mean: Float[Array, "N D"], variance: Float[Array, "N D"]
270
+ ) -> NoiseMoments:
271
+ r"""Moment matching of the transformed noise process.
272
+
273
+ Args:
274
+ mean: Mean of the latent noise GP.
275
+ variance: Variance of the latent noise GP.
276
+
277
+ Returns:
278
+ NoiseMoments: Expected log variance, inverse variance, and variance.
279
+ """
280
+ return self.noise_transform.moments(mean, variance)
281
+
282
+ def expected_log_likelihood(
283
+ self,
284
+ y: Float[Array, "N D"],
285
+ mean: Float[Array, "N D"],
286
+ variance: Float[Array, "N D"],
287
+ mean_g: tp.Optional[Float[Array, "N D"]] = None,
288
+ variance_g: tp.Optional[Float[Array, "N D"]] = None,
289
+ **kwargs: tp.Any,
290
+ ) -> Float[Array, " N"]:
291
+ raise NotImplementedError
292
+
293
+
129
294
  class Gaussian(AbstractLikelihood):
130
295
  r"""Gaussian likelihood object."""
131
296
 
@@ -186,6 +351,69 @@ class Gaussian(AbstractLikelihood):
186
351
  return npd.MultivariateNormal(dist.mean, noisy_cov)
187
352
 
188
353
 
354
+ class HeteroscedasticGaussian(AbstractHeteroscedasticLikelihood):
355
+ def predict(
356
+ self,
357
+ dist: tp.Union[npd.MultivariateNormal, GaussianDistribution],
358
+ noise_dist: tp.Optional[
359
+ tp.Union[npd.MultivariateNormal, GaussianDistribution]
360
+ ] = None,
361
+ ) -> npd.MultivariateNormal:
362
+ if noise_dist is None:
363
+ raise ValueError(
364
+ "noise_dist must be provided for heteroscedastic prediction."
365
+ )
366
+
367
+ n_data = dist.event_shape[0]
368
+ noise_mean = noise_dist.mean
369
+ noise_variance = jnp.diag(noise_dist.covariance_matrix)
370
+ noise_stats = self.noise_statistics(
371
+ noise_mean[..., None], noise_variance[..., None]
372
+ )
373
+
374
+ cov = dist.covariance_matrix
375
+ noisy_cov = cov.at[jnp.diag_indices(n_data)].add(noise_stats.variance.squeeze())
376
+
377
+ return npd.MultivariateNormal(dist.mean, noisy_cov)
378
+
379
+ def link_function(self, f: Float[Array, "..."]) -> npd.Normal:
380
+ sigma2 = self.noise_transform(jnp.zeros_like(f))
381
+ return npd.Normal(loc=f, scale=jnp.sqrt(sigma2))
382
+
383
+ def expected_log_likelihood(
384
+ self,
385
+ y: Float[Array, "N D"],
386
+ mean: Float[Array, "N D"],
387
+ variance: Float[Array, "N D"],
388
+ mean_g: tp.Optional[Float[Array, "N D"]] = None,
389
+ variance_g: tp.Optional[Float[Array, "N D"]] = None,
390
+ noise_stats: tp.Optional[NoiseMoments] = None,
391
+ return_parts: bool = False,
392
+ **_: tp.Any,
393
+ ) -> tp.Union[Float[Array, " N"], tuple[Float[Array, " N"], NoiseMoments]]:
394
+ if mean_g is None or variance_g is None:
395
+ raise ValueError(
396
+ "mean_g and variance_g must be provided for heteroscedastic models."
397
+ )
398
+
399
+ if noise_stats is None:
400
+ noise_stats = self.noise_statistics(mean_g, variance_g)
401
+ sq_error = jnp.square(y - mean)
402
+ log2pi = jnp.log(2.0 * jnp.pi)
403
+ expected = -0.5 * (
404
+ log2pi
405
+ + noise_stats.log_variance
406
+ + (sq_error + variance) * noise_stats.inv_variance
407
+ )
408
+ expected_sum = jnp.sum(expected, axis=1)
409
+ if return_parts:
410
+ return expected_sum, noise_stats
411
+ return expected_sum
412
+
413
+ def supports_tight_bound(self) -> bool:
414
+ return True
415
+
416
+
189
417
  class Bernoulli(AbstractLikelihood):
190
418
  def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs:
191
419
  r"""The probit link function of the Bernoulli likelihood.
@@ -268,7 +496,13 @@ __all__ = [
268
496
  "AbstractLikelihood",
269
497
  "NonGaussian",
270
498
  "Gaussian",
499
+ "AbstractHeteroscedasticLikelihood",
500
+ "HeteroscedasticGaussian",
271
501
  "Bernoulli",
272
502
  "Poisson",
273
503
  "inv_probit",
504
+ "NoiseMoments",
505
+ "AbstractNoiseTransform",
506
+ "LogNormalTransform",
507
+ "SoftplusTransform",
274
508
  ]
gpjax/mean_functions.py CHANGED
@@ -147,9 +147,9 @@ class Constant(AbstractMeanFunction):
147
147
  Float[Array, "1"]: The evaluated mean function.
148
148
  """
149
149
  if isinstance(self.constant, Parameter):
150
- return jnp.ones((x.shape[0], 1)) * self.constant.value
150
+ return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant.value
151
151
  else:
152
- return jnp.ones((x.shape[0], 1)) * self.constant
152
+ return jnp.ones((x.shape[0], 1), dtype=x.dtype) * self.constant
153
153
 
154
154
 
155
155
  class Zero(Constant):
gpjax/objectives.py CHANGED
@@ -14,6 +14,9 @@ from gpjax.gps import (
14
14
  ConjugatePosterior,
15
15
  NonConjugatePosterior,
16
16
  )
17
+ from gpjax.likelihoods import (
18
+ AbstractHeteroscedasticLikelihood,
19
+ )
17
20
  from gpjax.linalg import (
18
21
  Dense,
19
22
  lower_cholesky,
@@ -25,9 +28,13 @@ from gpjax.typing import (
25
28
  Array,
26
29
  ScalarFloat,
27
30
  )
28
- from gpjax.variational_families import AbstractVariationalFamily
31
+ from gpjax.variational_families import (
32
+ AbstractVariationalFamily,
33
+ HeteroscedasticVariationalFamily,
34
+ )
29
35
 
30
36
  VF = TypeVar("VF", bound=AbstractVariationalFamily)
37
+ HVF = TypeVar("HVF", bound=HeteroscedasticVariationalFamily)
31
38
 
32
39
 
33
40
  Objective = tpe.Callable[[nnx.Module, Dataset], ScalarFloat]
@@ -414,3 +421,51 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
414
421
 
415
422
  # log N(y; μx, Io² + KxzKzz⁻¹Kzx) - 1/2o² tr(Kxx - KxzKzz⁻¹Kzx)
416
423
  return (two_log_prob - two_trace).squeeze() / 2.0
424
+
425
+
426
+ def heteroscedastic_elbo_conjugate(
427
+ variational_family: HVF, data: Dataset
428
+ ) -> ScalarFloat:
429
+ r"""Tight bound from Lázaro-Gredilla & Titsias (2011) for heteroscedastic Gaussian likelihoods."""
430
+ likelihood = variational_family.posterior.likelihood
431
+ mean_f, var_f, mean_g, var_g = variational_family.predict(data.X)
432
+
433
+ expected_ll, _ = likelihood.expected_log_likelihood(
434
+ data.y,
435
+ mean_f,
436
+ var_f,
437
+ mean_g=mean_g,
438
+ variance_g=var_g,
439
+ return_parts=True,
440
+ )
441
+
442
+ scale = likelihood.num_datapoints / data.n
443
+ return scale * jnp.sum(expected_ll) - variational_family.prior_kl()
444
+
445
+
446
+ def heteroscedastic_elbo_chained(variational_family: HVF, data: Dataset) -> ScalarFloat:
447
+ r"""Generic chained bound for heteroscedastic likelihoods."""
448
+ likelihood: AbstractHeteroscedasticLikelihood = (
449
+ variational_family.posterior.likelihood
450
+ )
451
+ mean_f, var_f, mean_g, var_g = variational_family.predict(data.X)
452
+ noise_stats = likelihood.noise_statistics(mean_g, var_g)
453
+
454
+ expected_ll = likelihood.expected_log_likelihood(
455
+ data.y,
456
+ mean_f,
457
+ var_f,
458
+ mean_g=mean_g,
459
+ variance_g=var_g,
460
+ noise_stats=noise_stats,
461
+ )
462
+
463
+ scale = likelihood.num_datapoints / data.n
464
+ return scale * jnp.sum(expected_ll) - variational_family.prior_kl()
465
+
466
+
467
+ def heteroscedastic_elbo(variational_family: HVF, data: Dataset) -> ScalarFloat:
468
+ likelihood = variational_family.posterior.likelihood
469
+ if likelihood.supports_tight_bound():
470
+ return heteroscedastic_elbo_conjugate(variational_family, data)
471
+ return heteroscedastic_elbo_chained(variational_family, data)
gpjax/parameters.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import typing as tp
2
2
 
3
3
  from flax import nnx
4
+ import jax
4
5
  from jax.experimental import checkify
5
6
  import jax.numpy as jnp
6
7
  import jax.tree_util as jtu
@@ -77,7 +78,14 @@ class Parameter(nnx.Variable[T]):
77
78
  _check_is_arraylike(value)
78
79
 
79
80
  super().__init__(value=jnp.asarray(value), **kwargs)
80
- self.tag = tag
81
+
82
+ # nnx.Variable metadata must be set via set_metadata (direct setattr is disallowed).
83
+ self.set_metadata(tag=tag)
84
+
85
+ @property
86
+ def tag(self) -> ParameterTag:
87
+ """Return the parameter's constraint tag."""
88
+ return self.metadata.get("tag", "real")
81
89
 
82
90
 
83
91
  class NonNegativeReal(Parameter[T]):
@@ -155,7 +163,7 @@ def _check_is_arraylike(value: T) -> None:
155
163
  Raises:
156
164
  TypeError: If the value is not array-like.
157
165
  """
158
- if not isinstance(value, (ArrayLike, list)):
166
+ if not isinstance(value, (jax.Array, ArrayLike, list)):
159
167
  raise TypeError(
160
168
  f"Expected parameter value to be an array-like type. Got {value}."
161
169
  )
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import abc
17
+ from dataclasses import dataclass
17
18
 
18
19
  import beartype.typing as tp
19
20
  from flax import nnx
@@ -29,9 +30,12 @@ from gpjax.distributions import GaussianDistribution
29
30
  from gpjax.gps import (
30
31
  AbstractPosterior,
31
32
  AbstractPrior,
33
+ ChainedPosterior,
34
+ HeteroscedasticPosterior,
32
35
  )
33
36
  from gpjax.kernels.base import AbstractKernel
34
37
  from gpjax.likelihoods import (
38
+ AbstractHeteroscedasticLikelihood,
35
39
  Gaussian,
36
40
  NonGaussian,
37
41
  )
@@ -59,8 +63,10 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
59
63
  L = tp.TypeVar("L", Gaussian, NonGaussian)
60
64
  NGL = tp.TypeVar("NGL", bound=NonGaussian)
61
65
  GL = tp.TypeVar("GL", bound=Gaussian)
66
+ HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
62
67
  P = tp.TypeVar("P", bound=AbstractPrior)
63
68
  PP = tp.TypeVar("PP", bound=AbstractPosterior)
69
+ HP = tp.TypeVar("HP", HeteroscedasticPosterior, ChainedPosterior)
64
70
 
65
71
 
66
72
  class AbstractVariationalFamily(nnx.Module, tp.Generic[L]):
@@ -870,6 +876,126 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
870
876
  )
871
877
 
872
878
 
879
+ @dataclass
880
+ class VariationalGaussianInit:
881
+ """Initialization parameters for a variational Gaussian distribution."""
882
+
883
+ inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
884
+ variational_mean: tp.Union[Float[Array, "N 1"], None] = None
885
+ variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None
886
+
887
+
888
+ class HeteroscedasticPrediction(tp.NamedTuple):
889
+ """Mean and variance of the signal and noise latent processes."""
890
+
891
+ mean_f: Float[Array, "N 1"]
892
+ variance_f: Float[Array, "N 1"]
893
+ mean_g: Float[Array, "N 1"]
894
+ variance_g: Float[Array, "N 1"]
895
+
896
+
897
+ class HeteroscedasticVariationalFamily(AbstractVariationalFamily[HL]):
898
+ r"""Variational family for two independent latent processes f and g."""
899
+
900
+ def __init__(
901
+ self,
902
+ posterior: HP,
903
+ inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] = None,
904
+ inducing_inputs_g: tp.Union[
905
+ Int[Array, "M D"], Float[Array, "M D"], None
906
+ ] = None,
907
+ variational_mean_f: tp.Union[Float[Array, "N 1"], None] = None,
908
+ variational_root_covariance_f: tp.Union[Float[Array, "N N"], None] = None,
909
+ variational_mean_g: tp.Union[Float[Array, "M 1"], None] = None,
910
+ variational_root_covariance_g: tp.Union[Float[Array, "M M"], None] = None,
911
+ jitter: ScalarFloat = 1e-6,
912
+ signal_init: tp.Optional[VariationalGaussianInit] = None,
913
+ noise_init: tp.Optional[VariationalGaussianInit] = None,
914
+ ):
915
+ self.jitter = jitter
916
+
917
+ if signal_init is not None:
918
+ self.signal_variational = VariationalGaussian(
919
+ posterior=posterior,
920
+ inducing_inputs=signal_init.inducing_inputs,
921
+ variational_mean=signal_init.variational_mean,
922
+ variational_root_covariance=signal_init.variational_root_covariance,
923
+ jitter=jitter,
924
+ )
925
+ elif inducing_inputs is not None:
926
+ self.signal_variational = VariationalGaussian(
927
+ posterior=posterior,
928
+ inducing_inputs=inducing_inputs,
929
+ variational_mean=variational_mean_f,
930
+ variational_root_covariance=variational_root_covariance_f,
931
+ jitter=jitter,
932
+ )
933
+ else:
934
+ raise ValueError("Either signal_init or inducing_inputs must be provided.")
935
+
936
+ if noise_init is not None:
937
+ self.noise_variational = VariationalGaussian(
938
+ posterior=posterior.noise_posterior,
939
+ inducing_inputs=noise_init.inducing_inputs,
940
+ variational_mean=noise_init.variational_mean,
941
+ variational_root_covariance=noise_init.variational_root_covariance,
942
+ jitter=jitter,
943
+ )
944
+ else:
945
+ noise_inducing = (
946
+ inducing_inputs if inducing_inputs_g is None else inducing_inputs_g
947
+ )
948
+ if noise_inducing is None and signal_init is not None:
949
+ noise_inducing = signal_init.inducing_inputs
950
+
951
+ if noise_inducing is None:
952
+ raise ValueError(
953
+ "Could not determine inducing inputs for noise process."
954
+ )
955
+
956
+ self.noise_variational = VariationalGaussian(
957
+ posterior=posterior.noise_posterior,
958
+ inducing_inputs=noise_inducing,
959
+ variational_mean=variational_mean_g,
960
+ variational_root_covariance=variational_root_covariance_g,
961
+ jitter=jitter,
962
+ )
963
+ super().__init__(posterior)
964
+
965
+ def prior_kl(self) -> ScalarFloat:
966
+ return self.signal_variational.prior_kl() + self.noise_variational.prior_kl()
967
+
968
+ def predict(
969
+ self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
970
+ ) -> HeteroscedasticPrediction:
971
+ dist_f = self.signal_variational.predict(test_inputs)
972
+ dist_g = self.noise_variational.predict(test_inputs)
973
+
974
+ mean_f = dist_f.mean[:, None] if dist_f.mean.ndim == 1 else dist_f.mean
975
+ var_f = (
976
+ dist_f.variance[:, None] if dist_f.variance.ndim == 1 else dist_f.variance
977
+ )
978
+ mean_g = dist_g.mean[:, None] if dist_g.mean.ndim == 1 else dist_g.mean
979
+ var_g = (
980
+ dist_g.variance[:, None] if dist_g.variance.ndim == 1 else dist_g.variance
981
+ )
982
+
983
+ return HeteroscedasticPrediction(
984
+ mean_f=mean_f,
985
+ variance_f=var_f,
986
+ mean_g=mean_g,
987
+ variance_g=var_g,
988
+ )
989
+
990
+ def predict_latents(
991
+ self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
992
+ ) -> tuple[GaussianDistribution, GaussianDistribution]:
993
+ return (
994
+ self.signal_variational.predict(test_inputs),
995
+ self.noise_variational.predict(test_inputs),
996
+ )
997
+
998
+
873
999
  __all__ = [
874
1000
  "AbstractVariationalFamily",
875
1001
  "AbstractVariationalGaussian",
@@ -879,4 +1005,7 @@ __all__ = [
879
1005
  "NaturalVariationalGaussian",
880
1006
  "ExpectationVariationalGaussian",
881
1007
  "CollapsedVariationalGaussian",
1008
+ "HeteroscedasticVariationalFamily",
1009
+ "VariationalGaussianInit",
1010
+ "HeteroscedasticPrediction",
882
1011
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.3
3
+ Version: 0.13.5
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
25
25
  Requires-Dist: numpy>=2.0.0
26
26
  Requires-Dist: numpyro
27
27
  Requires-Dist: optax>0.2.1
28
+ Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
28
29
  Requires-Dist: tqdm>4.66.2
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
59
60
  Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
60
61
  Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
61
62
  Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
62
- Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
63
64
  Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
64
65
  Requires-Dist: networkx>=3.0; extra == 'docs'
65
66
  Requires-Dist: pandas>=1.5.3; extra == 'docs'
@@ -141,7 +142,7 @@ GPJax into the package it is today.
141
142
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
142
143
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
143
144
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
144
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
145
+ > - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
145
146
  > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
146
147
  > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
147
148
  > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
@@ -1,18 +1,18 @@
1
- gpjax/__init__.py,sha256=8EULPS_vtq4TeN6MjdLVHLlTVVyQADOoHtuVoRN8z5Y,1625
2
- gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
1
+ gpjax/__init__.py,sha256=8guaWVHiNXlZ9m_XvjPo8nherGpCpGWCCxB5fenq_Gw,1625
2
+ gpjax/citation.py,sha256=9L-OVIDCrwjM3TeWinC5UHS3Lxn_9sVIsj_H9i08BT8,4428
3
3
  gpjax/dataset.py,sha256=Ef5JGrl4jJS1mQmL3JdO0fdqbVmflT_Cu5VrlpYdJY4,4071
4
4
  gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
5
  gpjax/fit.py,sha256=tOXmM3l5-N3Jlnq8MOVkSpRj-0fRWOy2t9GxQRyUqxY,15318
6
- gpjax/gps.py,sha256=NcPXkkx0kXrSBTUje4QR6PS0DGrD7p5c-DyiOATwUz8,35338
6
+ gpjax/gps.py,sha256=83rki591ma3B_yr1Bq1OUlAH75sThpg-x00ZZXeqwN0,37906
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
- gpjax/mean_functions.py,sha256=Aq09I5h5DZe1WokRxLa0Mpj8U0kJxpQ8CmdJJyCjNOc,6541
8
+ gpjax/likelihoods.py,sha256=SzGh3uAkULVMfAVA3QpKQwWp-XAcx46B7QhhgKymn6A,17381
9
+ gpjax/mean_functions.py,sha256=cDv3X5E4-to8pgzHpbKhoMKDLdWg3WhEq1W21kvbDDg,6571
10
10
  gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
- gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
- gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
11
+ gpjax/objectives.py,sha256=QScq5e77-jugtt2wKZNLgk3VtkqSDddJyYBs-hQ7dZc,17206
12
+ gpjax/parameters.py,sha256=iGab7IMAANuhFm8ojcK5x5fJNzUOdflr2h2pBaVLe9Q,7035
13
13
  gpjax/scan.py,sha256=Z_V1yd76dL4B7_rnJnr_fohom6xzN_WRYxlTAAdqfa0,5357
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=x4VnUh9GiW77ijnDEwrYzH0aWFOW6NLPQXEp--I9R-g,31566
15
+ gpjax/variational_families.py,sha256=J7L1lh0McYSgr1BKq5vsVjtfu2XeyOCn-I_AHQym-xA,36723
16
16
  gpjax/kernels/__init__.py,sha256=GFiku1s8KPPzvfwDYzEjENPQReywvJ29M9Kl2JjkcoU,1893
17
17
  gpjax/kernels/base.py,sha256=gov60BgXOQK0PxbKQRVgjU55kCVZYAkDZrAInODvugo,11549
18
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
@@ -46,7 +46,7 @@ gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
46
  gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
47
  gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
48
  gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
- gpjax-0.13.3.dist-info/METADATA,sha256=1Au0xXIXWvS64SPLP8R_mPQ6nuUMayx2qdNLD-HdhW8,10296
50
- gpjax-0.13.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.13.3.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.13.3.dist-info/RECORD,,
49
+ gpjax-0.13.5.dist-info/METADATA,sha256=qJ8ZeG1TGMX6a_2WQnNSwTl2JBXVp8107QqY9f1XWyE,10382
50
+ gpjax-0.13.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
51
+ gpjax-0.13.5.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.13.5.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any