liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
liesel_gam/predictor.py
CHANGED
|
@@ -1,45 +1,84 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import Callable
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
4
|
from typing import Any, Self, cast
|
|
5
5
|
|
|
6
|
+
import liesel.goose as gs
|
|
6
7
|
import liesel.model as lsl
|
|
7
8
|
|
|
9
|
+
from .var import BasisDot, Term, UserVar
|
|
10
|
+
|
|
8
11
|
Array = Any
|
|
9
12
|
|
|
13
|
+
term_types = Term | BasisDot | lsl.Var
|
|
14
|
+
|
|
10
15
|
|
|
11
|
-
class AdditivePredictor(
|
|
16
|
+
class AdditivePredictor(UserVar):
|
|
12
17
|
def __init__(
|
|
13
|
-
self,
|
|
18
|
+
self,
|
|
19
|
+
name: str,
|
|
20
|
+
inv_link: Callable[[Array], Array] | None = None,
|
|
21
|
+
intercept: bool | lsl.Var = True,
|
|
22
|
+
intercept_name: str = "$\\beta{subscript}$",
|
|
14
23
|
) -> None:
|
|
15
24
|
if inv_link is None:
|
|
16
25
|
|
|
17
|
-
def
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
26
|
+
def inv_link(x):
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
def _sum(*args, intercept, **kwargs):
|
|
30
|
+
# the + 0. implicitly ensures correct dtype also for empty predictors
|
|
31
|
+
return inv_link(sum(args) + sum(kwargs.values()) + 0.0 + intercept)
|
|
21
32
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
return inv_link(sum(args) + sum(kwargs.values()) + 0.0)
|
|
33
|
+
if intercept and not isinstance(intercept, lsl.Var):
|
|
34
|
+
name_cleaned = name.replace("$", "")
|
|
25
35
|
|
|
26
|
-
|
|
36
|
+
intercept_: lsl.Var | float = lsl.Var.new_param(
|
|
37
|
+
name=intercept_name.format(subscript="_{0," + name_cleaned + "}"),
|
|
38
|
+
value=0.0,
|
|
39
|
+
distribution=None,
|
|
40
|
+
inference=gs.MCMCSpec(gs.IWLSKernel),
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
intercept_ = 0.0
|
|
44
|
+
|
|
45
|
+
super().__init__(lsl.Calc(_sum, intercept=intercept_), name=name)
|
|
27
46
|
self.update()
|
|
28
|
-
self.terms: dict[str,
|
|
47
|
+
self.terms: dict[str, term_types] = {}
|
|
29
48
|
"""Dictionary of terms in this predictor."""
|
|
30
49
|
|
|
50
|
+
@property
|
|
51
|
+
def intercept(self) -> lsl.Var | lsl.Node:
|
|
52
|
+
return self.value_node["intercept"]
|
|
53
|
+
|
|
54
|
+
@intercept.setter
|
|
55
|
+
def intercept(self, value: lsl.Var | lsl.Node):
|
|
56
|
+
self.value_node["intercept"] = value
|
|
57
|
+
|
|
31
58
|
def update(self) -> Self:
|
|
32
59
|
return cast(Self, super().update())
|
|
33
60
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
61
|
+
def __iadd__(self, other: term_types | Sequence[term_types]) -> Self:
|
|
62
|
+
if isinstance(other, term_types):
|
|
63
|
+
self.append(other)
|
|
64
|
+
else:
|
|
65
|
+
self.extend(other)
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
def append(self, term: term_types) -> None:
|
|
69
|
+
if not isinstance(term, term_types):
|
|
70
|
+
raise TypeError(f"{term} is of unsupported type {type(term)}.")
|
|
71
|
+
|
|
72
|
+
if term.name in self.terms:
|
|
73
|
+
raise RuntimeError(f"{self} already contains a term of name {term.name}.")
|
|
74
|
+
|
|
75
|
+
self.value_node.add_inputs(term)
|
|
76
|
+
self.terms[term.name] = term
|
|
77
|
+
self.update()
|
|
38
78
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
return self.update()
|
|
79
|
+
def extend(self, terms: Sequence[term_types]) -> None:
|
|
80
|
+
for term in terms:
|
|
81
|
+
self.append(term)
|
|
43
82
|
|
|
44
83
|
def __getitem__(self, name) -> lsl.Var:
|
|
45
84
|
return self.terms[name]
|