gpjax 0.9.2__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/likelihoods.py +3 -5
- gpjax/scan.py +10 -10
- gpjax/variational_families.py +9 -2
- {gpjax-0.9.2.dist-info → gpjax-0.9.3.dist-info}/METADATA +1 -1
- {gpjax-0.9.2.dist-info → gpjax-0.9.3.dist-info}/RECORD +8 -8
- {gpjax-0.9.2.dist-info → gpjax-0.9.3.dist-info}/WHEEL +0 -0
- {gpjax-0.9.2.dist-info → gpjax-0.9.3.dist-info}/licenses/LICENSE +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,7 +40,7 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.9.
|
|
43
|
+
__version__ = "0.9.3"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"base",
|
gpjax/likelihoods.py
CHANGED
|
@@ -28,7 +28,6 @@ from gpjax.integrators import (
|
|
|
28
28
|
GHQuadratureIntegrator,
|
|
29
29
|
)
|
|
30
30
|
from gpjax.parameters import (
|
|
31
|
-
Parameter,
|
|
32
31
|
PositiveReal,
|
|
33
32
|
Static,
|
|
34
33
|
)
|
|
@@ -152,10 +151,9 @@ class Gaussian(AbstractLikelihood):
|
|
|
152
151
|
likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
|
|
153
152
|
the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
|
|
154
153
|
"""
|
|
155
|
-
if isinstance(obs_stddev,
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
self.obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
|
|
154
|
+
if not isinstance(obs_stddev, (PositiveReal, Static)):
|
|
155
|
+
obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
|
|
156
|
+
self.obs_stddev = obs_stddev
|
|
159
157
|
|
|
160
158
|
super().__init__(num_datapoints, integrator)
|
|
161
159
|
|
gpjax/scan.py
CHANGED
|
@@ -22,7 +22,6 @@ from beartype.typing import (
|
|
|
22
22
|
)
|
|
23
23
|
import jax
|
|
24
24
|
from jax import lax
|
|
25
|
-
from jax.experimental import host_callback as hcb
|
|
26
25
|
import jax.numpy as jnp
|
|
27
26
|
import jax.tree_util as jtu
|
|
28
27
|
from jaxtyping import (
|
|
@@ -54,7 +53,8 @@ def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None:
|
|
|
54
53
|
|
|
55
54
|
def _do_callback(_) -> int:
|
|
56
55
|
"""Perform the callback."""
|
|
57
|
-
|
|
56
|
+
jax.debug.callback(func, *args)
|
|
57
|
+
return _dummy_result
|
|
58
58
|
|
|
59
59
|
def _not_callback(_) -> int:
|
|
60
60
|
"""Do nothing."""
|
|
@@ -113,19 +113,19 @@ def vscan(
|
|
|
113
113
|
_progress_bar = trange(_length)
|
|
114
114
|
_progress_bar.set_description("Compiling...", refresh=True)
|
|
115
115
|
|
|
116
|
-
def _set_running(args: Any
|
|
116
|
+
def _set_running(*args: Any) -> None:
|
|
117
117
|
"""Set the tqdm progress bar to running."""
|
|
118
118
|
_progress_bar.set_description("Running", refresh=False)
|
|
119
119
|
|
|
120
|
-
def _update_tqdm(args: Any
|
|
120
|
+
def _update_tqdm(*args: Any) -> None:
|
|
121
121
|
"""Update the tqdm progress bar with the latest objective value."""
|
|
122
122
|
_value, _iter_num = args
|
|
123
|
-
_progress_bar.update(_iter_num)
|
|
123
|
+
_progress_bar.update(_iter_num.item())
|
|
124
124
|
|
|
125
125
|
if log_value and _value is not None:
|
|
126
126
|
_progress_bar.set_postfix({"Value": f"{_value: .2f}"})
|
|
127
127
|
|
|
128
|
-
def _close_tqdm(args: Any
|
|
128
|
+
def _close_tqdm(*args: Any) -> None:
|
|
129
129
|
"""Close the tqdm progress bar."""
|
|
130
130
|
_progress_bar.close()
|
|
131
131
|
|
|
@@ -145,16 +145,16 @@ def vscan(
|
|
|
145
145
|
_is_last: bool = iter_num == _length - 1
|
|
146
146
|
|
|
147
147
|
# Update progress bar, if first of log_rate.
|
|
148
|
-
_callback(_is_first, _set_running
|
|
148
|
+
_callback(_is_first, _set_running)
|
|
149
149
|
|
|
150
150
|
# Update progress bar, if multiple of log_rate.
|
|
151
|
-
_callback(_is_multiple, _update_tqdm,
|
|
151
|
+
_callback(_is_multiple, _update_tqdm, y, log_rate)
|
|
152
152
|
|
|
153
153
|
# Update progress bar, if remainder.
|
|
154
|
-
_callback(_is_remainder, _update_tqdm,
|
|
154
|
+
_callback(_is_remainder, _update_tqdm, y, _remainder)
|
|
155
155
|
|
|
156
156
|
# Close progress bar, if last iteration.
|
|
157
|
-
_callback(_is_last, _close_tqdm
|
|
157
|
+
_callback(_is_last, _close_tqdm)
|
|
158
158
|
|
|
159
159
|
return carry, y
|
|
160
160
|
|
gpjax/variational_families.py
CHANGED
|
@@ -108,10 +108,17 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
|
|
|
108
108
|
def __init__(
|
|
109
109
|
self,
|
|
110
110
|
posterior: AbstractPosterior[P, L],
|
|
111
|
-
inducing_inputs:
|
|
111
|
+
inducing_inputs: tp.Union[
|
|
112
|
+
Float[Array, "N D"],
|
|
113
|
+
Real,
|
|
114
|
+
Static,
|
|
115
|
+
],
|
|
112
116
|
jitter: ScalarFloat = 1e-6,
|
|
113
117
|
):
|
|
114
|
-
|
|
118
|
+
if not isinstance(inducing_inputs, (Real, Static)):
|
|
119
|
+
inducing_inputs = Real(inducing_inputs)
|
|
120
|
+
|
|
121
|
+
self.inducing_inputs = inducing_inputs
|
|
115
122
|
self.jitter = jitter
|
|
116
123
|
|
|
117
124
|
super().__init__(posterior)
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=UNfvnpRhJEZfz9qYjzWYNSSh3xg1FqNFst_A6xl_nfE,1697
|
|
2
2
|
gpjax/citation.py,sha256=R4Pmvjt0ndA0avEDSvIbxDxKapkRRYXWX7RRWBvZCRQ,5306
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=zxkSEZIlTg0PHvvgj0BQuIFEg-ugx6_NkEwSsbqWUM0,9325
|
|
5
5
|
gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
|
|
6
6
|
gpjax/gps.py,sha256=NO18geRfcjo4mA3PGkuGont_Mj_yRqfvWzJqYmoKwiY,31225
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=
|
|
8
|
+
gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
|
|
9
9
|
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
10
|
gpjax/mean_functions.py,sha256=et2HzlsYJNViBvTohF2wZYgCWQfDX4KboYeO7egMR1c,6420
|
|
11
11
|
gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
|
|
12
12
|
gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
|
|
13
|
-
gpjax/scan.py,sha256=
|
|
13
|
+
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=
|
|
15
|
+
gpjax/variational_families.py,sha256=JO78dywHNH9__hjJkrP2ASb1L3C9aEBOW7fd0run-e4,27918
|
|
16
16
|
gpjax/decision_making/__init__.py,sha256=SDuPQl80lJ7nhfRsiB_7c22wCMiQO5ehSNohxUGnB7w,2170
|
|
17
17
|
gpjax/decision_making/decision_maker.py,sha256=S4pOXrWcEHy0NDA0gfWzhk7pG0NJfaPpMXvq03yTy0g,13915
|
|
18
18
|
gpjax/decision_making/posterior_handler.py,sha256=UgXf1Gu7GMh2YDSmiSWJIzmWlFW06KTS44HYz3mazZQ,5905
|
|
@@ -56,7 +56,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
56
56
|
gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
|
|
57
57
|
gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
|
|
58
58
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
59
|
-
gpjax-0.9.
|
|
60
|
-
gpjax-0.9.
|
|
61
|
-
gpjax-0.9.
|
|
62
|
-
gpjax-0.9.
|
|
59
|
+
gpjax-0.9.3.dist-info/METADATA,sha256=A-tRR4wxz_YizCra4tZdOfRD1aUGJ_5sKgzL7Ax81B0,9976
|
|
60
|
+
gpjax-0.9.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
|
61
|
+
gpjax-0.9.3.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
62
|
+
gpjax-0.9.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|