gpjax 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,237 @@
1
+ """Linear algebra operations for GPJax LinearOperators."""
2
+
3
+ from typing import Union
4
+
5
+ from jax import Array
6
+ import jax.numpy as jnp
7
+ import jax.scipy as jsp
8
+ from jaxtyping import Float
9
+
10
+ from gpjax.linalg.operators import (
11
+ BlockDiag,
12
+ Dense,
13
+ Diagonal,
14
+ Identity,
15
+ Kronecker,
16
+ LinearOperator,
17
+ Triangular,
18
+ )
19
+ from gpjax.typing import ScalarFloat
20
+
21
+
22
+ def lower_cholesky(A: LinearOperator) -> LinearOperator:
23
+ """Compute the lower Cholesky decomposition of a positive semi-definite operator.
24
+
25
+ This function dispatches on the type of the input LinearOperator to provide
26
+ efficient implementations for different operator structures.
27
+
28
+ Args:
29
+ A: A positive semi-definite LinearOperator.
30
+
31
+ Returns:
32
+ The lower triangular Cholesky factor L such that A = L @ L.T.
33
+ """
34
+
35
+ def _handle_identity(A):
36
+ return A
37
+
38
+ def _handle_diagonal(A):
39
+ return Diagonal(jnp.sqrt(A.diagonal))
40
+
41
+ def _handle_triangular(A):
42
+ if A.lower:
43
+ return A
44
+ return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
45
+
46
+ def _handle_kronecker(A):
47
+ cholesky_ops = [lower_cholesky(op) for op in A.operators]
48
+ return Kronecker(cholesky_ops)
49
+
50
+ def _handle_blockdiag(A):
51
+ cholesky_ops = [lower_cholesky(op) for op in A.operators]
52
+ return BlockDiag(cholesky_ops, multiplicities=A.multiplicities)
53
+
54
+ def _handle_dense(A):
55
+ return Triangular(jnp.linalg.cholesky(A.array), lower=True)
56
+
57
+ def _handle_default(A):
58
+ return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
59
+
60
+ dispatch_table = {
61
+ Identity: _handle_identity,
62
+ Diagonal: _handle_diagonal,
63
+ Triangular: _handle_triangular,
64
+ Kronecker: _handle_kronecker,
65
+ BlockDiag: _handle_blockdiag,
66
+ Dense: _handle_dense,
67
+ }
68
+
69
+ handler = dispatch_table.get(type(A), _handle_default)
70
+ return handler(A)
71
+
72
+
73
+ def solve(
74
+ A: LinearOperator,
75
+ b: Union[Float[Array, " N"], Float[Array, " N M"]],
76
+ ) -> Union[Float[Array, " N"], Float[Array, " N M"]]:
77
+ """Solve the linear system A @ x = b for x.
78
+
79
+ This function dispatches on the type of the input LinearOperator to provide
80
+ efficient implementations for different operator structures.
81
+
82
+ Args:
83
+ A: A LinearOperator representing the matrix A.
84
+ b: The right-hand side vector or matrix.
85
+
86
+ Returns:
87
+ The solution x to the linear system.
88
+ """
89
+ # Handle different shapes of b
90
+ if b.ndim == 1:
91
+ b = b[:, None]
92
+ squeeze_output = True
93
+ else:
94
+ squeeze_output = False
95
+
96
+ # Dispatch based on operator type
97
+ if isinstance(A, Identity):
98
+ # Identity matrix: x = b
99
+ result = b
100
+
101
+ elif isinstance(A, Diagonal):
102
+ # Diagonal matrix: element-wise division
103
+ result = b / A.diagonal[:, None]
104
+
105
+ elif isinstance(A, Triangular):
106
+ # Triangular matrix: use triangular solver
107
+ result = jsp.linalg.solve_triangular(A.array, b, lower=A.lower)
108
+
109
+ elif isinstance(A, Dense):
110
+ # Dense matrix: use standard solver
111
+ result = jnp.linalg.solve(A.array, b)
112
+
113
+ else:
114
+ # Default: convert to dense and solve
115
+ result = jnp.linalg.solve(A.to_dense(), b)
116
+
117
+ if squeeze_output:
118
+ result = result.squeeze(-1)
119
+
120
+ return result
121
+
122
+
123
+ def logdet(A: LinearOperator) -> ScalarFloat:
124
+ """Compute the log-determinant of a linear operator.
125
+
126
+ This function dispatches on the type of the input LinearOperator to provide
127
+ efficient implementations for different operator structures.
128
+
129
+ Args:
130
+ A: A LinearOperator.
131
+
132
+ Returns:
133
+ The log-determinant of A.
134
+ """
135
+
136
+ def _handle_identity(A):
137
+ return jnp.array(0.0)
138
+
139
+ def _handle_diagonal(A):
140
+ return jnp.sum(jnp.log(A.diagonal))
141
+
142
+ def _handle_triangular(A):
143
+ diag_elements = jnp.diag(A.array)
144
+ return jnp.sum(jnp.log(diag_elements))
145
+
146
+ def _handle_kronecker(A):
147
+ logdet_val = 0.0
148
+ for i, op in enumerate(A.operators):
149
+ op_logdet = logdet(op)
150
+ power = 1
151
+ for j, other_op in enumerate(A.operators):
152
+ if i != j:
153
+ power *= other_op.shape[0]
154
+ logdet_val += power * op_logdet
155
+ return logdet_val
156
+
157
+ def _handle_blockdiag(A):
158
+ logdet_val = 0.0
159
+ for op, mult in zip(A.operators, A.multiplicities, strict=False):
160
+ logdet_val += mult * logdet(op)
161
+ return logdet_val
162
+
163
+ def _handle_dense(A):
164
+ _, logdet_val = jnp.linalg.slogdet(A.array)
165
+ return logdet_val
166
+
167
+ def _handle_default(A):
168
+ _, logdet_val = jnp.linalg.slogdet(A.to_dense())
169
+ return logdet_val
170
+
171
+ dispatch_table = {
172
+ Identity: _handle_identity,
173
+ Diagonal: _handle_diagonal,
174
+ Triangular: _handle_triangular,
175
+ Kronecker: _handle_kronecker,
176
+ BlockDiag: _handle_blockdiag,
177
+ Dense: _handle_dense,
178
+ }
179
+
180
+ handler = dispatch_table.get(type(A), _handle_default)
181
+ return handler(A)
182
+
183
+
184
+ def diag(A: LinearOperator) -> Float[Array, " N"]:
185
+ """Extract the diagonal of a linear operator.
186
+
187
+ This function dispatches on the type of the input LinearOperator to provide
188
+ efficient implementations for different operator structures.
189
+
190
+ Args:
191
+ A: A LinearOperator.
192
+
193
+ Returns:
194
+ The diagonal elements of A as a 1D array.
195
+ """
196
+
197
+ def _handle_identity(A):
198
+ n = A.shape[0]
199
+ return jnp.ones(n, dtype=A.dtype)
200
+
201
+ def _handle_diagonal(A):
202
+ return A.diagonal
203
+
204
+ def _handle_triangular(A):
205
+ return jnp.diag(A.array)
206
+
207
+ def _handle_kronecker(A):
208
+ result = diag(A.operators[0])
209
+ for op in A.operators[1:]:
210
+ result = jnp.kron(result, diag(op))
211
+ return result
212
+
213
+ def _handle_blockdiag(A):
214
+ diags = []
215
+ for op, mult in zip(A.operators, A.multiplicities, strict=False):
216
+ op_diag = diag(op)
217
+ for _ in range(mult):
218
+ diags.append(op_diag)
219
+ return jnp.concatenate(diags)
220
+
221
+ def _handle_dense(A):
222
+ return jnp.diag(A.array)
223
+
224
+ def _handle_default(A):
225
+ return jnp.diag(A.to_dense())
226
+
227
+ dispatch_table = {
228
+ Identity: _handle_identity,
229
+ Diagonal: _handle_diagonal,
230
+ Triangular: _handle_triangular,
231
+ Kronecker: _handle_kronecker,
232
+ BlockDiag: _handle_blockdiag,
233
+ Dense: _handle_dense,
234
+ }
235
+
236
+ handler = dispatch_table.get(type(A), _handle_default)
237
+ return handler(A)
@@ -0,0 +1,411 @@
1
+ """Linear operator abstractions for GPJax."""
2
+
3
+ from abc import (
4
+ ABC,
5
+ abstractmethod,
6
+ )
7
+ from typing import (
8
+ Any,
9
+ List,
10
+ Tuple,
11
+ Union,
12
+ )
13
+
14
+ from jax import Array
15
+ import jax.numpy as jnp
16
+ import jax.tree_util as jtu
17
+ from jaxtyping import Float
18
+
19
+
20
+ class LinearOperator(ABC):
21
+ """Abstract base class for linear operators."""
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ @property
27
+ @abstractmethod
28
+ def shape(self) -> Tuple[int, int]:
29
+ """Return the shape of the operator."""
30
+
31
+ @property
32
+ @abstractmethod
33
+ def dtype(self) -> jnp.dtype:
34
+ """Return the data type of the operator."""
35
+
36
+ @abstractmethod
37
+ def to_dense(self) -> Float[Array, "M N"]:
38
+ """Convert the operator to a dense JAX array."""
39
+
40
+ @property
41
+ def T(self) -> "LinearOperator":
42
+ """Return the transpose of the operator."""
43
+ # Default implementation: convert to dense and transpose
44
+ return Dense(self.to_dense().T)
45
+
46
+ def __matmul__(self, other):
47
+ """Matrix multiplication with another array or operator."""
48
+ if hasattr(other, "to_dense"):
49
+ # Other is a LinearOperator
50
+ return Dense(self.to_dense() @ other.to_dense())
51
+ else:
52
+ # Other is a JAX array
53
+ return self.to_dense() @ other
54
+
55
+ def __rmatmul__(self, other):
56
+ """Right matrix multiplication (other @ self)."""
57
+ if hasattr(other, "to_dense"):
58
+ # Other is a LinearOperator
59
+ return Dense(other.to_dense() @ self.to_dense())
60
+ else:
61
+ # Other is a JAX array
62
+ return other @ self.to_dense()
63
+
64
+ def __add__(self, other):
65
+ """Addition with another array or operator."""
66
+ if hasattr(other, "to_dense"):
67
+ # Other is a LinearOperator
68
+ return Dense(self.to_dense() + other.to_dense())
69
+ else:
70
+ # Other is a JAX array
71
+ return Dense(self.to_dense() + other)
72
+
73
+ def __radd__(self, other):
74
+ """Right addition (other + self)."""
75
+ if hasattr(other, "to_dense"):
76
+ # Other is a LinearOperator
77
+ return Dense(other.to_dense() + self.to_dense())
78
+ else:
79
+ # Other is a JAX array
80
+ return Dense(other + self.to_dense())
81
+
82
+ def __sub__(self, other):
83
+ """Subtraction with another array or operator."""
84
+ if hasattr(other, "to_dense"):
85
+ # Other is a LinearOperator
86
+ return Dense(self.to_dense() - other.to_dense())
87
+ else:
88
+ # Other is a JAX array
89
+ return Dense(self.to_dense() - other)
90
+
91
+ def __rsub__(self, other):
92
+ """Right subtraction (other - self)."""
93
+ if hasattr(other, "to_dense"):
94
+ # Other is a LinearOperator
95
+ return Dense(other.to_dense() - self.to_dense())
96
+ else:
97
+ # Other is a JAX array
98
+ return Dense(other - self.to_dense())
99
+
100
+ def __mul__(self, other):
101
+ """Scalar multiplication (self * scalar)."""
102
+ if jnp.isscalar(other):
103
+ return Dense(self.to_dense() * other)
104
+ else:
105
+ # Element-wise multiplication with array
106
+ return Dense(self.to_dense() * other)
107
+
108
+ def __rmul__(self, other):
109
+ """Right scalar multiplication (scalar * self)."""
110
+ if jnp.isscalar(other):
111
+ return Dense(other * self.to_dense())
112
+ else:
113
+ # Element-wise multiplication with array
114
+ return Dense(other * self.to_dense())
115
+
116
+
117
+ class Dense(LinearOperator):
118
+ """Dense linear operator wrapping a JAX array."""
119
+
120
+ def __init__(self, array: Float[Array, "M N"]):
121
+ super().__init__()
122
+ self.array = array
123
+
124
+ @property
125
+ def shape(self) -> Tuple[int, int]:
126
+ return self.array.shape
127
+
128
+ @property
129
+ def dtype(self) -> jnp.dtype:
130
+ return self.array.dtype
131
+
132
+ def to_dense(self) -> Float[Array, "M N"]:
133
+ return self.array
134
+
135
+ @property
136
+ def T(self) -> "Dense":
137
+ return Dense(self.array.T)
138
+
139
+
140
+ class Diagonal(LinearOperator):
141
+ """Diagonal linear operator."""
142
+
143
+ def __init__(self, diagonal: Float[Array, " N"]):
144
+ super().__init__()
145
+ self.diagonal = diagonal
146
+
147
+ @property
148
+ def shape(self) -> Tuple[int, int]:
149
+ n = self.diagonal.shape[0]
150
+ return (n, n)
151
+
152
+ @property
153
+ def dtype(self) -> jnp.dtype:
154
+ return self.diagonal.dtype
155
+
156
+ def to_dense(self) -> Float[Array, "N N"]:
157
+ return jnp.diag(self.diagonal)
158
+
159
+ @property
160
+ def T(self) -> "Diagonal":
161
+ return Diagonal(self.diagonal)
162
+
163
+
164
+ class Identity(LinearOperator):
165
+ """Identity linear operator."""
166
+
167
+ def __init__(self, shape: Union[int, Tuple[int, int]], dtype=jnp.float64):
168
+ super().__init__()
169
+ if isinstance(shape, int):
170
+ self._shape = (shape, shape)
171
+ else:
172
+ if shape[0] != shape[1]:
173
+ raise ValueError(f"Identity matrix must be square, got shape {shape}")
174
+ self._shape = shape
175
+ self._dtype = dtype
176
+
177
+ @property
178
+ def shape(self) -> Tuple[int, int]:
179
+ return self._shape
180
+
181
+ @property
182
+ def dtype(self) -> Any:
183
+ return self._dtype
184
+
185
+ def to_dense(self) -> Float[Array, "N N"]:
186
+ n = self._shape[0]
187
+ return jnp.eye(n, dtype=self._dtype)
188
+
189
+ @property
190
+ def T(self) -> "Identity":
191
+ return Identity(self._shape, dtype=self._dtype)
192
+
193
+
194
+ class Triangular(LinearOperator):
195
+ """Triangular linear operator."""
196
+
197
+ def __init__(self, array: Float[Array, "N N"], lower: bool = True):
198
+ super().__init__()
199
+ self.array = array
200
+ self.lower = lower
201
+
202
+ @property
203
+ def shape(self) -> Tuple[int, int]:
204
+ return self.array.shape
205
+
206
+ @property
207
+ def dtype(self) -> Any:
208
+ return self.array.dtype
209
+
210
+ def to_dense(self) -> Float[Array, "N N"]:
211
+ if self.lower:
212
+ return jnp.tril(self.array)
213
+ else:
214
+ return jnp.triu(self.array)
215
+
216
+ @property
217
+ def T(self) -> "Triangular":
218
+ return Triangular(self.array.T, lower=not self.lower)
219
+
220
+
221
+ class BlockDiag(LinearOperator):
222
+ """Block diagonal linear operator."""
223
+
224
+ def __init__(
225
+ self, operators: List[LinearOperator], multiplicities: List[int] = None
226
+ ):
227
+ super().__init__()
228
+ self.operators = operators
229
+
230
+ # Handle multiplicities - how many times each block is repeated
231
+ if multiplicities is None:
232
+ self.multiplicities = [1] * len(operators)
233
+ else:
234
+ if len(multiplicities) != len(operators):
235
+ raise ValueError(
236
+ f"Length of multiplicities ({len(multiplicities)}) must match operators ({len(operators)})"
237
+ )
238
+ self.multiplicities = multiplicities
239
+
240
+ # Calculate total shape with multiplicities
241
+ rows = sum(
242
+ op.shape[0] * mult
243
+ for op, mult in zip(operators, self.multiplicities, strict=False)
244
+ )
245
+ cols = sum(
246
+ op.shape[1] * mult
247
+ for op, mult in zip(operators, self.multiplicities, strict=False)
248
+ )
249
+ self._shape = (rows, cols)
250
+
251
+ # Use dtype of first operator (assuming all same dtype)
252
+ if operators:
253
+ self._dtype = operators[0].dtype
254
+ else:
255
+ self._dtype = jnp.float64
256
+
257
+ @property
258
+ def shape(self) -> Tuple[int, int]:
259
+ return self._shape
260
+
261
+ @property
262
+ def dtype(self) -> Any:
263
+ return self._dtype
264
+
265
+ def to_dense(self) -> Float[Array, "M N"]:
266
+ if not self.operators:
267
+ return jnp.zeros(self._shape, dtype=self._dtype)
268
+
269
+ # Convert each operator to dense and create block diagonal with multiplicities
270
+ expanded_blocks = []
271
+ for op, mult in zip(self.operators, self.multiplicities, strict=False):
272
+ op_dense = op.to_dense()
273
+ for _ in range(mult):
274
+ expanded_blocks.append(op_dense)
275
+
276
+ # Create the full block diagonal matrix
277
+ n_blocks = len(expanded_blocks)
278
+ if n_blocks == 0:
279
+ return jnp.zeros(self._shape, dtype=self._dtype)
280
+
281
+ # Build the block diagonal matrix
282
+ rows = []
283
+ for i in range(n_blocks):
284
+ row = []
285
+ for j in range(n_blocks):
286
+ if i == j:
287
+ row.append(expanded_blocks[i])
288
+ else:
289
+ row.append(
290
+ jnp.zeros(
291
+ (expanded_blocks[i].shape[0], expanded_blocks[j].shape[1]),
292
+ dtype=self._dtype,
293
+ )
294
+ )
295
+ rows.append(row)
296
+ return jnp.block(rows)
297
+
298
+ @property
299
+ def T(self) -> "BlockDiag":
300
+ transposed_ops = [op.T for op in self.operators]
301
+ return BlockDiag(transposed_ops, multiplicities=self.multiplicities)
302
+
303
+
304
+ class Kronecker(LinearOperator):
305
+ """Kronecker product linear operator."""
306
+
307
+ def __init__(self, operators: List[LinearOperator]):
308
+ super().__init__()
309
+ if len(operators) < 2:
310
+ raise ValueError("Kronecker product requires at least 2 operators")
311
+ self.operators = operators
312
+
313
+ # Calculate shape as product of individual shapes
314
+ rows = 1
315
+ cols = 1
316
+ for op in operators:
317
+ rows *= op.shape[0]
318
+ cols *= op.shape[1]
319
+ self._shape = (rows, cols)
320
+
321
+ # Use dtype of first operator
322
+ self._dtype = operators[0].dtype
323
+
324
+ @property
325
+ def shape(self) -> Tuple[int, int]:
326
+ return self._shape
327
+
328
+ @property
329
+ def dtype(self) -> Any:
330
+ return self._dtype
331
+
332
+ def to_dense(self) -> Float[Array, "M N"]:
333
+ # Convert to dense and compute Kronecker product
334
+ result = self.operators[0].to_dense()
335
+ for op in self.operators[1:]:
336
+ result = jnp.kron(result, op.to_dense())
337
+ return result
338
+
339
+ @property
340
+ def T(self) -> "Kronecker":
341
+ transposed_ops = [op.T for op in self.operators]
342
+ return Kronecker(transposed_ops)
343
+
344
+
345
+ def _dense_tree_flatten(dense):
346
+ return (dense.array,), None
347
+
348
+
349
+ def _dense_tree_unflatten(aux_data, children):
350
+ return Dense(children[0])
351
+
352
+
353
+ jtu.register_pytree_node(Dense, _dense_tree_flatten, _dense_tree_unflatten)
354
+
355
+
356
+ def _diagonal_tree_flatten(diagonal):
357
+ return (diagonal.diagonal,), None
358
+
359
+
360
+ def _diagonal_tree_unflatten(aux_data, children):
361
+ return Diagonal(children[0])
362
+
363
+
364
+ jtu.register_pytree_node(Diagonal, _diagonal_tree_flatten, _diagonal_tree_unflatten)
365
+
366
+
367
+ def _identity_tree_flatten(identity):
368
+ return (), (identity._shape, identity._dtype)
369
+
370
+
371
+ def _identity_tree_unflatten(aux_data, children):
372
+ shape, dtype = aux_data
373
+ return Identity(shape, dtype)
374
+
375
+
376
+ jtu.register_pytree_node(Identity, _identity_tree_flatten, _identity_tree_unflatten)
377
+
378
+
379
+ def _triangular_tree_flatten(triangular):
380
+ return (triangular.array,), triangular.lower
381
+
382
+
383
+ def _triangular_tree_unflatten(aux_data, children):
384
+ return Triangular(children[0], aux_data)
385
+
386
+
387
+ jtu.register_pytree_node(
388
+ Triangular, _triangular_tree_flatten, _triangular_tree_unflatten
389
+ )
390
+
391
+
392
+ def _blockdiag_tree_flatten(blockdiag):
393
+ return tuple(blockdiag.operators), blockdiag.multiplicities
394
+
395
+
396
+ def _blockdiag_tree_unflatten(aux_data, children):
397
+ return BlockDiag(list(children), aux_data)
398
+
399
+
400
+ jtu.register_pytree_node(BlockDiag, _blockdiag_tree_flatten, _blockdiag_tree_unflatten)
401
+
402
+
403
+ def _kronecker_tree_flatten(kronecker):
404
+ return tuple(kronecker.operators), None
405
+
406
+
407
+ def _kronecker_tree_unflatten(aux_data, children):
408
+ return Kronecker(list(children))
409
+
410
+
411
+ jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten)
gpjax/linalg/utils.py ADDED
@@ -0,0 +1,33 @@
1
+ """Utility functions for the linear algebra module."""
2
+
3
+ from gpjax.linalg.operators import LinearOperator
4
+
5
+
6
+ class PSDAnnotation:
7
+ """Marker class for PSD (Positive Semi-Definite) annotations."""
8
+
9
+ def __call__(self, A: LinearOperator) -> LinearOperator:
10
+ """Make PSD annotation callable."""
11
+ return psd(A)
12
+
13
+
14
+ # Create the PSD marker similar to cola.PSD
15
+ PSD = PSDAnnotation()
16
+
17
+
18
+ def psd(A: LinearOperator) -> LinearOperator:
19
+ """Mark a linear operator as positive semi-definite.
20
+
21
+ This function acts as a marker/wrapper for positive semi-definite matrices.
22
+
23
+ Args:
24
+ A: A LinearOperator that is assumed to be positive semi-definite.
25
+
26
+ Returns:
27
+ The same LinearOperator, marked as PSD.
28
+ """
29
+ # Add annotations attribute if it doesn't exist
30
+ if not hasattr(A, "annotations"):
31
+ A.annotations = set()
32
+ A.annotations.add(PSD)
33
+ return A