gpjax 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/kernels/base.py +4 -1
- gpjax/kernels/nonstationary/polynomial.py +1 -1
- gpjax/mean_functions.py +4 -3
- gpjax/parameters.py +88 -26
- {gpjax-0.10.0.dist-info → gpjax-0.10.2.dist-info}/METADATA +1 -1
- {gpjax-0.10.0.dist-info → gpjax-0.10.2.dist-info}/RECORD +9 -9
- {gpjax-0.10.0.dist-info → gpjax-0.10.2.dist-info}/WHEEL +0 -0
- {gpjax-0.10.0.dist-info → gpjax-0.10.2.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -39,7 +39,7 @@ __license__ = "MIT"
|
|
|
39
39
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
40
40
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
41
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.10.
|
|
42
|
+
__version__ = "0.10.2"
|
|
43
43
|
|
|
44
44
|
__all__ = [
|
|
45
45
|
"base",
|
gpjax/kernels/base.py
CHANGED
|
@@ -32,6 +32,7 @@ from gpjax.kernels.computations import (
|
|
|
32
32
|
from gpjax.parameters import (
|
|
33
33
|
Parameter,
|
|
34
34
|
Real,
|
|
35
|
+
Static,
|
|
35
36
|
)
|
|
36
37
|
from gpjax.typing import (
|
|
37
38
|
Array,
|
|
@@ -220,7 +221,9 @@ class Constant(AbstractKernel):
|
|
|
220
221
|
def __init__(
|
|
221
222
|
self,
|
|
222
223
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
223
|
-
constant: tp.Union[
|
|
224
|
+
constant: tp.Union[
|
|
225
|
+
ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
|
|
226
|
+
] = jnp.array(0.0),
|
|
224
227
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
225
228
|
):
|
|
226
229
|
if isinstance(constant, Parameter):
|
|
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
|
|
|
46
46
|
self,
|
|
47
47
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
48
48
|
degree: int = 2,
|
|
49
|
-
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] =
|
|
49
|
+
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
50
50
|
variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
51
51
|
n_dims: tp.Union[int, None] = None,
|
|
52
52
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
gpjax/mean_functions.py
CHANGED
|
@@ -28,6 +28,7 @@ from jaxtyping import (
|
|
|
28
28
|
from gpjax.parameters import (
|
|
29
29
|
Parameter,
|
|
30
30
|
Real,
|
|
31
|
+
Static
|
|
31
32
|
)
|
|
32
33
|
from gpjax.typing import (
|
|
33
34
|
Array,
|
|
@@ -130,9 +131,9 @@ class Constant(AbstractMeanFunction):
|
|
|
130
131
|
"""
|
|
131
132
|
|
|
132
133
|
def __init__(
|
|
133
|
-
self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0
|
|
134
|
+
self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0
|
|
134
135
|
):
|
|
135
|
-
if isinstance(constant, Parameter):
|
|
136
|
+
if isinstance(constant, Parameter) or isinstance(constant, Static):
|
|
136
137
|
self.constant = constant
|
|
137
138
|
else:
|
|
138
139
|
self.constant = Real(jnp.array(constant))
|
|
@@ -158,7 +159,7 @@ class Zero(Constant):
|
|
|
158
159
|
"""
|
|
159
160
|
|
|
160
161
|
def __init__(self):
|
|
161
|
-
super().__init__(constant=jnp.array(0.0))
|
|
162
|
+
super().__init__(constant=Static(jnp.array(0.0)))
|
|
162
163
|
|
|
163
164
|
|
|
164
165
|
class CombinationMeanFunction(AbstractMeanFunction):
|
gpjax/parameters.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
3
|
from flax import nnx
|
|
4
|
+
from jax.experimental import checkify
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import jax.tree_util as jtu
|
|
6
7
|
from jax.typing import ArrayLike
|
|
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
|
|
|
84
85
|
|
|
85
86
|
def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
|
|
86
87
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
87
|
-
|
|
88
|
-
_check_is_positive(self.value)
|
|
88
|
+
_safe_assert(_check_is_positive, self.value)
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class Real(Parameter[T]):
|
|
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
|
|
|
101
101
|
def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
|
|
102
102
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
103
103
|
|
|
104
|
-
|
|
104
|
+
# Only perform validation in non-JIT contexts
|
|
105
|
+
if (
|
|
106
|
+
not isinstance(value, jnp.ndarray)
|
|
107
|
+
or not getattr(value, "aval", None) is None
|
|
108
|
+
):
|
|
109
|
+
_safe_assert(
|
|
110
|
+
_check_in_bounds,
|
|
111
|
+
self.value,
|
|
112
|
+
low=jnp.array(0.0),
|
|
113
|
+
high=jnp.array(1.0),
|
|
114
|
+
)
|
|
105
115
|
|
|
106
116
|
|
|
107
117
|
class Static(nnx.Variable[T]):
|
|
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
|
|
|
120
130
|
def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
|
|
121
131
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
122
132
|
|
|
123
|
-
|
|
124
|
-
|
|
133
|
+
# Only perform validation in non-JIT contexts
|
|
134
|
+
if (
|
|
135
|
+
not isinstance(value, jnp.ndarray)
|
|
136
|
+
or not getattr(value, "aval", None) is None
|
|
137
|
+
):
|
|
138
|
+
_safe_assert(_check_is_square, self.value)
|
|
139
|
+
_safe_assert(_check_is_lower_triangular, self.value)
|
|
125
140
|
|
|
126
141
|
|
|
127
142
|
DEFAULT_BIJECTION = {
|
|
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
|
|
|
132
147
|
}
|
|
133
148
|
|
|
134
149
|
|
|
135
|
-
def _check_is_arraylike(value: T):
|
|
150
|
+
def _check_is_arraylike(value: T) -> None:
|
|
151
|
+
"""Check if a value is array-like.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
value: The value to check.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
TypeError: If the value is not array-like.
|
|
158
|
+
"""
|
|
136
159
|
if not isinstance(value, (ArrayLike, list)):
|
|
137
160
|
raise TypeError(
|
|
138
161
|
f"Expected parameter value to be an array-like type. Got {value}."
|
|
139
162
|
)
|
|
140
163
|
|
|
141
164
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
165
|
+
@checkify.checkify
|
|
166
|
+
def _check_is_positive(value):
|
|
167
|
+
checkify.check(
|
|
168
|
+
jnp.all(value > 0), "value needs to be positive, got {value}", value=value
|
|
169
|
+
)
|
|
147
170
|
|
|
148
171
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
f"Expected parameter value to be a square matrix. Got {value}."
|
|
153
|
-
)
|
|
172
|
+
@checkify.checkify
|
|
173
|
+
def _check_is_square(value: T) -> None:
|
|
174
|
+
"""Check if a value is a square matrix.
|
|
154
175
|
|
|
176
|
+
Args:
|
|
177
|
+
value: The value to check.
|
|
155
178
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the value is not a square matrix.
|
|
181
|
+
"""
|
|
182
|
+
checkify.check(
|
|
183
|
+
value.shape[0] == value.shape[1],
|
|
184
|
+
"value needs to be a square matrix, got {value}",
|
|
185
|
+
value=value,
|
|
186
|
+
)
|
|
161
187
|
|
|
162
188
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
189
|
+
@checkify.checkify
|
|
190
|
+
def _check_is_lower_triangular(value: T) -> None:
|
|
191
|
+
"""Check if a value is a lower triangular matrix.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
value: The value to check.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If the value is not a lower triangular matrix.
|
|
198
|
+
"""
|
|
199
|
+
checkify.check(
|
|
200
|
+
jnp.all(jnp.tril(value) == value),
|
|
201
|
+
"value needs to be a lower triangular matrix, got {value}",
|
|
202
|
+
value=value,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@checkify.checkify
|
|
207
|
+
def _check_in_bounds(value: T, low: T, high: T) -> None:
|
|
208
|
+
"""Check if a value is bounded between low and high.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
value: The value to check.
|
|
212
|
+
low: The lower bound.
|
|
213
|
+
high: The upper bound.
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
ValueError: If any element of value is outside the bounds.
|
|
217
|
+
"""
|
|
218
|
+
checkify.check(
|
|
219
|
+
jnp.all((value >= low) & (value <= high)),
|
|
220
|
+
"value needs to be bounded between {low} and {high}, got {value}",
|
|
221
|
+
value=value,
|
|
222
|
+
low=low,
|
|
223
|
+
high=high,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
|
|
228
|
+
error, _ = fn(value, **kwargs)
|
|
229
|
+
checkify.check_error(error)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
|
|
2
2
|
gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
|
|
@@ -7,14 +7,14 @@ gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
|
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
8
|
gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
|
|
9
9
|
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
|
-
gpjax/mean_functions.py,sha256=
|
|
10
|
+
gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
|
|
11
11
|
gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
|
|
12
|
-
gpjax/parameters.py,sha256=
|
|
12
|
+
gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
|
|
13
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
15
|
gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
|
|
16
16
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
|
-
gpjax/kernels/base.py,sha256=
|
|
17
|
+
gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
|
|
18
18
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
19
|
gpjax/kernels/approximations/rff.py,sha256=4kD1uocjHmxkLgvf4DxB4_Gy7iefdPgnWiZB3jDiExI,4126
|
|
20
20
|
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
@@ -30,7 +30,7 @@ gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5
|
|
|
30
30
|
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
31
|
gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
|
|
32
32
|
gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
|
|
33
|
-
gpjax/kernels/nonstationary/polynomial.py,sha256=
|
|
33
|
+
gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
|
|
34
34
|
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
35
|
gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
|
|
36
36
|
gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
|
|
@@ -42,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
42
42
|
gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
|
|
43
43
|
gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
|
|
44
44
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
45
|
-
gpjax-0.10.
|
|
46
|
-
gpjax-0.10.
|
|
47
|
-
gpjax-0.10.
|
|
48
|
-
gpjax-0.10.
|
|
45
|
+
gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
|
|
46
|
+
gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
47
|
+
gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
48
|
+
gpjax-0.10.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|