gpjax 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -39,7 +39,7 @@ __license__ = "MIT"
39
39
  __description__ = "Didactic Gaussian processes in JAX"
40
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.10.1"
42
+ __version__ = "0.10.2"
43
43
 
44
44
  __all__ = [
45
45
  "base",
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
46
46
  self,
47
47
  active_dims: tp.Union[list[int], slice, None] = None,
48
48
  degree: int = 2,
49
- shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 0.0,
49
+ shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
50
50
  variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
51
51
  n_dims: tp.Union[int, None] = None,
52
52
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
gpjax/parameters.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import typing as tp
2
2
 
3
3
  from flax import nnx
4
+ from jax.experimental import checkify
4
5
  import jax.numpy as jnp
5
6
  import jax.tree_util as jtu
6
7
  from jax.typing import ArrayLike
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
84
85
 
85
86
  def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
86
87
  super().__init__(value=value, tag=tag, **kwargs)
87
-
88
- _check_is_positive(self.value)
88
+ _safe_assert(_check_is_positive, self.value)
89
89
 
90
90
 
91
91
  class Real(Parameter[T]):
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
101
101
  def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
102
102
  super().__init__(value=value, tag=tag, **kwargs)
103
103
 
104
- _check_in_bounds(self.value, 0.0, 1.0)
104
+ # Only perform validation in non-JIT contexts
105
+ if (
106
+ not isinstance(value, jnp.ndarray)
107
+ or not getattr(value, "aval", None) is None
108
+ ):
109
+ _safe_assert(
110
+ _check_in_bounds,
111
+ self.value,
112
+ low=jnp.array(0.0),
113
+ high=jnp.array(1.0),
114
+ )
105
115
 
106
116
 
107
117
  class Static(nnx.Variable[T]):
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
120
130
  def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
121
131
  super().__init__(value=value, tag=tag, **kwargs)
122
132
 
123
- _check_is_square(self.value)
124
- _check_is_lower_triangular(self.value)
133
+ # Only perform validation in non-JIT contexts
134
+ if (
135
+ not isinstance(value, jnp.ndarray)
136
+ or not getattr(value, "aval", None) is None
137
+ ):
138
+ _safe_assert(_check_is_square, self.value)
139
+ _safe_assert(_check_is_lower_triangular, self.value)
125
140
 
126
141
 
127
142
  DEFAULT_BIJECTION = {
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
132
147
  }
133
148
 
134
149
 
135
- def _check_is_arraylike(value: T):
150
+ def _check_is_arraylike(value: T) -> None:
151
+ """Check if a value is array-like.
152
+
153
+ Args:
154
+ value: The value to check.
155
+
156
+ Raises:
157
+ TypeError: If the value is not array-like.
158
+ """
136
159
  if not isinstance(value, (ArrayLike, list)):
137
160
  raise TypeError(
138
161
  f"Expected parameter value to be an array-like type. Got {value}."
139
162
  )
140
163
 
141
164
 
142
- def _check_is_positive(value: T):
143
- if jnp.any(value < 0):
144
- raise ValueError(
145
- f"Expected parameter value to be strictly positive. Got {value}."
146
- )
165
+ @checkify.checkify
166
+ def _check_is_positive(value):
167
+ checkify.check(
168
+ jnp.all(value > 0), "value needs to be positive, got {value}", value=value
169
+ )
147
170
 
148
171
 
149
- def _check_is_square(value: T):
150
- if value.shape[0] != value.shape[1]:
151
- raise ValueError(
152
- f"Expected parameter value to be a square matrix. Got {value}."
153
- )
172
+ @checkify.checkify
173
+ def _check_is_square(value: T) -> None:
174
+ """Check if a value is a square matrix.
154
175
 
176
+ Args:
177
+ value: The value to check.
155
178
 
156
- def _check_is_lower_triangular(value: T):
157
- if not jnp.all(jnp.tril(value) == value):
158
- raise ValueError(
159
- f"Expected parameter value to be a lower triangular matrix. Got {value}."
160
- )
179
+ Raises:
180
+ ValueError: If the value is not a square matrix.
181
+ """
182
+ checkify.check(
183
+ value.shape[0] == value.shape[1],
184
+ "value needs to be a square matrix, got {value}",
185
+ value=value,
186
+ )
161
187
 
162
188
 
163
- def _check_in_bounds(value: T, low: float, high: float):
164
- if jnp.any((value < low) | (value > high)):
165
- raise ValueError(
166
- f"Expected parameter value to be bounded between {low} and {high}. Got {value}."
167
- )
189
+ @checkify.checkify
190
+ def _check_is_lower_triangular(value: T) -> None:
191
+ """Check if a value is a lower triangular matrix.
192
+
193
+ Args:
194
+ value: The value to check.
195
+
196
+ Raises:
197
+ ValueError: If the value is not a lower triangular matrix.
198
+ """
199
+ checkify.check(
200
+ jnp.all(jnp.tril(value) == value),
201
+ "value needs to be a lower triangular matrix, got {value}",
202
+ value=value,
203
+ )
204
+
205
+
206
+ @checkify.checkify
207
+ def _check_in_bounds(value: T, low: T, high: T) -> None:
208
+ """Check if a value is bounded between low and high.
209
+
210
+ Args:
211
+ value: The value to check.
212
+ low: The lower bound.
213
+ high: The upper bound.
214
+
215
+ Raises:
216
+ ValueError: If any element of value is outside the bounds.
217
+ """
218
+ checkify.check(
219
+ jnp.all((value >= low) & (value <= high)),
220
+ "value needs to be bounded between {low} and {high}, got {value}",
221
+ value=value,
222
+ low=low,
223
+ high=high,
224
+ )
225
+
226
+
227
+ def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
228
+ error, _ = fn(value, **kwargs)
229
+ checkify.check_error(error)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.1
3
+ Version: 0.10.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -1,4 +1,4 @@
1
- gpjax/__init__.py,sha256=bd56ijur1Pxm_Ww4iSPn-CtiHXAvIR-FFl-d2HvbiTE,1654
1
+ gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
2
2
  gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
@@ -9,7 +9,7 @@ gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
10
  gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
11
11
  gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
- gpjax/parameters.py,sha256=Z4Wy3gEzPZG23-dtqC437_ZWnd_sPe9LcLCKn21ZBvA,4886
12
+ gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
15
  gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
@@ -30,8 +30,7 @@ gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5
30
30
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
31
  gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
32
32
  gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
33
- gpjax/kernels/nonstationary/oak.py,sha256=yKweXX6ptx9aQuKyDzXuGAk64hjswRMZcpIPifDGBu0,14460
34
- gpjax/kernels/nonstationary/polynomial.py,sha256=yTGobMPbCnKlj4PiQPSXEkWNrj2sjg_x9zFsnFa_j4E,3257
33
+ gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
35
34
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
36
35
  gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
37
36
  gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
@@ -43,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
43
42
  gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
44
43
  gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
45
44
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
46
- gpjax-0.10.1.dist-info/METADATA,sha256=UfBcK541_MbD5Rbl9ocozq8DTOE_O3l8eBNvQZ7LQTs,9970
47
- gpjax-0.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
- gpjax-0.10.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
- gpjax-0.10.1.dist-info/RECORD,,
45
+ gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
46
+ gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
+ gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
+ gpjax-0.10.2.dist-info/RECORD,,
@@ -1,406 +0,0 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Optional, Callable, Union, Tuple
17
-
18
- import beartype.typing as tp
19
- from flax import nnx
20
- import jax
21
- import jax.numpy as jnp
22
- from jaxtyping import Float, Array, Scalar
23
- from gpjax.typing import ScalarFloat
24
-
25
- from gpjax.kernels.base import AbstractKernel
26
- from gpjax.kernels.computations import AbstractKernelComputation, DenseKernelComputation
27
- from gpjax.parameters import PositiveReal, Parameter
28
- from jax.typing import ArrayLike
29
-
30
-
31
- def legendre_polynomial(n: int) -> Callable:
32
- """Compute the Legendre polynomial of degree n.
33
-
34
- Args:
35
- n: Degree of the Legendre polynomial.
36
-
37
- Returns:
38
- Function that evaluates the nth Legendre polynomial at given points.
39
- """
40
- if n == 0:
41
- return lambda x: jnp.ones_like(x)
42
- elif n == 1:
43
- return lambda x: x
44
- else:
45
- p_n_minus_2 = legendre_polynomial(n - 2)
46
- p_n_minus_1 = legendre_polynomial(n - 1)
47
- return (
48
- lambda x: ((2 * n - 1) * x * p_n_minus_1(x) - (n - 1) * p_n_minus_2(x)) / n
49
- )
50
-
51
-
52
- def legendre_polynomial_derivative(n: int) -> Callable:
53
- """Compute the derivative of the Legendre polynomial of degree n.
54
-
55
- Args:
56
- n: Degree of the Legendre polynomial.
57
-
58
- Returns:
59
- Function that evaluates the derivative of the nth Legendre polynomial.
60
- """
61
- if n == 0:
62
- return lambda x: jnp.zeros_like(x)
63
- else:
64
- p_n_minus_1 = legendre_polynomial(n - 1)
65
- return (
66
- lambda x: n
67
- * (p_n_minus_1(x) - x * legendre_polynomial_derivative(n - 1)(x))
68
- / (1 - x**2 + 1e-10)
69
- )
70
-
71
-
72
- def gauss_legendre_quadrature(
73
- deg: int, a: float = -1.0, b: float = 1.0
74
- ) -> Tuple[jnp.ndarray, jnp.ndarray]:
75
- """Generate Gauss-Legendre quadrature points and weights.
76
-
77
- Args:
78
- deg: Number of quadrature points.
79
- a: Lower limit of integration (default: -1.0).
80
- b: Upper limit of integration (default: 1.0).
81
-
82
- Returns:
83
- Tuple of (points, weights) for Gauss-Legendre quadrature.
84
- """
85
- # For computational efficiency, use a simpler approach for small degrees
86
- if deg <= 4:
87
- # Hardcoded points and weights for small degrees
88
- if deg == 1:
89
- x = jnp.array([0.0])
90
- w = jnp.array([2.0])
91
- elif deg == 2:
92
- x = jnp.array([-0.5773502691896257, 0.5773502691896257])
93
- w = jnp.array([1.0, 1.0])
94
- elif deg == 3:
95
- x = jnp.array([-0.7745966692414834, 0.0, 0.7745966692414834])
96
- w = jnp.array([0.5555555555555556, 0.8888888888888888, 0.5555555555555556])
97
- elif deg == 4:
98
- x = jnp.array(
99
- [
100
- -0.8611363115940526,
101
- -0.3399810435848563,
102
- 0.3399810435848563,
103
- 0.8611363115940526,
104
- ]
105
- )
106
- w = jnp.array(
107
- [
108
- 0.3478548451374538,
109
- 0.6521451548625461,
110
- 0.6521451548625461,
111
- 0.3478548451374538,
112
- ]
113
- )
114
- else:
115
- # Initial guess for roots (Chebyshev nodes)
116
- k = jnp.arange(1, deg + 1)
117
- x0 = jnp.cos(jnp.pi * (k - 0.25) / (deg + 0.5))
118
-
119
- # Newton iteration to find roots more accurately
120
- # In practice, we would use more iterations, but for simplicity we use a fixed number
121
- P_n = legendre_polynomial(deg)
122
- dP_n = legendre_polynomial_derivative(deg)
123
-
124
- # Single Newton step for demonstration (would use a loop in practice)
125
- x = x0 - P_n(x0) / (dP_n(x0) + 1e-10)
126
-
127
- # Compute weights
128
- w = 2.0 / ((1.0 - x**2) * dP_n(x) ** 2 + 1e-10)
129
-
130
- # Scale from [-1,1] to [a,b]
131
- x = 0.5 * (b - a) * x + 0.5 * (b + a)
132
- w = 0.5 * (b - a) * w
133
-
134
- return x, w
135
-
136
-
137
- class OrthogonalAdditiveKernel(AbstractKernel):
138
- """Orthogonal Additive Kernels (OAKs) generalize additive kernels by orthogonalizing
139
- the feature space to create uncorrelated kernel components.
140
-
141
- This implementation uses a Gauss-Legendre quadrature approximation for the required
142
- one-dimensional integrals involving the base kernels, allowing for arbitrary base kernels.
143
-
144
- References:
145
- - X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited.
146
- Proceedings of the 39th International Conference on Machine Learning. Jul 2022.
147
- """
148
-
149
- base_kernel: AbstractKernel
150
- quad_deg: int
151
- dim: int
152
- offset: nnx.Variable
153
- coeffs_1: nnx.Variable
154
- coeffs_2: Optional[nnx.Variable]
155
- z: jnp.ndarray
156
- w: jnp.ndarray
157
- name: str = "OrthogonalAdditiveKernel"
158
-
159
- def __init__(
160
- self,
161
- base_kernel: AbstractKernel,
162
- dim: int,
163
- quad_deg: int = 32,
164
- second_order: bool = False,
165
- active_dims: tp.Union[list[int], slice, None] = None,
166
- n_dims: tp.Union[int, None] = None,
167
- offset: tp.Union[float, Parameter[ScalarFloat]] = 1.0,
168
- coeffs_1: tp.Union[ArrayLike, Parameter[ArrayLike]] = None,
169
- coeffs_2: tp.Union[ArrayLike, Parameter[ArrayLike]] = None,
170
- compute_engine: AbstractKernelComputation = DenseKernelComputation(),
171
- ):
172
- """Initialise the OrthogonalAdditiveKernel.
173
-
174
- Args:
175
- base_kernel: The kernel which to orthogonalize and evaluate.
176
- dim: Input dimensionality of the kernel.
177
- quad_deg: Number of integration nodes for orthogonalization.
178
- second_order: Toggles second order interactions. If true, both the time and
179
- space complexity of evaluating the kernel are quadratic in `dim`.
180
- active_dims: The indices of the input dimensions that the kernel operates on.
181
- n_dims: The number of input dimensions. If not provided, it will be inferred.
182
- offset: The zeroth-order coefficient.
183
- coeffs_1: The first-order coefficients. Should be a 1D array of length dim.
184
- coeffs_2: The second-order coefficients. Should be a 2D array of shape (dim, dim).
185
- compute_engine: The computation engine to use for kernel evaluations.
186
- """
187
- super().__init__(
188
- active_dims=active_dims, n_dims=n_dims, compute_engine=compute_engine
189
- )
190
-
191
- self.base_kernel = base_kernel
192
- self.quad_deg = quad_deg
193
- self.dim = dim
194
-
195
- # Integration nodes and weights for [0, 1]
196
- self.z, self.w = gauss_legendre_quadrature(quad_deg, a=0.0, b=1.0)
197
-
198
- # Create expandable axes
199
- z_expanded = jnp.expand_dims(self.z, axis=-1)
200
- self.z = jnp.broadcast_to(z_expanded, (quad_deg, dim))
201
- self.w = jnp.expand_dims(self.w, axis=-1)
202
-
203
- # Default coefficients if not provided
204
- if isinstance(offset, Parameter):
205
- self.offset = offset
206
- else:
207
- self.offset = PositiveReal(jnp.array(offset))
208
-
209
- if coeffs_1 is None:
210
- log_d = jnp.log(dim)
211
- default_coeffs_1 = jnp.exp(-log_d) * jnp.ones(dim)
212
- self.coeffs_1 = PositiveReal(default_coeffs_1)
213
- elif isinstance(coeffs_1, Parameter):
214
- self.coeffs_1 = coeffs_1
215
- else:
216
- self.coeffs_1 = PositiveReal(jnp.array(coeffs_1))
217
-
218
- self.second_order = second_order
219
- if second_order:
220
- if coeffs_2 is None:
221
- log_d = jnp.log(dim)
222
- # Initialize with zeros for upper triangular part (excluding diagonal)
223
- n_entries = dim * (dim - 1) // 2
224
- default_coeffs_2 = jnp.exp(-2 * log_d) * jnp.ones(n_entries)
225
- self.coeffs_2_raw = PositiveReal(default_coeffs_2)
226
- elif isinstance(coeffs_2, Parameter):
227
- self.coeffs_2_raw = coeffs_2
228
- else:
229
- self.coeffs_2_raw = PositiveReal(jnp.array(coeffs_2))
230
-
231
- # Pre-compute indices for efficient triu operations
232
- self.triu_indices = jnp.triu_indices(dim, k=1)
233
- else:
234
- self.coeffs_2_raw = None
235
-
236
- # Compute normalizer (in __call__)
237
- self._normalizer = None
238
-
239
- @property
240
- def coeffs_2(self) -> Optional[jnp.ndarray]:
241
- """Returns a full matrix of second-order coefficients.
242
-
243
- Returns:
244
- A dim x dim array of second-order coefficients or None if second_order is False.
245
- """
246
- if not self.second_order or self.coeffs_2_raw is None:
247
- return None
248
-
249
- # Create a full matrix from the raw coefficients
250
- coeffs_2_flat = self.coeffs_2_raw.value
251
- coeffs_2_full = jnp.zeros((self.dim, self.dim))
252
-
253
- # Fill the upper triangular part
254
- i, j = self.triu_indices
255
- coeffs_2_full = coeffs_2_full.at[i, j].set(coeffs_2_flat)
256
-
257
- # Make it symmetric
258
- coeffs_2_full = coeffs_2_full + jnp.transpose(coeffs_2_full)
259
-
260
- return coeffs_2_full
261
-
262
- def normalizer(self, eps: float = 1e-6) -> jnp.ndarray:
263
- """Integrates the orthogonalized base kernels over [0, 1] x [0, 1].
264
-
265
- Args:
266
- eps: Minimum value constraint on the normalizers to avoid division by zero.
267
-
268
- Returns:
269
- A d-dim tensor of normalization constants.
270
- """
271
- if self._normalizer is None or self.training:
272
- # Compute K(z, z) - base kernel gram matrix on integration points
273
- K_zz = self.base_kernel.cross_covariance(self.z, self.z)
274
-
275
- # Integrate: w^T * K * w
276
- self._normalizer = jnp.matmul(
277
- jnp.matmul(jnp.transpose(self.w), K_zz), self.w
278
- )
279
-
280
- # Ensure positive values
281
- self._normalizer = jnp.maximum(self._normalizer, eps)
282
-
283
- return self._normalizer
284
-
285
- def _orthogonal_base_kernels(
286
- self, x1: Float[Array, "N D"], x2: Float[Array, "M D"]
287
- ) -> Float[Array, "D N M"]:
288
- """Evaluates the set of d orthogonalized base kernels.
289
-
290
- Args:
291
- x1: Input array of shape [N, D]
292
- x2: Input array of shape [M, D]
293
-
294
- Returns:
295
- Array of shape [D, N, M] with orthogonalized kernel evaluations
296
- """
297
- # Compute base kernel between inputs
298
- K_x1x2 = self.base_kernel.cross_covariance(x1, x2) # [N, M]
299
-
300
- # Compute normalizer
301
- norm = jnp.sqrt(self.normalizer())
302
- w_normalized = self.w / norm
303
-
304
- # Compute base kernel between x1 and integration points z
305
- K_x1z = self.base_kernel.cross_covariance(x1, self.z) # [N, quad_deg]
306
- S_x1 = jnp.matmul(K_x1z, w_normalized) # [N, 1]
307
-
308
- # Compute base kernel between x2 and integration points z
309
- if x1 is x2:
310
- S_x2 = S_x1
311
- else:
312
- K_x2z = self.base_kernel.cross_covariance(x2, self.z) # [M, quad_deg]
313
- S_x2 = jnp.matmul(K_x2z, w_normalized) # [M, 1]
314
-
315
- # Compute orthogonal kernel: K_x1x2 - S_x1 * S_x2^T
316
- K_ortho = K_x1x2 - jnp.outer(S_x1, S_x2)
317
-
318
- return K_ortho
319
-
320
- def __call__(
321
- self,
322
- x: Float[Array, "N D"],
323
- y: Float[Array, "M D"],
324
- ) -> ScalarFloat:
325
- """Evaluate the kernel at a single pair of inputs.
326
-
327
- Args:
328
- x: First input.
329
- y: Second input.
330
-
331
- Returns:
332
- The kernel value at (x, y).
333
- """
334
- # Slice inputs to relevant dimensions
335
- x_sliced = self.slice_input(x)
336
- y_sliced = self.slice_input(y)
337
-
338
- # Get orthogonalized kernels
339
- K_ortho = self._orthogonal_base_kernels(
340
- jnp.expand_dims(x_sliced, 0), jnp.expand_dims(y_sliced, 0)
341
- ) # [1, 1]
342
-
343
- # Apply first-order effects
344
- first_order = jnp.sum(self.coeffs_1.value * K_ortho)
345
-
346
- # Add offset
347
- result = self.offset.value + first_order
348
-
349
- # Add second-order effects if enabled
350
- if self.second_order and self.coeffs_2 is not None:
351
- # For a single point evaluation, we use a simpler approach
352
- # Computing the tensor of second order interactions
353
- second_order = 0.0
354
- for i in range(self.dim):
355
- for j in range(i + 1, self.dim):
356
- coef = self.coeffs_2[i, j]
357
- if coef > 0:
358
- second_order += coef * K_ortho[i] * K_ortho[j]
359
-
360
- result = result + second_order
361
-
362
- return result
363
-
364
- def cross_covariance(
365
- self, x1: Float[Array, "N D"], x2: Float[Array, "M D"]
366
- ) -> Float[Array, "N M"]:
367
- """Compute the cross-covariance matrix of the kernel.
368
-
369
- Args:
370
- x1: First input matrix of shape [N, D].
371
- x2: Second input matrix of shape [M, D].
372
-
373
- Returns:
374
- Cross-covariance matrix of shape [N, M].
375
- """
376
- # Slice inputs to relevant dimensions
377
- x1 = self.slice_input(x1)
378
- x2 = self.slice_input(x2)
379
-
380
- # Get orthogonalized kernels for all dimensions
381
- K_ortho = self._orthogonal_base_kernels(x1, x2) # [D, N, M]
382
-
383
- # Apply first-order effects (sum over dimensions)
384
- coeffs_1 = self.coeffs_1.value
385
- first_order = jnp.tensordot(coeffs_1, K_ortho, axes=([0], [0])) # [N, M]
386
-
387
- # Add offset (broadcast to match output shape)
388
- result = jnp.broadcast_to(self.offset.value, first_order.shape) + first_order
389
-
390
- # Add second-order effects if enabled
391
- if self.second_order and self.coeffs_2 is not None:
392
- # Compute second-order interactions using einsum
393
- coeffs_2_full = self.coeffs_2
394
- second_order = jnp.einsum(
395
- "ij,ink,jml->nml",
396
- coeffs_2_full,
397
- jnp.expand_dims(K_ortho, 1),
398
- jnp.expand_dims(K_ortho, 0),
399
- )
400
-
401
- # Sum over dimensions i, j
402
- second_order = jnp.sum(second_order, axis=(0, 1))
403
-
404
- result = result + second_order
405
-
406
- return result
File without changes