gpjax 0.9.2__py3-none-any.whl → 0.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Didactic Gaussian processes in JAX"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.2"
43
+ __version__ = "0.9.4"
44
44
 
45
45
  __all__ = [
46
46
  "base",
@@ -32,7 +32,7 @@ from gpjax.typing import (
32
32
  class PoissonTestFunction:
33
33
  """
34
34
  Test function for GPs utilising the Poisson likelihood. Function taken from
35
- https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
35
+ https://docs.jaxgaussianprocesses.com/_examples/poisson/#dataset.
36
36
 
37
37
  Attributes:
38
38
  search_space (ContinuousSearchSpace): Search space for the function.
@@ -77,7 +77,7 @@ class PoissonTestFunction:
77
77
  def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
78
78
  """
79
79
  Evaluate the test function at a set of points. Function taken from
80
- https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
80
+ https://docs.jaxgaussianprocesses.com/_examples/poisson/#dataset.
81
81
 
82
82
  Args:
83
83
  x (Float[Array, 'N D']): Points to evaluate the test function at.
gpjax/gps.py CHANGED
@@ -652,7 +652,8 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
652
652
  """
653
653
  super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
654
654
 
655
- latent = latent or jr.normal(key, shape=(self.likelihood.num_datapoints, 1))
655
+ if latent is None:
656
+ latent = jr.normal(key, shape=(self.likelihood.num_datapoints, 1))
656
657
 
657
658
  # TODO: static or intermediate?
658
659
  self.latent = latent if isinstance(latent, Parameter) else Real(latent)
gpjax/likelihoods.py CHANGED
@@ -28,7 +28,6 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- Parameter,
32
31
  PositiveReal,
33
32
  Static,
34
33
  )
@@ -152,10 +151,9 @@ class Gaussian(AbstractLikelihood):
152
151
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
153
152
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
154
153
  """
155
- if isinstance(obs_stddev, Parameter):
156
- self.obs_stddev = obs_stddev
157
- else:
158
- self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
154
+ if not isinstance(obs_stddev, (PositiveReal, Static)):
155
+ obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
156
+ self.obs_stddev = obs_stddev
159
157
 
160
158
  super().__init__(num_datapoints, integrator)
161
159
 
gpjax/scan.py CHANGED
@@ -22,7 +22,6 @@ from beartype.typing import (
22
22
  )
23
23
  import jax
24
24
  from jax import lax
25
- from jax.experimental import host_callback as hcb
26
25
  import jax.numpy as jnp
27
26
  import jax.tree_util as jtu
28
27
  from jaxtyping import (
@@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:
54
53
 
55
54
  def _do_callback(_) -> int:
56
55
  """Perform the callback."""
57
- return hcb.id_tap(func, *args, result=_dummy_result)
56
+ jax.debug.callback(func, *args)
57
+ return _dummy_result
58
58
 
59
59
  def _not_callback(_) -> int:
60
60
  """Do nothing."""
@@ -113,19 +113,19 @@ def vscan(
113
113
  _progress_bar = trange(_length)
114
114
  _progress_bar.set_description("Compiling...", refresh=True)
115
115
 
116
- def _set_running(args: Any, transform: Any) -> None:
116
+ def _set_running(*args: Any) -> None:
117
117
  """Set the tqdm progress bar to running."""
118
118
  _progress_bar.set_description("Running", refresh=False)
119
119
 
120
- def _update_tqdm(args: Any, transform: Any) -> None:
120
+ def _update_tqdm(*args: Any) -> None:
121
121
  """Update the tqdm progress bar with the latest objective value."""
122
122
  _value, _iter_num = args
123
- _progress_bar.update(_iter_num)
123
+ _progress_bar.update(_iter_num.item())
124
124
 
125
125
  if log_value and _value is not None:
126
126
  _progress_bar.set_postfix({"Value": f"{_value: .2f}"})
127
127
 
128
- def _close_tqdm(args: Any, transform: Any) -> None:
128
+ def _close_tqdm(*args: Any) -> None:
129
129
  """Close the tqdm progress bar."""
130
130
  _progress_bar.close()
131
131
 
@@ -145,16 +145,16 @@ def vscan(
145
145
  _is_last: bool = iter_num == _length - 1
146
146
 
147
147
  # Update progress bar, if first of log_rate.
148
- _callback(_is_first, _set_running, (y, log_rate))
148
+ _callback(_is_first, _set_running)
149
149
 
150
150
  # Update progress bar, if multiple of log_rate.
151
- _callback(_is_multiple, _update_tqdm, (y, log_rate))
151
+ _callback(_is_multiple, _update_tqdm, y, log_rate)
152
152
 
153
153
  # Update progress bar, if remainder.
154
- _callback(_is_remainder, _update_tqdm, (y, _remainder))
154
+ _callback(_is_remainder, _update_tqdm, y, _remainder)
155
155
 
156
156
  # Close progress bar, if last iteration.
157
- _callback(_is_last, _close_tqdm, (y, None))
157
+ _callback(_is_last, _close_tqdm)
158
158
 
159
159
  return carry, y
160
160
 
@@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
108
  def __init__(
109
109
  self,
110
110
  posterior: AbstractPosterior[P, L],
111
- inducing_inputs: Float[Array, "N D"],
111
+ inducing_inputs: tp.Union[
112
+ Float[Array, "N D"],
113
+ Real,
114
+ Static,
115
+ ],
112
116
  jitter: ScalarFloat = 1e-6,
113
117
  ):
114
- self.inducing_inputs = Static(inducing_inputs)
118
+ if not isinstance(inducing_inputs, (Real, Static)):
119
+ inducing_inputs = Real(inducing_inputs)
120
+
121
+ self.inducing_inputs = inducing_inputs
115
122
  self.jitter = jitter
116
123
 
117
124
  super().__init__(posterior)
@@ -142,12 +149,14 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
142
149
  ):
143
150
  super().__init__(posterior, inducing_inputs, jitter)
144
151
 
145
- self.variational_mean = Real(
146
- variational_mean or jnp.zeros((self.num_inducing, 1))
147
- )
148
- self.variational_root_covariance = LowerTriangular(
149
- variational_root_covariance or jnp.eye(self.num_inducing)
150
- )
152
+ if variational_mean is None:
153
+ variational_mean = jnp.zeros((self.num_inducing, 1))
154
+
155
+ if variational_root_covariance is None:
156
+ variational_root_covariance = jnp.eye(self.num_inducing)
157
+
158
+ self.variational_mean = Real(variational_mean)
159
+ self.variational_root_covariance = LowerTriangular(variational_root_covariance)
151
160
 
152
161
  def prior_kl(self) -> ScalarFloat:
153
162
  r"""Compute the prior KL divergence.
@@ -371,12 +380,14 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
371
380
  ):
372
381
  super().__init__(posterior, inducing_inputs, jitter)
373
382
 
374
- self.natural_vector = Static(
375
- natural_vector or jnp.zeros((self.num_inducing, 1))
376
- )
377
- self.natural_matrix = Static(
378
- natural_matrix or -0.5 * jnp.eye(self.num_inducing)
379
- )
383
+ if natural_vector is None:
384
+ natural_vector = jnp.zeros((self.num_inducing, 1))
385
+
386
+ if natural_matrix is None:
387
+ natural_matrix = -0.5 * jnp.eye(self.num_inducing)
388
+
389
+ self.natural_vector = Static(natural_vector)
390
+ self.natural_matrix = Static(natural_matrix)
380
391
 
381
392
  def prior_kl(self) -> ScalarFloat:
382
393
  r"""Compute the KL-divergence between our current variational approximation
@@ -533,13 +544,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
533
544
  ):
534
545
  super().__init__(posterior, inducing_inputs, jitter)
535
546
 
536
- # must come after super().__init__
537
- self.expectation_vector = Static(
538
- expectation_vector or jnp.zeros((self.num_inducing, 1))
539
- )
540
- self.expectation_matrix = Static(
541
- expectation_matrix or jnp.eye(self.num_inducing)
542
- )
547
+ if expectation_vector is None:
548
+ expectation_vector = jnp.zeros((self.num_inducing, 1))
549
+
550
+ if expectation_matrix is None:
551
+ expectation_matrix = jnp.eye(self.num_inducing)
552
+
553
+ self.expectation_vector = Static(expectation_vector)
554
+ self.expectation_matrix = Static(expectation_matrix)
543
555
 
544
556
  def prior_kl(self) -> ScalarFloat:
545
557
  r"""Evaluate the prior KL-divergence.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.9.2
3
+ Version: 0.9.4
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
19
19
  Requires-Python: <3.13,>=3.10
20
20
  Requires-Dist: beartype>0.16.1
21
21
  Requires-Dist: cola-ml==0.0.5
22
- Requires-Dist: flax>=0.8.5
22
+ Requires-Dist: flax<0.10.0
23
23
  Requires-Dist: jax<0.4.28
24
24
  Requires-Dist: jaxlib<0.4.28
25
25
  Requires-Dist: jaxopt==0.8.2
@@ -103,23 +103,23 @@ helped to shape GPJax into the package it is today.
103
103
 
104
104
  ## Notebook examples
105
105
 
106
- > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
107
- > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
108
- > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
109
- > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
110
- > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
111
- > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
112
- > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
113
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/examples/spatial/)
114
- > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/examples/barycentres/)
115
- > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/examples/deep_kernels/)
116
- > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/examples/poisson/)
117
- > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/)
106
+ > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/_examples/regression/)
107
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/_examples/classification/)
108
+ > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/collapsed_vi/)
109
+ > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/_examples/uncollapsed_vi/)
110
+ > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
111
+ > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
112
+ > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
113
+ > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
114
+ > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
115
+ > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
116
+ > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
117
+ > - [**Bayesian Optimisation**](https://docs.jaxgaussianprocesses.com/_examples/bayesian_optimisation/)
118
118
 
119
119
  ## Guides for customisation
120
120
  >
121
- > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
122
- > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/examples/yacht/)
121
+ > - [**Custom kernels**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
122
+ > - [**UCI regression**](https://docs.jaxgaussianprocesses.com/_examples/yacht/)
123
123
 
124
124
  ## Conversion between `.ipynb` and `.py`
125
125
  Above examples are stored in [examples](docs/examples) directory in the double
@@ -180,7 +180,7 @@ optimiser = ox.adam(learning_rate=1e-2)
180
180
  # Obtain Type 2 MLEs of the hyperparameters
181
181
  opt_posterior, history = gpx.fit(
182
182
  model=posterior,
183
- objective=gpx.objectives.conjugate_mll,
183
+ objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
184
184
  train_data=D,
185
185
  optim=optimiser,
186
186
  num_iters=500,
@@ -1,18 +1,18 @@
1
- gpjax/__init__.py,sha256=Bx5JFaveeVk3qJMTzbmrKOFy0U7fNcQ_JnVo5m0ACGA,1697
1
+ gpjax/__init__.py,sha256=f1Sl-8Oz6YuEueKxvzIAL0iH_9b9xGzQv07tddS5wto,1697
2
2
  gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
5
5
  gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
6
- gpjax/gps.py,sha256=NO18geRfcjo4mA3PGkuGont_Mj_yRqfvWzJqYmoKwiY,31225
6
+ gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=Uh4kgLTod8ODw178L--G3w4olpm9XvCdcAZ8l7FwkF4,9255
8
+ gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
10
  gpjax/mean_functions.py,sha256=et2HzlsYJNViBvTohF2wZYgCWQfDX4KboYeO7egMR1c,6420
11
11
  gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
12
  gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
13
- gpjax/scan.py,sha256=mtMsg8yLdkVuOYeTHLnATPGfGDnCMAQNdUA-FJlpfLs,5475
13
+ gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=Eik5CCU7qH7_7cacpZ-1lIXm4tElELwSYfVw-n0rI20,27742
15
+ gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
16
16
  gpjax/decision_making/__init__.py,sha256=SDuPQl80lJ7nhfRsiB_7c22wCMiQO5ehSNohxUGnB7w,2170
17
17
  gpjax/decision_making/decision_maker.py,sha256=S4pOXrWcEHy0NDA0gfWzhk7pG0NJfaPpMXvq03yTy0g,13915
18
18
  gpjax/decision_making/posterior_handler.py,sha256=UgXf1Gu7GMh2YDSmiSWJIzmWlFW06KTS44HYz3mazZQ,5905
@@ -21,7 +21,7 @@ gpjax/decision_making/utility_maximizer.py,sha256=VT2amwSJbB64IL_MiWNl9ZgjcqO757
21
21
  gpjax/decision_making/utils.py,sha256=5j1GO5kcmG2laZR39NjhqgEjRekAWWzrnREv_5Zct_Y,2367
22
22
  gpjax/decision_making/test_functions/__init__.py,sha256=GDCY9_kaAnxDWwzo1FkdxnDx-80MErAHchbGybT9xYs,1109
23
23
  gpjax/decision_making/test_functions/continuous_functions.py,sha256=oL5ZQkvmbC3u9rEvSYI2DRAN3r7Ynf7wRZQlUWjKjt0,5612
24
- gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=cfo3xQOzB5ajMjjl0YFfNlJClkAcY7ZbT23UyBYEofQ,2955
24
+ gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=eJpCnTS9dRieLxpjH4L6OTsP-w9JM3XhjnzCfk2Xqn8,2957
25
25
  gpjax/decision_making/utility_functions/__init__.py,sha256=xXI-4JKWAfTJ7XZ1vRDpqtb91MNzSPD0lP6xo0tOc7o,1445
26
26
  gpjax/decision_making/utility_functions/base.py,sha256=FOqrsRDmtHiCVl6IHr12-AEYBLStzMT5EBs-F92e1Og,3882
27
27
  gpjax/decision_making/utility_functions/expected_improvement.py,sha256=H6hjC-lj1oiHf2BomeQqroORQ7vtcOngiDAWxRwkNbg,4481
@@ -56,7 +56,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
56
56
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
57
57
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
58
58
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
59
- gpjax-0.9.2.dist-info/METADATA,sha256=JWT3cDW7onuKnTYUGqa15WxG4L7oEboJKPHYyAggYZ0,9976
60
- gpjax-0.9.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
61
- gpjax-0.9.2.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
- gpjax-0.9.2.dist-info/RECORD,,
59
+ gpjax-0.9.4.dist-info/METADATA,sha256=Qx_Qv91sE7_Y-c9CGuF40QBFk2FjLW0Fo2SHqFAgQFQ,10010
60
+ gpjax-0.9.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
61
+ gpjax-0.9.4.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
+ gpjax-0.9.4.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any