gpjax 0.11.0__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -32,14 +32,15 @@ from gpjax.citation import cite
32
32
  from gpjax.dataset import Dataset
33
33
  from gpjax.fit import (
34
34
  fit,
35
+ fit_lbfgs,
35
36
  fit_scipy,
36
37
  )
37
38
 
38
39
  __license__ = "MIT"
39
- __description__ = "Didactic Gaussian processes in JAX"
40
+ __description__ = "Gaussian processes in JAX and Flax"
40
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.11.0"
43
+ __version__ = "0.11.1"
43
44
 
44
45
  __all__ = [
45
46
  "base",
@@ -56,5 +57,6 @@ __all__ = [
56
57
  "fit",
57
58
  "Module",
58
59
  "param_field",
60
+ "fit_lbfgs",
59
61
  "fit_scipy",
60
62
  ]
gpjax/fit.py CHANGED
@@ -15,13 +15,13 @@
15
15
 
16
16
  import typing as tp
17
17
 
18
- from flax import nnx
19
18
  import jax
20
- from jax.flatten_util import ravel_pytree
21
19
  import jax.numpy as jnp
22
20
  import jax.random as jr
23
- from numpyro.distributions.transforms import Transform
24
21
  import optax as ox
22
+ from flax import nnx
23
+ from jax.flatten_util import ravel_pytree
24
+ from numpyro.distributions.transforms import Transform
25
25
  from scipy.optimize import minimize
26
26
 
27
27
  from gpjax.dataset import Dataset
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
127
127
  _check_verbose(verbose)
128
128
 
129
129
  # Model state filtering
130
-
131
130
  graphdef, params, *static_state = nnx.split(model, Parameter, ...)
132
131
 
133
132
  # Parameters bijection to unconstrained space
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
253
252
  return model, history
254
253
 
255
254
 
255
+ def fit_lbfgs(
256
+ *,
257
+ model: Model,
258
+ objective: Objective,
259
+ train_data: Dataset,
260
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
261
+ max_iters: int = 100,
262
+ safe: bool = True,
263
+ max_linesearch_steps: int = 32,
264
+ gtol: float = 1e-5,
265
+ ) -> tuple[Model, jax.Array]:
266
+ r"""Train a Module model with respect to a supplied Objective function.
267
+
268
+ Uses Optax's LBFGS implementation and a jax.lax.while loop.
269
+
270
+ Args:
271
+ model: the model Module to be optimised.
272
+ objective: The objective function that we are optimising with
273
+ respect to.
274
+ train_data (Dataset): The training data to be used for the optimisation.
275
+ max_iters (int): The maximum number of optimisation steps to run. Defaults
276
+ to 500.
277
+ safe (bool): Whether to check the types of the inputs.
278
+ max_linesearch_steps (int): The maximum number of linesearch steps to use
279
+ for finding the stepsize.
280
+ gtol (float): Terminate the optimisation if the L2 norm of the gradient is
281
+ below this threshold.
282
+
283
+ Returns:
284
+ A tuple comprising the optimised model and final loss.
285
+ """
286
+ if safe:
287
+ # Check inputs
288
+ _check_model(model)
289
+ _check_train_data(train_data)
290
+ _check_num_iters(max_iters)
291
+
292
+ # Model state filtering
293
+ graphdef, params, *static_state = nnx.split(model, Parameter, ...)
294
+
295
+ # Parameters bijection to unconstrained space
296
+ if params_bijection is not None:
297
+ params = transform(params, params_bijection, inverse=True)
298
+
299
+ # Loss definition
300
+ def loss(params: nnx.State) -> ScalarFloat:
301
+ params = transform(params, params_bijection)
302
+ model = nnx.merge(graphdef, params, *static_state)
303
+ return objective(model, train_data)
304
+
305
+ # Initialise optimiser
306
+ optim = ox.lbfgs(
307
+ linesearch=ox.scale_by_zoom_linesearch(
308
+ max_linesearch_steps=max_linesearch_steps,
309
+ initial_guess_strategy="one",
310
+ )
311
+ )
312
+ opt_state = optim.init(params)
313
+ loss_value_and_grad = ox.value_and_grad_from_state(loss)
314
+
315
+ # Optimisation step.
316
+ def step(carry):
317
+ params, opt_state = carry
318
+
319
+ # Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
320
+ # See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
321
+ loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
322
+ updates, opt_state = optim.update(
323
+ loss_gradient,
324
+ opt_state,
325
+ params,
326
+ value=loss_val,
327
+ grad=loss_gradient,
328
+ value_fn=loss,
329
+ )
330
+ params = ox.apply_updates(params, updates)
331
+
332
+ return params, opt_state
333
+
334
+ def continue_fn(carry):
335
+ _, opt_state = carry
336
+ n = ox.tree_utils.tree_get(opt_state, "count")
337
+ g = ox.tree_utils.tree_get(opt_state, "grad")
338
+ g_l2_norm = ox.tree_utils.tree_l2_norm(g)
339
+ return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
340
+
341
+ # Optimisation loop
342
+ params, opt_state = jax.lax.while_loop(
343
+ continue_fn,
344
+ step,
345
+ (params, opt_state),
346
+ )
347
+ final_loss = ox.tree_utils.tree_get(opt_state, "value")
348
+
349
+ # Parameters bijection to constrained space
350
+ if params_bijection is not None:
351
+ params = transform(params, params_bijection)
352
+
353
+ # Reconstruct model
354
+ model = nnx.merge(graphdef, params, *static_state)
355
+
356
+ return model, final_loss
357
+
358
+
256
359
  def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
257
360
  """Batch the data into mini-batches. Sampling is done with replacement.
258
361
 
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
91
94
  if isinstance(variance, nnx.Variable):
92
95
  self.variance = variance
93
96
  else:
94
- self.variance = PositiveReal(variance)
97
+ self.variance = NonNegativeReal(variance)
95
98
  if tp.TYPE_CHECKING:
96
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
99
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
97
100
 
98
101
  if isinstance(bias_variance, nnx.Variable):
99
102
  self.bias_variance = bias_variance
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import NonNegativeReal
27
27
  from gpjax.typing import (
28
28
  Array,
29
29
  ScalarArray,
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
64
64
  if isinstance(variance, nnx.Variable):
65
65
  self.variance = variance
66
66
  else:
67
- self.variance = PositiveReal(variance)
67
+ self.variance = NonNegativeReal(variance)
68
68
  if tp.TYPE_CHECKING:
69
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
69
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
70
70
 
71
71
  def __call__(
72
72
  self,
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
76
79
  if isinstance(variance, nnx.Variable):
77
80
  self.variance = variance
78
81
  else:
79
- self.variance = PositiveReal(variance)
82
+ self.variance = NonNegativeReal(variance)
80
83
  if tp.TYPE_CHECKING:
81
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
84
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
82
85
 
83
86
  self.name = f"Polynomial (degree {self.degree})"
84
87
 
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  DenseKernelComputation,
27
27
  )
28
- from gpjax.parameters import PositiveReal
28
+ from gpjax.parameters import (
29
+ NonNegativeReal,
30
+ PositiveReal,
31
+ )
29
32
  from gpjax.typing import (
30
33
  Array,
31
34
  ScalarArray,
@@ -85,11 +88,11 @@ class StationaryKernel(AbstractKernel):
85
88
  if isinstance(variance, nnx.Variable):
86
89
  self.variance = variance
87
90
  else:
88
- self.variance = PositiveReal(variance)
91
+ self.variance = NonNegativeReal(variance)
89
92
 
90
93
  # static typing
91
94
  if tp.TYPE_CHECKING:
92
- self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
95
+ self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
93
96
 
94
97
  @property
95
98
  def spectral_density(self) -> npd.Normal | npd.StudentT:
gpjax/likelihoods.py CHANGED
@@ -28,7 +28,7 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- PositiveReal,
31
+ NonNegativeReal,
32
32
  Static,
33
33
  )
34
34
  from gpjax.typing import (
@@ -134,7 +134,7 @@ class Gaussian(AbstractLikelihood):
134
134
  self,
135
135
  num_datapoints: int,
136
136
  obs_stddev: tp.Union[
137
- ScalarFloat, Float[Array, "#N"], PositiveReal, Static
137
+ ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
138
138
  ] = 1.0,
139
139
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
140
140
  ):
@@ -148,8 +148,8 @@ class Gaussian(AbstractLikelihood):
148
148
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
149
149
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
150
150
  """
151
- if not isinstance(obs_stddev, (PositiveReal, Static)):
152
- obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
151
+ if not isinstance(obs_stddev, (NonNegativeReal, Static)):
152
+ obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
153
153
  self.obs_stddev = obs_stddev
154
154
 
155
155
  super().__init__(num_datapoints, integrator)
gpjax/mean_functions.py CHANGED
@@ -207,5 +207,5 @@ SumMeanFunction = ft.partial(
207
207
  CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
208
208
  )
209
209
  ProductMeanFunction = ft.partial(
210
- CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
210
+ CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
211
211
  )
gpjax/parameters.py CHANGED
@@ -82,6 +82,14 @@ class Parameter(nnx.Variable[T]):
82
82
  self._tag = tag
83
83
 
84
84
 
85
+ class NonNegativeReal(Parameter[T]):
86
+ """Parameter that is non-negative."""
87
+
88
+ def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
89
+ super().__init__(value=value, tag=tag, **kwargs)
90
+ _safe_assert(_check_is_non_negative, self.value)
91
+
92
+
85
93
  class PositiveReal(Parameter[T]):
86
94
  """Parameter that is strictly positive."""
87
95
 
@@ -143,6 +151,7 @@ class LowerTriangular(Parameter[T]):
143
151
 
144
152
  DEFAULT_BIJECTION = {
145
153
  "positive": npt.SoftplusTransform(),
154
+ "non_negative": npt.SoftplusTransform(),
146
155
  "real": npt.IdentityTransform(),
147
156
  "sigmoid": npt.SigmoidTransform(),
148
157
  "lower_triangular": FillTriangularTransform(),
@@ -164,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
164
173
  )
165
174
 
166
175
 
176
+ @checkify.checkify
177
+ def _check_is_non_negative(value):
178
+ checkify.check(
179
+ jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
180
+ )
181
+
182
+
167
183
  @checkify.checkify
168
184
  def _check_is_positive(value):
169
185
  checkify.check(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.11.0
3
+ Version: 0.11.1
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,16 +1,16 @@
1
- gpjax/__init__.py,sha256=wXJtQa_3W7wZEw_t1Dk0uHUzNQQDv8QzsVbnwXCMXcQ,1654
1
+ gpjax/__init__.py,sha256=TjAAfeZTCEl_zsibA8pV76M1jcHkeFhNfWk_SllfgHY,1686
2
2
  gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
5
- gpjax/fit.py,sha256=STwpeqSuu2pgT6uZU7xd7koPZbAjPDzhcZ8nHfozR7Q,11538
5
+ gpjax/fit.py,sha256=7L2veA6aRNiozZD8fWa-MVDoYFUKjGJahmvjz8Wp-P0,15046
6
6
  gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=VcCibgihaskmvNJT4kuPa7ehgjlnR9LgMz_2KJJvHY0,9296
8
+ gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
- gpjax/mean_functions.py,sha256=gIPz7exEhish3yeJQxZp5Q_jlf2-gCE-KVAnL2Rumkc,6489
10
+ gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
11
11
  gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
12
12
  gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
13
- gpjax/parameters.py,sha256=Vj1xzrziSLxfBSqyc-BacyKBwkbE9Sjq4b1HV5HZiOg,6507
13
+ gpjax/parameters.py,sha256=H-DiXmotdBZCbf-GOjRaJoS_isk3GgFrpKHTq5GpnoA,6998
14
14
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
15
15
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
16
16
  gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
@@ -29,11 +29,11 @@ gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3
29
29
  gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
30
30
  gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
31
31
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
32
- gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
33
- gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
34
- gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
32
+ gpjax/kernels/nonstationary/arccosine.py,sha256=2WV6aM0Z3-xXZnoPw-77n2CW62n-AZuJy-7AQ9xrMco,5858
33
+ gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
34
+ gpjax/kernels/nonstationary/polynomial.py,sha256=arP8DK0jnBOaayDWcFvHF0pdu9FVhwzXdqjnHUAL2VI,3293
35
35
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
36
- gpjax/kernels/stationary/base.py,sha256=FlsXMsXyZ5cI80jbsIo8Jv-H6gsV3C7v6plIhyCl-GI,7042
36
+ gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
37
37
  gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
38
38
  gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
39
39
  gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
@@ -43,7 +43,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
43
43
  gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
44
  gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
45
45
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
46
- gpjax-0.11.0.dist-info/METADATA,sha256=eSWVc5y9WNrUmKpaOVq1CcHjrKjMwlmSvwovN9h9aCk,8558
47
- gpjax-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
- gpjax-0.11.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
- gpjax-0.11.0.dist-info/RECORD,,
46
+ gpjax-0.11.1.dist-info/METADATA,sha256=02crI6D0dsht6XJ8N1ZqNj5ZktmS5NymVfY45pPmEgM,8558
47
+ gpjax-0.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
+ gpjax-0.11.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
+ gpjax-0.11.1.dist-info/RECORD,,
File without changes