liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
liesel_gam/dist.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from functools import cached_property, reduce
|
|
3
|
+
from math import prod
|
|
4
|
+
from typing import Self
|
|
2
5
|
|
|
3
6
|
import jax
|
|
4
7
|
import jax.numpy as jnp
|
|
@@ -98,3 +101,540 @@ class MultivariateNormalSingular(tfd.Distribution):
|
|
|
98
101
|
r = tuple(range(event_shape))
|
|
99
102
|
diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
|
|
100
103
|
return evecs @ diags
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _diag_of_kron_of_diag_with_identities(in_diags: Sequence[jax.Array]) -> jax.Array:
|
|
107
|
+
sizes = [v.shape[-1] for v in in_diags]
|
|
108
|
+
diag = jnp.zeros(prod(sizes))
|
|
109
|
+
|
|
110
|
+
for j in range(len(in_diags)):
|
|
111
|
+
left_size = prod(sizes[:j])
|
|
112
|
+
right_size = prod(sizes[(j + 1) :])
|
|
113
|
+
d = in_diags[j]
|
|
114
|
+
|
|
115
|
+
# First handle identities to the right: repeat
|
|
116
|
+
d_rep = jnp.repeat(d, right_size)
|
|
117
|
+
|
|
118
|
+
# Then handle identities to the left: tile
|
|
119
|
+
d_tile = jnp.tile(d_rep, left_size)
|
|
120
|
+
|
|
121
|
+
diag = diag + d_tile
|
|
122
|
+
|
|
123
|
+
return diag
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _materialize_precision(penalties: Sequence[jax.Array]) -> jax.Array:
|
|
127
|
+
"""
|
|
128
|
+
Build K(tau^2) = sum_{j=1}^p K_j / tau_j^2
|
|
129
|
+
with K_j = I_{d1} ⊗ ... ⊗ I_{dj-1} ⊗ Ktilde_j ⊗ I_{dj+1} ⊗ ... ⊗ I_{dp}.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
tau2 : array-like, shape (p,)
|
|
134
|
+
Squared smoothing parameters (tau_1^2, ..., tau_p^2).
|
|
135
|
+
K_tilde : sequence of arrays
|
|
136
|
+
List/tuple of p matrices, K_tilde[j] has shape (d_j, d_j).
|
|
137
|
+
dims : sequence of ints, optional
|
|
138
|
+
d_j for each dimension. If None, inferred from K_tilde.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
K : jnp.ndarray, shape (∏ d_j, ∏ d_j)
|
|
143
|
+
"""
|
|
144
|
+
p = len(penalties)
|
|
145
|
+
|
|
146
|
+
dims = [penalties[j].shape[-1] for j in range(p)]
|
|
147
|
+
|
|
148
|
+
# Build K_1 / tau1^2 as initial value
|
|
149
|
+
factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
|
|
150
|
+
|
|
151
|
+
def kron_all(mats):
|
|
152
|
+
"""Kronecker product of a list of matrices."""
|
|
153
|
+
return reduce(jnp.kron, mats)
|
|
154
|
+
|
|
155
|
+
K = kron_all(factors)
|
|
156
|
+
|
|
157
|
+
# Add remaining K_j / tau_j^2
|
|
158
|
+
for j in range(1, p):
|
|
159
|
+
factors = [penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)]
|
|
160
|
+
Kj = kron_all(factors)
|
|
161
|
+
K = K + Kj
|
|
162
|
+
|
|
163
|
+
return K
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _compute_masks(
|
|
167
|
+
penalties: Sequence[jax.Array],
|
|
168
|
+
penalties_eigvalues: Sequence[jax.Array],
|
|
169
|
+
eps: float = 1e-6,
|
|
170
|
+
) -> jax.Array:
|
|
171
|
+
diag = _diag_of_kron_of_diag_with_identities
|
|
172
|
+
|
|
173
|
+
B = penalties_eigvalues[0].shape[:-1]
|
|
174
|
+
B_flat = int(jnp.prod(jnp.array(B))) if B else 1
|
|
175
|
+
flat_evs = [ev.reshape(B_flat, ev.shape[-1]) for ev in penalties_eigvalues]
|
|
176
|
+
|
|
177
|
+
diags = jax.vmap(diag)(flat_evs) # (B_flat, N)
|
|
178
|
+
K = _materialize_precision(penalties) # (B, N, N)
|
|
179
|
+
K = K.reshape((B_flat,) + K.shape[-2:])
|
|
180
|
+
|
|
181
|
+
ranks = jax.vmap(jnp.linalg.matrix_rank)(K) # (B_flat,)
|
|
182
|
+
masks = (diags > eps).sum(-1) # (B_flat,)
|
|
183
|
+
|
|
184
|
+
if not jnp.allclose(masks, ranks):
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Number of zero eigenvalues ({masks}) does not "
|
|
187
|
+
f"correspond to penalty rank ({ranks}). Maybe a different value for "
|
|
188
|
+
f"{eps=} can help."
|
|
189
|
+
)
|
|
190
|
+
mask = diags > eps
|
|
191
|
+
|
|
192
|
+
return mask.reshape(B + (mask.shape[-1],))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _apply_Ki_along_axis(B, K, axis, dims, total_size):
|
|
196
|
+
"""
|
|
197
|
+
Apply K (Di x Di) along a given axis of B (shape dims),
|
|
198
|
+
returning an array with the same shape as B.
|
|
199
|
+
"""
|
|
200
|
+
# Move the axis we want to the front: (Di, rest...)
|
|
201
|
+
B_perm = jnp.moveaxis(B, axis, 0)
|
|
202
|
+
Di = dims[axis]
|
|
203
|
+
rest = total_size // Di
|
|
204
|
+
|
|
205
|
+
# Flatten everything except that axis: (Di, rest)
|
|
206
|
+
B_flat = B_perm.reshape(Di, rest)
|
|
207
|
+
|
|
208
|
+
# Matrix multiply: K @ B_flat -> (Di, rest)
|
|
209
|
+
C_flat = K @ B_flat
|
|
210
|
+
|
|
211
|
+
# Restore original shape/order
|
|
212
|
+
C_perm = C_flat.reshape(B_perm.shape)
|
|
213
|
+
C = jnp.moveaxis(C_perm, 0, axis) # back to shape = dims
|
|
214
|
+
return C
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _kron_sum_quadratic(x: jax.Array, Ks: Sequence[jax.Array]) -> jax.Array:
|
|
218
|
+
dims = [K.shape[0] for K in Ks]
|
|
219
|
+
# Basic sanity checks (cheap, can remove if you like)
|
|
220
|
+
for K, d in zip(Ks, dims):
|
|
221
|
+
assert K.shape == (d, d)
|
|
222
|
+
total_size = prod(dims)
|
|
223
|
+
assert x.size == total_size
|
|
224
|
+
|
|
225
|
+
# Reshape x into m-dimensional tensor
|
|
226
|
+
B = x.reshape(dims)
|
|
227
|
+
|
|
228
|
+
total = jnp.array(0.0, dtype=x.dtype)
|
|
229
|
+
for axis, K in enumerate(Ks):
|
|
230
|
+
C = _apply_Ki_along_axis(B, K, axis, dims, total_size)
|
|
231
|
+
total = total + jnp.vdot(B, C) # scalar
|
|
232
|
+
|
|
233
|
+
return total
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class StructuredPenaltyOperator:
|
|
237
|
+
"""
|
|
238
|
+
- scales is an array with shape (B,K), where B is the batch shape and K is the
|
|
239
|
+
number of penalties. Each scale parameter corresponds to one penalty.
|
|
240
|
+
- penalties is a sequence of length K, containing arrays with shape (B, Di, Di).
|
|
241
|
+
B is the batch shape.
|
|
242
|
+
(Di, Di) is the block size of the individual penalty and can differ between
|
|
243
|
+
elements of the penalties sequence.
|
|
244
|
+
N = prod([p.shape[-1] for p in penalties]).
|
|
245
|
+
- penalties_eigvalues is a sequence of length K, containing arrays of shape (B, Di).
|
|
246
|
+
- penalties_eigvectors is a sequence of length K, containing arrays of shape
|
|
247
|
+
(B, Di, Di).
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(
|
|
251
|
+
self,
|
|
252
|
+
scales: jax.Array,
|
|
253
|
+
penalties: Sequence[jax.Array],
|
|
254
|
+
penalties_eigvalues: Sequence[jax.Array],
|
|
255
|
+
masks: jax.Array | None = None,
|
|
256
|
+
validate_args: bool = False,
|
|
257
|
+
tol: float = 1e-6,
|
|
258
|
+
) -> None:
|
|
259
|
+
self._scales = jnp.asarray(scales)
|
|
260
|
+
self._penalties = tuple([jnp.asarray(p) for p in penalties])
|
|
261
|
+
self._penalties_eigvalues = tuple(
|
|
262
|
+
[jnp.asarray(ev) for ev in penalties_eigvalues]
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if validate_args:
|
|
266
|
+
self._validate_penalties()
|
|
267
|
+
|
|
268
|
+
self._sizes = [K.shape[-1] for K in self._penalties]
|
|
269
|
+
|
|
270
|
+
self._masks = masks
|
|
271
|
+
self._tol = tol
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def from_penalties(
|
|
275
|
+
cls, scales: jax.Array, penalties: Sequence[jax.Array], eps: float = 1e-6
|
|
276
|
+
) -> Self:
|
|
277
|
+
evs = [jnp.linalg.eigh(K) for K in penalties]
|
|
278
|
+
evals = [ev.eigenvalues for ev in evs]
|
|
279
|
+
|
|
280
|
+
masks = _compute_masks(penalties=penalties, penalties_eigvalues=evals, eps=eps)
|
|
281
|
+
|
|
282
|
+
return cls(
|
|
283
|
+
scales=scales,
|
|
284
|
+
penalties=penalties,
|
|
285
|
+
penalties_eigvalues=evals,
|
|
286
|
+
masks=masks,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
@cached_property
|
|
290
|
+
def variances(self) -> jax.Array:
|
|
291
|
+
return jnp.square(self._scales)
|
|
292
|
+
|
|
293
|
+
def materialize_precision(self) -> jax.Array:
|
|
294
|
+
return self._materialize_precision(self.variances)
|
|
295
|
+
|
|
296
|
+
def materialize_penalty(self) -> jax.Array:
|
|
297
|
+
return self._materialize_precision(jnp.ones_like(self.variances))
|
|
298
|
+
|
|
299
|
+
def _materialize_precision(self, variances: jax.Array) -> jax.Array:
|
|
300
|
+
"""This is inefficient, should be used for testing only."""
|
|
301
|
+
p = len(self._penalties)
|
|
302
|
+
|
|
303
|
+
dims = [self._penalties[j].shape[-1] for j in range(p)]
|
|
304
|
+
|
|
305
|
+
def kron_all(mats):
|
|
306
|
+
"""Kronecker product of a list of matrices."""
|
|
307
|
+
return reduce(jnp.kron, mats)
|
|
308
|
+
|
|
309
|
+
def one_batch(variances, *penalties):
|
|
310
|
+
# Build K_1 / tau1^2 as initial value
|
|
311
|
+
factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
|
|
312
|
+
|
|
313
|
+
K = kron_all(factors) / variances[0]
|
|
314
|
+
|
|
315
|
+
# Add remaining K_j / tau_j^2
|
|
316
|
+
for j in range(1, p):
|
|
317
|
+
factors = [
|
|
318
|
+
penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)
|
|
319
|
+
]
|
|
320
|
+
Kj = kron_all(factors)
|
|
321
|
+
K = K + Kj / variances[j]
|
|
322
|
+
|
|
323
|
+
return K
|
|
324
|
+
|
|
325
|
+
batch_shape = variances.shape[:-1]
|
|
326
|
+
K = variances.shape[-1]
|
|
327
|
+
|
|
328
|
+
# flatten batch dims so we can vmap over a single leading dim
|
|
329
|
+
B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
|
|
330
|
+
tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
|
|
331
|
+
pens_flat = [p.reshape((B_flat,) + p.shape[-2:]) for p in self._penalties]
|
|
332
|
+
|
|
333
|
+
big_K_fun = jax.vmap(one_batch, in_axes=(0,) + (0,) * K)
|
|
334
|
+
big_K = big_K_fun(tau2_flat, *pens_flat)
|
|
335
|
+
|
|
336
|
+
N = prod(dims)
|
|
337
|
+
big_K = jnp.reshape(big_K, batch_shape + (N, N))
|
|
338
|
+
return big_K
|
|
339
|
+
|
|
340
|
+
def _sum_of_scaled_eigenvalues(
|
|
341
|
+
self, variances: jax.Array, eigenvalues: Sequence[jax.Array]
|
|
342
|
+
) -> jax.Array:
|
|
343
|
+
"""
|
|
344
|
+
Expects
|
|
345
|
+
- variances (p,)
|
|
346
|
+
- eigenvalues (p, Di)
|
|
347
|
+
|
|
348
|
+
Returns (N,) where N = prod(Di)
|
|
349
|
+
"""
|
|
350
|
+
diag = jnp.zeros(prod(self._sizes))
|
|
351
|
+
|
|
352
|
+
for j in range(len(self._penalties)):
|
|
353
|
+
left_size = prod(self._sizes[:j])
|
|
354
|
+
right_size = prod(self._sizes[(j + 1) :])
|
|
355
|
+
d = eigenvalues[j] / variances[j]
|
|
356
|
+
|
|
357
|
+
# First handle identities to the right: repeat
|
|
358
|
+
d_rep = jnp.repeat(d, right_size)
|
|
359
|
+
|
|
360
|
+
# Then handle identities to the left: tile
|
|
361
|
+
d_tile = jnp.tile(d_rep, left_size)
|
|
362
|
+
|
|
363
|
+
diag = diag + d_tile
|
|
364
|
+
|
|
365
|
+
return diag
|
|
366
|
+
|
|
367
|
+
def log_pdet(self) -> jax.Array:
|
|
368
|
+
variances = self.variances # shape (B..., K)
|
|
369
|
+
batch_shape = variances.shape[:-1]
|
|
370
|
+
K = variances.shape[-1]
|
|
371
|
+
|
|
372
|
+
# flatten batch dims so we can vmap over a single leading dim
|
|
373
|
+
B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
|
|
374
|
+
tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
|
|
375
|
+
|
|
376
|
+
# eigenvalues per penalty, flattened over batch
|
|
377
|
+
eigvals_flat = [
|
|
378
|
+
ev.reshape(B_flat, ev.shape[-1]) for ev in self._penalties_eigvalues
|
|
379
|
+
] # list of K arrays (B_flat, Di)
|
|
380
|
+
|
|
381
|
+
def _single_diag(variances, *eigenvalues):
|
|
382
|
+
diag = self._sum_of_scaled_eigenvalues(variances, eigenvalues)
|
|
383
|
+
return diag
|
|
384
|
+
|
|
385
|
+
# vmap over flattened batch dimension
|
|
386
|
+
diag_flat = jax.vmap(_single_diag, in_axes=(0,) + (0,) * K)(
|
|
387
|
+
tau2_flat,
|
|
388
|
+
*eigvals_flat,
|
|
389
|
+
)
|
|
390
|
+
diag = jnp.reshape(diag_flat, batch_shape + (diag_flat.shape[-1],))
|
|
391
|
+
|
|
392
|
+
if self._masks is None:
|
|
393
|
+
mask = diag > self._tol
|
|
394
|
+
else:
|
|
395
|
+
mask = self._masks
|
|
396
|
+
|
|
397
|
+
logdet = jnp.log(jnp.where(mask, diag, 1.0)).sum(-1)
|
|
398
|
+
|
|
399
|
+
return logdet
|
|
400
|
+
|
|
401
|
+
def quad_form(self, x: jax.Array) -> jax.Array:
|
|
402
|
+
variances = self.variances # shape (B..., K)
|
|
403
|
+
batch_shape = variances.shape[:-1]
|
|
404
|
+
batch_shape_x = x.shape[:-1]
|
|
405
|
+
if batch_shape_x and (batch_shape_x != batch_shape):
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"x has batch shape {batch_shape_x}, but batch size is {batch_shape}."
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
K = variances.shape[-1]
|
|
411
|
+
N = x.shape[-1]
|
|
412
|
+
|
|
413
|
+
# flatten batch dims so we can vmap over a single leading dim
|
|
414
|
+
B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
|
|
415
|
+
tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
|
|
416
|
+
if batch_shape_x:
|
|
417
|
+
x_flat = x.reshape(B_flat, N) # (B_flat, N)
|
|
418
|
+
in_axis_x = 0
|
|
419
|
+
else:
|
|
420
|
+
x_flat = x
|
|
421
|
+
in_axis_x = None
|
|
422
|
+
|
|
423
|
+
pens_flat = [p.reshape((B_flat,) + p.shape[-2:]) for p in self._penalties]
|
|
424
|
+
|
|
425
|
+
def kron_sum_quadratic(x, variances, *penalties):
|
|
426
|
+
p = penalties
|
|
427
|
+
v = variances
|
|
428
|
+
scaled_penalties = [p[i] / v[i] for i in range(len(penalties))]
|
|
429
|
+
return _kron_sum_quadratic(x, scaled_penalties)
|
|
430
|
+
|
|
431
|
+
quad_form_vec = jax.vmap(
|
|
432
|
+
kron_sum_quadratic, in_axes=(in_axis_x,) + (0,) + (0,) * K
|
|
433
|
+
)
|
|
434
|
+
quad_form_out = quad_form_vec(x_flat, tau2_flat, *pens_flat)
|
|
435
|
+
|
|
436
|
+
quad_form_out = jnp.reshape(quad_form_out, batch_shape)
|
|
437
|
+
return quad_form_out
|
|
438
|
+
|
|
439
|
+
def _validate_penalties(self) -> None:
|
|
440
|
+
# validate number of penalty matrices
|
|
441
|
+
n_penalties1 = self._scales.shape[-1]
|
|
442
|
+
n_penalties2 = len(self._penalties)
|
|
443
|
+
n_penalties3 = len(self._penalties_eigvalues)
|
|
444
|
+
|
|
445
|
+
if not len({n_penalties1, n_penalties2, n_penalties3}) == 1:
|
|
446
|
+
msg1 = "Got inconsistent numbers of penalties. "
|
|
447
|
+
msg2 = f"Number of scale parameters: {n_penalties1}"
|
|
448
|
+
msg3 = f"Number of penalty matrices: {n_penalties2}. "
|
|
449
|
+
msg4 = f"Number of eigenvalue vectors: {n_penalties3}. "
|
|
450
|
+
raise ValueError(msg1 + msg2 + msg3 + msg4)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class MultivariateNormalStructured(tfd.Distribution):
|
|
454
|
+
"""
|
|
455
|
+
- loc is an array with shape (B, N), where B is the batch shape and N is the
|
|
456
|
+
event shape.
|
|
457
|
+
- scales is an array with shape (B,K), where B is the batch shape and K is the
|
|
458
|
+
number of penalties. Each scale parameter corresponds to one penalty.
|
|
459
|
+
- penalties is a sequence of length K, containing arrays with shape (B, Di, Di).
|
|
460
|
+
B is the batch shape.
|
|
461
|
+
(Di, Di) is the block size of the individual penalty and can differ between
|
|
462
|
+
elements of the penalties sequence.
|
|
463
|
+
N = prod([p.shape[-1] for p in penalties]).
|
|
464
|
+
- penalties_eigvalues is a sequence of length K, containing arrays of shape (B, Di).
|
|
465
|
+
- penalties_eigvectors is a sequence of length K, containing arrays of shape
|
|
466
|
+
(B, Di, Di).
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
def __init__(
|
|
470
|
+
self,
|
|
471
|
+
loc: Array,
|
|
472
|
+
op: StructuredPenaltyOperator,
|
|
473
|
+
validate_args: bool = False,
|
|
474
|
+
allow_nan_stats: bool = True,
|
|
475
|
+
name: str = "MultivariateNormalStructuredSingular",
|
|
476
|
+
include_normalizing_constant: bool = True,
|
|
477
|
+
):
|
|
478
|
+
parameters = dict(locals())
|
|
479
|
+
|
|
480
|
+
self._loc = jnp.asarray(loc)
|
|
481
|
+
self._op = op
|
|
482
|
+
self._n = self._loc.shape[-1]
|
|
483
|
+
self._include_normalizing_constant = include_normalizing_constant
|
|
484
|
+
|
|
485
|
+
if validate_args:
|
|
486
|
+
self._validate_penalties()
|
|
487
|
+
self._validate_event_dim()
|
|
488
|
+
|
|
489
|
+
super().__init__(
|
|
490
|
+
dtype=self._loc.dtype,
|
|
491
|
+
reparameterization_type=tfd.NOT_REPARAMETERIZED,
|
|
492
|
+
validate_args=validate_args,
|
|
493
|
+
allow_nan_stats=allow_nan_stats,
|
|
494
|
+
parameters=parameters,
|
|
495
|
+
name=name,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
def _validate_event_dim(self) -> None:
|
|
499
|
+
# validate sample size
|
|
500
|
+
n_loc = self._loc.shape[-1]
|
|
501
|
+
ndim_penalties = [p.shape[-1] for p in self._op._penalties]
|
|
502
|
+
n_penalties = prod(ndim_penalties)
|
|
503
|
+
|
|
504
|
+
if not n_loc == n_penalties:
|
|
505
|
+
msg1 = "Got inconsistent event dimensions. "
|
|
506
|
+
msg2 = f"Event dimension implied by loc: {n_loc}. "
|
|
507
|
+
msg3 = f"Event dimension implied by penalties: {n_penalties}"
|
|
508
|
+
raise ValueError(msg1 + msg2 + msg3)
|
|
509
|
+
|
|
510
|
+
def _batch_shape(self):
|
|
511
|
+
variances = self._op.variances # shape (B..., K)
|
|
512
|
+
batch_shape = tuple(variances.shape[:-1])
|
|
513
|
+
return tf.TensorShape(batch_shape)
|
|
514
|
+
|
|
515
|
+
def _batch_shape_tensor(self):
|
|
516
|
+
variances = self._op.variances # shape (B..., K)
|
|
517
|
+
batch_shape = tuple(variances.shape[:-1])
|
|
518
|
+
return jnp.array(batch_shape, dtype=self._loc.dtype)
|
|
519
|
+
|
|
520
|
+
def _event_shape(self):
|
|
521
|
+
return tf.TensorShape((jnp.shape(self._loc)[-1],))
|
|
522
|
+
|
|
523
|
+
def _event_shape_tensor(self):
|
|
524
|
+
return jnp.array((jnp.shape(self._loc)[-1],), dtype=self._loc.dtype)
|
|
525
|
+
|
|
526
|
+
def _log_prob(self, x: Array) -> jax.Array:
|
|
527
|
+
x = jnp.asarray(x)
|
|
528
|
+
x_centered = x - self._loc
|
|
529
|
+
|
|
530
|
+
log_pdet = self._op.log_pdet()
|
|
531
|
+
quad_form = self._op.quad_form(x_centered)
|
|
532
|
+
|
|
533
|
+
# early returns, minimally more efficient
|
|
534
|
+
if not self._include_normalizing_constant:
|
|
535
|
+
return 0.5 * (log_pdet - quad_form)
|
|
536
|
+
|
|
537
|
+
const = -(self._n / 2) * jnp.log(2 * jnp.pi)
|
|
538
|
+
|
|
539
|
+
return 0.5 * (log_pdet - quad_form) + const
|
|
540
|
+
|
|
541
|
+
@classmethod
|
|
542
|
+
def from_penalties(
|
|
543
|
+
cls,
|
|
544
|
+
loc: Array,
|
|
545
|
+
scales: Array,
|
|
546
|
+
penalties: Sequence[Array],
|
|
547
|
+
tol: float = 1e-6,
|
|
548
|
+
validate_args: bool = False,
|
|
549
|
+
allow_nan_stats: bool = True,
|
|
550
|
+
include_normalizing_constant: bool = True,
|
|
551
|
+
) -> Self:
|
|
552
|
+
"""
|
|
553
|
+
This is expensive, because it computes eigenvalue decompositions of all
|
|
554
|
+
penalty matrices. Should only be used when performance is irrelevant.
|
|
555
|
+
"""
|
|
556
|
+
constructor = cls.get_locscale_constructor(
|
|
557
|
+
penalties=penalties,
|
|
558
|
+
tol=tol,
|
|
559
|
+
precompute_masks=False,
|
|
560
|
+
validate_args=validate_args,
|
|
561
|
+
allow_nan_stats=allow_nan_stats,
|
|
562
|
+
include_normalizing_constant=include_normalizing_constant,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
return constructor(loc, scales)
|
|
566
|
+
|
|
567
|
+
@classmethod
|
|
568
|
+
def get_locscale_constructor(
|
|
569
|
+
cls,
|
|
570
|
+
penalties: Sequence[Array],
|
|
571
|
+
tol: float = 1e-6,
|
|
572
|
+
precompute_masks: bool = True,
|
|
573
|
+
validate_args: bool = False,
|
|
574
|
+
allow_nan_stats: bool = True,
|
|
575
|
+
include_normalizing_constant: bool = True,
|
|
576
|
+
) -> Callable[[Array, Array], "MultivariateNormalStructured"]:
|
|
577
|
+
penalties_ = [jnp.asarray(p) for p in penalties]
|
|
578
|
+
evs = [jnp.linalg.eigh(K) for K in penalties]
|
|
579
|
+
evals = [ev.eigenvalues for ev in evs]
|
|
580
|
+
|
|
581
|
+
if precompute_masks:
|
|
582
|
+
masks = _compute_masks(
|
|
583
|
+
penalties=penalties_, penalties_eigvalues=evals, eps=tol
|
|
584
|
+
)
|
|
585
|
+
else:
|
|
586
|
+
masks = None
|
|
587
|
+
|
|
588
|
+
def construct_dist(loc: Array, scales: Array) -> "MultivariateNormalStructured":
|
|
589
|
+
loc = jnp.asarray(loc)
|
|
590
|
+
scales = jnp.asarray(scales)
|
|
591
|
+
op = StructuredPenaltyOperator(
|
|
592
|
+
scales=scales,
|
|
593
|
+
penalties=penalties_,
|
|
594
|
+
penalties_eigvalues=evals,
|
|
595
|
+
masks=masks,
|
|
596
|
+
tol=tol,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
dist = cls(
|
|
600
|
+
loc=loc,
|
|
601
|
+
op=op,
|
|
602
|
+
validate_args=validate_args,
|
|
603
|
+
allow_nan_stats=allow_nan_stats,
|
|
604
|
+
include_normalizing_constant=include_normalizing_constant,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
return dist
|
|
608
|
+
|
|
609
|
+
return construct_dist
|
|
610
|
+
|
|
611
|
+
@cached_property
|
|
612
|
+
def _sqrt_cov(self) -> Array:
|
|
613
|
+
prec = self._op.materialize_precision()
|
|
614
|
+
eigenvalues, evecs = jnp.linalg.eigh(prec)
|
|
615
|
+
sqrt_eval = jnp.sqrt(1 / eigenvalues)
|
|
616
|
+
assert self._op._masks is not None
|
|
617
|
+
sqrt_eval = sqrt_eval.at[..., ~self._op._masks].set(0.0)
|
|
618
|
+
|
|
619
|
+
event_shape = sqrt_eval.shape[-1]
|
|
620
|
+
shape = sqrt_eval.shape + (event_shape,)
|
|
621
|
+
|
|
622
|
+
r = tuple(range(event_shape))
|
|
623
|
+
diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
|
|
624
|
+
return evecs @ diags
|
|
625
|
+
|
|
626
|
+
def _sample_n(self, n, seed=None) -> Array:
|
|
627
|
+
shape = [n] + self.batch_shape + self.event_shape
|
|
628
|
+
|
|
629
|
+
# The added dimension at the end here makes sure that matrix multiplication
|
|
630
|
+
# with the "sqrt pcov" matrices works out correctly.
|
|
631
|
+
z = jax.random.normal(key=seed, shape=shape + [1])
|
|
632
|
+
|
|
633
|
+
# Add a dimension at 0 for the sample size.
|
|
634
|
+
sqrt_cov = jnp.expand_dims(self._sqrt_cov, 0)
|
|
635
|
+
centered_samples = jnp.reshape(sqrt_cov @ z, shape)
|
|
636
|
+
|
|
637
|
+
# Add a dimension at 0 for the sample size.
|
|
638
|
+
loc = jnp.expand_dims(self._loc, 0)
|
|
639
|
+
|
|
640
|
+
return centered_samples + loc
|
liesel_gam/kernel.py
CHANGED
|
@@ -6,12 +6,18 @@ import liesel.goose as gs
|
|
|
6
6
|
import liesel.model as lsl
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def star_ig_gibbs(
|
|
10
|
-
|
|
9
|
+
def star_ig_gibbs(
|
|
10
|
+
coef: lsl.Var, scale: lsl.Var, penalty: jax.typing.ArrayLike | None = None
|
|
11
|
+
) -> gs.GibbsKernel:
|
|
12
|
+
variance_var = scale.value_node[0] # type: ignore
|
|
11
13
|
a_value = variance_var.dist_node["concentration"].value # type: ignore
|
|
12
14
|
b_value = variance_var.dist_node["scale"].value # type: ignore
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
if coef.dist_node is None:
|
|
17
|
+
penalty_value = jnp.asarray(penalty)
|
|
18
|
+
else:
|
|
19
|
+
penalty_value = coef.dist_node["penalty"].value # type: ignore
|
|
20
|
+
|
|
15
21
|
rank_value = jnp.linalg.matrix_rank(penalty_value)
|
|
16
22
|
|
|
17
23
|
model = coef.model
|
|
@@ -23,7 +29,7 @@ def star_ig_gibbs(coef: lsl.Var) -> gs.GibbsKernel:
|
|
|
23
29
|
def transition(prng_key, model_state):
|
|
24
30
|
pos = model.extract_position([coef.name], model_state)
|
|
25
31
|
|
|
26
|
-
coef_value = pos[coef.name]
|
|
32
|
+
coef_value = pos[coef.name]
|
|
27
33
|
|
|
28
34
|
a_gibbs = jnp.squeeze(a_value + 0.5 * rank_value)
|
|
29
35
|
b_gibbs = jnp.squeeze(b_value + 0.5 * (coef_value @ penalty_value @ coef_value))
|
|
@@ -35,14 +41,19 @@ def star_ig_gibbs(coef: lsl.Var) -> gs.GibbsKernel:
|
|
|
35
41
|
return gs.GibbsKernel([name], transition)
|
|
36
42
|
|
|
37
43
|
|
|
38
|
-
def init_star_ig_gibbs(
|
|
44
|
+
def init_star_ig_gibbs(
|
|
45
|
+
position_keys: Sequence[str],
|
|
46
|
+
coef: lsl.Var,
|
|
47
|
+
scale: lsl.Var,
|
|
48
|
+
penalty: jax.typing.ArrayLike | None = None,
|
|
49
|
+
) -> gs.GibbsKernel:
|
|
39
50
|
if len(position_keys) != 1:
|
|
40
51
|
raise ValueError("The position keys must be a single key.")
|
|
41
52
|
|
|
42
|
-
variance_var =
|
|
53
|
+
variance_var = scale.value_node[0] # type: ignore
|
|
43
54
|
name = variance_var.name
|
|
44
55
|
|
|
45
56
|
if position_keys[0] != name:
|
|
46
57
|
raise ValueError(f"The position key must be {name}.")
|
|
47
58
|
|
|
48
|
-
return star_ig_gibbs(coef) # type: ignore
|
|
59
|
+
return star_ig_gibbs(coef, scale, penalty) # type: ignore
|