gpjax 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/citation.py +7 -2
- gpjax/distributions.py +16 -56
- gpjax/fit.py +3 -3
- gpjax/gps.py +34 -48
- gpjax/kernels/base.py +1 -1
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +6 -5
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +33 -0
- gpjax/objectives.py +21 -21
- gpjax/parameters.py +11 -13
- gpjax/variational_families.py +43 -37
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/METADATA +49 -9
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/RECORD +21 -18
- gpjax/lower_cholesky.py +0 -69
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/WHEEL +0 -0
- {gpjax-0.11.1.dist-info → gpjax-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Linear algebra operations for GPJax LinearOperators."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from jax import Array
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import jax.scipy as jsp
|
|
8
|
+
from jaxtyping import Float
|
|
9
|
+
|
|
10
|
+
from gpjax.linalg.operators import (
|
|
11
|
+
BlockDiag,
|
|
12
|
+
Dense,
|
|
13
|
+
Diagonal,
|
|
14
|
+
Identity,
|
|
15
|
+
Kronecker,
|
|
16
|
+
LinearOperator,
|
|
17
|
+
Triangular,
|
|
18
|
+
)
|
|
19
|
+
from gpjax.typing import ScalarFloat
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def lower_cholesky(A: LinearOperator) -> LinearOperator:
|
|
23
|
+
"""Compute the lower Cholesky decomposition of a positive semi-definite operator.
|
|
24
|
+
|
|
25
|
+
This function dispatches on the type of the input LinearOperator to provide
|
|
26
|
+
efficient implementations for different operator structures.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
A: A positive semi-definite LinearOperator.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The lower triangular Cholesky factor L such that A = L @ L.T.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def _handle_identity(A):
|
|
36
|
+
return A
|
|
37
|
+
|
|
38
|
+
def _handle_diagonal(A):
|
|
39
|
+
return Diagonal(jnp.sqrt(A.diagonal))
|
|
40
|
+
|
|
41
|
+
def _handle_triangular(A):
|
|
42
|
+
if A.lower:
|
|
43
|
+
return A
|
|
44
|
+
return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
|
|
45
|
+
|
|
46
|
+
def _handle_kronecker(A):
|
|
47
|
+
cholesky_ops = [lower_cholesky(op) for op in A.operators]
|
|
48
|
+
return Kronecker(cholesky_ops)
|
|
49
|
+
|
|
50
|
+
def _handle_blockdiag(A):
|
|
51
|
+
cholesky_ops = [lower_cholesky(op) for op in A.operators]
|
|
52
|
+
return BlockDiag(cholesky_ops, multiplicities=A.multiplicities)
|
|
53
|
+
|
|
54
|
+
def _handle_dense(A):
|
|
55
|
+
return Triangular(jnp.linalg.cholesky(A.array), lower=True)
|
|
56
|
+
|
|
57
|
+
def _handle_default(A):
|
|
58
|
+
return Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
|
|
59
|
+
|
|
60
|
+
dispatch_table = {
|
|
61
|
+
Identity: _handle_identity,
|
|
62
|
+
Diagonal: _handle_diagonal,
|
|
63
|
+
Triangular: _handle_triangular,
|
|
64
|
+
Kronecker: _handle_kronecker,
|
|
65
|
+
BlockDiag: _handle_blockdiag,
|
|
66
|
+
Dense: _handle_dense,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
handler = dispatch_table.get(type(A), _handle_default)
|
|
70
|
+
return handler(A)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def solve(
|
|
74
|
+
A: LinearOperator,
|
|
75
|
+
b: Union[Float[Array, " N"], Float[Array, " N M"]],
|
|
76
|
+
) -> Union[Float[Array, " N"], Float[Array, " N M"]]:
|
|
77
|
+
"""Solve the linear system A @ x = b for x.
|
|
78
|
+
|
|
79
|
+
This function dispatches on the type of the input LinearOperator to provide
|
|
80
|
+
efficient implementations for different operator structures.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
A: A LinearOperator representing the matrix A.
|
|
84
|
+
b: The right-hand side vector or matrix.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The solution x to the linear system.
|
|
88
|
+
"""
|
|
89
|
+
# Handle different shapes of b
|
|
90
|
+
if b.ndim == 1:
|
|
91
|
+
b = b[:, None]
|
|
92
|
+
squeeze_output = True
|
|
93
|
+
else:
|
|
94
|
+
squeeze_output = False
|
|
95
|
+
|
|
96
|
+
# Dispatch based on operator type
|
|
97
|
+
if isinstance(A, Identity):
|
|
98
|
+
# Identity matrix: x = b
|
|
99
|
+
result = b
|
|
100
|
+
|
|
101
|
+
elif isinstance(A, Diagonal):
|
|
102
|
+
# Diagonal matrix: element-wise division
|
|
103
|
+
result = b / A.diagonal[:, None]
|
|
104
|
+
|
|
105
|
+
elif isinstance(A, Triangular):
|
|
106
|
+
# Triangular matrix: use triangular solver
|
|
107
|
+
result = jsp.linalg.solve_triangular(A.array, b, lower=A.lower)
|
|
108
|
+
|
|
109
|
+
elif isinstance(A, Dense):
|
|
110
|
+
# Dense matrix: use standard solver
|
|
111
|
+
result = jnp.linalg.solve(A.array, b)
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
# Default: convert to dense and solve
|
|
115
|
+
result = jnp.linalg.solve(A.to_dense(), b)
|
|
116
|
+
|
|
117
|
+
if squeeze_output:
|
|
118
|
+
result = result.squeeze(-1)
|
|
119
|
+
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def logdet(A: LinearOperator) -> ScalarFloat:
|
|
124
|
+
"""Compute the log-determinant of a linear operator.
|
|
125
|
+
|
|
126
|
+
This function dispatches on the type of the input LinearOperator to provide
|
|
127
|
+
efficient implementations for different operator structures.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
A: A LinearOperator.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The log-determinant of A.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def _handle_identity(A):
|
|
137
|
+
return jnp.array(0.0)
|
|
138
|
+
|
|
139
|
+
def _handle_diagonal(A):
|
|
140
|
+
return jnp.sum(jnp.log(A.diagonal))
|
|
141
|
+
|
|
142
|
+
def _handle_triangular(A):
|
|
143
|
+
diag_elements = jnp.diag(A.array)
|
|
144
|
+
return jnp.sum(jnp.log(diag_elements))
|
|
145
|
+
|
|
146
|
+
def _handle_kronecker(A):
|
|
147
|
+
logdet_val = 0.0
|
|
148
|
+
for i, op in enumerate(A.operators):
|
|
149
|
+
op_logdet = logdet(op)
|
|
150
|
+
power = 1
|
|
151
|
+
for j, other_op in enumerate(A.operators):
|
|
152
|
+
if i != j:
|
|
153
|
+
power *= other_op.shape[0]
|
|
154
|
+
logdet_val += power * op_logdet
|
|
155
|
+
return logdet_val
|
|
156
|
+
|
|
157
|
+
def _handle_blockdiag(A):
|
|
158
|
+
logdet_val = 0.0
|
|
159
|
+
for op, mult in zip(A.operators, A.multiplicities, strict=False):
|
|
160
|
+
logdet_val += mult * logdet(op)
|
|
161
|
+
return logdet_val
|
|
162
|
+
|
|
163
|
+
def _handle_dense(A):
|
|
164
|
+
_, logdet_val = jnp.linalg.slogdet(A.array)
|
|
165
|
+
return logdet_val
|
|
166
|
+
|
|
167
|
+
def _handle_default(A):
|
|
168
|
+
_, logdet_val = jnp.linalg.slogdet(A.to_dense())
|
|
169
|
+
return logdet_val
|
|
170
|
+
|
|
171
|
+
dispatch_table = {
|
|
172
|
+
Identity: _handle_identity,
|
|
173
|
+
Diagonal: _handle_diagonal,
|
|
174
|
+
Triangular: _handle_triangular,
|
|
175
|
+
Kronecker: _handle_kronecker,
|
|
176
|
+
BlockDiag: _handle_blockdiag,
|
|
177
|
+
Dense: _handle_dense,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
handler = dispatch_table.get(type(A), _handle_default)
|
|
181
|
+
return handler(A)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def diag(A: LinearOperator) -> Float[Array, " N"]:
|
|
185
|
+
"""Extract the diagonal of a linear operator.
|
|
186
|
+
|
|
187
|
+
This function dispatches on the type of the input LinearOperator to provide
|
|
188
|
+
efficient implementations for different operator structures.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
A: A LinearOperator.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
The diagonal elements of A as a 1D array.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def _handle_identity(A):
|
|
198
|
+
n = A.shape[0]
|
|
199
|
+
return jnp.ones(n, dtype=A.dtype)
|
|
200
|
+
|
|
201
|
+
def _handle_diagonal(A):
|
|
202
|
+
return A.diagonal
|
|
203
|
+
|
|
204
|
+
def _handle_triangular(A):
|
|
205
|
+
return jnp.diag(A.array)
|
|
206
|
+
|
|
207
|
+
def _handle_kronecker(A):
|
|
208
|
+
result = diag(A.operators[0])
|
|
209
|
+
for op in A.operators[1:]:
|
|
210
|
+
result = jnp.kron(result, diag(op))
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
def _handle_blockdiag(A):
|
|
214
|
+
diags = []
|
|
215
|
+
for op, mult in zip(A.operators, A.multiplicities, strict=False):
|
|
216
|
+
op_diag = diag(op)
|
|
217
|
+
for _ in range(mult):
|
|
218
|
+
diags.append(op_diag)
|
|
219
|
+
return jnp.concatenate(diags)
|
|
220
|
+
|
|
221
|
+
def _handle_dense(A):
|
|
222
|
+
return jnp.diag(A.array)
|
|
223
|
+
|
|
224
|
+
def _handle_default(A):
|
|
225
|
+
return jnp.diag(A.to_dense())
|
|
226
|
+
|
|
227
|
+
dispatch_table = {
|
|
228
|
+
Identity: _handle_identity,
|
|
229
|
+
Diagonal: _handle_diagonal,
|
|
230
|
+
Triangular: _handle_triangular,
|
|
231
|
+
Kronecker: _handle_kronecker,
|
|
232
|
+
BlockDiag: _handle_blockdiag,
|
|
233
|
+
Dense: _handle_dense,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
handler = dispatch_table.get(type(A), _handle_default)
|
|
237
|
+
return handler(A)
|
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""Linear operator abstractions for GPJax."""
|
|
2
|
+
|
|
3
|
+
from abc import (
|
|
4
|
+
ABC,
|
|
5
|
+
abstractmethod,
|
|
6
|
+
)
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
List,
|
|
10
|
+
Tuple,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from jax import Array
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
import jax.tree_util as jtu
|
|
17
|
+
from jaxtyping import Float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LinearOperator(ABC):
|
|
21
|
+
"""Abstract base class for linear operators."""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def shape(self) -> Tuple[int, int]:
|
|
29
|
+
"""Return the shape of the operator."""
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def dtype(self) -> jnp.dtype:
|
|
34
|
+
"""Return the data type of the operator."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
38
|
+
"""Convert the operator to a dense JAX array."""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def T(self) -> "LinearOperator":
|
|
42
|
+
"""Return the transpose of the operator."""
|
|
43
|
+
# Default implementation: convert to dense and transpose
|
|
44
|
+
return Dense(self.to_dense().T)
|
|
45
|
+
|
|
46
|
+
def __matmul__(self, other):
|
|
47
|
+
"""Matrix multiplication with another array or operator."""
|
|
48
|
+
if hasattr(other, "to_dense"):
|
|
49
|
+
# Other is a LinearOperator
|
|
50
|
+
return Dense(self.to_dense() @ other.to_dense())
|
|
51
|
+
else:
|
|
52
|
+
# Other is a JAX array
|
|
53
|
+
return self.to_dense() @ other
|
|
54
|
+
|
|
55
|
+
def __rmatmul__(self, other):
|
|
56
|
+
"""Right matrix multiplication (other @ self)."""
|
|
57
|
+
if hasattr(other, "to_dense"):
|
|
58
|
+
# Other is a LinearOperator
|
|
59
|
+
return Dense(other.to_dense() @ self.to_dense())
|
|
60
|
+
else:
|
|
61
|
+
# Other is a JAX array
|
|
62
|
+
return other @ self.to_dense()
|
|
63
|
+
|
|
64
|
+
def __add__(self, other):
|
|
65
|
+
"""Addition with another array or operator."""
|
|
66
|
+
if hasattr(other, "to_dense"):
|
|
67
|
+
# Other is a LinearOperator
|
|
68
|
+
return Dense(self.to_dense() + other.to_dense())
|
|
69
|
+
else:
|
|
70
|
+
# Other is a JAX array
|
|
71
|
+
return Dense(self.to_dense() + other)
|
|
72
|
+
|
|
73
|
+
def __radd__(self, other):
|
|
74
|
+
"""Right addition (other + self)."""
|
|
75
|
+
if hasattr(other, "to_dense"):
|
|
76
|
+
# Other is a LinearOperator
|
|
77
|
+
return Dense(other.to_dense() + self.to_dense())
|
|
78
|
+
else:
|
|
79
|
+
# Other is a JAX array
|
|
80
|
+
return Dense(other + self.to_dense())
|
|
81
|
+
|
|
82
|
+
def __sub__(self, other):
|
|
83
|
+
"""Subtraction with another array or operator."""
|
|
84
|
+
if hasattr(other, "to_dense"):
|
|
85
|
+
# Other is a LinearOperator
|
|
86
|
+
return Dense(self.to_dense() - other.to_dense())
|
|
87
|
+
else:
|
|
88
|
+
# Other is a JAX array
|
|
89
|
+
return Dense(self.to_dense() - other)
|
|
90
|
+
|
|
91
|
+
def __rsub__(self, other):
|
|
92
|
+
"""Right subtraction (other - self)."""
|
|
93
|
+
if hasattr(other, "to_dense"):
|
|
94
|
+
# Other is a LinearOperator
|
|
95
|
+
return Dense(other.to_dense() - self.to_dense())
|
|
96
|
+
else:
|
|
97
|
+
# Other is a JAX array
|
|
98
|
+
return Dense(other - self.to_dense())
|
|
99
|
+
|
|
100
|
+
def __mul__(self, other):
|
|
101
|
+
"""Scalar multiplication (self * scalar)."""
|
|
102
|
+
if jnp.isscalar(other):
|
|
103
|
+
return Dense(self.to_dense() * other)
|
|
104
|
+
else:
|
|
105
|
+
# Element-wise multiplication with array
|
|
106
|
+
return Dense(self.to_dense() * other)
|
|
107
|
+
|
|
108
|
+
def __rmul__(self, other):
|
|
109
|
+
"""Right scalar multiplication (scalar * self)."""
|
|
110
|
+
if jnp.isscalar(other):
|
|
111
|
+
return Dense(other * self.to_dense())
|
|
112
|
+
else:
|
|
113
|
+
# Element-wise multiplication with array
|
|
114
|
+
return Dense(other * self.to_dense())
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Dense(LinearOperator):
|
|
118
|
+
"""Dense linear operator wrapping a JAX array."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, array: Float[Array, "M N"]):
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.array = array
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def shape(self) -> Tuple[int, int]:
|
|
126
|
+
return self.array.shape
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def dtype(self) -> jnp.dtype:
|
|
130
|
+
return self.array.dtype
|
|
131
|
+
|
|
132
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
133
|
+
return self.array
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def T(self) -> "Dense":
|
|
137
|
+
return Dense(self.array.T)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Diagonal(LinearOperator):
|
|
141
|
+
"""Diagonal linear operator."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, diagonal: Float[Array, " N"]):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.diagonal = diagonal
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def shape(self) -> Tuple[int, int]:
|
|
149
|
+
n = self.diagonal.shape[0]
|
|
150
|
+
return (n, n)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def dtype(self) -> jnp.dtype:
|
|
154
|
+
return self.diagonal.dtype
|
|
155
|
+
|
|
156
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
157
|
+
return jnp.diag(self.diagonal)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def T(self) -> "Diagonal":
|
|
161
|
+
return Diagonal(self.diagonal)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Identity(LinearOperator):
|
|
165
|
+
"""Identity linear operator."""
|
|
166
|
+
|
|
167
|
+
def __init__(self, shape: Union[int, Tuple[int, int]], dtype=jnp.float64):
|
|
168
|
+
super().__init__()
|
|
169
|
+
if isinstance(shape, int):
|
|
170
|
+
self._shape = (shape, shape)
|
|
171
|
+
else:
|
|
172
|
+
if shape[0] != shape[1]:
|
|
173
|
+
raise ValueError(f"Identity matrix must be square, got shape {shape}")
|
|
174
|
+
self._shape = shape
|
|
175
|
+
self._dtype = dtype
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def shape(self) -> Tuple[int, int]:
|
|
179
|
+
return self._shape
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def dtype(self) -> Any:
|
|
183
|
+
return self._dtype
|
|
184
|
+
|
|
185
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
186
|
+
n = self._shape[0]
|
|
187
|
+
return jnp.eye(n, dtype=self._dtype)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def T(self) -> "Identity":
|
|
191
|
+
return Identity(self._shape, dtype=self._dtype)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class Triangular(LinearOperator):
|
|
195
|
+
"""Triangular linear operator."""
|
|
196
|
+
|
|
197
|
+
def __init__(self, array: Float[Array, "N N"], lower: bool = True):
|
|
198
|
+
super().__init__()
|
|
199
|
+
self.array = array
|
|
200
|
+
self.lower = lower
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def shape(self) -> Tuple[int, int]:
|
|
204
|
+
return self.array.shape
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def dtype(self) -> Any:
|
|
208
|
+
return self.array.dtype
|
|
209
|
+
|
|
210
|
+
def to_dense(self) -> Float[Array, "N N"]:
|
|
211
|
+
if self.lower:
|
|
212
|
+
return jnp.tril(self.array)
|
|
213
|
+
else:
|
|
214
|
+
return jnp.triu(self.array)
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def T(self) -> "Triangular":
|
|
218
|
+
return Triangular(self.array.T, lower=not self.lower)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class BlockDiag(LinearOperator):
|
|
222
|
+
"""Block diagonal linear operator."""
|
|
223
|
+
|
|
224
|
+
def __init__(
|
|
225
|
+
self, operators: List[LinearOperator], multiplicities: List[int] = None
|
|
226
|
+
):
|
|
227
|
+
super().__init__()
|
|
228
|
+
self.operators = operators
|
|
229
|
+
|
|
230
|
+
# Handle multiplicities - how many times each block is repeated
|
|
231
|
+
if multiplicities is None:
|
|
232
|
+
self.multiplicities = [1] * len(operators)
|
|
233
|
+
else:
|
|
234
|
+
if len(multiplicities) != len(operators):
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Length of multiplicities ({len(multiplicities)}) must match operators ({len(operators)})"
|
|
237
|
+
)
|
|
238
|
+
self.multiplicities = multiplicities
|
|
239
|
+
|
|
240
|
+
# Calculate total shape with multiplicities
|
|
241
|
+
rows = sum(
|
|
242
|
+
op.shape[0] * mult
|
|
243
|
+
for op, mult in zip(operators, self.multiplicities, strict=False)
|
|
244
|
+
)
|
|
245
|
+
cols = sum(
|
|
246
|
+
op.shape[1] * mult
|
|
247
|
+
for op, mult in zip(operators, self.multiplicities, strict=False)
|
|
248
|
+
)
|
|
249
|
+
self._shape = (rows, cols)
|
|
250
|
+
|
|
251
|
+
# Use dtype of first operator (assuming all same dtype)
|
|
252
|
+
if operators:
|
|
253
|
+
self._dtype = operators[0].dtype
|
|
254
|
+
else:
|
|
255
|
+
self._dtype = jnp.float64
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def shape(self) -> Tuple[int, int]:
|
|
259
|
+
return self._shape
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def dtype(self) -> Any:
|
|
263
|
+
return self._dtype
|
|
264
|
+
|
|
265
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
266
|
+
if not self.operators:
|
|
267
|
+
return jnp.zeros(self._shape, dtype=self._dtype)
|
|
268
|
+
|
|
269
|
+
# Convert each operator to dense and create block diagonal with multiplicities
|
|
270
|
+
expanded_blocks = []
|
|
271
|
+
for op, mult in zip(self.operators, self.multiplicities, strict=False):
|
|
272
|
+
op_dense = op.to_dense()
|
|
273
|
+
for _ in range(mult):
|
|
274
|
+
expanded_blocks.append(op_dense)
|
|
275
|
+
|
|
276
|
+
# Create the full block diagonal matrix
|
|
277
|
+
n_blocks = len(expanded_blocks)
|
|
278
|
+
if n_blocks == 0:
|
|
279
|
+
return jnp.zeros(self._shape, dtype=self._dtype)
|
|
280
|
+
|
|
281
|
+
# Build the block diagonal matrix
|
|
282
|
+
rows = []
|
|
283
|
+
for i in range(n_blocks):
|
|
284
|
+
row = []
|
|
285
|
+
for j in range(n_blocks):
|
|
286
|
+
if i == j:
|
|
287
|
+
row.append(expanded_blocks[i])
|
|
288
|
+
else:
|
|
289
|
+
row.append(
|
|
290
|
+
jnp.zeros(
|
|
291
|
+
(expanded_blocks[i].shape[0], expanded_blocks[j].shape[1]),
|
|
292
|
+
dtype=self._dtype,
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
rows.append(row)
|
|
296
|
+
return jnp.block(rows)
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def T(self) -> "BlockDiag":
|
|
300
|
+
transposed_ops = [op.T for op in self.operators]
|
|
301
|
+
return BlockDiag(transposed_ops, multiplicities=self.multiplicities)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Kronecker(LinearOperator):
|
|
305
|
+
"""Kronecker product linear operator."""
|
|
306
|
+
|
|
307
|
+
def __init__(self, operators: List[LinearOperator]):
|
|
308
|
+
super().__init__()
|
|
309
|
+
if len(operators) < 2:
|
|
310
|
+
raise ValueError("Kronecker product requires at least 2 operators")
|
|
311
|
+
self.operators = operators
|
|
312
|
+
|
|
313
|
+
# Calculate shape as product of individual shapes
|
|
314
|
+
rows = 1
|
|
315
|
+
cols = 1
|
|
316
|
+
for op in operators:
|
|
317
|
+
rows *= op.shape[0]
|
|
318
|
+
cols *= op.shape[1]
|
|
319
|
+
self._shape = (rows, cols)
|
|
320
|
+
|
|
321
|
+
# Use dtype of first operator
|
|
322
|
+
self._dtype = operators[0].dtype
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def shape(self) -> Tuple[int, int]:
|
|
326
|
+
return self._shape
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def dtype(self) -> Any:
|
|
330
|
+
return self._dtype
|
|
331
|
+
|
|
332
|
+
def to_dense(self) -> Float[Array, "M N"]:
|
|
333
|
+
# Convert to dense and compute Kronecker product
|
|
334
|
+
result = self.operators[0].to_dense()
|
|
335
|
+
for op in self.operators[1:]:
|
|
336
|
+
result = jnp.kron(result, op.to_dense())
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def T(self) -> "Kronecker":
|
|
341
|
+
transposed_ops = [op.T for op in self.operators]
|
|
342
|
+
return Kronecker(transposed_ops)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _dense_tree_flatten(dense):
|
|
346
|
+
return (dense.array,), None
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _dense_tree_unflatten(aux_data, children):
|
|
350
|
+
return Dense(children[0])
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
jtu.register_pytree_node(Dense, _dense_tree_flatten, _dense_tree_unflatten)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _diagonal_tree_flatten(diagonal):
|
|
357
|
+
return (diagonal.diagonal,), None
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _diagonal_tree_unflatten(aux_data, children):
|
|
361
|
+
return Diagonal(children[0])
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
jtu.register_pytree_node(Diagonal, _diagonal_tree_flatten, _diagonal_tree_unflatten)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _identity_tree_flatten(identity):
|
|
368
|
+
return (), (identity._shape, identity._dtype)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _identity_tree_unflatten(aux_data, children):
|
|
372
|
+
shape, dtype = aux_data
|
|
373
|
+
return Identity(shape, dtype)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
jtu.register_pytree_node(Identity, _identity_tree_flatten, _identity_tree_unflatten)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _triangular_tree_flatten(triangular):
|
|
380
|
+
return (triangular.array,), triangular.lower
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _triangular_tree_unflatten(aux_data, children):
|
|
384
|
+
return Triangular(children[0], aux_data)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
jtu.register_pytree_node(
|
|
388
|
+
Triangular, _triangular_tree_flatten, _triangular_tree_unflatten
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _blockdiag_tree_flatten(blockdiag):
|
|
393
|
+
return tuple(blockdiag.operators), blockdiag.multiplicities
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _blockdiag_tree_unflatten(aux_data, children):
|
|
397
|
+
return BlockDiag(list(children), aux_data)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
jtu.register_pytree_node(BlockDiag, _blockdiag_tree_flatten, _blockdiag_tree_unflatten)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _kronecker_tree_flatten(kronecker):
|
|
404
|
+
return tuple(kronecker.operators), None
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _kronecker_tree_unflatten(aux_data, children):
|
|
408
|
+
return Kronecker(list(children))
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten)
|
gpjax/linalg/utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Utility functions for the linear algebra module."""
|
|
2
|
+
|
|
3
|
+
from gpjax.linalg.operators import LinearOperator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PSDAnnotation:
|
|
7
|
+
"""Marker class for PSD (Positive Semi-Definite) annotations."""
|
|
8
|
+
|
|
9
|
+
def __call__(self, A: LinearOperator) -> LinearOperator:
|
|
10
|
+
"""Make PSD annotation callable."""
|
|
11
|
+
return psd(A)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Create the PSD marker similar to cola.PSD
|
|
15
|
+
PSD = PSDAnnotation()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def psd(A: LinearOperator) -> LinearOperator:
|
|
19
|
+
"""Mark a linear operator as positive semi-definite.
|
|
20
|
+
|
|
21
|
+
This function acts as a marker/wrapper for positive semi-definite matrices.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
A: A LinearOperator that is assumed to be positive semi-definite.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The same LinearOperator, marked as PSD.
|
|
28
|
+
"""
|
|
29
|
+
# Add annotations attribute if it doesn't exist
|
|
30
|
+
if not hasattr(A, "annotations"):
|
|
31
|
+
A.annotations = set()
|
|
32
|
+
A.annotations.add(PSD)
|
|
33
|
+
return A
|