gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/distributions.py +16 -56
- gpjax/fit.py +11 -6
- gpjax/gps.py +61 -73
- gpjax/kernels/approximations/rff.py +2 -5
- gpjax/kernels/base.py +2 -5
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +7 -6
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +65 -0
- gpjax/mean_functions.py +8 -7
- gpjax/objectives.py +22 -21
- gpjax/parameters.py +11 -23
- gpjax/variational_families.py +93 -67
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/METADATA +50 -18
- gpjax-0.12.2.dist-info/RECORD +52 -0
- gpjax/lower_cholesky.py +0 -69
- gpjax-0.11.2.dist-info/RECORD +0 -49
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/WHEEL +0 -0
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""Linear operator abstractions for GPJax."""
|
|
2
|
+
|
|
3
|
+
from abc import (
|
|
4
|
+
ABC,
|
|
5
|
+
abstractmethod,
|
|
6
|
+
)
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
List,
|
|
10
|
+
Tuple,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from jax import Array
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
import jax.tree_util as jtu
|
|
17
|
+
from jaxtyping import Float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LinearOperator(ABC):
|
|
21
|
+
"""Abstract base class for linear operators."""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def shape(self) -> Tuple[int, int]:
|
|
29
|
+
"""Return the shape of the operator."""
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def dtype(self) -> jnp.dtype:
|
|
34
|
+
"""Return the data type of the operator."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
38
|
+
"""Convert the operator to a dense JAX array."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def T(self) -> "LinearOperator":
|
|
42
|
+
"""Return the transpose of the operator."""
|
|
43
|
+
# Default implementation: convert to dense and transpose
|
|
44
|
+
return Dense(self.to_dense().T)
|
|
45
|
+
|
|
46
|
+
def __matmul__(self, other):
|
|
47
|
+
"""Matrix multiplication with another array or operator."""
|
|
48
|
+
if hasattr(other, "to_dense"):
|
|
49
|
+
# Other is a LinearOperator
|
|
50
|
+
return Dense(self.to_dense() @ other.to_dense())
|
|
51
|
+
else:
|
|
52
|
+
# Other is a JAX array
|
|
53
|
+
return self.to_dense() @ other
|
|
54
|
+
|
|
55
|
+
def __rmatmul__(self, other):
|
|
56
|
+
"""Right matrix multiplication (other @ self)."""
|
|
57
|
+
if hasattr(other, "to_dense"):
|
|
58
|
+
# Other is a LinearOperator
|
|
59
|
+
return Dense(other.to_dense() @ self.to_dense())
|
|
60
|
+
else:
|
|
61
|
+
# Other is a JAX array
|
|
62
|
+
return other @ self.to_dense()
|
|
63
|
+
|
|
64
|
+
def __add__(self, other):
|
|
65
|
+
"""Addition with another array or operator."""
|
|
66
|
+
if hasattr(other, "to_dense"):
|
|
67
|
+
# Other is a LinearOperator
|
|
68
|
+
return Dense(self.to_dense() + other.to_dense())
|
|
69
|
+
else:
|
|
70
|
+
# Other is a JAX array
|
|
71
|
+
return Dense(self.to_dense() + other)
|
|
72
|
+
|
|
73
|
+
def __radd__(self, other):
|
|
74
|
+
"""Right addition (other + self)."""
|
|
75
|
+
if hasattr(other, "to_dense"):
|
|
76
|
+
# Other is a LinearOperator
|
|
77
|
+
return Dense(other.to_dense() + self.to_dense())
|
|
78
|
+
else:
|
|
79
|
+
# Other is a JAX array
|
|
80
|
+
return Dense(other + self.to_dense())
|
|
81
|
+
|
|
82
|
+
def __sub__(self, other):
|
|
83
|
+
"""Subtraction with another array or operator."""
|
|
84
|
+
if hasattr(other, "to_dense"):
|
|
85
|
+
# Other is a LinearOperator
|
|
86
|
+
return Dense(self.to_dense() - other.to_dense())
|
|
87
|
+
else:
|
|
88
|
+
# Other is a JAX array
|
|
89
|
+
return Dense(self.to_dense() - other)
|
|
90
|
+
|
|
91
|
+
def __rsub__(self, other):
|
|
92
|
+
"""Right subtraction (other - self)."""
|
|
93
|
+
if hasattr(other, "to_dense"):
|
|
94
|
+
# Other is a LinearOperator
|
|
95
|
+
return Dense(other.to_dense() - self.to_dense())
|
|
96
|
+
else:
|
|
97
|
+
# Other is a JAX array
|
|
98
|
+
return Dense(other - self.to_dense())
|
|
99
|
+
|
|
100
|
+
def __mul__(self, other):
|
|
101
|
+
"""Scalar multiplication (self * scalar)."""
|
|
102
|
+
if jnp.isscalar(other):
|
|
103
|
+
return Dense(self.to_dense() * other)
|
|
104
|
+
else:
|
|
105
|
+
# Element-wise multiplication with array
|
|
106
|
+
return Dense(self.to_dense() * other)
|
|
107
|
+
|
|
108
|
+
def __rmul__(self, other):
|
|
109
|
+
"""Right scalar multiplication (scalar * self)."""
|
|
110
|
+
if jnp.isscalar(other):
|
|
111
|
+
return Dense(other * self.to_dense())
|
|
112
|
+
else:
|
|
113
|
+
# Element-wise multiplication with array
|
|
114
|
+
return Dense(other * self.to_dense())
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Dense(LinearOperator):
|
|
118
|
+
"""Dense linear operator wrapping a JAX array."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, array: Float[Array, "M N"]):
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.array = array
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def shape(self) -> Tuple[int, int]:
|
|
126
|
+
return self.array.shape
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def dtype(self) -> jnp.dtype:
|
|
130
|
+
return self.array.dtype
|
|
131
|
+
|
|
132
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
133
|
+
return self.array
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def T(self) -> "Dense":
|
|
137
|
+
return Dense(self.array.T)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Diagonal(LinearOperator):
|
|
141
|
+
"""Diagonal linear operator."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, diagonal: Float[Array, " N"]):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.diagonal = diagonal
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def shape(self) -> Tuple[int, int]:
|
|
149
|
+
n = self.diagonal.shape[0]
|
|
150
|
+
return (n, n)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def dtype(self) -> jnp.dtype:
|
|
154
|
+
return self.diagonal.dtype
|
|
155
|
+
|
|
156
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
157
|
+
return jnp.diag(self.diagonal)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def T(self) -> "Diagonal":
|
|
161
|
+
return Diagonal(self.diagonal)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Identity(LinearOperator):
|
|
165
|
+
"""Identity linear operator."""
|
|
166
|
+
|
|
167
|
+
def __init__(self, shape: Union[int, Tuple[int, int]], dtype=jnp.float64):
|
|
168
|
+
super().__init__()
|
|
169
|
+
if isinstance(shape, int):
|
|
170
|
+
self._shape = (shape, shape)
|
|
171
|
+
else:
|
|
172
|
+
if shape[0] != shape[1]:
|
|
173
|
+
raise ValueError(f"Identity matrix must be square, got shape {shape}")
|
|
174
|
+
self._shape = shape
|
|
175
|
+
self._dtype = dtype
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def shape(self) -> Tuple[int, int]:
|
|
179
|
+
return self._shape
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def dtype(self) -> Any:
|
|
183
|
+
return self._dtype
|
|
184
|
+
|
|
185
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
186
|
+
n = self._shape[0]
|
|
187
|
+
return jnp.eye(n, dtype=self._dtype)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def T(self) -> "Identity":
|
|
191
|
+
return Identity(self._shape, dtype=self._dtype)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class Triangular(LinearOperator):
|
|
195
|
+
"""Triangular linear operator."""
|
|
196
|
+
|
|
197
|
+
def __init__(self, array: Float[Array, "N N"], lower: bool = True):
|
|
198
|
+
super().__init__()
|
|
199
|
+
self.array = array
|
|
200
|
+
self.lower = lower
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def shape(self) -> Tuple[int, int]:
|
|
204
|
+
return self.array.shape
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def dtype(self) -> Any:
|
|
208
|
+
return self.array.dtype
|
|
209
|
+
|
|
210
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
211
|
+
if self.lower:
|
|
212
|
+
return jnp.tril(self.array)
|
|
213
|
+
else:
|
|
214
|
+
return jnp.triu(self.array)
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def T(self) -> "Triangular":
|
|
218
|
+
return Triangular(self.array.T, lower=not self.lower)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class BlockDiag(LinearOperator):
|
|
222
|
+
"""Block diagonal linear operator."""
|
|
223
|
+
|
|
224
|
+
def __init__(
|
|
225
|
+
self, operators: List[LinearOperator], multiplicities: List[int] = None
|
|
226
|
+
):
|
|
227
|
+
super().__init__()
|
|
228
|
+
self.operators = operators
|
|
229
|
+
|
|
230
|
+
# Handle multiplicities - how many times each block is repeated
|
|
231
|
+
if multiplicities is None:
|
|
232
|
+
self.multiplicities = [1] * len(operators)
|
|
233
|
+
else:
|
|
234
|
+
if len(multiplicities) != len(operators):
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Length of multiplicities ({len(multiplicities)}) must match operators ({len(operators)})"
|
|
237
|
+
)
|
|
238
|
+
self.multiplicities = multiplicities
|
|
239
|
+
|
|
240
|
+
# Calculate total shape with multiplicities
|
|
241
|
+
rows = sum(
|
|
242
|
+
op.shape[0] * mult
|
|
243
|
+
for op, mult in zip(operators, self.multiplicities, strict=False)
|
|
244
|
+
)
|
|
245
|
+
cols = sum(
|
|
246
|
+
op.shape[1] * mult
|
|
247
|
+
for op, mult in zip(operators, self.multiplicities, strict=False)
|
|
248
|
+
)
|
|
249
|
+
self._shape = (rows, cols)
|
|
250
|
+
|
|
251
|
+
# Use dtype of first operator (assuming all same dtype)
|
|
252
|
+
if operators:
|
|
253
|
+
self._dtype = operators[0].dtype
|
|
254
|
+
else:
|
|
255
|
+
self._dtype = jnp.float64
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def shape(self) -> Tuple[int, int]:
|
|
259
|
+
return self._shape
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def dtype(self) -> Any:
|
|
263
|
+
return self._dtype
|
|
264
|
+
|
|
265
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
266
|
+
if not self.operators:
|
|
267
|
+
return jnp.zeros(self._shape, dtype=self._dtype)
|
|
268
|
+
|
|
269
|
+
# Convert each operator to dense and create block diagonal with multiplicities
|
|
270
|
+
expanded_blocks = []
|
|
271
|
+
for op, mult in zip(self.operators, self.multiplicities, strict=False):
|
|
272
|
+
op_dense = op.to_dense()
|
|
273
|
+
for _ in range(mult):
|
|
274
|
+
expanded_blocks.append(op_dense)
|
|
275
|
+
|
|
276
|
+
# Create the full block diagonal matrix
|
|
277
|
+
n_blocks = len(expanded_blocks)
|
|
278
|
+
if n_blocks == 0:
|
|
279
|
+
return jnp.zeros(self._shape, dtype=self._dtype)
|
|
280
|
+
|
|
281
|
+
# Build the block diagonal matrix
|
|
282
|
+
rows = []
|
|
283
|
+
for i in range(n_blocks):
|
|
284
|
+
row = []
|
|
285
|
+
for j in range(n_blocks):
|
|
286
|
+
if i == j:
|
|
287
|
+
row.append(expanded_blocks[i])
|
|
288
|
+
else:
|
|
289
|
+
row.append(
|
|
290
|
+
jnp.zeros(
|
|
291
|
+
(expanded_blocks[i].shape[0], expanded_blocks[j].shape[1]),
|
|
292
|
+
dtype=self._dtype,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
rows.append(row)
|
|
296
|
+
return jnp.block(rows)
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def T(self) -> "BlockDiag":
|
|
300
|
+
transposed_ops = [op.T for op in self.operators]
|
|
301
|
+
return BlockDiag(transposed_ops, multiplicities=self.multiplicities)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Kronecker(LinearOperator):
|
|
305
|
+
"""Kronecker product linear operator."""
|
|
306
|
+
|
|
307
|
+
def __init__(self, operators: List[LinearOperator]):
|
|
308
|
+
super().__init__()
|
|
309
|
+
if len(operators) < 2:
|
|
310
|
+
raise ValueError("Kronecker product requires at least 2 operators")
|
|
311
|
+
self.operators = operators
|
|
312
|
+
|
|
313
|
+
# Calculate shape as product of individual shapes
|
|
314
|
+
rows = 1
|
|
315
|
+
cols = 1
|
|
316
|
+
for op in operators:
|
|
317
|
+
rows *= op.shape[0]
|
|
318
|
+
cols *= op.shape[1]
|
|
319
|
+
self._shape = (rows, cols)
|
|
320
|
+
|
|
321
|
+
# Use dtype of first operator
|
|
322
|
+
self._dtype = operators[0].dtype
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def shape(self) -> Tuple[int, int]:
|
|
326
|
+
return self._shape
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def dtype(self) -> Any:
|
|
330
|
+
return self._dtype
|
|
331
|
+
|
|
332
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
333
|
+
# Convert to dense and compute Kronecker product
|
|
334
|
+
result = self.operators[0].to_dense()
|
|
335
|
+
for op in self.operators[1:]:
|
|
336
|
+
result = jnp.kron(result, op.to_dense())
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def T(self) -> "Kronecker":
|
|
341
|
+
transposed_ops = [op.T for op in self.operators]
|
|
342
|
+
return Kronecker(transposed_ops)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _dense_tree_flatten(dense):
|
|
346
|
+
return (dense.array,), None
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _dense_tree_unflatten(aux_data, children):
|
|
350
|
+
return Dense(children[0])
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
jtu.register_pytree_node(Dense, _dense_tree_flatten, _dense_tree_unflatten)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _diagonal_tree_flatten(diagonal):
|
|
357
|
+
return (diagonal.diagonal,), None
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _diagonal_tree_unflatten(aux_data, children):
|
|
361
|
+
return Diagonal(children[0])
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
jtu.register_pytree_node(Diagonal, _diagonal_tree_flatten, _diagonal_tree_unflatten)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _identity_tree_flatten(identity):
|
|
368
|
+
return (), (identity._shape, identity._dtype)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _identity_tree_unflatten(aux_data, children):
|
|
372
|
+
shape, dtype = aux_data
|
|
373
|
+
return Identity(shape, dtype)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
jtu.register_pytree_node(Identity, _identity_tree_flatten, _identity_tree_unflatten)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _triangular_tree_flatten(triangular):
|
|
380
|
+
return (triangular.array,), triangular.lower
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _triangular_tree_unflatten(aux_data, children):
|
|
384
|
+
return Triangular(children[0], aux_data)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
jtu.register_pytree_node(
|
|
388
|
+
Triangular, _triangular_tree_flatten, _triangular_tree_unflatten
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _blockdiag_tree_flatten(blockdiag):
|
|
393
|
+
return tuple(blockdiag.operators), blockdiag.multiplicities
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _blockdiag_tree_unflatten(aux_data, children):
|
|
397
|
+
return BlockDiag(list(children), aux_data)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
jtu.register_pytree_node(BlockDiag, _blockdiag_tree_flatten, _blockdiag_tree_unflatten)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _kronecker_tree_flatten(kronecker):
|
|
404
|
+
return tuple(kronecker.operators), None
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _kronecker_tree_unflatten(aux_data, children):
|
|
408
|
+
return Kronecker(list(children))
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten)
|
gpjax/linalg/utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Utility functions for the linear algebra module."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array
|
|
5
|
+
|
|
6
|
+
from gpjax.linalg.operators import LinearOperator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PSDAnnotation:
|
|
10
|
+
"""Marker class for PSD (Positive Semi-Definite) annotations."""
|
|
11
|
+
|
|
12
|
+
def __call__(self, A: LinearOperator) -> LinearOperator:
|
|
13
|
+
"""Make PSD annotation callable."""
|
|
14
|
+
return psd(A)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Create the PSD marker similar to cola.PSD
|
|
18
|
+
PSD = PSDAnnotation()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def psd(A: LinearOperator) -> LinearOperator:
|
|
22
|
+
"""Mark a linear operator as positive semi-definite.
|
|
23
|
+
|
|
24
|
+
This function acts as a marker/wrapper for positive semi-definite matrices.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
A: A LinearOperator that is assumed to be positive semi-definite.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The same LinearOperator, marked as PSD.
|
|
31
|
+
"""
|
|
32
|
+
# Add annotations attribute if it doesn't exist
|
|
33
|
+
if not hasattr(A, "annotations"):
|
|
34
|
+
A.annotations = set()
|
|
35
|
+
A.annotations.add(PSD)
|
|
36
|
+
return A
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
|
|
40
|
+
"""Add jitter to the diagonal of a matrix for numerical stability.
|
|
41
|
+
|
|
42
|
+
This function adds a small positive value (jitter) to the diagonal elements
|
|
43
|
+
of a square matrix to improve numerical stability, particularly for
|
|
44
|
+
Cholesky decompositions and matrix inversions.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
matrix: A square matrix to which jitter will be added.
|
|
48
|
+
jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The matrix with jitter added to its diagonal.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> import jax.numpy as jnp
|
|
55
|
+
>>> from gpjax.linalg.utils import add_jitter
|
|
56
|
+
>>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
|
|
57
|
+
>>> jittered_matrix = add_jitter(matrix, jitter=0.01)
|
|
58
|
+
"""
|
|
59
|
+
if matrix.ndim != 2:
|
|
60
|
+
raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
|
|
61
|
+
|
|
62
|
+
if matrix.shape[0] != matrix.shape[1]:
|
|
63
|
+
raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
|
|
64
|
+
|
|
65
|
+
return matrix + jnp.eye(matrix.shape[0]) * jitter
|
gpjax/mean_functions.py
CHANGED
|
@@ -27,8 +27,6 @@ from jaxtyping import (
|
|
|
27
27
|
|
|
28
28
|
from gpjax.parameters import (
|
|
29
29
|
Parameter,
|
|
30
|
-
Real,
|
|
31
|
-
Static,
|
|
32
30
|
)
|
|
33
31
|
from gpjax.typing import (
|
|
34
32
|
Array,
|
|
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
|
|
|
132
130
|
|
|
133
131
|
def __init__(
|
|
134
132
|
self,
|
|
135
|
-
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter
|
|
133
|
+
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
|
|
136
134
|
):
|
|
137
|
-
if isinstance(constant, Parameter)
|
|
135
|
+
if isinstance(constant, Parameter):
|
|
138
136
|
self.constant = constant
|
|
139
137
|
else:
|
|
140
|
-
self.constant =
|
|
138
|
+
self.constant = jnp.array(constant)
|
|
141
139
|
|
|
142
140
|
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
|
|
143
141
|
r"""Evaluate the mean function at the given points.
|
|
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
|
|
|
148
146
|
Returns:
|
|
149
147
|
Float[Array, "1"]: The evaluated mean function.
|
|
150
148
|
"""
|
|
151
|
-
|
|
149
|
+
if isinstance(self.constant, Parameter):
|
|
150
|
+
return jnp.ones((x.shape[0], 1)) * self.constant.value
|
|
151
|
+
else:
|
|
152
|
+
return jnp.ones((x.shape[0], 1)) * self.constant
|
|
152
153
|
|
|
153
154
|
|
|
154
155
|
class Zero(Constant):
|
|
@@ -160,7 +161,7 @@ class Zero(Constant):
|
|
|
160
161
|
"""
|
|
161
162
|
|
|
162
163
|
def __init__(self):
|
|
163
|
-
super().__init__(constant=
|
|
164
|
+
super().__init__(constant=0.0)
|
|
164
165
|
|
|
165
166
|
|
|
166
167
|
class CombinationMeanFunction(AbstractMeanFunction):
|
gpjax/objectives.py
CHANGED
|
@@ -1,13 +1,5 @@
|
|
|
1
1
|
from typing import TypeVar
|
|
2
2
|
|
|
3
|
-
from cola.annotations import PSD
|
|
4
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
5
|
-
from cola.linalg.inverse.inv import (
|
|
6
|
-
inv,
|
|
7
|
-
solve,
|
|
8
|
-
)
|
|
9
|
-
from cola.linalg.trace.diag_trace import diag
|
|
10
|
-
from cola.ops.operators import I_like
|
|
11
3
|
from flax import nnx
|
|
12
4
|
from jax import vmap
|
|
13
5
|
import jax.numpy as jnp
|
|
@@ -22,7 +14,13 @@ from gpjax.gps import (
|
|
|
22
14
|
ConjugatePosterior,
|
|
23
15
|
NonConjugatePosterior,
|
|
24
16
|
)
|
|
25
|
-
from gpjax.
|
|
17
|
+
from gpjax.linalg import (
|
|
18
|
+
Dense,
|
|
19
|
+
lower_cholesky,
|
|
20
|
+
psd,
|
|
21
|
+
solve,
|
|
22
|
+
)
|
|
23
|
+
from gpjax.linalg.utils import add_jitter
|
|
26
24
|
from gpjax.typing import (
|
|
27
25
|
Array,
|
|
28
26
|
ScalarFloat,
|
|
@@ -100,9 +98,9 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
|
|
|
100
98
|
|
|
101
99
|
# Σ = (Kxx + Io²) = LLᵀ
|
|
102
100
|
Kxx = posterior.prior.kernel.gram(x)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
Sigma =
|
|
101
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
102
|
+
Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
103
|
+
Sigma = psd(Dense(Sigma_dense))
|
|
106
104
|
|
|
107
105
|
# p(y | x, θ), where θ are the model hyperparameters:
|
|
108
106
|
mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma)
|
|
@@ -164,11 +162,14 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
|
|
|
164
162
|
|
|
165
163
|
# Σ = (Kxx + Io²)
|
|
166
164
|
Kxx = posterior.prior.kernel.gram(x)
|
|
167
|
-
|
|
168
|
-
|
|
165
|
+
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * (
|
|
166
|
+
obs_var + posterior.prior.jitter
|
|
167
|
+
)
|
|
168
|
+
Sigma = psd(Dense(Sigma_dense)) # [N, N]
|
|
169
169
|
|
|
170
|
-
Sigma_inv_y = solve(Sigma, y - mx
|
|
171
|
-
|
|
170
|
+
Sigma_inv_y = solve(Sigma, y - mx) # [N, 1]
|
|
171
|
+
Sigma_inv = jnp.linalg.inv(Sigma.to_dense())
|
|
172
|
+
Sigma_inv_diag = jnp.diag(Sigma_inv)[:, None] # [N, 1]
|
|
172
173
|
|
|
173
174
|
loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
|
|
174
175
|
loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
|
|
@@ -213,8 +214,8 @@ def log_posterior_density(
|
|
|
213
214
|
|
|
214
215
|
# Gram matrix
|
|
215
216
|
Kxx = posterior.prior.kernel.gram(x)
|
|
216
|
-
|
|
217
|
-
Kxx =
|
|
217
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
|
|
218
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
218
219
|
Lx = lower_cholesky(Kxx)
|
|
219
220
|
|
|
220
221
|
# Compute the prior mean function
|
|
@@ -349,8 +350,8 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
349
350
|
noise = variational_family.posterior.likelihood.obs_stddev.value**2
|
|
350
351
|
z = variational_family.inducing_inputs.value
|
|
351
352
|
Kzz = kernel.gram(z)
|
|
352
|
-
|
|
353
|
-
Kzz =
|
|
353
|
+
Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
|
|
354
|
+
Kzz = psd(Dense(Kzz_dense))
|
|
354
355
|
Kzx = kernel.cross_covariance(z, x)
|
|
355
356
|
Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
|
|
356
357
|
μx = mean_function(x)
|
|
@@ -383,7 +384,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
|
|
|
383
384
|
#
|
|
384
385
|
# with A and B defined as above.
|
|
385
386
|
|
|
386
|
-
A = solve(Lz, Kzx
|
|
387
|
+
A = solve(Lz, Kzx) / jnp.sqrt(noise)
|
|
387
388
|
|
|
388
389
|
# AAᵀ
|
|
389
390
|
AAT = jnp.matmul(A, A.T)
|
gpjax/parameters.py
CHANGED
|
@@ -21,23 +21,20 @@ def transform(
|
|
|
21
21
|
r"""Transforms parameters using a bijector.
|
|
22
22
|
|
|
23
23
|
Example:
|
|
24
|
-
```pycon
|
|
25
24
|
>>> from gpjax.parameters import PositiveReal, transform
|
|
26
25
|
>>> import jax.numpy as jnp
|
|
27
26
|
>>> import numpyro.distributions.transforms as npt
|
|
28
27
|
>>> from flax import nnx
|
|
29
28
|
>>> params = nnx.State(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
29
|
+
... {
|
|
30
|
+
... "a": PositiveReal(jnp.array([1.0])),
|
|
31
|
+
... "b": PositiveReal(jnp.array([2.0])),
|
|
32
|
+
... }
|
|
33
|
+
... )
|
|
35
34
|
>>> params_bijection = {'positive': npt.SoftplusTransform()}
|
|
36
35
|
>>> transformed_params = transform(params, params_bijection)
|
|
37
36
|
>>> print(transformed_params["a"].value)
|
|
38
|
-
|
|
39
|
-
```
|
|
40
|
-
|
|
37
|
+
[1.3132617]
|
|
41
38
|
|
|
42
39
|
Args:
|
|
43
40
|
params: A nnx.State object containing parameters to be transformed.
|
|
@@ -49,7 +46,7 @@ def transform(
|
|
|
49
46
|
"""
|
|
50
47
|
|
|
51
48
|
def _inner(param):
|
|
52
|
-
bijector = params_bijection.get(param.
|
|
49
|
+
bijector = params_bijection.get(param.tag, npt.IdentityTransform())
|
|
53
50
|
if inverse:
|
|
54
51
|
transformed_value = bijector.inv(param.value)
|
|
55
52
|
else:
|
|
@@ -60,10 +57,11 @@ def transform(
|
|
|
60
57
|
|
|
61
58
|
gp_params, *other_params = params.split(Parameter, ...)
|
|
62
59
|
|
|
60
|
+
# Transform each parameter in the state
|
|
63
61
|
transformed_gp_params: nnx.State = jtu.tree_map(
|
|
64
|
-
lambda x: _inner(x),
|
|
62
|
+
lambda x: _inner(x) if isinstance(x, Parameter) else x,
|
|
65
63
|
gp_params,
|
|
66
|
-
is_leaf=lambda x: isinstance(x,
|
|
64
|
+
is_leaf=lambda x: isinstance(x, Parameter),
|
|
67
65
|
)
|
|
68
66
|
return nnx.State.merge(transformed_gp_params, *other_params)
|
|
69
67
|
|
|
@@ -79,7 +77,7 @@ class Parameter(nnx.Variable[T]):
|
|
|
79
77
|
_check_is_arraylike(value)
|
|
80
78
|
|
|
81
79
|
super().__init__(value=jnp.asarray(value), **kwargs)
|
|
82
|
-
self.
|
|
80
|
+
self.tag = tag
|
|
83
81
|
|
|
84
82
|
|
|
85
83
|
class NonNegativeReal(Parameter[T]):
|
|
@@ -124,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
|
|
|
124
122
|
)
|
|
125
123
|
|
|
126
124
|
|
|
127
|
-
class Static(nnx.Variable[T]):
|
|
128
|
-
"""Static parameter that is not trainable."""
|
|
129
|
-
|
|
130
|
-
def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
|
|
131
|
-
_check_is_arraylike(value)
|
|
132
|
-
|
|
133
|
-
super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
|
|
134
|
-
self._tag = tag
|
|
135
|
-
|
|
136
|
-
|
|
137
125
|
class LowerTriangular(Parameter[T]):
|
|
138
126
|
"""Parameter that is a lower triangular matrix."""
|
|
139
127
|
|