liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2003 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections.abc import Callable, Mapping, Sequence
5
+ from dataclasses import dataclass, field
6
+ from math import ceil
7
+ from typing import Any, Literal, get_args
8
+
9
+ import formulaic as fo
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import liesel.goose as gs
13
+ import liesel.model as lsl
14
+ import numpy as np
15
+ import pandas as pd
16
+ import smoothcon as scon
17
+ from liesel.model.model import TemporaryModel
18
+ from ryp import r, to_py, to_r
19
+
20
+ from ..var import (
21
+ Basis,
22
+ LinBasis,
23
+ LinTerm,
24
+ MRFBasis,
25
+ MRFSpec,
26
+ MRFTerm,
27
+ RITerm,
28
+ ScaleIG,
29
+ Term,
30
+ TPTerm,
31
+ VarIGPrior,
32
+ )
33
+ from .registry import CategoryMapping, PandasRegistry
34
+
35
+ InferenceTypes = Any
36
+
37
+ Array = jax.Array
38
+ ArrayLike = jax.typing.ArrayLike
39
+
40
+ BasisTypes = Literal["tp", "ts", "cr", "cs", "cc", "bs", "ps", "cp", "gp"]
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ def _validate_bs(bs):
47
+ if isinstance(bs, str):
48
+ bs = [bs]
49
+ allowed = get_args(BasisTypes)
50
+ for bs_str in bs:
51
+ if bs_str not in allowed:
52
+ raise ValueError(f"Allowed values for 'bs' are: {allowed}; got {bs=}.")
53
+
54
+
55
+ def _margin_penalties(smooth: scon.SmoothCon):
56
+ """Extracts the marginal penalty matrices from a ti() smooth."""
57
+ # this should go into smoothcon, but it works here for now
58
+ r(
59
+ f"penalties_list <- lapply({smooth._smooth_r_name}"
60
+ "[[1]]$margin, function(x) x$S[[1]])"
61
+ )
62
+ pens = to_py("penalties_list")
63
+ return [pen.to_numpy() for pen in pens]
64
+
65
+
66
+ def _tp_penalty(K1, K2) -> Array:
67
+ """Computes the full tensor product penalty from the marginals."""
68
+ # this should go into smoothcon, but it works here for now
69
+ D1 = np.shape(K1)[1]
70
+ D2 = np.shape(K2)[1]
71
+ I1 = np.eye(D1)
72
+ I2 = np.eye(D2)
73
+
74
+ return jnp.asarray(jnp.kron(K1, I2) + jnp.kron(I1, K2))
75
+
76
+
77
+ def labels_to_integers(newdata: dict, mappings: dict[str, CategoryMapping]) -> dict:
78
+ # replace categorical inputs with their index representation
79
+ # create combined input matrices from individual variables, if desired
80
+ newdata = newdata.copy()
81
+
82
+ # replace categorical variables by their integer representations
83
+ for name, mapping in mappings.items():
84
+ if name in newdata:
85
+ newdata[name] = mapping.labels_to_integers(newdata[name])
86
+
87
+ return newdata
88
+
89
+
90
+ def assert_intercept_in_spec(spec: fo.ModelSpec) -> fo.ModelSpec:
91
+ """
92
+ Uses the degrees of the terms in the spec's formula to find intercepts.
93
+ The degree of a term indicates how many columns of the input data are referenced
94
+ by the term, so a degree of zero can be used to identify an intercept.
95
+ """
96
+ terms = list(spec.formula)
97
+ terms_with_degree_zero = [term for term in terms if term.degree == 0]
98
+
99
+ if len(terms_with_degree_zero) > 1:
100
+ raise RuntimeError(f"Too many intercepts: {len(terms_with_degree_zero)}.")
101
+ if len(terms_with_degree_zero) == 0:
102
+ raise RuntimeError(
103
+ "No intercept found in formula. Did you explicitly remove an "
104
+ "intercept by including '0' or '-1'? This breaks model matrix setup."
105
+ )
106
+
107
+ return spec
108
+
109
+
110
+ def validate_formula(formula: str) -> None:
111
+ if "~" in formula:
112
+ raise ValueError("'~' in formulas is not supported.")
113
+
114
+ terms = ["".join(x.split()) for x in formula.split("+")]
115
+ for term in terms:
116
+ if term == "1":
117
+ raise ValueError(
118
+ "Using '1 +' is not supported. To add an intercept, use the "
119
+ "argument 'include_intercept'."
120
+ )
121
+ if term == "0" or term == "-1":
122
+ raise ValueError(
123
+ "Using '0 +' or '-1' is not supported. Intercepts are not included "
124
+ "by default and can be added manually with the argument "
125
+ "'include_intercept'."
126
+ )
127
+
128
+
129
+ def validate_penalty_order(penalty_order: int):
130
+ if not isinstance(penalty_order, int):
131
+ raise TypeError(
132
+ f"'penalty_order' must be int or None, got {type(penalty_order)}"
133
+ )
134
+ if not penalty_order > 0:
135
+ raise ValueError(f"'penalty_order' must be >0, got {penalty_order}")
136
+
137
+
138
+ class BasisBuilder:
139
+ def __init__(
140
+ self, registry: PandasRegistry, names: NameManager | None = None
141
+ ) -> None:
142
+ self.registry = registry
143
+ self.mappings: dict[str, CategoryMapping] = {}
144
+ self.names = NameManager() if names is None else names
145
+
146
+ @property
147
+ def data(self) -> pd.DataFrame:
148
+ return self.registry.data
149
+
150
+ def basis(
151
+ self,
152
+ *x: str,
153
+ basis_fn: Callable[[Array], Array] = lambda x: x,
154
+ use_callback: bool = True,
155
+ cache_basis: bool = True,
156
+ penalty: ArrayLike | lsl.Value | None = None,
157
+ basis_name: str = "B",
158
+ ) -> Basis:
159
+ if isinstance(penalty, lsl.Value):
160
+ penalty.value = jnp.asarray(penalty.value)
161
+ elif penalty is not None:
162
+ penalty = jnp.asarray(penalty)
163
+
164
+ x_vars = []
165
+ for x_name in x:
166
+ x_var = self.registry.get_numeric_obs(x_name)
167
+ x_vars.append(x_var)
168
+
169
+ Xname = self.registry.prefix + ",".join(x)
170
+
171
+ Xvar = lsl.TransientCalc(
172
+ lambda *x: jnp.column_stack(x),
173
+ *x_vars,
174
+ _name=Xname,
175
+ )
176
+
177
+ basis = Basis(
178
+ value=Xvar,
179
+ basis_fn=basis_fn,
180
+ name=self.names.create_lazily(basis_name + "(" + Xname + ")"),
181
+ use_callback=use_callback,
182
+ cache_basis=cache_basis,
183
+ penalty=jnp.asarray(penalty),
184
+ )
185
+
186
+ return basis
187
+
188
+ def ps(
189
+ self,
190
+ x: str,
191
+ *,
192
+ k: int,
193
+ basis_degree: int = 3,
194
+ penalty_order: int = 2,
195
+ knots: ArrayLike | None = None,
196
+ absorb_cons: bool = True,
197
+ diagonal_penalty: bool = True,
198
+ scale_penalty: bool = True,
199
+ basis_name: str = "B",
200
+ ) -> Basis:
201
+ validate_penalty_order(penalty_order)
202
+ if knots is not None:
203
+ knots = np.asarray(knots)
204
+
205
+ spec = f"s({x}, bs='ps', k={k}, m=c({basis_degree - 1}, {penalty_order}))"
206
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
207
+ smooth = scon.SmoothCon(
208
+ spec,
209
+ data={x: x_array},
210
+ knots=knots,
211
+ absorb_cons=absorb_cons,
212
+ diagonal_penalty=diagonal_penalty,
213
+ scale_penalty=scale_penalty,
214
+ )
215
+
216
+ x_var = self.registry.get_numeric_obs(x)
217
+ basis = Basis(
218
+ x_var,
219
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
220
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
221
+ penalty=smooth.penalty,
222
+ use_callback=True,
223
+ cache_basis=True,
224
+ )
225
+
226
+ if absorb_cons:
227
+ basis._constraint = "absorbed_via_mgcv"
228
+ return basis
229
+
230
+ def cr(
231
+ self,
232
+ x: str,
233
+ *,
234
+ k: int,
235
+ penalty_order: int = 2,
236
+ knots: ArrayLike | None = None,
237
+ absorb_cons: bool = True,
238
+ diagonal_penalty: bool = True,
239
+ scale_penalty: bool = True,
240
+ basis_name: str = "B",
241
+ ) -> Basis:
242
+ validate_penalty_order(penalty_order)
243
+ if knots is not None:
244
+ knots = np.asarray(knots)
245
+ spec = f"s({x}, bs='cr', k={k}, m=c({penalty_order}))"
246
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
247
+ smooth = scon.SmoothCon(
248
+ spec,
249
+ data={x: x_array},
250
+ knots=knots,
251
+ absorb_cons=absorb_cons,
252
+ diagonal_penalty=diagonal_penalty,
253
+ scale_penalty=scale_penalty,
254
+ )
255
+
256
+ x_var = self.registry.get_numeric_obs(x)
257
+ basis = Basis(
258
+ x_var,
259
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
260
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
261
+ penalty=smooth.penalty,
262
+ use_callback=True,
263
+ cache_basis=True,
264
+ )
265
+
266
+ if absorb_cons:
267
+ basis._constraint = "absorbed_via_mgcv"
268
+ return basis
269
+
270
+ def cs(
271
+ self,
272
+ x: str,
273
+ *,
274
+ k: int,
275
+ penalty_order: int = 2,
276
+ knots: ArrayLike | None = None,
277
+ absorb_cons: bool = True,
278
+ diagonal_penalty: bool = True,
279
+ scale_penalty: bool = True,
280
+ basis_name: str = "B",
281
+ ) -> Basis:
282
+ """
283
+ s(x,bs="cs") specifies a penalized cubic regression spline which has had its
284
+ penalty modified to shrink towards zero at high enough smoothing parameters (as
285
+ the smoothing parameter goes to infinity a normal cubic spline tends to a
286
+ straight line.)
287
+ """
288
+ validate_penalty_order(penalty_order)
289
+ if knots is not None:
290
+ knots = np.asarray(knots)
291
+ spec = f"s({x}, bs='cs', k={k}, m=c({penalty_order}))"
292
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
293
+ smooth = scon.SmoothCon(
294
+ spec,
295
+ data={x: x_array},
296
+ knots=knots,
297
+ absorb_cons=absorb_cons,
298
+ diagonal_penalty=diagonal_penalty,
299
+ scale_penalty=scale_penalty,
300
+ )
301
+
302
+ x_var = self.registry.get_numeric_obs(x)
303
+ basis = Basis(
304
+ x_var,
305
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
306
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
307
+ penalty=smooth.penalty,
308
+ use_callback=True,
309
+ cache_basis=True,
310
+ )
311
+
312
+ if absorb_cons:
313
+ basis._constraint = "absorbed_via_mgcv"
314
+ return basis
315
+
316
+ def cc(
317
+ self,
318
+ x: str,
319
+ *,
320
+ k: int,
321
+ penalty_order: int = 2,
322
+ knots: ArrayLike | None = None,
323
+ absorb_cons: bool = True,
324
+ diagonal_penalty: bool = True,
325
+ scale_penalty: bool = True,
326
+ basis_name: str = "B",
327
+ ) -> Basis:
328
+ validate_penalty_order(penalty_order)
329
+ if knots is not None:
330
+ knots = np.asarray(knots)
331
+ spec = f"s({x}, bs='cc', k={k}, m=c({penalty_order}))"
332
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
333
+ smooth = scon.SmoothCon(
334
+ spec,
335
+ data={x: x_array},
336
+ knots=knots,
337
+ absorb_cons=absorb_cons,
338
+ diagonal_penalty=diagonal_penalty,
339
+ scale_penalty=scale_penalty,
340
+ )
341
+
342
+ x_var = self.registry.get_numeric_obs(x)
343
+ basis = Basis(
344
+ x_var,
345
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
346
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
347
+ penalty=smooth.penalty,
348
+ use_callback=True,
349
+ cache_basis=True,
350
+ )
351
+
352
+ if absorb_cons:
353
+ basis._constraint = "absorbed_via_mgcv"
354
+ return basis
355
+
356
+ def bs(
357
+ self,
358
+ x: str,
359
+ *,
360
+ k: int,
361
+ basis_degree: int = 3,
362
+ penalty_order: int | Sequence[int] = 2,
363
+ knots: ArrayLike | None = None,
364
+ absorb_cons: bool = True,
365
+ diagonal_penalty: bool = True,
366
+ scale_penalty: bool = True,
367
+ basis_name: str = "B",
368
+ ) -> Basis:
369
+ """
370
+ The integrated square of the m[2]th derivative is used as the penalty. So
371
+ m=c(3,2) is a conventional cubic spline. Any further elements of m, after the
372
+ first 2, define the order of derivative in further penalties. If m is supplied
373
+ as a single number, then it is taken to be m[1] and m[2]=m[1]-1, which is only a
374
+ conventional smoothing spline in the m=3, cubic spline case.
375
+ """
376
+ if knots is not None:
377
+ knots = np.asarray(knots)
378
+ if isinstance(penalty_order, int):
379
+ validate_penalty_order(penalty_order)
380
+ penalty_order_seq: Sequence[str] = [str(penalty_order)]
381
+ else:
382
+ [validate_penalty_order(p) for p in penalty_order]
383
+ penalty_order_seq = [str(p) for p in penalty_order]
384
+
385
+ spec = (
386
+ f"s({x}, bs='bs', k={k}, "
387
+ f"m=c({basis_degree}, {', '.join(penalty_order_seq)}))"
388
+ )
389
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
390
+ smooth = scon.SmoothCon(
391
+ spec,
392
+ data={x: x_array},
393
+ knots=knots,
394
+ absorb_cons=absorb_cons,
395
+ diagonal_penalty=diagonal_penalty,
396
+ scale_penalty=scale_penalty,
397
+ )
398
+
399
+ x_var = self.registry.get_numeric_obs(x)
400
+ basis = Basis(
401
+ x_var,
402
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
403
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
404
+ penalty=smooth.penalty,
405
+ use_callback=True,
406
+ cache_basis=True,
407
+ )
408
+
409
+ if absorb_cons:
410
+ basis._constraint = "absorbed_via_mgcv"
411
+ return basis
412
+
413
+ def cp(
414
+ self,
415
+ x: str,
416
+ *,
417
+ k: int,
418
+ basis_degree: int = 3,
419
+ penalty_order: int = 2,
420
+ knots: ArrayLike | None = None,
421
+ absorb_cons: bool = True,
422
+ diagonal_penalty: bool = True,
423
+ scale_penalty: bool = True,
424
+ basis_name: str = "B",
425
+ ) -> Basis:
426
+ validate_penalty_order(penalty_order)
427
+ if knots is not None:
428
+ knots = np.asarray(knots)
429
+ spec = f"s({x}, bs='cp', k={k}, m=c({basis_degree - 1}, {penalty_order}))"
430
+ x_array = jnp.asarray(self.registry.data[x].to_numpy())
431
+ smooth = scon.SmoothCon(
432
+ spec,
433
+ data={x: x_array},
434
+ knots=knots,
435
+ absorb_cons=absorb_cons,
436
+ diagonal_penalty=diagonal_penalty,
437
+ scale_penalty=scale_penalty,
438
+ )
439
+
440
+ x_var = self.registry.get_numeric_obs(x)
441
+ basis = Basis(
442
+ x_var,
443
+ name=self.names.create_lazily(basis_name + "(" + x_var.name + ")"),
444
+ basis_fn=lambda x_: jnp.asarray(smooth.predict({x: x_})),
445
+ penalty=smooth.penalty,
446
+ use_callback=True,
447
+ cache_basis=True,
448
+ )
449
+
450
+ if absorb_cons:
451
+ basis._constraint = "absorbed_via_mgcv"
452
+ return basis
453
+
454
+ def s(
455
+ self,
456
+ *x: str,
457
+ k: int,
458
+ bs: BasisTypes,
459
+ m: str = "NA",
460
+ knots: ArrayLike | None = None,
461
+ absorb_cons: bool = True,
462
+ diagonal_penalty: bool = True,
463
+ scale_penalty: bool = True,
464
+ basis_name: str = "B",
465
+ ) -> Basis:
466
+ if knots is not None:
467
+ knots = np.asarray(knots)
468
+ _validate_bs(bs)
469
+ bs_arg = f"'{bs}'"
470
+ spec = f"s({','.join(x)}, bs={bs_arg}, k={k}, m={m})"
471
+
472
+ obs_vars = {}
473
+ for xname in x:
474
+ obs_vars[xname] = self.registry.get_numeric_obs(xname)
475
+ obs_values = {k: np.asarray(v.value) for k, v in obs_vars.items()}
476
+
477
+ smooth = scon.SmoothCon(
478
+ spec,
479
+ data=pd.DataFrame.from_dict(obs_values),
480
+ knots=knots,
481
+ absorb_cons=absorb_cons,
482
+ diagonal_penalty=diagonal_penalty,
483
+ scale_penalty=scale_penalty,
484
+ )
485
+
486
+ xname = ",".join([v.name for v in obs_vars.values()])
487
+
488
+ if len(obs_vars) > 1:
489
+ xvar: lsl.Var | lsl.TransientCalc = (
490
+ lsl.TransientCalc( # for memory-efficiency
491
+ lambda *args: jnp.vstack(args).T,
492
+ *list(obs_vars.values()),
493
+ _name=self.names.create_lazily(xname),
494
+ )
495
+ )
496
+ else:
497
+ xvar = obs_vars[xname]
498
+
499
+ def basis_fn(x):
500
+ df = pd.DataFrame(x, columns=list(obs_vars))
501
+ return jnp.asarray(smooth.predict(df))
502
+
503
+ basis = Basis(
504
+ xvar,
505
+ name=self.names.create_lazily(basis_name + "(" + xname + ")"),
506
+ basis_fn=basis_fn,
507
+ penalty=smooth.penalty,
508
+ use_callback=True,
509
+ cache_basis=True,
510
+ )
511
+ if absorb_cons:
512
+ basis._constraint = "absorbed_via_mgcv"
513
+ return basis
514
+
515
+ def tp(
516
+ self,
517
+ *x: str,
518
+ k: int,
519
+ penalty_order: int | None = None,
520
+ knots: ArrayLike | None = None,
521
+ absorb_cons: bool = True,
522
+ diagonal_penalty: bool = True,
523
+ scale_penalty: bool = True,
524
+ basis_name: str = "B",
525
+ remove_null_space_completely: bool = False,
526
+ ) -> Basis:
527
+ """
528
+ For penalty_order:
529
+ m = penalty_order
530
+ Quote from MGCV docs
531
+ The default is to set m (the order of derivative in the thin plate spline
532
+ penalty) to the smallest value satisfying 2m > d+1 where d is the number of
533
+ covariates of the term: this yields ‘visually smooth’ functions.
534
+ In any case 2m>d must be satisfied.
535
+ """
536
+ d = len(x)
537
+ m_args = []
538
+ if penalty_order is None:
539
+ penalty_order_default = ceil((d + 1) / 2)
540
+ i = 0
541
+ while not 2 * penalty_order_default > (d + 1) and i < 20:
542
+ penalty_order_default += 1
543
+ i += 1
544
+
545
+ m_args.append(str(penalty_order_default))
546
+ else:
547
+ validate_penalty_order(penalty_order)
548
+ m_args.append(str(penalty_order))
549
+
550
+ if remove_null_space_completely:
551
+ m_args.append("0")
552
+ m_str = "c(" + ", ".join(m_args) + ")"
553
+
554
+ basis = self.s(
555
+ *x,
556
+ k=k,
557
+ bs="tp",
558
+ m=m_str,
559
+ knots=knots,
560
+ absorb_cons=absorb_cons,
561
+ diagonal_penalty=diagonal_penalty,
562
+ scale_penalty=scale_penalty,
563
+ basis_name=basis_name,
564
+ )
565
+ return basis
566
+
567
+ def ts(
568
+ self,
569
+ *x: str,
570
+ k: int,
571
+ penalty_order: int | None = None,
572
+ knots: ArrayLike | None = None,
573
+ absorb_cons: bool = True,
574
+ diagonal_penalty: bool = True,
575
+ scale_penalty: bool = True,
576
+ basis_name: str = "B",
577
+ ) -> Basis:
578
+ """
579
+ For penalty_order:
580
+ m = penalty_order
581
+ Quote from MGCV docs
582
+ The default is to set m (the order of derivative in the thin plate spline
583
+ penalty) to the smallest value satisfying 2m > d+1 where d is the number of
584
+ covariates of the term: this yields ‘visually smooth’ functions.
585
+ In any case 2m>d must be satisfied.
586
+ """
587
+ d = len(x)
588
+ m_args = []
589
+ if not penalty_order:
590
+ m_args.append(str(ceil((d + 1) / 2)))
591
+ else:
592
+ validate_penalty_order(penalty_order)
593
+ m_args.append(str(penalty_order))
594
+
595
+ m_str = "c(" + ", ".join(m_args) + ")"
596
+
597
+ basis = self.s(
598
+ *x,
599
+ k=k,
600
+ bs="ts",
601
+ m=m_str,
602
+ knots=knots,
603
+ absorb_cons=absorb_cons,
604
+ diagonal_penalty=diagonal_penalty,
605
+ scale_penalty=scale_penalty,
606
+ basis_name=basis_name,
607
+ )
608
+ return basis
609
+
610
+ def kriging(
611
+ self,
612
+ *x: str,
613
+ k: int,
614
+ kernel_name: Literal[
615
+ "spherical",
616
+ "power_exponential",
617
+ "matern1.5",
618
+ "matern2.5",
619
+ "matern3.5",
620
+ ] = "matern1.5",
621
+ linear_trend: bool = True,
622
+ range: float | None = None,
623
+ power_exponential_power: float = 1.0,
624
+ knots: ArrayLike | None = None,
625
+ absorb_cons: bool = True,
626
+ diagonal_penalty: bool = True,
627
+ scale_penalty: bool = True,
628
+ basis_name: str = "B",
629
+ ) -> Basis:
630
+ """
631
+
632
+ - If range=None, the range parameter will be estimated as in Kammann and \
633
+ Wand (2003)
634
+ """
635
+ m_kernel_dict = {
636
+ "spherical": 1,
637
+ "power_exponential": 2,
638
+ "matern1.5": 3,
639
+ "matern2.5": 4,
640
+ "matern3.5": 5,
641
+ }
642
+ m_linear = 1.0 if linear_trend else -1.0
643
+
644
+ m_args = []
645
+ m_kernel = str(int(m_linear * m_kernel_dict[kernel_name]))
646
+ m_args.append(m_kernel)
647
+ if range:
648
+ m_range = str(range)
649
+ m_args.append(m_range)
650
+ if power_exponential_power:
651
+ if not range:
652
+ m_args.append(str(-1.0))
653
+ if not 0.0 < power_exponential_power <= 2.0:
654
+ raise ValueError(
655
+ "'power_exponential_power' must be in (0, 2.0], "
656
+ f"got {power_exponential_power}"
657
+ )
658
+ m_args.append(str(power_exponential_power))
659
+
660
+ m_str = "c(" + ", ".join(m_args) + ")"
661
+
662
+ basis = self.s(
663
+ *x,
664
+ k=k,
665
+ bs="gp",
666
+ m=m_str,
667
+ knots=knots,
668
+ absorb_cons=absorb_cons,
669
+ diagonal_penalty=diagonal_penalty,
670
+ scale_penalty=scale_penalty,
671
+ basis_name=basis_name,
672
+ )
673
+
674
+ return basis
675
+
676
+ def lin(
677
+ self,
678
+ formula: str,
679
+ xname: str = "",
680
+ basis_name: str = "X",
681
+ include_intercept: bool = False,
682
+ context: dict[str, Any] | None = None,
683
+ ) -> LinBasis:
684
+ validate_formula(formula)
685
+ spec = fo.ModelSpec(formula, output="numpy")
686
+
687
+ if not include_intercept:
688
+ # because we do our own intercept handling with the full model matrix
689
+ # it may be surprising to assert that there is an intercept only if
690
+ # the plan is to remove it.
691
+ # But in order to safely remove it, we first have to ensure that it is
692
+ # present.
693
+ assert_intercept_in_spec(spec)
694
+
695
+ # evaluate model matrix once to get a spec with structure information
696
+ # also necessary to populate spec with the correct information for
697
+ # transformations like center, scale, standardize
698
+ spec = spec.get_model_matrix(self.data, context=context).model_spec
699
+
700
+ # get column names. There may be a more efficient way to do it
701
+ # that does not require building the model matrix a second time, but this
702
+ # works robustly for now: we take the names that formulaic creates
703
+ column_names = list(
704
+ fo.ModelSpec(formula, output="pandas")
705
+ .get_model_matrix(self.data, context=context)
706
+ .columns
707
+ )[1:]
708
+
709
+ required = sorted(str(var) for var in spec.required_variables)
710
+ df_subset = self.data.loc[:, required]
711
+ df_colnames = df_subset.columns
712
+
713
+ variables = dict()
714
+
715
+ mappings = {}
716
+ for col in df_colnames:
717
+ result = self.registry.get_obs_and_mapping(col)
718
+ variables[col] = result.var
719
+
720
+ if result.mapping is not None:
721
+ self.mappings[col] = result.mapping
722
+ mappings[col] = result.mapping
723
+
724
+ xvar = lsl.TransientCalc( # for memory-efficiency
725
+ lambda *args: jnp.vstack(args).T,
726
+ *list(variables.values()),
727
+ _name=self.names.create_lazily(xname) if xname else xname,
728
+ )
729
+
730
+ def basis_fn(x):
731
+ df = pd.DataFrame(x, columns=df_colnames)
732
+
733
+ # for categorical variables: convert integer representation back to
734
+ # labels
735
+ for col in df_colnames:
736
+ if col in self.mappings:
737
+ integers = df[col].to_numpy()
738
+ df[col] = self.mappings[col].integers_to_labels(integers)
739
+
740
+ basis = np.asarray(spec.get_model_matrix(df, context=context))
741
+ if not include_intercept:
742
+ basis = basis[:, 1:]
743
+ return jnp.asarray(basis, dtype=float)
744
+
745
+ if xname:
746
+ bname = self.names.create_lazily(basis_name + "(" + xvar.name + ")")
747
+ else:
748
+ bname = self.names.create_lazily(basis_name)
749
+
750
+ basis = LinBasis(
751
+ xvar,
752
+ basis_fn=basis_fn,
753
+ use_callback=True,
754
+ cache_basis=True,
755
+ name=bname,
756
+ )
757
+
758
+ basis.model_spec = spec
759
+ basis.mappings = mappings
760
+ basis.column_names = column_names
761
+
762
+ return basis
763
+
764
+ def ri(
765
+ self,
766
+ cluster: str,
767
+ basis_name: str = "B",
768
+ penalty: ArrayLike | None = None,
769
+ ) -> Basis:
770
+ if penalty is not None:
771
+ penalty = jnp.asarray(penalty)
772
+ result = self.registry.get_obs_and_mapping(cluster)
773
+ if result.mapping is None:
774
+ raise TypeError(f"{cluster=} must be categorical.")
775
+
776
+ self.mappings[cluster] = result.mapping
777
+ nparams = len(result.mapping.labels_to_integers_map)
778
+
779
+ if penalty is None:
780
+ penalty = jnp.eye(nparams)
781
+
782
+ basis = Basis(
783
+ value=result.var,
784
+ basis_fn=lambda x: x,
785
+ name=self.names.create_lazily(basis_name + "(" + cluster + ")"),
786
+ use_callback=False,
787
+ cache_basis=False,
788
+ penalty=jnp.asarray(penalty) if penalty is not None else penalty,
789
+ )
790
+
791
+ return basis
792
+
793
+ def mrf(
794
+ self,
795
+ x: str,
796
+ k: int = -1,
797
+ polys: dict[str, ArrayLike] | None = None,
798
+ nb: Mapping[str, ArrayLike | list[str] | list[int]] | None = None,
799
+ penalty: ArrayLike | None = None,
800
+ penalty_labels: Sequence[str] | None = None,
801
+ absorb_cons: bool = False,
802
+ diagonal_penalty: bool = False,
803
+ scale_penalty: bool = False,
804
+ basis_name: str = "B",
805
+ ) -> MRFBasis:
806
+ """
807
+ Polys: Dictionary of arrays. The keys of the dict are the region labels.
808
+ The corresponding values define the region by defining polygons.
809
+ nb: Dictionary of array. The keys of the dict are the region labels.
810
+ The corresponding values indicate the neighbors of the region.
811
+ If it is a list or array of strings, the values are the labels of the
812
+ neighbors.
813
+ If it is a list or array of integers, the values are the indices of the
814
+ neighbors.
815
+
816
+
817
+ mgcv does not concern itself with your category ordering. It *will* order
818
+ categories alphabetically. Penalty columns have to take this into account.
819
+
820
+ Comments on return value:
821
+
822
+ - If either polys or nb are supplied, the returned container will contain nb.
823
+ - If only a penalty matrix is supplied, the returned container will *not*
824
+ contain nb.
825
+ - Returning the label order only makes sense if the basis is *not*
826
+ reparameterized, because only then we have a clear correspondence of
827
+ parameters to labels.
828
+ If the basis is reparameterized, there's no such correspondence in a clear
829
+ way, so the returned label order is None.
830
+
831
+ """
832
+
833
+ if not isinstance(k, int):
834
+ raise TypeError(f"'k' must be int, got {type(k)}.")
835
+ if k < -1:
836
+ raise ValueError(f"'k' cannot be smaller than -1, got {k=}.")
837
+
838
+ if polys is None and nb is None and penalty is None:
839
+ raise ValueError("At least one of polys, nb, or penalty must be provided.")
840
+
841
+ var, mapping = self.registry.get_categorical_obs(x)
842
+ self.mappings[x] = mapping
843
+
844
+ labels = set(list(mapping.labels_to_integers_map))
845
+
846
+ if penalty is not None:
847
+ if penalty_labels is None:
848
+ raise ValueError(
849
+ "If 'penalty' is supplied, 'penalty_labels' must also be supplied."
850
+ )
851
+ if len(penalty_labels) != len(labels):
852
+ raise ValueError(
853
+ f"Variable {x} has {len(labels)} unique entries, but "
854
+ f"'penalty_labels' has {len(penalty_labels)}. Both must match."
855
+ )
856
+
857
+ xt_args = []
858
+ pass_to_r: dict[str, np.typing.NDArray | dict[str, np.typing.NDArray]] = {}
859
+ if polys is not None:
860
+ xt_args.append("polys=polys")
861
+ if not labels == set(list(polys)):
862
+ raise ValueError(
863
+ "Names in 'poly' must correspond to the levels of 'x'."
864
+ )
865
+ pass_to_r["polys"] = {key: np.asarray(val) for key, val in polys.items()}
866
+
867
+ if nb is not None:
868
+ xt_args.append("nb=nb")
869
+ if not labels == set(list(nb)):
870
+ raise ValueError("Names in 'nb' must correspond to the levels of 'x'.")
871
+
872
+ nb_processed = {}
873
+ for key, val in nb.items():
874
+ val_arr = np.asarray(val)
875
+ if np.isdtype(val_arr.dtype, np.dtype("int")):
876
+ # add one to convert to 1-based indexing for R
877
+ # and cast to float for R
878
+ val_arr = np.astype(val_arr + 1, float)
879
+ # val_arr = np.astype(val_arr, float)
880
+ elif np.isdtype(val_arr.dtype, np.dtype("float")):
881
+ # add one to convert to 1-based indexing for R
882
+ val_arr = np.astype(np.astype(val_arr, int) + 1, float)
883
+ elif val_arr.dtype.kind == "U": # must be unicode strings then
884
+ pass
885
+ else:
886
+ raise TypeError(f"Unsupported dtype: {val_arr.dtype!r}")
887
+
888
+ nb_processed[key] = val_arr
889
+
890
+ pass_to_r["nb"] = nb_processed
891
+
892
+ if penalty is not None:
893
+ penalty = np.asarray(penalty)
894
+ pen_rank = np.linalg.matrix_rank(penalty)
895
+ pen_dim = penalty.shape[-1]
896
+ if (pen_dim - pen_rank) != 1:
897
+ logger.warning(
898
+ f"Supplied penalty has dimension {penalty.shape} and rank "
899
+ f"{pen_rank}. The expected rank deficiency is 1. "
900
+ "This may indicate a problem. There might be disconnected sets "
901
+ "of regions in the data represented by this penalty. "
902
+ "In this case, you probably need more elaborate constraints "
903
+ "than the ones provided here. You might consider splitting the "
904
+ "disconnected regions into several mrf terms. "
905
+ "Otherwise, please only continue if you are certain that you "
906
+ "know what is happening."
907
+ )
908
+
909
+ xt_args.append("penalty=penalty")
910
+ if not np.shape(penalty)[0] == np.shape(penalty)[1]:
911
+ raise ValueError(f"Penalty must be square, got {np.shape(penalty)=}")
912
+
913
+ if not np.shape(penalty)[1] == len(labels):
914
+ raise ValueError(
915
+ "Dimensions of 'penalty' must correspond to the levels of 'x'."
916
+ )
917
+ pass_to_r["penalty"] = penalty
918
+
919
+ xt = "list("
920
+ xt += ",".join(xt_args)
921
+ xt += ")"
922
+
923
+ if penalty is not None:
924
+ # removing penalty from the pass_to_r dict, because we are giving it
925
+ # special treatment here.
926
+ # specifically, we have to equip it with row and column names to make
927
+ # sure that penalty entries get correctly matched to clusters by mgcv
928
+ penalty_prelim_arr = np.asarray(pass_to_r.pop("penalty"))
929
+ to_r(penalty_prelim_arr, "penalty")
930
+ to_r(np.array(penalty_labels), "penalty_labels")
931
+ r("colnames(penalty) <- penalty_labels")
932
+ r("rownames(penalty) <- penalty_labels")
933
+
934
+ spec = f"s({x}, k={k}, bs='mrf', xt={xt})"
935
+
936
+ # disabling warnings about "mrf should be a factor"
937
+ # since even turning data into a pandas df and x_array into
938
+ # a categorical series did not satisfy mgcv in that regard.
939
+ # Things still seem to work, and we ensure further above
940
+ # that we are actually dealing with a categorical variable
941
+ # so I think turning the warnings off temporarily here is fine
942
+ # r("old_warn <- getOption('warn')")
943
+ # r("options(warn = -1)")
944
+ observed = mapping.integers_to_labels(var.value)
945
+ regions = list(mapping.labels_to_integers_map)
946
+ df = pd.DataFrame({x: pd.Categorical(observed, categories=regions)})
947
+
948
+ smooth = scon.SmoothCon(
949
+ spec,
950
+ data=df,
951
+ diagonal_penalty=diagonal_penalty,
952
+ absorb_cons=absorb_cons,
953
+ scale_penalty=scale_penalty,
954
+ pass_to_r=pass_to_r,
955
+ )
956
+ # r("options(warn = old_warn)")
957
+
958
+ x_name = x
959
+
960
+ def basis_fun(x):
961
+ """
962
+ The array outputted by this smooth contains column names.
963
+ Here, we remove these column names and convert to jax.
964
+ """
965
+ # disabling warnings about "mrf should be a factor"
966
+ r("old_warn <- getOption('warn')")
967
+ r("options(warn = -1)")
968
+ labels = mapping.integers_to_labels(x)
969
+ df = pd.DataFrame({x_name: pd.Categorical(labels, categories=regions)})
970
+ basis = jnp.asarray(np.astype(smooth.predict(df)[:, 1:], "float"))
971
+ r("options(warn = old_warn)")
972
+ return basis
973
+
974
+ smooth_penalty = smooth.penalty
975
+ if np.shape(smooth_penalty)[1] > len(labels):
976
+ smooth_penalty = smooth_penalty[:, 1:]
977
+
978
+ penalty_arr = jnp.asarray(np.astype(smooth_penalty, "float"))
979
+
980
+ basis = MRFBasis(
981
+ value=var,
982
+ basis_fn=basis_fun,
983
+ name=self.names.create_lazily(basis_name + "(" + x + ")"),
984
+ cache_basis=True,
985
+ use_callback=True,
986
+ penalty=penalty_arr,
987
+ )
988
+ if absorb_cons:
989
+ basis._constraint = "absorbed_via_mgcv"
990
+
991
+ try:
992
+ nb_out = to_py(f"{smooth._smooth_r_name}[[1]]$xt$nb", format="numpy")
993
+ except TypeError:
994
+ nb_out = None
995
+ # nb_out = {key: np.astype(val, "int") for key, val in nb_out.items()}
996
+
997
+ if absorb_cons:
998
+ label_order = None
999
+ else:
1000
+ label_order = list(
1001
+ to_py(f"{smooth._smooth_r_name}[[1]]$X", format="pandas").columns
1002
+ )
1003
+ label_order = [lab[1:] for lab in label_order] # removes leading x from R
1004
+
1005
+ if nb_out is not None:
1006
+
1007
+ def to_label(code):
1008
+ try:
1009
+ label_array = mapping.integers_to_labels(code - 1)
1010
+ except TypeError:
1011
+ label_array = code
1012
+ return np.atleast_1d(label_array).tolist()
1013
+
1014
+ nb_out = {k: to_label(v) for k, v in nb_out.items()}
1015
+
1016
+ basis.mrf_spec = MRFSpec(mapping, nb_out, label_order)
1017
+
1018
+ return basis
1019
+
1020
+
1021
+ @dataclass
1022
+ class NameManager:
1023
+ prefix: str = ""
1024
+ created_names: dict[str, int] = field(default_factory=dict)
1025
+
1026
+ def create(self, name: str, apply_prefix: bool = True) -> str:
1027
+ """
1028
+ Appends a counter to the given name for uniqueness.
1029
+ There is an individual counter for each name.
1030
+
1031
+ If a prefix was passed to the builder on init, the prefix is applied to the
1032
+ name.
1033
+ """
1034
+ if apply_prefix:
1035
+ name = self.prefix + name
1036
+
1037
+ i = self.created_names.get(name, 0)
1038
+
1039
+ name_indexed = name + str(i)
1040
+
1041
+ self.created_names[name] = i + 1
1042
+
1043
+ return name_indexed
1044
+
1045
+ def create_lazily(self, name: str, apply_prefix: bool = True) -> str:
1046
+ if apply_prefix:
1047
+ name = self.prefix + name
1048
+
1049
+ i = self.created_names.get(name, 0)
1050
+
1051
+ if i > 0:
1052
+ name_indexed = name + str(i)
1053
+ else:
1054
+ name_indexed = name
1055
+
1056
+ self.created_names[name] = i + 1
1057
+
1058
+ return name_indexed
1059
+
1060
+ def fname(self, f: str, basis: Basis) -> str:
1061
+ return self.create_lazily(f"{f}({basis.x.name})")
1062
+
1063
+ def create_param_name(self, term_name: str, param_name: str) -> str:
1064
+ if term_name:
1065
+ param_name = f"${param_name}" + "_{" + f"{term_name}" + "}$"
1066
+ return self.create_lazily(param_name, apply_prefix=False)
1067
+ else:
1068
+ param_name = f"${param_name}$"
1069
+ return self.create_lazily(param_name, apply_prefix=True)
1070
+
1071
+ def create_beta_name(self, term_name: str) -> str:
1072
+ return self.create_param_name(term_name=term_name, param_name="\\beta")
1073
+
1074
+ def create_tau_name(self, term_name: str) -> str:
1075
+ return self.create_param_name(term_name=term_name, param_name="\\tau")
1076
+
1077
+ def create_tau2_name(self, term_name: str) -> str:
1078
+ return self.create_param_name(term_name=term_name, param_name="\\tau^2")
1079
+
1080
+
1081
+ class TermBuilder:
1082
+ def __init__(self, registry: PandasRegistry, prefix_names_by: str = "") -> None:
1083
+ self.registry = registry
1084
+ self.names = NameManager(prefix=prefix_names_by)
1085
+ self.bases = BasisBuilder(registry, names=self.names)
1086
+
1087
+ def _init_default_scale(
1088
+ self,
1089
+ concentration: float | Array,
1090
+ scale: float | Array,
1091
+ value: float | Array = 1.0,
1092
+ term_name: str = "",
1093
+ ) -> ScaleIG:
1094
+ scale_name = self.names.create_tau_name(term_name)
1095
+ variance_name = self.names.create_tau2_name(term_name)
1096
+ scale_var = ScaleIG(
1097
+ value=value,
1098
+ concentration=concentration,
1099
+ scale=scale,
1100
+ name=scale_name,
1101
+ variance_name=variance_name,
1102
+ )
1103
+ return scale_var
1104
+
1105
+ @classmethod
1106
+ def from_dict(
1107
+ cls, data: dict[str, ArrayLike], prefix_names_by: str = ""
1108
+ ) -> TermBuilder:
1109
+ return cls.from_df(pd.DataFrame(data), prefix_names_by=prefix_names_by)
1110
+
1111
+ @classmethod
1112
+ def from_df(cls, data: pd.DataFrame, prefix_names_by: str = "") -> TermBuilder:
1113
+ registry = PandasRegistry(
1114
+ data, na_action="drop", prefix_names_by=prefix_names_by
1115
+ )
1116
+ return cls(registry, prefix_names_by=prefix_names_by)
1117
+
1118
+ def labels_to_integers(self, newdata: dict) -> dict:
1119
+ return labels_to_integers(newdata, self.bases.mappings)
1120
+
1121
+ # formula
1122
+ def lin(
1123
+ self,
1124
+ formula: str,
1125
+ prior: lsl.Dist | None = None,
1126
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1127
+ include_intercept: bool = False,
1128
+ context: dict[str, Any] | None = None,
1129
+ ) -> LinTerm:
1130
+ r"""
1131
+ Supported:
1132
+ - {a+1} for quoted Python
1133
+ - `weird name` backtick-strings for weird names
1134
+ - (a + b)**n for n-th order interactions
1135
+ - a:b for simple interactions
1136
+ - a*b for expanding to a + b + a:b
1137
+ - a / b for nesting
1138
+ - b %in% a for inverted nesting
1139
+ - Python functions
1140
+ - bs
1141
+ - cr
1142
+ - cs
1143
+ - cc
1144
+ - hashed
1145
+
1146
+ .. warning:: If you use bs, cr, cs, or cc, be aware that these will not
1147
+ lead to terms that include a penalty. In most cases, you probably want
1148
+ to use :meth:`~.TermBuilder.s`, :meth:`~.TermBuilder.ps`, and so on
1149
+ instead.
1150
+
1151
+ Not supported:
1152
+
1153
+ - String literals
1154
+ - Numeric literals
1155
+ - Wildcard "."
1156
+ - \| for splitting a formula
1157
+ - "te" tensor products
1158
+
1159
+ - "~" in formula
1160
+ - 1 + in formula
1161
+ - 0 + in formula
1162
+ - -1 in formula
1163
+
1164
+ """
1165
+
1166
+ basis = self.bases.lin(
1167
+ formula,
1168
+ xname="",
1169
+ basis_name="X",
1170
+ include_intercept=include_intercept,
1171
+ context=context,
1172
+ )
1173
+
1174
+ if basis.x.name:
1175
+ term_name = self.names.create_lazily("lin" + "(" + basis.x.name + ")")
1176
+ else:
1177
+ term_name = self.names.create_lazily("lin" + "(" + basis.name + ")")
1178
+
1179
+ coef_name = self.names.create_beta_name(term_name)
1180
+
1181
+ term = LinTerm(
1182
+ basis, prior=prior, name=term_name, inference=inference, coef_name=coef_name
1183
+ )
1184
+
1185
+ term.model_spec = basis.model_spec
1186
+ term.mappings = basis.mappings
1187
+ term.column_names = basis.column_names
1188
+
1189
+ return term
1190
+
1191
+ def cr(
1192
+ self,
1193
+ x: str,
1194
+ *,
1195
+ k: int,
1196
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1197
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1198
+ penalty_order: int = 2,
1199
+ knots: ArrayLike | None = None,
1200
+ absorb_cons: bool = True,
1201
+ diagonal_penalty: bool = True,
1202
+ scale_penalty: bool = True,
1203
+ noncentered: bool = False,
1204
+ ) -> Term:
1205
+ basis = self.bases.cr(
1206
+ x=x,
1207
+ k=k,
1208
+ penalty_order=penalty_order,
1209
+ knots=knots,
1210
+ absorb_cons=absorb_cons,
1211
+ diagonal_penalty=diagonal_penalty,
1212
+ scale_penalty=scale_penalty,
1213
+ basis_name="B",
1214
+ )
1215
+
1216
+ fname = self.names.fname("cr", basis)
1217
+
1218
+ if isinstance(scale, VarIGPrior):
1219
+ scale = self._init_default_scale(
1220
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1221
+ )
1222
+
1223
+ coef_name = self.names.create_beta_name(fname)
1224
+ term = Term(
1225
+ basis=basis,
1226
+ penalty=basis.penalty,
1227
+ scale=scale,
1228
+ name=fname,
1229
+ inference=inference,
1230
+ coef_name=coef_name,
1231
+ )
1232
+ if noncentered:
1233
+ term.reparam_noncentered()
1234
+ return term
1235
+
1236
+ def cs(
1237
+ self,
1238
+ x: str,
1239
+ *,
1240
+ k: int,
1241
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1242
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1243
+ penalty_order: int = 2,
1244
+ knots: ArrayLike | None = None,
1245
+ absorb_cons: bool = True,
1246
+ diagonal_penalty: bool = True,
1247
+ scale_penalty: bool = True,
1248
+ noncentered: bool = False,
1249
+ ) -> Term:
1250
+ basis = self.bases.cs(
1251
+ x=x,
1252
+ k=k,
1253
+ penalty_order=penalty_order,
1254
+ knots=knots,
1255
+ absorb_cons=absorb_cons,
1256
+ diagonal_penalty=diagonal_penalty,
1257
+ scale_penalty=scale_penalty,
1258
+ basis_name="B",
1259
+ )
1260
+
1261
+ fname = self.names.fname("cs", basis)
1262
+
1263
+ if isinstance(scale, VarIGPrior):
1264
+ scale = self._init_default_scale(
1265
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1266
+ )
1267
+
1268
+ coef_name = self.names.create_beta_name(fname)
1269
+ term = Term(
1270
+ basis=basis,
1271
+ penalty=basis.penalty,
1272
+ scale=scale,
1273
+ name=fname,
1274
+ inference=inference,
1275
+ coef_name=coef_name,
1276
+ )
1277
+ if noncentered:
1278
+ term.reparam_noncentered()
1279
+ return term
1280
+
1281
+ def cc(
1282
+ self,
1283
+ x: str,
1284
+ *,
1285
+ k: int,
1286
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1287
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1288
+ penalty_order: int = 2,
1289
+ knots: ArrayLike | None = None,
1290
+ absorb_cons: bool = True,
1291
+ diagonal_penalty: bool = True,
1292
+ scale_penalty: bool = True,
1293
+ noncentered: bool = False,
1294
+ ) -> Term:
1295
+ basis = self.bases.cc(
1296
+ x=x,
1297
+ k=k,
1298
+ penalty_order=penalty_order,
1299
+ knots=knots,
1300
+ absorb_cons=absorb_cons,
1301
+ diagonal_penalty=diagonal_penalty,
1302
+ scale_penalty=scale_penalty,
1303
+ basis_name="B",
1304
+ )
1305
+
1306
+ fname = self.names.fname("cc", basis)
1307
+
1308
+ if isinstance(scale, VarIGPrior):
1309
+ scale = self._init_default_scale(
1310
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1311
+ )
1312
+
1313
+ coef_name = self.names.create_beta_name(fname)
1314
+ term = Term(
1315
+ basis=basis,
1316
+ penalty=basis.penalty,
1317
+ scale=scale,
1318
+ name=fname,
1319
+ inference=inference,
1320
+ coef_name=coef_name,
1321
+ )
1322
+ if noncentered:
1323
+ term.reparam_noncentered()
1324
+ return term
1325
+
1326
+ def bs(
1327
+ self,
1328
+ x: str,
1329
+ *,
1330
+ k: int,
1331
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1332
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1333
+ basis_degree: int = 3,
1334
+ penalty_order: int | Sequence[int] = 2,
1335
+ knots: ArrayLike | None = None,
1336
+ absorb_cons: bool = True,
1337
+ diagonal_penalty: bool = True,
1338
+ scale_penalty: bool = True,
1339
+ noncentered: bool = False,
1340
+ ) -> Term:
1341
+ basis = self.bases.bs(
1342
+ x=x,
1343
+ k=k,
1344
+ basis_degree=basis_degree,
1345
+ penalty_order=penalty_order,
1346
+ knots=knots,
1347
+ absorb_cons=absorb_cons,
1348
+ diagonal_penalty=diagonal_penalty,
1349
+ scale_penalty=scale_penalty,
1350
+ basis_name="B",
1351
+ )
1352
+
1353
+ fname = self.names.fname("bs", basis)
1354
+
1355
+ if isinstance(scale, VarIGPrior):
1356
+ scale = self._init_default_scale(
1357
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1358
+ )
1359
+
1360
+ coef_name = self.names.create_beta_name(fname)
1361
+ term = Term(
1362
+ basis=basis,
1363
+ penalty=basis.penalty,
1364
+ scale=scale,
1365
+ name=fname,
1366
+ inference=inference,
1367
+ coef_name=coef_name,
1368
+ )
1369
+ if noncentered:
1370
+ term.reparam_noncentered()
1371
+ return term
1372
+
1373
+ # P-spline
1374
+ def ps(
1375
+ self,
1376
+ x: str,
1377
+ *,
1378
+ k: int,
1379
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1380
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1381
+ basis_degree: int = 3,
1382
+ penalty_order: int = 2,
1383
+ knots: ArrayLike | None = None,
1384
+ absorb_cons: bool = True,
1385
+ diagonal_penalty: bool = True,
1386
+ scale_penalty: bool = True,
1387
+ noncentered: bool = False,
1388
+ ) -> Term:
1389
+ basis = self.bases.ps(
1390
+ x=x,
1391
+ k=k,
1392
+ basis_degree=basis_degree,
1393
+ penalty_order=penalty_order,
1394
+ knots=knots,
1395
+ absorb_cons=absorb_cons,
1396
+ diagonal_penalty=diagonal_penalty,
1397
+ scale_penalty=scale_penalty,
1398
+ basis_name="B",
1399
+ )
1400
+
1401
+ fname = self.names.fname("ps", basis)
1402
+
1403
+ if isinstance(scale, VarIGPrior):
1404
+ scale = self._init_default_scale(
1405
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1406
+ )
1407
+
1408
+ coef_name = self.names.create_beta_name(fname)
1409
+ term = Term(
1410
+ basis=basis,
1411
+ penalty=basis.penalty,
1412
+ scale=scale,
1413
+ name=fname,
1414
+ inference=inference,
1415
+ coef_name=coef_name,
1416
+ )
1417
+ if noncentered:
1418
+ term.reparam_noncentered()
1419
+ return term
1420
+
1421
+ def cp(
1422
+ self,
1423
+ x: str,
1424
+ *,
1425
+ k: int,
1426
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1427
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1428
+ basis_degree: int = 3,
1429
+ penalty_order: int = 2,
1430
+ knots: ArrayLike | None = None,
1431
+ absorb_cons: bool = True,
1432
+ diagonal_penalty: bool = True,
1433
+ scale_penalty: bool = True,
1434
+ noncentered: bool = False,
1435
+ ) -> Term:
1436
+ basis = self.bases.cp(
1437
+ x=x,
1438
+ k=k,
1439
+ basis_degree=basis_degree,
1440
+ penalty_order=penalty_order,
1441
+ knots=knots,
1442
+ absorb_cons=absorb_cons,
1443
+ diagonal_penalty=diagonal_penalty,
1444
+ scale_penalty=scale_penalty,
1445
+ basis_name="B",
1446
+ )
1447
+
1448
+ fname = self.names.fname("cp", basis)
1449
+
1450
+ if isinstance(scale, VarIGPrior):
1451
+ scale = self._init_default_scale(
1452
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1453
+ )
1454
+
1455
+ coef_name = self.names.create_beta_name(fname)
1456
+ term = Term(
1457
+ basis=basis,
1458
+ penalty=basis.penalty,
1459
+ scale=scale,
1460
+ name=fname,
1461
+ inference=inference,
1462
+ coef_name=coef_name,
1463
+ )
1464
+ if noncentered:
1465
+ term.reparam_noncentered()
1466
+ return term
1467
+
1468
+ # random intercept
1469
+ def ri(
1470
+ self,
1471
+ cluster: str,
1472
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1473
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1474
+ penalty: ArrayLike | None = None,
1475
+ noncentered: bool = False,
1476
+ ) -> RITerm:
1477
+ basis = self.bases.ri(cluster=cluster, basis_name="B", penalty=penalty)
1478
+
1479
+ fname = self.names.fname("ri", basis)
1480
+ if isinstance(scale, VarIGPrior):
1481
+ scale = self._init_default_scale(
1482
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1483
+ )
1484
+
1485
+ coef_name = self.names.create_beta_name(fname)
1486
+
1487
+ term = RITerm(
1488
+ basis=basis,
1489
+ penalty=basis.penalty,
1490
+ coef_name=coef_name,
1491
+ inference=inference,
1492
+ scale=scale,
1493
+ name=fname,
1494
+ )
1495
+
1496
+ if noncentered:
1497
+ term.reparam_noncentered()
1498
+
1499
+ mapping = self.bases.mappings[cluster]
1500
+ term.mapping = mapping
1501
+ term.labels = list(mapping.labels_to_integers_map)
1502
+
1503
+ return term
1504
+
1505
+ # random scaling
1506
+ def rs(
1507
+ self,
1508
+ x: str | Term | LinTerm,
1509
+ cluster: str,
1510
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1511
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1512
+ penalty: ArrayLike | None = None,
1513
+ noncentered: bool = False,
1514
+ ) -> lsl.Var:
1515
+ ri = self.ri(
1516
+ cluster=cluster,
1517
+ scale=scale,
1518
+ inference=inference,
1519
+ penalty=penalty,
1520
+ noncentered=noncentered,
1521
+ )
1522
+
1523
+ if isinstance(x, str):
1524
+ x_var = self.registry.get_numeric_obs(x)
1525
+ xname = x
1526
+ else:
1527
+ x_var = x
1528
+ xname = x_var.basis.x.name
1529
+
1530
+ fname = self.names.create_lazily("rs(" + xname + "|" + cluster + ")")
1531
+ term = lsl.Var.new_calc(
1532
+ lambda x, cluster: x * cluster,
1533
+ x=x_var,
1534
+ cluster=ri,
1535
+ name=fname,
1536
+ )
1537
+ return term
1538
+
1539
+ # varying coefficient
1540
+ def vc(
1541
+ self,
1542
+ x: str,
1543
+ by: Term,
1544
+ ) -> lsl.Var:
1545
+ fname = self.names.create_lazily(x + "*" + by.name)
1546
+ x_var = self.registry.get_obs(x)
1547
+
1548
+ term = lsl.Var.new_calc(
1549
+ lambda x, by: x * by,
1550
+ x=x_var,
1551
+ by=by,
1552
+ name=fname,
1553
+ )
1554
+ return term
1555
+
1556
+ # general smooth with MGCV bases
1557
+ def s(
1558
+ self,
1559
+ *x: str,
1560
+ k: int,
1561
+ bs: BasisTypes,
1562
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1563
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1564
+ m: str = "NA",
1565
+ knots: ArrayLike | None = None,
1566
+ absorb_cons: bool = True,
1567
+ diagonal_penalty: bool = True,
1568
+ scale_penalty: bool = True,
1569
+ noncentered: bool = False,
1570
+ ) -> Term:
1571
+ """
1572
+ Works:
1573
+ - tp (thin plate splines)
1574
+ - ts (thin plate splines with slight null space penalty)
1575
+
1576
+ - cr (cubic regression splines)
1577
+ - cs (shrinked cubic regression splines)
1578
+ - cc (cyclic cubic regression splines)
1579
+
1580
+ - bs (B-splines)
1581
+ - ps (P-splines)
1582
+ - cp (cyclic P-splines)
1583
+
1584
+ Works, but not here:
1585
+ - re (use .ri instead)
1586
+ - mrf (used .mrf instead)
1587
+ - te (use .te instead) (with the bases above)
1588
+ - ti (use .ti instead) (with the bases above)
1589
+
1590
+ Does not work:
1591
+ - ds (Duchon splines)
1592
+ - sos (splines on the sphere)
1593
+ - gp (gaussian process)
1594
+ - so (soap film smooths)
1595
+ - ad (adaptive smooths)
1596
+
1597
+ Probably disallow manually:
1598
+ - fz (factor smooth interaction)
1599
+ - fs (random factor smooth interaction)
1600
+ """
1601
+ basis = self.bases.s(
1602
+ *x,
1603
+ k=k,
1604
+ bs=bs,
1605
+ m=m,
1606
+ knots=knots,
1607
+ absorb_cons=absorb_cons,
1608
+ diagonal_penalty=diagonal_penalty,
1609
+ scale_penalty=scale_penalty,
1610
+ basis_name="B",
1611
+ )
1612
+
1613
+ fname = self.names.fname(bs, basis=basis)
1614
+
1615
+ if isinstance(scale, VarIGPrior):
1616
+ scale = self._init_default_scale(
1617
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1618
+ )
1619
+
1620
+ coef_name = self.names.create_beta_name(fname)
1621
+ term = Term(
1622
+ basis,
1623
+ penalty=basis.penalty,
1624
+ name=fname,
1625
+ coef_name=coef_name,
1626
+ scale=scale,
1627
+ inference=inference,
1628
+ )
1629
+ if noncentered:
1630
+ term.reparam_noncentered()
1631
+ return term
1632
+
1633
+ # markov random field
1634
+ def mrf(
1635
+ self,
1636
+ x: str,
1637
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1638
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1639
+ k: int = -1,
1640
+ polys: dict[str, ArrayLike] | None = None,
1641
+ nb: Mapping[str, ArrayLike | list[str] | list[int]] | None = None,
1642
+ penalty: ArrayLike | None = None,
1643
+ absorb_cons: bool = True,
1644
+ diagonal_penalty: bool = True,
1645
+ scale_penalty: bool = True,
1646
+ noncentered: bool = False,
1647
+ ) -> MRFTerm:
1648
+ """
1649
+ Polys: Dictionary of arrays. The keys of the dict are the region labels.
1650
+ The corresponding values define the region by defining polygons.
1651
+ nb: Dictionary of array. The keys of the dict are the region labels.
1652
+ The corresponding values indicate the neighbors of the region.
1653
+ If it is a list or array of strings, the values are the labels of the
1654
+ neighbors.
1655
+ If it is a list or array of integers, the values are the indices of the
1656
+ neighbors.
1657
+
1658
+
1659
+ mgcv does not concern itself with your category ordering. It *will* order
1660
+ categories alphabetically. Penalty columns have to take this into account.
1661
+ """
1662
+ basis = self.bases.mrf(
1663
+ x=x,
1664
+ k=k,
1665
+ polys=polys,
1666
+ nb=nb,
1667
+ penalty=penalty,
1668
+ absorb_cons=absorb_cons,
1669
+ diagonal_penalty=diagonal_penalty,
1670
+ scale_penalty=scale_penalty,
1671
+ basis_name="B",
1672
+ )
1673
+
1674
+ fname = self.names.fname("mrf", basis)
1675
+ if isinstance(scale, VarIGPrior):
1676
+ scale = self._init_default_scale(
1677
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1678
+ )
1679
+ coef_name = self.names.create_beta_name(fname)
1680
+ term = MRFTerm(
1681
+ basis,
1682
+ penalty=basis.penalty,
1683
+ name=fname,
1684
+ scale=scale,
1685
+ inference=inference,
1686
+ coef_name=coef_name,
1687
+ )
1688
+ if noncentered:
1689
+ term.reparam_noncentered()
1690
+
1691
+ term.polygons = polys
1692
+ term.neighbors = basis.mrf_spec.nb
1693
+ if basis.mrf_spec.ordered_labels is not None:
1694
+ term.ordered_labels = basis.mrf_spec.ordered_labels
1695
+
1696
+ term.labels = list(basis.mrf_spec.mapping.labels_to_integers_map)
1697
+ term.mapping = basis.mrf_spec.mapping
1698
+
1699
+ return term
1700
+
1701
+ # general basis function + penalty smooth
1702
+ def f(
1703
+ self,
1704
+ *x: str,
1705
+ basis_fn: Callable[[Array], Array],
1706
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1707
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1708
+ use_callback: bool = True,
1709
+ cache_basis: bool = True,
1710
+ penalty: ArrayLike | None = None,
1711
+ noncentered: bool = False,
1712
+ ) -> Term:
1713
+ basis = self.bases.basis(
1714
+ *x,
1715
+ basis_fn=basis_fn,
1716
+ use_callback=use_callback,
1717
+ cache_basis=cache_basis,
1718
+ penalty=penalty,
1719
+ basis_name="B",
1720
+ )
1721
+
1722
+ fname = self.names.fname("f", basis)
1723
+ if isinstance(scale, VarIGPrior):
1724
+ scale = self._init_default_scale(
1725
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1726
+ )
1727
+
1728
+ coef_name = self.names.create_beta_name(fname)
1729
+ term = Term(
1730
+ basis,
1731
+ penalty=basis.penalty,
1732
+ name=fname,
1733
+ scale=scale,
1734
+ inference=inference,
1735
+ coef_name=coef_name,
1736
+ )
1737
+ if noncentered:
1738
+ term.reparam_noncentered()
1739
+ return term
1740
+
1741
+ def kriging(
1742
+ self,
1743
+ *x: str,
1744
+ k: int,
1745
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1746
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1747
+ kernel_name: Literal[
1748
+ "spherical",
1749
+ "power_exponential",
1750
+ "matern1.5",
1751
+ "matern2.5",
1752
+ "matern3.5",
1753
+ ] = "matern1.5",
1754
+ linear_trend: bool = True,
1755
+ range: float | None = None,
1756
+ power_exponential_power: float = 1.0,
1757
+ knots: ArrayLike | None = None,
1758
+ absorb_cons: bool = True,
1759
+ diagonal_penalty: bool = True,
1760
+ scale_penalty: bool = True,
1761
+ noncentered: bool = False,
1762
+ ) -> Term:
1763
+ basis = self.bases.kriging(
1764
+ *x,
1765
+ k=k,
1766
+ kernel_name=kernel_name,
1767
+ linear_trend=linear_trend,
1768
+ range=range,
1769
+ power_exponential_power=power_exponential_power,
1770
+ knots=knots,
1771
+ absorb_cons=absorb_cons,
1772
+ diagonal_penalty=diagonal_penalty,
1773
+ scale_penalty=scale_penalty,
1774
+ basis_name="B",
1775
+ )
1776
+
1777
+ fname = self.names.fname("kriging", basis)
1778
+ if isinstance(scale, VarIGPrior):
1779
+ scale = self._init_default_scale(
1780
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1781
+ )
1782
+ coef_name = self.names.create_beta_name(fname)
1783
+ term = Term(
1784
+ basis,
1785
+ penalty=basis.penalty,
1786
+ name=fname,
1787
+ scale=scale,
1788
+ inference=inference,
1789
+ coef_name=coef_name,
1790
+ )
1791
+ if noncentered:
1792
+ term.reparam_noncentered()
1793
+ return term
1794
+
1795
+ def tp(
1796
+ self,
1797
+ *x: str,
1798
+ k: int,
1799
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1800
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1801
+ penalty_order: int | None = None,
1802
+ knots: ArrayLike | None = None,
1803
+ absorb_cons: bool = True,
1804
+ diagonal_penalty: bool = True,
1805
+ scale_penalty: bool = True,
1806
+ noncentered: bool = False,
1807
+ remove_null_space_completely: bool = False,
1808
+ ) -> Term:
1809
+ basis = self.bases.tp(
1810
+ *x,
1811
+ k=k,
1812
+ penalty_order=penalty_order,
1813
+ knots=knots,
1814
+ absorb_cons=absorb_cons,
1815
+ diagonal_penalty=diagonal_penalty,
1816
+ scale_penalty=scale_penalty,
1817
+ basis_name="B",
1818
+ remove_null_space_completely=remove_null_space_completely,
1819
+ )
1820
+
1821
+ fname = self.names.fname("tp", basis)
1822
+ if isinstance(scale, VarIGPrior):
1823
+ scale = self._init_default_scale(
1824
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1825
+ )
1826
+
1827
+ coef_name = self.names.create_beta_name(fname)
1828
+ term = Term(
1829
+ basis,
1830
+ penalty=basis.penalty,
1831
+ name=fname,
1832
+ scale=scale,
1833
+ inference=inference,
1834
+ coef_name=coef_name,
1835
+ )
1836
+ if noncentered:
1837
+ term.reparam_noncentered()
1838
+ return term
1839
+
1840
+ def ts(
1841
+ self,
1842
+ *x: str,
1843
+ k: int,
1844
+ scale: ScaleIG | lsl.Var | float | VarIGPrior = VarIGPrior(1.0, 0.005),
1845
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1846
+ penalty_order: int | None = None,
1847
+ knots: ArrayLike | None = None,
1848
+ absorb_cons: bool = True,
1849
+ diagonal_penalty: bool = True,
1850
+ scale_penalty: bool = True,
1851
+ noncentered: bool = False,
1852
+ ) -> Term:
1853
+ basis = self.bases.ts(
1854
+ *x,
1855
+ k=k,
1856
+ penalty_order=penalty_order,
1857
+ knots=knots,
1858
+ absorb_cons=absorb_cons,
1859
+ diagonal_penalty=diagonal_penalty,
1860
+ scale_penalty=scale_penalty,
1861
+ basis_name="B",
1862
+ )
1863
+
1864
+ fname = self.names.fname("ts", basis)
1865
+ if isinstance(scale, VarIGPrior):
1866
+ scale = self._init_default_scale(
1867
+ concentration=scale.concentration, scale=scale.scale, term_name=fname
1868
+ )
1869
+
1870
+ coef_name = self.names.create_beta_name(fname)
1871
+ term = Term(
1872
+ basis,
1873
+ penalty=basis.penalty,
1874
+ name=fname,
1875
+ scale=scale,
1876
+ inference=inference,
1877
+ coef_name=coef_name,
1878
+ )
1879
+ if noncentered:
1880
+ term.reparam_noncentered()
1881
+ return term
1882
+
1883
+ def ta(
1884
+ self,
1885
+ *marginals: Term,
1886
+ common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
1887
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1888
+ include_main_effects: bool = False,
1889
+ scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
1890
+ _fname: str = "ta",
1891
+ ) -> TPTerm:
1892
+ """
1893
+ Will remove any default gibbs samplers and replace them with scales_inferece
1894
+ on a transformed version.
1895
+ """
1896
+ inputs = ",".join(list(TPTerm._input_obs([t.basis for t in marginals])))
1897
+ fname = self.names.create_lazily(f"{_fname}(" + inputs + ")")
1898
+ coef_name = self.names.create_beta_name(fname)
1899
+
1900
+ if isinstance(common_scale, VarIGPrior):
1901
+ common_scale = self._init_default_scale(
1902
+ concentration=common_scale.concentration,
1903
+ scale=common_scale.scale,
1904
+ term_name=fname,
1905
+ )
1906
+
1907
+ if common_scale is not None and not isinstance(common_scale, float):
1908
+ _replace_star_gibbs_with(common_scale, scales_inference)
1909
+
1910
+ term = TPTerm(
1911
+ *marginals,
1912
+ common_scale=common_scale,
1913
+ name=fname,
1914
+ inference=inference,
1915
+ coef_name=coef_name,
1916
+ include_main_effects=include_main_effects,
1917
+ )
1918
+
1919
+ for scale in term.scales:
1920
+ if not isinstance(scale, lsl.Var):
1921
+ raise TypeError(
1922
+ f"Expected scale to be a liesel.model.Var, got {type(scale)}"
1923
+ )
1924
+ _replace_star_gibbs_with(scale, scales_inference)
1925
+
1926
+ return term
1927
+
1928
+ def tx(
1929
+ self,
1930
+ *marginals: Term,
1931
+ common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
1932
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1933
+ scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
1934
+ ) -> TPTerm:
1935
+ return self.ta(
1936
+ *marginals,
1937
+ common_scale=common_scale,
1938
+ inference=inference,
1939
+ scales_inference=scales_inference,
1940
+ include_main_effects=False,
1941
+ _fname="tx",
1942
+ )
1943
+
1944
+ def tf(
1945
+ self,
1946
+ *marginals: Term,
1947
+ common_scale: ScaleIG | lsl.Var | float | VarIGPrior | None = None,
1948
+ inference: InferenceTypes | None = gs.MCMCSpec(gs.IWLSKernel),
1949
+ scales_inference: InferenceTypes | None = gs.MCMCSpec(gs.HMCKernel),
1950
+ ) -> TPTerm:
1951
+ return self.ta(
1952
+ *marginals,
1953
+ common_scale=common_scale,
1954
+ inference=inference,
1955
+ scales_inference=scales_inference,
1956
+ include_main_effects=True,
1957
+ _fname="tf",
1958
+ )
1959
+
1960
+
1961
+ def _get_parameter(var: lsl.Var) -> lsl.Var:
1962
+ if var.strong:
1963
+ if var.parameter:
1964
+ return var
1965
+ else:
1966
+ raise ValueError(f"{var} is strong, but not a parameter.")
1967
+
1968
+ with TemporaryModel(var, to_float32=False) as model:
1969
+ params = model.parameters
1970
+ if not params:
1971
+ raise ValueError(f"No parameter found in the graph of {var}.")
1972
+ if len(params) > 1:
1973
+ raise ValueError(
1974
+ f"In the graph of {var}, there are {len(params)} parameters, "
1975
+ "so we cannot return a unique parameter."
1976
+ )
1977
+ param = list(model.parameters.values())[0]
1978
+
1979
+ return param
1980
+
1981
+
1982
+ def _replace_star_gibbs_with(var: lsl.Var, inference: InferenceTypes | None) -> lsl.Var:
1983
+ param = _get_parameter(var)
1984
+ if param.inference is not None:
1985
+ if isinstance(param.inference, gs.MCMCSpec):
1986
+ try:
1987
+ is_star_gibbs = param.inference.kernel.__name__ == "StarVarianceGibbs" # type: ignore
1988
+ if not is_star_gibbs:
1989
+ return var
1990
+ except AttributeError:
1991
+ # in this case, we assume that the inference has been set intentionally
1992
+ # so we don't change anything
1993
+ return var
1994
+ else:
1995
+ # in this case, we assume that the inference has been set intentionally
1996
+ # so we don't change anything
1997
+ return var
1998
+ if param.name:
1999
+ trafo_name = "h(" + param.name + ")"
2000
+ else:
2001
+ trafo_name = None
2002
+ param.transform(bijector=None, inference=inference, name=trafo_name)
2003
+ return var