gpjax 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/kernels/nonstationary/polynomial.py +1 -1
- gpjax/parameters.py +88 -26
- {gpjax-0.10.1.dist-info → gpjax-0.10.2.dist-info}/METADATA +1 -1
- {gpjax-0.10.1.dist-info → gpjax-0.10.2.dist-info}/RECORD +7 -8
- gpjax/kernels/nonstationary/oak.py +0 -406
- {gpjax-0.10.1.dist-info → gpjax-0.10.2.dist-info}/WHEEL +0 -0
- {gpjax-0.10.1.dist-info → gpjax-0.10.2.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -39,7 +39,7 @@ __license__ = "MIT"
|
|
|
39
39
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
40
40
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
41
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.10.
|
|
42
|
+
__version__ = "0.10.2"
|
|
43
43
|
|
|
44
44
|
__all__ = [
|
|
45
45
|
"base",
|
|
@@ -46,7 +46,7 @@ class Polynomial(AbstractKernel):
|
|
|
46
46
|
self,
|
|
47
47
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
48
48
|
degree: int = 2,
|
|
49
|
-
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] =
|
|
49
|
+
shift: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
50
50
|
variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
|
|
51
51
|
n_dims: tp.Union[int, None] = None,
|
|
52
52
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
gpjax/parameters.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
3
|
from flax import nnx
|
|
4
|
+
from jax.experimental import checkify
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import jax.tree_util as jtu
|
|
6
7
|
from jax.typing import ArrayLike
|
|
@@ -84,8 +85,7 @@ class PositiveReal(Parameter[T]):
|
|
|
84
85
|
|
|
85
86
|
def __init__(self, value: T, tag: ParameterTag = "positive", **kwargs):
|
|
86
87
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
87
|
-
|
|
88
|
-
_check_is_positive(self.value)
|
|
88
|
+
_safe_assert(_check_is_positive, self.value)
|
|
89
89
|
|
|
90
90
|
|
|
91
91
|
class Real(Parameter[T]):
|
|
@@ -101,7 +101,17 @@ class SigmoidBounded(Parameter[T]):
|
|
|
101
101
|
def __init__(self, value: T, tag: ParameterTag = "sigmoid", **kwargs):
|
|
102
102
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
103
103
|
|
|
104
|
-
|
|
104
|
+
# Only perform validation in non-JIT contexts
|
|
105
|
+
if (
|
|
106
|
+
not isinstance(value, jnp.ndarray)
|
|
107
|
+
or not getattr(value, "aval", None) is None
|
|
108
|
+
):
|
|
109
|
+
_safe_assert(
|
|
110
|
+
_check_in_bounds,
|
|
111
|
+
self.value,
|
|
112
|
+
low=jnp.array(0.0),
|
|
113
|
+
high=jnp.array(1.0),
|
|
114
|
+
)
|
|
105
115
|
|
|
106
116
|
|
|
107
117
|
class Static(nnx.Variable[T]):
|
|
@@ -120,8 +130,13 @@ class LowerTriangular(Parameter[T]):
|
|
|
120
130
|
def __init__(self, value: T, tag: ParameterTag = "lower_triangular", **kwargs):
|
|
121
131
|
super().__init__(value=value, tag=tag, **kwargs)
|
|
122
132
|
|
|
123
|
-
|
|
124
|
-
|
|
133
|
+
# Only perform validation in non-JIT contexts
|
|
134
|
+
if (
|
|
135
|
+
not isinstance(value, jnp.ndarray)
|
|
136
|
+
or not getattr(value, "aval", None) is None
|
|
137
|
+
):
|
|
138
|
+
_safe_assert(_check_is_square, self.value)
|
|
139
|
+
_safe_assert(_check_is_lower_triangular, self.value)
|
|
125
140
|
|
|
126
141
|
|
|
127
142
|
DEFAULT_BIJECTION = {
|
|
@@ -132,36 +147,83 @@ DEFAULT_BIJECTION = {
|
|
|
132
147
|
}
|
|
133
148
|
|
|
134
149
|
|
|
135
|
-
def _check_is_arraylike(value: T):
|
|
150
|
+
def _check_is_arraylike(value: T) -> None:
|
|
151
|
+
"""Check if a value is array-like.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
value: The value to check.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
TypeError: If the value is not array-like.
|
|
158
|
+
"""
|
|
136
159
|
if not isinstance(value, (ArrayLike, list)):
|
|
137
160
|
raise TypeError(
|
|
138
161
|
f"Expected parameter value to be an array-like type. Got {value}."
|
|
139
162
|
)
|
|
140
163
|
|
|
141
164
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
165
|
+
@checkify.checkify
|
|
166
|
+
def _check_is_positive(value):
|
|
167
|
+
checkify.check(
|
|
168
|
+
jnp.all(value > 0), "value needs to be positive, got {value}", value=value
|
|
169
|
+
)
|
|
147
170
|
|
|
148
171
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
f"Expected parameter value to be a square matrix. Got {value}."
|
|
153
|
-
)
|
|
172
|
+
@checkify.checkify
|
|
173
|
+
def _check_is_square(value: T) -> None:
|
|
174
|
+
"""Check if a value is a square matrix.
|
|
154
175
|
|
|
176
|
+
Args:
|
|
177
|
+
value: The value to check.
|
|
155
178
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the value is not a square matrix.
|
|
181
|
+
"""
|
|
182
|
+
checkify.check(
|
|
183
|
+
value.shape[0] == value.shape[1],
|
|
184
|
+
"value needs to be a square matrix, got {value}",
|
|
185
|
+
value=value,
|
|
186
|
+
)
|
|
161
187
|
|
|
162
188
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
189
|
+
@checkify.checkify
|
|
190
|
+
def _check_is_lower_triangular(value: T) -> None:
|
|
191
|
+
"""Check if a value is a lower triangular matrix.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
value: The value to check.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If the value is not a lower triangular matrix.
|
|
198
|
+
"""
|
|
199
|
+
checkify.check(
|
|
200
|
+
jnp.all(jnp.tril(value) == value),
|
|
201
|
+
"value needs to be a lower triangular matrix, got {value}",
|
|
202
|
+
value=value,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@checkify.checkify
|
|
207
|
+
def _check_in_bounds(value: T, low: T, high: T) -> None:
|
|
208
|
+
"""Check if a value is bounded between low and high.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
value: The value to check.
|
|
212
|
+
low: The lower bound.
|
|
213
|
+
high: The upper bound.
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
ValueError: If any element of value is outside the bounds.
|
|
217
|
+
"""
|
|
218
|
+
checkify.check(
|
|
219
|
+
jnp.all((value >= low) & (value <= high)),
|
|
220
|
+
"value needs to be bounded between {low} and {high}, got {value}",
|
|
221
|
+
value=value,
|
|
222
|
+
low=low,
|
|
223
|
+
high=high,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _safe_assert(fn: tp.Callable[[tp.Any], None], value: T, **kwargs) -> None:
|
|
228
|
+
error, _ = fn(value, **kwargs)
|
|
229
|
+
checkify.check_error(error)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
|
|
2
2
|
gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
4
|
gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
|
|
@@ -9,7 +9,7 @@ gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
|
|
|
9
9
|
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
10
|
gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
|
|
11
11
|
gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
|
|
12
|
-
gpjax/parameters.py,sha256=
|
|
12
|
+
gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
|
|
13
13
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
14
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
15
|
gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
|
|
@@ -30,8 +30,7 @@ gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5
|
|
|
30
30
|
gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
|
|
31
31
|
gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
|
|
32
32
|
gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
|
|
33
|
-
gpjax/kernels/nonstationary/
|
|
34
|
-
gpjax/kernels/nonstationary/polynomial.py,sha256=yTGobMPbCnKlj4PiQPSXEkWNrj2sjg_x9zFsnFa_j4E,3257
|
|
33
|
+
gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
|
|
35
34
|
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
36
35
|
gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
|
|
37
36
|
gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
|
|
@@ -43,7 +42,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UO
|
|
|
43
42
|
gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
|
|
44
43
|
gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
|
|
45
44
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
46
|
-
gpjax-0.10.
|
|
47
|
-
gpjax-0.10.
|
|
48
|
-
gpjax-0.10.
|
|
49
|
-
gpjax-0.10.
|
|
45
|
+
gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
|
|
46
|
+
gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
47
|
+
gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
48
|
+
gpjax-0.10.2.dist-info/RECORD,,
|
|
@@ -1,406 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
from typing import Optional, Callable, Union, Tuple
|
|
17
|
-
|
|
18
|
-
import beartype.typing as tp
|
|
19
|
-
from flax import nnx
|
|
20
|
-
import jax
|
|
21
|
-
import jax.numpy as jnp
|
|
22
|
-
from jaxtyping import Float, Array, Scalar
|
|
23
|
-
from gpjax.typing import ScalarFloat
|
|
24
|
-
|
|
25
|
-
from gpjax.kernels.base import AbstractKernel
|
|
26
|
-
from gpjax.kernels.computations import AbstractKernelComputation, DenseKernelComputation
|
|
27
|
-
from gpjax.parameters import PositiveReal, Parameter
|
|
28
|
-
from jax.typing import ArrayLike
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def legendre_polynomial(n: int) -> Callable:
|
|
32
|
-
"""Compute the Legendre polynomial of degree n.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
n: Degree of the Legendre polynomial.
|
|
36
|
-
|
|
37
|
-
Returns:
|
|
38
|
-
Function that evaluates the nth Legendre polynomial at given points.
|
|
39
|
-
"""
|
|
40
|
-
if n == 0:
|
|
41
|
-
return lambda x: jnp.ones_like(x)
|
|
42
|
-
elif n == 1:
|
|
43
|
-
return lambda x: x
|
|
44
|
-
else:
|
|
45
|
-
p_n_minus_2 = legendre_polynomial(n - 2)
|
|
46
|
-
p_n_minus_1 = legendre_polynomial(n - 1)
|
|
47
|
-
return (
|
|
48
|
-
lambda x: ((2 * n - 1) * x * p_n_minus_1(x) - (n - 1) * p_n_minus_2(x)) / n
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def legendre_polynomial_derivative(n: int) -> Callable:
|
|
53
|
-
"""Compute the derivative of the Legendre polynomial of degree n.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
n: Degree of the Legendre polynomial.
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
Function that evaluates the derivative of the nth Legendre polynomial.
|
|
60
|
-
"""
|
|
61
|
-
if n == 0:
|
|
62
|
-
return lambda x: jnp.zeros_like(x)
|
|
63
|
-
else:
|
|
64
|
-
p_n_minus_1 = legendre_polynomial(n - 1)
|
|
65
|
-
return (
|
|
66
|
-
lambda x: n
|
|
67
|
-
* (p_n_minus_1(x) - x * legendre_polynomial_derivative(n - 1)(x))
|
|
68
|
-
/ (1 - x**2 + 1e-10)
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def gauss_legendre_quadrature(
|
|
73
|
-
deg: int, a: float = -1.0, b: float = 1.0
|
|
74
|
-
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
75
|
-
"""Generate Gauss-Legendre quadrature points and weights.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
deg: Number of quadrature points.
|
|
79
|
-
a: Lower limit of integration (default: -1.0).
|
|
80
|
-
b: Upper limit of integration (default: 1.0).
|
|
81
|
-
|
|
82
|
-
Returns:
|
|
83
|
-
Tuple of (points, weights) for Gauss-Legendre quadrature.
|
|
84
|
-
"""
|
|
85
|
-
# For computational efficiency, use a simpler approach for small degrees
|
|
86
|
-
if deg <= 4:
|
|
87
|
-
# Hardcoded points and weights for small degrees
|
|
88
|
-
if deg == 1:
|
|
89
|
-
x = jnp.array([0.0])
|
|
90
|
-
w = jnp.array([2.0])
|
|
91
|
-
elif deg == 2:
|
|
92
|
-
x = jnp.array([-0.5773502691896257, 0.5773502691896257])
|
|
93
|
-
w = jnp.array([1.0, 1.0])
|
|
94
|
-
elif deg == 3:
|
|
95
|
-
x = jnp.array([-0.7745966692414834, 0.0, 0.7745966692414834])
|
|
96
|
-
w = jnp.array([0.5555555555555556, 0.8888888888888888, 0.5555555555555556])
|
|
97
|
-
elif deg == 4:
|
|
98
|
-
x = jnp.array(
|
|
99
|
-
[
|
|
100
|
-
-0.8611363115940526,
|
|
101
|
-
-0.3399810435848563,
|
|
102
|
-
0.3399810435848563,
|
|
103
|
-
0.8611363115940526,
|
|
104
|
-
]
|
|
105
|
-
)
|
|
106
|
-
w = jnp.array(
|
|
107
|
-
[
|
|
108
|
-
0.3478548451374538,
|
|
109
|
-
0.6521451548625461,
|
|
110
|
-
0.6521451548625461,
|
|
111
|
-
0.3478548451374538,
|
|
112
|
-
]
|
|
113
|
-
)
|
|
114
|
-
else:
|
|
115
|
-
# Initial guess for roots (Chebyshev nodes)
|
|
116
|
-
k = jnp.arange(1, deg + 1)
|
|
117
|
-
x0 = jnp.cos(jnp.pi * (k - 0.25) / (deg + 0.5))
|
|
118
|
-
|
|
119
|
-
# Newton iteration to find roots more accurately
|
|
120
|
-
# In practice, we would use more iterations, but for simplicity we use a fixed number
|
|
121
|
-
P_n = legendre_polynomial(deg)
|
|
122
|
-
dP_n = legendre_polynomial_derivative(deg)
|
|
123
|
-
|
|
124
|
-
# Single Newton step for demonstration (would use a loop in practice)
|
|
125
|
-
x = x0 - P_n(x0) / (dP_n(x0) + 1e-10)
|
|
126
|
-
|
|
127
|
-
# Compute weights
|
|
128
|
-
w = 2.0 / ((1.0 - x**2) * dP_n(x) ** 2 + 1e-10)
|
|
129
|
-
|
|
130
|
-
# Scale from [-1,1] to [a,b]
|
|
131
|
-
x = 0.5 * (b - a) * x + 0.5 * (b + a)
|
|
132
|
-
w = 0.5 * (b - a) * w
|
|
133
|
-
|
|
134
|
-
return x, w
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class OrthogonalAdditiveKernel(AbstractKernel):
|
|
138
|
-
"""Orthogonal Additive Kernels (OAKs) generalize additive kernels by orthogonalizing
|
|
139
|
-
the feature space to create uncorrelated kernel components.
|
|
140
|
-
|
|
141
|
-
This implementation uses a Gauss-Legendre quadrature approximation for the required
|
|
142
|
-
one-dimensional integrals involving the base kernels, allowing for arbitrary base kernels.
|
|
143
|
-
|
|
144
|
-
References:
|
|
145
|
-
- X. Lu, A. Boukouvalas, and J. Hensman. Additive Gaussian processes revisited.
|
|
146
|
-
Proceedings of the 39th International Conference on Machine Learning. Jul 2022.
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
base_kernel: AbstractKernel
|
|
150
|
-
quad_deg: int
|
|
151
|
-
dim: int
|
|
152
|
-
offset: nnx.Variable
|
|
153
|
-
coeffs_1: nnx.Variable
|
|
154
|
-
coeffs_2: Optional[nnx.Variable]
|
|
155
|
-
z: jnp.ndarray
|
|
156
|
-
w: jnp.ndarray
|
|
157
|
-
name: str = "OrthogonalAdditiveKernel"
|
|
158
|
-
|
|
159
|
-
def __init__(
|
|
160
|
-
self,
|
|
161
|
-
base_kernel: AbstractKernel,
|
|
162
|
-
dim: int,
|
|
163
|
-
quad_deg: int = 32,
|
|
164
|
-
second_order: bool = False,
|
|
165
|
-
active_dims: tp.Union[list[int], slice, None] = None,
|
|
166
|
-
n_dims: tp.Union[int, None] = None,
|
|
167
|
-
offset: tp.Union[float, Parameter[ScalarFloat]] = 1.0,
|
|
168
|
-
coeffs_1: tp.Union[ArrayLike, Parameter[ArrayLike]] = None,
|
|
169
|
-
coeffs_2: tp.Union[ArrayLike, Parameter[ArrayLike]] = None,
|
|
170
|
-
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
171
|
-
):
|
|
172
|
-
"""Initialise the OrthogonalAdditiveKernel.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
base_kernel: The kernel which to orthogonalize and evaluate.
|
|
176
|
-
dim: Input dimensionality of the kernel.
|
|
177
|
-
quad_deg: Number of integration nodes for orthogonalization.
|
|
178
|
-
second_order: Toggles second order interactions. If true, both the time and
|
|
179
|
-
space complexity of evaluating the kernel are quadratic in `dim`.
|
|
180
|
-
active_dims: The indices of the input dimensions that the kernel operates on.
|
|
181
|
-
n_dims: The number of input dimensions. If not provided, it will be inferred.
|
|
182
|
-
offset: The zeroth-order coefficient.
|
|
183
|
-
coeffs_1: The first-order coefficients. Should be a 1D array of length dim.
|
|
184
|
-
coeffs_2: The second-order coefficients. Should be a 2D array of shape (dim, dim).
|
|
185
|
-
compute_engine: The computation engine to use for kernel evaluations.
|
|
186
|
-
"""
|
|
187
|
-
super().__init__(
|
|
188
|
-
active_dims=active_dims, n_dims=n_dims, compute_engine=compute_engine
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
self.base_kernel = base_kernel
|
|
192
|
-
self.quad_deg = quad_deg
|
|
193
|
-
self.dim = dim
|
|
194
|
-
|
|
195
|
-
# Integration nodes and weights for [0, 1]
|
|
196
|
-
self.z, self.w = gauss_legendre_quadrature(quad_deg, a=0.0, b=1.0)
|
|
197
|
-
|
|
198
|
-
# Create expandable axes
|
|
199
|
-
z_expanded = jnp.expand_dims(self.z, axis=-1)
|
|
200
|
-
self.z = jnp.broadcast_to(z_expanded, (quad_deg, dim))
|
|
201
|
-
self.w = jnp.expand_dims(self.w, axis=-1)
|
|
202
|
-
|
|
203
|
-
# Default coefficients if not provided
|
|
204
|
-
if isinstance(offset, Parameter):
|
|
205
|
-
self.offset = offset
|
|
206
|
-
else:
|
|
207
|
-
self.offset = PositiveReal(jnp.array(offset))
|
|
208
|
-
|
|
209
|
-
if coeffs_1 is None:
|
|
210
|
-
log_d = jnp.log(dim)
|
|
211
|
-
default_coeffs_1 = jnp.exp(-log_d) * jnp.ones(dim)
|
|
212
|
-
self.coeffs_1 = PositiveReal(default_coeffs_1)
|
|
213
|
-
elif isinstance(coeffs_1, Parameter):
|
|
214
|
-
self.coeffs_1 = coeffs_1
|
|
215
|
-
else:
|
|
216
|
-
self.coeffs_1 = PositiveReal(jnp.array(coeffs_1))
|
|
217
|
-
|
|
218
|
-
self.second_order = second_order
|
|
219
|
-
if second_order:
|
|
220
|
-
if coeffs_2 is None:
|
|
221
|
-
log_d = jnp.log(dim)
|
|
222
|
-
# Initialize with zeros for upper triangular part (excluding diagonal)
|
|
223
|
-
n_entries = dim * (dim - 1) // 2
|
|
224
|
-
default_coeffs_2 = jnp.exp(-2 * log_d) * jnp.ones(n_entries)
|
|
225
|
-
self.coeffs_2_raw = PositiveReal(default_coeffs_2)
|
|
226
|
-
elif isinstance(coeffs_2, Parameter):
|
|
227
|
-
self.coeffs_2_raw = coeffs_2
|
|
228
|
-
else:
|
|
229
|
-
self.coeffs_2_raw = PositiveReal(jnp.array(coeffs_2))
|
|
230
|
-
|
|
231
|
-
# Pre-compute indices for efficient triu operations
|
|
232
|
-
self.triu_indices = jnp.triu_indices(dim, k=1)
|
|
233
|
-
else:
|
|
234
|
-
self.coeffs_2_raw = None
|
|
235
|
-
|
|
236
|
-
# Compute normalizer (in __call__)
|
|
237
|
-
self._normalizer = None
|
|
238
|
-
|
|
239
|
-
@property
|
|
240
|
-
def coeffs_2(self) -> Optional[jnp.ndarray]:
|
|
241
|
-
"""Returns a full matrix of second-order coefficients.
|
|
242
|
-
|
|
243
|
-
Returns:
|
|
244
|
-
A dim x dim array of second-order coefficients or None if second_order is False.
|
|
245
|
-
"""
|
|
246
|
-
if not self.second_order or self.coeffs_2_raw is None:
|
|
247
|
-
return None
|
|
248
|
-
|
|
249
|
-
# Create a full matrix from the raw coefficients
|
|
250
|
-
coeffs_2_flat = self.coeffs_2_raw.value
|
|
251
|
-
coeffs_2_full = jnp.zeros((self.dim, self.dim))
|
|
252
|
-
|
|
253
|
-
# Fill the upper triangular part
|
|
254
|
-
i, j = self.triu_indices
|
|
255
|
-
coeffs_2_full = coeffs_2_full.at[i, j].set(coeffs_2_flat)
|
|
256
|
-
|
|
257
|
-
# Make it symmetric
|
|
258
|
-
coeffs_2_full = coeffs_2_full + jnp.transpose(coeffs_2_full)
|
|
259
|
-
|
|
260
|
-
return coeffs_2_full
|
|
261
|
-
|
|
262
|
-
def normalizer(self, eps: float = 1e-6) -> jnp.ndarray:
|
|
263
|
-
"""Integrates the orthogonalized base kernels over [0, 1] x [0, 1].
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
eps: Minimum value constraint on the normalizers to avoid division by zero.
|
|
267
|
-
|
|
268
|
-
Returns:
|
|
269
|
-
A d-dim tensor of normalization constants.
|
|
270
|
-
"""
|
|
271
|
-
if self._normalizer is None or self.training:
|
|
272
|
-
# Compute K(z, z) - base kernel gram matrix on integration points
|
|
273
|
-
K_zz = self.base_kernel.cross_covariance(self.z, self.z)
|
|
274
|
-
|
|
275
|
-
# Integrate: w^T * K * w
|
|
276
|
-
self._normalizer = jnp.matmul(
|
|
277
|
-
jnp.matmul(jnp.transpose(self.w), K_zz), self.w
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
# Ensure positive values
|
|
281
|
-
self._normalizer = jnp.maximum(self._normalizer, eps)
|
|
282
|
-
|
|
283
|
-
return self._normalizer
|
|
284
|
-
|
|
285
|
-
def _orthogonal_base_kernels(
|
|
286
|
-
self, x1: Float[Array, "N D"], x2: Float[Array, "M D"]
|
|
287
|
-
) -> Float[Array, "D N M"]:
|
|
288
|
-
"""Evaluates the set of d orthogonalized base kernels.
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
x1: Input array of shape [N, D]
|
|
292
|
-
x2: Input array of shape [M, D]
|
|
293
|
-
|
|
294
|
-
Returns:
|
|
295
|
-
Array of shape [D, N, M] with orthogonalized kernel evaluations
|
|
296
|
-
"""
|
|
297
|
-
# Compute base kernel between inputs
|
|
298
|
-
K_x1x2 = self.base_kernel.cross_covariance(x1, x2) # [N, M]
|
|
299
|
-
|
|
300
|
-
# Compute normalizer
|
|
301
|
-
norm = jnp.sqrt(self.normalizer())
|
|
302
|
-
w_normalized = self.w / norm
|
|
303
|
-
|
|
304
|
-
# Compute base kernel between x1 and integration points z
|
|
305
|
-
K_x1z = self.base_kernel.cross_covariance(x1, self.z) # [N, quad_deg]
|
|
306
|
-
S_x1 = jnp.matmul(K_x1z, w_normalized) # [N, 1]
|
|
307
|
-
|
|
308
|
-
# Compute base kernel between x2 and integration points z
|
|
309
|
-
if x1 is x2:
|
|
310
|
-
S_x2 = S_x1
|
|
311
|
-
else:
|
|
312
|
-
K_x2z = self.base_kernel.cross_covariance(x2, self.z) # [M, quad_deg]
|
|
313
|
-
S_x2 = jnp.matmul(K_x2z, w_normalized) # [M, 1]
|
|
314
|
-
|
|
315
|
-
# Compute orthogonal kernel: K_x1x2 - S_x1 * S_x2^T
|
|
316
|
-
K_ortho = K_x1x2 - jnp.outer(S_x1, S_x2)
|
|
317
|
-
|
|
318
|
-
return K_ortho
|
|
319
|
-
|
|
320
|
-
def __call__(
|
|
321
|
-
self,
|
|
322
|
-
x: Float[Array, "N D"],
|
|
323
|
-
y: Float[Array, "M D"],
|
|
324
|
-
) -> ScalarFloat:
|
|
325
|
-
"""Evaluate the kernel at a single pair of inputs.
|
|
326
|
-
|
|
327
|
-
Args:
|
|
328
|
-
x: First input.
|
|
329
|
-
y: Second input.
|
|
330
|
-
|
|
331
|
-
Returns:
|
|
332
|
-
The kernel value at (x, y).
|
|
333
|
-
"""
|
|
334
|
-
# Slice inputs to relevant dimensions
|
|
335
|
-
x_sliced = self.slice_input(x)
|
|
336
|
-
y_sliced = self.slice_input(y)
|
|
337
|
-
|
|
338
|
-
# Get orthogonalized kernels
|
|
339
|
-
K_ortho = self._orthogonal_base_kernels(
|
|
340
|
-
jnp.expand_dims(x_sliced, 0), jnp.expand_dims(y_sliced, 0)
|
|
341
|
-
) # [1, 1]
|
|
342
|
-
|
|
343
|
-
# Apply first-order effects
|
|
344
|
-
first_order = jnp.sum(self.coeffs_1.value * K_ortho)
|
|
345
|
-
|
|
346
|
-
# Add offset
|
|
347
|
-
result = self.offset.value + first_order
|
|
348
|
-
|
|
349
|
-
# Add second-order effects if enabled
|
|
350
|
-
if self.second_order and self.coeffs_2 is not None:
|
|
351
|
-
# For a single point evaluation, we use a simpler approach
|
|
352
|
-
# Computing the tensor of second order interactions
|
|
353
|
-
second_order = 0.0
|
|
354
|
-
for i in range(self.dim):
|
|
355
|
-
for j in range(i + 1, self.dim):
|
|
356
|
-
coef = self.coeffs_2[i, j]
|
|
357
|
-
if coef > 0:
|
|
358
|
-
second_order += coef * K_ortho[i] * K_ortho[j]
|
|
359
|
-
|
|
360
|
-
result = result + second_order
|
|
361
|
-
|
|
362
|
-
return result
|
|
363
|
-
|
|
364
|
-
def cross_covariance(
|
|
365
|
-
self, x1: Float[Array, "N D"], x2: Float[Array, "M D"]
|
|
366
|
-
) -> Float[Array, "N M"]:
|
|
367
|
-
"""Compute the cross-covariance matrix of the kernel.
|
|
368
|
-
|
|
369
|
-
Args:
|
|
370
|
-
x1: First input matrix of shape [N, D].
|
|
371
|
-
x2: Second input matrix of shape [M, D].
|
|
372
|
-
|
|
373
|
-
Returns:
|
|
374
|
-
Cross-covariance matrix of shape [N, M].
|
|
375
|
-
"""
|
|
376
|
-
# Slice inputs to relevant dimensions
|
|
377
|
-
x1 = self.slice_input(x1)
|
|
378
|
-
x2 = self.slice_input(x2)
|
|
379
|
-
|
|
380
|
-
# Get orthogonalized kernels for all dimensions
|
|
381
|
-
K_ortho = self._orthogonal_base_kernels(x1, x2) # [D, N, M]
|
|
382
|
-
|
|
383
|
-
# Apply first-order effects (sum over dimensions)
|
|
384
|
-
coeffs_1 = self.coeffs_1.value
|
|
385
|
-
first_order = jnp.tensordot(coeffs_1, K_ortho, axes=([0], [0])) # [N, M]
|
|
386
|
-
|
|
387
|
-
# Add offset (broadcast to match output shape)
|
|
388
|
-
result = jnp.broadcast_to(self.offset.value, first_order.shape) + first_order
|
|
389
|
-
|
|
390
|
-
# Add second-order effects if enabled
|
|
391
|
-
if self.second_order and self.coeffs_2 is not None:
|
|
392
|
-
# Compute second-order interactions using einsum
|
|
393
|
-
coeffs_2_full = self.coeffs_2
|
|
394
|
-
second_order = jnp.einsum(
|
|
395
|
-
"ij,ink,jml->nml",
|
|
396
|
-
coeffs_2_full,
|
|
397
|
-
jnp.expand_dims(K_ortho, 1),
|
|
398
|
-
jnp.expand_dims(K_ortho, 0),
|
|
399
|
-
)
|
|
400
|
-
|
|
401
|
-
# Sum over dimensions i, j
|
|
402
|
-
second_order = jnp.sum(second_order, axis=(0, 1))
|
|
403
|
-
|
|
404
|
-
result = result + second_order
|
|
405
|
-
|
|
406
|
-
return result
|
|
File without changes
|
|
File without changes
|