liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,2003 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from math import ceil
|
|
7
|
+
from typing import Any, Literal, get_args
|
|
8
|
+
|
|
9
|
+
import formulaic as fo
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
import liesel.goose as gs
|
|
13
|
+
import liesel.model as lsl
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import smoothcon as scon
|
|
17
|
+
from liesel.model.model import TemporaryModel
|
|
18
|
+
from ryp import r, to_py, to_r
|
|
19
|
+
|
|
20
|
+
from ..var import (
|
|
21
|
+
Basis,
|
|
22
|
+
LinBasis,
|
|
23
|
+
LinTerm,
|
|
24
|
+
MRFBasis,
|
|
25
|
+
MRFSpec,
|
|
26
|
+
MRFTerm,
|
|
27
|
+
RITerm,
|
|
28
|
+
ScaleIG,
|
|
29
|
+
Term,
|
|
30
|
+
TPTerm,
|
|
31
|
+
VarIGPrior,
|
|
32
|
+
)
|
|
33
|
+
from .registry import CategoryMapping, PandasRegistry
|
|
34
|
+
|
|
35
|
+
InferenceTypes = Any
|
|
36
|
+
|
|
37
|
+
Array = jax.Array
|
|
38
|
+
ArrayLike = jax.typing.ArrayLike
|
|
39
|
+
|
|
40
|
+
BasisTypes = Literal["tp", "ts", "cr", "cs", "cc", "bs", "ps", "cp", "gp"]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _validate_bs(bs):
|
|
47
|
+
if isinstance(bs, str):
|
|
48
|
+
bs = [bs]
|
|
49
|
+
allowed = get_args(BasisTypes)
|
|
50
|
+
for bs_str in bs:
|
|
51
|
+
if bs_str not in allowed:
|
|
52
|
+
raise ValueError(f"Allowed values for 'bs' are: {allowed}; got {bs=}.")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _margin_penalties(smooth: scon.SmoothCon):
|
|
56
|
+
"""Extracts the marginal penalty matrices from a ti() smooth."""
|
|
57
|
+
# this should go into smoothcon, but it works here for now
|
|
58
|
+
r(
|
|
59
|
+
f"penalties_list <- lapply({smooth._smooth_r_name}"
|
|
60
|
+
"[[1]]$margin, function(x) x$S[[1]])"
|
|
61
|
+
)
|
|
62
|
+
pens = to_py("penalties_list")
|
|
63
|
+
return [pen.to_numpy() for pen in pens]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _tp_penalty(K1, K2) -> Array:
|
|
67
|
+
"""Computes the full tensor product penalty from the marginals."""
|
|
68
|
+
# this should go into smoothcon, but it works here for now
|
|
69
|
+
D1 = np.shape(K1)[1]
|
|
70
|
+
D2 = np.shape(K2)[1]
|
|
71
|
+
I1 = np.eye(D1)
|
|
72
|
+
I2 = np.eye(D2)
|
|
73
|
+
|
|
74
|
+
return jnp.asarray(jnp.kron(K1, I2) + jnp.kron(I1, K2))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def labels_to_integers(newdata: dict, mappings: dict[str, CategoryMapping]) -> dict:
|
|
78
|
+
# replace categorical inputs with their index representation
|
|
79
|
+
# create combined input matrices from individual variables, if desired
|
|
80
|
+
newdata = newdata.copy()
|
|
81
|
+
|
|
82
|
+
# replace categorical variables by their integer representations
|
|
83
|
+
for name, mapping in mappings.items():
|
|
84
|
+
if name in newdata:
|
|
85
|
+
newdata[name] = mapping.labels_to_integers(newdata[name])
|
|
86
|
+
|
|
87
|
+
return newdata
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def assert_intercept_in_spec(spec: fo.ModelSpec) -> fo.ModelSpec:
|
|
91
|
+
"""
|
|
92
|
+
Uses the degrees of the terms in the spec's formula to find intercepts.
|
|
93
|
+
The degree of a term indicates how many columns of the input data are referenced
|
|
94
|
+
by the term, so a degree of zero can be used to identify an intercept.
|
|
95
|
+
"""
|
|
96
|
+
terms = list(spec.formula)
|
|
97
|
+
terms_with_degree_zero = [term for term in terms if term.degree == 0]
|
|
98
|
+
|
|
99
|
+
if len(terms_with_degree_zero) > 1:
|
|
100
|
+
raise RuntimeError(f"Too many intercepts: {len(terms_with_degree_zero)}.")
|
|
101
|
+
if len(terms_with_degree_zero) == 0:
|
|
102
|
+
raise RuntimeError(
|
|
103
|
+
"No intercept found in formula. Did you explicitly remove an "
|
|
104
|
+
"intercept by including '0' or '-1'? This breaks model matrix setup."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return spec
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def validate_formula(formula: str) -> None:
|
|
111
|
+
if "~" in formula:
|
|
112
|
+
raise ValueError("'~' in formulas is not supported.")
|
|
113
|
+
|
|
114
|
+
terms = ["".join(x.split()) for x in formula.split("+")]
|
|
115
|
+
for term in terms:
|
|
116
|
+
if term == "1":
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Using '1 +' is not supported. To add an intercept, use the "
|
|
119
|
+
"argument 'include_intercept'."
|
|
120
|
+
)
|
|
121
|
+
if term == "0" or term == "-1":
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Using '0 +' or '-1' is not supported. Intercepts are not included "
|
|
124
|
+
"by default and can be added manually with the argument "
|
|
125
|
+
"'include_intercept'."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def validate_penalty_order(penalty_order: int):
|
|
130
|
+
if not isinstance(penalty_order, int):
|
|
131
|
+
raise TypeError(
|
|
132
|
+
f"'penalty_order' must be int or None, got {type(penalty_order)}"
|
|
133
|
+
)
|
|
134
|
+
if not penalty_order > 0:
|
|
135
|
+
raise ValueError(f"'penalty_order' must be >0, got {penalty_order}")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class BasisBuilder:
|
|
139
|
+
def __init__(
|
|
140
|
+
self, registry: PandasRegistry, names: NameManager | None = None
|
|
141
|
+
) -> None:
|
|
142
|
+
self.registry = registry
|
|
143
|
+
self.mappings: dict[str, CategoryMapping] = {}
|
|
144
|
+
self.names = NameManager() if names is None else names
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def data(self) -> pd.DataFrame:
|
|
148
|
+
return self.registry.data
|
|
149
|
+
|
|
150
|
+
def basis(
|
|
151
|
+
self,
|
|
152
|
+
*x: str,
|
|
153
|
+
basis_fn: Callable[[Array], Array] = lambda x: x,
|
|
154
|
+
use_callback: bool = True,
|
|
155
|
+
cache_basis: bool = True,
|
|
156
|
+
penalty: ArrayLike | lsl.Value | None = None,
|
|
157
|
+
basis_name: str = "B",
|
|
158
|
+
) -> Basis:
|
|
159
|
+
if isinstance(penalty, lsl.Value):
|
|
160
|
+
penalty.value = jnp.asarray(penalty.value)
|
|
161
|
+
elif penalty is not None:
|
|
162
|
+
penalty = jnp.asarray(penalty)
|
|
163
|
+
|
|
164
|
+
x_vars = []
|
|
165
|
+
for x_name in x:
|
|
166
|
+
x_var = self.registry.get_numeric_obs(x_name)
|
|
167
|
+
x_vars.append(x_var)
|
|
168
|
+
|
|
169
|
+
Xname = self.registry.prefix + ",".join(x)
|
|
170
|
+
|
|
171
|
+
Xvar = lsl.TransientCalc(
|
|
172
|
+
lambda *x: jnp.column_stack(x),
|
|
173
|
+
*x_vars,
|
|
174
|
+
_name=Xname,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
basis = Basis(
|
|
178
|
+
value=Xvar,
|
|
179
|
+
basis_fn=basis_fn,
|
|
180
|
+
name=self.names.create_lazily(basis_name + "(" + Xname + ")"),
|
|
181
|
+
use_callback=use_callback,
|
|
182
|
+
cache_basis=cache_basis,
|
|
183
|
+
penalty=jnp.asarray(penalty),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return basis
|
|
187
|
+
|
|
188
|
+
def ps(
|
|
189
|
+
self,
|
|
190
|
+
x: str,
|
|
191
|
+
*,
|
|
192
|
+
k: int,
|
|
193
|
+
basis_degree: int = 3,
|
|
194
|
+
penalty_order: int = 2,
|
|
195
|
+
knots: ArrayLike | None = None,
|
|
196
|
+
absorb_cons: bool = True,
|
|
197
|
+
diagonal_penalty: bool = True,
|
|
198
|
+
scale_penalty: bool = True,
|
|
199
|
+
basis_name: str = "B",
|
|
200
|
+
) -> Basis:
|
|
201
|
+
validate_penalty_order(penalty_order)
|
|
202
|
+
if knots is not None:
|
|
203
|
+
knots = np.asarray(knots)
|
|
204
|
+
|
|
205
|
+
spec = f"s({x}, bs='ps', k={k}, m=c({basis_degree - 1}, {penalty_order}))"
|
|
206
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
207
|
+
smooth = scon.SmoothCon(
|
|
208
|
+
spec,
|
|
209
|
+
data={x: x_array},
|
|
210
|
+
knots=knots,
|
|
211
|
+
absorb_cons=absorb_cons,
|
|
212
|
+
diagonal_penalty=diagonal_penalty,
|
|
213
|
+
scale_penalty=scale_penalty,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
217
|
+
basis = Basis(
|
|
218
|
+
x_var,
|
|
219
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
220
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
221
|
+
penalty=smooth.penalty,
|
|
222
|
+
use_callback=True,
|
|
223
|
+
cache_basis=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if absorb_cons:
|
|
227
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
228
|
+
return basis
|
|
229
|
+
|
|
230
|
+
def cr(
|
|
231
|
+
self,
|
|
232
|
+
x: str,
|
|
233
|
+
*,
|
|
234
|
+
k: int,
|
|
235
|
+
penalty_order: int = 2,
|
|
236
|
+
knots: ArrayLike | None = None,
|
|
237
|
+
absorb_cons: bool = True,
|
|
238
|
+
diagonal_penalty: bool = True,
|
|
239
|
+
scale_penalty: bool = True,
|
|
240
|
+
basis_name: str = "B",
|
|
241
|
+
) -> Basis:
|
|
242
|
+
validate_penalty_order(penalty_order)
|
|
243
|
+
if knots is not None:
|
|
244
|
+
knots = np.asarray(knots)
|
|
245
|
+
spec = f"s({x}, bs='cr', k={k}, m=c({penalty_order}))"
|
|
246
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
247
|
+
smooth = scon.SmoothCon(
|
|
248
|
+
spec,
|
|
249
|
+
data={x: x_array},
|
|
250
|
+
knots=knots,
|
|
251
|
+
absorb_cons=absorb_cons,
|
|
252
|
+
diagonal_penalty=diagonal_penalty,
|
|
253
|
+
scale_penalty=scale_penalty,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
257
|
+
basis = Basis(
|
|
258
|
+
x_var,
|
|
259
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
260
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
261
|
+
penalty=smooth.penalty,
|
|
262
|
+
use_callback=True,
|
|
263
|
+
cache_basis=True,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if absorb_cons:
|
|
267
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
268
|
+
return basis
|
|
269
|
+
|
|
270
|
+
def cs(
|
|
271
|
+
self,
|
|
272
|
+
x: str,
|
|
273
|
+
*,
|
|
274
|
+
k: int,
|
|
275
|
+
penalty_order: int = 2,
|
|
276
|
+
knots: ArrayLike | None = None,
|
|
277
|
+
absorb_cons: bool = True,
|
|
278
|
+
diagonal_penalty: bool = True,
|
|
279
|
+
scale_penalty: bool = True,
|
|
280
|
+
basis_name: str = "B",
|
|
281
|
+
) -> Basis:
|
|
282
|
+
"""
|
|
283
|
+
s(x,bs="cs") specifies a penalized cubic regression spline which has had its
|
|
284
|
+
penalty modified to shrink towards zero at high enough smoothing parameters (as
|
|
285
|
+
the smoothing parameter goes to infinity a normal cubic spline tends to a
|
|
286
|
+
straight line.)
|
|
287
|
+
"""
|
|
288
|
+
validate_penalty_order(penalty_order)
|
|
289
|
+
if knots is not None:
|
|
290
|
+
knots = np.asarray(knots)
|
|
291
|
+
spec = f"s({x}, bs='cs', k={k}, m=c({penalty_order}))"
|
|
292
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
293
|
+
smooth = scon.SmoothCon(
|
|
294
|
+
spec,
|
|
295
|
+
data={x: x_array},
|
|
296
|
+
knots=knots,
|
|
297
|
+
absorb_cons=absorb_cons,
|
|
298
|
+
diagonal_penalty=diagonal_penalty,
|
|
299
|
+
scale_penalty=scale_penalty,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
303
|
+
basis = Basis(
|
|
304
|
+
x_var,
|
|
305
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
306
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
307
|
+
penalty=smooth.penalty,
|
|
308
|
+
use_callback=True,
|
|
309
|
+
cache_basis=True,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if absorb_cons:
|
|
313
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
314
|
+
return basis
|
|
315
|
+
|
|
316
|
+
def cc(
|
|
317
|
+
self,
|
|
318
|
+
x: str,
|
|
319
|
+
*,
|
|
320
|
+
k: int,
|
|
321
|
+
penalty_order: int = 2,
|
|
322
|
+
knots: ArrayLike | None = None,
|
|
323
|
+
absorb_cons: bool = True,
|
|
324
|
+
diagonal_penalty: bool = True,
|
|
325
|
+
scale_penalty: bool = True,
|
|
326
|
+
basis_name: str = "B",
|
|
327
|
+
) -> Basis:
|
|
328
|
+
validate_penalty_order(penalty_order)
|
|
329
|
+
if knots is not None:
|
|
330
|
+
knots = np.asarray(knots)
|
|
331
|
+
spec = f"s({x}, bs='cc', k={k}, m=c({penalty_order}))"
|
|
332
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
333
|
+
smooth = scon.SmoothCon(
|
|
334
|
+
spec,
|
|
335
|
+
data={x: x_array},
|
|
336
|
+
knots=knots,
|
|
337
|
+
absorb_cons=absorb_cons,
|
|
338
|
+
diagonal_penalty=diagonal_penalty,
|
|
339
|
+
scale_penalty=scale_penalty,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
343
|
+
basis = Basis(
|
|
344
|
+
x_var,
|
|
345
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
346
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
347
|
+
penalty=smooth.penalty,
|
|
348
|
+
use_callback=True,
|
|
349
|
+
cache_basis=True,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if absorb_cons:
|
|
353
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
354
|
+
return basis
|
|
355
|
+
|
|
356
|
+
def bs(
|
|
357
|
+
self,
|
|
358
|
+
x: str,
|
|
359
|
+
*,
|
|
360
|
+
k: int,
|
|
361
|
+
basis_degree: int = 3,
|
|
362
|
+
penalty_order: int | Sequence[int] = 2,
|
|
363
|
+
knots: ArrayLike | None = None,
|
|
364
|
+
absorb_cons: bool = True,
|
|
365
|
+
diagonal_penalty: bool = True,
|
|
366
|
+
scale_penalty: bool = True,
|
|
367
|
+
basis_name: str = "B",
|
|
368
|
+
) -> Basis:
|
|
369
|
+
"""
|
|
370
|
+
The integrated square of the m[2]th derivative is used as the penalty. So
|
|
371
|
+
m=c(3,2) is a conventional cubic spline. Any further elements of m, after the
|
|
372
|
+
first 2, define the order of derivative in further penalties. If m is supplied
|
|
373
|
+
as a single number, then it is taken to be m[1] and m[2]=m[1]-1, which is only a
|
|
374
|
+
conventional smoothing spline in the m=3, cubic spline case.
|
|
375
|
+
"""
|
|
376
|
+
if knots is not None:
|
|
377
|
+
knots = np.asarray(knots)
|
|
378
|
+
if isinstance(penalty_order, int):
|
|
379
|
+
validate_penalty_order(penalty_order)
|
|
380
|
+
penalty_order_seq: Sequence[str] = [str(penalty_order)]
|
|
381
|
+
else:
|
|
382
|
+
[validate_penalty_order(p) for p in penalty_order]
|
|
383
|
+
penalty_order_seq = [str(p) for p in penalty_order]
|
|
384
|
+
|
|
385
|
+
spec = (
|
|
386
|
+
f"s({x}, bs='bs', k={k}, "
|
|
387
|
+
f"m=c({basis_degree}, {', '.join(penalty_order_seq)}))"
|
|
388
|
+
)
|
|
389
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
390
|
+
smooth = scon.SmoothCon(
|
|
391
|
+
spec,
|
|
392
|
+
data={x: x_array},
|
|
393
|
+
knots=knots,
|
|
394
|
+
absorb_cons=absorb_cons,
|
|
395
|
+
diagonal_penalty=diagonal_penalty,
|
|
396
|
+
scale_penalty=scale_penalty,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
400
|
+
basis = Basis(
|
|
401
|
+
x_var,
|
|
402
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
403
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
404
|
+
penalty=smooth.penalty,
|
|
405
|
+
use_callback=True,
|
|
406
|
+
cache_basis=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if absorb_cons:
|
|
410
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
411
|
+
return basis
|
|
412
|
+
|
|
413
|
+
def cp(
|
|
414
|
+
self,
|
|
415
|
+
x: str,
|
|
416
|
+
*,
|
|
417
|
+
k: int,
|
|
418
|
+
basis_degree: int = 3,
|
|
419
|
+
penalty_order: int = 2,
|
|
420
|
+
knots: ArrayLike | None = None,
|
|
421
|
+
absorb_cons: bool = True,
|
|
422
|
+
diagonal_penalty: bool = True,
|
|
423
|
+
scale_penalty: bool = True,
|
|
424
|
+
basis_name: str = "B",
|
|
425
|
+
) -> Basis:
|
|
426
|
+
validate_penalty_order(penalty_order)
|
|
427
|
+
if knots is not None:
|
|
428
|
+
knots = np.asarray(knots)
|
|
429
|
+
spec = f"s({x}, bs='cp', k={k}, m=c({basis_degree - 1}, {penalty_order}))"
|
|
430
|
+
x_array = jnp.asarray(self.registry.data[x].to_numpy())
|
|
431
|
+
smooth = scon.SmoothCon(
|
|
432
|
+
spec,
|
|
433
|
+
data={x: x_array},
|
|
434
|
+
knots=knots,
|
|
435
|
+
absorb_cons=absorb_cons,
|
|
436
|
+
diagonal_penalty=diagonal_penalty,
|
|
437
|
+
scale_penalty=scale_penalty,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
441
|
+
basis = Basis(
|
|
442
|
+
x_var,
|
|
443
|
+
name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
|
|
444
|
+
basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
|
|
445
|
+
penalty=smooth.penalty,
|
|
446
|
+
use_callback=True,
|
|
447
|
+
cache_basis=True,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
if absorb_cons:
|
|
451
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
452
|
+
return basis
|
|
453
|
+
|
|
454
|
+
def s(
|
|
455
|
+
self,
|
|
456
|
+
*x: str,
|
|
457
|
+
k: int,
|
|
458
|
+
bs: BasisTypes,
|
|
459
|
+
m: str = "NA",
|
|
460
|
+
knots: ArrayLike | None = None,
|
|
461
|
+
absorb_cons: bool = True,
|
|
462
|
+
diagonal_penalty: bool = True,
|
|
463
|
+
scale_penalty: bool = True,
|
|
464
|
+
basis_name: str = "B",
|
|
465
|
+
) -> Basis:
|
|
466
|
+
if knots is not None:
|
|
467
|
+
knots = np.asarray(knots)
|
|
468
|
+
_validate_bs(bs)
|
|
469
|
+
bs_arg = f"'{bs}'"
|
|
470
|
+
spec = f"s({','.join(x)}, bs={bs_arg}, k={k}, m={m})"
|
|
471
|
+
|
|
472
|
+
obs_vars = {}
|
|
473
|
+
for xname in x:
|
|
474
|
+
obs_vars[xname] = self.registry.get_numeric_obs(xname)
|
|
475
|
+
obs_values = {k: np.asarray(v.value) for k, v in obs_vars.items()}
|
|
476
|
+
|
|
477
|
+
smooth = scon.SmoothCon(
|
|
478
|
+
spec,
|
|
479
|
+
data=pd.DataFrame.from_dict(obs_values),
|
|
480
|
+
knots=knots,
|
|
481
|
+
absorb_cons=absorb_cons,
|
|
482
|
+
diagonal_penalty=diagonal_penalty,
|
|
483
|
+
scale_penalty=scale_penalty,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
xname = ",".join([v.name for v in obs_vars.values()])
|
|
487
|
+
|
|
488
|
+
if len(obs_vars) > 1:
|
|
489
|
+
xvar: lsl.Var | lsl.TransientCalc = (
|
|
490
|
+
lsl.TransientCalc( # for memory-efficiency
|
|
491
|
+
lambda *args: jnp.vstack(args).T,
|
|
492
|
+
*list(obs_vars.values()),
|
|
493
|
+
_name=self.names.create_lazily(xname),
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
xvar = obs_vars[xname]
|
|
498
|
+
|
|
499
|
+
def basis_fn(x):
|
|
500
|
+
df = pd.DataFrame(x, columns=list(obs_vars))
|
|
501
|
+
return jnp.asarray(smooth.predict(df))
|
|
502
|
+
|
|
503
|
+
basis = Basis(
|
|
504
|
+
xvar,
|
|
505
|
+
name=self.names.create_lazily(basis_name + "(" + xname + ")"),
|
|
506
|
+
basis_fn=basis_fn,
|
|
507
|
+
penalty=smooth.penalty,
|
|
508
|
+
use_callback=True,
|
|
509
|
+
cache_basis=True,
|
|
510
|
+
)
|
|
511
|
+
if absorb_cons:
|
|
512
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
513
|
+
return basis
|
|
514
|
+
|
|
515
|
+
def tp(
|
|
516
|
+
self,
|
|
517
|
+
*x: str,
|
|
518
|
+
k: int,
|
|
519
|
+
penalty_order: int | None = None,
|
|
520
|
+
knots: ArrayLike | None = None,
|
|
521
|
+
absorb_cons: bool = True,
|
|
522
|
+
diagonal_penalty: bool = True,
|
|
523
|
+
scale_penalty: bool = True,
|
|
524
|
+
basis_name: str = "B",
|
|
525
|
+
remove_null_space_completely: bool = False,
|
|
526
|
+
) -> Basis:
|
|
527
|
+
"""
|
|
528
|
+
For penalty_order:
|
|
529
|
+
m = penalty_order
|
|
530
|
+
Quote from MGCV docs
|
|
531
|
+
The default is to set m (the order of derivative in the thin plate spline
|
|
532
|
+
penalty) to the smallest value satisfying 2m > d+1 where d is the number of
|
|
533
|
+
covariates of the term: this yields ‘visually smooth’ functions.
|
|
534
|
+
In any case 2m>d must be satisfied.
|
|
535
|
+
"""
|
|
536
|
+
d = len(x)
|
|
537
|
+
m_args = []
|
|
538
|
+
if penalty_order is None:
|
|
539
|
+
penalty_order_default = ceil((d + 1) / 2)
|
|
540
|
+
i = 0
|
|
541
|
+
while not 2 * penalty_order_default > (d + 1) and i < 20:
|
|
542
|
+
penalty_order_default += 1
|
|
543
|
+
i += 1
|
|
544
|
+
|
|
545
|
+
m_args.append(str(penalty_order_default))
|
|
546
|
+
else:
|
|
547
|
+
validate_penalty_order(penalty_order)
|
|
548
|
+
m_args.append(str(penalty_order))
|
|
549
|
+
|
|
550
|
+
if remove_null_space_completely:
|
|
551
|
+
m_args.append("0")
|
|
552
|
+
m_str = "c(" + ", ".join(m_args) + ")"
|
|
553
|
+
|
|
554
|
+
basis = self.s(
|
|
555
|
+
*x,
|
|
556
|
+
k=k,
|
|
557
|
+
bs="tp",
|
|
558
|
+
m=m_str,
|
|
559
|
+
knots=knots,
|
|
560
|
+
absorb_cons=absorb_cons,
|
|
561
|
+
diagonal_penalty=diagonal_penalty,
|
|
562
|
+
scale_penalty=scale_penalty,
|
|
563
|
+
basis_name=basis_name,
|
|
564
|
+
)
|
|
565
|
+
return basis
|
|
566
|
+
|
|
567
|
+
def ts(
|
|
568
|
+
self,
|
|
569
|
+
*x: str,
|
|
570
|
+
k: int,
|
|
571
|
+
penalty_order: int | None = None,
|
|
572
|
+
knots: ArrayLike | None = None,
|
|
573
|
+
absorb_cons: bool = True,
|
|
574
|
+
diagonal_penalty: bool = True,
|
|
575
|
+
scale_penalty: bool = True,
|
|
576
|
+
basis_name: str = "B",
|
|
577
|
+
) -> Basis:
|
|
578
|
+
"""
|
|
579
|
+
For penalty_order:
|
|
580
|
+
m = penalty_order
|
|
581
|
+
Quote from MGCV docs
|
|
582
|
+
The default is to set m (the order of derivative in the thin plate spline
|
|
583
|
+
penalty) to the smallest value satisfying 2m > d+1 where d is the number of
|
|
584
|
+
covariates of the term: this yields ‘visually smooth’ functions.
|
|
585
|
+
In any case 2m>d must be satisfied.
|
|
586
|
+
"""
|
|
587
|
+
d = len(x)
|
|
588
|
+
m_args = []
|
|
589
|
+
if not penalty_order:
|
|
590
|
+
m_args.append(str(ceil((d + 1) / 2)))
|
|
591
|
+
else:
|
|
592
|
+
validate_penalty_order(penalty_order)
|
|
593
|
+
m_args.append(str(penalty_order))
|
|
594
|
+
|
|
595
|
+
m_str = "c(" + ", ".join(m_args) + ")"
|
|
596
|
+
|
|
597
|
+
basis = self.s(
|
|
598
|
+
*x,
|
|
599
|
+
k=k,
|
|
600
|
+
bs="ts",
|
|
601
|
+
m=m_str,
|
|
602
|
+
knots=knots,
|
|
603
|
+
absorb_cons=absorb_cons,
|
|
604
|
+
diagonal_penalty=diagonal_penalty,
|
|
605
|
+
scale_penalty=scale_penalty,
|
|
606
|
+
basis_name=basis_name,
|
|
607
|
+
)
|
|
608
|
+
return basis
|
|
609
|
+
|
|
610
|
+
def kriging(
|
|
611
|
+
self,
|
|
612
|
+
*x: str,
|
|
613
|
+
k: int,
|
|
614
|
+
kernel_name: Literal[
|
|
615
|
+
"spherical",
|
|
616
|
+
"power_exponential",
|
|
617
|
+
"matern1.5",
|
|
618
|
+
"matern2.5",
|
|
619
|
+
"matern3.5",
|
|
620
|
+
] = "matern1.5",
|
|
621
|
+
linear_trend: bool = True,
|
|
622
|
+
range: float | None = None,
|
|
623
|
+
power_exponential_power: float = 1.0,
|
|
624
|
+
knots: ArrayLike | None = None,
|
|
625
|
+
absorb_cons: bool = True,
|
|
626
|
+
diagonal_penalty: bool = True,
|
|
627
|
+
scale_penalty: bool = True,
|
|
628
|
+
basis_name: str = "B",
|
|
629
|
+
) -> Basis:
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
- If range=None, the range parameter will be estimated as in Kammann and \
|
|
633
|
+
Wand (2003)
|
|
634
|
+
"""
|
|
635
|
+
m_kernel_dict = {
|
|
636
|
+
"spherical": 1,
|
|
637
|
+
"power_exponential": 2,
|
|
638
|
+
"matern1.5": 3,
|
|
639
|
+
"matern2.5": 4,
|
|
640
|
+
"matern3.5": 5,
|
|
641
|
+
}
|
|
642
|
+
m_linear = 1.0 if linear_trend else -1.0
|
|
643
|
+
|
|
644
|
+
m_args = []
|
|
645
|
+
m_kernel = str(int(m_linear * m_kernel_dict[kernel_name]))
|
|
646
|
+
m_args.append(m_kernel)
|
|
647
|
+
if range:
|
|
648
|
+
m_range = str(range)
|
|
649
|
+
m_args.append(m_range)
|
|
650
|
+
if power_exponential_power:
|
|
651
|
+
if not range:
|
|
652
|
+
m_args.append(str(-1.0))
|
|
653
|
+
if not 0.0 < power_exponential_power <= 2.0:
|
|
654
|
+
raise ValueError(
|
|
655
|
+
"'power_exponential_power' must be in (0, 2.0], "
|
|
656
|
+
f"got {power_exponential_power}"
|
|
657
|
+
)
|
|
658
|
+
m_args.append(str(power_exponential_power))
|
|
659
|
+
|
|
660
|
+
m_str = "c(" + ", ".join(m_args) + ")"
|
|
661
|
+
|
|
662
|
+
basis = self.s(
|
|
663
|
+
*x,
|
|
664
|
+
k=k,
|
|
665
|
+
bs="gp",
|
|
666
|
+
m=m_str,
|
|
667
|
+
knots=knots,
|
|
668
|
+
absorb_cons=absorb_cons,
|
|
669
|
+
diagonal_penalty=diagonal_penalty,
|
|
670
|
+
scale_penalty=scale_penalty,
|
|
671
|
+
basis_name=basis_name,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
return basis
|
|
675
|
+
|
|
676
|
+
def lin(
|
|
677
|
+
self,
|
|
678
|
+
formula: str,
|
|
679
|
+
xname: str = "",
|
|
680
|
+
basis_name: str = "X",
|
|
681
|
+
include_intercept: bool = False,
|
|
682
|
+
context: dict[str, Any] | None = None,
|
|
683
|
+
) -> LinBasis:
|
|
684
|
+
validate_formula(formula)
|
|
685
|
+
spec = fo.ModelSpec(formula, output="numpy")
|
|
686
|
+
|
|
687
|
+
if not include_intercept:
|
|
688
|
+
# because we do our own intercept handling with the full model matrix
|
|
689
|
+
# it may be surprising to assert that there is an intercept only if
|
|
690
|
+
# the plan is to remove it.
|
|
691
|
+
# But in order to safely remove it, we first have to ensure that it is
|
|
692
|
+
# present.
|
|
693
|
+
assert_intercept_in_spec(spec)
|
|
694
|
+
|
|
695
|
+
# evaluate model matrix once to get a spec with structure information
|
|
696
|
+
# also necessary to populate spec with the correct information for
|
|
697
|
+
# transformations like center, scale, standardize
|
|
698
|
+
spec = spec.get_model_matrix(self.data, context=context).model_spec
|
|
699
|
+
|
|
700
|
+
# get column names. There may be a more efficient way to do it
|
|
701
|
+
# that does not require building the model matrix a second time, but this
|
|
702
|
+
# works robustly for now: we take the names that formulaic creates
|
|
703
|
+
column_names = list(
|
|
704
|
+
fo.ModelSpec(formula, output="pandas")
|
|
705
|
+
.get_model_matrix(self.data, context=context)
|
|
706
|
+
.columns
|
|
707
|
+
)[1:]
|
|
708
|
+
|
|
709
|
+
required = sorted(str(var) for var in spec.required_variables)
|
|
710
|
+
df_subset = self.data.loc[:, required]
|
|
711
|
+
df_colnames = df_subset.columns
|
|
712
|
+
|
|
713
|
+
variables = dict()
|
|
714
|
+
|
|
715
|
+
mappings = {}
|
|
716
|
+
for col in df_colnames:
|
|
717
|
+
result = self.registry.get_obs_and_mapping(col)
|
|
718
|
+
variables[col] = result.var
|
|
719
|
+
|
|
720
|
+
if result.mapping is not None:
|
|
721
|
+
self.mappings[col] = result.mapping
|
|
722
|
+
mappings[col] = result.mapping
|
|
723
|
+
|
|
724
|
+
xvar = lsl.TransientCalc( # for memory-efficiency
|
|
725
|
+
lambda *args: jnp.vstack(args).T,
|
|
726
|
+
*list(variables.values()),
|
|
727
|
+
_name=self.names.create_lazily(xname) if xname else xname,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
def basis_fn(x):
|
|
731
|
+
df = pd.DataFrame(x, columns=df_colnames)
|
|
732
|
+
|
|
733
|
+
# for categorical variables: convert integer representation back to
|
|
734
|
+
# labels
|
|
735
|
+
for col in df_colnames:
|
|
736
|
+
if col in self.mappings:
|
|
737
|
+
integers = df[col].to_numpy()
|
|
738
|
+
df[col] = self.mappings[col].integers_to_labels(integers)
|
|
739
|
+
|
|
740
|
+
basis = np.asarray(spec.get_model_matrix(df, context=context))
|
|
741
|
+
if not include_intercept:
|
|
742
|
+
basis = basis[:, 1:]
|
|
743
|
+
return jnp.asarray(basis, dtype=float)
|
|
744
|
+
|
|
745
|
+
if xname:
|
|
746
|
+
bname = self.names.create_lazily(basis_name + "(" + xvar.name + ")")
|
|
747
|
+
else:
|
|
748
|
+
bname = self.names.create_lazily(basis_name)
|
|
749
|
+
|
|
750
|
+
basis = LinBasis(
|
|
751
|
+
xvar,
|
|
752
|
+
basis_fn=basis_fn,
|
|
753
|
+
use_callback=True,
|
|
754
|
+
cache_basis=True,
|
|
755
|
+
name=bname,
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
basis.model_spec = spec
|
|
759
|
+
basis.mappings = mappings
|
|
760
|
+
basis.column_names = column_names
|
|
761
|
+
|
|
762
|
+
return basis
|
|
763
|
+
|
|
764
|
+
def ri(
|
|
765
|
+
self,
|
|
766
|
+
cluster: str,
|
|
767
|
+
basis_name: str = "B",
|
|
768
|
+
penalty: ArrayLike | None = None,
|
|
769
|
+
) -> Basis:
|
|
770
|
+
if penalty is not None:
|
|
771
|
+
penalty = jnp.asarray(penalty)
|
|
772
|
+
result = self.registry.get_obs_and_mapping(cluster)
|
|
773
|
+
if result.mapping is None:
|
|
774
|
+
raise TypeError(f"{cluster=} must be categorical.")
|
|
775
|
+
|
|
776
|
+
self.mappings[cluster] = result.mapping
|
|
777
|
+
nparams = len(result.mapping.labels_to_integers_map)
|
|
778
|
+
|
|
779
|
+
if penalty is None:
|
|
780
|
+
penalty = jnp.eye(nparams)
|
|
781
|
+
|
|
782
|
+
basis = Basis(
|
|
783
|
+
value=result.var,
|
|
784
|
+
basis_fn=lambda x: x,
|
|
785
|
+
name=self.names.create_lazily(basis_name + "(" + cluster + ")"),
|
|
786
|
+
use_callback=False,
|
|
787
|
+
cache_basis=False,
|
|
788
|
+
penalty=jnp.asarray(penalty) if penalty is not None else penalty,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
return basis
|
|
792
|
+
|
|
793
|
+
def mrf(
|
|
794
|
+
self,
|
|
795
|
+
x: str,
|
|
796
|
+
k: int = -1,
|
|
797
|
+
polys: dict[str, ArrayLike] | None = None,
|
|
798
|
+
nb: Mapping[str, ArrayLike | list[str] | list[int]] | None = None,
|
|
799
|
+
penalty: ArrayLike | None = None,
|
|
800
|
+
penalty_labels: Sequence[str] | None = None,
|
|
801
|
+
absorb_cons: bool = False,
|
|
802
|
+
diagonal_penalty: bool = False,
|
|
803
|
+
scale_penalty: bool = False,
|
|
804
|
+
basis_name: str = "B",
|
|
805
|
+
) -> MRFBasis:
|
|
806
|
+
"""
|
|
807
|
+
Polys: Dictionary of arrays. The keys of the dict are the region labels.
|
|
808
|
+
The corresponding values define the region by defining polygons.
|
|
809
|
+
nb: Dictionary of array. The keys of the dict are the region labels.
|
|
810
|
+
The corresponding values indicate the neighbors of the region.
|
|
811
|
+
If it is a list or array of strings, the values are the labels of the
|
|
812
|
+
neighbors.
|
|
813
|
+
If it is a list or array of integers, the values are the indices of the
|
|
814
|
+
neighbors.
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
mgcv does not concern itself with your category ordering. It *will* order
|
|
818
|
+
categories alphabetically. Penalty columns have to take this into account.
|
|
819
|
+
|
|
820
|
+
Comments on return value:
|
|
821
|
+
|
|
822
|
+
- If either polys or nb are supplied, the returned container will contain nb.
|
|
823
|
+
- If only a penalty matrix is supplied, the returned container will *not*
|
|
824
|
+
contain nb.
|
|
825
|
+
- Returning the label order only makes sense if the basis is *not*
|
|
826
|
+
reparameterized, because only then we have a clear correspondence of
|
|
827
|
+
parameters to labels.
|
|
828
|
+
If the basis is reparameterized, there's no such correspondence in a clear
|
|
829
|
+
way, so the returned label order is None.
|
|
830
|
+
|
|
831
|
+
"""
|
|
832
|
+
|
|
833
|
+
if not isinstance(k, int):
|
|
834
|
+
raise TypeError(f"'k' must be int, got {type(k)}.")
|
|
835
|
+
if k < -1:
|
|
836
|
+
raise ValueError(f"'k' cannot be smaller than -1, got {k=}.")
|
|
837
|
+
|
|
838
|
+
if polys is None and nb is None and penalty is None:
|
|
839
|
+
raise ValueError("At least one of polys, nb, or penalty must be provided.")
|
|
840
|
+
|
|
841
|
+
var, mapping = self.registry.get_categorical_obs(x)
|
|
842
|
+
self.mappings[x] = mapping
|
|
843
|
+
|
|
844
|
+
labels = set(list(mapping.labels_to_integers_map))
|
|
845
|
+
|
|
846
|
+
if penalty is not None:
|
|
847
|
+
if penalty_labels is None:
|
|
848
|
+
raise ValueError(
|
|
849
|
+
"If 'penalty' is supplied, 'penalty_labels' must also be supplied."
|
|
850
|
+
)
|
|
851
|
+
if len(penalty_labels) != len(labels):
|
|
852
|
+
raise ValueError(
|
|
853
|
+
f"Variable {x} has {len(labels)} unique entries, but "
|
|
854
|
+
f"'penalty_labels' has {len(penalty_labels)}. Both must match."
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
xt_args = []
|
|
858
|
+
pass_to_r: dict[str, np.typing.NDArray | dict[str, np.typing.NDArray]] = {}
|
|
859
|
+
if polys is not None:
|
|
860
|
+
xt_args.append("polys=polys")
|
|
861
|
+
if not labels == set(list(polys)):
|
|
862
|
+
raise ValueError(
|
|
863
|
+
"Names in 'poly' must correspond to the levels of 'x'."
|
|
864
|
+
)
|
|
865
|
+
pass_to_r["polys"] = {key: np.asarray(val) for key, val in polys.items()}
|
|
866
|
+
|
|
867
|
+
if nb is not None:
|
|
868
|
+
xt_args.append("nb=nb")
|
|
869
|
+
if not labels == set(list(nb)):
|
|
870
|
+
raise ValueError("Names in 'nb' must correspond to the levels of 'x'.")
|
|
871
|
+
|
|
872
|
+
nb_processed = {}
|
|
873
|
+
for key, val in nb.items():
|
|
874
|
+
val_arr = np.asarray(val)
|
|
875
|
+
if np.isdtype(val_arr.dtype, np.dtype("int")):
|
|
876
|
+
# add one to convert to 1-based indexing for R
|
|
877
|
+
# and cast to float for R
|
|
878
|
+
val_arr = np.astype(val_arr + 1, float)
|
|
879
|
+
# val_arr = np.astype(val_arr, float)
|
|
880
|
+
elif np.isdtype(val_arr.dtype, np.dtype("float")):
|
|
881
|
+
# add one to convert to 1-based indexing for R
|
|
882
|
+
val_arr = np.astype(np.astype(val_arr, int) + 1, float)
|
|
883
|
+
elif val_arr.dtype.kind == "U": # must be unicode strings then
|
|
884
|
+
pass
|
|
885
|
+
else:
|
|
886
|
+
raise TypeError(f"Unsupported dtype: {val_arr.dtype!r}")
|
|
887
|
+
|
|
888
|
+
nb_processed[key] = val_arr
|
|
889
|
+
|
|
890
|
+
pass_to_r["nb"] = nb_processed
|
|
891
|
+
|
|
892
|
+
if penalty is not None:
|
|
893
|
+
penalty = np.asarray(penalty)
|
|
894
|
+
pen_rank = np.linalg.matrix_rank(penalty)
|
|
895
|
+
pen_dim = penalty.shape[-1]
|
|
896
|
+
if (pen_dim - pen_rank) != 1:
|
|
897
|
+
logger.warning(
|
|
898
|
+
f"Supplied penalty has dimension {penalty.shape} and rank "
|
|
899
|
+
f"{pen_rank}. The expected rank deficiency is 1. "
|
|
900
|
+
"This may indicate a problem. There might be disconnected sets "
|
|
901
|
+
"of regions in the data represented by this penalty. "
|
|
902
|
+
"In this case, you probably need more elaborate constraints "
|
|
903
|
+
"than the ones provided here. You might consider splitting the "
|
|
904
|
+
"disconnected regions into several mrf terms. "
|
|
905
|
+
"Otherwise, please only continue if you are certain that you "
|
|
906
|
+
"know what is happening."
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
xt_args.append("penalty=penalty")
|
|
910
|
+
if not np.shape(penalty)[0] == np.shape(penalty)[1]:
|
|
911
|
+
raise ValueError(f"Penalty must be square, got {np.shape(penalty)=}")
|
|
912
|
+
|
|
913
|
+
if not np.shape(penalty)[1] == len(labels):
|
|
914
|
+
raise ValueError(
|
|
915
|
+
"Dimensions of 'penalty' must correspond to the levels of 'x'."
|
|
916
|
+
)
|
|
917
|
+
pass_to_r["penalty"] = penalty
|
|
918
|
+
|
|
919
|
+
xt = "list("
|
|
920
|
+
xt += ",".join(xt_args)
|
|
921
|
+
xt += ")"
|
|
922
|
+
|
|
923
|
+
if penalty is not None:
|
|
924
|
+
# removing penalty from the pass_to_r dict, because we are giving it
|
|
925
|
+
# special treatment here.
|
|
926
|
+
# specifically, we have to equip it with row and column names to make
|
|
927
|
+
# sure that penalty entries get correctly matched to clusters by mgcv
|
|
928
|
+
penalty_prelim_arr = np.asarray(pass_to_r.pop("penalty"))
|
|
929
|
+
to_r(penalty_prelim_arr, "penalty")
|
|
930
|
+
to_r(np.array(penalty_labels), "penalty_labels")
|
|
931
|
+
r("colnames(penalty) <- penalty_labels")
|
|
932
|
+
r("rownames(penalty) <- penalty_labels")
|
|
933
|
+
|
|
934
|
+
spec = f"s({x}, k={k}, bs='mrf', xt={xt})"
|
|
935
|
+
|
|
936
|
+
# disabling warnings about "mrf should be a factor"
|
|
937
|
+
# since even turning data into a pandas df and x_array into
|
|
938
|
+
# a categorical series did not satisfy mgcv in that regard.
|
|
939
|
+
# Things still seem to work, and we ensure further above
|
|
940
|
+
# that we are actually dealing with a categorical variable
|
|
941
|
+
# so I think turning the warnings off temporarily here is fine
|
|
942
|
+
# r("old_warn <- getOption('warn')")
|
|
943
|
+
# r("options(warn = -1)")
|
|
944
|
+
observed = mapping.integers_to_labels(var.value)
|
|
945
|
+
regions = list(mapping.labels_to_integers_map)
|
|
946
|
+
df = pd.DataFrame({x: pd.Categorical(observed, categories=regions)})
|
|
947
|
+
|
|
948
|
+
smooth = scon.SmoothCon(
|
|
949
|
+
spec,
|
|
950
|
+
data=df,
|
|
951
|
+
diagonal_penalty=diagonal_penalty,
|
|
952
|
+
absorb_cons=absorb_cons,
|
|
953
|
+
scale_penalty=scale_penalty,
|
|
954
|
+
pass_to_r=pass_to_r,
|
|
955
|
+
)
|
|
956
|
+
# r("options(warn = old_warn)")
|
|
957
|
+
|
|
958
|
+
x_name = x
|
|
959
|
+
|
|
960
|
+
def basis_fun(x):
|
|
961
|
+
"""
|
|
962
|
+
The array outputted by this smooth contains column names.
|
|
963
|
+
Here, we remove these column names and convert to jax.
|
|
964
|
+
"""
|
|
965
|
+
# disabling warnings about "mrf should be a factor"
|
|
966
|
+
r("old_warn <- getOption('warn')")
|
|
967
|
+
r("options(warn = -1)")
|
|
968
|
+
labels = mapping.integers_to_labels(x)
|
|
969
|
+
df = pd.DataFrame({x_name: pd.Categorical(labels, categories=regions)})
|
|
970
|
+
basis = jnp.asarray(np.astype(smooth.predict(df)[:, 1:], "float"))
|
|
971
|
+
r("options(warn = old_warn)")
|
|
972
|
+
return basis
|
|
973
|
+
|
|
974
|
+
smooth_penalty = smooth.penalty
|
|
975
|
+
if np.shape(smooth_penalty)[1] > len(labels):
|
|
976
|
+
smooth_penalty = smooth_penalty[:, 1:]
|
|
977
|
+
|
|
978
|
+
penalty_arr = jnp.asarray(np.astype(smooth_penalty, "float"))
|
|
979
|
+
|
|
980
|
+
basis = MRFBasis(
|
|
981
|
+
value=var,
|
|
982
|
+
basis_fn=basis_fun,
|
|
983
|
+
name=self.names.create_lazily(basis_name + "(" + x + ")"),
|
|
984
|
+
cache_basis=True,
|
|
985
|
+
use_callback=True,
|
|
986
|
+
penalty=penalty_arr,
|
|
987
|
+
)
|
|
988
|
+
if absorb_cons:
|
|
989
|
+
basis._constraint = "absorbed_via_mgcv"
|
|
990
|
+
|
|
991
|
+
try:
|
|
992
|
+
nb_out = to_py(f"{smooth._smooth_r_name}[[1]]$xt$nb", format="numpy")
|
|
993
|
+
except TypeError:
|
|
994
|
+
nb_out = None
|
|
995
|
+
# nb_out = {key: np.astype(val, "int") for key, val in nb_out.items()}
|
|
996
|
+
|
|
997
|
+
if absorb_cons:
|
|
998
|
+
label_order = None
|
|
999
|
+
else:
|
|
1000
|
+
label_order = list(
|
|
1001
|
+
to_py(f"{smooth._smooth_r_name}[[1]]$X", format="pandas").columns
|
|
1002
|
+
)
|
|
1003
|
+
label_order = [lab[1:] for lab in label_order] # removes leading x from R
|
|
1004
|
+
|
|
1005
|
+
if nb_out is not None:
|
|
1006
|
+
|
|
1007
|
+
def to_label(code):
|
|
1008
|
+
try:
|
|
1009
|
+
label_array = mapping.integers_to_labels(code - 1)
|
|
1010
|
+
except TypeError:
|
|
1011
|
+
label_array = code
|
|
1012
|
+
return np.atleast_1d(label_array).tolist()
|
|
1013
|
+
|
|
1014
|
+
nb_out = {k: to_label(v) for k, v in nb_out.items()}
|
|
1015
|
+
|
|
1016
|
+
basis.mrf_spec = MRFSpec(mapping, nb_out, label_order)
|
|
1017
|
+
|
|
1018
|
+
return basis
|
|
1019
|
+
|
|
1020
|
+
|
|
1021
|
+
@dataclass
|
|
1022
|
+
class NameManager:
|
|
1023
|
+
prefix: str = ""
|
|
1024
|
+
created_names: dict[str, int] = field(default_factory=dict)
|
|
1025
|
+
|
|
1026
|
+
def create(self, name: str, apply_prefix: bool = True) -> str:
|
|
1027
|
+
"""
|
|
1028
|
+
Appends a counter to the given name for uniqueness.
|
|
1029
|
+
There is an individual counter for each name.
|
|
1030
|
+
|
|
1031
|
+
If a prefix was passed to the builder on init, the prefix is applied to the
|
|
1032
|
+
name.
|
|
1033
|
+
"""
|
|
1034
|
+
if apply_prefix:
|
|
1035
|
+
name = self.prefix + name
|
|
1036
|
+
|
|
1037
|
+
i = self.created_names.get(name, 0)
|
|
1038
|
+
|
|
1039
|
+
name_indexed = name + str(i)
|
|
1040
|
+
|
|
1041
|
+
self.created_names[name] = i + 1
|
|
1042
|
+
|
|
1043
|
+
return name_indexed
|
|
1044
|
+
|
|
1045
|
+
def create_lazily(self, name: str, apply_prefix: bool = True) -> str:
|
|
1046
|
+
if apply_prefix:
|
|
1047
|
+
name = self.prefix + name
|
|
1048
|
+
|
|
1049
|
+
i = self.created_names.get(name, 0)
|
|
1050
|
+
|
|
1051
|
+
if i > 0:
|
|
1052
|
+
name_indexed = name + str(i)
|
|
1053
|
+
else:
|
|
1054
|
+
name_indexed = name
|
|
1055
|
+
|
|
1056
|
+
self.created_names[name] = i + 1
|
|
1057
|
+
|
|
1058
|
+
return name_indexed
|
|
1059
|
+
|
|
1060
|
+
def fname(self, f: str, basis: Basis) -> str:
|
|
1061
|
+
return self.create_lazily(f"{f}({basis.x.name})")
|
|
1062
|
+
|
|
1063
|
+
def create_param_name(self, term_name: str, param_name: str) -> str:
|
|
1064
|
+
if term_name:
|
|
1065
|
+
param_name = f"${param_name}" + "_{" + f"{term_name}" + "}$"
|
|
1066
|
+
return self.create_lazily(param_name, apply_prefix=False)
|
|
1067
|
+
else:
|
|
1068
|
+
param_name = f"${param_name}$"
|
|
1069
|
+
return self.create_lazily(param_name, apply_prefix=True)
|
|
1070
|
+
|
|
1071
|
+
def create_beta_name(self, term_name: str) -> str:
|
|
1072
|
+
return self.create_param_name(term_name=term_name, param_name="\\beta")
|
|
1073
|
+
|
|
1074
|
+
def create_tau_name(self, term_name: str) -> str:
|
|
1075
|
+
return self.create_param_name(term_name=term_name, param_name="\\tau")
|
|
1076
|
+
|
|
1077
|
+
def create_tau2_name(self, term_name: str) -> str:
|
|
1078
|
+
return self.create_param_name(term_name=term_name, param_name="\\tau^2")
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
class TermBuilder:
|
|
1082
|
+
def __init__(self, registry: PandasRegistry, prefix_names_by: str = "") -> None:
|
|
1083
|
+
self.registry = registry
|
|
1084
|
+
self.names = NameManager(prefix=prefix_names_by)
|
|
1085
|
+
self.bases = BasisBuilder(registry, names=self.names)
|
|
1086
|
+
|
|
1087
|
+
def _init_default_scale(
|
|
1088
|
+
self,
|
|
1089
|
+
concentration: float | Array,
|
|
1090
|
+
scale: float | Array,
|
|
1091
|
+
value: float | Array = 1.0,
|
|
1092
|
+
term_name: str = "",
|
|
1093
|
+
) -> ScaleIG:
|
|
1094
|
+
scale_name = self.names.create_tau_name(term_name)
|
|
1095
|
+
variance_name = self.names.create_tau2_name(term_name)
|
|
1096
|
+
scale_var = ScaleIG(
|
|
1097
|
+
value=value,
|
|
1098
|
+
concentration=concentration,
|
|
1099
|
+
scale=scale,
|
|
1100
|
+
name=scale_name,
|
|
1101
|
+
variance_name=variance_name,
|
|
1102
|
+
)
|
|
1103
|
+
return scale_var
|
|
1104
|
+
|
|
1105
|
+
@classmethod
|
|
1106
|
+
def from_dict(
|
|
1107
|
+
cls, data: dict[str, ArrayLike], prefix_names_by: str = ""
|
|
1108
|
+
) -> TermBuilder:
|
|
1109
|
+
return cls.from_df(pd.DataFrame(data), prefix_names_by=prefix_names_by)
|
|
1110
|
+
|
|
1111
|
+
@classmethod
|
|
1112
|
+
def from_df(cls, data: pd.DataFrame, prefix_names_by: str = "") -> TermBuilder:
|
|
1113
|
+
registry = PandasRegistry(
|
|
1114
|
+
data, na_action="drop", prefix_names_by=prefix_names_by
|
|
1115
|
+
)
|
|
1116
|
+
return cls(registry, prefix_names_by=prefix_names_by)
|
|
1117
|
+
|
|
1118
|
+
def labels_to_integers(self, newdata: dict) -> dict:
|
|
1119
|
+
return labels_to_integers(newdata, self.bases.mappings)
|
|
1120
|
+
|
|
1121
|
+
# formula
|
|
1122
|
+
def lin(
|
|
1123
|
+
self,
|
|
1124
|
+
formula: str,
|
|
1125
|
+
prior: lsl.Dist | None = None,
|
|
1126
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1127
|
+
include_intercept: bool = False,
|
|
1128
|
+
context: dict[str, Any] | None = None,
|
|
1129
|
+
) -> LinTerm:
|
|
1130
|
+
r"""
|
|
1131
|
+
Supported:
|
|
1132
|
+
- {a+1} for quoted Python
|
|
1133
|
+
- `weird name` backtick-strings for weird names
|
|
1134
|
+
- (a + b)**n for n-th order interactions
|
|
1135
|
+
- a:b for simple interactions
|
|
1136
|
+
- a*b for expanding to a + b + a:b
|
|
1137
|
+
- a / b for nesting
|
|
1138
|
+
- b %in% a for inverted nesting
|
|
1139
|
+
- Python functions
|
|
1140
|
+
- bs
|
|
1141
|
+
- cr
|
|
1142
|
+
- cs
|
|
1143
|
+
- cc
|
|
1144
|
+
- hashed
|
|
1145
|
+
|
|
1146
|
+
.. warning:: If you use bs, cr, cs, or cc, be aware that these will not
|
|
1147
|
+
lead to terms that include a penalty. In most cases, you probably want
|
|
1148
|
+
to use :meth:`~.TermBuilder.s`, :meth:`~.TermBuilder.ps`, and so on
|
|
1149
|
+
instead.
|
|
1150
|
+
|
|
1151
|
+
Not supported:
|
|
1152
|
+
|
|
1153
|
+
- String literals
|
|
1154
|
+
- Numeric literals
|
|
1155
|
+
- Wildcard "."
|
|
1156
|
+
- \| for splitting a formula
|
|
1157
|
+
- "te" tensor products
|
|
1158
|
+
|
|
1159
|
+
- "~" in formula
|
|
1160
|
+
- 1 + in formula
|
|
1161
|
+
- 0 + in formula
|
|
1162
|
+
- -1 in formula
|
|
1163
|
+
|
|
1164
|
+
"""
|
|
1165
|
+
|
|
1166
|
+
basis = self.bases.lin(
|
|
1167
|
+
formula,
|
|
1168
|
+
xname="",
|
|
1169
|
+
basis_name="X",
|
|
1170
|
+
include_intercept=include_intercept,
|
|
1171
|
+
context=context,
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
if basis.x.name:
|
|
1175
|
+
term_name = self.names.create_lazily("lin" + "(" + basis.x.name + ")")
|
|
1176
|
+
else:
|
|
1177
|
+
term_name = self.names.create_lazily("lin" + "(" + basis.name + ")")
|
|
1178
|
+
|
|
1179
|
+
coef_name = self.names.create_beta_name(term_name)
|
|
1180
|
+
|
|
1181
|
+
term = LinTerm(
|
|
1182
|
+
basis, prior=prior, name=term_name, inference=inference, coef_name=coef_name
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
term.model_spec = basis.model_spec
|
|
1186
|
+
term.mappings = basis.mappings
|
|
1187
|
+
term.column_names = basis.column_names
|
|
1188
|
+
|
|
1189
|
+
return term
|
|
1190
|
+
|
|
1191
|
+
def cr(
|
|
1192
|
+
self,
|
|
1193
|
+
x: str,
|
|
1194
|
+
*,
|
|
1195
|
+
k: int,
|
|
1196
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1197
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1198
|
+
penalty_order: int = 2,
|
|
1199
|
+
knots: ArrayLike | None = None,
|
|
1200
|
+
absorb_cons: bool = True,
|
|
1201
|
+
diagonal_penalty: bool = True,
|
|
1202
|
+
scale_penalty: bool = True,
|
|
1203
|
+
noncentered: bool = False,
|
|
1204
|
+
) -> Term:
|
|
1205
|
+
basis = self.bases.cr(
|
|
1206
|
+
x=x,
|
|
1207
|
+
k=k,
|
|
1208
|
+
penalty_order=penalty_order,
|
|
1209
|
+
knots=knots,
|
|
1210
|
+
absorb_cons=absorb_cons,
|
|
1211
|
+
diagonal_penalty=diagonal_penalty,
|
|
1212
|
+
scale_penalty=scale_penalty,
|
|
1213
|
+
basis_name="B",
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
fname = self.names.fname("cr", basis)
|
|
1217
|
+
|
|
1218
|
+
if isinstance(scale, VarIGPrior):
|
|
1219
|
+
scale = self._init_default_scale(
|
|
1220
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1224
|
+
term = Term(
|
|
1225
|
+
basis=basis,
|
|
1226
|
+
penalty=basis.penalty,
|
|
1227
|
+
scale=scale,
|
|
1228
|
+
name=fname,
|
|
1229
|
+
inference=inference,
|
|
1230
|
+
coef_name=coef_name,
|
|
1231
|
+
)
|
|
1232
|
+
if noncentered:
|
|
1233
|
+
term.reparam_noncentered()
|
|
1234
|
+
return term
|
|
1235
|
+
|
|
1236
|
+
def cs(
|
|
1237
|
+
self,
|
|
1238
|
+
x: str,
|
|
1239
|
+
*,
|
|
1240
|
+
k: int,
|
|
1241
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1242
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1243
|
+
penalty_order: int = 2,
|
|
1244
|
+
knots: ArrayLike | None = None,
|
|
1245
|
+
absorb_cons: bool = True,
|
|
1246
|
+
diagonal_penalty: bool = True,
|
|
1247
|
+
scale_penalty: bool = True,
|
|
1248
|
+
noncentered: bool = False,
|
|
1249
|
+
) -> Term:
|
|
1250
|
+
basis = self.bases.cs(
|
|
1251
|
+
x=x,
|
|
1252
|
+
k=k,
|
|
1253
|
+
penalty_order=penalty_order,
|
|
1254
|
+
knots=knots,
|
|
1255
|
+
absorb_cons=absorb_cons,
|
|
1256
|
+
diagonal_penalty=diagonal_penalty,
|
|
1257
|
+
scale_penalty=scale_penalty,
|
|
1258
|
+
basis_name="B",
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
fname = self.names.fname("cs", basis)
|
|
1262
|
+
|
|
1263
|
+
if isinstance(scale, VarIGPrior):
|
|
1264
|
+
scale = self._init_default_scale(
|
|
1265
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1269
|
+
term = Term(
|
|
1270
|
+
basis=basis,
|
|
1271
|
+
penalty=basis.penalty,
|
|
1272
|
+
scale=scale,
|
|
1273
|
+
name=fname,
|
|
1274
|
+
inference=inference,
|
|
1275
|
+
coef_name=coef_name,
|
|
1276
|
+
)
|
|
1277
|
+
if noncentered:
|
|
1278
|
+
term.reparam_noncentered()
|
|
1279
|
+
return term
|
|
1280
|
+
|
|
1281
|
+
def cc(
|
|
1282
|
+
self,
|
|
1283
|
+
x: str,
|
|
1284
|
+
*,
|
|
1285
|
+
k: int,
|
|
1286
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1287
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1288
|
+
penalty_order: int = 2,
|
|
1289
|
+
knots: ArrayLike | None = None,
|
|
1290
|
+
absorb_cons: bool = True,
|
|
1291
|
+
diagonal_penalty: bool = True,
|
|
1292
|
+
scale_penalty: bool = True,
|
|
1293
|
+
noncentered: bool = False,
|
|
1294
|
+
) -> Term:
|
|
1295
|
+
basis = self.bases.cc(
|
|
1296
|
+
x=x,
|
|
1297
|
+
k=k,
|
|
1298
|
+
penalty_order=penalty_order,
|
|
1299
|
+
knots=knots,
|
|
1300
|
+
absorb_cons=absorb_cons,
|
|
1301
|
+
diagonal_penalty=diagonal_penalty,
|
|
1302
|
+
scale_penalty=scale_penalty,
|
|
1303
|
+
basis_name="B",
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
fname = self.names.fname("cc", basis)
|
|
1307
|
+
|
|
1308
|
+
if isinstance(scale, VarIGPrior):
|
|
1309
|
+
scale = self._init_default_scale(
|
|
1310
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1314
|
+
term = Term(
|
|
1315
|
+
basis=basis,
|
|
1316
|
+
penalty=basis.penalty,
|
|
1317
|
+
scale=scale,
|
|
1318
|
+
name=fname,
|
|
1319
|
+
inference=inference,
|
|
1320
|
+
coef_name=coef_name,
|
|
1321
|
+
)
|
|
1322
|
+
if noncentered:
|
|
1323
|
+
term.reparam_noncentered()
|
|
1324
|
+
return term
|
|
1325
|
+
|
|
1326
|
+
def bs(
|
|
1327
|
+
self,
|
|
1328
|
+
x: str,
|
|
1329
|
+
*,
|
|
1330
|
+
k: int,
|
|
1331
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1332
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1333
|
+
basis_degree: int = 3,
|
|
1334
|
+
penalty_order: int | Sequence[int] = 2,
|
|
1335
|
+
knots: ArrayLike | None = None,
|
|
1336
|
+
absorb_cons: bool = True,
|
|
1337
|
+
diagonal_penalty: bool = True,
|
|
1338
|
+
scale_penalty: bool = True,
|
|
1339
|
+
noncentered: bool = False,
|
|
1340
|
+
) -> Term:
|
|
1341
|
+
basis = self.bases.bs(
|
|
1342
|
+
x=x,
|
|
1343
|
+
k=k,
|
|
1344
|
+
basis_degree=basis_degree,
|
|
1345
|
+
penalty_order=penalty_order,
|
|
1346
|
+
knots=knots,
|
|
1347
|
+
absorb_cons=absorb_cons,
|
|
1348
|
+
diagonal_penalty=diagonal_penalty,
|
|
1349
|
+
scale_penalty=scale_penalty,
|
|
1350
|
+
basis_name="B",
|
|
1351
|
+
)
|
|
1352
|
+
|
|
1353
|
+
fname = self.names.fname("bs", basis)
|
|
1354
|
+
|
|
1355
|
+
if isinstance(scale, VarIGPrior):
|
|
1356
|
+
scale = self._init_default_scale(
|
|
1357
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1358
|
+
)
|
|
1359
|
+
|
|
1360
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1361
|
+
term = Term(
|
|
1362
|
+
basis=basis,
|
|
1363
|
+
penalty=basis.penalty,
|
|
1364
|
+
scale=scale,
|
|
1365
|
+
name=fname,
|
|
1366
|
+
inference=inference,
|
|
1367
|
+
coef_name=coef_name,
|
|
1368
|
+
)
|
|
1369
|
+
if noncentered:
|
|
1370
|
+
term.reparam_noncentered()
|
|
1371
|
+
return term
|
|
1372
|
+
|
|
1373
|
+
# P-spline
|
|
1374
|
+
def ps(
|
|
1375
|
+
self,
|
|
1376
|
+
x: str,
|
|
1377
|
+
*,
|
|
1378
|
+
k: int,
|
|
1379
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1380
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1381
|
+
basis_degree: int = 3,
|
|
1382
|
+
penalty_order: int = 2,
|
|
1383
|
+
knots: ArrayLike | None = None,
|
|
1384
|
+
absorb_cons: bool = True,
|
|
1385
|
+
diagonal_penalty: bool = True,
|
|
1386
|
+
scale_penalty: bool = True,
|
|
1387
|
+
noncentered: bool = False,
|
|
1388
|
+
) -> Term:
|
|
1389
|
+
basis = self.bases.ps(
|
|
1390
|
+
x=x,
|
|
1391
|
+
k=k,
|
|
1392
|
+
basis_degree=basis_degree,
|
|
1393
|
+
penalty_order=penalty_order,
|
|
1394
|
+
knots=knots,
|
|
1395
|
+
absorb_cons=absorb_cons,
|
|
1396
|
+
diagonal_penalty=diagonal_penalty,
|
|
1397
|
+
scale_penalty=scale_penalty,
|
|
1398
|
+
basis_name="B",
|
|
1399
|
+
)
|
|
1400
|
+
|
|
1401
|
+
fname = self.names.fname("ps", basis)
|
|
1402
|
+
|
|
1403
|
+
if isinstance(scale, VarIGPrior):
|
|
1404
|
+
scale = self._init_default_scale(
|
|
1405
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1409
|
+
term = Term(
|
|
1410
|
+
basis=basis,
|
|
1411
|
+
penalty=basis.penalty,
|
|
1412
|
+
scale=scale,
|
|
1413
|
+
name=fname,
|
|
1414
|
+
inference=inference,
|
|
1415
|
+
coef_name=coef_name,
|
|
1416
|
+
)
|
|
1417
|
+
if noncentered:
|
|
1418
|
+
term.reparam_noncentered()
|
|
1419
|
+
return term
|
|
1420
|
+
|
|
1421
|
+
def cp(
|
|
1422
|
+
self,
|
|
1423
|
+
x: str,
|
|
1424
|
+
*,
|
|
1425
|
+
k: int,
|
|
1426
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1427
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1428
|
+
basis_degree: int = 3,
|
|
1429
|
+
penalty_order: int = 2,
|
|
1430
|
+
knots: ArrayLike | None = None,
|
|
1431
|
+
absorb_cons: bool = True,
|
|
1432
|
+
diagonal_penalty: bool = True,
|
|
1433
|
+
scale_penalty: bool = True,
|
|
1434
|
+
noncentered: bool = False,
|
|
1435
|
+
) -> Term:
|
|
1436
|
+
basis = self.bases.cp(
|
|
1437
|
+
x=x,
|
|
1438
|
+
k=k,
|
|
1439
|
+
basis_degree=basis_degree,
|
|
1440
|
+
penalty_order=penalty_order,
|
|
1441
|
+
knots=knots,
|
|
1442
|
+
absorb_cons=absorb_cons,
|
|
1443
|
+
diagonal_penalty=diagonal_penalty,
|
|
1444
|
+
scale_penalty=scale_penalty,
|
|
1445
|
+
basis_name="B",
|
|
1446
|
+
)
|
|
1447
|
+
|
|
1448
|
+
fname = self.names.fname("cp", basis)
|
|
1449
|
+
|
|
1450
|
+
if isinstance(scale, VarIGPrior):
|
|
1451
|
+
scale = self._init_default_scale(
|
|
1452
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1453
|
+
)
|
|
1454
|
+
|
|
1455
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1456
|
+
term = Term(
|
|
1457
|
+
basis=basis,
|
|
1458
|
+
penalty=basis.penalty,
|
|
1459
|
+
scale=scale,
|
|
1460
|
+
name=fname,
|
|
1461
|
+
inference=inference,
|
|
1462
|
+
coef_name=coef_name,
|
|
1463
|
+
)
|
|
1464
|
+
if noncentered:
|
|
1465
|
+
term.reparam_noncentered()
|
|
1466
|
+
return term
|
|
1467
|
+
|
|
1468
|
+
# random intercept
|
|
1469
|
+
def ri(
|
|
1470
|
+
self,
|
|
1471
|
+
cluster: str,
|
|
1472
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1473
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1474
|
+
penalty: ArrayLike | None = None,
|
|
1475
|
+
noncentered: bool = False,
|
|
1476
|
+
) -> RITerm:
|
|
1477
|
+
basis = self.bases.ri(cluster=cluster, basis_name="B", penalty=penalty)
|
|
1478
|
+
|
|
1479
|
+
fname = self.names.fname("ri", basis)
|
|
1480
|
+
if isinstance(scale, VarIGPrior):
|
|
1481
|
+
scale = self._init_default_scale(
|
|
1482
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1486
|
+
|
|
1487
|
+
term = RITerm(
|
|
1488
|
+
basis=basis,
|
|
1489
|
+
penalty=basis.penalty,
|
|
1490
|
+
coef_name=coef_name,
|
|
1491
|
+
inference=inference,
|
|
1492
|
+
scale=scale,
|
|
1493
|
+
name=fname,
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
if noncentered:
|
|
1497
|
+
term.reparam_noncentered()
|
|
1498
|
+
|
|
1499
|
+
mapping = self.bases.mappings[cluster]
|
|
1500
|
+
term.mapping = mapping
|
|
1501
|
+
term.labels = list(mapping.labels_to_integers_map)
|
|
1502
|
+
|
|
1503
|
+
return term
|
|
1504
|
+
|
|
1505
|
+
# random scaling
|
|
1506
|
+
def rs(
|
|
1507
|
+
self,
|
|
1508
|
+
x: str | Term | LinTerm,
|
|
1509
|
+
cluster: str,
|
|
1510
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1511
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1512
|
+
penalty: ArrayLike | None = None,
|
|
1513
|
+
noncentered: bool = False,
|
|
1514
|
+
) -> lsl.Var:
|
|
1515
|
+
ri = self.ri(
|
|
1516
|
+
cluster=cluster,
|
|
1517
|
+
scale=scale,
|
|
1518
|
+
inference=inference,
|
|
1519
|
+
penalty=penalty,
|
|
1520
|
+
noncentered=noncentered,
|
|
1521
|
+
)
|
|
1522
|
+
|
|
1523
|
+
if isinstance(x, str):
|
|
1524
|
+
x_var = self.registry.get_numeric_obs(x)
|
|
1525
|
+
xname = x
|
|
1526
|
+
else:
|
|
1527
|
+
x_var = x
|
|
1528
|
+
xname = x_var.basis.x.name
|
|
1529
|
+
|
|
1530
|
+
fname = self.names.create_lazily("rs(" + xname + "|" + cluster + ")")
|
|
1531
|
+
term = lsl.Var.new_calc(
|
|
1532
|
+
lambda x, cluster: x * cluster,
|
|
1533
|
+
x=x_var,
|
|
1534
|
+
cluster=ri,
|
|
1535
|
+
name=fname,
|
|
1536
|
+
)
|
|
1537
|
+
return term
|
|
1538
|
+
|
|
1539
|
+
# varying coefficient
|
|
1540
|
+
def vc(
|
|
1541
|
+
self,
|
|
1542
|
+
x: str,
|
|
1543
|
+
by: Term,
|
|
1544
|
+
) -> lsl.Var:
|
|
1545
|
+
fname = self.names.create_lazily(x + "*" + by.name)
|
|
1546
|
+
x_var = self.registry.get_obs(x)
|
|
1547
|
+
|
|
1548
|
+
term = lsl.Var.new_calc(
|
|
1549
|
+
lambda x, by: x * by,
|
|
1550
|
+
x=x_var,
|
|
1551
|
+
by=by,
|
|
1552
|
+
name=fname,
|
|
1553
|
+
)
|
|
1554
|
+
return term
|
|
1555
|
+
|
|
1556
|
+
# general smooth with MGCV bases
|
|
1557
|
+
def s(
|
|
1558
|
+
self,
|
|
1559
|
+
*x: str,
|
|
1560
|
+
k: int,
|
|
1561
|
+
bs: BasisTypes,
|
|
1562
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1563
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1564
|
+
m: str = "NA",
|
|
1565
|
+
knots: ArrayLike | None = None,
|
|
1566
|
+
absorb_cons: bool = True,
|
|
1567
|
+
diagonal_penalty: bool = True,
|
|
1568
|
+
scale_penalty: bool = True,
|
|
1569
|
+
noncentered: bool = False,
|
|
1570
|
+
) -> Term:
|
|
1571
|
+
"""
|
|
1572
|
+
Works:
|
|
1573
|
+
- tp (thin plate splines)
|
|
1574
|
+
- ts (thin plate splines with slight null space penalty)
|
|
1575
|
+
|
|
1576
|
+
- cr (cubic regression splines)
|
|
1577
|
+
- cs (shrinked cubic regression splines)
|
|
1578
|
+
- cc (cyclic cubic regression splines)
|
|
1579
|
+
|
|
1580
|
+
- bs (B-splines)
|
|
1581
|
+
- ps (P-splines)
|
|
1582
|
+
- cp (cyclic P-splines)
|
|
1583
|
+
|
|
1584
|
+
Works, but not here:
|
|
1585
|
+
- re (use .ri instead)
|
|
1586
|
+
- mrf (used .mrf instead)
|
|
1587
|
+
- te (use .te instead) (with the bases above)
|
|
1588
|
+
- ti (use .ti instead) (with the bases above)
|
|
1589
|
+
|
|
1590
|
+
Does not work:
|
|
1591
|
+
- ds (Duchon splines)
|
|
1592
|
+
- sos (splines on the sphere)
|
|
1593
|
+
- gp (gaussian process)
|
|
1594
|
+
- so (soap film smooths)
|
|
1595
|
+
- ad (adaptive smooths)
|
|
1596
|
+
|
|
1597
|
+
Probably disallow manually:
|
|
1598
|
+
- fz (factor smooth interaction)
|
|
1599
|
+
- fs (random factor smooth interaction)
|
|
1600
|
+
"""
|
|
1601
|
+
basis = self.bases.s(
|
|
1602
|
+
*x,
|
|
1603
|
+
k=k,
|
|
1604
|
+
bs=bs,
|
|
1605
|
+
m=m,
|
|
1606
|
+
knots=knots,
|
|
1607
|
+
absorb_cons=absorb_cons,
|
|
1608
|
+
diagonal_penalty=diagonal_penalty,
|
|
1609
|
+
scale_penalty=scale_penalty,
|
|
1610
|
+
basis_name="B",
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
fname = self.names.fname(bs, basis=basis)
|
|
1614
|
+
|
|
1615
|
+
if isinstance(scale, VarIGPrior):
|
|
1616
|
+
scale = self._init_default_scale(
|
|
1617
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1618
|
+
)
|
|
1619
|
+
|
|
1620
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1621
|
+
term = Term(
|
|
1622
|
+
basis,
|
|
1623
|
+
penalty=basis.penalty,
|
|
1624
|
+
name=fname,
|
|
1625
|
+
coef_name=coef_name,
|
|
1626
|
+
scale=scale,
|
|
1627
|
+
inference=inference,
|
|
1628
|
+
)
|
|
1629
|
+
if noncentered:
|
|
1630
|
+
term.reparam_noncentered()
|
|
1631
|
+
return term
|
|
1632
|
+
|
|
1633
|
+
# markov random field
|
|
1634
|
+
def mrf(
|
|
1635
|
+
self,
|
|
1636
|
+
x: str,
|
|
1637
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1638
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1639
|
+
k: int = -1,
|
|
1640
|
+
polys: dict[str, ArrayLike] | None = None,
|
|
1641
|
+
nb: Mapping[str, ArrayLike | list[str] | list[int]] | None = None,
|
|
1642
|
+
penalty: ArrayLike | None = None,
|
|
1643
|
+
absorb_cons: bool = True,
|
|
1644
|
+
diagonal_penalty: bool = True,
|
|
1645
|
+
scale_penalty: bool = True,
|
|
1646
|
+
noncentered: bool = False,
|
|
1647
|
+
) -> MRFTerm:
|
|
1648
|
+
"""
|
|
1649
|
+
Polys: Dictionary of arrays. The keys of the dict are the region labels.
|
|
1650
|
+
The corresponding values define the region by defining polygons.
|
|
1651
|
+
nb: Dictionary of array. The keys of the dict are the region labels.
|
|
1652
|
+
The corresponding values indicate the neighbors of the region.
|
|
1653
|
+
If it is a list or array of strings, the values are the labels of the
|
|
1654
|
+
neighbors.
|
|
1655
|
+
If it is a list or array of integers, the values are the indices of the
|
|
1656
|
+
neighbors.
|
|
1657
|
+
|
|
1658
|
+
|
|
1659
|
+
mgcv does not concern itself with your category ordering. It *will* order
|
|
1660
|
+
categories alphabetically. Penalty columns have to take this into account.
|
|
1661
|
+
"""
|
|
1662
|
+
basis = self.bases.mrf(
|
|
1663
|
+
x=x,
|
|
1664
|
+
k=k,
|
|
1665
|
+
polys=polys,
|
|
1666
|
+
nb=nb,
|
|
1667
|
+
penalty=penalty,
|
|
1668
|
+
absorb_cons=absorb_cons,
|
|
1669
|
+
diagonal_penalty=diagonal_penalty,
|
|
1670
|
+
scale_penalty=scale_penalty,
|
|
1671
|
+
basis_name="B",
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
fname = self.names.fname("mrf", basis)
|
|
1675
|
+
if isinstance(scale, VarIGPrior):
|
|
1676
|
+
scale = self._init_default_scale(
|
|
1677
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1678
|
+
)
|
|
1679
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1680
|
+
term = MRFTerm(
|
|
1681
|
+
basis,
|
|
1682
|
+
penalty=basis.penalty,
|
|
1683
|
+
name=fname,
|
|
1684
|
+
scale=scale,
|
|
1685
|
+
inference=inference,
|
|
1686
|
+
coef_name=coef_name,
|
|
1687
|
+
)
|
|
1688
|
+
if noncentered:
|
|
1689
|
+
term.reparam_noncentered()
|
|
1690
|
+
|
|
1691
|
+
term.polygons = polys
|
|
1692
|
+
term.neighbors = basis.mrf_spec.nb
|
|
1693
|
+
if basis.mrf_spec.ordered_labels is not None:
|
|
1694
|
+
term.ordered_labels = basis.mrf_spec.ordered_labels
|
|
1695
|
+
|
|
1696
|
+
term.labels = list(basis.mrf_spec.mapping.labels_to_integers_map)
|
|
1697
|
+
term.mapping = basis.mrf_spec.mapping
|
|
1698
|
+
|
|
1699
|
+
return term
|
|
1700
|
+
|
|
1701
|
+
# general basis function + penalty smooth
|
|
1702
|
+
def f(
|
|
1703
|
+
self,
|
|
1704
|
+
*x: str,
|
|
1705
|
+
basis_fn: Callable[[Array], Array],
|
|
1706
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1707
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1708
|
+
use_callback: bool = True,
|
|
1709
|
+
cache_basis: bool = True,
|
|
1710
|
+
penalty: ArrayLike | None = None,
|
|
1711
|
+
noncentered: bool = False,
|
|
1712
|
+
) -> Term:
|
|
1713
|
+
basis = self.bases.basis(
|
|
1714
|
+
*x,
|
|
1715
|
+
basis_fn=basis_fn,
|
|
1716
|
+
use_callback=use_callback,
|
|
1717
|
+
cache_basis=cache_basis,
|
|
1718
|
+
penalty=penalty,
|
|
1719
|
+
basis_name="B",
|
|
1720
|
+
)
|
|
1721
|
+
|
|
1722
|
+
fname = self.names.fname("f", basis)
|
|
1723
|
+
if isinstance(scale, VarIGPrior):
|
|
1724
|
+
scale = self._init_default_scale(
|
|
1725
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1726
|
+
)
|
|
1727
|
+
|
|
1728
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1729
|
+
term = Term(
|
|
1730
|
+
basis,
|
|
1731
|
+
penalty=basis.penalty,
|
|
1732
|
+
name=fname,
|
|
1733
|
+
scale=scale,
|
|
1734
|
+
inference=inference,
|
|
1735
|
+
coef_name=coef_name,
|
|
1736
|
+
)
|
|
1737
|
+
if noncentered:
|
|
1738
|
+
term.reparam_noncentered()
|
|
1739
|
+
return term
|
|
1740
|
+
|
|
1741
|
+
def kriging(
|
|
1742
|
+
self,
|
|
1743
|
+
*x: str,
|
|
1744
|
+
k: int,
|
|
1745
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1746
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1747
|
+
kernel_name: Literal[
|
|
1748
|
+
"spherical",
|
|
1749
|
+
"power_exponential",
|
|
1750
|
+
"matern1.5",
|
|
1751
|
+
"matern2.5",
|
|
1752
|
+
"matern3.5",
|
|
1753
|
+
] = "matern1.5",
|
|
1754
|
+
linear_trend: bool = True,
|
|
1755
|
+
range: float | None = None,
|
|
1756
|
+
power_exponential_power: float = 1.0,
|
|
1757
|
+
knots: ArrayLike | None = None,
|
|
1758
|
+
absorb_cons: bool = True,
|
|
1759
|
+
diagonal_penalty: bool = True,
|
|
1760
|
+
scale_penalty: bool = True,
|
|
1761
|
+
noncentered: bool = False,
|
|
1762
|
+
) -> Term:
|
|
1763
|
+
basis = self.bases.kriging(
|
|
1764
|
+
*x,
|
|
1765
|
+
k=k,
|
|
1766
|
+
kernel_name=kernel_name,
|
|
1767
|
+
linear_trend=linear_trend,
|
|
1768
|
+
range=range,
|
|
1769
|
+
power_exponential_power=power_exponential_power,
|
|
1770
|
+
knots=knots,
|
|
1771
|
+
absorb_cons=absorb_cons,
|
|
1772
|
+
diagonal_penalty=diagonal_penalty,
|
|
1773
|
+
scale_penalty=scale_penalty,
|
|
1774
|
+
basis_name="B",
|
|
1775
|
+
)
|
|
1776
|
+
|
|
1777
|
+
fname = self.names.fname("kriging", basis)
|
|
1778
|
+
if isinstance(scale, VarIGPrior):
|
|
1779
|
+
scale = self._init_default_scale(
|
|
1780
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1781
|
+
)
|
|
1782
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1783
|
+
term = Term(
|
|
1784
|
+
basis,
|
|
1785
|
+
penalty=basis.penalty,
|
|
1786
|
+
name=fname,
|
|
1787
|
+
scale=scale,
|
|
1788
|
+
inference=inference,
|
|
1789
|
+
coef_name=coef_name,
|
|
1790
|
+
)
|
|
1791
|
+
if noncentered:
|
|
1792
|
+
term.reparam_noncentered()
|
|
1793
|
+
return term
|
|
1794
|
+
|
|
1795
|
+
def tp(
|
|
1796
|
+
self,
|
|
1797
|
+
*x: str,
|
|
1798
|
+
k: int,
|
|
1799
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1800
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1801
|
+
penalty_order: int | None = None,
|
|
1802
|
+
knots: ArrayLike | None = None,
|
|
1803
|
+
absorb_cons: bool = True,
|
|
1804
|
+
diagonal_penalty: bool = True,
|
|
1805
|
+
scale_penalty: bool = True,
|
|
1806
|
+
noncentered: bool = False,
|
|
1807
|
+
remove_null_space_completely: bool = False,
|
|
1808
|
+
) -> Term:
|
|
1809
|
+
basis = self.bases.tp(
|
|
1810
|
+
*x,
|
|
1811
|
+
k=k,
|
|
1812
|
+
penalty_order=penalty_order,
|
|
1813
|
+
knots=knots,
|
|
1814
|
+
absorb_cons=absorb_cons,
|
|
1815
|
+
diagonal_penalty=diagonal_penalty,
|
|
1816
|
+
scale_penalty=scale_penalty,
|
|
1817
|
+
basis_name="B",
|
|
1818
|
+
remove_null_space_completely=remove_null_space_completely,
|
|
1819
|
+
)
|
|
1820
|
+
|
|
1821
|
+
fname = self.names.fname("tp", basis)
|
|
1822
|
+
if isinstance(scale, VarIGPrior):
|
|
1823
|
+
scale = self._init_default_scale(
|
|
1824
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1828
|
+
term = Term(
|
|
1829
|
+
basis,
|
|
1830
|
+
penalty=basis.penalty,
|
|
1831
|
+
name=fname,
|
|
1832
|
+
scale=scale,
|
|
1833
|
+
inference=inference,
|
|
1834
|
+
coef_name=coef_name,
|
|
1835
|
+
)
|
|
1836
|
+
if noncentered:
|
|
1837
|
+
term.reparam_noncentered()
|
|
1838
|
+
return term
|
|
1839
|
+
|
|
1840
|
+
def ts(
|
|
1841
|
+
self,
|
|
1842
|
+
*x: str,
|
|
1843
|
+
k: int,
|
|
1844
|
+
scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
|
|
1845
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1846
|
+
penalty_order: int | None = None,
|
|
1847
|
+
knots: ArrayLike | None = None,
|
|
1848
|
+
absorb_cons: bool = True,
|
|
1849
|
+
diagonal_penalty: bool = True,
|
|
1850
|
+
scale_penalty: bool = True,
|
|
1851
|
+
noncentered: bool = False,
|
|
1852
|
+
) -> Term:
|
|
1853
|
+
basis = self.bases.ts(
|
|
1854
|
+
*x,
|
|
1855
|
+
k=k,
|
|
1856
|
+
penalty_order=penalty_order,
|
|
1857
|
+
knots=knots,
|
|
1858
|
+
absorb_cons=absorb_cons,
|
|
1859
|
+
diagonal_penalty=diagonal_penalty,
|
|
1860
|
+
scale_penalty=scale_penalty,
|
|
1861
|
+
basis_name="B",
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
fname = self.names.fname("ts", basis)
|
|
1865
|
+
if isinstance(scale, VarIGPrior):
|
|
1866
|
+
scale = self._init_default_scale(
|
|
1867
|
+
concentration=scale.concentration, scale=scale.scale, term_name=fname
|
|
1868
|
+
)
|
|
1869
|
+
|
|
1870
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1871
|
+
term = Term(
|
|
1872
|
+
basis,
|
|
1873
|
+
penalty=basis.penalty,
|
|
1874
|
+
name=fname,
|
|
1875
|
+
scale=scale,
|
|
1876
|
+
inference=inference,
|
|
1877
|
+
coef_name=coef_name,
|
|
1878
|
+
)
|
|
1879
|
+
if noncentered:
|
|
1880
|
+
term.reparam_noncentered()
|
|
1881
|
+
return term
|
|
1882
|
+
|
|
1883
|
+
def ta(
|
|
1884
|
+
self,
|
|
1885
|
+
*marginals: Term,
|
|
1886
|
+
common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
|
|
1887
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1888
|
+
include_main_effects: bool = False,
|
|
1889
|
+
scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
|
|
1890
|
+
_fname: str = "ta",
|
|
1891
|
+
) -> TPTerm:
|
|
1892
|
+
"""
|
|
1893
|
+
Will remove any default gibbs samplers and replace them with scales_inferece
|
|
1894
|
+
on a transformed version.
|
|
1895
|
+
"""
|
|
1896
|
+
inputs = ",".join(list(TPTerm._input_obs([t.basis for t in marginals])))
|
|
1897
|
+
fname = self.names.create_lazily(f"{_fname}(" + inputs + ")")
|
|
1898
|
+
coef_name = self.names.create_beta_name(fname)
|
|
1899
|
+
|
|
1900
|
+
if isinstance(common_scale, VarIGPrior):
|
|
1901
|
+
common_scale = self._init_default_scale(
|
|
1902
|
+
concentration=common_scale.concentration,
|
|
1903
|
+
scale=common_scale.scale,
|
|
1904
|
+
term_name=fname,
|
|
1905
|
+
)
|
|
1906
|
+
|
|
1907
|
+
if common_scale is not None and not isinstance(common_scale, float):
|
|
1908
|
+
_replace_star_gibbs_with(common_scale, scales_inference)
|
|
1909
|
+
|
|
1910
|
+
term = TPTerm(
|
|
1911
|
+
*marginals,
|
|
1912
|
+
common_scale=common_scale,
|
|
1913
|
+
name=fname,
|
|
1914
|
+
inference=inference,
|
|
1915
|
+
coef_name=coef_name,
|
|
1916
|
+
include_main_effects=include_main_effects,
|
|
1917
|
+
)
|
|
1918
|
+
|
|
1919
|
+
for scale in term.scales:
|
|
1920
|
+
if not isinstance(scale, lsl.Var):
|
|
1921
|
+
raise TypeError(
|
|
1922
|
+
f"Expected scale to be a liesel.model.Var, got {type(scale)}"
|
|
1923
|
+
)
|
|
1924
|
+
_replace_star_gibbs_with(scale, scales_inference)
|
|
1925
|
+
|
|
1926
|
+
return term
|
|
1927
|
+
|
|
1928
|
+
def tx(
|
|
1929
|
+
self,
|
|
1930
|
+
*marginals: Term,
|
|
1931
|
+
common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
|
|
1932
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1933
|
+
scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
|
|
1934
|
+
) -> TPTerm:
|
|
1935
|
+
return self.ta(
|
|
1936
|
+
*marginals,
|
|
1937
|
+
common_scale=common_scale,
|
|
1938
|
+
inference=inference,
|
|
1939
|
+
scales_inference=scales_inference,
|
|
1940
|
+
include_main_effects=False,
|
|
1941
|
+
_fname="tx",
|
|
1942
|
+
)
|
|
1943
|
+
|
|
1944
|
+
def tf(
|
|
1945
|
+
self,
|
|
1946
|
+
*marginals: Term,
|
|
1947
|
+
common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
|
|
1948
|
+
inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
|
|
1949
|
+
scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
|
|
1950
|
+
) -> TPTerm:
|
|
1951
|
+
return self.ta(
|
|
1952
|
+
*marginals,
|
|
1953
|
+
common_scale=common_scale,
|
|
1954
|
+
inference=inference,
|
|
1955
|
+
scales_inference=scales_inference,
|
|
1956
|
+
include_main_effects=True,
|
|
1957
|
+
_fname="tf",
|
|
1958
|
+
)
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
def _get_parameter(var: lsl.Var) -> lsl.Var:
|
|
1962
|
+
if var.strong:
|
|
1963
|
+
if var.parameter:
|
|
1964
|
+
return var
|
|
1965
|
+
else:
|
|
1966
|
+
raise ValueError(f"{var} is strong, but not a parameter.")
|
|
1967
|
+
|
|
1968
|
+
with TemporaryModel(var, to_float32=False) as model:
|
|
1969
|
+
params = model.parameters
|
|
1970
|
+
if not params:
|
|
1971
|
+
raise ValueError(f"No parameter found in the graph of {var}.")
|
|
1972
|
+
if len(params) > 1:
|
|
1973
|
+
raise ValueError(
|
|
1974
|
+
f"In the graph of {var}, there are {len(params)} parameters, "
|
|
1975
|
+
"so we cannot return a unique parameter."
|
|
1976
|
+
)
|
|
1977
|
+
param = list(model.parameters.values())[0]
|
|
1978
|
+
|
|
1979
|
+
return param
|
|
1980
|
+
|
|
1981
|
+
|
|
1982
|
+
def _replace_star_gibbs_with(var: lsl.Var, inference: InferenceTypes | None) -> lsl.Var:
|
|
1983
|
+
param = _get_parameter(var)
|
|
1984
|
+
if param.inference is not None:
|
|
1985
|
+
if isinstance(param.inference, gs.MCMCSpec):
|
|
1986
|
+
try:
|
|
1987
|
+
is_star_gibbs = param.inference.kernel.__name__ == "StarVarianceGibbs" # type: ignore
|
|
1988
|
+
if not is_star_gibbs:
|
|
1989
|
+
return var
|
|
1990
|
+
except AttributeError:
|
|
1991
|
+
# in this case, we assume that the inference has been set intentionally
|
|
1992
|
+
# so we don't change anything
|
|
1993
|
+
return var
|
|
1994
|
+
else:
|
|
1995
|
+
# in this case, we assume that the inference has been set intentionally
|
|
1996
|
+
# so we don't change anything
|
|
1997
|
+
return var
|
|
1998
|
+
if param.name:
|
|
1999
|
+
trafo_name = "h(" + param.name + ")"
|
|
2000
|
+
else:
|
|
2001
|
+
trafo_name = None
|
|
2002
|
+
param.transform(bijector=None, inference=inference, name=trafo_name)
|
|
2003
|
+
return var
|