gpjax 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Didactic Gaussian processes in JAX"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.9.2"
43
+ __version__ = "0.9.3"
44
44
 
45
45
  __all__ = [
46
46
  "base",
gpjax/likelihoods.py CHANGED
@@ -28,7 +28,6 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- Parameter,
32
31
  PositiveReal,
33
32
  Static,
34
33
  )
@@ -152,10 +151,9 @@ class Gaussian(AbstractLikelihood):
152
151
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
153
152
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
154
153
  """
155
- if isinstance(obs_stddev, Parameter):
156
- self.obs_stddev = obs_stddev
157
- else:
158
- self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
154
+ if not isinstance(obs_stddev, (PositiveReal, Static)):
155
+ obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
156
+ self.obs_stddev = obs_stddev
159
157
 
160
158
  super().__init__(num_datapoints, integrator)
161
159
 
gpjax/scan.py CHANGED
@@ -22,7 +22,6 @@ from beartype.typing import (
22
22
  )
23
23
  import jax
24
24
  from jax import lax
25
- from jax.experimental import host_callback as hcb
26
25
  import jax.numpy as jnp
27
26
  import jax.tree_util as jtu
28
27
  from jaxtyping import (
@@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:
54
53
 
55
54
  def _do_callback(_) -> int:
56
55
  """Perform the callback."""
57
- return hcb.id_tap(func, *args, result=_dummy_result)
56
+ jax.debug.callback(func, *args)
57
+ return _dummy_result
58
58
 
59
59
  def _not_callback(_) -> int:
60
60
  """Do nothing."""
@@ -113,19 +113,19 @@ def vscan(
113
113
  _progress_bar = trange(_length)
114
114
  _progress_bar.set_description("Compiling...", refresh=True)
115
115
 
116
- def _set_running(args: Any, transform: Any) -> None:
116
+ def _set_running(*args: Any) -> None:
117
117
  """Set the tqdm progress bar to running."""
118
118
  _progress_bar.set_description("Running", refresh=False)
119
119
 
120
- def _update_tqdm(args: Any, transform: Any) -> None:
120
+ def _update_tqdm(*args: Any) -> None:
121
121
  """Update the tqdm progress bar with the latest objective value."""
122
122
  _value, _iter_num = args
123
- _progress_bar.update(_iter_num)
123
+ _progress_bar.update(_iter_num.item())
124
124
 
125
125
  if log_value and _value is not None:
126
126
  _progress_bar.set_postfix({"Value": f"{_value: .2f}"})
127
127
 
128
- def _close_tqdm(args: Any, transform: Any) -> None:
128
+ def _close_tqdm(*args: Any) -> None:
129
129
  """Close the tqdm progress bar."""
130
130
  _progress_bar.close()
131
131
 
@@ -145,16 +145,16 @@ def vscan(
145
145
  _is_last: bool = iter_num == _length - 1
146
146
 
147
147
  # Update progress bar, if first of log_rate.
148
- _callback(_is_first, _set_running, (y, log_rate))
148
+ _callback(_is_first, _set_running)
149
149
 
150
150
  # Update progress bar, if multiple of log_rate.
151
- _callback(_is_multiple, _update_tqdm, (y, log_rate))
151
+ _callback(_is_multiple, _update_tqdm, y, log_rate)
152
152
 
153
153
  # Update progress bar, if remainder.
154
- _callback(_is_remainder, _update_tqdm, (y, _remainder))
154
+ _callback(_is_remainder, _update_tqdm, y, _remainder)
155
155
 
156
156
  # Close progress bar, if last iteration.
157
- _callback(_is_last, _close_tqdm, (y, None))
157
+ _callback(_is_last, _close_tqdm)
158
158
 
159
159
  return carry, y
160
160
 
@@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
108
  def __init__(
109
109
  self,
110
110
  posterior: AbstractPosterior[P, L],
111
- inducing_inputs: Float[Array, "N D"],
111
+ inducing_inputs: tp.Union[
112
+ Float[Array, "N D"],
113
+ Real,
114
+ Static,
115
+ ],
112
116
  jitter: ScalarFloat = 1e-6,
113
117
  ):
114
- self.inducing_inputs = Static(inducing_inputs)
118
+ if not isinstance(inducing_inputs, (Real, Static)):
119
+ inducing_inputs = Real(inducing_inputs)
120
+
121
+ self.inducing_inputs = inducing_inputs
115
122
  self.jitter = jitter
116
123
 
117
124
  super().__init__(posterior)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: gpjax
3
- Version: 0.9.2
3
+ Version: 0.9.3
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,18 +1,18 @@
1
- gpjax/__init__.py,sha256=Bx5JFaveeVk3qJMTzbmrKOFy0U7fNcQ_JnVo5m0ACGA,1697
1
+ gpjax/__init__.py,sha256=UNfvnpRhJEZfz9qYjzWYNSSh3xg1FqNFst_A6xl_nfE,1697
2
2
  gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
5
5
  gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
6
6
  gpjax/gps.py,sha256=NO18geRfcjo4mA3PGkuGont_Mj_yRqfvWzJqYmoKwiY,31225
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=Uh4kgLTod8ODw178L--G3w4olpm9XvCdcAZ8l7FwkF4,9255
8
+ gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
10
  gpjax/mean_functions.py,sha256=et2HzlsYJNViBvTohF2wZYgCWQfDX4KboYeO7egMR1c,6420
11
11
  gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
12
  gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
13
- gpjax/scan.py,sha256=mtMsg8yLdkVuOYeTHLnATPGfGDnCMAQNdUA-FJlpfLs,5475
13
+ gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=Eik5CCU7qH7_7cacpZ-1lIXm4tElELwSYfVw-n0rI20,27742
15
+ gpjax/variational_families.py,sha256=JO78dywHNH9__hjJkrP2ASb1L3C9aEBOW7fd0run-e4,27918
16
16
  gpjax/decision_making/__init__.py,sha256=SDuPQl80lJ7nhfRsiB_7c22wCMiQO5ehSNohxUGnB7w,2170
17
17
  gpjax/decision_making/decision_maker.py,sha256=S4pOXrWcEHy0NDA0gfWzhk7pG0NJfaPpMXvq03yTy0g,13915
18
18
  gpjax/decision_making/posterior_handler.py,sha256=UgXf1Gu7GMh2YDSmiSWJIzmWlFW06KTS44HYz3mazZQ,5905
@@ -56,7 +56,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
56
56
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
57
57
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
58
58
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
59
- gpjax-0.9.2.dist-info/METADATA,sha256=JWT3cDW7onuKnTYUGqa15WxG4L7oEboJKPHYyAggYZ0,9976
60
- gpjax-0.9.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
61
- gpjax-0.9.2.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
- gpjax-0.9.2.dist-info/RECORD,,
59
+ gpjax-0.9.3.dist-info/METADATA,sha256=A-tRR4wxz_YizCra4tZdOfRD1aUGJ_5sKgzL7Ax81B0,9976
60
+ gpjax-0.9.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
61
+ gpjax-0.9.3.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
62
+ gpjax-0.9.3.dist-info/RECORD,,
File without changes