liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,561 @@
1
+ """Variable registry for managing data variables and transformations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import inspect
7
+ import logging
8
+ import warnings
9
+ from collections.abc import Callable
10
+ from dataclasses import dataclass
11
+ from typing import Any, Literal, assert_never
12
+
13
+ import jax.numpy as jnp
14
+ import liesel.model as lsl
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ from .category_mapping import CategoryMapping, series_is_categorical
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ Array = Any
23
+
24
+
25
+ class CannotHashValueError(Exception):
26
+ """Custom exception for values that cannot be hashed."""
27
+
28
+ def __init__(self, value: Any):
29
+ super().__init__(f"Cannot hash value of type '{type(value).__name__}'")
30
+ self.value = value
31
+
32
+
33
+ @dataclass
34
+ class VarAndMapping:
35
+ var: lsl.Var
36
+ mapping: CategoryMapping | None = None
37
+
38
+ @property
39
+ def is_categorical(self) -> bool:
40
+ return self.mapping is not None
41
+
42
+
43
+ class PandasRegistry:
44
+ """Registry for managing variables and their transformations.
45
+
46
+ Handles conversion from `pandas.DataFrame` to `liesel.Var` objects,
47
+ applies transformations, and caches results for efficiency.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ data: pd.DataFrame,
53
+ na_action: Literal["error", "drop", "ignore"] = "error",
54
+ prefix_names_by: str = "",
55
+ ):
56
+ """Initialize the variable registry.
57
+
58
+ Args:
59
+ data: pandas DataFrame containing model variables
60
+ na_action: How to handle NaN values. Either "error", "drop", or "ignore"
61
+ """
62
+ if na_action not in ["error", "drop", "ignore"]:
63
+ raise ValueError("na_action must be 'error', 'drop', or 'ignore'")
64
+
65
+ self.original_data = data.copy()
66
+ self.na_action = na_action
67
+ self.data = self._validate_data(data)
68
+ self._var_cache: dict[str, lsl.Var] = {}
69
+ self._derived_cache: dict[str, lsl.Var] = {}
70
+ self.prefix = prefix_names_by
71
+
72
+ def _validate_data(self, data: pd.DataFrame) -> pd.DataFrame:
73
+ """Validate data and handle NaN values according to policy."""
74
+ if data.isna().any().any():
75
+ if self.na_action == "error":
76
+ na_cols = data.columns[data.isna().any()].tolist()
77
+ raise ValueError(
78
+ f"Data contains NaN values in columns: {na_cols}. "
79
+ "Use na_action='drop' to automatically remove rows with NaN values."
80
+ )
81
+ elif self.na_action == "drop":
82
+ clean_data = data.dropna()
83
+ if len(clean_data) == 0:
84
+ raise ValueError("No rows remaining after dropping NaN values")
85
+ return clean_data
86
+ elif self.na_action == "ignore":
87
+ pass
88
+ else:
89
+ assert_never()
90
+
91
+ return data.copy()
92
+
93
+ @property
94
+ def columns(self) -> list[str]:
95
+ """Get list of available column names."""
96
+ return list(self.data.columns)
97
+
98
+ @property
99
+ def shape(self) -> tuple[int, int]:
100
+ """Get shape of the data after NA handling."""
101
+ return self.data.shape
102
+
103
+ def _to_jax(self, values: Any, var_name: str) -> Array:
104
+ """Check if values are compatible with JAX."""
105
+ try:
106
+ array = jnp.asarray(values)
107
+ except Exception as e:
108
+ raise TypeError(
109
+ f"Variable '{var_name}' could not convert to JAX array"
110
+ ) from e
111
+
112
+ return array
113
+
114
+ def _get_cache_key(
115
+ self, name: str, transform: Callable | None, var_name: str | None
116
+ ) -> str:
117
+ """Generate cache key for variable with optional transform."""
118
+ if transform is None:
119
+ return name
120
+
121
+ transform_id = getattr(transform, "__name__", str(transform))
122
+ cache_name = var_name or f"{name}_{transform_id}"
123
+ return cache_name
124
+
125
+ def _is_closure(self, func: Callable) -> bool:
126
+ """Check if function is a closure (captures variables from outer scope)."""
127
+ return func.__closure__ is not None
128
+
129
+ def _hash_closure_value(self, value: Any) -> str:
130
+ """Create hash for closure values, specifically supporting JAX arrays."""
131
+ try:
132
+ # try direct hashing first
133
+ return str(hash(value))
134
+ except TypeError:
135
+ # handle unhashable types
136
+ if isinstance(value, jnp.ndarray):
137
+ # JAX arrays: hash shape, dtype, and content
138
+ return f"jax_array_{value.shape}_{value.dtype}_{hash(value.tobytes())}"
139
+ else:
140
+ # unsupported type - signal to skip caching
141
+ raise CannotHashValueError(value)
142
+
143
+ def _hash_function(self, func: Callable) -> str | None:
144
+ """Create hash for function, or use object ID for methods/callable objects."""
145
+ if inspect.isfunction(func):
146
+ # Regular functions: hash source code and closures
147
+ source = inspect.getsource(func)
148
+
149
+ if self._is_closure(func):
150
+ # for mypy
151
+ assert func.__closure__ is not None, "Closure should have a closure"
152
+ # hash closure variables
153
+ closure_names = func.__code__.co_freevars
154
+ closure_values = [cell.cell_contents for cell in func.__closure__]
155
+
156
+ closure_hashes = []
157
+ for name, value in zip(closure_names, closure_values):
158
+ try:
159
+ value_hash = self._hash_closure_value(value)
160
+ closure_hashes.append(f"{name}:{value_hash}")
161
+ except CannotHashValueError:
162
+ # unsupported closure variable, skip caching
163
+ warnings.warn(
164
+ f"Function uses unsupported closure variable type "
165
+ f"'{type(value).__name__}'. Provide explicit cache_key "
166
+ f"for caching.",
167
+ UserWarning,
168
+ stacklevel=3,
169
+ )
170
+ return None
171
+
172
+ closure_signature = ",".join(sorted(closure_hashes))
173
+ else:
174
+ closure_signature = ""
175
+
176
+ # combine source and closure state
177
+ combined = f"{source}|{closure_signature}"
178
+ return hashlib.md5(combined.encode()).hexdigest()
179
+
180
+ elif inspect.ismethod(func):
181
+ # Bound method: use object ID + method name for consistent caching
182
+ obj_id = id(func.__self__)
183
+ method_name = func.__name__
184
+ return f"method_{obj_id}_{method_name}"
185
+
186
+ elif hasattr(func, "__call__"):
187
+ # Callable objects, lambdas, etc.: use object ID
188
+ return f"obj_id_{id(func)}"
189
+ else:
190
+ raise TypeError(f"Unsupported function type: {type(func)}")
191
+
192
+ def get_obs(
193
+ self,
194
+ name: str,
195
+ ) -> lsl.Var:
196
+ """Get or create a liesel Var for a data column.
197
+
198
+ Args:
199
+ name: Column name in the data
200
+
201
+ Returns:
202
+ liesel.Var object
203
+ """
204
+ if name not in self.data.columns:
205
+ available = list(self.data.columns)
206
+ raise KeyError(
207
+ f"Variable '{name}' not found in data. "
208
+ f"Available variables: {sorted(available)}"
209
+ )
210
+
211
+ varname = self.prefix + name
212
+
213
+ # check if already cached
214
+ if name in self._var_cache:
215
+ var = self._var_cache[name]
216
+ else:
217
+ # get raw values
218
+ values = self._to_jax(self.data[name].to_numpy(), name)
219
+ var = lsl.Var.new_obs(values, name=varname)
220
+ self._var_cache[name] = var
221
+
222
+ return var
223
+
224
+ def _make_derived_var(
225
+ self, base_var: lsl.Var, transform: Callable, var_name: str | None
226
+ ) -> lsl.Var:
227
+ """Apply a transformation to a base variable and return a new Var."""
228
+ if var_name is None:
229
+ var_name = (
230
+ f"{base_var.name}_{getattr(transform, '__name__', str(transform))}"
231
+ )
232
+
233
+ try:
234
+ derived_var = lsl.Var.new_calc(transform, base_var, name=var_name)
235
+ except Exception as e:
236
+ transformation_name = getattr(transform, "__name__", str(transform))
237
+ raise ValueError(
238
+ f"Failed to apply transformation '{transformation_name}' "
239
+ f"to variable '{base_var.name}': {str(e)}"
240
+ )
241
+
242
+ return derived_var
243
+
244
+ def get_calc(
245
+ self,
246
+ name: str,
247
+ transform: Callable,
248
+ var_name: str | None = None,
249
+ cache_key: str | None = None,
250
+ ) -> lsl.Var:
251
+ """Get a derived version of the variable.
252
+
253
+ Derived variables are cached when possible. Creates a lsl.new_obs for the
254
+ base variable and a lsl.new_calc for the derived variable.
255
+
256
+ Args:
257
+ name: Column name in the data frame
258
+ transform: Callable transformation function to apply
259
+ var_name: Custom name for the resulting variable
260
+ cache_key: Explicit cache key. If provided, skips function hashing.
261
+ Returns:
262
+ liesel.Var object with transformed values
263
+ """
264
+
265
+ # get base var
266
+ base_var = self.get_obs(name)
267
+
268
+ # generate cache key
269
+ if cache_key is not None:
270
+ # explicit cache key provided
271
+ full_cache_key = f"{name}_{cache_key}_{var_name or 'default'}"
272
+ else:
273
+ # try to hash the function
274
+ func_hash = self._hash_function(transform)
275
+ if func_hash is None:
276
+ # caching not possible, return derived var without caching
277
+ return self._make_derived_var(base_var, transform, var_name)
278
+
279
+ full_cache_key = f"{name}_{func_hash}_{var_name or 'default'}"
280
+
281
+ # check cache first
282
+ if full_cache_key in self._derived_cache:
283
+ return self._derived_cache[full_cache_key]
284
+
285
+ # cache miss
286
+ var = self._make_derived_var(base_var, transform, var_name)
287
+ self._derived_cache[full_cache_key] = var
288
+
289
+ return var
290
+
291
+ def get_calc_centered(self, name: str, var_name: str | None = None) -> lsl.Var:
292
+ """Get a centered version of the variable: x - mean(x).
293
+
294
+ note, mean(x) is computed from the original data and cached.
295
+
296
+ Args:
297
+ name: Column name in the data
298
+ var_name: Custom name for the resulting variable
299
+
300
+ Returns:
301
+ liesel.Var object with centered values
302
+ """
303
+ base_var = self.get_obs(name)
304
+ values = base_var.value
305
+
306
+ mean_val = float(np.mean(values))
307
+
308
+ def center_transform(x):
309
+ return x - mean_val
310
+
311
+ center_transform.__name__ = "centered"
312
+
313
+ return self._make_derived_var(
314
+ base_var, center_transform, var_name or f"{name}_centered"
315
+ )
316
+
317
+ def get_calc_standardized(self, name: str, var_name: str | None = None) -> lsl.Var:
318
+ """Get a standardized version of the variable: (x - mean(x)) / std(x).
319
+
320
+ note, mean(x) and std(x) are computed from the original data and cached.
321
+
322
+ Args:
323
+ name: Column name in the data
324
+ var_name: Custom name for the resulting variable
325
+
326
+ Returns:
327
+ liesel.Var object with standardized values
328
+ """
329
+ base_var = self.get_obs(name)
330
+ values = base_var.value
331
+
332
+ mean_val = float(np.mean(values))
333
+ std_val = float(np.std(values))
334
+
335
+ if std_val == 0:
336
+ raise ValueError(
337
+ f"Failed to apply transformation 'standardization' to variable "
338
+ f"'{name}': standard deviation is zero (constant variable)"
339
+ )
340
+
341
+ def std_transform(x):
342
+ return (x - mean_val) / std_val
343
+
344
+ std_transform.__name__ = "std"
345
+
346
+ return self._make_derived_var(
347
+ base_var, std_transform, var_name or f"{name}_std"
348
+ )
349
+
350
+ def get_calc_dummymatrix(
351
+ self, name: str, var_name_prefix: str | None = None
352
+ ) -> lsl.Var:
353
+ """Get dummy matrix for a categorical column using standard dummy coding.
354
+
355
+ Drops the column of the first category.
356
+
357
+ Args:
358
+ name: Column name in the data
359
+ var_name_prefix: Prefix for dummy variable names
360
+
361
+ Returns:
362
+ Dictionary mapping category names to liesel.Var objects
363
+ """
364
+
365
+ base_var, mapping = self.get_categorical_obs(name)
366
+ base_var.name = base_var.name = f"{name}_codes"
367
+
368
+ codebook = mapping.labels_to_integers_map
369
+
370
+ if len(codebook) < 2:
371
+ raise ValueError(
372
+ f"Failed to apply transformation 'dummy encoding' to variable "
373
+ f"'{name}': only {len(codebook)} unique value(s) found"
374
+ )
375
+
376
+ # jax-compatible dummy coding transformation
377
+ n_categories = len(codebook)
378
+
379
+ def dummy_transform(codes):
380
+ # create dummy matrix with standard dummy coding (drop first category)
381
+ # use float32 to support NaN for unknown codes
382
+ dummy_matrix = jnp.zeros(
383
+ (codes.shape[0], n_categories - 1), dtype=jnp.float32
384
+ )
385
+ for i in range(1, n_categories): # only a few cat, so for loop is fine
386
+ dummy_matrix = dummy_matrix.at[:, i - 1].set(codes == i)
387
+
388
+ # set rows with unknown codes (>= n_categories or < 0) to NaN
389
+ unknown_mask = (codes >= n_categories) | (codes < 0)
390
+ dummy_matrix = jnp.where(unknown_mask[:, None], jnp.nan, dummy_matrix)
391
+
392
+ return dummy_matrix
393
+
394
+ dummy_transform.__name__ = f"{name}_dummy"
395
+
396
+ # create dummy matrix variable
397
+ prefix = var_name_prefix or f"{name}_"
398
+ dummy_matrix_name = f"{prefix}matrix"
399
+ dummy_matrix_var = lsl.Var.new_calc(
400
+ dummy_transform, base_var, name=dummy_matrix_name
401
+ )
402
+
403
+ return dummy_matrix_var
404
+
405
+ def is_numeric(self, name: str) -> bool:
406
+ """Check if a variable is numeric.
407
+
408
+ Args:
409
+ name: Column name in the data
410
+
411
+ Returns:
412
+ True if variable is numeric, False otherwise
413
+ """
414
+ if name not in self.data.columns:
415
+ available = list(self.data.columns)
416
+ raise KeyError(
417
+ f"Variable '{name}' not found in data. "
418
+ f"Available variables: {sorted(available)}"
419
+ )
420
+
421
+ return pd.api.types.is_numeric_dtype(self.data[name])
422
+
423
+ def is_categorical(self, name: str) -> bool:
424
+ """Check if a variable is categorical.
425
+
426
+ Args:
427
+ name: Column name in the data
428
+
429
+ Returns:
430
+ True if variable is categorical, False otherwise
431
+ """
432
+ if name not in self.data.columns:
433
+ available = list(self.data.columns)
434
+ raise KeyError(
435
+ f"Variable '{name}' not found in data. "
436
+ f"Available variables: {sorted(available)}"
437
+ )
438
+
439
+ return series_is_categorical(self.data[name])
440
+
441
+ def is_boolean(self, name: str) -> bool:
442
+ """Check if a variable is boolean.
443
+
444
+ Args:
445
+ name: Column name in the data
446
+
447
+ Returns:
448
+ True if variable is boolean, False otherwise
449
+ """
450
+ if name not in self.data.columns:
451
+ available = list(self.data.columns)
452
+ raise KeyError(
453
+ f"Variable '{name}' not found in data. "
454
+ f"Available variables: {sorted(available)}"
455
+ )
456
+
457
+ return pd.api.types.is_bool_dtype(self.data[name])
458
+
459
+ def get_numeric_obs(self, name: str) -> lsl.Var:
460
+ """Get a variable and ensure it is numeric.
461
+
462
+ Args:
463
+ name: Variable name to retrieve
464
+
465
+ Returns:
466
+ liesel.Var object for the numeric variable
467
+
468
+ Raises:
469
+ TypeError: If the variable is not numeric
470
+ """
471
+ if not self.is_numeric(name):
472
+ raise TypeError(
473
+ f"Type mismatch for variable '{name}': expected numeric, "
474
+ f"got {str(self.data[name].dtype)}"
475
+ )
476
+ return self.get_obs(name)
477
+
478
+ def get_categorical_obs(self, name: str) -> tuple[lsl.Var, CategoryMapping]:
479
+ """Get a variable and ensure it is categorical.
480
+
481
+ Each variable is converted to integer codes.
482
+
483
+ Args:
484
+ name: Variable name to retrieve
485
+
486
+ Returns:
487
+ liesel.Var object for the categorical variable and a CategoryMapping.
488
+
489
+ Raises:
490
+ TypeError: If any variable is not categorical
491
+ """
492
+ series = self.data[name]
493
+ if not self.is_categorical(name):
494
+ raise TypeError(
495
+ f"Type mismatch for variable '{name}': expected categorical, "
496
+ f"got {str(series.dtype)}"
497
+ )
498
+
499
+ mapping = CategoryMapping.from_series(series)
500
+ if name in self._var_cache:
501
+ var = self._var_cache[name]
502
+ else:
503
+ # convert categorical variables to integer codes
504
+ category_codes = mapping.labels_to_integers(series)
505
+ jax_codes = self._to_jax(category_codes, name)
506
+ varname = self.prefix + name
507
+ var = lsl.Var.new_obs(jax_codes, name=varname)
508
+ self._var_cache[name] = var
509
+
510
+ # now some exception handling
511
+ # only emitted once
512
+ nparams = len(mapping.labels_to_integers_map)
513
+ n_observed_clusters = jnp.unique(var.value).size
514
+ observed_clusters = np.unique(var.value).tolist()
515
+ clusters = list(mapping.integers_to_labels_map)
516
+ clusters_not_in_data = [c for c in clusters if c not in observed_clusters]
517
+
518
+ if n_observed_clusters != nparams:
519
+ logger.info(
520
+ f"For {name}, there are {nparams} categories, but the "
521
+ f"data contain observations for only {n_observed_clusters}. The "
522
+ f"categories without observations are: {clusters_not_in_data}. "
523
+ "If this is intended, you can ignore this warning. "
524
+ "Be aware, that parameters for the unobserved categories may be "
525
+ "included in the model."
526
+ )
527
+
528
+ return var, mapping
529
+
530
+ def get_boolean_obs(self, name: str) -> lsl.Var:
531
+ """Get a variable and ensure it is boolean.
532
+
533
+ Args:
534
+ name: Variable name to retrieve
535
+
536
+ Returns:
537
+ liesel.Var object for the boolean variable
538
+
539
+ Raises:
540
+ TypeError: If the variable is not boolean
541
+ """
542
+ if not self.is_boolean(name):
543
+ raise TypeError(
544
+ f"Type mismatch for variable '{name}': expected boolean, "
545
+ f"got {str(self.data[name].dtype)}"
546
+ )
547
+ return self.get_obs(name)
548
+
549
+ def get_obs_and_mapping(self, name: str) -> VarAndMapping:
550
+ """
551
+ Get an observed variable. Returns a wrapper that holds the variable and,
552
+ if the variable is categorical, the :class:`.CategoryMapping` between
553
+ labels and integer codes.
554
+ """
555
+ if self.is_categorical(name):
556
+ var, mapping = self.get_categorical_obs(name)
557
+ else:
558
+ var = self.get_obs(name)
559
+ mapping = None
560
+
561
+ return VarAndMapping(var, mapping)
@@ -0,0 +1,107 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+
5
+ def penalty_to_unit_design(penalty: Array, rank: Array | int | None = None) -> Array:
6
+ """
7
+ Convert a (semi-)definite penalty matrix into the design matrix
8
+ projector used by mixed-model reparameterizations.
9
+
10
+ The routine performs an eigenvalue decomposition of `penalty`, keeps the
11
+ first `rank` eigenvectors (default: numerical rank of `penalty`), rescales
12
+ them to have unit marginal variance (1 / sqrt(lambda)), and returns the
13
+ resulting loading matrix.
14
+
15
+ Parameters
16
+ ----------
17
+ penalty
18
+ Positive semi-definite penalty matrix.
19
+ rank
20
+ Optional target rank. Defaults to the matrix rank inferred from
21
+ ``penalty``.
22
+
23
+ Returns
24
+ -------
25
+ A matrix whose columns span the penalized subspace and are scaled for
26
+ mixed-model formulations.
27
+ """
28
+ if rank is None:
29
+ rank = jnp.linalg.matrix_rank(penalty)
30
+
31
+ evalues, evectors = jnp.linalg.eigh(penalty)
32
+ evalues = evalues[::-1] # put in decreasing order
33
+ evectors = evectors[:, ::-1] # make order correspond to eigenvalues
34
+ rank = jnp.linalg.matrix_rank(penalty)
35
+
36
+ if evectors[0, 0] < 0:
37
+ evectors = -evectors
38
+
39
+ U = evectors
40
+ D = 1 / jnp.sqrt(jnp.ones_like(evalues).at[:rank].set(evalues[:rank]))
41
+ Z = (U.T * jnp.expand_dims(D, 1)).T
42
+ return Z
43
+
44
+
45
+ class LinearConstraintEVD:
46
+ @staticmethod
47
+ def general(constraint: Array) -> Array:
48
+ A = constraint
49
+ nconstraints, _ = A.shape
50
+
51
+ AtA = A.T @ A
52
+ evals, evecs = jnp.linalg.eigh(AtA)
53
+
54
+ if evecs[0, 0] < 0:
55
+ evecs = -evecs
56
+
57
+ rank = jnp.linalg.matrix_rank(AtA)
58
+ Abar = evecs[:-rank]
59
+
60
+ A_stacked = jnp.r_[A, Abar]
61
+ C_stacked = jnp.linalg.inv(A_stacked)
62
+ Cbar = C_stacked[:, nconstraints:]
63
+ return Cbar
64
+
65
+ @classmethod
66
+ def _nullspace(cls, penalty: Array, rank: float | Array | None = None) -> Array:
67
+ if rank is None:
68
+ rank = jnp.linalg.matrix_rank(penalty)
69
+ evals, evecs = jnp.linalg.eigh(penalty)
70
+ evals = evals[::-1] # put in decreasing order
71
+ evecs = evecs[:, ::-1] # make order correspond to eigenvalues
72
+ rank = jnp.sum(evals > 1e-6)
73
+
74
+ if evecs[0, 0] < 0:
75
+ evecs = -evecs
76
+
77
+ U = evecs
78
+ D = 1 / jnp.sqrt(jnp.ones_like(evals).at[:rank].set(evals[:rank]))
79
+ Z = (U.T * jnp.expand_dims(D, 1)).T
80
+ Abar = Z[:, :rank]
81
+
82
+ return Abar
83
+
84
+ @classmethod
85
+ def constant_and_linear(cls, x: Array, basis: Array) -> Array:
86
+ nobs = jnp.shape(x)[0]
87
+ j = jnp.ones(shape=nobs)
88
+ X = jnp.c_[j, x]
89
+ A = jnp.linalg.inv(X.T @ X) @ X.T @ basis
90
+ return cls.general(constraint=A)
91
+
92
+ @classmethod
93
+ def sumzero_coef(cls, ncoef: int) -> Array:
94
+ j = jnp.ones(shape=(1, ncoef))
95
+ return cls.general(constraint=j)
96
+
97
+ @classmethod
98
+ def sumzero_term(cls, basis: Array) -> Array:
99
+ nobs = jnp.shape(basis)[0]
100
+ j = jnp.ones(shape=nobs)
101
+ A = jnp.expand_dims(j @ basis, 0)
102
+ return cls.general(constraint=A)
103
+
104
+ @classmethod
105
+ def sumzero_term2(cls, basis: Array) -> Array:
106
+ A = jnp.mean(basis, axis=0, keepdims=True)
107
+ return cls.general(constraint=A)