gpjax 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,10 +40,9 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.12.0"
43
+ __version__ = "0.13.0"
44
44
 
45
45
  __all__ = [
46
- "base",
47
46
  "gps",
48
47
  "integrators",
49
48
  "kernels",
@@ -55,8 +54,6 @@ __all__ = [
55
54
  "Dataset",
56
55
  "cite",
57
56
  "fit",
58
- "Module",
59
- "param_field",
60
57
  "fit_lbfgs",
61
58
  "fit_scipy",
62
59
  ]
gpjax/fit.py CHANGED
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
48
48
  train_data: Dataset,
49
49
  optim: ox.GradientTransformation,
50
50
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
51
+ trainable: nnx.filterlib.Filter = Parameter,
51
52
  key: KeyArray = jr.PRNGKey(42),
52
53
  num_iters: int = 100,
53
54
  batch_size: int = -1,
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
65
66
  >>> import jax.random as jr
66
67
  >>> import optax as ox
67
68
  >>> import gpjax as gpx
68
- >>> from gpjax.parameters import PositiveReal, Static
69
+ >>> from gpjax.parameters import PositiveReal
69
70
  >>>
70
71
  >>> # (1) Create a dataset:
71
72
  >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
75
76
  >>> class LinearModel(nnx.Module):
76
77
  >>> def __init__(self, weight: float, bias: float):
77
78
  >>> self.weight = PositiveReal(weight)
78
- >>> self.bias = Static(bias)
79
+ >>> self.bias = bias
79
80
  >>>
80
81
  >>> def __call__(self, x):
81
- >>> return self.weight.value * x + self.bias.value
82
+ >>> return self.weight.value * x + self.bias
82
83
  >>>
83
84
  >>> model = LinearModel(weight=1.0, bias=1.0)
84
85
  >>>
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
100
101
  train_data (Dataset): The training data to be used for the optimisation.
101
102
  optim (GradientTransformation): The Optax optimiser that is to be used for
102
103
  learning a parameter set.
104
+ trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
105
+ Defaults to nnx.Param (all Parameter instances).
103
106
  num_iters (int): The number of optimisation steps to run. Defaults
104
107
  to 100.
105
108
  batch_size (int): The size of the mini-batch to use. Defaults to -1
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
127
130
  _check_verbose(verbose)
128
131
 
129
132
  # Model state filtering
130
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
133
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
131
134
 
132
135
  # Parameters bijection to unconstrained space
133
136
  if params_bijection is not None:
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
182
185
  model: Model,
183
186
  objective: Objective,
184
187
  train_data: Dataset,
188
+ trainable: nnx.filterlib.Filter = Parameter,
185
189
  max_iters: int = 500,
186
190
  verbose: bool = True,
187
191
  safe: bool = True,
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
210
214
  _check_verbose(verbose)
211
215
 
212
216
  # Model state filtering
213
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
217
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
214
218
 
215
219
  # Parameters bijection to unconstrained space
216
220
  params = transform(params, DEFAULT_BIJECTION, inverse=True)
@@ -258,6 +262,7 @@ def fit_lbfgs(
258
262
  objective: Objective,
259
263
  train_data: Dataset,
260
264
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
265
+ trainable: nnx.filterlib.Filter = Parameter,
261
266
  max_iters: int = 100,
262
267
  safe: bool = True,
263
268
  max_linesearch_steps: int = 32,
@@ -290,7 +295,7 @@ def fit_lbfgs(
290
295
  _check_num_iters(max_iters)
291
296
 
292
297
  # Model state filtering
293
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
298
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
294
299
 
295
300
  # Parameters bijection to unconstrained space
296
301
  if params_bijection is not None:
gpjax/gps.py CHANGED
@@ -35,16 +35,15 @@ from gpjax.likelihoods import (
35
35
  )
36
36
  from gpjax.linalg import (
37
37
  Dense,
38
- Identity,
39
38
  psd,
40
39
  solve,
41
40
  )
42
41
  from gpjax.linalg.operations import lower_cholesky
42
+ from gpjax.linalg.utils import add_jitter
43
43
  from gpjax.mean_functions import AbstractMeanFunction
44
44
  from gpjax.parameters import (
45
45
  Parameter,
46
46
  Real,
47
- Static,
48
47
  )
49
48
  from gpjax.typing import (
50
49
  Array,
@@ -78,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
78
77
  self.mean_function = mean_function
79
78
  self.jitter = jitter
80
79
 
81
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
80
+ def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
82
81
  r"""Evaluate the Gaussian process at the given points.
83
82
 
84
83
  The output of this function is a
@@ -91,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
91
90
  `__call__` method and should instead define a `predict` method.
92
91
 
93
92
  Args:
94
- *args (Any): The arguments to pass to the GP's `predict` method.
95
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
93
+ test_inputs: Input locations where the GP should be evaluated.
96
94
 
97
95
  Returns:
98
96
  GaussianDistribution: A multivariate normal random variable representation
99
97
  of the Gaussian process.
100
98
  """
101
- return self.predict(*args, **kwargs)
99
+ return self.predict(test_inputs)
102
100
 
103
101
  @abstractmethod
104
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
102
+ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
105
103
  r"""Evaluate the predictive distribution.
106
104
 
107
105
  Compute the latent function's multivariate normal distribution for a
@@ -109,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
109
107
  this method must be implemented.
110
108
 
111
109
  Args:
112
- *args (Any): Arguments to the predict method.
113
- **kwargs (Any): Keyword arguments to the predict method.
110
+ test_inputs: Input locations where the GP should be evaluated.
114
111
 
115
112
  Returns:
116
113
  GaussianDistribution: A multivariate normal random variable representation
@@ -249,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
249
246
  GaussianDistribution: A multivariate normal random variable representation
250
247
  of the Gaussian process.
251
248
  """
252
- x = test_inputs
253
- mx = self.mean_function(x)
254
- Kxx = self.kernel.gram(x)
255
- Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
249
+ mean_at_test = self.mean_function(test_inputs)
250
+ Kxx = self.kernel.gram(test_inputs)
251
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
256
252
  Kxx = psd(Dense(Kxx_dense))
257
253
 
258
- return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)
254
+ return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
259
255
 
260
256
  def sample_approx(
261
257
  self,
@@ -359,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
359
355
  self.likelihood = likelihood
360
356
  self.jitter = jitter
361
357
 
362
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
358
+ def __call__(
359
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
360
+ ) -> GaussianDistribution:
363
361
  r"""Evaluate the Gaussian process posterior at the given points.
364
362
 
365
363
  The output of this function is a
@@ -368,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
368
366
  evaluated and the distribution can be sampled.
369
367
 
370
368
  Under the hood, `__call__` is calling the objects `predict` method. For this
371
- reasons, classes inheriting the `AbstractPrior` class, should not overwrite the
369
+ reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
372
370
  `__call__` method and should instead define a `predict` method.
373
371
 
374
372
  Args:
375
- *args (Any): The arguments to pass to the GP's `predict` method.
376
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
373
+ test_inputs: Input locations where the GP should be evaluated.
374
+ train_data: Training dataset to condition on.
377
375
 
378
376
  Returns:
379
377
  GaussianDistribution: A multivariate normal random variable representation
380
378
  of the Gaussian process.
381
379
  """
382
- return self.predict(*args, **kwargs)
380
+ return self.predict(test_inputs, train_data)
383
381
 
384
382
  @abstractmethod
385
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
383
+ def predict(
384
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
385
+ ) -> GaussianDistribution:
386
386
  r"""Compute the latent function's multivariate normal distribution for a
387
- given set of parameters. For any class inheriting the `AbstractPrior` class,
387
+ given set of parameters. For any class inheriting the `AbstractPosterior` class,
388
388
  this method must be implemented.
389
389
 
390
390
  Args:
391
- *args (Any): Arguments to the predict method.
392
- **kwargs (Any): Keyword arguments to the predict method.
391
+ test_inputs: Input locations where the GP should be evaluated.
392
+ train_data: Training dataset to condition on.
393
393
 
394
394
  Returns:
395
395
  GaussianDistribution: A multivariate normal random variable representation
@@ -503,22 +503,24 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
503
503
 
504
504
  # Precompute Gram matrix, Kxx, at training inputs, x
505
505
  Kxx = self.prior.kernel.gram(x)
506
- Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
506
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
507
507
  Kxx = Dense(Kxx_dense)
508
508
 
509
509
  Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
510
510
  Sigma = psd(Dense(Sigma_dense))
511
+ L_sigma = lower_cholesky(Sigma)
511
512
 
512
513
  mean_t = self.prior.mean_function(t)
513
514
  Ktt = self.prior.kernel.gram(t)
514
515
  Kxt = self.prior.kernel.cross_covariance(x, t)
515
- Sigma_inv_Kxt = solve(Sigma, Kxt)
516
516
 
517
- mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)
517
+ L_inv_Kxt = solve(L_sigma, Kxt)
518
+ L_inv_y_diff = solve(L_sigma, y - mx)
519
+
520
+ mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
518
521
 
519
- # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
520
- covariance = Ktt.to_dense() - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
521
- covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
522
+ covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
523
+ covariance = add_jitter(covariance, self.prior.jitter)
522
524
  covariance = psd(Dense(covariance))
523
525
 
524
526
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
@@ -577,7 +579,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
577
579
 
578
580
  obs_var = self.likelihood.obs_stddev.value**2
579
581
  Kxx = self.prior.kernel.gram(train_data.X)
580
- Sigma = Kxx + jnp.eye(Kxx.shape[0]) * (obs_var + self.jitter)
582
+ Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
581
583
  eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
582
584
  y = train_data.y - self.prior.mean_function(train_data.X)
583
585
  Phi = fourier_feature_fn(train_data.X)
@@ -643,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
643
645
 
644
646
  # TODO: static or intermediate?
645
647
  self.latent = latent if isinstance(latent, Parameter) else Real(latent)
646
- self.key = Static(key)
648
+ self.key = key
647
649
 
648
650
  def predict(
649
651
  self, test_inputs: Num[Array, "N D"], train_data: Dataset
@@ -675,7 +677,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
675
677
 
676
678
  # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
677
679
  Kxx = kernel.gram(x)
678
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * self.prior.jitter
680
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
679
681
  Kxx = psd(Dense(Kxx_dense))
680
682
  Lx = lower_cholesky(Kxx)
681
683
 
@@ -698,7 +700,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
698
700
 
699
701
  # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
700
702
  covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
701
- covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
703
+ covariance = add_jitter(covariance, self.prior.jitter)
702
704
  covariance = psd(Dense(covariance))
703
705
 
704
706
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
@@ -1,13 +1,13 @@
1
1
  """Compute Random Fourier Feature (RFF) kernel approximations."""
2
2
 
3
3
  import beartype.typing as tp
4
+ from flax import nnx
4
5
  import jax.random as jr
5
6
  from jaxtyping import Float
6
7
 
7
8
  from gpjax.kernels.base import AbstractKernel
8
9
  from gpjax.kernels.computations import BasisFunctionComputation
9
10
  from gpjax.kernels.stationary.base import StationaryKernel
10
- from gpjax.parameters import Static
11
11
  from gpjax.typing import (
12
12
  Array,
13
13
  KeyArray,
@@ -55,7 +55,7 @@ class RFF(AbstractKernel):
55
55
  self._check_valid_base_kernel(base_kernel)
56
56
  self.base_kernel = base_kernel
57
57
  self.num_basis_fns = num_basis_fns
58
- self.frequencies = frequencies
58
+ self.frequencies = nnx.data(frequencies)
59
59
  self.compute_engine = compute_engine
60
60
 
61
61
  if self.frequencies is None:
@@ -66,10 +66,8 @@ class RFF(AbstractKernel):
66
66
  "Please specify the n_dims argument for the base kernel."
67
67
  )
68
68
 
69
- self.frequencies = Static(
70
- self.base_kernel.spectral_density.sample(
71
- key=key, sample_shape=(self.num_basis_fns, n_dims)
72
- )
69
+ self.frequencies = self.base_kernel.spectral_density.sample(
70
+ key=key, sample_shape=(self.num_basis_fns, n_dims)
73
71
  )
74
72
  self.name = f"{self.base_kernel.name} (RFF)"
75
73
 
gpjax/kernels/base.py CHANGED
@@ -32,7 +32,6 @@ from gpjax.linalg import LinearOperator
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
35
- Static,
36
35
  )
37
36
  from gpjax.typing import (
38
37
  Array,
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
221
220
  def __init__(
222
221
  self,
223
222
  active_dims: tp.Union[list[int], slice, None] = None,
224
- constant: tp.Union[
225
- ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226
- ] = jnp.array(0.0),
223
+ constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
227
224
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
228
225
  ):
229
226
  if isinstance(constant, Parameter):
@@ -256,7 +253,7 @@ class CombinationKernel(AbstractKernel):
256
253
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
257
254
  ):
258
255
  # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
259
- kernels_list: list[AbstractKernel] = []
256
+ kernels_list: list[AbstractKernel] = nnx.List([])
260
257
  for kernel in kernels:
261
258
  if not isinstance(kernel, AbstractKernel):
262
259
  raise TypeError("can only combine Kernel instances") # pragma: no cover
@@ -57,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
57
57
  Returns:
58
58
  A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
59
59
  """
60
- frequencies = kernel.frequencies.value
60
+ frequencies = kernel.frequencies
61
61
  scaling_factor = kernel.base_kernel.lengthscale.value
62
62
  z = jnp.matmul(x, (frequencies / scaling_factor).T)
63
63
  z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
@@ -42,7 +42,7 @@ class EigenKernelComputation(AbstractKernelComputation):
42
42
  # Transform the eigenvalues of the graph Laplacian according to the
43
43
  # RBF kernel's SPDE form.
44
44
  S = jnp.power(
45
- kernel.eigenvalues.value
45
+ kernel.eigenvalues
46
46
  + 2
47
47
  * kernel.smoothness.value
48
48
  / kernel.lengthscale.value
@@ -30,7 +30,6 @@ from gpjax.kernels.stationary.base import StationaryKernel
30
30
  from gpjax.parameters import (
31
31
  Parameter,
32
32
  PositiveReal,
33
- Static,
34
33
  )
35
34
  from gpjax.typing import (
36
35
  Array,
@@ -55,9 +54,9 @@ class GraphKernel(StationaryKernel):
55
54
  """
56
55
 
57
56
  num_vertex: tp.Union[ScalarInt, None]
58
- laplacian: Static[Float[Array, "N N"]]
59
- eigenvalues: Static[Float[Array, "N 1"]]
60
- eigenvectors: Static[Float[Array, "N N"]]
57
+ laplacian: Float[Array, "N N"]
58
+ eigenvalues: Float[Array, "N 1"]
59
+ eigenvectors: Float[Array, "N N"]
61
60
  name: str = "Graph Matérn"
62
61
 
63
62
  def __init__(
@@ -91,11 +90,11 @@ class GraphKernel(StationaryKernel):
91
90
  else:
92
91
  self.smoothness = PositiveReal(smoothness)
93
92
 
94
- self.laplacian = Static(laplacian)
95
- evals, eigenvectors = jnp.linalg.eigh(self.laplacian.value)
96
- self.eigenvectors = Static(eigenvectors)
97
- self.eigenvalues = Static(evals.reshape(-1, 1))
98
- self.num_vertex = self.eigenvalues.value.shape[0]
93
+ self.laplacian = laplacian
94
+ evals, eigenvectors = jnp.linalg.eigh(self.laplacian)
95
+ self.eigenvectors = eigenvectors
96
+ self.eigenvalues = evals.reshape(-1, 1)
97
+ self.num_vertex = self.eigenvalues.shape[0]
99
98
 
100
99
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
101
100
 
@@ -107,7 +106,7 @@ class GraphKernel(StationaryKernel):
107
106
  S,
108
107
  **kwargs,
109
108
  ):
110
- Kxx = (jax_gather_nd(self.eigenvectors.value, x) * S.squeeze()) @ jnp.transpose(
111
- jax_gather_nd(self.eigenvectors.value, y)
109
+ Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110
+ jax_gather_nd(self.eigenvectors, y)
112
111
  ) # shape (n,n)
113
112
  return Kxx.squeeze()
@@ -25,7 +25,6 @@ from gpjax.kernels.computations import (
25
25
  )
26
26
  from gpjax.parameters import (
27
27
  NonNegativeReal,
28
- PositiveReal,
29
28
  )
30
29
  from gpjax.typing import (
31
30
  Array,
@@ -82,30 +81,13 @@ class ArcCosine(AbstractKernel):
82
81
 
83
82
  self.order = order
84
83
 
85
- if isinstance(weight_variance, nnx.Variable):
86
- self.weight_variance = weight_variance
87
- else:
88
- self.weight_variance = PositiveReal(weight_variance)
89
- if tp.TYPE_CHECKING:
90
- self.weight_variance = tp.cast(
91
- PositiveReal[WeightVariance], self.weight_variance
92
- )
84
+ self.weight_variance = weight_variance
93
85
 
94
86
  if isinstance(variance, nnx.Variable):
95
87
  self.variance = variance
96
88
  else:
97
89
  self.variance = NonNegativeReal(variance)
98
- if tp.TYPE_CHECKING:
99
- self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
100
-
101
- if isinstance(bias_variance, nnx.Variable):
102
- self.bias_variance = bias_variance
103
- else:
104
- self.bias_variance = PositiveReal(bias_variance)
105
- if tp.TYPE_CHECKING:
106
- self.bias_variance = tp.cast(
107
- PositiveReal[ScalarArray], self.bias_variance
108
- )
90
+ self.bias_variance = bias_variance
109
91
 
110
92
  self.name = f"ArcCosine (order {self.order})"
111
93
 
@@ -141,7 +123,17 @@ class ArcCosine(AbstractKernel):
141
123
  Returns:
142
124
  ScalarFloat: The value of the weighted product between the two arguments``.
143
125
  """
144
- return jnp.inner(self.weight_variance.value * x, y) + self.bias_variance.value
126
+ weight_var = (
127
+ self.weight_variance.value
128
+ if hasattr(self.weight_variance, "value")
129
+ else self.weight_variance
130
+ )
131
+ bias_var = (
132
+ self.bias_variance.value
133
+ if hasattr(self.bias_variance, "value")
134
+ else self.bias_variance
135
+ )
136
+ return jnp.inner(weight_var * x, y) + bias_var
145
137
 
146
138
  def _J(self, theta: ScalarFloat) -> ScalarFloat:
147
139
  r"""Evaluate the angular dependency function corresponding to the desired order.
@@ -69,12 +69,9 @@ class Polynomial(AbstractKernel):
69
69
 
70
70
  self.degree = degree
71
71
 
72
- if isinstance(shift, nnx.Variable):
73
- self.shift = shift
74
- else:
75
- self.shift = PositiveReal(shift)
76
- if tp.TYPE_CHECKING:
77
- self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
72
+ self.shift = shift
73
+ if tp.TYPE_CHECKING and not isinstance(shift, nnx.Variable):
74
+ self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
78
75
 
79
76
  if isinstance(variance, nnx.Variable):
80
77
  self.variance = variance
@@ -88,7 +85,9 @@ class Polynomial(AbstractKernel):
88
85
  def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
89
86
  x = self.slice_input(x)
90
87
  y = self.slice_input(y)
91
- K = jnp.power(
92
- self.shift.value + self.variance.value * jnp.dot(x, y), self.degree
88
+ shift_val = self.shift.value if hasattr(self.shift, "value") else self.shift
89
+ variance_val = (
90
+ self.variance.value if hasattr(self.variance, "value") else self.variance
93
91
  )
92
+ K = jnp.power(shift_val + variance_val * jnp.dot(x, y), self.degree)
94
93
  return K.squeeze()
@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
127
127
  """
128
128
 
129
129
  if isinstance(lengthscale, nnx.Variable):
130
- return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
130
+ return _check_lengthscale_dims_compat(lengthscale.value, n_dims)
131
131
 
132
132
  lengthscale = jnp.asarray(lengthscale)
133
133
  ls_shape = jnp.shape(lengthscale)
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
146
146
  return n_dims
147
147
 
148
148
 
149
- def _check_lengthscale_dims_compat_old(
150
- lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
151
- n_dims: tp.Union[int, None],
152
- ):
153
- r"""Check that the lengthscale is compatible with n_dims.
154
-
155
- If possible, infer the number of input dimensions from the lengthscale.
156
- """
157
-
158
- if isinstance(lengthscale, nnx.Variable):
159
- return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
160
-
161
- lengthscale = jnp.asarray(lengthscale)
162
- ls_shape = jnp.shape(lengthscale)
163
-
164
- if ls_shape == ():
165
- return lengthscale, n_dims
166
- elif ls_shape != () and n_dims is None:
167
- return lengthscale, ls_shape[0]
168
- elif ls_shape != () and n_dims is not None:
169
- if ls_shape != (n_dims,):
170
- raise ValueError(
171
- "Expected `lengthscale` to be compatible with the number "
172
- f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
173
- f"but the number of input dimensions is {n_dims}."
174
- )
175
- return lengthscale, n_dims
176
-
177
-
178
149
  def _check_lengthscale(lengthscale: tp.Any):
179
150
  """Check that the lengthscale is a valid value."""
180
151
 
@@ -35,7 +35,7 @@ class Matern12(StationaryKernel):
35
35
  lengthscale parameter $\ell$ and variance $\sigma^2$.
36
36
 
37
37
  $$
38
- k(x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell^2}\Bigg)
38
+ k(x, y) = \sigma^2\exp\Bigg(-\frac{\lvert x-y \rvert}{2\ell}\Bigg)
39
39
  $$
40
40
  """
41
41
 
@@ -32,7 +32,7 @@ class Matern32(StationaryKernel):
32
32
  lengthscale parameter $\ell$ and variance $\sigma^2$.
33
33
 
34
34
  $$
35
- k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell^2} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
35
+ k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
36
36
  $$
37
37
  """
38
38
 
@@ -33,7 +33,7 @@ class Matern52(StationaryKernel):
33
33
  lengthscale parameter $\ell$ and variance $\sigma^2$.
34
34
 
35
35
  $$
36
- k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell^2} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
36
+ k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
37
37
  $$
38
38
  """
39
39
 
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
23
23
  DenseKernelComputation,
24
24
  )
25
25
  from gpjax.kernels.stationary.base import StationaryKernel
26
- from gpjax.parameters import PositiveReal
27
26
  from gpjax.typing import (
28
27
  Array,
29
28
  ScalarArray,
@@ -72,10 +71,7 @@ class Periodic(StationaryKernel):
72
71
  covariance matrix.
73
72
  """
74
73
 
75
- if isinstance(period, nnx.Variable):
76
- self.period = period
77
- else:
78
- self.period = PositiveReal(period)
74
+ self.period = period
79
75
 
80
76
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
81
77
 
@@ -84,8 +80,9 @@ class Periodic(StationaryKernel):
84
80
  ) -> Float[Array, ""]:
85
81
  x = self.slice_input(x)
86
82
  y = self.slice_input(y)
83
+ period_val = self.period.value if hasattr(self.period, "value") else self.period
87
84
  sine_squared = (
88
- jnp.sin(jnp.pi * (x - y) / self.period.value) / self.lengthscale.value
85
+ jnp.sin(jnp.pi * (x - y) / period_val) / self.lengthscale.value
89
86
  ) ** 2
90
87
  K = self.variance.value * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
91
88
  return K.squeeze()
@@ -24,7 +24,6 @@ from gpjax.kernels.computations import (
24
24
  )
25
25
  from gpjax.kernels.stationary.base import StationaryKernel
26
26
  from gpjax.kernels.stationary.utils import euclidean_distance
27
- from gpjax.parameters import SigmoidBounded
28
27
  from gpjax.typing import (
29
28
  Array,
30
29
  ScalarArray,
@@ -76,10 +75,7 @@ class PoweredExponential(StationaryKernel):
76
75
  compute_engine: the computation engine that the kernel uses to compute the
77
76
  covariance matrix.
78
77
  """
79
- if isinstance(power, nnx.Variable):
80
- self.power = power
81
- else:
82
- self.power = SigmoidBounded(power)
78
+ self.power = power
83
79
 
84
80
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
85
81
 
@@ -88,7 +84,6 @@ class PoweredExponential(StationaryKernel):
88
84
  ) -> Float[Array, ""]:
89
85
  x = self.slice_input(x) / self.lengthscale.value
90
86
  y = self.slice_input(y) / self.lengthscale.value
91
- K = self.variance.value * jnp.exp(
92
- -(euclidean_distance(x, y) ** self.power.value)
93
- )
87
+ power_val = self.power.value if hasattr(self.power, "value") else self.power
88
+ K = self.variance.value * jnp.exp(-(euclidean_distance(x, y) ** power_val))
94
89
  return K.squeeze()