gpjax 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -39,7 +39,7 @@ __license__ = "MIT"
39
39
  __description__ = "Didactic Gaussian processes in JAX"
40
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.10.0"
42
+ __version__ = "0.10.2"
43
43
 
44
44
  __all__ = [
45
45
  "base",
gpjax/kernels/base.py CHANGED
@@ -32,6 +32,7 @@ from gpjax.kernels.computations import (
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
35
+ Static,
35
36
  )
36
37
  from gpjax.typing import (
37
38
  Array,
@@ -220,7 +221,9 @@ class Constant(AbstractKernel):
220
221
  def __init__(
221
222
  self,
222
223
  active_dims: tp.Union[list[int], slice, None] = None,
223
- constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
224
+ constant: tp.Union[
225
+ ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226
+ ] = jnp.array(0.0),
224
227
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
225
228
  ):
226
229
  if isinstance(constant, Parameter):
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
46
46
  self,
47
47
  active_dims: tp.Union[list[int], slice, None] = None,
48
48
  degree: int = 2,
49
- shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 0.0,
49
+ shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
50
50
  variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
51
51
  n_dims: tp.Union[int, None] = None,
52
52
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
gpjax/mean_functions.py CHANGED
@@ -28,6 +28,7 @@ from jaxtyping import (
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
30
  Real,
31
+ Static
31
32
  )
32
33
  from gpjax.typing import (
33
34
  Array,
@@ -130,9 +131,9 @@ class Constant(AbstractMeanFunction):
130
131
  """
131
132
 
132
133
  def __init__(
133
- self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0
134
+ self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0
134
135
  ):
135
- if isinstance(constant, Parameter):
136
+ if isinstance(constant, Parameter) or isinstance(constant, Static):
136
137
  self.constant = constant
137
138
  else:
138
139
  self.constant = Real(jnp.array(constant))
@@ -158,7 +159,7 @@ class Zero(Constant):
158
159
  """
159
160
 
160
161
  def __init__(self):
161
- super().__init__(constant=jnp.array(0.0))
162
+ super().__init__(constant=Static(jnp.array(0.0)))
162
163
 
163
164
 
164
165
  class CombinationMeanFunction(AbstractMeanFunction):
gpjax/parameters.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import typing as tp
2
2
 
3
3
  from flax import nnx
4
+ from jax.experimental import checkify
4
5
  import jax.numpy as jnp
5
6
  import jax.tree_util as jtu
6
7
  from jax.typing import ArrayLike
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
84
85
 
85
86
  def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
86
87
  super().__init__(value=value, tag=tag, **kwargs)
87
-
88
- _check_is_positive(self.value)
88
+ _safe_assert(_check_is_positive, self.value)
89
89
 
90
90
 
91
91
  class Real(Parameter[T]):
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
101
101
  def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
102
102
  super().__init__(value=value, tag=tag, **kwargs)
103
103
 
104
- _check_in_bounds(self.value, 0.0, 1.0)
104
+ # Only perform validation in non-JIT contexts
105
+ if (
106
+ not isinstance(value, jnp.ndarray)
107
+ or not getattr(value, "aval", None) is None
108
+ ):
109
+ _safe_assert(
110
+ _check_in_bounds,
111
+ self.value,
112
+ low=jnp.array(0.0),
113
+ high=jnp.array(1.0),
114
+ )
105
115
 
106
116
 
107
117
  class Static(nnx.Variable[T]):
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
120
130
  def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
121
131
  super().__init__(value=value, tag=tag, **kwargs)
122
132
 
123
- _check_is_square(self.value)
124
- _check_is_lower_triangular(self.value)
133
+ # Only perform validation in non-JIT contexts
134
+ if (
135
+ not isinstance(value, jnp.ndarray)
136
+ or not getattr(value, "aval", None) is None
137
+ ):
138
+ _safe_assert(_check_is_square, self.value)
139
+ _safe_assert(_check_is_lower_triangular, self.value)
125
140
 
126
141
 
127
142
  DEFAULT_BIJECTION = {
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
132
147
  }
133
148
 
134
149
 
135
- def _check_is_arraylike(value: T):
150
+ def _check_is_arraylike(value: T) -> None:
151
+ """Check if a value is array-like.
152
+
153
+ Args:
154
+ value: The value to check.
155
+
156
+ Raises:
157
+ TypeError: If the value is not array-like.
158
+ """
136
159
  if not isinstance(value, (ArrayLike, list)):
137
160
  raise TypeError(
138
161
  f"Expected parameter value to be an array-like type. Got {value}."
139
162
  )
140
163
 
141
164
 
142
- def _check_is_positive(value: T):
143
- if jnp.any(value < 0):
144
- raise ValueError(
145
- f"Expected parameter value to be strictly positive. Got {value}."
146
- )
165
+ @checkify.checkify
166
+ def _check_is_positive(value):
167
+ checkify.check(
168
+ jnp.all(value > 0), "value needs to be positive, got {value}", value=value
169
+ )
147
170
 
148
171
 
149
- def _check_is_square(value: T):
150
- if value.shape[0] != value.shape[1]:
151
- raise ValueError(
152
- f"Expected parameter value to be a square matrix. Got {value}."
153
- )
172
+ @checkify.checkify
173
+ def _check_is_square(value: T) -> None:
174
+ """Check if a value is a square matrix.
154
175
 
176
+ Args:
177
+ value: The value to check.
155
178
 
156
- def _check_is_lower_triangular(value: T):
157
- if not jnp.all(jnp.tril(value) == value):
158
- raise ValueError(
159
- f"Expected parameter value to be a lower triangular matrix. Got {value}."
160
- )
179
+ Raises:
180
+ ValueError: If the value is not a square matrix.
181
+ """
182
+ checkify.check(
183
+ value.shape[0] == value.shape[1],
184
+ "value needs to be a square matrix, got {value}",
185
+ value=value,
186
+ )
161
187
 
162
188
 
163
- def _check_in_bounds(value: T, low: float, high: float):
164
- if jnp.any((value < low) | (value > high)):
165
- raise ValueError(
166
- f"Expected parameter value to be bounded between {low} and {high}. Got {value}."
167
- )
189
+ @checkify.checkify
190
+ def _check_is_lower_triangular(value: T) -> None:
191
+ """Check if a value is a lower triangular matrix.
192
+
193
+ Args:
194
+ value: The value to check.
195
+
196
+ Raises:
197
+ ValueError: If the value is not a lower triangular matrix.
198
+ """
199
+ checkify.check(
200
+ jnp.all(jnp.tril(value) == value),
201
+ "value needs to be a lower triangular matrix, got {value}",
202
+ value=value,
203
+ )
204
+
205
+
206
+ @checkify.checkify
207
+ def _check_in_bounds(value: T, low: T, high: T) -> None:
208
+ """Check if a value is bounded between low and high.
209
+
210
+ Args:
211
+ value: The value to check.
212
+ low: The lower bound.
213
+ high: The upper bound.
214
+
215
+ Raises:
216
+ ValueError: If any element of value is outside the bounds.
217
+ """
218
+ checkify.check(
219
+ jnp.all((value >= low) & (value <= high)),
220
+ "value needs to be bounded between {low} and {high}, got {value}",
221
+ value=value,
222
+ low=low,
223
+ high=high,
224
+ )
225
+
226
+
227
+ def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
228
+ error, _ = fn(value, **kwargs)
229
+ checkify.check_error(error)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.0
3
+ Version: 0.10.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,4 +1,4 @@
1
- gpjax/__init__.py,sha256=LeAdMRx9XYvLf6csLhCIv6IHnDbAFB9rP--TYXECgz0,1654
1
+ gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
2
2
  gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
@@ -7,14 +7,14 @@ gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
8
  gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
- gpjax/mean_functions.py,sha256=et2HzlsYJNViBvTohF2wZYgCWQfDX4KboYeO7egMR1c,6420
10
+ gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
11
11
  gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
- gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
12
+ gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
15
  gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
16
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
- gpjax/kernels/base.py,sha256=abkj3zidsBs7YSkYEfjeJ5jTs1YyDCPoBM2ZzqaqrgI,11561
17
+ gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
18
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
19
  gpjax/kernels/approximations/rff.py,sha256=4kD1uocjHmxkLgvf4DxB4_Gy7iefdPgnWiZB3jDiExI,4126
20
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
@@ -30,7 +30,7 @@ gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5
30
30
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
31
  gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
32
32
  gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
33
- gpjax/kernels/nonstationary/polynomial.py,sha256=yTGobMPbCnKlj4PiQPSXEkWNrj2sjg_x9zFsnFa_j4E,3257
33
+ gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
34
34
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
35
  gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
36
36
  gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
@@ -42,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
42
42
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
43
43
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
44
44
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
- gpjax-0.10.0.dist-info/METADATA,sha256=wZyZSD1p2t_K5m25TrvGJr6lTlfaUxoB12F-0f1d9Co,9970
46
- gpjax-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
- gpjax-0.10.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
- gpjax-0.10.0.dist-info/RECORD,,
45
+ gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
46
+ gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
+ gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
+ gpjax-0.10.2.dist-info/RECORD,,
File without changes