liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liesel_gam/var.py CHANGED
@@ -1,96 +1,581 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Callable
4
- from typing import Any, Self
3
+ import copy
4
+ from collections.abc import Callable, Sequence
5
+ from functools import reduce
6
+ from typing import Any, Literal, NamedTuple, Self
5
7
 
6
8
  import jax
7
9
  import jax.numpy as jnp
8
10
  import liesel.goose as gs
9
11
  import liesel.model as lsl
10
12
  import tensorflow_probability.substrates.jax.distributions as tfd
13
+ from formulaic import ModelSpec
11
14
 
12
- from .dist import MultivariateNormalSingular
15
+ from liesel_gam.builder.category_mapping import CategoryMapping
16
+
17
+ from .constraint import LinearConstraintEVD, penalty_to_unit_design
18
+ from .dist import MultivariateNormalSingular, MultivariateNormalStructured
13
19
  from .kernel import init_star_ig_gibbs
14
- from .roles import Roles
15
20
 
16
21
  InferenceTypes = Any
17
- Array = Any
22
+ Array = jax.Array
23
+ ArrayLike = jax.typing.ArrayLike
24
+
25
+
26
+ class VarIGPrior(NamedTuple):
27
+ concentration: float
28
+ scale: float
29
+
30
+
31
+ def _append_name(name: str, append: str) -> str:
32
+ if name == "":
33
+ return ""
34
+ else:
35
+ return name + append
36
+
37
+
38
+ def _ensure_var_or_node(
39
+ x: lsl.Var | lsl.Node | ArrayLike,
40
+ name: str | None,
41
+ ) -> lsl.Var | lsl.Node:
42
+ """
43
+ If x is an array, creates a new observed variable.
44
+ """
45
+ if isinstance(x, lsl.Var | lsl.Node):
46
+ x_var = x
47
+ else:
48
+ name = name if name is not None else ""
49
+ x_var = lsl.Var.new_obs(jnp.asarray(x), name=name)
50
+
51
+ if name is not None and x_var.name != name:
52
+ raise ValueError(f"{x_var.name=} and {name=} are incompatible.")
53
+
54
+ return x_var
55
+
56
+
57
+ def _ensure_value(
58
+ x: lsl.Var | lsl.Node | ArrayLike,
59
+ name: str | None,
60
+ ) -> lsl.Var | lsl.Node:
61
+ """
62
+ If x is an array, creates a new value node.
63
+ """
64
+ if isinstance(x, lsl.Var | lsl.Node):
65
+ x_var = x
66
+ else:
67
+ name = name if name is not None else ""
68
+ x_var = lsl.Value(jnp.asarray(x), _name=name)
69
+
70
+ if name is not None and x_var.name != name:
71
+ raise ValueError(f"{x_var.name=} and {name=} are incompatible.")
72
+
73
+ return x_var
74
+
75
+
76
+ class UserVar(lsl.Var):
77
+ @classmethod
78
+ def new_calc(cls, *args, **kwargs) -> None: # type: ignore
79
+ raise NotImplementedError(
80
+ f"This constructor is not implemented on {cls.__name__}."
81
+ )
82
+
83
+ @classmethod
84
+ def new_obs(cls, *args, **kwargs) -> None: # type: ignore
85
+ raise NotImplementedError(
86
+ f"This constructor is not implemented on {cls.__name__}."
87
+ )
88
+
89
+ @classmethod
90
+ def new_param(cls, *args, **kwargs) -> None: # type: ignore
91
+ raise NotImplementedError(
92
+ f"This constructor is not implemented on {cls.__name__}."
93
+ )
94
+
95
+ @classmethod
96
+ def new_value(cls, *args, **kwargs) -> None: # type: ignore
97
+ raise NotImplementedError(
98
+ f"This constructor is not implemented on {cls.__name__}."
99
+ )
100
+
101
+
102
+ def mvn_diag_prior(scale: lsl.Var) -> lsl.Dist:
103
+ return lsl.Dist(tfd.Normal, loc=0.0, scale=scale)
104
+
105
+
106
+ def mvn_structured_prior(scale: lsl.Var, penalty: lsl.Var | lsl.Value) -> lsl.Dist:
107
+ if isinstance(penalty, lsl.Var) and not penalty.strong:
108
+ raise NotImplementedError(
109
+ "Varying penalties or currently not supported by this function."
110
+ )
111
+ prior = lsl.Dist(
112
+ MultivariateNormalSingular,
113
+ loc=0.0,
114
+ scale=scale,
115
+ penalty=penalty,
116
+ penalty_rank=jnp.linalg.matrix_rank(penalty.value),
117
+ )
118
+ return prior
119
+
120
+
121
+ def term_prior(
122
+ scale: lsl.Var | Array | None,
123
+ penalty: lsl.Var | lsl.Value | Array | None,
124
+ ) -> lsl.Dist | None:
125
+ """
126
+ Returns
127
+ - None if scale=None
128
+ - A simple Normal prior with loc=0.0 and scale=scale if penalty=None
129
+ - A potentially rank-deficient structured multivariate normal prior otherwise
130
+ """
131
+ if scale is None:
132
+ if penalty is not None:
133
+ raise ValueError(f"If {scale=}, then penalty must also be None.")
134
+ return None
135
+
136
+ if not isinstance(scale, lsl.Var | lsl.Value):
137
+ scale = lsl.Var(scale)
138
+
139
+ if penalty is None:
140
+ return mvn_diag_prior(scale)
141
+
142
+ if not isinstance(penalty, lsl.Var | lsl.Value):
143
+ penalty = lsl.Value(penalty)
144
+
145
+ return mvn_structured_prior(scale, penalty)
146
+
18
147
 
148
+ class ScaleIG(UserVar):
149
+ """
150
+ A variable with an Inverse Gamma prior on its square.
151
+
152
+ The variance parameter (i.e. the squared scale) is flagged as a parameter.
153
+
154
+ Parameters
155
+ ----------
156
+ value
157
+ Initial value of the variable.
158
+ concentration
159
+ Concentration parameter of the inverse gamma distribution.\
160
+ In some parameterizations, this parameter is called ``a``.
161
+ scale
162
+ Scale parameter of the inverse gamma distribution.\
163
+ In some parameterizations, this parameter is called ``b``.
164
+ name
165
+ Name of the variable.
166
+ inference
167
+ Inference type.
168
+ """
19
169
 
20
- class SmoothTerm(lsl.Var):
21
170
  def __init__(
22
171
  self,
23
- basis: Basis | lsl.Var,
24
- penalty: lsl.Var | Array,
25
- scale: lsl.Var,
26
- name: str,
172
+ value: float | Array,
173
+ concentration: float | lsl.Var | lsl.Node | ArrayLike,
174
+ scale: float | lsl.Var | lsl.Node | ArrayLike,
175
+ name: str = "",
176
+ variance_name: str = "",
27
177
  inference: InferenceTypes = None,
28
- coef_name: str | None = None,
29
178
  ):
30
- coef_name = f"{name}_coef" if coef_name is None else coef_name
31
-
32
- if not jnp.asarray(basis.value).ndim == 2:
33
- raise ValueError(f"basis must have 2 dimensions, got {basis.value.ndim}.")
179
+ value = jnp.asarray(value)
180
+ if value.size != 1:
181
+ raise ValueError(
182
+ f"Expected scalar value for ScaleIG, got size {value.size}."
183
+ )
34
184
 
35
- nbases = jnp.shape(basis.value)[-1]
185
+ concentration_node = _ensure_value(
186
+ concentration, name=_append_name(name, "_concentration")
187
+ )
188
+ scale_node = _ensure_value(scale, name=_append_name(name, "_scale"))
36
189
 
37
190
  prior = lsl.Dist(
38
- MultivariateNormalSingular,
39
- loc=0.0,
40
- scale=scale,
41
- penalty=penalty,
42
- penalty_rank=jnp.linalg.matrix_rank(penalty),
191
+ tfd.InverseGamma, concentration=concentration_node, scale=scale_node
43
192
  )
44
193
 
45
- self.scale = scale
46
- self.nbases = nbases
194
+ variance_name = variance_name or _append_name(name, "_square")
195
+
196
+ self._variance_param = lsl.Var.new_param(
197
+ value, prior, inference=inference, name=variance_name
198
+ )
199
+ super().__init__(lsl.Calc(jnp.sqrt, self._variance_param), name=name)
200
+
201
+ def setup_gibbs_inference(
202
+ self, coef: lsl.Var, penalty: jax.typing.ArrayLike | None = None
203
+ ) -> ScaleIG:
204
+ init_gibbs = copy.copy(init_star_ig_gibbs)
205
+ init_gibbs.__name__ = "StarVarianceGibbs"
206
+
207
+ self._variance_param.inference = gs.MCMCSpec(
208
+ init_star_ig_gibbs,
209
+ kernel_kwargs={"coef": coef, "scale": self, "penalty": penalty},
210
+ )
211
+ return self
212
+
213
+
214
+ def _init_scale_ig(
215
+ x: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
216
+ validate_scalar: bool = False,
217
+ ) -> ScaleIG | lsl.Var | None:
218
+ if isinstance(x, VarIGPrior):
219
+ concentration = jnp.asarray(x.concentration)
220
+ scale_ = jnp.asarray(x.scale)
221
+
222
+ if validate_scalar:
223
+ if not concentration.size == 1:
224
+ raise ValueError(
225
+ "Expected scalar hyperparameter 'concentration', "
226
+ f"got size {concentration.size}"
227
+ )
228
+
229
+ if not scale_.size == 1:
230
+ raise ValueError(
231
+ f"Expected scalar hyperparameter 'scale', got size {scale_.size}"
232
+ )
233
+
234
+ scale_var: ScaleIG | lsl.Var | None = ScaleIG(
235
+ value=jnp.array(1.0),
236
+ concentration=concentration,
237
+ scale=scale_,
238
+ )
239
+ elif isinstance(x, ScaleIG | lsl.Var):
240
+ if isinstance(x, ScaleIG):
241
+ if x._variance_param.strong:
242
+ x._variance_param.value = jnp.asarray(x._variance_param.value)
243
+ x.update()
244
+ elif x.strong:
245
+ x.value = jnp.asarray(x.value)
246
+
247
+ scale_var = x
248
+ if validate_scalar:
249
+ size = jnp.asarray(scale_var.value).size
250
+ if not size == 1:
251
+ raise ValueError(f"Expected scalar scale, got size {size}")
252
+ elif x is not None:
253
+ scale_var = lsl.Var.new_value(jnp.asarray(x))
254
+ if validate_scalar:
255
+ size = scale_var.value.size
256
+ if not size == 1:
257
+ raise ValueError(f"Expected scalar scale, got size {size}")
258
+ elif x is None:
259
+ scale_var = x
260
+ else:
261
+ raise TypeError(f"Unexpected type for scale: {type(x)}")
262
+
263
+ return scale_var
264
+
265
+
266
+ def _validate_scalar_or_p_scale(scale_value: Array, p):
267
+ try:
268
+ is_scalar = scale_value.size == 1
269
+ except AttributeError:
270
+ raise TypeError(
271
+ f"Expected scale value to be an array, got type {type(scale_value)}"
272
+ )
273
+ is_p = scale_value.size == p
274
+ if not (is_scalar or is_p):
275
+ raise ValueError(
276
+ f"Expected scale to have size 1 or {p}, got size {scale_value.size}"
277
+ )
278
+
279
+
280
+ class Term(UserVar):
281
+ """
282
+ General structured additive term.
283
+
284
+ A structured additive term represents a smooth or structured effect in a
285
+ generalized additive model. The term wraps a design/basis matrix together
286
+ with a prior/penalty and a set of coefficients. The object exposes the
287
+ coefficient variable and evaluates the term as the matrix-vector product
288
+ of the basis and the coefficients.
289
+ The term evaluates to ``basis @ coef``.
290
+
291
+ Parameters
292
+ ----------
293
+ basis
294
+ A :class:`.Basis` instance that produces the design matrix for the \
295
+ term. The basis must evaluate to a 2-D array with shape ``(n_obs, n_bases)``.
296
+ penalty
297
+ Penalty matrix or a variable/value wrapping the penalty \
298
+ used to construct the multivariate normal prior for the coefficients.
299
+ scale
300
+ Scale parameter for the prior on the coefficients. This \
301
+ is typically either a scalar or a per-coefficient scale variable.
302
+ name
303
+ Human-readable name for the term. Used for labelling variables and \
304
+ building sensible default names for internal nodes.
305
+ inference
306
+ :class:`liesel.goose.MCMCSpec` inference specification forwarded to coefficient\
307
+ creation.
308
+ coef_name
309
+ Name for the coefficient variable. If ``None``, a default name based \
310
+ on ``name`` will be used.
311
+ _update_on_init
312
+ If ``True`` (default) the internal calculation/graph nodes are \
313
+ evaluated during initialization. Set to ``False`` to delay \
314
+ initial evaluation.
315
+
316
+ Raises
317
+ ------
318
+ ValueError
319
+ If ``basis.value`` does not have two dimensions.
320
+
321
+ Attributes
322
+ ----------
323
+ scale
324
+ The scale variable used by the prior on the coefficients.
325
+ nbases
326
+ Number of basis functions (number of columns in the basis matrix).
327
+ basis
328
+ The basis object provided to the constructor.
329
+ coef
330
+ The coefficient variable created for this term. It holds the prior
331
+ (multivariate normal singular) and is used in the evaluation of the
332
+ term.
333
+ is_noncentered
334
+ Whether the term has been reparameterized to the non-centered form.
335
+
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ basis: Basis,
341
+ penalty: lsl.Var | lsl.Value | Array | None,
342
+ scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
343
+ name: str = "",
344
+ inference: InferenceTypes = None,
345
+ coef_name: str | None = None,
346
+ _update_on_init: bool = True,
347
+ validate_scalar_scale: bool = True,
348
+ ):
349
+ scale = _init_scale_ig(scale, validate_scalar=validate_scalar_scale)
350
+
351
+ coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
352
+
353
+ prior = term_prior(scale, penalty)
354
+
47
355
  self.basis = basis
356
+
357
+ if isinstance(penalty, lsl.Var | lsl.Value):
358
+ nparam = jnp.shape(penalty.value)[-1]
359
+ elif penalty is not None:
360
+ nparam = jnp.shape(penalty)[-1]
361
+ else:
362
+ nparam = self.nbases
363
+
364
+ if scale is not None:
365
+ _validate_scalar_or_p_scale(scale.value, nparam)
48
366
  self.coef = lsl.Var.new_param(
49
- jnp.zeros(nbases), prior, inference=inference, name=coef_name
367
+ jnp.zeros(nparam), prior, inference=inference, name=coef_name
368
+ )
369
+ calc = lsl.Calc(
370
+ lambda basis, coef: jnp.dot(basis, coef),
371
+ basis=basis,
372
+ coef=self.coef,
373
+ _update_on_init=_update_on_init,
50
374
  )
51
- calc = lsl.Calc(jnp.dot, basis, self.coef)
375
+ self._scale = scale
52
376
 
53
377
  super().__init__(calc, name=name)
378
+ if _update_on_init:
379
+ self.coef.update()
380
+
381
+ self.is_noncentered = False
382
+
383
+ if hasattr(self.scale, "setup_gibbs_inference"):
384
+ try:
385
+ self.scale.setup_gibbs_inference(self.coef) # type: ignore
386
+ except Exception as e:
387
+ raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
388
+
389
+ @property
390
+ def nbases(self) -> int:
391
+ return jnp.shape(self.basis.value)[-1]
392
+
393
+ @property
394
+ def scale(self) -> lsl.Var | lsl.Node | None:
395
+ return self._scale
396
+
397
+ def reparam_noncentered(self) -> Self:
398
+ """
399
+ Turns this term into noncentered form, which means the prior for
400
+ the coefficient will be turned from ``coef ~ N(0, scale^2 * inv(penalty))`` into
401
+ ``latent_coef ~ N(0, inv(penalty)); coef = scale * latent_coef``.
402
+ This can sometimes be helpful when sampling with the No-U-Turn Sampler.
403
+ """
404
+ if self.scale is None:
405
+ raise ValueError(
406
+ f"Noncentering reparameterization of {self} fails, "
407
+ f"because {self.scale=}."
408
+ )
409
+ if self.is_noncentered:
410
+ return self
411
+
412
+ assert self.coef.dist_node is not None
413
+
414
+ self.coef.dist_node["scale"] = lsl.Value(jnp.array(1.0))
415
+
416
+ if self.scale.name and self.coef.name:
417
+ scaled_name = self.scale.name + "*" + self.coef.name
418
+ else:
419
+ scaled_name = _append_name(self.coef.name, "_scaled")
420
+
421
+ scaled_coef = lsl.Var.new_calc(
422
+ lambda scale, coef: scale * coef,
423
+ self.scale,
424
+ self.coef,
425
+ name=scaled_name,
426
+ )
427
+
428
+ self.value_node["coef"] = scaled_coef
54
429
  self.coef.update()
55
430
  self.update()
56
- self.coef.role = Roles.coef_smooth
57
- self.role = Roles.term_smooth
431
+ self.is_noncentered = True
432
+
433
+ if hasattr(self.scale, "setup_gibbs_inference"):
434
+ try:
435
+ pen = self.coef.dist_node["penalty"].value
436
+ self.scale.setup_gibbs_inference(scaled_coef, penalty=pen) # type: ignore
437
+ except Exception as e:
438
+ raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
439
+
440
+ return self
58
441
 
59
442
  @classmethod
60
- def new_ig(
443
+ def f(
61
444
  cls,
62
- basis: Basis | lsl.Var,
63
- penalty: Array,
64
- name: str,
65
- ig_concentration: float = 0.01,
66
- ig_scale: float = 0.01,
445
+ basis: Basis,
446
+ fname: str = "f",
447
+ scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
67
448
  inference: InferenceTypes = None,
68
- variance_value: float | None = None,
69
- variance_name: str | None = None,
70
- variance_jitter_dist: tfd.Distribution | None = None,
71
449
  coef_name: str | None = None,
450
+ noncentered: bool = False,
72
451
  ) -> Self:
73
- variance_name = f"{name}_variance" if variance_name is None else variance_name
452
+ """
453
+ Construct a smooth term from a :class:`.Basis`.
454
+
455
+ This convenience constructor builds a named ``term`` using the
456
+ provided basis. The penalty matrix is taken from ``basis.penalty`` and
457
+ a coefficient variable with an appropriate multivariate-normal prior
458
+ is created. The returned term evaluates to ``basis @ coef``.
459
+
460
+ Parameters
461
+ ----------
462
+ basis
463
+ Basis object that provides the design matrix and penalty for the \
464
+ smooth term. The basis must have an associated input variable with \
465
+ a meaningful name (used to compose the term name).
466
+ fname
467
+ Function-name prefix used when constructing the term name. Default \
468
+ is ``'f'`` which results in names like ``f(x)`` when the basis \
469
+ input is named ``x``.
470
+ scale
471
+ Scale parameter passed to the coefficient prior.
472
+ inference
473
+ Inference specification forwarded to the coefficient variable \
474
+ creation, a :class:`liesel.goose.MCMCSpec`.
475
+ noncentered
476
+ If ``True``, the term is reparameterized to the non-centered \
477
+ form via :meth:`.reparam_noncentered` before being returned.
478
+ coef_name
479
+ Coefficient name. The default coefficient name is a LaTeX-like string \
480
+ ``"$\\beta_{f(x)}$"`` to improve readability in printed summaries.
481
+
482
+ Returns
483
+ -------
484
+ A :class:`.Term` instance configured with the given basis and prior settings.
485
+ """
486
+ if not basis.x.name:
487
+ raise ValueError("basis.x must be named.")
488
+
489
+ if not basis.name:
490
+ raise ValueError("basis must be named.")
491
+
492
+ name = f"{fname}({basis.x.name})"
493
+ coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$"
74
494
 
75
- variance = lsl.Var.new_param(
76
- value=1.0,
77
- distribution=lsl.Dist(
78
- tfd.InverseGamma,
79
- concentration=ig_concentration,
80
- scale=ig_scale,
81
- ),
82
- name=variance_name,
495
+ term = cls(
496
+ basis=basis,
497
+ penalty=basis.penalty if scale is not None else None,
498
+ scale=scale,
499
+ inference=inference,
500
+ coef_name=coef_name,
501
+ name=name,
502
+ validate_scalar_scale=not noncentered,
83
503
  )
84
- variance.role = Roles.variance_smooth
85
504
 
86
- scale = lsl.Var.new_calc(jnp.sqrt, variance, name=f"{variance_name}_root")
87
- scale.role = Roles.scale_smooth
505
+ if noncentered:
506
+ term.reparam_noncentered()
88
507
 
89
- if variance_value is None:
90
- ig_median = variance.dist_node.init_dist().quantile(0.5) # type: ignore
91
- variance.value = min(ig_median, 10.0)
92
- else:
93
- variance.value = variance_value
508
+ return term
509
+
510
+ @classmethod
511
+ def new_ig(
512
+ cls,
513
+ basis: Basis,
514
+ penalty: lsl.Var | lsl.Value | Array | None,
515
+ name: str,
516
+ ig_concentration: float = 1.0,
517
+ ig_scale: float = 0.005,
518
+ inference: InferenceTypes = None,
519
+ scale_value: float = 100.0,
520
+ scale_name: str | None = None,
521
+ coef_name: str | None = None,
522
+ noncentered: bool = False,
523
+ ) -> Term:
524
+ """
525
+ Construct a smooth term with an inverse-gamma prior on the variance.
526
+
527
+ This convenience constructor creates a term similar to :meth:`.f` but
528
+ sets up an explicit variance parameter with an Inverse-Gamma prior.
529
+ A scale variable is set up by taking the square-root, and the
530
+ coefficient prior uses the derived ``scale`` together with the basis
531
+ penalty. By default a Gibbs-style initialization is attached to the
532
+ variance inference via an internal kernel; an optional jitter
533
+ distribution can be provided for MCMC initialization.
534
+
535
+ Parameters
536
+ ----------
537
+ basis
538
+ Basis object providing the design matrix and penalty.
539
+ name
540
+ Term name.
541
+ penalty
542
+ Penalty matrix or a variable/value wrapping the penalty \
543
+ used to construct the multivariate normal prior for the coefficients.
544
+ ig_concentration
545
+ Concentration (shape) parameter of the Inverse-Gamma prior for the \
546
+ variance.
547
+ ig_scale
548
+ Scale parameter of the Inverse-Gamma prior for the variance.
549
+ inference
550
+ Inference specification forwarded to the coefficient variable \
551
+ creation, a :class:`liesel.goose.MCMCSpec`.
552
+ variance_value
553
+ Initial value for the variance parameter.
554
+ variance_name
555
+ Variance parameter name. The default is a LaTeX-like representation \
556
+ ``"$\\tau^2_{...}$"`` for readability in summaries.
557
+ coef_name
558
+ Coefficient name. The default coefficient name is a LaTeX-like string \
559
+ ``"$\\beta_{f(x)}$"`` to improve readability in printed summaries.
560
+ noncentered
561
+ If ``True``, reparameterize the term to non-centered form \
562
+ (see :meth:`.reparam_noncentered`).
563
+
564
+ Returns
565
+ -------
566
+ A :class:`.Term` instance configured with an inverse-gamma prior on
567
+ the variance and an appropriate inference specification for
568
+ variance updates.
569
+
570
+ """
571
+ coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$"
572
+ scale_name = scale_name or "$\\tau$"
573
+ scale = ScaleIG(
574
+ jnp.asarray(scale_value),
575
+ concentration=ig_concentration,
576
+ scale=ig_scale,
577
+ name=scale_name,
578
+ )
94
579
 
95
580
  term = cls(
96
581
  basis=basis,
@@ -101,118 +586,1015 @@ class SmoothTerm(lsl.Var):
101
586
  coef_name=coef_name,
102
587
  )
103
588
 
104
- variance.inference = gs.MCMCSpec(
105
- init_star_ig_gibbs,
106
- kernel_kwargs={"coef": term.coef},
107
- jitter_dist=variance_jitter_dist,
108
- )
589
+ if noncentered:
590
+ term.reparam_noncentered()
109
591
 
110
592
  return term
111
593
 
594
+ def diagonalize_penalty(self, atol: float = 1e-6) -> Self:
595
+ """
596
+ Diagonalize the penalty via an eigenvalue decomposition.
597
+
598
+ This method computes a transformation that diagonalizes
599
+ the penalty matrix and updates the internal basis function such that
600
+ subsequent evaluations use the accordingly transformed basis. The penalty is
601
+ updated to the diagonalized version.
602
+
603
+ Returns
604
+ -------
605
+ The modified term instance (self).
606
+ """
607
+ self.basis.diagonalize_penalty(atol)
608
+ return self
609
+
610
+ def scale_penalty(self) -> Self:
611
+ """
612
+ Scale the penalty matrix by its infinite norm.
613
+
614
+ The penalty matrix is divided by its infinity norm (max absolute row
615
+ sum) so that its values are numerically well-conditioned for
616
+ downstream use. The updated penalty replaces the previous one.
617
+
618
+ Returns
619
+ -------
620
+ The modified term instance (self).
621
+ """
622
+ self.basis.scale_penalty()
623
+ return self
624
+
625
+ def constrain(
626
+ self,
627
+ constraint: ArrayLike
628
+ | Literal["sumzero_term", "sumzero_coef", "constant_and_linear"],
629
+ ) -> Self:
630
+ """
631
+ Apply a linear constraint to the term's basis and corresponding penalty.
632
+
633
+ Parameters
634
+ ----------
635
+ constraint
636
+ Type of constraint or custom linear constraint matrix to apply. \
637
+ If an array is supplied, the constraint will be \
638
+ ``A @ coef == 0``, where ``A`` is the supplied constraint matrix.
639
+
640
+ Returns
641
+ -------
642
+ The modified term instance (self).
643
+ """
644
+ self.basis.constrain(constraint)
645
+ self.coef.value = jnp.zeros(self.nbases)
646
+ return self
647
+
648
+
649
+ SmoothTerm = Term
650
+
651
+
652
+ class MRFTerm(Term):
653
+ _neighbors = None
654
+ _polygons = None
655
+ _ordered_labels = None
656
+ _labels = None
657
+ _mapping = None
112
658
 
113
- class LinearTerm(lsl.Var):
659
+ @property
660
+ def neighbors(self) -> dict[str, list[str]] | None:
661
+ return self._neighbors
662
+
663
+ @neighbors.setter
664
+ def neighbors(self, value: dict[str, list[str]] | None) -> None:
665
+ self._neighbors = value
666
+
667
+ @property
668
+ def polygons(self) -> dict[str, ArrayLike] | None:
669
+ return self._polygons
670
+
671
+ @polygons.setter
672
+ def polygons(self, value: dict[str, ArrayLike] | None) -> None:
673
+ self._polygons = value
674
+
675
+ @property
676
+ def labels(self) -> list[str] | None:
677
+ return self._labels
678
+
679
+ @labels.setter
680
+ def labels(self, value: list[str]) -> None:
681
+ self._labels = value
682
+
683
+ @property
684
+ def mapping(self) -> CategoryMapping:
685
+ if self._mapping is None:
686
+ raise ValueError("No mapping defined.")
687
+ return self._mapping
688
+
689
+ @mapping.setter
690
+ def mapping(self, value: CategoryMapping) -> None:
691
+ self._mapping = value
692
+
693
+ @property
694
+ def ordered_labels(self) -> list[str] | None:
695
+ return self._ordered_labels
696
+
697
+ @ordered_labels.setter
698
+ def ordered_labels(self, value: list[str]) -> None:
699
+ self._ordered_labels = value
700
+
701
+
702
+ class IndexingTerm(Term):
114
703
  def __init__(
115
704
  self,
116
- x: lsl.Var | Array,
117
- name: str,
118
- distribution: lsl.Dist | None = None,
705
+ basis: Basis,
706
+ penalty: lsl.Var | lsl.Value | Array | None,
707
+ scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
708
+ name: str = "",
119
709
  inference: InferenceTypes = None,
120
- add_intercept: bool = False,
121
710
  coef_name: str | None = None,
122
- basis_name: str | None = None,
711
+ _update_on_init: bool = True,
712
+ validate_scalar_scale: bool = True,
123
713
  ):
124
- coef_name = f"{name}_coef" if coef_name is None else coef_name
125
- basis_name = f"B({name})" if basis_name is None else basis_name
714
+ if not basis.value.ndim == 1:
715
+ raise ValueError(f"IndexingTerm requires 1d basis, got {basis.value.ndim=}")
126
716
 
127
- def _matrix(x):
128
- x = jnp.atleast_1d(x)
129
- if len(jnp.shape(x)) == 1:
130
- x = jnp.expand_dims(x, -1)
131
- if add_intercept:
132
- ones = jnp.ones(x.shape[0])
133
- x = jnp.c_[ones, x]
134
- return x
717
+ if not jnp.issubdtype(jnp.dtype(basis.value), jnp.integer):
718
+ raise TypeError(
719
+ f"IndexingTerm requires integer basis, got {jnp.dtype(basis.value)=}."
720
+ )
135
721
 
136
- if not isinstance(x, lsl.Var):
137
- x = lsl.Var.new_obs(x, name=f"{name}_input")
722
+ super().__init__(
723
+ basis=basis,
724
+ penalty=penalty,
725
+ scale=scale,
726
+ name=name,
727
+ inference=inference,
728
+ coef_name=coef_name,
729
+ _update_on_init=False,
730
+ validate_scalar_scale=validate_scalar_scale,
731
+ )
138
732
 
139
- basis = lsl.Var(lsl.TransientCalc(_matrix, x=x), name=basis_name)
140
- basis.role = Roles.basis
733
+ # mypy warns that self.value_node might be a lsl.Node, which does not have the
734
+ # attribute "function".
735
+ # But we can assume safely that self.value_node is a lsl.Calc, which does have
736
+ # one.
737
+ self.value_node.function = lambda basis, coef: jnp.take(coef, basis) # type: ignore
738
+ if _update_on_init:
739
+ self.coef.update()
740
+ self.update()
141
741
 
142
- nbases = jnp.shape(basis.value)[-1]
742
+ @property
743
+ def full_basis(self) -> Basis:
744
+ nclusters = jnp.unique(self.basis.value).size
745
+ full_basis = Basis(
746
+ self.basis.x, basis_fn=jax.nn.one_hot, num_classes=nclusters, name=""
747
+ )
748
+ return full_basis
143
749
 
144
- self.nbases = nbases
750
+
751
+ class RITerm(IndexingTerm):
752
+ _labels = None
753
+ _mapping = None
754
+
755
+ @property
756
+ def full_basis(self) -> Basis:
757
+ try:
758
+ nclusters = len(self.mapping.labels_to_integers_map)
759
+ except ValueError:
760
+ nclusters = jnp.unique(self.basis.value).size
761
+
762
+ full_basis = Basis(
763
+ self.basis.x, basis_fn=jax.nn.one_hot, num_classes=nclusters, name=""
764
+ )
765
+ return full_basis
766
+
767
+ @property
768
+ def labels(self) -> list[str]:
769
+ if self._labels is None:
770
+ raise ValueError("No labels defined.")
771
+ return self._labels
772
+
773
+ @labels.setter
774
+ def labels(self, value: list[str]) -> None:
775
+ self._labels = value
776
+
777
+ @property
778
+ def mapping(self) -> CategoryMapping:
779
+ if self._mapping is None:
780
+ raise ValueError("No mapping defined.")
781
+ return self._mapping
782
+
783
+ @mapping.setter
784
+ def mapping(self, value: CategoryMapping) -> None:
785
+ self._mapping = value
786
+
787
+
788
+ class BasisDot(UserVar):
789
+ def __init__(
790
+ self,
791
+ basis: Basis,
792
+ prior: lsl.Dist | None = None,
793
+ name: str = "",
794
+ inference: InferenceTypes = None,
795
+ coef_name: str | None = None,
796
+ _update_on_init: bool = True,
797
+ ):
145
798
  self.basis = basis
799
+ self.nbases = self.basis.nbases
800
+ coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
801
+
146
802
  self.coef = lsl.Var.new_param(
147
- jnp.zeros(nbases), distribution, inference=inference, name=coef_name
803
+ jnp.zeros(self.basis.nbases), prior, inference=inference, name=coef_name
804
+ )
805
+ calc = lsl.Calc(
806
+ lambda basis, coef: jnp.dot(basis, coef),
807
+ basis=self.basis,
808
+ coef=self.coef,
809
+ _update_on_init=_update_on_init,
148
810
  )
149
- calc = lsl.Calc(jnp.dot, basis, self.coef)
150
811
 
151
812
  super().__init__(calc, name=name)
152
- self.coef.role = Roles.coef_linear
153
- self.role = Roles.term_linear
154
813
 
155
814
 
156
- class Intercept(lsl.Var):
815
+ class Intercept(UserVar):
157
816
  def __init__(
158
817
  self,
159
818
  name: str,
160
- value: Array | float = 0.0,
819
+ value: ArrayLike | float = 0.0,
161
820
  distribution: lsl.Dist | None = None,
162
821
  inference: InferenceTypes = None,
163
822
  ) -> None:
164
823
  super().__init__(
165
- value=value, distribution=distribution, name=name, inference=inference
824
+ value=jnp.asarray(value),
825
+ distribution=distribution,
826
+ name=name,
827
+ inference=inference,
166
828
  )
167
829
  self.parameter = True
168
- self.role = Roles.intercept
169
830
 
170
831
 
171
- class Basis(lsl.Var):
832
+ def make_callback(function, output_shape, dtype, m: int = 0):
833
+ if len(output_shape):
834
+ k = output_shape[-1]
835
+
836
+ def fn(x, **basis_kwargs):
837
+ n = jnp.shape(jnp.atleast_1d(x))[0]
838
+ if len(output_shape) == 2:
839
+ shape = (n - m, k)
840
+ elif len(output_shape) == 1:
841
+ shape = (n - m,)
842
+ elif not len(output_shape):
843
+ shape = ()
844
+ else:
845
+ raise RuntimeError(
846
+ "Return shape of 'basis_fn(value)' must"
847
+ f" have <= 2 dimensions, got {output_shape}"
848
+ )
849
+ result_shape = jax.ShapeDtypeStruct(shape, dtype)
850
+ result = jax.pure_callback(
851
+ function, result_shape, x, vmap_method="sequential", **basis_kwargs
852
+ )
853
+ return result
854
+
855
+ return fn
856
+
857
+
858
+ def is_diagonal(M, atol=1e-12):
859
+ # mask for off-diagonal elements
860
+ off_diag_mask = ~jnp.eye(M.shape[-1], dtype=bool)
861
+ off_diag_values = M[off_diag_mask]
862
+ return jnp.all(jnp.abs(off_diag_values) < atol)
863
+
864
+
865
+ class Basis(UserVar):
866
+ """
867
+ General basis for a structured additive term.
868
+
869
+ The ``Basis`` class wraps either a provided observation variable or a raw
870
+ array and a basis-generation function. It constructs an internal
871
+ calculation node that produces the basis (design) matrix used by
872
+ smooth terms. The basis function may be executed via a
873
+ callback that does not need to be jax-compatible (the default, potentially slow)
874
+ with a jax-compatible function that is included in just-in-time-compilation
875
+ (when ``use_callback=False``).
876
+
877
+ Parameters
878
+ ----------
879
+ value
880
+ If a :class:`liesel.model.Var` or node is provided it is used as \
881
+ the input variable for the basis. Otherwise a raw array-like \
882
+ object may be supplied together with ``xname`` to create an \
883
+ observed variable internally.
884
+ basis_fn
885
+ Function mapping the input variable's values to a basis matrix or \
886
+ vector. It must accept the input array and any ``basis_kwargs`` \
887
+ and return an array of shape ``(n_obs, n_bases)`` (or a scalar/1-d \
888
+ array for simpler bases). By default this is the identity \
889
+ function (``lambda x: x``).
890
+ name
891
+ Optional name for the basis object. If omitted, a sensible name \
892
+ is constructed from the input variable's name (``B(<xname>)``).
893
+ xname
894
+ Required when ``value`` is a raw array: provides a name for the \
895
+ observation variable that will be created.
896
+ use_callback
897
+ If ``True`` (default) the basis_fn is wrapped in a JAX \
898
+ ``pure_callback`` via :func:`make_callback` to allow arbitrary \
899
+ Python basis functions while preserving JAX tracing. If ``False`` \
900
+ the function is used directly and must be jittable via JAX.
901
+ cache_basis
902
+ If ``True`` the computed basis is cached in a persistent \
903
+ calculation node (``lsl.Calc``), which avoids re-computation \
904
+ when not required, but uses memory. If ``False`` a transient \
905
+ calculation node (``lsl.TransientCalc``) is used and the basis \
906
+ will be recomputed with each evaluation of ``Basis.value``, \
907
+ but not stored in memory.
908
+ penalty
909
+ Penalty matrix associated with the basis. If omitted, \
910
+ a default identity penalty is created based on the number \
911
+ of basis functions.
912
+ **basis_kwargs
913
+ Additional keyword arguments forwarded to ``basis_fn``.
914
+
915
+ Raises
916
+ ------
917
+ ValueError
918
+ If ``value`` is an array and ``xname`` is not provided, or if
919
+ the created input variable has no name.
920
+
921
+ Notes
922
+ -----
923
+ The basis is evaluated once during initialization (via
924
+ ``self.update()``) to determine its shape and dtype. The internal
925
+ callback wrapper inspects the return shape to build a compatible
926
+ JAX ShapeDtypeStruct for the pure callback.
927
+
928
+ Attributes
929
+ ----------
930
+ role
931
+ The role assigned to this variable.
932
+ observed
933
+ Whether the basis is derived from an observed variable (always \
934
+ ``True`` for bases created from input data).
935
+ x
936
+ The input variable (observations) used to construct the basis.
937
+ nbases
938
+ Number of basis functions (number of columns in the basis matrix).
939
+ penalty
940
+ Penalty matrix (wrapped as a :class:`liesel.model.Value`) associated \
941
+ with the basis.
942
+
943
+ Examples
944
+ --------
945
+ Identity basis from a named variable::
946
+
947
+ import liesel.model as lsl
948
+ import jax.numpy as jnp
949
+ xvar = lsl.Var.new_obs(jnp.array([1.,2.,3.]), name='x')
950
+ b = Basis(value=xvar)
951
+ """
952
+
172
953
  def __init__(
173
954
  self,
174
- value: lsl.Var | lsl.Node,
175
- basis_fn: Callable[[Array], Array] | Callable[..., Array],
176
- *args,
955
+ value: lsl.Var | lsl.Node | ArrayLike,
956
+ basis_fn: Callable[[Array], Array] | Callable[..., Array] = lambda x: x,
177
957
  name: str | None = None,
178
- **kwargs,
958
+ xname: str | None = None,
959
+ use_callback: bool = True,
960
+ cache_basis: bool = True,
961
+ penalty: ArrayLike | lsl.Value | None = None,
962
+ **basis_kwargs,
179
963
  ) -> None:
180
- try:
181
- value_ar = jnp.asarray(value.value)
182
- except AttributeError:
183
- raise TypeError(f"{value=} should be a liesel.model.Var instance.")
184
-
185
- dtype = value_ar.dtype
186
-
187
- input_shape = jnp.shape(basis_fn(value_ar, *args, **kwargs))
188
- if len(input_shape):
189
- k = input_shape[-1]
190
-
191
- def fn(x):
192
- n = jnp.shape(jnp.atleast_1d(x))[0]
193
- if len(input_shape) == 2:
194
- shape = (n, k)
195
- elif len(input_shape) == 1:
196
- shape = (n,)
197
- elif not len(input_shape):
198
- shape = ()
964
+ self._validate_xname(value, xname)
965
+ value_var = _ensure_var_or_node(value, xname)
966
+
967
+ if use_callback:
968
+ value_ar = jnp.asarray(value_var.value)
969
+ basis_kwargs_arr = {}
970
+ for key, val in basis_kwargs.items():
971
+ if isinstance(val, lsl.Var | lsl.Node):
972
+ basis_kwargs_arr[key] = val.value
973
+ else:
974
+ basis_kwargs_arr[key] = val
975
+ basis_ar = basis_fn(value_ar, **basis_kwargs_arr)
976
+ dtype = basis_ar.dtype
977
+ input_shape = jnp.shape(basis_ar)
978
+
979
+ # This is special-case handling for compatibility with
980
+ # basis functions that remove cases. For example, if you have a formulaic
981
+ # formula "x + lag(x)", then the resulting basis will have one case less
982
+ # than the original x, because the first case is dropped.
983
+ if value_ar.shape:
984
+ p = value_ar.shape[0] if value_ar.shape else 0
985
+ k = input_shape[0] if input_shape else 0
986
+ m = p - k
199
987
  else:
200
- raise RuntimeError(
201
- "Return shape of 'basis_fn(value)' must"
202
- " have <= dimensions, got {input_shape}"
203
- )
204
- result_shape = jax.ShapeDtypeStruct(shape, dtype)
205
- result = jax.pure_callback(
206
- basis_fn, result_shape, x, *args, vmap_method="sequential", **kwargs
988
+ m = 0
989
+
990
+ fn = make_callback(basis_fn, input_shape, dtype, m)
991
+ else:
992
+ fn = basis_fn
993
+
994
+ name_ = self._basis_name(value_var, name)
995
+
996
+ if cache_basis:
997
+ calc = lsl.Calc(
998
+ fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc")
999
+ )
1000
+ else:
1001
+ calc = lsl.TransientCalc(
1002
+ fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc")
1003
+ )
1004
+
1005
+ super().__init__(calc, name=name_)
1006
+ self.update()
1007
+ self.observed = True
1008
+
1009
+ self.x: lsl.Var | lsl.Node = value_var
1010
+ basis_shape = jnp.shape(self.value)
1011
+ if len(basis_shape) >= 1:
1012
+ self.nbases: int = basis_shape[-1]
1013
+ else:
1014
+ self.nbases = 1 # scalar case
1015
+
1016
+ if isinstance(penalty, lsl.Value):
1017
+ penalty_var = penalty
1018
+ elif penalty is None:
1019
+ penalty_arr = jnp.eye(self.nbases)
1020
+ penalty_var = lsl.Value(penalty_arr)
1021
+ else:
1022
+ penalty_arr = jnp.asarray(penalty)
1023
+ penalty_var = lsl.Value(penalty_arr)
1024
+
1025
+ self._penalty = penalty_var
1026
+
1027
+ self._constraint: str | None = None
1028
+ self._reparam_matrix: Array | None = None
1029
+
1030
+ @property
1031
+ def constraint(self) -> str | None:
1032
+ return self._constraint
1033
+
1034
+ @property
1035
+ def reparam_matrix(self) -> Array | None:
1036
+ return self._reparam_matrix
1037
+
1038
+ def _validate_xname(self, value: lsl.Var | lsl.Node | ArrayLike, xname: str | None):
1039
+ if isinstance(value, lsl.Var | lsl.Node) and xname is not None:
1040
+ raise ValueError(
1041
+ "When supplying a variable or node to `value`, `xname` must not be "
1042
+ "used. Name the variable instead."
207
1043
  )
208
- return result
209
1044
 
210
- if not value.name:
211
- raise ValueError(f"{value=} must be named.")
1045
+ def _basis_name(self, value: lsl.Var | lsl.Node | ArrayLike, name: str | None):
1046
+ if name is not None and name != "":
1047
+ return name
1048
+
1049
+ if isinstance(value, lsl.Var | lsl.Node) and value.name == "":
1050
+ return ""
1051
+
1052
+ if hasattr(value, "name"):
1053
+ return f"B({value.name})"
1054
+ return ""
1055
+
1056
+ @property
1057
+ def penalty(self) -> lsl.Value:
1058
+ """
1059
+ Return the penalty matrix wrapped as a :class:`liesel.model.Value`.
1060
+
1061
+ Returns
1062
+ -------
1063
+ lsl.Value
1064
+ Value wrapper holding the penalty (precision) matrix for this
1065
+ basis.
1066
+ """
1067
+ return self._penalty
1068
+
1069
+ def update_penalty(self, value: ArrayLike | lsl.Value):
1070
+ """
1071
+ Update the penalty matrix for this basis.
212
1072
 
213
- if name is None:
214
- name_ = f"B({value.name})"
1073
+ Parameters
1074
+ ----------
1075
+ value
1076
+ New penalty matrix or an already-wrapped :class:`liesel.model.Value`.
1077
+ """
1078
+ if isinstance(value, lsl.Value):
1079
+ self._penalty.value = value.value
1080
+ else:
1081
+ penalty_arr = jnp.asarray(value)
1082
+ self._penalty.value = penalty_arr
1083
+
1084
+ @classmethod
1085
+ def new_linear(
1086
+ cls,
1087
+ value: lsl.Var | lsl.Node | Array,
1088
+ name: str | None = None,
1089
+ xname: str | None = None,
1090
+ add_intercept: bool = False,
1091
+ ):
1092
+ """
1093
+ Create a linear basis (design matrix) from input values.
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ value
1098
+ Input variable or raw array used to construct the design matrix.
1099
+ name
1100
+ Optional name for the basis.
1101
+ xname
1102
+ Name for the observation variable when ``value`` is \
1103
+ a raw array.
1104
+ add_intercept
1105
+ If ``True``, adds an intercept column of ones as the first \
1106
+ column of the design matrix.
1107
+
1108
+ Returns
1109
+ -------
1110
+ A :class:`.Basis` instance that produces a (n_obs, n_features)
1111
+ design matrix.
1112
+ """
1113
+
1114
+ def as_matrix(x):
1115
+ x = jnp.atleast_1d(x)
1116
+ if len(jnp.shape(x)) == 1:
1117
+ x = jnp.expand_dims(x, -1)
1118
+ if add_intercept:
1119
+ ones = jnp.ones(x.shape[0])
1120
+ x = jnp.c_[ones, x]
1121
+ return x
1122
+
1123
+ basis = cls(
1124
+ value=value,
1125
+ basis_fn=as_matrix,
1126
+ name=name,
1127
+ xname=xname,
1128
+ use_callback=False,
1129
+ cache_basis=False,
1130
+ )
1131
+
1132
+ return basis
1133
+
1134
+ def diagonalize_penalty(self, atol: float = 1e-6) -> Self:
1135
+ """
1136
+ Diagonalize the penalty via an eigenvalue decomposition.
1137
+
1138
+ This method computes a transformation that diagonalizes
1139
+ the penalty matrix and updates the internal basis function such that
1140
+ subsequent evaluations use the accordingly transformed basis. The penalty is
1141
+ updated to the diagonalized version.
1142
+
1143
+ Returns
1144
+ -------
1145
+ The modified basis instance (self).
1146
+ """
1147
+ assert isinstance(self.value_node, lsl.Calc)
1148
+ basis_fn = self.value_node.function
1149
+
1150
+ K = self.penalty.value
1151
+ if is_diagonal(K, atol=atol):
1152
+ return self
215
1153
 
216
- super().__init__(lsl.Calc(fn, value, _name=name_ + "_calc"), name=name_)
1154
+ Z = penalty_to_unit_design(K)
1155
+
1156
+ def reparam_basis(*args, **kwargs):
1157
+ return basis_fn(*args, **kwargs) @ Z
1158
+
1159
+ self.value_node.function = reparam_basis
1160
+ self.update()
1161
+ penalty = Z.T @ K @ Z
1162
+ self.update_penalty(penalty)
1163
+
1164
+ return self
1165
+
1166
+ def scale_penalty(self) -> Self:
1167
+ """
1168
+ Scale the penalty matrix by its infinite norm.
1169
+
1170
+ The penalty matrix is divided by its infinity norm (max absolute row
1171
+ sum) so that its values are numerically well-conditioned for
1172
+ downstream use. The updated penalty replaces the previous one.
1173
+
1174
+ Returns
1175
+ -------
1176
+ The modified basis instance (self).
1177
+ """
1178
+ K = self.penalty.value
1179
+ scale = jnp.linalg.norm(K, ord=jnp.inf)
1180
+ penalty = K / scale
1181
+ self.update_penalty(penalty)
1182
+ return self
1183
+
1184
+ def _apply_constraint(self, Z: Array) -> Self:
1185
+ """
1186
+ Apply a linear reparameterisation to the basis using matrix Z.
1187
+
1188
+ This internal helper multiplies the basis functions by ``Z`` (i.e.
1189
+ right-multiplies the design matrix) and updates the penalty to
1190
+ reflect the change of basis: ``K_new = Z.T @ K @ Z``.
1191
+
1192
+ Parameters
1193
+ ----------
1194
+ Z
1195
+ Transformation matrix applied to the basis functions.
1196
+
1197
+ Returns
1198
+ -------
1199
+ The modified basis instance (self).
1200
+ """
1201
+
1202
+ assert isinstance(self.value_node, lsl.Calc)
1203
+ basis_fn = self.value_node.function
1204
+
1205
+ K = self.penalty.value
1206
+
1207
+ def reparam_basis(*args, **kwargs):
1208
+ return basis_fn(*args, **kwargs) @ Z
1209
+
1210
+ self.value_node.function = reparam_basis
217
1211
  self.update()
218
- self.role = Roles.basis
1212
+ penalty = Z.T @ K @ Z
1213
+ self.update_penalty(penalty)
1214
+ return self
1215
+
1216
+ def constrain(
1217
+ self,
1218
+ constraint: ArrayLike
1219
+ | Literal["sumzero_term", "sumzero_coef", "constant_and_linear"],
1220
+ ) -> Self:
1221
+ """
1222
+ Apply a linear constraint to the basis and corresponding penalty.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ constraint
1227
+ Type of constraint or custom linear constraint matrix to apply.
1228
+ If an array is supplied, the constraint will be \
1229
+ ``A @ coef == 0``, where ``A`` is the supplied constraint matrix.
1230
+
1231
+ Returns
1232
+ -------
1233
+ The modified basis instance (self).
1234
+ """
1235
+ if not self.value.ndim == 2:
1236
+ raise ValueError(
1237
+ "Constraints can only be applied to matrix-valued bases. "
1238
+ f"{self} has shape {self.value.shape}"
1239
+ )
1240
+
1241
+ if self.constraint is not None:
1242
+ raise ValueError(
1243
+ f"A '{self.constraint}' constraint has already been applied."
1244
+ )
1245
+
1246
+ if isinstance(constraint, str):
1247
+ type_: str = constraint
1248
+ else:
1249
+ constraint_matrix = jnp.asarray(constraint)
1250
+ type_ = "custom"
1251
+
1252
+ match type_:
1253
+ case "sumzero_coef":
1254
+ Z = LinearConstraintEVD.sumzero_coef(self.nbases)
1255
+ case "sumzero_term":
1256
+ Z = LinearConstraintEVD.sumzero_term(self.value)
1257
+ case "constant_and_linear":
1258
+ Z = LinearConstraintEVD.constant_and_linear(self.x.value, self.value)
1259
+ case "custom":
1260
+ Z = LinearConstraintEVD.general(constraint_matrix)
1261
+
1262
+ self._apply_constraint(Z)
1263
+ self._constraint = type_
1264
+ self._reparam_matrix = Z
1265
+
1266
+ return self
1267
+
1268
+
1269
+ class MRFSpec(NamedTuple):
1270
+ mapping: CategoryMapping
1271
+ nb: dict[str, list[str]] | None
1272
+ ordered_labels: list[str] | None
1273
+
1274
+
1275
+ class MRFBasis(Basis):
1276
+ _mrf_spec: MRFSpec | None = None
1277
+
1278
+ @property
1279
+ def mrf_spec(self) -> MRFSpec:
1280
+ if self._mrf_spec is None:
1281
+ raise ValueError("No MRF spec defined.")
1282
+ return self._mrf_spec
1283
+
1284
+ @mrf_spec.setter
1285
+ def mrf_spec(self, value: MRFSpec):
1286
+ if not isinstance(value, MRFSpec):
1287
+ raise TypeError(
1288
+ f"Replacement must be of type {MRFSpec}, got {type(value)}."
1289
+ )
1290
+ self._mrf_spec = value
1291
+
1292
+
1293
+ class LinBasis(Basis):
1294
+ _model_spec: ModelSpec | None = None
1295
+ _mappings: dict[str, CategoryMapping] | None = None
1296
+ _column_names: list[str] | None = None
1297
+
1298
+ @property
1299
+ def model_spec(self) -> ModelSpec:
1300
+ if self._model_spec is None:
1301
+ raise ValueError("No model spec defined.")
1302
+ return self._model_spec
1303
+
1304
+ @model_spec.setter
1305
+ def model_spec(self, value: ModelSpec):
1306
+ if not isinstance(value, ModelSpec):
1307
+ raise TypeError(
1308
+ f"Replacement must be of type {ModelSpec}, got {type(value)}."
1309
+ )
1310
+ self._model_spec = value
1311
+
1312
+ @property
1313
+ def mappings(self) -> dict[str, CategoryMapping]:
1314
+ if self._mappings is None:
1315
+ raise ValueError("No model spec defined.")
1316
+ return self._mappings
1317
+
1318
+ @mappings.setter
1319
+ def mappings(self, value: dict[str, CategoryMapping]):
1320
+ if not isinstance(value, dict):
1321
+ raise TypeError(f"Replacement must be of type dict, got {type(value)}.")
1322
+
1323
+ for val in value.values():
1324
+ if not isinstance(val, CategoryMapping):
1325
+ raise TypeError(
1326
+ f"The values in the replacement must be of type {CategoryMapping}, "
1327
+ f"got {type(val)}."
1328
+ )
1329
+ self._mappings = value
1330
+
1331
+ @property
1332
+ def column_names(self) -> list[str]:
1333
+ if self._column_names is None:
1334
+ raise ValueError("No model spec defined.")
1335
+ return self._column_names
1336
+
1337
+ @column_names.setter
1338
+ def column_names(self, value: Sequence[str]):
1339
+ if not isinstance(value, Sequence):
1340
+ raise TypeError(f"Replacement must be a sequence, got {type(value)}.")
1341
+
1342
+ for val in value:
1343
+ if not isinstance(val, str):
1344
+ raise TypeError(
1345
+ f"The values in the replacement must be of type str, "
1346
+ f"got {type(val)}."
1347
+ )
1348
+ self._column_names = list(value)
1349
+
1350
+
1351
+ class LinTerm(BasisDot):
1352
+ _model_spec: ModelSpec | None = None
1353
+ _mappings: dict[str, CategoryMapping] | None = None
1354
+ _column_names: list[str] | None = None
1355
+
1356
+ @property
1357
+ def model_spec(self) -> ModelSpec | None:
1358
+ return self._model_spec
1359
+
1360
+ @model_spec.setter
1361
+ def model_spec(self, value: ModelSpec):
1362
+ if not isinstance(value, ModelSpec):
1363
+ raise TypeError(
1364
+ f"Replacement must be of type {ModelSpec}, got {type(value)}."
1365
+ )
1366
+ self._model_spec = value
1367
+
1368
+ @property
1369
+ def mappings(self) -> dict[str, CategoryMapping]:
1370
+ if self._mappings is None:
1371
+ raise ValueError("No model spec defined.")
1372
+ return self._mappings
1373
+
1374
+ @mappings.setter
1375
+ def mappings(self, value: dict[str, CategoryMapping]):
1376
+ if not isinstance(value, dict):
1377
+ raise TypeError(f"Replacement must be of type dict, got {type(value)}.")
1378
+
1379
+ for val in value.values():
1380
+ if not isinstance(val, CategoryMapping):
1381
+ raise TypeError(
1382
+ f"The values in the replacement must be of type {CategoryMapping}, "
1383
+ f"got {type(val)}."
1384
+ )
1385
+ self._mappings = value
1386
+
1387
+ @property
1388
+ def column_names(self) -> list[str]:
1389
+ if self._column_names is None:
1390
+ raise ValueError("No model spec defined.")
1391
+ return self._column_names
1392
+
1393
+ @column_names.setter
1394
+ def column_names(self, value: Sequence[str]):
1395
+ if not isinstance(value, Sequence):
1396
+ raise TypeError(f"Replacement must be a sequence, got {type(value)}.")
1397
+
1398
+ for val in value:
1399
+ if not isinstance(val, str):
1400
+ raise TypeError(
1401
+ f"The values in the replacement must be of type str, "
1402
+ f"got {type(val)}."
1403
+ )
1404
+ self._column_names = list(value)
1405
+
1406
+
1407
+ class TPTerm(UserVar):
1408
+ """
1409
+ General anisotropic structured additive tensor product term.
1410
+
1411
+ A bivariate tensor product can have:
1412
+
1413
+ 1. One scale parameter (when using ita)
1414
+ 2. Two scale parameters (when using include_main_effects)
1415
+ 3. Three scale parameters (when using common_scale and include_main_effects,
1416
+ or adding main effects separately)
1417
+ 4. Four scale parameters (when adding main effects separately)
1418
+
1419
+ Option four is the most flexible one, since it also allows you to use different
1420
+ basis dimensions for the main effects and the interaction.
1421
+ """
1422
+
1423
+ def __init__(
1424
+ self,
1425
+ *marginals: Term | IndexingTerm | RITerm | MRFTerm,
1426
+ common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
1427
+ name: str = "",
1428
+ inference: InferenceTypes = None,
1429
+ coef_name: str | None = None,
1430
+ basis_name: str | None = None,
1431
+ include_main_effects: bool = False,
1432
+ _update_on_init: bool = True,
1433
+ ):
1434
+ self._validate_marginals(marginals)
1435
+ coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
1436
+ bases = self._get_bases(marginals)
1437
+ penalties = [b.penalty.value for b in bases]
1438
+
1439
+ if common_scale is None:
1440
+ scales = [t.scale for t in marginals]
1441
+ else:
1442
+ scales = [_init_scale_ig(common_scale) for _ in bases]
1443
+
1444
+ _rowwise_kron = jax.vmap(jnp.kron)
1445
+
1446
+ def rowwise_kron(*bases):
1447
+ return reduce(_rowwise_kron, bases)
1448
+
1449
+ if basis_name is None:
1450
+ basis_name = "B(" + ",".join(list(self._input_obs(bases))) + ")"
1451
+
1452
+ assert basis_name is not None
1453
+ basis = lsl.Var.new_calc(rowwise_kron, *bases, name=basis_name)
1454
+ nbases = jnp.shape(basis.value)[-1]
1455
+
1456
+ mvnds = MultivariateNormalStructured.get_locscale_constructor(
1457
+ penalties=penalties
1458
+ )
1459
+
1460
+ scales_var = lsl.Calc(lambda *x: jnp.stack(x, axis=-1), *scales)
1461
+
1462
+ prior = lsl.Dist(distribution=mvnds, loc=jnp.zeros(nbases), scales=scales_var)
1463
+
1464
+ coef = lsl.Var.new_param(
1465
+ jnp.zeros(nbases),
1466
+ distribution=prior,
1467
+ inference=inference,
1468
+ name=coef_name,
1469
+ )
1470
+
1471
+ self.basis = basis
1472
+ self.marginals = marginals
1473
+ self.bases = bases
1474
+ self.penalties = penalties
1475
+ self.scales = scales
1476
+
1477
+ self.nbases = nbases
1478
+ self.basis = basis
1479
+ self.coef = coef
1480
+ self.scale = scales_var
1481
+ self.include_main_effects = include_main_effects
1482
+
1483
+ if include_main_effects:
1484
+ calc = lsl.Calc(
1485
+ lambda *marginals, basis, coef: sum(marginals) + jnp.dot(basis, coef),
1486
+ *marginals,
1487
+ basis=basis,
1488
+ coef=self.coef,
1489
+ _update_on_init=_update_on_init,
1490
+ )
1491
+ else:
1492
+ calc = lsl.Calc(
1493
+ lambda basis, coef: jnp.dot(basis, coef),
1494
+ basis=basis,
1495
+ coef=self.coef,
1496
+ _update_on_init=_update_on_init,
1497
+ )
1498
+
1499
+ super().__init__(calc, name=name)
1500
+ if _update_on_init:
1501
+ self.coef.update()
1502
+
1503
+ @staticmethod
1504
+ def _get_bases(
1505
+ marginals: Sequence[Term | RITerm | MRFTerm | IndexingTerm],
1506
+ ) -> list[Basis]:
1507
+ bases = []
1508
+ for t in marginals:
1509
+ if hasattr(t, "full_basis"):
1510
+ bases.append(t.full_basis)
1511
+ else:
1512
+ bases.append(t.basis)
1513
+ return bases
1514
+
1515
+ @staticmethod
1516
+ def _validate_marginals(marginals: Sequence[Term]):
1517
+ for t in marginals:
1518
+ if t.scale is None:
1519
+ raise ValueError(f"Invalid scale for {t}: {t.scale}")
1520
+ try:
1521
+ # ignoring type here because potential errors are handlded
1522
+ t.coef.dist_node["penalty"] # type: ignore
1523
+ except Exception as e:
1524
+ raise ValueError(f"Invalid penalty for {t}") from e
1525
+
1526
+ for i, b in enumerate(TPTerm._get_bases(marginals)):
1527
+ if b.value.ndim != 2:
1528
+ raise ValueError(
1529
+ "Expected 2-dimensional basis, but the basis "
1530
+ f"of {marginals[i]} has shape {b.value.shape}"
1531
+ )
1532
+
1533
+ @property
1534
+ def input_obs(self) -> dict[str, lsl.Var]:
1535
+ return self._input_obs(self.bases)
1536
+
1537
+ @staticmethod
1538
+ def _input_obs(bases: Sequence[Basis]) -> dict[str, lsl.Var]:
1539
+ # this method includes assumptions about how the individual bases are
1540
+ # structured: Basis.x can be a strong observed variable directly, or a
1541
+ # calculator variable that depends on strong observed variables.
1542
+ # If these assumptions are violated, this method may produce unexpected results.
1543
+ # The bases created by BasisBuilder fit theses assumptions.
1544
+ _input_x = {}
1545
+ for b in bases:
1546
+ if isinstance(b.x, lsl.Var):
1547
+ if b.x.strong and b.x.observed:
1548
+ # case: ordinary univariate marginal basis, like ps
1549
+ if not b.x.name:
1550
+ raise ValueError(f"Observed name not found for {b}")
1551
+ _input_x[b.x.name] = b.x
1552
+ elif b.x.weak:
1553
+ # currently, I don't expect this case to be present
1554
+ # but it would make sense
1555
+ for xi in b.x.all_input_vars():
1556
+ if xi.observed:
1557
+ if not xi.name:
1558
+ raise ValueError(f"Observed name not found for {b}")
1559
+ _input_x[xi.name] = xi
1560
+
1561
+ else:
1562
+ # case: potentially multivariate marginal, possibly tp,
1563
+ # where basis.x is a calculator that collects the strong inputs.
1564
+ for xj in b.x.all_input_nodes():
1565
+ if not isinstance(xj, lsl.Var):
1566
+ if xj.var is not None:
1567
+ if xj.var.observed:
1568
+ _input_x[xj.var.name] = xj.var
1569
+ elif xj.observed:
1570
+ if not xj.name:
1571
+ raise ValueError(f"Observed name not found for {b}")
1572
+ _input_x[xj.name] = xj
1573
+
1574
+ return _input_x
1575
+
1576
+ @classmethod
1577
+ def f(
1578
+ cls,
1579
+ *marginals: Term,
1580
+ common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
1581
+ fname: str = "ta",
1582
+ inference: InferenceTypes = None,
1583
+ _update_on_init: bool = True,
1584
+ ) -> Self:
1585
+ xnames = list(cls._input_obs(cls._get_bases(marginals)))
1586
+ name = fname + "(" + ",".join(xnames) + ")"
1587
+
1588
+ coef_name = "$\\beta_{" + name + "}$"
1589
+
1590
+ term = cls(
1591
+ *marginals,
1592
+ common_scale=common_scale,
1593
+ inference=inference,
1594
+ coef_name=coef_name,
1595
+ name=name,
1596
+ basis_name=None,
1597
+ _update_on_init=_update_on_init,
1598
+ )
1599
+
1600
+ return term