liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
"""Variable registry for managing data variables and transformations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
import warnings
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, Literal, assert_never
|
|
12
|
+
|
|
13
|
+
import jax.numpy as jnp
|
|
14
|
+
import liesel.model as lsl
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
|
|
18
|
+
from .category_mapping import CategoryMapping, series_is_categorical
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
Array = Any
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CannotHashValueError(Exception):
|
|
26
|
+
"""Custom exception for values that cannot be hashed."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, value: Any):
|
|
29
|
+
super().__init__(f"Cannot hash value of type '{type(value).__name__}'")
|
|
30
|
+
self.value = value
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class VarAndMapping:
|
|
35
|
+
var: lsl.Var
|
|
36
|
+
mapping: CategoryMapping | None = None
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def is_categorical(self) -> bool:
|
|
40
|
+
return self.mapping is not None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PandasRegistry:
|
|
44
|
+
"""Registry for managing variables and their transformations.
|
|
45
|
+
|
|
46
|
+
Handles conversion from `pandas.DataFrame` to `liesel.Var` objects,
|
|
47
|
+
applies transformations, and caches results for efficiency.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
data: pd.DataFrame,
|
|
53
|
+
na_action: Literal["error", "drop", "ignore"] = "error",
|
|
54
|
+
prefix_names_by: str = "",
|
|
55
|
+
):
|
|
56
|
+
"""Initialize the variable registry.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
data: pandas DataFrame containing model variables
|
|
60
|
+
na_action: How to handle NaN values. Either "error", "drop", or "ignore"
|
|
61
|
+
"""
|
|
62
|
+
if na_action not in ["error", "drop", "ignore"]:
|
|
63
|
+
raise ValueError("na_action must be 'error', 'drop', or 'ignore'")
|
|
64
|
+
|
|
65
|
+
self.original_data = data.copy()
|
|
66
|
+
self.na_action = na_action
|
|
67
|
+
self.data = self._validate_data(data)
|
|
68
|
+
self._var_cache: dict[str, lsl.Var] = {}
|
|
69
|
+
self._derived_cache: dict[str, lsl.Var] = {}
|
|
70
|
+
self.prefix = prefix_names_by
|
|
71
|
+
|
|
72
|
+
def _validate_data(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
73
|
+
"""Validate data and handle NaN values according to policy."""
|
|
74
|
+
if data.isna().any().any():
|
|
75
|
+
if self.na_action == "error":
|
|
76
|
+
na_cols = data.columns[data.isna().any()].tolist()
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"Data contains NaN values in columns: {na_cols}. "
|
|
79
|
+
"Use na_action='drop' to automatically remove rows with NaN values."
|
|
80
|
+
)
|
|
81
|
+
elif self.na_action == "drop":
|
|
82
|
+
clean_data = data.dropna()
|
|
83
|
+
if len(clean_data) == 0:
|
|
84
|
+
raise ValueError("No rows remaining after dropping NaN values")
|
|
85
|
+
return clean_data
|
|
86
|
+
elif self.na_action == "ignore":
|
|
87
|
+
pass
|
|
88
|
+
else:
|
|
89
|
+
assert_never()
|
|
90
|
+
|
|
91
|
+
return data.copy()
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def columns(self) -> list[str]:
|
|
95
|
+
"""Get list of available column names."""
|
|
96
|
+
return list(self.data.columns)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def shape(self) -> tuple[int, int]:
|
|
100
|
+
"""Get shape of the data after NA handling."""
|
|
101
|
+
return self.data.shape
|
|
102
|
+
|
|
103
|
+
def _to_jax(self, values: Any, var_name: str) -> Array:
|
|
104
|
+
"""Check if values are compatible with JAX."""
|
|
105
|
+
try:
|
|
106
|
+
array = jnp.asarray(values)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise TypeError(
|
|
109
|
+
f"Variable '{var_name}' could not convert to JAX array"
|
|
110
|
+
) from e
|
|
111
|
+
|
|
112
|
+
return array
|
|
113
|
+
|
|
114
|
+
def _get_cache_key(
|
|
115
|
+
self, name: str, transform: Callable | None, var_name: str | None
|
|
116
|
+
) -> str:
|
|
117
|
+
"""Generate cache key for variable with optional transform."""
|
|
118
|
+
if transform is None:
|
|
119
|
+
return name
|
|
120
|
+
|
|
121
|
+
transform_id = getattr(transform, "__name__", str(transform))
|
|
122
|
+
cache_name = var_name or f"{name}_{transform_id}"
|
|
123
|
+
return cache_name
|
|
124
|
+
|
|
125
|
+
def _is_closure(self, func: Callable) -> bool:
|
|
126
|
+
"""Check if function is a closure (captures variables from outer scope)."""
|
|
127
|
+
return func.__closure__ is not None
|
|
128
|
+
|
|
129
|
+
def _hash_closure_value(self, value: Any) -> str:
|
|
130
|
+
"""Create hash for closure values, specifically supporting JAX arrays."""
|
|
131
|
+
try:
|
|
132
|
+
# try direct hashing first
|
|
133
|
+
return str(hash(value))
|
|
134
|
+
except TypeError:
|
|
135
|
+
# handle unhashable types
|
|
136
|
+
if isinstance(value, jnp.ndarray):
|
|
137
|
+
# JAX arrays: hash shape, dtype, and content
|
|
138
|
+
return f"jax_array_{value.shape}_{value.dtype}_{hash(value.tobytes())}"
|
|
139
|
+
else:
|
|
140
|
+
# unsupported type - signal to skip caching
|
|
141
|
+
raise CannotHashValueError(value)
|
|
142
|
+
|
|
143
|
+
def _hash_function(self, func: Callable) -> str | None:
|
|
144
|
+
"""Create hash for function, or use object ID for methods/callable objects."""
|
|
145
|
+
if inspect.isfunction(func):
|
|
146
|
+
# Regular functions: hash source code and closures
|
|
147
|
+
source = inspect.getsource(func)
|
|
148
|
+
|
|
149
|
+
if self._is_closure(func):
|
|
150
|
+
# for mypy
|
|
151
|
+
assert func.__closure__ is not None, "Closure should have a closure"
|
|
152
|
+
# hash closure variables
|
|
153
|
+
closure_names = func.__code__.co_freevars
|
|
154
|
+
closure_values = [cell.cell_contents for cell in func.__closure__]
|
|
155
|
+
|
|
156
|
+
closure_hashes = []
|
|
157
|
+
for name, value in zip(closure_names, closure_values):
|
|
158
|
+
try:
|
|
159
|
+
value_hash = self._hash_closure_value(value)
|
|
160
|
+
closure_hashes.append(f"{name}:{value_hash}")
|
|
161
|
+
except CannotHashValueError:
|
|
162
|
+
# unsupported closure variable, skip caching
|
|
163
|
+
warnings.warn(
|
|
164
|
+
f"Function uses unsupported closure variable type "
|
|
165
|
+
f"'{type(value).__name__}'. Provide explicit cache_key "
|
|
166
|
+
f"for caching.",
|
|
167
|
+
UserWarning,
|
|
168
|
+
stacklevel=3,
|
|
169
|
+
)
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
closure_signature = ",".join(sorted(closure_hashes))
|
|
173
|
+
else:
|
|
174
|
+
closure_signature = ""
|
|
175
|
+
|
|
176
|
+
# combine source and closure state
|
|
177
|
+
combined = f"{source}|{closure_signature}"
|
|
178
|
+
return hashlib.md5(combined.encode()).hexdigest()
|
|
179
|
+
|
|
180
|
+
elif inspect.ismethod(func):
|
|
181
|
+
# Bound method: use object ID + method name for consistent caching
|
|
182
|
+
obj_id = id(func.__self__)
|
|
183
|
+
method_name = func.__name__
|
|
184
|
+
return f"method_{obj_id}_{method_name}"
|
|
185
|
+
|
|
186
|
+
elif hasattr(func, "__call__"):
|
|
187
|
+
# Callable objects, lambdas, etc.: use object ID
|
|
188
|
+
return f"obj_id_{id(func)}"
|
|
189
|
+
else:
|
|
190
|
+
raise TypeError(f"Unsupported function type: {type(func)}")
|
|
191
|
+
|
|
192
|
+
def get_obs(
|
|
193
|
+
self,
|
|
194
|
+
name: str,
|
|
195
|
+
) -> lsl.Var:
|
|
196
|
+
"""Get or create a liesel Var for a data column.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
name: Column name in the data
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
liesel.Var object
|
|
203
|
+
"""
|
|
204
|
+
if name not in self.data.columns:
|
|
205
|
+
available = list(self.data.columns)
|
|
206
|
+
raise KeyError(
|
|
207
|
+
f"Variable '{name}' not found in data. "
|
|
208
|
+
f"Available variables: {sorted(available)}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
varname = self.prefix + name
|
|
212
|
+
|
|
213
|
+
# check if already cached
|
|
214
|
+
if name in self._var_cache:
|
|
215
|
+
var = self._var_cache[name]
|
|
216
|
+
else:
|
|
217
|
+
# get raw values
|
|
218
|
+
values = self._to_jax(self.data[name].to_numpy(), name)
|
|
219
|
+
var = lsl.Var.new_obs(values, name=varname)
|
|
220
|
+
self._var_cache[name] = var
|
|
221
|
+
|
|
222
|
+
return var
|
|
223
|
+
|
|
224
|
+
def _make_derived_var(
|
|
225
|
+
self, base_var: lsl.Var, transform: Callable, var_name: str | None
|
|
226
|
+
) -> lsl.Var:
|
|
227
|
+
"""Apply a transformation to a base variable and return a new Var."""
|
|
228
|
+
if var_name is None:
|
|
229
|
+
var_name = (
|
|
230
|
+
f"{base_var.name}_{getattr(transform, '__name__', str(transform))}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
derived_var = lsl.Var.new_calc(transform, base_var, name=var_name)
|
|
235
|
+
except Exception as e:
|
|
236
|
+
transformation_name = getattr(transform, "__name__", str(transform))
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Failed to apply transformation '{transformation_name}' "
|
|
239
|
+
f"to variable '{base_var.name}': {str(e)}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return derived_var
|
|
243
|
+
|
|
244
|
+
def get_calc(
|
|
245
|
+
self,
|
|
246
|
+
name: str,
|
|
247
|
+
transform: Callable,
|
|
248
|
+
var_name: str | None = None,
|
|
249
|
+
cache_key: str | None = None,
|
|
250
|
+
) -> lsl.Var:
|
|
251
|
+
"""Get a derived version of the variable.
|
|
252
|
+
|
|
253
|
+
Derived variables are cached when possible. Creates a lsl.new_obs for the
|
|
254
|
+
base variable and a lsl.new_calc for the derived variable.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
name: Column name in the data frame
|
|
258
|
+
transform: Callable transformation function to apply
|
|
259
|
+
var_name: Custom name for the resulting variable
|
|
260
|
+
cache_key: Explicit cache key. If provided, skips function hashing.
|
|
261
|
+
Returns:
|
|
262
|
+
liesel.Var object with transformed values
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# get base var
|
|
266
|
+
base_var = self.get_obs(name)
|
|
267
|
+
|
|
268
|
+
# generate cache key
|
|
269
|
+
if cache_key is not None:
|
|
270
|
+
# explicit cache key provided
|
|
271
|
+
full_cache_key = f"{name}_{cache_key}_{var_name or 'default'}"
|
|
272
|
+
else:
|
|
273
|
+
# try to hash the function
|
|
274
|
+
func_hash = self._hash_function(transform)
|
|
275
|
+
if func_hash is None:
|
|
276
|
+
# caching not possible, return derived var without caching
|
|
277
|
+
return self._make_derived_var(base_var, transform, var_name)
|
|
278
|
+
|
|
279
|
+
full_cache_key = f"{name}_{func_hash}_{var_name or 'default'}"
|
|
280
|
+
|
|
281
|
+
# check cache first
|
|
282
|
+
if full_cache_key in self._derived_cache:
|
|
283
|
+
return self._derived_cache[full_cache_key]
|
|
284
|
+
|
|
285
|
+
# cache miss
|
|
286
|
+
var = self._make_derived_var(base_var, transform, var_name)
|
|
287
|
+
self._derived_cache[full_cache_key] = var
|
|
288
|
+
|
|
289
|
+
return var
|
|
290
|
+
|
|
291
|
+
def get_calc_centered(self, name: str, var_name: str | None = None) -> lsl.Var:
|
|
292
|
+
"""Get a centered version of the variable: x - mean(x).
|
|
293
|
+
|
|
294
|
+
note, mean(x) is computed from the original data and cached.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
name: Column name in the data
|
|
298
|
+
var_name: Custom name for the resulting variable
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
liesel.Var object with centered values
|
|
302
|
+
"""
|
|
303
|
+
base_var = self.get_obs(name)
|
|
304
|
+
values = base_var.value
|
|
305
|
+
|
|
306
|
+
mean_val = float(np.mean(values))
|
|
307
|
+
|
|
308
|
+
def center_transform(x):
|
|
309
|
+
return x - mean_val
|
|
310
|
+
|
|
311
|
+
center_transform.__name__ = "centered"
|
|
312
|
+
|
|
313
|
+
return self._make_derived_var(
|
|
314
|
+
base_var, center_transform, var_name or f"{name}_centered"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def get_calc_standardized(self, name: str, var_name: str | None = None) -> lsl.Var:
|
|
318
|
+
"""Get a standardized version of the variable: (x - mean(x)) / std(x).
|
|
319
|
+
|
|
320
|
+
note, mean(x) and std(x) are computed from the original data and cached.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
name: Column name in the data
|
|
324
|
+
var_name: Custom name for the resulting variable
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
liesel.Var object with standardized values
|
|
328
|
+
"""
|
|
329
|
+
base_var = self.get_obs(name)
|
|
330
|
+
values = base_var.value
|
|
331
|
+
|
|
332
|
+
mean_val = float(np.mean(values))
|
|
333
|
+
std_val = float(np.std(values))
|
|
334
|
+
|
|
335
|
+
if std_val == 0:
|
|
336
|
+
raise ValueError(
|
|
337
|
+
f"Failed to apply transformation 'standardization' to variable "
|
|
338
|
+
f"'{name}': standard deviation is zero (constant variable)"
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def std_transform(x):
|
|
342
|
+
return (x - mean_val) / std_val
|
|
343
|
+
|
|
344
|
+
std_transform.__name__ = "std"
|
|
345
|
+
|
|
346
|
+
return self._make_derived_var(
|
|
347
|
+
base_var, std_transform, var_name or f"{name}_std"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def get_calc_dummymatrix(
|
|
351
|
+
self, name: str, var_name_prefix: str | None = None
|
|
352
|
+
) -> lsl.Var:
|
|
353
|
+
"""Get dummy matrix for a categorical column using standard dummy coding.
|
|
354
|
+
|
|
355
|
+
Drops the column of the first category.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
name: Column name in the data
|
|
359
|
+
var_name_prefix: Prefix for dummy variable names
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Dictionary mapping category names to liesel.Var objects
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
base_var, mapping = self.get_categorical_obs(name)
|
|
366
|
+
base_var.name = base_var.name = f"{name}_codes"
|
|
367
|
+
|
|
368
|
+
codebook = mapping.labels_to_integers_map
|
|
369
|
+
|
|
370
|
+
if len(codebook) < 2:
|
|
371
|
+
raise ValueError(
|
|
372
|
+
f"Failed to apply transformation 'dummy encoding' to variable "
|
|
373
|
+
f"'{name}': only {len(codebook)} unique value(s) found"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# jax-compatible dummy coding transformation
|
|
377
|
+
n_categories = len(codebook)
|
|
378
|
+
|
|
379
|
+
def dummy_transform(codes):
|
|
380
|
+
# create dummy matrix with standard dummy coding (drop first category)
|
|
381
|
+
# use float32 to support NaN for unknown codes
|
|
382
|
+
dummy_matrix = jnp.zeros(
|
|
383
|
+
(codes.shape[0], n_categories - 1), dtype=jnp.float32
|
|
384
|
+
)
|
|
385
|
+
for i in range(1, n_categories): # only a few cat, so for loop is fine
|
|
386
|
+
dummy_matrix = dummy_matrix.at[:, i - 1].set(codes == i)
|
|
387
|
+
|
|
388
|
+
# set rows with unknown codes (>= n_categories or < 0) to NaN
|
|
389
|
+
unknown_mask = (codes >= n_categories) | (codes < 0)
|
|
390
|
+
dummy_matrix = jnp.where(unknown_mask[:, None], jnp.nan, dummy_matrix)
|
|
391
|
+
|
|
392
|
+
return dummy_matrix
|
|
393
|
+
|
|
394
|
+
dummy_transform.__name__ = f"{name}_dummy"
|
|
395
|
+
|
|
396
|
+
# create dummy matrix variable
|
|
397
|
+
prefix = var_name_prefix or f"{name}_"
|
|
398
|
+
dummy_matrix_name = f"{prefix}matrix"
|
|
399
|
+
dummy_matrix_var = lsl.Var.new_calc(
|
|
400
|
+
dummy_transform, base_var, name=dummy_matrix_name
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return dummy_matrix_var
|
|
404
|
+
|
|
405
|
+
def is_numeric(self, name: str) -> bool:
|
|
406
|
+
"""Check if a variable is numeric.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
name: Column name in the data
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
True if variable is numeric, False otherwise
|
|
413
|
+
"""
|
|
414
|
+
if name not in self.data.columns:
|
|
415
|
+
available = list(self.data.columns)
|
|
416
|
+
raise KeyError(
|
|
417
|
+
f"Variable '{name}' not found in data. "
|
|
418
|
+
f"Available variables: {sorted(available)}"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
return pd.api.types.is_numeric_dtype(self.data[name])
|
|
422
|
+
|
|
423
|
+
def is_categorical(self, name: str) -> bool:
|
|
424
|
+
"""Check if a variable is categorical.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
name: Column name in the data
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
True if variable is categorical, False otherwise
|
|
431
|
+
"""
|
|
432
|
+
if name not in self.data.columns:
|
|
433
|
+
available = list(self.data.columns)
|
|
434
|
+
raise KeyError(
|
|
435
|
+
f"Variable '{name}' not found in data. "
|
|
436
|
+
f"Available variables: {sorted(available)}"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
return series_is_categorical(self.data[name])
|
|
440
|
+
|
|
441
|
+
def is_boolean(self, name: str) -> bool:
|
|
442
|
+
"""Check if a variable is boolean.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
name: Column name in the data
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
True if variable is boolean, False otherwise
|
|
449
|
+
"""
|
|
450
|
+
if name not in self.data.columns:
|
|
451
|
+
available = list(self.data.columns)
|
|
452
|
+
raise KeyError(
|
|
453
|
+
f"Variable '{name}' not found in data. "
|
|
454
|
+
f"Available variables: {sorted(available)}"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
return pd.api.types.is_bool_dtype(self.data[name])
|
|
458
|
+
|
|
459
|
+
def get_numeric_obs(self, name: str) -> lsl.Var:
|
|
460
|
+
"""Get a variable and ensure it is numeric.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
name: Variable name to retrieve
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
liesel.Var object for the numeric variable
|
|
467
|
+
|
|
468
|
+
Raises:
|
|
469
|
+
TypeError: If the variable is not numeric
|
|
470
|
+
"""
|
|
471
|
+
if not self.is_numeric(name):
|
|
472
|
+
raise TypeError(
|
|
473
|
+
f"Type mismatch for variable '{name}': expected numeric, "
|
|
474
|
+
f"got {str(self.data[name].dtype)}"
|
|
475
|
+
)
|
|
476
|
+
return self.get_obs(name)
|
|
477
|
+
|
|
478
|
+
def get_categorical_obs(self, name: str) -> tuple[lsl.Var, CategoryMapping]:
|
|
479
|
+
"""Get a variable and ensure it is categorical.
|
|
480
|
+
|
|
481
|
+
Each variable is converted to integer codes.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
name: Variable name to retrieve
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
liesel.Var object for the categorical variable and a CategoryMapping.
|
|
488
|
+
|
|
489
|
+
Raises:
|
|
490
|
+
TypeError: If any variable is not categorical
|
|
491
|
+
"""
|
|
492
|
+
series = self.data[name]
|
|
493
|
+
if not self.is_categorical(name):
|
|
494
|
+
raise TypeError(
|
|
495
|
+
f"Type mismatch for variable '{name}': expected categorical, "
|
|
496
|
+
f"got {str(series.dtype)}"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
mapping = CategoryMapping.from_series(series)
|
|
500
|
+
if name in self._var_cache:
|
|
501
|
+
var = self._var_cache[name]
|
|
502
|
+
else:
|
|
503
|
+
# convert categorical variables to integer codes
|
|
504
|
+
category_codes = mapping.labels_to_integers(series)
|
|
505
|
+
jax_codes = self._to_jax(category_codes, name)
|
|
506
|
+
varname = self.prefix + name
|
|
507
|
+
var = lsl.Var.new_obs(jax_codes, name=varname)
|
|
508
|
+
self._var_cache[name] = var
|
|
509
|
+
|
|
510
|
+
# now some exception handling
|
|
511
|
+
# only emitted once
|
|
512
|
+
nparams = len(mapping.labels_to_integers_map)
|
|
513
|
+
n_observed_clusters = jnp.unique(var.value).size
|
|
514
|
+
observed_clusters = np.unique(var.value).tolist()
|
|
515
|
+
clusters = list(mapping.integers_to_labels_map)
|
|
516
|
+
clusters_not_in_data = [c for c in clusters if c not in observed_clusters]
|
|
517
|
+
|
|
518
|
+
if n_observed_clusters != nparams:
|
|
519
|
+
logger.info(
|
|
520
|
+
f"For {name}, there are {nparams} categories, but the "
|
|
521
|
+
f"data contain observations for only {n_observed_clusters}. The "
|
|
522
|
+
f"categories without observations are: {clusters_not_in_data}. "
|
|
523
|
+
"If this is intended, you can ignore this warning. "
|
|
524
|
+
"Be aware, that parameters for the unobserved categories may be "
|
|
525
|
+
"included in the model."
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
return var, mapping
|
|
529
|
+
|
|
530
|
+
def get_boolean_obs(self, name: str) -> lsl.Var:
|
|
531
|
+
"""Get a variable and ensure it is boolean.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
name: Variable name to retrieve
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
liesel.Var object for the boolean variable
|
|
538
|
+
|
|
539
|
+
Raises:
|
|
540
|
+
TypeError: If the variable is not boolean
|
|
541
|
+
"""
|
|
542
|
+
if not self.is_boolean(name):
|
|
543
|
+
raise TypeError(
|
|
544
|
+
f"Type mismatch for variable '{name}': expected boolean, "
|
|
545
|
+
f"got {str(self.data[name].dtype)}"
|
|
546
|
+
)
|
|
547
|
+
return self.get_obs(name)
|
|
548
|
+
|
|
549
|
+
def get_obs_and_mapping(self, name: str) -> VarAndMapping:
|
|
550
|
+
"""
|
|
551
|
+
Get an observed variable. Returns a wrapper that holds the variable and,
|
|
552
|
+
if the variable is categorical, the :class:`.CategoryMapping` between
|
|
553
|
+
labels and integer codes.
|
|
554
|
+
"""
|
|
555
|
+
if self.is_categorical(name):
|
|
556
|
+
var, mapping = self.get_categorical_obs(name)
|
|
557
|
+
else:
|
|
558
|
+
var = self.get_obs(name)
|
|
559
|
+
mapping = None
|
|
560
|
+
|
|
561
|
+
return VarAndMapping(var, mapping)
|
liesel_gam/constraint.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import Array
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def penalty_to_unit_design(penalty: Array, rank: Array | int | None = None) -> Array:
|
|
6
|
+
"""
|
|
7
|
+
Convert a (semi-)definite penalty matrix into the design matrix
|
|
8
|
+
projector used by mixed-model reparameterizations.
|
|
9
|
+
|
|
10
|
+
The routine performs an eigenvalue decomposition of `penalty`, keeps the
|
|
11
|
+
first `rank` eigenvectors (default: numerical rank of `penalty`), rescales
|
|
12
|
+
them to have unit marginal variance (1 / sqrt(lambda)), and returns the
|
|
13
|
+
resulting loading matrix.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
penalty
|
|
18
|
+
Positive semi-definite penalty matrix.
|
|
19
|
+
rank
|
|
20
|
+
Optional target rank. Defaults to the matrix rank inferred from
|
|
21
|
+
``penalty``.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
A matrix whose columns span the penalized subspace and are scaled for
|
|
26
|
+
mixed-model formulations.
|
|
27
|
+
"""
|
|
28
|
+
if rank is None:
|
|
29
|
+
rank = jnp.linalg.matrix_rank(penalty)
|
|
30
|
+
|
|
31
|
+
evalues, evectors = jnp.linalg.eigh(penalty)
|
|
32
|
+
evalues = evalues[::-1] # put in decreasing order
|
|
33
|
+
evectors = evectors[:, ::-1] # make order correspond to eigenvalues
|
|
34
|
+
rank = jnp.linalg.matrix_rank(penalty)
|
|
35
|
+
|
|
36
|
+
if evectors[0, 0] < 0:
|
|
37
|
+
evectors = -evectors
|
|
38
|
+
|
|
39
|
+
U = evectors
|
|
40
|
+
D = 1 / jnp.sqrt(jnp.ones_like(evalues).at[:rank].set(evalues[:rank]))
|
|
41
|
+
Z = (U.T * jnp.expand_dims(D, 1)).T
|
|
42
|
+
return Z
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LinearConstraintEVD:
|
|
46
|
+
@staticmethod
|
|
47
|
+
def general(constraint: Array) -> Array:
|
|
48
|
+
A = constraint
|
|
49
|
+
nconstraints, _ = A.shape
|
|
50
|
+
|
|
51
|
+
AtA = A.T @ A
|
|
52
|
+
evals, evecs = jnp.linalg.eigh(AtA)
|
|
53
|
+
|
|
54
|
+
if evecs[0, 0] < 0:
|
|
55
|
+
evecs = -evecs
|
|
56
|
+
|
|
57
|
+
rank = jnp.linalg.matrix_rank(AtA)
|
|
58
|
+
Abar = evecs[:-rank]
|
|
59
|
+
|
|
60
|
+
A_stacked = jnp.r_[A, Abar]
|
|
61
|
+
C_stacked = jnp.linalg.inv(A_stacked)
|
|
62
|
+
Cbar = C_stacked[:, nconstraints:]
|
|
63
|
+
return Cbar
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def _nullspace(cls, penalty: Array, rank: float | Array | None = None) -> Array:
|
|
67
|
+
if rank is None:
|
|
68
|
+
rank = jnp.linalg.matrix_rank(penalty)
|
|
69
|
+
evals, evecs = jnp.linalg.eigh(penalty)
|
|
70
|
+
evals = evals[::-1] # put in decreasing order
|
|
71
|
+
evecs = evecs[:, ::-1] # make order correspond to eigenvalues
|
|
72
|
+
rank = jnp.sum(evals > 1e-6)
|
|
73
|
+
|
|
74
|
+
if evecs[0, 0] < 0:
|
|
75
|
+
evecs = -evecs
|
|
76
|
+
|
|
77
|
+
U = evecs
|
|
78
|
+
D = 1 / jnp.sqrt(jnp.ones_like(evals).at[:rank].set(evals[:rank]))
|
|
79
|
+
Z = (U.T * jnp.expand_dims(D, 1)).T
|
|
80
|
+
Abar = Z[:, :rank]
|
|
81
|
+
|
|
82
|
+
return Abar
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def constant_and_linear(cls, x: Array, basis: Array) -> Array:
|
|
86
|
+
nobs = jnp.shape(x)[0]
|
|
87
|
+
j = jnp.ones(shape=nobs)
|
|
88
|
+
X = jnp.c_[j, x]
|
|
89
|
+
A = jnp.linalg.inv(X.T @ X) @ X.T @ basis
|
|
90
|
+
return cls.general(constraint=A)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def sumzero_coef(cls, ncoef: int) -> Array:
|
|
94
|
+
j = jnp.ones(shape=(1, ncoef))
|
|
95
|
+
return cls.general(constraint=j)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def sumzero_term(cls, basis: Array) -> Array:
|
|
99
|
+
nobs = jnp.shape(basis)[0]
|
|
100
|
+
j = jnp.ones(shape=nobs)
|
|
101
|
+
A = jnp.expand_dims(j @ basis, 0)
|
|
102
|
+
return cls.general(constraint=A)
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def sumzero_term2(cls, basis: Array) -> Array:
|
|
106
|
+
A = jnp.mean(basis, axis=0, keepdims=True)
|
|
107
|
+
return cls.general(constraint=A)
|