liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liesel_gam/dist.py CHANGED
@@ -1,4 +1,7 @@
1
- from functools import cached_property
1
+ from collections.abc import Callable, Sequence
2
+ from functools import cached_property, reduce
3
+ from math import prod
4
+ from typing import Self
2
5
 
3
6
  import jax
4
7
  import jax.numpy as jnp
@@ -98,3 +101,540 @@ class MultivariateNormalSingular(tfd.Distribution):
98
101
  r = tuple(range(event_shape))
99
102
  diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
100
103
  return evecs @ diags
104
+
105
+
106
+ def _diag_of_kron_of_diag_with_identities(in_diags: Sequence[jax.Array]) -> jax.Array:
107
+ sizes = [v.shape[-1] for v in in_diags]
108
+ diag = jnp.zeros(prod(sizes))
109
+
110
+ for j in range(len(in_diags)):
111
+ left_size = prod(sizes[:j])
112
+ right_size = prod(sizes[(j + 1) :])
113
+ d = in_diags[j]
114
+
115
+ # First handle identities to the right: repeat
116
+ d_rep = jnp.repeat(d, right_size)
117
+
118
+ # Then handle identities to the left: tile
119
+ d_tile = jnp.tile(d_rep, left_size)
120
+
121
+ diag = diag + d_tile
122
+
123
+ return diag
124
+
125
+
126
+ def _materialize_precision(penalties: Sequence[jax.Array]) -> jax.Array:
127
+ """
128
+ Build K(tau^2) = sum_{j=1}^p K_j / tau_j^2
129
+ with K_j = I_{d1} ⊗ ... ⊗ I_{dj-1} ⊗ Ktilde_j ⊗ I_{dj+1} ⊗ ... ⊗ I_{dp}.
130
+
131
+ Parameters
132
+ ----------
133
+ tau2 : array-like, shape (p,)
134
+ Squared smoothing parameters (tau_1^2, ..., tau_p^2).
135
+ K_tilde : sequence of arrays
136
+ List/tuple of p matrices, K_tilde[j] has shape (d_j, d_j).
137
+ dims : sequence of ints, optional
138
+ d_j for each dimension. If None, inferred from K_tilde.
139
+
140
+ Returns
141
+ -------
142
+ K : jnp.ndarray, shape (∏ d_j, ∏ d_j)
143
+ """
144
+ p = len(penalties)
145
+
146
+ dims = [penalties[j].shape[-1] for j in range(p)]
147
+
148
+ # Build K_1 / tau1^2 as initial value
149
+ factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
150
+
151
+ def kron_all(mats):
152
+ """Kronecker product of a list of matrices."""
153
+ return reduce(jnp.kron, mats)
154
+
155
+ K = kron_all(factors)
156
+
157
+ # Add remaining K_j / tau_j^2
158
+ for j in range(1, p):
159
+ factors = [penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)]
160
+ Kj = kron_all(factors)
161
+ K = K + Kj
162
+
163
+ return K
164
+
165
+
166
+ def _compute_masks(
167
+ penalties: Sequence[jax.Array],
168
+ penalties_eigvalues: Sequence[jax.Array],
169
+ eps: float = 1e-6,
170
+ ) -> jax.Array:
171
+ diag = _diag_of_kron_of_diag_with_identities
172
+
173
+ B = penalties_eigvalues[0].shape[:-1]
174
+ B_flat = int(jnp.prod(jnp.array(B))) if B else 1
175
+ flat_evs = [ev.reshape(B_flat, ev.shape[-1]) for ev in penalties_eigvalues]
176
+
177
+ diags = jax.vmap(diag)(flat_evs) # (B_flat, N)
178
+ K = _materialize_precision(penalties) # (B, N, N)
179
+ K = K.reshape((B_flat,) + K.shape[-2:])
180
+
181
+ ranks = jax.vmap(jnp.linalg.matrix_rank)(K) # (B_flat,)
182
+ masks = (diags > eps).sum(-1) # (B_flat,)
183
+
184
+ if not jnp.allclose(masks, ranks):
185
+ raise ValueError(
186
+ f"Number of zero eigenvalues ({masks}) does not "
187
+ f"correspond to penalty rank ({ranks}). Maybe a different value for "
188
+ f"{eps=} can help."
189
+ )
190
+ mask = diags > eps
191
+
192
+ return mask.reshape(B + (mask.shape[-1],))
193
+
194
+
195
+ def _apply_Ki_along_axis(B, K, axis, dims, total_size):
196
+ """
197
+ Apply K (Di x Di) along a given axis of B (shape dims),
198
+ returning an array with the same shape as B.
199
+ """
200
+ # Move the axis we want to the front: (Di, rest...)
201
+ B_perm = jnp.moveaxis(B, axis, 0)
202
+ Di = dims[axis]
203
+ rest = total_size // Di
204
+
205
+ # Flatten everything except that axis: (Di, rest)
206
+ B_flat = B_perm.reshape(Di, rest)
207
+
208
+ # Matrix multiply: K @ B_flat -> (Di, rest)
209
+ C_flat = K @ B_flat
210
+
211
+ # Restore original shape/order
212
+ C_perm = C_flat.reshape(B_perm.shape)
213
+ C = jnp.moveaxis(C_perm, 0, axis) # back to shape = dims
214
+ return C
215
+
216
+
217
+ def _kron_sum_quadratic(x: jax.Array, Ks: Sequence[jax.Array]) -> jax.Array:
218
+ dims = [K.shape[0] for K in Ks]
219
+ # Basic sanity checks (cheap, can remove if you like)
220
+ for K, d in zip(Ks, dims):
221
+ assert K.shape == (d, d)
222
+ total_size = prod(dims)
223
+ assert x.size == total_size
224
+
225
+ # Reshape x into m-dimensional tensor
226
+ B = x.reshape(dims)
227
+
228
+ total = jnp.array(0.0, dtype=x.dtype)
229
+ for axis, K in enumerate(Ks):
230
+ C = _apply_Ki_along_axis(B, K, axis, dims, total_size)
231
+ total = total + jnp.vdot(B, C) # scalar
232
+
233
+ return total
234
+
235
+
236
+ class StructuredPenaltyOperator:
237
+ """
238
+ - scales is an array with shape (B,K), where B is the batch shape and K is the
239
+ number of penalties. Each scale parameter corresponds to one penalty.
240
+ - penalties is a sequence of length K, containing arrays with shape (B, Di, Di).
241
+ B is the batch shape.
242
+ (Di, Di) is the block size of the individual penalty and can differ between
243
+ elements of the penalties sequence.
244
+ N = prod([p.shape[-1] for p in penalties]).
245
+ - penalties_eigvalues is a sequence of length K, containing arrays of shape (B, Di).
246
+ - penalties_eigvectors is a sequence of length K, containing arrays of shape
247
+ (B, Di, Di).
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ scales: jax.Array,
253
+ penalties: Sequence[jax.Array],
254
+ penalties_eigvalues: Sequence[jax.Array],
255
+ masks: jax.Array | None = None,
256
+ validate_args: bool = False,
257
+ tol: float = 1e-6,
258
+ ) -> None:
259
+ self._scales = jnp.asarray(scales)
260
+ self._penalties = tuple([jnp.asarray(p) for p in penalties])
261
+ self._penalties_eigvalues = tuple(
262
+ [jnp.asarray(ev) for ev in penalties_eigvalues]
263
+ )
264
+
265
+ if validate_args:
266
+ self._validate_penalties()
267
+
268
+ self._sizes = [K.shape[-1] for K in self._penalties]
269
+
270
+ self._masks = masks
271
+ self._tol = tol
272
+
273
+ @classmethod
274
+ def from_penalties(
275
+ cls, scales: jax.Array, penalties: Sequence[jax.Array], eps: float = 1e-6
276
+ ) -> Self:
277
+ evs = [jnp.linalg.eigh(K) for K in penalties]
278
+ evals = [ev.eigenvalues for ev in evs]
279
+
280
+ masks = _compute_masks(penalties=penalties, penalties_eigvalues=evals, eps=eps)
281
+
282
+ return cls(
283
+ scales=scales,
284
+ penalties=penalties,
285
+ penalties_eigvalues=evals,
286
+ masks=masks,
287
+ )
288
+
289
+ @cached_property
290
+ def variances(self) -> jax.Array:
291
+ return jnp.square(self._scales)
292
+
293
+ def materialize_precision(self) -> jax.Array:
294
+ return self._materialize_precision(self.variances)
295
+
296
+ def materialize_penalty(self) -> jax.Array:
297
+ return self._materialize_precision(jnp.ones_like(self.variances))
298
+
299
+ def _materialize_precision(self, variances: jax.Array) -> jax.Array:
300
+ """This is inefficient, should be used for testing only."""
301
+ p = len(self._penalties)
302
+
303
+ dims = [self._penalties[j].shape[-1] for j in range(p)]
304
+
305
+ def kron_all(mats):
306
+ """Kronecker product of a list of matrices."""
307
+ return reduce(jnp.kron, mats)
308
+
309
+ def one_batch(variances, *penalties):
310
+ # Build K_1 / tau1^2 as initial value
311
+ factors = [penalties[0] if i == 0 else jnp.eye(dims[i]) for i in range(p)]
312
+
313
+ K = kron_all(factors) / variances[0]
314
+
315
+ # Add remaining K_j / tau_j^2
316
+ for j in range(1, p):
317
+ factors = [
318
+ penalties[j] if i == j else jnp.eye(dims[i]) for i in range(p)
319
+ ]
320
+ Kj = kron_all(factors)
321
+ K = K + Kj / variances[j]
322
+
323
+ return K
324
+
325
+ batch_shape = variances.shape[:-1]
326
+ K = variances.shape[-1]
327
+
328
+ # flatten batch dims so we can vmap over a single leading dim
329
+ B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
330
+ tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
331
+ pens_flat = [p.reshape((B_flat,) + p.shape[-2:]) for p in self._penalties]
332
+
333
+ big_K_fun = jax.vmap(one_batch, in_axes=(0,) + (0,) * K)
334
+ big_K = big_K_fun(tau2_flat, *pens_flat)
335
+
336
+ N = prod(dims)
337
+ big_K = jnp.reshape(big_K, batch_shape + (N, N))
338
+ return big_K
339
+
340
+ def _sum_of_scaled_eigenvalues(
341
+ self, variances: jax.Array, eigenvalues: Sequence[jax.Array]
342
+ ) -> jax.Array:
343
+ """
344
+ Expects
345
+ - variances (p,)
346
+ - eigenvalues (p, Di)
347
+
348
+ Returns (N,) where N = prod(Di)
349
+ """
350
+ diag = jnp.zeros(prod(self._sizes))
351
+
352
+ for j in range(len(self._penalties)):
353
+ left_size = prod(self._sizes[:j])
354
+ right_size = prod(self._sizes[(j + 1) :])
355
+ d = eigenvalues[j] / variances[j]
356
+
357
+ # First handle identities to the right: repeat
358
+ d_rep = jnp.repeat(d, right_size)
359
+
360
+ # Then handle identities to the left: tile
361
+ d_tile = jnp.tile(d_rep, left_size)
362
+
363
+ diag = diag + d_tile
364
+
365
+ return diag
366
+
367
+ def log_pdet(self) -> jax.Array:
368
+ variances = self.variances # shape (B..., K)
369
+ batch_shape = variances.shape[:-1]
370
+ K = variances.shape[-1]
371
+
372
+ # flatten batch dims so we can vmap over a single leading dim
373
+ B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
374
+ tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
375
+
376
+ # eigenvalues per penalty, flattened over batch
377
+ eigvals_flat = [
378
+ ev.reshape(B_flat, ev.shape[-1]) for ev in self._penalties_eigvalues
379
+ ] # list of K arrays (B_flat, Di)
380
+
381
+ def _single_diag(variances, *eigenvalues):
382
+ diag = self._sum_of_scaled_eigenvalues(variances, eigenvalues)
383
+ return diag
384
+
385
+ # vmap over flattened batch dimension
386
+ diag_flat = jax.vmap(_single_diag, in_axes=(0,) + (0,) * K)(
387
+ tau2_flat,
388
+ *eigvals_flat,
389
+ )
390
+ diag = jnp.reshape(diag_flat, batch_shape + (diag_flat.shape[-1],))
391
+
392
+ if self._masks is None:
393
+ mask = diag > self._tol
394
+ else:
395
+ mask = self._masks
396
+
397
+ logdet = jnp.log(jnp.where(mask, diag, 1.0)).sum(-1)
398
+
399
+ return logdet
400
+
401
+ def quad_form(self, x: jax.Array) -> jax.Array:
402
+ variances = self.variances # shape (B..., K)
403
+ batch_shape = variances.shape[:-1]
404
+ batch_shape_x = x.shape[:-1]
405
+ if batch_shape_x and (batch_shape_x != batch_shape):
406
+ raise ValueError(
407
+ f"x has batch shape {batch_shape_x}, but batch size is {batch_shape}."
408
+ )
409
+
410
+ K = variances.shape[-1]
411
+ N = x.shape[-1]
412
+
413
+ # flatten batch dims so we can vmap over a single leading dim
414
+ B_flat = int(jnp.prod(jnp.array(batch_shape))) if batch_shape else 1
415
+ tau2_flat = variances.reshape(B_flat, K) # (B_flat, K)
416
+ if batch_shape_x:
417
+ x_flat = x.reshape(B_flat, N) # (B_flat, N)
418
+ in_axis_x = 0
419
+ else:
420
+ x_flat = x
421
+ in_axis_x = None
422
+
423
+ pens_flat = [p.reshape((B_flat,) + p.shape[-2:]) for p in self._penalties]
424
+
425
+ def kron_sum_quadratic(x, variances, *penalties):
426
+ p = penalties
427
+ v = variances
428
+ scaled_penalties = [p[i] / v[i] for i in range(len(penalties))]
429
+ return _kron_sum_quadratic(x, scaled_penalties)
430
+
431
+ quad_form_vec = jax.vmap(
432
+ kron_sum_quadratic, in_axes=(in_axis_x,) + (0,) + (0,) * K
433
+ )
434
+ quad_form_out = quad_form_vec(x_flat, tau2_flat, *pens_flat)
435
+
436
+ quad_form_out = jnp.reshape(quad_form_out, batch_shape)
437
+ return quad_form_out
438
+
439
+ def _validate_penalties(self) -> None:
440
+ # validate number of penalty matrices
441
+ n_penalties1 = self._scales.shape[-1]
442
+ n_penalties2 = len(self._penalties)
443
+ n_penalties3 = len(self._penalties_eigvalues)
444
+
445
+ if not len({n_penalties1, n_penalties2, n_penalties3}) == 1:
446
+ msg1 = "Got inconsistent numbers of penalties. "
447
+ msg2 = f"Number of scale parameters: {n_penalties1}"
448
+ msg3 = f"Number of penalty matrices: {n_penalties2}. "
449
+ msg4 = f"Number of eigenvalue vectors: {n_penalties3}. "
450
+ raise ValueError(msg1 + msg2 + msg3 + msg4)
451
+
452
+
453
+ class MultivariateNormalStructured(tfd.Distribution):
454
+ """
455
+ - loc is an array with shape (B, N), where B is the batch shape and N is the
456
+ event shape.
457
+ - scales is an array with shape (B,K), where B is the batch shape and K is the
458
+ number of penalties. Each scale parameter corresponds to one penalty.
459
+ - penalties is a sequence of length K, containing arrays with shape (B, Di, Di).
460
+ B is the batch shape.
461
+ (Di, Di) is the block size of the individual penalty and can differ between
462
+ elements of the penalties sequence.
463
+ N = prod([p.shape[-1] for p in penalties]).
464
+ - penalties_eigvalues is a sequence of length K, containing arrays of shape (B, Di).
465
+ - penalties_eigvectors is a sequence of length K, containing arrays of shape
466
+ (B, Di, Di).
467
+ """
468
+
469
+ def __init__(
470
+ self,
471
+ loc: Array,
472
+ op: StructuredPenaltyOperator,
473
+ validate_args: bool = False,
474
+ allow_nan_stats: bool = True,
475
+ name: str = "MultivariateNormalStructuredSingular",
476
+ include_normalizing_constant: bool = True,
477
+ ):
478
+ parameters = dict(locals())
479
+
480
+ self._loc = jnp.asarray(loc)
481
+ self._op = op
482
+ self._n = self._loc.shape[-1]
483
+ self._include_normalizing_constant = include_normalizing_constant
484
+
485
+ if validate_args:
486
+ self._validate_penalties()
487
+ self._validate_event_dim()
488
+
489
+ super().__init__(
490
+ dtype=self._loc.dtype,
491
+ reparameterization_type=tfd.NOT_REPARAMETERIZED,
492
+ validate_args=validate_args,
493
+ allow_nan_stats=allow_nan_stats,
494
+ parameters=parameters,
495
+ name=name,
496
+ )
497
+
498
+ def _validate_event_dim(self) -> None:
499
+ # validate sample size
500
+ n_loc = self._loc.shape[-1]
501
+ ndim_penalties = [p.shape[-1] for p in self._op._penalties]
502
+ n_penalties = prod(ndim_penalties)
503
+
504
+ if not n_loc == n_penalties:
505
+ msg1 = "Got inconsistent event dimensions. "
506
+ msg2 = f"Event dimension implied by loc: {n_loc}. "
507
+ msg3 = f"Event dimension implied by penalties: {n_penalties}"
508
+ raise ValueError(msg1 + msg2 + msg3)
509
+
510
+ def _batch_shape(self):
511
+ variances = self._op.variances # shape (B..., K)
512
+ batch_shape = tuple(variances.shape[:-1])
513
+ return tf.TensorShape(batch_shape)
514
+
515
+ def _batch_shape_tensor(self):
516
+ variances = self._op.variances # shape (B..., K)
517
+ batch_shape = tuple(variances.shape[:-1])
518
+ return jnp.array(batch_shape, dtype=self._loc.dtype)
519
+
520
+ def _event_shape(self):
521
+ return tf.TensorShape((jnp.shape(self._loc)[-1],))
522
+
523
+ def _event_shape_tensor(self):
524
+ return jnp.array((jnp.shape(self._loc)[-1],), dtype=self._loc.dtype)
525
+
526
+ def _log_prob(self, x: Array) -> jax.Array:
527
+ x = jnp.asarray(x)
528
+ x_centered = x - self._loc
529
+
530
+ log_pdet = self._op.log_pdet()
531
+ quad_form = self._op.quad_form(x_centered)
532
+
533
+ # early returns, minimally more efficient
534
+ if not self._include_normalizing_constant:
535
+ return 0.5 * (log_pdet - quad_form)
536
+
537
+ const = -(self._n / 2) * jnp.log(2 * jnp.pi)
538
+
539
+ return 0.5 * (log_pdet - quad_form) + const
540
+
541
+ @classmethod
542
+ def from_penalties(
543
+ cls,
544
+ loc: Array,
545
+ scales: Array,
546
+ penalties: Sequence[Array],
547
+ tol: float = 1e-6,
548
+ validate_args: bool = False,
549
+ allow_nan_stats: bool = True,
550
+ include_normalizing_constant: bool = True,
551
+ ) -> Self:
552
+ """
553
+ This is expensive, because it computes eigenvalue decompositions of all
554
+ penalty matrices. Should only be used when performance is irrelevant.
555
+ """
556
+ constructor = cls.get_locscale_constructor(
557
+ penalties=penalties,
558
+ tol=tol,
559
+ precompute_masks=False,
560
+ validate_args=validate_args,
561
+ allow_nan_stats=allow_nan_stats,
562
+ include_normalizing_constant=include_normalizing_constant,
563
+ )
564
+
565
+ return constructor(loc, scales)
566
+
567
+ @classmethod
568
+ def get_locscale_constructor(
569
+ cls,
570
+ penalties: Sequence[Array],
571
+ tol: float = 1e-6,
572
+ precompute_masks: bool = True,
573
+ validate_args: bool = False,
574
+ allow_nan_stats: bool = True,
575
+ include_normalizing_constant: bool = True,
576
+ ) -> Callable[[Array, Array], "MultivariateNormalStructured"]:
577
+ penalties_ = [jnp.asarray(p) for p in penalties]
578
+ evs = [jnp.linalg.eigh(K) for K in penalties]
579
+ evals = [ev.eigenvalues for ev in evs]
580
+
581
+ if precompute_masks:
582
+ masks = _compute_masks(
583
+ penalties=penalties_, penalties_eigvalues=evals, eps=tol
584
+ )
585
+ else:
586
+ masks = None
587
+
588
+ def construct_dist(loc: Array, scales: Array) -> "MultivariateNormalStructured":
589
+ loc = jnp.asarray(loc)
590
+ scales = jnp.asarray(scales)
591
+ op = StructuredPenaltyOperator(
592
+ scales=scales,
593
+ penalties=penalties_,
594
+ penalties_eigvalues=evals,
595
+ masks=masks,
596
+ tol=tol,
597
+ )
598
+
599
+ dist = cls(
600
+ loc=loc,
601
+ op=op,
602
+ validate_args=validate_args,
603
+ allow_nan_stats=allow_nan_stats,
604
+ include_normalizing_constant=include_normalizing_constant,
605
+ )
606
+
607
+ return dist
608
+
609
+ return construct_dist
610
+
611
+ @cached_property
612
+ def _sqrt_cov(self) -> Array:
613
+ prec = self._op.materialize_precision()
614
+ eigenvalues, evecs = jnp.linalg.eigh(prec)
615
+ sqrt_eval = jnp.sqrt(1 / eigenvalues)
616
+ assert self._op._masks is not None
617
+ sqrt_eval = sqrt_eval.at[..., ~self._op._masks].set(0.0)
618
+
619
+ event_shape = sqrt_eval.shape[-1]
620
+ shape = sqrt_eval.shape + (event_shape,)
621
+
622
+ r = tuple(range(event_shape))
623
+ diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
624
+ return evecs @ diags
625
+
626
+ def _sample_n(self, n, seed=None) -> Array:
627
+ shape = [n] + self.batch_shape + self.event_shape
628
+
629
+ # The added dimension at the end here makes sure that matrix multiplication
630
+ # with the "sqrt pcov" matrices works out correctly.
631
+ z = jax.random.normal(key=seed, shape=shape + [1])
632
+
633
+ # Add a dimension at 0 for the sample size.
634
+ sqrt_cov = jnp.expand_dims(self._sqrt_cov, 0)
635
+ centered_samples = jnp.reshape(sqrt_cov @ z, shape)
636
+
637
+ # Add a dimension at 0 for the sample size.
638
+ loc = jnp.expand_dims(self._loc, 0)
639
+
640
+ return centered_samples + loc
liesel_gam/kernel.py CHANGED
@@ -6,12 +6,18 @@ import liesel.goose as gs
6
6
  import liesel.model as lsl
7
7
 
8
8
 
9
- def star_ig_gibbs(coef: lsl.Var) -> gs.GibbsKernel:
10
- variance_var = coef.dist_node["scale"].value_node[0] # type: ignore
9
+ def star_ig_gibbs(
10
+ coef: lsl.Var, scale: lsl.Var, penalty: jax.typing.ArrayLike | None = None
11
+ ) -> gs.GibbsKernel:
12
+ variance_var = scale.value_node[0] # type: ignore
11
13
  a_value = variance_var.dist_node["concentration"].value # type: ignore
12
14
  b_value = variance_var.dist_node["scale"].value # type: ignore
13
15
 
14
- penalty_value = coef.dist_node["penalty"].value # type: ignore
16
+ if coef.dist_node is None:
17
+ penalty_value = jnp.asarray(penalty)
18
+ else:
19
+ penalty_value = coef.dist_node["penalty"].value # type: ignore
20
+
15
21
  rank_value = jnp.linalg.matrix_rank(penalty_value)
16
22
 
17
23
  model = coef.model
@@ -23,7 +29,7 @@ def star_ig_gibbs(coef: lsl.Var) -> gs.GibbsKernel:
23
29
  def transition(prng_key, model_state):
24
30
  pos = model.extract_position([coef.name], model_state)
25
31
 
26
- coef_value = pos[coef.name].squeeze()
32
+ coef_value = pos[coef.name]
27
33
 
28
34
  a_gibbs = jnp.squeeze(a_value + 0.5 * rank_value)
29
35
  b_gibbs = jnp.squeeze(b_value + 0.5 * (coef_value @ penalty_value @ coef_value))
@@ -35,14 +41,19 @@ def star_ig_gibbs(coef: lsl.Var) -> gs.GibbsKernel:
35
41
  return gs.GibbsKernel([name], transition)
36
42
 
37
43
 
38
- def init_star_ig_gibbs(position_keys: Sequence[str], coef: lsl.Var) -> gs.GibbsKernel:
44
+ def init_star_ig_gibbs(
45
+ position_keys: Sequence[str],
46
+ coef: lsl.Var,
47
+ scale: lsl.Var,
48
+ penalty: jax.typing.ArrayLike | None = None,
49
+ ) -> gs.GibbsKernel:
39
50
  if len(position_keys) != 1:
40
51
  raise ValueError("The position keys must be a single key.")
41
52
 
42
- variance_var = coef.dist_node["scale"].value_node[0] # type: ignore
53
+ variance_var = scale.value_node[0] # type: ignore
43
54
  name = variance_var.name
44
55
 
45
56
  if position_keys[0] != name:
46
57
  raise ValueError(f"The position key must be {name}.")
47
58
 
48
- return star_ig_gibbs(coef) # type: ignore
59
+ return star_ig_gibbs(coef, scale, penalty) # type: ignore