liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
liesel_gam/var.py
CHANGED
|
@@ -1,96 +1,581 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
from
|
|
3
|
+
import copy
|
|
4
|
+
from collections.abc import Callable, Sequence
|
|
5
|
+
from functools import reduce
|
|
6
|
+
from typing import Any, Literal, NamedTuple, Self
|
|
5
7
|
|
|
6
8
|
import jax
|
|
7
9
|
import jax.numpy as jnp
|
|
8
10
|
import liesel.goose as gs
|
|
9
11
|
import liesel.model as lsl
|
|
10
12
|
import tensorflow_probability.substrates.jax.distributions as tfd
|
|
13
|
+
from formulaic import ModelSpec
|
|
11
14
|
|
|
12
|
-
from .
|
|
15
|
+
from liesel_gam.builder.category_mapping import CategoryMapping
|
|
16
|
+
|
|
17
|
+
from .constraint import LinearConstraintEVD, penalty_to_unit_design
|
|
18
|
+
from .dist import MultivariateNormalSingular, MultivariateNormalStructured
|
|
13
19
|
from .kernel import init_star_ig_gibbs
|
|
14
|
-
from .roles import Roles
|
|
15
20
|
|
|
16
21
|
InferenceTypes = Any
|
|
17
|
-
Array =
|
|
22
|
+
Array = jax.Array
|
|
23
|
+
ArrayLike = jax.typing.ArrayLike
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VarIGPrior(NamedTuple):
|
|
27
|
+
concentration: float
|
|
28
|
+
scale: float
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _append_name(name: str, append: str) -> str:
|
|
32
|
+
if name == "":
|
|
33
|
+
return ""
|
|
34
|
+
else:
|
|
35
|
+
return name + append
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _ensure_var_or_node(
|
|
39
|
+
x: lsl.Var | lsl.Node | ArrayLike,
|
|
40
|
+
name: str | None,
|
|
41
|
+
) -> lsl.Var | lsl.Node:
|
|
42
|
+
"""
|
|
43
|
+
If x is an array, creates a new observed variable.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(x, lsl.Var | lsl.Node):
|
|
46
|
+
x_var = x
|
|
47
|
+
else:
|
|
48
|
+
name = name if name is not None else ""
|
|
49
|
+
x_var = lsl.Var.new_obs(jnp.asarray(x), name=name)
|
|
50
|
+
|
|
51
|
+
if name is not None and x_var.name != name:
|
|
52
|
+
raise ValueError(f"{x_var.name=} and {name=} are incompatible.")
|
|
53
|
+
|
|
54
|
+
return x_var
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _ensure_value(
|
|
58
|
+
x: lsl.Var | lsl.Node | ArrayLike,
|
|
59
|
+
name: str | None,
|
|
60
|
+
) -> lsl.Var | lsl.Node:
|
|
61
|
+
"""
|
|
62
|
+
If x is an array, creates a new value node.
|
|
63
|
+
"""
|
|
64
|
+
if isinstance(x, lsl.Var | lsl.Node):
|
|
65
|
+
x_var = x
|
|
66
|
+
else:
|
|
67
|
+
name = name if name is not None else ""
|
|
68
|
+
x_var = lsl.Value(jnp.asarray(x), _name=name)
|
|
69
|
+
|
|
70
|
+
if name is not None and x_var.name != name:
|
|
71
|
+
raise ValueError(f"{x_var.name=} and {name=} are incompatible.")
|
|
72
|
+
|
|
73
|
+
return x_var
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class UserVar(lsl.Var):
|
|
77
|
+
@classmethod
|
|
78
|
+
def new_calc(cls, *args, **kwargs) -> None: # type: ignore
|
|
79
|
+
raise NotImplementedError(
|
|
80
|
+
f"This constructor is not implemented on {cls.__name__}."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def new_obs(cls, *args, **kwargs) -> None: # type: ignore
|
|
85
|
+
raise NotImplementedError(
|
|
86
|
+
f"This constructor is not implemented on {cls.__name__}."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def new_param(cls, *args, **kwargs) -> None: # type: ignore
|
|
91
|
+
raise NotImplementedError(
|
|
92
|
+
f"This constructor is not implemented on {cls.__name__}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def new_value(cls, *args, **kwargs) -> None: # type: ignore
|
|
97
|
+
raise NotImplementedError(
|
|
98
|
+
f"This constructor is not implemented on {cls.__name__}."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def mvn_diag_prior(scale: lsl.Var) -> lsl.Dist:
|
|
103
|
+
return lsl.Dist(tfd.Normal, loc=0.0, scale=scale)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def mvn_structured_prior(scale: lsl.Var, penalty: lsl.Var | lsl.Value) -> lsl.Dist:
|
|
107
|
+
if isinstance(penalty, lsl.Var) and not penalty.strong:
|
|
108
|
+
raise NotImplementedError(
|
|
109
|
+
"Varying penalties or currently not supported by this function."
|
|
110
|
+
)
|
|
111
|
+
prior = lsl.Dist(
|
|
112
|
+
MultivariateNormalSingular,
|
|
113
|
+
loc=0.0,
|
|
114
|
+
scale=scale,
|
|
115
|
+
penalty=penalty,
|
|
116
|
+
penalty_rank=jnp.linalg.matrix_rank(penalty.value),
|
|
117
|
+
)
|
|
118
|
+
return prior
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def term_prior(
|
|
122
|
+
scale: lsl.Var | Array | None,
|
|
123
|
+
penalty: lsl.Var | lsl.Value | Array | None,
|
|
124
|
+
) -> lsl.Dist | None:
|
|
125
|
+
"""
|
|
126
|
+
Returns
|
|
127
|
+
- None if scale=None
|
|
128
|
+
- A simple Normal prior with loc=0.0 and scale=scale if penalty=None
|
|
129
|
+
- A potentially rank-deficient structured multivariate normal prior otherwise
|
|
130
|
+
"""
|
|
131
|
+
if scale is None:
|
|
132
|
+
if penalty is not None:
|
|
133
|
+
raise ValueError(f"If {scale=}, then penalty must also be None.")
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
if not isinstance(scale, lsl.Var | lsl.Value):
|
|
137
|
+
scale = lsl.Var(scale)
|
|
138
|
+
|
|
139
|
+
if penalty is None:
|
|
140
|
+
return mvn_diag_prior(scale)
|
|
141
|
+
|
|
142
|
+
if not isinstance(penalty, lsl.Var | lsl.Value):
|
|
143
|
+
penalty = lsl.Value(penalty)
|
|
144
|
+
|
|
145
|
+
return mvn_structured_prior(scale, penalty)
|
|
146
|
+
|
|
18
147
|
|
|
148
|
+
class ScaleIG(UserVar):
|
|
149
|
+
"""
|
|
150
|
+
A variable with an Inverse Gamma prior on its square.
|
|
151
|
+
|
|
152
|
+
The variance parameter (i.e. the squared scale) is flagged as a parameter.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
value
|
|
157
|
+
Initial value of the variable.
|
|
158
|
+
concentration
|
|
159
|
+
Concentration parameter of the inverse gamma distribution.\
|
|
160
|
+
In some parameterizations, this parameter is called ``a``.
|
|
161
|
+
scale
|
|
162
|
+
Scale parameter of the inverse gamma distribution.\
|
|
163
|
+
In some parameterizations, this parameter is called ``b``.
|
|
164
|
+
name
|
|
165
|
+
Name of the variable.
|
|
166
|
+
inference
|
|
167
|
+
Inference type.
|
|
168
|
+
"""
|
|
19
169
|
|
|
20
|
-
class SmoothTerm(lsl.Var):
|
|
21
170
|
def __init__(
|
|
22
171
|
self,
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
scale: lsl.Var,
|
|
26
|
-
name: str,
|
|
172
|
+
value: float | Array,
|
|
173
|
+
concentration: float | lsl.Var | lsl.Node | ArrayLike,
|
|
174
|
+
scale: float | lsl.Var | lsl.Node | ArrayLike,
|
|
175
|
+
name: str = "",
|
|
176
|
+
variance_name: str = "",
|
|
27
177
|
inference: InferenceTypes = None,
|
|
28
|
-
coef_name: str | None = None,
|
|
29
178
|
):
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
179
|
+
value = jnp.asarray(value)
|
|
180
|
+
if value.size != 1:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Expected scalar value for ScaleIG, got size {value.size}."
|
|
183
|
+
)
|
|
34
184
|
|
|
35
|
-
|
|
185
|
+
concentration_node = _ensure_value(
|
|
186
|
+
concentration, name=_append_name(name, "_concentration")
|
|
187
|
+
)
|
|
188
|
+
scale_node = _ensure_value(scale, name=_append_name(name, "_scale"))
|
|
36
189
|
|
|
37
190
|
prior = lsl.Dist(
|
|
38
|
-
|
|
39
|
-
loc=0.0,
|
|
40
|
-
scale=scale,
|
|
41
|
-
penalty=penalty,
|
|
42
|
-
penalty_rank=jnp.linalg.matrix_rank(penalty),
|
|
191
|
+
tfd.InverseGamma, concentration=concentration_node, scale=scale_node
|
|
43
192
|
)
|
|
44
193
|
|
|
45
|
-
|
|
46
|
-
|
|
194
|
+
variance_name = variance_name or _append_name(name, "_square")
|
|
195
|
+
|
|
196
|
+
self._variance_param = lsl.Var.new_param(
|
|
197
|
+
value, prior, inference=inference, name=variance_name
|
|
198
|
+
)
|
|
199
|
+
super().__init__(lsl.Calc(jnp.sqrt, self._variance_param), name=name)
|
|
200
|
+
|
|
201
|
+
def setup_gibbs_inference(
|
|
202
|
+
self, coef: lsl.Var, penalty: jax.typing.ArrayLike | None = None
|
|
203
|
+
) -> ScaleIG:
|
|
204
|
+
init_gibbs = copy.copy(init_star_ig_gibbs)
|
|
205
|
+
init_gibbs.__name__ = "StarVarianceGibbs"
|
|
206
|
+
|
|
207
|
+
self._variance_param.inference = gs.MCMCSpec(
|
|
208
|
+
init_star_ig_gibbs,
|
|
209
|
+
kernel_kwargs={"coef": coef, "scale": self, "penalty": penalty},
|
|
210
|
+
)
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _init_scale_ig(
|
|
215
|
+
x: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
|
|
216
|
+
validate_scalar: bool = False,
|
|
217
|
+
) -> ScaleIG | lsl.Var | None:
|
|
218
|
+
if isinstance(x, VarIGPrior):
|
|
219
|
+
concentration = jnp.asarray(x.concentration)
|
|
220
|
+
scale_ = jnp.asarray(x.scale)
|
|
221
|
+
|
|
222
|
+
if validate_scalar:
|
|
223
|
+
if not concentration.size == 1:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"Expected scalar hyperparameter 'concentration', "
|
|
226
|
+
f"got size {concentration.size}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if not scale_.size == 1:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Expected scalar hyperparameter 'scale', got size {scale_.size}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
scale_var: ScaleIG | lsl.Var | None = ScaleIG(
|
|
235
|
+
value=jnp.array(1.0),
|
|
236
|
+
concentration=concentration,
|
|
237
|
+
scale=scale_,
|
|
238
|
+
)
|
|
239
|
+
elif isinstance(x, ScaleIG | lsl.Var):
|
|
240
|
+
if isinstance(x, ScaleIG):
|
|
241
|
+
if x._variance_param.strong:
|
|
242
|
+
x._variance_param.value = jnp.asarray(x._variance_param.value)
|
|
243
|
+
x.update()
|
|
244
|
+
elif x.strong:
|
|
245
|
+
x.value = jnp.asarray(x.value)
|
|
246
|
+
|
|
247
|
+
scale_var = x
|
|
248
|
+
if validate_scalar:
|
|
249
|
+
size = jnp.asarray(scale_var.value).size
|
|
250
|
+
if not size == 1:
|
|
251
|
+
raise ValueError(f"Expected scalar scale, got size {size}")
|
|
252
|
+
elif x is not None:
|
|
253
|
+
scale_var = lsl.Var.new_value(jnp.asarray(x))
|
|
254
|
+
if validate_scalar:
|
|
255
|
+
size = scale_var.value.size
|
|
256
|
+
if not size == 1:
|
|
257
|
+
raise ValueError(f"Expected scalar scale, got size {size}")
|
|
258
|
+
elif x is None:
|
|
259
|
+
scale_var = x
|
|
260
|
+
else:
|
|
261
|
+
raise TypeError(f"Unexpected type for scale: {type(x)}")
|
|
262
|
+
|
|
263
|
+
return scale_var
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _validate_scalar_or_p_scale(scale_value: Array, p):
|
|
267
|
+
try:
|
|
268
|
+
is_scalar = scale_value.size == 1
|
|
269
|
+
except AttributeError:
|
|
270
|
+
raise TypeError(
|
|
271
|
+
f"Expected scale value to be an array, got type {type(scale_value)}"
|
|
272
|
+
)
|
|
273
|
+
is_p = scale_value.size == p
|
|
274
|
+
if not (is_scalar or is_p):
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"Expected scale to have size 1 or {p}, got size {scale_value.size}"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class Term(UserVar):
|
|
281
|
+
"""
|
|
282
|
+
General structured additive term.
|
|
283
|
+
|
|
284
|
+
A structured additive term represents a smooth or structured effect in a
|
|
285
|
+
generalized additive model. The term wraps a design/basis matrix together
|
|
286
|
+
with a prior/penalty and a set of coefficients. The object exposes the
|
|
287
|
+
coefficient variable and evaluates the term as the matrix-vector product
|
|
288
|
+
of the basis and the coefficients.
|
|
289
|
+
The term evaluates to ``basis @ coef``.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
basis
|
|
294
|
+
A :class:`.Basis` instance that produces the design matrix for the \
|
|
295
|
+
term. The basis must evaluate to a 2-D array with shape ``(n_obs, n_bases)``.
|
|
296
|
+
penalty
|
|
297
|
+
Penalty matrix or a variable/value wrapping the penalty \
|
|
298
|
+
used to construct the multivariate normal prior for the coefficients.
|
|
299
|
+
scale
|
|
300
|
+
Scale parameter for the prior on the coefficients. This \
|
|
301
|
+
is typically either a scalar or a per-coefficient scale variable.
|
|
302
|
+
name
|
|
303
|
+
Human-readable name for the term. Used for labelling variables and \
|
|
304
|
+
building sensible default names for internal nodes.
|
|
305
|
+
inference
|
|
306
|
+
:class:`liesel.goose.MCMCSpec` inference specification forwarded to coefficient\
|
|
307
|
+
creation.
|
|
308
|
+
coef_name
|
|
309
|
+
Name for the coefficient variable. If ``None``, a default name based \
|
|
310
|
+
on ``name`` will be used.
|
|
311
|
+
_update_on_init
|
|
312
|
+
If ``True`` (default) the internal calculation/graph nodes are \
|
|
313
|
+
evaluated during initialization. Set to ``False`` to delay \
|
|
314
|
+
initial evaluation.
|
|
315
|
+
|
|
316
|
+
Raises
|
|
317
|
+
------
|
|
318
|
+
ValueError
|
|
319
|
+
If ``basis.value`` does not have two dimensions.
|
|
320
|
+
|
|
321
|
+
Attributes
|
|
322
|
+
----------
|
|
323
|
+
scale
|
|
324
|
+
The scale variable used by the prior on the coefficients.
|
|
325
|
+
nbases
|
|
326
|
+
Number of basis functions (number of columns in the basis matrix).
|
|
327
|
+
basis
|
|
328
|
+
The basis object provided to the constructor.
|
|
329
|
+
coef
|
|
330
|
+
The coefficient variable created for this term. It holds the prior
|
|
331
|
+
(multivariate normal singular) and is used in the evaluation of the
|
|
332
|
+
term.
|
|
333
|
+
is_noncentered
|
|
334
|
+
Whether the term has been reparameterized to the non-centered form.
|
|
335
|
+
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
basis: Basis,
|
|
341
|
+
penalty: lsl.Var | lsl.Value | Array | None,
|
|
342
|
+
scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
|
|
343
|
+
name: str = "",
|
|
344
|
+
inference: InferenceTypes = None,
|
|
345
|
+
coef_name: str | None = None,
|
|
346
|
+
_update_on_init: bool = True,
|
|
347
|
+
validate_scalar_scale: bool = True,
|
|
348
|
+
):
|
|
349
|
+
scale = _init_scale_ig(scale, validate_scalar=validate_scalar_scale)
|
|
350
|
+
|
|
351
|
+
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
|
|
352
|
+
|
|
353
|
+
prior = term_prior(scale, penalty)
|
|
354
|
+
|
|
47
355
|
self.basis = basis
|
|
356
|
+
|
|
357
|
+
if isinstance(penalty, lsl.Var | lsl.Value):
|
|
358
|
+
nparam = jnp.shape(penalty.value)[-1]
|
|
359
|
+
elif penalty is not None:
|
|
360
|
+
nparam = jnp.shape(penalty)[-1]
|
|
361
|
+
else:
|
|
362
|
+
nparam = self.nbases
|
|
363
|
+
|
|
364
|
+
if scale is not None:
|
|
365
|
+
_validate_scalar_or_p_scale(scale.value, nparam)
|
|
48
366
|
self.coef = lsl.Var.new_param(
|
|
49
|
-
jnp.zeros(
|
|
367
|
+
jnp.zeros(nparam), prior, inference=inference, name=coef_name
|
|
368
|
+
)
|
|
369
|
+
calc = lsl.Calc(
|
|
370
|
+
lambda basis, coef: jnp.dot(basis, coef),
|
|
371
|
+
basis=basis,
|
|
372
|
+
coef=self.coef,
|
|
373
|
+
_update_on_init=_update_on_init,
|
|
50
374
|
)
|
|
51
|
-
|
|
375
|
+
self._scale = scale
|
|
52
376
|
|
|
53
377
|
super().__init__(calc, name=name)
|
|
378
|
+
if _update_on_init:
|
|
379
|
+
self.coef.update()
|
|
380
|
+
|
|
381
|
+
self.is_noncentered = False
|
|
382
|
+
|
|
383
|
+
if hasattr(self.scale, "setup_gibbs_inference"):
|
|
384
|
+
try:
|
|
385
|
+
self.scale.setup_gibbs_inference(self.coef) # type: ignore
|
|
386
|
+
except Exception as e:
|
|
387
|
+
raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def nbases(self) -> int:
|
|
391
|
+
return jnp.shape(self.basis.value)[-1]
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def scale(self) -> lsl.Var | lsl.Node | None:
|
|
395
|
+
return self._scale
|
|
396
|
+
|
|
397
|
+
def reparam_noncentered(self) -> Self:
|
|
398
|
+
"""
|
|
399
|
+
Turns this term into noncentered form, which means the prior for
|
|
400
|
+
the coefficient will be turned from ``coef ~ N(0, scale^2 * inv(penalty))`` into
|
|
401
|
+
``latent_coef ~ N(0, inv(penalty)); coef = scale * latent_coef``.
|
|
402
|
+
This can sometimes be helpful when sampling with the No-U-Turn Sampler.
|
|
403
|
+
"""
|
|
404
|
+
if self.scale is None:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"Noncentering reparameterization of {self} fails, "
|
|
407
|
+
f"because {self.scale=}."
|
|
408
|
+
)
|
|
409
|
+
if self.is_noncentered:
|
|
410
|
+
return self
|
|
411
|
+
|
|
412
|
+
assert self.coef.dist_node is not None
|
|
413
|
+
|
|
414
|
+
self.coef.dist_node["scale"] = lsl.Value(jnp.array(1.0))
|
|
415
|
+
|
|
416
|
+
if self.scale.name and self.coef.name:
|
|
417
|
+
scaled_name = self.scale.name + "*" + self.coef.name
|
|
418
|
+
else:
|
|
419
|
+
scaled_name = _append_name(self.coef.name, "_scaled")
|
|
420
|
+
|
|
421
|
+
scaled_coef = lsl.Var.new_calc(
|
|
422
|
+
lambda scale, coef: scale * coef,
|
|
423
|
+
self.scale,
|
|
424
|
+
self.coef,
|
|
425
|
+
name=scaled_name,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
self.value_node["coef"] = scaled_coef
|
|
54
429
|
self.coef.update()
|
|
55
430
|
self.update()
|
|
56
|
-
self.
|
|
57
|
-
|
|
431
|
+
self.is_noncentered = True
|
|
432
|
+
|
|
433
|
+
if hasattr(self.scale, "setup_gibbs_inference"):
|
|
434
|
+
try:
|
|
435
|
+
pen = self.coef.dist_node["penalty"].value
|
|
436
|
+
self.scale.setup_gibbs_inference(scaled_coef, penalty=pen) # type: ignore
|
|
437
|
+
except Exception as e:
|
|
438
|
+
raise RuntimeError(f"Failed to setup Gibbs kernel for {self}") from e
|
|
439
|
+
|
|
440
|
+
return self
|
|
58
441
|
|
|
59
442
|
@classmethod
|
|
60
|
-
def
|
|
443
|
+
def f(
|
|
61
444
|
cls,
|
|
62
|
-
basis: Basis
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
ig_concentration: float = 0.01,
|
|
66
|
-
ig_scale: float = 0.01,
|
|
445
|
+
basis: Basis,
|
|
446
|
+
fname: str = "f",
|
|
447
|
+
scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
|
|
67
448
|
inference: InferenceTypes = None,
|
|
68
|
-
variance_value: float | None = None,
|
|
69
|
-
variance_name: str | None = None,
|
|
70
|
-
variance_jitter_dist: tfd.Distribution | None = None,
|
|
71
449
|
coef_name: str | None = None,
|
|
450
|
+
noncentered: bool = False,
|
|
72
451
|
) -> Self:
|
|
73
|
-
|
|
452
|
+
"""
|
|
453
|
+
Construct a smooth term from a :class:`.Basis`.
|
|
454
|
+
|
|
455
|
+
This convenience constructor builds a named ``term`` using the
|
|
456
|
+
provided basis. The penalty matrix is taken from ``basis.penalty`` and
|
|
457
|
+
a coefficient variable with an appropriate multivariate-normal prior
|
|
458
|
+
is created. The returned term evaluates to ``basis @ coef``.
|
|
459
|
+
|
|
460
|
+
Parameters
|
|
461
|
+
----------
|
|
462
|
+
basis
|
|
463
|
+
Basis object that provides the design matrix and penalty for the \
|
|
464
|
+
smooth term. The basis must have an associated input variable with \
|
|
465
|
+
a meaningful name (used to compose the term name).
|
|
466
|
+
fname
|
|
467
|
+
Function-name prefix used when constructing the term name. Default \
|
|
468
|
+
is ``'f'`` which results in names like ``f(x)`` when the basis \
|
|
469
|
+
input is named ``x``.
|
|
470
|
+
scale
|
|
471
|
+
Scale parameter passed to the coefficient prior.
|
|
472
|
+
inference
|
|
473
|
+
Inference specification forwarded to the coefficient variable \
|
|
474
|
+
creation, a :class:`liesel.goose.MCMCSpec`.
|
|
475
|
+
noncentered
|
|
476
|
+
If ``True``, the term is reparameterized to the non-centered \
|
|
477
|
+
form via :meth:`.reparam_noncentered` before being returned.
|
|
478
|
+
coef_name
|
|
479
|
+
Coefficient name. The default coefficient name is a LaTeX-like string \
|
|
480
|
+
``"$\\beta_{f(x)}$"`` to improve readability in printed summaries.
|
|
481
|
+
|
|
482
|
+
Returns
|
|
483
|
+
-------
|
|
484
|
+
A :class:`.Term` instance configured with the given basis and prior settings.
|
|
485
|
+
"""
|
|
486
|
+
if not basis.x.name:
|
|
487
|
+
raise ValueError("basis.x must be named.")
|
|
488
|
+
|
|
489
|
+
if not basis.name:
|
|
490
|
+
raise ValueError("basis must be named.")
|
|
491
|
+
|
|
492
|
+
name = f"{fname}({basis.x.name})"
|
|
493
|
+
coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$"
|
|
74
494
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
495
|
+
term = cls(
|
|
496
|
+
basis=basis,
|
|
497
|
+
penalty=basis.penalty if scale is not None else None,
|
|
498
|
+
scale=scale,
|
|
499
|
+
inference=inference,
|
|
500
|
+
coef_name=coef_name,
|
|
501
|
+
name=name,
|
|
502
|
+
validate_scalar_scale=not noncentered,
|
|
83
503
|
)
|
|
84
|
-
variance.role = Roles.variance_smooth
|
|
85
504
|
|
|
86
|
-
|
|
87
|
-
|
|
505
|
+
if noncentered:
|
|
506
|
+
term.reparam_noncentered()
|
|
88
507
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
508
|
+
return term
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
def new_ig(
|
|
512
|
+
cls,
|
|
513
|
+
basis: Basis,
|
|
514
|
+
penalty: lsl.Var | lsl.Value | Array | None,
|
|
515
|
+
name: str,
|
|
516
|
+
ig_concentration: float = 1.0,
|
|
517
|
+
ig_scale: float = 0.005,
|
|
518
|
+
inference: InferenceTypes = None,
|
|
519
|
+
scale_value: float = 100.0,
|
|
520
|
+
scale_name: str | None = None,
|
|
521
|
+
coef_name: str | None = None,
|
|
522
|
+
noncentered: bool = False,
|
|
523
|
+
) -> Term:
|
|
524
|
+
"""
|
|
525
|
+
Construct a smooth term with an inverse-gamma prior on the variance.
|
|
526
|
+
|
|
527
|
+
This convenience constructor creates a term similar to :meth:`.f` but
|
|
528
|
+
sets up an explicit variance parameter with an Inverse-Gamma prior.
|
|
529
|
+
A scale variable is set up by taking the square-root, and the
|
|
530
|
+
coefficient prior uses the derived ``scale`` together with the basis
|
|
531
|
+
penalty. By default a Gibbs-style initialization is attached to the
|
|
532
|
+
variance inference via an internal kernel; an optional jitter
|
|
533
|
+
distribution can be provided for MCMC initialization.
|
|
534
|
+
|
|
535
|
+
Parameters
|
|
536
|
+
----------
|
|
537
|
+
basis
|
|
538
|
+
Basis object providing the design matrix and penalty.
|
|
539
|
+
name
|
|
540
|
+
Term name.
|
|
541
|
+
penalty
|
|
542
|
+
Penalty matrix or a variable/value wrapping the penalty \
|
|
543
|
+
used to construct the multivariate normal prior for the coefficients.
|
|
544
|
+
ig_concentration
|
|
545
|
+
Concentration (shape) parameter of the Inverse-Gamma prior for the \
|
|
546
|
+
variance.
|
|
547
|
+
ig_scale
|
|
548
|
+
Scale parameter of the Inverse-Gamma prior for the variance.
|
|
549
|
+
inference
|
|
550
|
+
Inference specification forwarded to the coefficient variable \
|
|
551
|
+
creation, a :class:`liesel.goose.MCMCSpec`.
|
|
552
|
+
variance_value
|
|
553
|
+
Initial value for the variance parameter.
|
|
554
|
+
variance_name
|
|
555
|
+
Variance parameter name. The default is a LaTeX-like representation \
|
|
556
|
+
``"$\\tau^2_{...}$"`` for readability in summaries.
|
|
557
|
+
coef_name
|
|
558
|
+
Coefficient name. The default coefficient name is a LaTeX-like string \
|
|
559
|
+
``"$\\beta_{f(x)}$"`` to improve readability in printed summaries.
|
|
560
|
+
noncentered
|
|
561
|
+
If ``True``, reparameterize the term to non-centered form \
|
|
562
|
+
(see :meth:`.reparam_noncentered`).
|
|
563
|
+
|
|
564
|
+
Returns
|
|
565
|
+
-------
|
|
566
|
+
A :class:`.Term` instance configured with an inverse-gamma prior on
|
|
567
|
+
the variance and an appropriate inference specification for
|
|
568
|
+
variance updates.
|
|
569
|
+
|
|
570
|
+
"""
|
|
571
|
+
coef_name = coef_name or "$\\beta_{" + f"{name}" + "}$"
|
|
572
|
+
scale_name = scale_name or "$\\tau$"
|
|
573
|
+
scale = ScaleIG(
|
|
574
|
+
jnp.asarray(scale_value),
|
|
575
|
+
concentration=ig_concentration,
|
|
576
|
+
scale=ig_scale,
|
|
577
|
+
name=scale_name,
|
|
578
|
+
)
|
|
94
579
|
|
|
95
580
|
term = cls(
|
|
96
581
|
basis=basis,
|
|
@@ -101,118 +586,1015 @@ class SmoothTerm(lsl.Var):
|
|
|
101
586
|
coef_name=coef_name,
|
|
102
587
|
)
|
|
103
588
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
kernel_kwargs={"coef": term.coef},
|
|
107
|
-
jitter_dist=variance_jitter_dist,
|
|
108
|
-
)
|
|
589
|
+
if noncentered:
|
|
590
|
+
term.reparam_noncentered()
|
|
109
591
|
|
|
110
592
|
return term
|
|
111
593
|
|
|
594
|
+
def diagonalize_penalty(self, atol: float = 1e-6) -> Self:
|
|
595
|
+
"""
|
|
596
|
+
Diagonalize the penalty via an eigenvalue decomposition.
|
|
597
|
+
|
|
598
|
+
This method computes a transformation that diagonalizes
|
|
599
|
+
the penalty matrix and updates the internal basis function such that
|
|
600
|
+
subsequent evaluations use the accordingly transformed basis. The penalty is
|
|
601
|
+
updated to the diagonalized version.
|
|
602
|
+
|
|
603
|
+
Returns
|
|
604
|
+
-------
|
|
605
|
+
The modified term instance (self).
|
|
606
|
+
"""
|
|
607
|
+
self.basis.diagonalize_penalty(atol)
|
|
608
|
+
return self
|
|
609
|
+
|
|
610
|
+
def scale_penalty(self) -> Self:
|
|
611
|
+
"""
|
|
612
|
+
Scale the penalty matrix by its infinite norm.
|
|
613
|
+
|
|
614
|
+
The penalty matrix is divided by its infinity norm (max absolute row
|
|
615
|
+
sum) so that its values are numerically well-conditioned for
|
|
616
|
+
downstream use. The updated penalty replaces the previous one.
|
|
617
|
+
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
The modified term instance (self).
|
|
621
|
+
"""
|
|
622
|
+
self.basis.scale_penalty()
|
|
623
|
+
return self
|
|
624
|
+
|
|
625
|
+
def constrain(
|
|
626
|
+
self,
|
|
627
|
+
constraint: ArrayLike
|
|
628
|
+
| Literal["sumzero_term", "sumzero_coef", "constant_and_linear"],
|
|
629
|
+
) -> Self:
|
|
630
|
+
"""
|
|
631
|
+
Apply a linear constraint to the term's basis and corresponding penalty.
|
|
632
|
+
|
|
633
|
+
Parameters
|
|
634
|
+
----------
|
|
635
|
+
constraint
|
|
636
|
+
Type of constraint or custom linear constraint matrix to apply. \
|
|
637
|
+
If an array is supplied, the constraint will be \
|
|
638
|
+
``A @ coef == 0``, where ``A`` is the supplied constraint matrix.
|
|
639
|
+
|
|
640
|
+
Returns
|
|
641
|
+
-------
|
|
642
|
+
The modified term instance (self).
|
|
643
|
+
"""
|
|
644
|
+
self.basis.constrain(constraint)
|
|
645
|
+
self.coef.value = jnp.zeros(self.nbases)
|
|
646
|
+
return self
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
SmoothTerm = Term
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
class MRFTerm(Term):
|
|
653
|
+
_neighbors = None
|
|
654
|
+
_polygons = None
|
|
655
|
+
_ordered_labels = None
|
|
656
|
+
_labels = None
|
|
657
|
+
_mapping = None
|
|
112
658
|
|
|
113
|
-
|
|
659
|
+
@property
|
|
660
|
+
def neighbors(self) -> dict[str, list[str]] | None:
|
|
661
|
+
return self._neighbors
|
|
662
|
+
|
|
663
|
+
@neighbors.setter
|
|
664
|
+
def neighbors(self, value: dict[str, list[str]] | None) -> None:
|
|
665
|
+
self._neighbors = value
|
|
666
|
+
|
|
667
|
+
@property
|
|
668
|
+
def polygons(self) -> dict[str, ArrayLike] | None:
|
|
669
|
+
return self._polygons
|
|
670
|
+
|
|
671
|
+
@polygons.setter
|
|
672
|
+
def polygons(self, value: dict[str, ArrayLike] | None) -> None:
|
|
673
|
+
self._polygons = value
|
|
674
|
+
|
|
675
|
+
@property
|
|
676
|
+
def labels(self) -> list[str] | None:
|
|
677
|
+
return self._labels
|
|
678
|
+
|
|
679
|
+
@labels.setter
|
|
680
|
+
def labels(self, value: list[str]) -> None:
|
|
681
|
+
self._labels = value
|
|
682
|
+
|
|
683
|
+
@property
|
|
684
|
+
def mapping(self) -> CategoryMapping:
|
|
685
|
+
if self._mapping is None:
|
|
686
|
+
raise ValueError("No mapping defined.")
|
|
687
|
+
return self._mapping
|
|
688
|
+
|
|
689
|
+
@mapping.setter
|
|
690
|
+
def mapping(self, value: CategoryMapping) -> None:
|
|
691
|
+
self._mapping = value
|
|
692
|
+
|
|
693
|
+
@property
|
|
694
|
+
def ordered_labels(self) -> list[str] | None:
|
|
695
|
+
return self._ordered_labels
|
|
696
|
+
|
|
697
|
+
@ordered_labels.setter
|
|
698
|
+
def ordered_labels(self, value: list[str]) -> None:
|
|
699
|
+
self._ordered_labels = value
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class IndexingTerm(Term):
|
|
114
703
|
def __init__(
|
|
115
704
|
self,
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
705
|
+
basis: Basis,
|
|
706
|
+
penalty: lsl.Var | lsl.Value | Array | None,
|
|
707
|
+
scale: ScaleIG | VarIGPrior | lsl.Var | ArrayLike | None,
|
|
708
|
+
name: str = "",
|
|
119
709
|
inference: InferenceTypes = None,
|
|
120
|
-
add_intercept: bool = False,
|
|
121
710
|
coef_name: str | None = None,
|
|
122
|
-
|
|
711
|
+
_update_on_init: bool = True,
|
|
712
|
+
validate_scalar_scale: bool = True,
|
|
123
713
|
):
|
|
124
|
-
|
|
125
|
-
|
|
714
|
+
if not basis.value.ndim == 1:
|
|
715
|
+
raise ValueError(f"IndexingTerm requires 1d basis, got {basis.value.ndim=}")
|
|
126
716
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
if add_intercept:
|
|
132
|
-
ones = jnp.ones(x.shape[0])
|
|
133
|
-
x = jnp.c_[ones, x]
|
|
134
|
-
return x
|
|
717
|
+
if not jnp.issubdtype(jnp.dtype(basis.value), jnp.integer):
|
|
718
|
+
raise TypeError(
|
|
719
|
+
f"IndexingTerm requires integer basis, got {jnp.dtype(basis.value)=}."
|
|
720
|
+
)
|
|
135
721
|
|
|
136
|
-
|
|
137
|
-
|
|
722
|
+
super().__init__(
|
|
723
|
+
basis=basis,
|
|
724
|
+
penalty=penalty,
|
|
725
|
+
scale=scale,
|
|
726
|
+
name=name,
|
|
727
|
+
inference=inference,
|
|
728
|
+
coef_name=coef_name,
|
|
729
|
+
_update_on_init=False,
|
|
730
|
+
validate_scalar_scale=validate_scalar_scale,
|
|
731
|
+
)
|
|
138
732
|
|
|
139
|
-
|
|
140
|
-
|
|
733
|
+
# mypy warns that self.value_node might be a lsl.Node, which does not have the
|
|
734
|
+
# attribute "function".
|
|
735
|
+
# But we can assume safely that self.value_node is a lsl.Calc, which does have
|
|
736
|
+
# one.
|
|
737
|
+
self.value_node.function = lambda basis, coef: jnp.take(coef, basis) # type: ignore
|
|
738
|
+
if _update_on_init:
|
|
739
|
+
self.coef.update()
|
|
740
|
+
self.update()
|
|
141
741
|
|
|
142
|
-
|
|
742
|
+
@property
|
|
743
|
+
def full_basis(self) -> Basis:
|
|
744
|
+
nclusters = jnp.unique(self.basis.value).size
|
|
745
|
+
full_basis = Basis(
|
|
746
|
+
self.basis.x, basis_fn=jax.nn.one_hot, num_classes=nclusters, name=""
|
|
747
|
+
)
|
|
748
|
+
return full_basis
|
|
143
749
|
|
|
144
|
-
|
|
750
|
+
|
|
751
|
+
class RITerm(IndexingTerm):
|
|
752
|
+
_labels = None
|
|
753
|
+
_mapping = None
|
|
754
|
+
|
|
755
|
+
@property
|
|
756
|
+
def full_basis(self) -> Basis:
|
|
757
|
+
try:
|
|
758
|
+
nclusters = len(self.mapping.labels_to_integers_map)
|
|
759
|
+
except ValueError:
|
|
760
|
+
nclusters = jnp.unique(self.basis.value).size
|
|
761
|
+
|
|
762
|
+
full_basis = Basis(
|
|
763
|
+
self.basis.x, basis_fn=jax.nn.one_hot, num_classes=nclusters, name=""
|
|
764
|
+
)
|
|
765
|
+
return full_basis
|
|
766
|
+
|
|
767
|
+
@property
|
|
768
|
+
def labels(self) -> list[str]:
|
|
769
|
+
if self._labels is None:
|
|
770
|
+
raise ValueError("No labels defined.")
|
|
771
|
+
return self._labels
|
|
772
|
+
|
|
773
|
+
@labels.setter
|
|
774
|
+
def labels(self, value: list[str]) -> None:
|
|
775
|
+
self._labels = value
|
|
776
|
+
|
|
777
|
+
@property
|
|
778
|
+
def mapping(self) -> CategoryMapping:
|
|
779
|
+
if self._mapping is None:
|
|
780
|
+
raise ValueError("No mapping defined.")
|
|
781
|
+
return self._mapping
|
|
782
|
+
|
|
783
|
+
@mapping.setter
|
|
784
|
+
def mapping(self, value: CategoryMapping) -> None:
|
|
785
|
+
self._mapping = value
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
class BasisDot(UserVar):
|
|
789
|
+
def __init__(
|
|
790
|
+
self,
|
|
791
|
+
basis: Basis,
|
|
792
|
+
prior: lsl.Dist | None = None,
|
|
793
|
+
name: str = "",
|
|
794
|
+
inference: InferenceTypes = None,
|
|
795
|
+
coef_name: str | None = None,
|
|
796
|
+
_update_on_init: bool = True,
|
|
797
|
+
):
|
|
145
798
|
self.basis = basis
|
|
799
|
+
self.nbases = self.basis.nbases
|
|
800
|
+
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
|
|
801
|
+
|
|
146
802
|
self.coef = lsl.Var.new_param(
|
|
147
|
-
jnp.zeros(nbases),
|
|
803
|
+
jnp.zeros(self.basis.nbases), prior, inference=inference, name=coef_name
|
|
804
|
+
)
|
|
805
|
+
calc = lsl.Calc(
|
|
806
|
+
lambda basis, coef: jnp.dot(basis, coef),
|
|
807
|
+
basis=self.basis,
|
|
808
|
+
coef=self.coef,
|
|
809
|
+
_update_on_init=_update_on_init,
|
|
148
810
|
)
|
|
149
|
-
calc = lsl.Calc(jnp.dot, basis, self.coef)
|
|
150
811
|
|
|
151
812
|
super().__init__(calc, name=name)
|
|
152
|
-
self.coef.role = Roles.coef_linear
|
|
153
|
-
self.role = Roles.term_linear
|
|
154
813
|
|
|
155
814
|
|
|
156
|
-
class Intercept(
|
|
815
|
+
class Intercept(UserVar):
|
|
157
816
|
def __init__(
|
|
158
817
|
self,
|
|
159
818
|
name: str,
|
|
160
|
-
value:
|
|
819
|
+
value: ArrayLike | float = 0.0,
|
|
161
820
|
distribution: lsl.Dist | None = None,
|
|
162
821
|
inference: InferenceTypes = None,
|
|
163
822
|
) -> None:
|
|
164
823
|
super().__init__(
|
|
165
|
-
value=value,
|
|
824
|
+
value=jnp.asarray(value),
|
|
825
|
+
distribution=distribution,
|
|
826
|
+
name=name,
|
|
827
|
+
inference=inference,
|
|
166
828
|
)
|
|
167
829
|
self.parameter = True
|
|
168
|
-
self.role = Roles.intercept
|
|
169
830
|
|
|
170
831
|
|
|
171
|
-
|
|
832
|
+
def make_callback(function, output_shape, dtype, m: int = 0):
|
|
833
|
+
if len(output_shape):
|
|
834
|
+
k = output_shape[-1]
|
|
835
|
+
|
|
836
|
+
def fn(x, **basis_kwargs):
|
|
837
|
+
n = jnp.shape(jnp.atleast_1d(x))[0]
|
|
838
|
+
if len(output_shape) == 2:
|
|
839
|
+
shape = (n - m, k)
|
|
840
|
+
elif len(output_shape) == 1:
|
|
841
|
+
shape = (n - m,)
|
|
842
|
+
elif not len(output_shape):
|
|
843
|
+
shape = ()
|
|
844
|
+
else:
|
|
845
|
+
raise RuntimeError(
|
|
846
|
+
"Return shape of 'basis_fn(value)' must"
|
|
847
|
+
f" have <= 2 dimensions, got {output_shape}"
|
|
848
|
+
)
|
|
849
|
+
result_shape = jax.ShapeDtypeStruct(shape, dtype)
|
|
850
|
+
result = jax.pure_callback(
|
|
851
|
+
function, result_shape, x, vmap_method="sequential", **basis_kwargs
|
|
852
|
+
)
|
|
853
|
+
return result
|
|
854
|
+
|
|
855
|
+
return fn
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def is_diagonal(M, atol=1e-12):
|
|
859
|
+
# mask for off-diagonal elements
|
|
860
|
+
off_diag_mask = ~jnp.eye(M.shape[-1], dtype=bool)
|
|
861
|
+
off_diag_values = M[off_diag_mask]
|
|
862
|
+
return jnp.all(jnp.abs(off_diag_values) < atol)
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
class Basis(UserVar):
|
|
866
|
+
"""
|
|
867
|
+
General basis for a structured additive term.
|
|
868
|
+
|
|
869
|
+
The ``Basis`` class wraps either a provided observation variable or a raw
|
|
870
|
+
array and a basis-generation function. It constructs an internal
|
|
871
|
+
calculation node that produces the basis (design) matrix used by
|
|
872
|
+
smooth terms. The basis function may be executed via a
|
|
873
|
+
callback that does not need to be jax-compatible (the default, potentially slow)
|
|
874
|
+
with a jax-compatible function that is included in just-in-time-compilation
|
|
875
|
+
(when ``use_callback=False``).
|
|
876
|
+
|
|
877
|
+
Parameters
|
|
878
|
+
----------
|
|
879
|
+
value
|
|
880
|
+
If a :class:`liesel.model.Var` or node is provided it is used as \
|
|
881
|
+
the input variable for the basis. Otherwise a raw array-like \
|
|
882
|
+
object may be supplied together with ``xname`` to create an \
|
|
883
|
+
observed variable internally.
|
|
884
|
+
basis_fn
|
|
885
|
+
Function mapping the input variable's values to a basis matrix or \
|
|
886
|
+
vector. It must accept the input array and any ``basis_kwargs`` \
|
|
887
|
+
and return an array of shape ``(n_obs, n_bases)`` (or a scalar/1-d \
|
|
888
|
+
array for simpler bases). By default this is the identity \
|
|
889
|
+
function (``lambda x: x``).
|
|
890
|
+
name
|
|
891
|
+
Optional name for the basis object. If omitted, a sensible name \
|
|
892
|
+
is constructed from the input variable's name (``B(<xname>)``).
|
|
893
|
+
xname
|
|
894
|
+
Required when ``value`` is a raw array: provides a name for the \
|
|
895
|
+
observation variable that will be created.
|
|
896
|
+
use_callback
|
|
897
|
+
If ``True`` (default) the basis_fn is wrapped in a JAX \
|
|
898
|
+
``pure_callback`` via :func:`make_callback` to allow arbitrary \
|
|
899
|
+
Python basis functions while preserving JAX tracing. If ``False`` \
|
|
900
|
+
the function is used directly and must be jittable via JAX.
|
|
901
|
+
cache_basis
|
|
902
|
+
If ``True`` the computed basis is cached in a persistent \
|
|
903
|
+
calculation node (``lsl.Calc``), which avoids re-computation \
|
|
904
|
+
when not required, but uses memory. If ``False`` a transient \
|
|
905
|
+
calculation node (``lsl.TransientCalc``) is used and the basis \
|
|
906
|
+
will be recomputed with each evaluation of ``Basis.value``, \
|
|
907
|
+
but not stored in memory.
|
|
908
|
+
penalty
|
|
909
|
+
Penalty matrix associated with the basis. If omitted, \
|
|
910
|
+
a default identity penalty is created based on the number \
|
|
911
|
+
of basis functions.
|
|
912
|
+
**basis_kwargs
|
|
913
|
+
Additional keyword arguments forwarded to ``basis_fn``.
|
|
914
|
+
|
|
915
|
+
Raises
|
|
916
|
+
------
|
|
917
|
+
ValueError
|
|
918
|
+
If ``value`` is an array and ``xname`` is not provided, or if
|
|
919
|
+
the created input variable has no name.
|
|
920
|
+
|
|
921
|
+
Notes
|
|
922
|
+
-----
|
|
923
|
+
The basis is evaluated once during initialization (via
|
|
924
|
+
``self.update()``) to determine its shape and dtype. The internal
|
|
925
|
+
callback wrapper inspects the return shape to build a compatible
|
|
926
|
+
JAX ShapeDtypeStruct for the pure callback.
|
|
927
|
+
|
|
928
|
+
Attributes
|
|
929
|
+
----------
|
|
930
|
+
role
|
|
931
|
+
The role assigned to this variable.
|
|
932
|
+
observed
|
|
933
|
+
Whether the basis is derived from an observed variable (always \
|
|
934
|
+
``True`` for bases created from input data).
|
|
935
|
+
x
|
|
936
|
+
The input variable (observations) used to construct the basis.
|
|
937
|
+
nbases
|
|
938
|
+
Number of basis functions (number of columns in the basis matrix).
|
|
939
|
+
penalty
|
|
940
|
+
Penalty matrix (wrapped as a :class:`liesel.model.Value`) associated \
|
|
941
|
+
with the basis.
|
|
942
|
+
|
|
943
|
+
Examples
|
|
944
|
+
--------
|
|
945
|
+
Identity basis from a named variable::
|
|
946
|
+
|
|
947
|
+
import liesel.model as lsl
|
|
948
|
+
import jax.numpy as jnp
|
|
949
|
+
xvar = lsl.Var.new_obs(jnp.array([1.,2.,3.]), name='x')
|
|
950
|
+
b = Basis(value=xvar)
|
|
951
|
+
"""
|
|
952
|
+
|
|
172
953
|
def __init__(
|
|
173
954
|
self,
|
|
174
|
-
value: lsl.Var | lsl.Node,
|
|
175
|
-
basis_fn: Callable[[Array], Array] | Callable[..., Array],
|
|
176
|
-
*args,
|
|
955
|
+
value: lsl.Var | lsl.Node | ArrayLike,
|
|
956
|
+
basis_fn: Callable[[Array], Array] | Callable[..., Array] = lambda x: x,
|
|
177
957
|
name: str | None = None,
|
|
178
|
-
|
|
958
|
+
xname: str | None = None,
|
|
959
|
+
use_callback: bool = True,
|
|
960
|
+
cache_basis: bool = True,
|
|
961
|
+
penalty: ArrayLike | lsl.Value | None = None,
|
|
962
|
+
**basis_kwargs,
|
|
179
963
|
) -> None:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
964
|
+
self._validate_xname(value, xname)
|
|
965
|
+
value_var = _ensure_var_or_node(value, xname)
|
|
966
|
+
|
|
967
|
+
if use_callback:
|
|
968
|
+
value_ar = jnp.asarray(value_var.value)
|
|
969
|
+
basis_kwargs_arr = {}
|
|
970
|
+
for key, val in basis_kwargs.items():
|
|
971
|
+
if isinstance(val, lsl.Var | lsl.Node):
|
|
972
|
+
basis_kwargs_arr[key] = val.value
|
|
973
|
+
else:
|
|
974
|
+
basis_kwargs_arr[key] = val
|
|
975
|
+
basis_ar = basis_fn(value_ar, **basis_kwargs_arr)
|
|
976
|
+
dtype = basis_ar.dtype
|
|
977
|
+
input_shape = jnp.shape(basis_ar)
|
|
978
|
+
|
|
979
|
+
# This is special-case handling for compatibility with
|
|
980
|
+
# basis functions that remove cases. For example, if you have a formulaic
|
|
981
|
+
# formula "x + lag(x)", then the resulting basis will have one case less
|
|
982
|
+
# than the original x, because the first case is dropped.
|
|
983
|
+
if value_ar.shape:
|
|
984
|
+
p = value_ar.shape[0] if value_ar.shape else 0
|
|
985
|
+
k = input_shape[0] if input_shape else 0
|
|
986
|
+
m = p - k
|
|
199
987
|
else:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
988
|
+
m = 0
|
|
989
|
+
|
|
990
|
+
fn = make_callback(basis_fn, input_shape, dtype, m)
|
|
991
|
+
else:
|
|
992
|
+
fn = basis_fn
|
|
993
|
+
|
|
994
|
+
name_ = self._basis_name(value_var, name)
|
|
995
|
+
|
|
996
|
+
if cache_basis:
|
|
997
|
+
calc = lsl.Calc(
|
|
998
|
+
fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc")
|
|
999
|
+
)
|
|
1000
|
+
else:
|
|
1001
|
+
calc = lsl.TransientCalc(
|
|
1002
|
+
fn, value_var, **basis_kwargs, _name=_append_name(name_, "_calc")
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
super().__init__(calc, name=name_)
|
|
1006
|
+
self.update()
|
|
1007
|
+
self.observed = True
|
|
1008
|
+
|
|
1009
|
+
self.x: lsl.Var | lsl.Node = value_var
|
|
1010
|
+
basis_shape = jnp.shape(self.value)
|
|
1011
|
+
if len(basis_shape) >= 1:
|
|
1012
|
+
self.nbases: int = basis_shape[-1]
|
|
1013
|
+
else:
|
|
1014
|
+
self.nbases = 1 # scalar case
|
|
1015
|
+
|
|
1016
|
+
if isinstance(penalty, lsl.Value):
|
|
1017
|
+
penalty_var = penalty
|
|
1018
|
+
elif penalty is None:
|
|
1019
|
+
penalty_arr = jnp.eye(self.nbases)
|
|
1020
|
+
penalty_var = lsl.Value(penalty_arr)
|
|
1021
|
+
else:
|
|
1022
|
+
penalty_arr = jnp.asarray(penalty)
|
|
1023
|
+
penalty_var = lsl.Value(penalty_arr)
|
|
1024
|
+
|
|
1025
|
+
self._penalty = penalty_var
|
|
1026
|
+
|
|
1027
|
+
self._constraint: str | None = None
|
|
1028
|
+
self._reparam_matrix: Array | None = None
|
|
1029
|
+
|
|
1030
|
+
@property
|
|
1031
|
+
def constraint(self) -> str | None:
|
|
1032
|
+
return self._constraint
|
|
1033
|
+
|
|
1034
|
+
@property
|
|
1035
|
+
def reparam_matrix(self) -> Array | None:
|
|
1036
|
+
return self._reparam_matrix
|
|
1037
|
+
|
|
1038
|
+
def _validate_xname(self, value: lsl.Var | lsl.Node | ArrayLike, xname: str | None):
|
|
1039
|
+
if isinstance(value, lsl.Var | lsl.Node) and xname is not None:
|
|
1040
|
+
raise ValueError(
|
|
1041
|
+
"When supplying a variable or node to `value`, `xname` must not be "
|
|
1042
|
+
"used. Name the variable instead."
|
|
207
1043
|
)
|
|
208
|
-
return result
|
|
209
1044
|
|
|
210
|
-
|
|
211
|
-
|
|
1045
|
+
def _basis_name(self, value: lsl.Var | lsl.Node | ArrayLike, name: str | None):
|
|
1046
|
+
if name is not None and name != "":
|
|
1047
|
+
return name
|
|
1048
|
+
|
|
1049
|
+
if isinstance(value, lsl.Var | lsl.Node) and value.name == "":
|
|
1050
|
+
return ""
|
|
1051
|
+
|
|
1052
|
+
if hasattr(value, "name"):
|
|
1053
|
+
return f"B({value.name})"
|
|
1054
|
+
return ""
|
|
1055
|
+
|
|
1056
|
+
@property
|
|
1057
|
+
def penalty(self) -> lsl.Value:
|
|
1058
|
+
"""
|
|
1059
|
+
Return the penalty matrix wrapped as a :class:`liesel.model.Value`.
|
|
1060
|
+
|
|
1061
|
+
Returns
|
|
1062
|
+
-------
|
|
1063
|
+
lsl.Value
|
|
1064
|
+
Value wrapper holding the penalty (precision) matrix for this
|
|
1065
|
+
basis.
|
|
1066
|
+
"""
|
|
1067
|
+
return self._penalty
|
|
1068
|
+
|
|
1069
|
+
def update_penalty(self, value: ArrayLike | lsl.Value):
|
|
1070
|
+
"""
|
|
1071
|
+
Update the penalty matrix for this basis.
|
|
212
1072
|
|
|
213
|
-
|
|
214
|
-
|
|
1073
|
+
Parameters
|
|
1074
|
+
----------
|
|
1075
|
+
value
|
|
1076
|
+
New penalty matrix or an already-wrapped :class:`liesel.model.Value`.
|
|
1077
|
+
"""
|
|
1078
|
+
if isinstance(value, lsl.Value):
|
|
1079
|
+
self._penalty.value = value.value
|
|
1080
|
+
else:
|
|
1081
|
+
penalty_arr = jnp.asarray(value)
|
|
1082
|
+
self._penalty.value = penalty_arr
|
|
1083
|
+
|
|
1084
|
+
@classmethod
|
|
1085
|
+
def new_linear(
|
|
1086
|
+
cls,
|
|
1087
|
+
value: lsl.Var | lsl.Node | Array,
|
|
1088
|
+
name: str | None = None,
|
|
1089
|
+
xname: str | None = None,
|
|
1090
|
+
add_intercept: bool = False,
|
|
1091
|
+
):
|
|
1092
|
+
"""
|
|
1093
|
+
Create a linear basis (design matrix) from input values.
|
|
1094
|
+
|
|
1095
|
+
Parameters
|
|
1096
|
+
----------
|
|
1097
|
+
value
|
|
1098
|
+
Input variable or raw array used to construct the design matrix.
|
|
1099
|
+
name
|
|
1100
|
+
Optional name for the basis.
|
|
1101
|
+
xname
|
|
1102
|
+
Name for the observation variable when ``value`` is \
|
|
1103
|
+
a raw array.
|
|
1104
|
+
add_intercept
|
|
1105
|
+
If ``True``, adds an intercept column of ones as the first \
|
|
1106
|
+
column of the design matrix.
|
|
1107
|
+
|
|
1108
|
+
Returns
|
|
1109
|
+
-------
|
|
1110
|
+
A :class:`.Basis` instance that produces a (n_obs, n_features)
|
|
1111
|
+
design matrix.
|
|
1112
|
+
"""
|
|
1113
|
+
|
|
1114
|
+
def as_matrix(x):
|
|
1115
|
+
x = jnp.atleast_1d(x)
|
|
1116
|
+
if len(jnp.shape(x)) == 1:
|
|
1117
|
+
x = jnp.expand_dims(x, -1)
|
|
1118
|
+
if add_intercept:
|
|
1119
|
+
ones = jnp.ones(x.shape[0])
|
|
1120
|
+
x = jnp.c_[ones, x]
|
|
1121
|
+
return x
|
|
1122
|
+
|
|
1123
|
+
basis = cls(
|
|
1124
|
+
value=value,
|
|
1125
|
+
basis_fn=as_matrix,
|
|
1126
|
+
name=name,
|
|
1127
|
+
xname=xname,
|
|
1128
|
+
use_callback=False,
|
|
1129
|
+
cache_basis=False,
|
|
1130
|
+
)
|
|
1131
|
+
|
|
1132
|
+
return basis
|
|
1133
|
+
|
|
1134
|
+
def diagonalize_penalty(self, atol: float = 1e-6) -> Self:
|
|
1135
|
+
"""
|
|
1136
|
+
Diagonalize the penalty via an eigenvalue decomposition.
|
|
1137
|
+
|
|
1138
|
+
This method computes a transformation that diagonalizes
|
|
1139
|
+
the penalty matrix and updates the internal basis function such that
|
|
1140
|
+
subsequent evaluations use the accordingly transformed basis. The penalty is
|
|
1141
|
+
updated to the diagonalized version.
|
|
1142
|
+
|
|
1143
|
+
Returns
|
|
1144
|
+
-------
|
|
1145
|
+
The modified basis instance (self).
|
|
1146
|
+
"""
|
|
1147
|
+
assert isinstance(self.value_node, lsl.Calc)
|
|
1148
|
+
basis_fn = self.value_node.function
|
|
1149
|
+
|
|
1150
|
+
K = self.penalty.value
|
|
1151
|
+
if is_diagonal(K, atol=atol):
|
|
1152
|
+
return self
|
|
215
1153
|
|
|
216
|
-
|
|
1154
|
+
Z = penalty_to_unit_design(K)
|
|
1155
|
+
|
|
1156
|
+
def reparam_basis(*args, **kwargs):
|
|
1157
|
+
return basis_fn(*args, **kwargs) @ Z
|
|
1158
|
+
|
|
1159
|
+
self.value_node.function = reparam_basis
|
|
1160
|
+
self.update()
|
|
1161
|
+
penalty = Z.T @ K @ Z
|
|
1162
|
+
self.update_penalty(penalty)
|
|
1163
|
+
|
|
1164
|
+
return self
|
|
1165
|
+
|
|
1166
|
+
def scale_penalty(self) -> Self:
|
|
1167
|
+
"""
|
|
1168
|
+
Scale the penalty matrix by its infinite norm.
|
|
1169
|
+
|
|
1170
|
+
The penalty matrix is divided by its infinity norm (max absolute row
|
|
1171
|
+
sum) so that its values are numerically well-conditioned for
|
|
1172
|
+
downstream use. The updated penalty replaces the previous one.
|
|
1173
|
+
|
|
1174
|
+
Returns
|
|
1175
|
+
-------
|
|
1176
|
+
The modified basis instance (self).
|
|
1177
|
+
"""
|
|
1178
|
+
K = self.penalty.value
|
|
1179
|
+
scale = jnp.linalg.norm(K, ord=jnp.inf)
|
|
1180
|
+
penalty = K / scale
|
|
1181
|
+
self.update_penalty(penalty)
|
|
1182
|
+
return self
|
|
1183
|
+
|
|
1184
|
+
def _apply_constraint(self, Z: Array) -> Self:
|
|
1185
|
+
"""
|
|
1186
|
+
Apply a linear reparameterisation to the basis using matrix Z.
|
|
1187
|
+
|
|
1188
|
+
This internal helper multiplies the basis functions by ``Z`` (i.e.
|
|
1189
|
+
right-multiplies the design matrix) and updates the penalty to
|
|
1190
|
+
reflect the change of basis: ``K_new = Z.T @ K @ Z``.
|
|
1191
|
+
|
|
1192
|
+
Parameters
|
|
1193
|
+
----------
|
|
1194
|
+
Z
|
|
1195
|
+
Transformation matrix applied to the basis functions.
|
|
1196
|
+
|
|
1197
|
+
Returns
|
|
1198
|
+
-------
|
|
1199
|
+
The modified basis instance (self).
|
|
1200
|
+
"""
|
|
1201
|
+
|
|
1202
|
+
assert isinstance(self.value_node, lsl.Calc)
|
|
1203
|
+
basis_fn = self.value_node.function
|
|
1204
|
+
|
|
1205
|
+
K = self.penalty.value
|
|
1206
|
+
|
|
1207
|
+
def reparam_basis(*args, **kwargs):
|
|
1208
|
+
return basis_fn(*args, **kwargs) @ Z
|
|
1209
|
+
|
|
1210
|
+
self.value_node.function = reparam_basis
|
|
217
1211
|
self.update()
|
|
218
|
-
|
|
1212
|
+
penalty = Z.T @ K @ Z
|
|
1213
|
+
self.update_penalty(penalty)
|
|
1214
|
+
return self
|
|
1215
|
+
|
|
1216
|
+
def constrain(
|
|
1217
|
+
self,
|
|
1218
|
+
constraint: ArrayLike
|
|
1219
|
+
| Literal["sumzero_term", "sumzero_coef", "constant_and_linear"],
|
|
1220
|
+
) -> Self:
|
|
1221
|
+
"""
|
|
1222
|
+
Apply a linear constraint to the basis and corresponding penalty.
|
|
1223
|
+
|
|
1224
|
+
Parameters
|
|
1225
|
+
----------
|
|
1226
|
+
constraint
|
|
1227
|
+
Type of constraint or custom linear constraint matrix to apply.
|
|
1228
|
+
If an array is supplied, the constraint will be \
|
|
1229
|
+
``A @ coef == 0``, where ``A`` is the supplied constraint matrix.
|
|
1230
|
+
|
|
1231
|
+
Returns
|
|
1232
|
+
-------
|
|
1233
|
+
The modified basis instance (self).
|
|
1234
|
+
"""
|
|
1235
|
+
if not self.value.ndim == 2:
|
|
1236
|
+
raise ValueError(
|
|
1237
|
+
"Constraints can only be applied to matrix-valued bases. "
|
|
1238
|
+
f"{self} has shape {self.value.shape}"
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
if self.constraint is not None:
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
f"A '{self.constraint}' constraint has already been applied."
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
if isinstance(constraint, str):
|
|
1247
|
+
type_: str = constraint
|
|
1248
|
+
else:
|
|
1249
|
+
constraint_matrix = jnp.asarray(constraint)
|
|
1250
|
+
type_ = "custom"
|
|
1251
|
+
|
|
1252
|
+
match type_:
|
|
1253
|
+
case "sumzero_coef":
|
|
1254
|
+
Z = LinearConstraintEVD.sumzero_coef(self.nbases)
|
|
1255
|
+
case "sumzero_term":
|
|
1256
|
+
Z = LinearConstraintEVD.sumzero_term(self.value)
|
|
1257
|
+
case "constant_and_linear":
|
|
1258
|
+
Z = LinearConstraintEVD.constant_and_linear(self.x.value, self.value)
|
|
1259
|
+
case "custom":
|
|
1260
|
+
Z = LinearConstraintEVD.general(constraint_matrix)
|
|
1261
|
+
|
|
1262
|
+
self._apply_constraint(Z)
|
|
1263
|
+
self._constraint = type_
|
|
1264
|
+
self._reparam_matrix = Z
|
|
1265
|
+
|
|
1266
|
+
return self
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
class MRFSpec(NamedTuple):
|
|
1270
|
+
mapping: CategoryMapping
|
|
1271
|
+
nb: dict[str, list[str]] | None
|
|
1272
|
+
ordered_labels: list[str] | None
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
class MRFBasis(Basis):
|
|
1276
|
+
_mrf_spec: MRFSpec | None = None
|
|
1277
|
+
|
|
1278
|
+
@property
|
|
1279
|
+
def mrf_spec(self) -> MRFSpec:
|
|
1280
|
+
if self._mrf_spec is None:
|
|
1281
|
+
raise ValueError("No MRF spec defined.")
|
|
1282
|
+
return self._mrf_spec
|
|
1283
|
+
|
|
1284
|
+
@mrf_spec.setter
|
|
1285
|
+
def mrf_spec(self, value: MRFSpec):
|
|
1286
|
+
if not isinstance(value, MRFSpec):
|
|
1287
|
+
raise TypeError(
|
|
1288
|
+
f"Replacement must be of type {MRFSpec}, got {type(value)}."
|
|
1289
|
+
)
|
|
1290
|
+
self._mrf_spec = value
|
|
1291
|
+
|
|
1292
|
+
|
|
1293
|
+
class LinBasis(Basis):
|
|
1294
|
+
_model_spec: ModelSpec | None = None
|
|
1295
|
+
_mappings: dict[str, CategoryMapping] | None = None
|
|
1296
|
+
_column_names: list[str] | None = None
|
|
1297
|
+
|
|
1298
|
+
@property
|
|
1299
|
+
def model_spec(self) -> ModelSpec:
|
|
1300
|
+
if self._model_spec is None:
|
|
1301
|
+
raise ValueError("No model spec defined.")
|
|
1302
|
+
return self._model_spec
|
|
1303
|
+
|
|
1304
|
+
@model_spec.setter
|
|
1305
|
+
def model_spec(self, value: ModelSpec):
|
|
1306
|
+
if not isinstance(value, ModelSpec):
|
|
1307
|
+
raise TypeError(
|
|
1308
|
+
f"Replacement must be of type {ModelSpec}, got {type(value)}."
|
|
1309
|
+
)
|
|
1310
|
+
self._model_spec = value
|
|
1311
|
+
|
|
1312
|
+
@property
|
|
1313
|
+
def mappings(self) -> dict[str, CategoryMapping]:
|
|
1314
|
+
if self._mappings is None:
|
|
1315
|
+
raise ValueError("No model spec defined.")
|
|
1316
|
+
return self._mappings
|
|
1317
|
+
|
|
1318
|
+
@mappings.setter
|
|
1319
|
+
def mappings(self, value: dict[str, CategoryMapping]):
|
|
1320
|
+
if not isinstance(value, dict):
|
|
1321
|
+
raise TypeError(f"Replacement must be of type dict, got {type(value)}.")
|
|
1322
|
+
|
|
1323
|
+
for val in value.values():
|
|
1324
|
+
if not isinstance(val, CategoryMapping):
|
|
1325
|
+
raise TypeError(
|
|
1326
|
+
f"The values in the replacement must be of type {CategoryMapping}, "
|
|
1327
|
+
f"got {type(val)}."
|
|
1328
|
+
)
|
|
1329
|
+
self._mappings = value
|
|
1330
|
+
|
|
1331
|
+
@property
|
|
1332
|
+
def column_names(self) -> list[str]:
|
|
1333
|
+
if self._column_names is None:
|
|
1334
|
+
raise ValueError("No model spec defined.")
|
|
1335
|
+
return self._column_names
|
|
1336
|
+
|
|
1337
|
+
@column_names.setter
|
|
1338
|
+
def column_names(self, value: Sequence[str]):
|
|
1339
|
+
if not isinstance(value, Sequence):
|
|
1340
|
+
raise TypeError(f"Replacement must be a sequence, got {type(value)}.")
|
|
1341
|
+
|
|
1342
|
+
for val in value:
|
|
1343
|
+
if not isinstance(val, str):
|
|
1344
|
+
raise TypeError(
|
|
1345
|
+
f"The values in the replacement must be of type str, "
|
|
1346
|
+
f"got {type(val)}."
|
|
1347
|
+
)
|
|
1348
|
+
self._column_names = list(value)
|
|
1349
|
+
|
|
1350
|
+
|
|
1351
|
+
class LinTerm(BasisDot):
|
|
1352
|
+
_model_spec: ModelSpec | None = None
|
|
1353
|
+
_mappings: dict[str, CategoryMapping] | None = None
|
|
1354
|
+
_column_names: list[str] | None = None
|
|
1355
|
+
|
|
1356
|
+
@property
|
|
1357
|
+
def model_spec(self) -> ModelSpec | None:
|
|
1358
|
+
return self._model_spec
|
|
1359
|
+
|
|
1360
|
+
@model_spec.setter
|
|
1361
|
+
def model_spec(self, value: ModelSpec):
|
|
1362
|
+
if not isinstance(value, ModelSpec):
|
|
1363
|
+
raise TypeError(
|
|
1364
|
+
f"Replacement must be of type {ModelSpec}, got {type(value)}."
|
|
1365
|
+
)
|
|
1366
|
+
self._model_spec = value
|
|
1367
|
+
|
|
1368
|
+
@property
|
|
1369
|
+
def mappings(self) -> dict[str, CategoryMapping]:
|
|
1370
|
+
if self._mappings is None:
|
|
1371
|
+
raise ValueError("No model spec defined.")
|
|
1372
|
+
return self._mappings
|
|
1373
|
+
|
|
1374
|
+
@mappings.setter
|
|
1375
|
+
def mappings(self, value: dict[str, CategoryMapping]):
|
|
1376
|
+
if not isinstance(value, dict):
|
|
1377
|
+
raise TypeError(f"Replacement must be of type dict, got {type(value)}.")
|
|
1378
|
+
|
|
1379
|
+
for val in value.values():
|
|
1380
|
+
if not isinstance(val, CategoryMapping):
|
|
1381
|
+
raise TypeError(
|
|
1382
|
+
f"The values in the replacement must be of type {CategoryMapping}, "
|
|
1383
|
+
f"got {type(val)}."
|
|
1384
|
+
)
|
|
1385
|
+
self._mappings = value
|
|
1386
|
+
|
|
1387
|
+
@property
|
|
1388
|
+
def column_names(self) -> list[str]:
|
|
1389
|
+
if self._column_names is None:
|
|
1390
|
+
raise ValueError("No model spec defined.")
|
|
1391
|
+
return self._column_names
|
|
1392
|
+
|
|
1393
|
+
@column_names.setter
|
|
1394
|
+
def column_names(self, value: Sequence[str]):
|
|
1395
|
+
if not isinstance(value, Sequence):
|
|
1396
|
+
raise TypeError(f"Replacement must be a sequence, got {type(value)}.")
|
|
1397
|
+
|
|
1398
|
+
for val in value:
|
|
1399
|
+
if not isinstance(val, str):
|
|
1400
|
+
raise TypeError(
|
|
1401
|
+
f"The values in the replacement must be of type str, "
|
|
1402
|
+
f"got {type(val)}."
|
|
1403
|
+
)
|
|
1404
|
+
self._column_names = list(value)
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
class TPTerm(UserVar):
|
|
1408
|
+
"""
|
|
1409
|
+
General anisotropic structured additive tensor product term.
|
|
1410
|
+
|
|
1411
|
+
A bivariate tensor product can have:
|
|
1412
|
+
|
|
1413
|
+
1. One scale parameter (when using ita)
|
|
1414
|
+
2. Two scale parameters (when using include_main_effects)
|
|
1415
|
+
3. Three scale parameters (when using common_scale and include_main_effects,
|
|
1416
|
+
or adding main effects separately)
|
|
1417
|
+
4. Four scale parameters (when adding main effects separately)
|
|
1418
|
+
|
|
1419
|
+
Option four is the most flexible one, since it also allows you to use different
|
|
1420
|
+
basis dimensions for the main effects and the interaction.
|
|
1421
|
+
"""
|
|
1422
|
+
|
|
1423
|
+
def __init__(
|
|
1424
|
+
self,
|
|
1425
|
+
*marginals: Term | IndexingTerm | RITerm | MRFTerm,
|
|
1426
|
+
common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
|
|
1427
|
+
name: str = "",
|
|
1428
|
+
inference: InferenceTypes = None,
|
|
1429
|
+
coef_name: str | None = None,
|
|
1430
|
+
basis_name: str | None = None,
|
|
1431
|
+
include_main_effects: bool = False,
|
|
1432
|
+
_update_on_init: bool = True,
|
|
1433
|
+
):
|
|
1434
|
+
self._validate_marginals(marginals)
|
|
1435
|
+
coef_name = _append_name(name, "_coef") if coef_name is None else coef_name
|
|
1436
|
+
bases = self._get_bases(marginals)
|
|
1437
|
+
penalties = [b.penalty.value for b in bases]
|
|
1438
|
+
|
|
1439
|
+
if common_scale is None:
|
|
1440
|
+
scales = [t.scale for t in marginals]
|
|
1441
|
+
else:
|
|
1442
|
+
scales = [_init_scale_ig(common_scale) for _ in bases]
|
|
1443
|
+
|
|
1444
|
+
_rowwise_kron = jax.vmap(jnp.kron)
|
|
1445
|
+
|
|
1446
|
+
def rowwise_kron(*bases):
|
|
1447
|
+
return reduce(_rowwise_kron, bases)
|
|
1448
|
+
|
|
1449
|
+
if basis_name is None:
|
|
1450
|
+
basis_name = "B(" + ",".join(list(self._input_obs(bases))) + ")"
|
|
1451
|
+
|
|
1452
|
+
assert basis_name is not None
|
|
1453
|
+
basis = lsl.Var.new_calc(rowwise_kron, *bases, name=basis_name)
|
|
1454
|
+
nbases = jnp.shape(basis.value)[-1]
|
|
1455
|
+
|
|
1456
|
+
mvnds = MultivariateNormalStructured.get_locscale_constructor(
|
|
1457
|
+
penalties=penalties
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
scales_var = lsl.Calc(lambda *x: jnp.stack(x, axis=-1), *scales)
|
|
1461
|
+
|
|
1462
|
+
prior = lsl.Dist(distribution=mvnds, loc=jnp.zeros(nbases), scales=scales_var)
|
|
1463
|
+
|
|
1464
|
+
coef = lsl.Var.new_param(
|
|
1465
|
+
jnp.zeros(nbases),
|
|
1466
|
+
distribution=prior,
|
|
1467
|
+
inference=inference,
|
|
1468
|
+
name=coef_name,
|
|
1469
|
+
)
|
|
1470
|
+
|
|
1471
|
+
self.basis = basis
|
|
1472
|
+
self.marginals = marginals
|
|
1473
|
+
self.bases = bases
|
|
1474
|
+
self.penalties = penalties
|
|
1475
|
+
self.scales = scales
|
|
1476
|
+
|
|
1477
|
+
self.nbases = nbases
|
|
1478
|
+
self.basis = basis
|
|
1479
|
+
self.coef = coef
|
|
1480
|
+
self.scale = scales_var
|
|
1481
|
+
self.include_main_effects = include_main_effects
|
|
1482
|
+
|
|
1483
|
+
if include_main_effects:
|
|
1484
|
+
calc = lsl.Calc(
|
|
1485
|
+
lambda *marginals, basis, coef: sum(marginals) + jnp.dot(basis, coef),
|
|
1486
|
+
*marginals,
|
|
1487
|
+
basis=basis,
|
|
1488
|
+
coef=self.coef,
|
|
1489
|
+
_update_on_init=_update_on_init,
|
|
1490
|
+
)
|
|
1491
|
+
else:
|
|
1492
|
+
calc = lsl.Calc(
|
|
1493
|
+
lambda basis, coef: jnp.dot(basis, coef),
|
|
1494
|
+
basis=basis,
|
|
1495
|
+
coef=self.coef,
|
|
1496
|
+
_update_on_init=_update_on_init,
|
|
1497
|
+
)
|
|
1498
|
+
|
|
1499
|
+
super().__init__(calc, name=name)
|
|
1500
|
+
if _update_on_init:
|
|
1501
|
+
self.coef.update()
|
|
1502
|
+
|
|
1503
|
+
@staticmethod
|
|
1504
|
+
def _get_bases(
|
|
1505
|
+
marginals: Sequence[Term | RITerm | MRFTerm | IndexingTerm],
|
|
1506
|
+
) -> list[Basis]:
|
|
1507
|
+
bases = []
|
|
1508
|
+
for t in marginals:
|
|
1509
|
+
if hasattr(t, "full_basis"):
|
|
1510
|
+
bases.append(t.full_basis)
|
|
1511
|
+
else:
|
|
1512
|
+
bases.append(t.basis)
|
|
1513
|
+
return bases
|
|
1514
|
+
|
|
1515
|
+
@staticmethod
|
|
1516
|
+
def _validate_marginals(marginals: Sequence[Term]):
|
|
1517
|
+
for t in marginals:
|
|
1518
|
+
if t.scale is None:
|
|
1519
|
+
raise ValueError(f"Invalid scale for {t}: {t.scale}")
|
|
1520
|
+
try:
|
|
1521
|
+
# ignoring type here because potential errors are handlded
|
|
1522
|
+
t.coef.dist_node["penalty"] # type: ignore
|
|
1523
|
+
except Exception as e:
|
|
1524
|
+
raise ValueError(f"Invalid penalty for {t}") from e
|
|
1525
|
+
|
|
1526
|
+
for i, b in enumerate(TPTerm._get_bases(marginals)):
|
|
1527
|
+
if b.value.ndim != 2:
|
|
1528
|
+
raise ValueError(
|
|
1529
|
+
"Expected 2-dimensional basis, but the basis "
|
|
1530
|
+
f"of {marginals[i]} has shape {b.value.shape}"
|
|
1531
|
+
)
|
|
1532
|
+
|
|
1533
|
+
@property
|
|
1534
|
+
def input_obs(self) -> dict[str, lsl.Var]:
|
|
1535
|
+
return self._input_obs(self.bases)
|
|
1536
|
+
|
|
1537
|
+
@staticmethod
|
|
1538
|
+
def _input_obs(bases: Sequence[Basis]) -> dict[str, lsl.Var]:
|
|
1539
|
+
# this method includes assumptions about how the individual bases are
|
|
1540
|
+
# structured: Basis.x can be a strong observed variable directly, or a
|
|
1541
|
+
# calculator variable that depends on strong observed variables.
|
|
1542
|
+
# If these assumptions are violated, this method may produce unexpected results.
|
|
1543
|
+
# The bases created by BasisBuilder fit theses assumptions.
|
|
1544
|
+
_input_x = {}
|
|
1545
|
+
for b in bases:
|
|
1546
|
+
if isinstance(b.x, lsl.Var):
|
|
1547
|
+
if b.x.strong and b.x.observed:
|
|
1548
|
+
# case: ordinary univariate marginal basis, like ps
|
|
1549
|
+
if not b.x.name:
|
|
1550
|
+
raise ValueError(f"Observed name not found for {b}")
|
|
1551
|
+
_input_x[b.x.name] = b.x
|
|
1552
|
+
elif b.x.weak:
|
|
1553
|
+
# currently, I don't expect this case to be present
|
|
1554
|
+
# but it would make sense
|
|
1555
|
+
for xi in b.x.all_input_vars():
|
|
1556
|
+
if xi.observed:
|
|
1557
|
+
if not xi.name:
|
|
1558
|
+
raise ValueError(f"Observed name not found for {b}")
|
|
1559
|
+
_input_x[xi.name] = xi
|
|
1560
|
+
|
|
1561
|
+
else:
|
|
1562
|
+
# case: potentially multivariate marginal, possibly tp,
|
|
1563
|
+
# where basis.x is a calculator that collects the strong inputs.
|
|
1564
|
+
for xj in b.x.all_input_nodes():
|
|
1565
|
+
if not isinstance(xj, lsl.Var):
|
|
1566
|
+
if xj.var is not None:
|
|
1567
|
+
if xj.var.observed:
|
|
1568
|
+
_input_x[xj.var.name] = xj.var
|
|
1569
|
+
elif xj.observed:
|
|
1570
|
+
if not xj.name:
|
|
1571
|
+
raise ValueError(f"Observed name not found for {b}")
|
|
1572
|
+
_input_x[xj.name] = xj
|
|
1573
|
+
|
|
1574
|
+
return _input_x
|
|
1575
|
+
|
|
1576
|
+
@classmethod
|
|
1577
|
+
def f(
|
|
1578
|
+
cls,
|
|
1579
|
+
*marginals: Term,
|
|
1580
|
+
common_scale: ScaleIG | lsl.Var | ArrayLike | VarIGPrior | None = None,
|
|
1581
|
+
fname: str = "ta",
|
|
1582
|
+
inference: InferenceTypes = None,
|
|
1583
|
+
_update_on_init: bool = True,
|
|
1584
|
+
) -> Self:
|
|
1585
|
+
xnames = list(cls._input_obs(cls._get_bases(marginals)))
|
|
1586
|
+
name = fname + "(" + ",".join(xnames) + ")"
|
|
1587
|
+
|
|
1588
|
+
coef_name = "$\\beta_{" + name + "}$"
|
|
1589
|
+
|
|
1590
|
+
term = cls(
|
|
1591
|
+
*marginals,
|
|
1592
|
+
common_scale=common_scale,
|
|
1593
|
+
inference=inference,
|
|
1594
|
+
coef_name=coef_name,
|
|
1595
|
+
name=name,
|
|
1596
|
+
basis_name=None,
|
|
1597
|
+
_update_on_init=_update_on_init,
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
return term
|