liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liesel_gam/predictor.py CHANGED
@@ -1,45 +1,84 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Callable
3
+ from collections.abc import Callable, Sequence
4
4
  from typing import Any, Self, cast
5
5
 
6
+ import liesel.goose as gs
6
7
  import liesel.model as lsl
7
8
 
9
+ from .var import BasisDot, Term, UserVar
10
+
8
11
  Array = Any
9
12
 
13
+ term_types = Term | BasisDot | lsl.Var
14
+
10
15
 
11
- class AdditivePredictor(lsl.Var):
16
+ class AdditivePredictor(UserVar):
12
17
  def __init__(
13
- self, name: str, inv_link: Callable[[Array], Array] | None = None
18
+ self,
19
+ name: str,
20
+ inv_link: Callable[[Array], Array] | None = None,
21
+ intercept: bool | lsl.Var = True,
22
+ intercept_name: str = "$\\beta{subscript}$",
14
23
  ) -> None:
15
24
  if inv_link is None:
16
25
 
17
- def _sum(*args, **kwargs):
18
- # the + 0. implicitly ensures correct dtype also for empty predictors
19
- return sum(args) + sum(kwargs.values()) + 0.0
20
- else:
26
+ def inv_link(x):
27
+ return x
28
+
29
+ def _sum(*args, intercept, **kwargs):
30
+ # the + 0. implicitly ensures correct dtype also for empty predictors
31
+ return inv_link(sum(args) + sum(kwargs.values()) + 0.0 + intercept)
21
32
 
22
- def _sum(*args, **kwargs):
23
- # the + 0. implicitly ensures correct dtype also for empty predictors
24
- return inv_link(sum(args) + sum(kwargs.values()) + 0.0)
33
+ if intercept and not isinstance(intercept, lsl.Var):
34
+ name_cleaned = name.replace("$", "")
25
35
 
26
- super().__init__(lsl.Calc(_sum), name=name)
36
+ intercept_: lsl.Var | float = lsl.Var.new_param(
37
+ name=intercept_name.format(subscript="_{0," + name_cleaned + "}"),
38
+ value=0.0,
39
+ distribution=None,
40
+ inference=gs.MCMCSpec(gs.IWLSKernel),
41
+ )
42
+ else:
43
+ intercept_ = 0.0
44
+
45
+ super().__init__(lsl.Calc(_sum, intercept=intercept_), name=name)
27
46
  self.update()
28
- self.terms: dict[str, lsl.Var] = {}
47
+ self.terms: dict[str, term_types] = {}
29
48
  """Dictionary of terms in this predictor."""
30
49
 
50
+ @property
51
+ def intercept(self) -> lsl.Var | lsl.Node:
52
+ return self.value_node["intercept"]
53
+
54
+ @intercept.setter
55
+ def intercept(self, value: lsl.Var | lsl.Node):
56
+ self.value_node["intercept"] = value
57
+
31
58
  def update(self) -> Self:
32
59
  return cast(Self, super().update())
33
60
 
34
- def __add__(self, other: lsl.Var) -> Self:
35
- self.value_node.add_inputs(other)
36
- self.terms[other.name] = other
37
- return self.update()
61
+ def __iadd__(self, other: term_types | Sequence[term_types]) -> Self:
62
+ if isinstance(other, term_types):
63
+ self.append(other)
64
+ else:
65
+ self.extend(other)
66
+ return self
67
+
68
+ def append(self, term: term_types) -> None:
69
+ if not isinstance(term, term_types):
70
+ raise TypeError(f"{term} is of unsupported type {type(term)}.")
71
+
72
+ if term.name in self.terms:
73
+ raise RuntimeError(f"{self} already contains a term of name {term.name}.")
74
+
75
+ self.value_node.add_inputs(term)
76
+ self.terms[term.name] = term
77
+ self.update()
38
78
 
39
- def __iadd__(self, other: lsl.Var) -> Self:
40
- self.value_node.add_inputs(other)
41
- self.terms[other.name] = other
42
- return self.update()
79
+ def extend(self, terms: Sequence[term_types]) -> None:
80
+ for term in terms:
81
+ self.append(term)
43
82
 
44
83
  def __getitem__(self, name) -> lsl.Var:
45
84
  return self.terms[name]