liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ Array = Any
10
+
11
+
12
+ class CategoryError(KeyError):
13
+ pass
14
+
15
+
16
+ class UnknownLabelError(CategoryError):
17
+ pass
18
+
19
+
20
+ class UnknownCodeError(CategoryError):
21
+ pass
22
+
23
+
24
+ class CategoryMapping:
25
+ """Wraps a category mapping of labels to integers."""
26
+
27
+ def __init__(self, labels_to_integers_map: dict[Any, int]) -> None:
28
+ self._code_for_unknown_label = -1
29
+ self._label_for_unknown_code = None
30
+
31
+ self.labels_to_integers_map = labels_to_integers_map
32
+ self.integers_to_labels_map = {
33
+ code: label for label, code in self.labels_to_integers_map.items()
34
+ }
35
+
36
+ @classmethod
37
+ def from_series(cls, series: pd.Series | pd.Categorical) -> CategoryMapping:
38
+ """
39
+ When series is a pd.Categorical, the category sorting is kept.
40
+ When series is a series of dtype str or object, categories are sorted
41
+ alphabetically.
42
+ """
43
+ is_series = isinstance(series, pd.Series)
44
+ has_cat_dtype = isinstance(series.dtype, pd.CategoricalDtype)
45
+ is_cat = isinstance(series, pd.Categorical)
46
+ if is_cat:
47
+ unique_labels = np.asarray(series.categories)
48
+ elif is_series and has_cat_dtype:
49
+ unique_labels = np.asarray(series.cat.categories)
50
+ elif is_series:
51
+ cat = pd.Categorical(series)
52
+ unique_labels = np.sort(np.asarray(cat.categories))
53
+ else:
54
+ raise TypeError(
55
+ f"series must be a pd.Series or pd.Categorical, got {type(series)}."
56
+ )
57
+
58
+ mapping = {val: i for i, val in enumerate(unique_labels)}
59
+ return cls(mapping)
60
+
61
+ def to_integers(
62
+ self, labels_or_integers: np.typing.ArrayLike | Sequence[int] | Sequence[str]
63
+ ) -> np.typing.NDArray[np.int_]:
64
+ arr = np.asarray(labels_or_integers)
65
+
66
+ # Case 1: Already an integer array
67
+ if np.issubdtype(arr.dtype, np.integer):
68
+ valid_integers = np.array(list(self.integers_to_labels_map.keys()))
69
+ if not np.isin(arr, valid_integers).all():
70
+ invalid = arr[~np.isin(arr, valid_integers)]
71
+ raise ValueError(
72
+ f"Unknown integer codes: {invalid.tolist()} "
73
+ f"(valid integers: {valid_integers.tolist()})"
74
+ )
75
+ return arr.astype(int, copy=False)
76
+
77
+ # Case 2: Otherwise treat as labels
78
+ return self.labels_to_integers(arr)
79
+
80
+ def to_labels(
81
+ self, labels_or_integers: np.typing.ArrayLike | Sequence[int] | Sequence[str]
82
+ ) -> np.typing.NDArray[Any]:
83
+ arr = np.asarray(labels_or_integers)
84
+
85
+ # Case 1: It is an integer array
86
+ if np.issubdtype(arr.dtype, np.integer):
87
+ return self.integers_to_labels(arr)
88
+
89
+ # Case 2: Otherwise treat as labels
90
+ valid_labels = np.array(list(self.labels_to_integers_map.keys()))
91
+ if not np.isin(arr, valid_labels).all():
92
+ invalid = arr[~np.isin(arr, valid_labels)]
93
+ raise ValueError(
94
+ f"Unknown labels: {invalid.tolist()} "
95
+ f"(valid labels: {valid_labels.tolist()})"
96
+ )
97
+ return arr
98
+
99
+ def labels_to_integers(
100
+ self, labels: np.typing.ArrayLike | Sequence[str]
101
+ ) -> np.typing.NDArray[np.int_]:
102
+ """
103
+ A function of labels -> integers.
104
+
105
+ For unknown labels, returns -1.
106
+ """
107
+ labels = np.asarray(labels)
108
+ labels_flat = labels.flatten()
109
+ codes_flat = np.zeros_like(labels_flat, dtype=int)
110
+
111
+ for i, xi in enumerate(labels_flat):
112
+ codes_flat[i] = self.labels_to_integers_map.get(
113
+ xi, self._code_for_unknown_label
114
+ )
115
+ if codes_flat[i] == self._code_for_unknown_label:
116
+ raise UnknownLabelError(f"Category label {xi} is unknown.")
117
+
118
+ codes = np.reshape(codes_flat, shape=labels.shape)
119
+
120
+ return np.astype(codes, np.int_)
121
+
122
+ def integers_to_labels(
123
+ self, integers: np.typing.ArrayLike | Sequence[int]
124
+ ) -> np.typing.NDArray[Any]:
125
+ """
126
+ A function of integers -> labels.
127
+
128
+ For integers without labels, returns
129
+ """
130
+ integers = np.asarray(integers)
131
+ integers_flat = integers.flatten()
132
+ labels_flat_list = []
133
+
134
+ for xi in integers_flat:
135
+ label = self.integers_to_labels_map.get(xi, self._label_for_unknown_code)
136
+ if label == self._label_for_unknown_code:
137
+ raise UnknownCodeError(f"Category code {xi} is unknown.")
138
+ labels_flat_list.append(label)
139
+
140
+ labels_flat = np.asarray(labels_flat_list)
141
+ labels = np.reshape(labels_flat, shape=integers.shape)
142
+ return labels
143
+
144
+
145
+ def series_is_categorical(series: pd.Series | pd.Categorical) -> bool:
146
+ """
147
+ Provides a liberal interpretation of when a series is categorical. The following
148
+ are treated as categorical:
149
+
150
+ - Series with dtype str
151
+ - Series with dtype object
152
+ - Series with dtype CategoricalDtype
153
+ """
154
+ # This corresponds to how formulaic determines categorical columns.
155
+ # See formulaic.materializers.pandas.PandasMaterializer._is_categorical
156
+ is_cat1 = series.dtype in ("str", "object")
157
+ is_cat2 = isinstance(series.dtype, pd.CategoricalDtype)
158
+ return is_cat1 or is_cat2
@@ -0,0 +1,105 @@
1
+ """
2
+ Instances of :class:`.Basis` may use non-jittable basis functions.
3
+ In batched optimization, this may lead to inefficient repeated basis evaluation.
4
+ If the basis functions depend on ryp for interfacing to R, which many do, this will
5
+ not only be inefficient but fail completely, because R is not thread-safe.
6
+
7
+ To solve these issues, this module provides utility functions to create models that
8
+ can be safely and efficiently used in batched operations:
9
+
10
+ - :func:`.consolidate_bases` splits a model into the model, where
11
+ all bases are turned into strong, observed varibales, and a model for the bases.
12
+ The former can be used in batched optimization, the latter can be used to still
13
+ conveniently evaluate all relevant bases based on their original inputs.
14
+
15
+ - :func:`.evaluate_bases` takes a position of input data, evaluate the corresponding
16
+ bases in the provided model, and returns a position of the evaluated bases.
17
+ """
18
+
19
+ import liesel.model as lsl
20
+ from liesel.goose.types import Position
21
+
22
+ from ..var import Basis
23
+
24
+
25
+ def _remove_singleton_vars(gb: lsl.GraphBuilder) -> lsl.GraphBuilder:
26
+ """
27
+ Removes all singleton variables from the provided GraphBuilder.
28
+ """
29
+ model = gb.build_model()
30
+
31
+ G = model.var_graph
32
+ singletons1 = [n for n, d in G.degree() if d == 0]
33
+ singletons2 = [n for n in G.nodes() if G.in_degree(n) == 0 and G.out_degree(n) == 0]
34
+ singleton_vars = set(singletons1 + singletons2)
35
+
36
+ G = model.node_graph
37
+ singletons1 = [n for n, d in G.degree() if d == 0]
38
+ singletons2 = [n for n in G.nodes() if G.in_degree(n) == 0 and G.out_degree(n) == 0]
39
+ singleton_nodes = set(singletons1 + singletons2)
40
+
41
+ nodes, vars_ = model.pop_nodes_and_vars()
42
+
43
+ for var in singleton_vars:
44
+ vars_.pop(var.name, None)
45
+
46
+ for node in singleton_nodes:
47
+ nodes.pop(node.name, None)
48
+
49
+ gb = lsl.GraphBuilder(to_float32=model._to_float32)
50
+ gb.add(*vars_.values())
51
+ return gb
52
+
53
+
54
+ def consolidate_bases(
55
+ model: lsl.Model, copy: bool = True
56
+ ) -> tuple[lsl.Model, lsl.Model]:
57
+ """
58
+ Turns all :class:`.Basis` variables in the provided model into strong,
59
+ observed :class:`liesel.model.Var` variables.
60
+
61
+ Returns a new model that depends only on the strong bases, and a model that
62
+ holds the original bases and their input data.
63
+
64
+ If ``copy=False``, all data will be extracted from the original model, instead
65
+ of creating copies. This saves memory, but renders the original model empty.
66
+ """
67
+ if copy:
68
+ nodes, vars_ = model.copy_nodes_and_vars()
69
+ else:
70
+ nodes, vars_ = model.pop_nodes_and_vars()
71
+
72
+ gb = lsl.GraphBuilder(to_float32=model._to_float32)
73
+ gb.add(*nodes.values(), *vars_.values())
74
+
75
+ weak_bases = []
76
+
77
+ for var in gb.vars:
78
+ if not isinstance(var, Basis):
79
+ continue
80
+ weak_basis = var
81
+ strong_basis = lsl.Var.new_obs(weak_basis.update().value, name=weak_basis.name)
82
+ gb.replace_var(old=weak_basis, new=strong_basis)
83
+ weak_bases.append(weak_basis)
84
+
85
+ gb = _remove_singleton_vars(gb)
86
+
87
+ bases_model = lsl.GraphBuilder().add(*weak_bases).build_model()
88
+ model = gb.build_model()
89
+
90
+ return model, bases_model
91
+
92
+
93
+ def evaluate_bases(newdata: Position, model: lsl.Model) -> Position:
94
+ """
95
+ Evaluates all :class:`.Basis` variables in the provided model at the provided
96
+ newdata position.
97
+ """
98
+ state = model.update_state(newdata)
99
+
100
+ basis_names = []
101
+ for var in model.vars.values():
102
+ if isinstance(var, Basis):
103
+ basis_names.append(var.name)
104
+
105
+ return Position(model.extract_position(basis_names, state))