guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,548 @@
1
+ """Transformer utilities for GUTS base simulations.
2
+
3
+ Provides classes to transform parameter and result units, and functions to apply these
4
+ transformations to InferenceData objects.
5
+ """
6
+ from typing import Dict, Tuple, List, Any, Type
7
+ from dataclasses import dataclass, field
8
+ from guts_base.sim.utils import GutsBaseError
9
+ import numpy as np
10
+ import xarray as xr
11
+ from pymob.sim.config import Modelparameters
12
+ from pymob import SimulationBase
13
+
14
+ @dataclass
15
+ class TransformBase:
16
+ """Base class for transformation utilities.
17
+
18
+ Attributes
19
+ ----------
20
+ ignore_keys : Tuple
21
+ Keys that should be ignored during transformation.
22
+
23
+ Methods
24
+ -------
25
+ _transform_value(key, value, func_template)
26
+ Apply a specific transformation function to a given value.
27
+ _transform_dataset(x, func_template)
28
+ Transform all data variables in an xarray.Dataset.
29
+ _transform_dict(x, func_template)
30
+ Transform all items in a dictionary.
31
+ _transform_modelparameters(x, func_template)
32
+ Transform model parameter values in a Modelparameters instance.
33
+ _transform(x, func_template)
34
+ Dispatch transformation based on input type.
35
+ transform(x)
36
+ Transform using the default function template.
37
+ transform_inv(x)
38
+ Inverse transformation using the ``*_inv`` function template.
39
+ """
40
+ ignore_keys: Tuple = ()
41
+
42
+ def _transform_value(self, key, value, func_template):
43
+ # skips transform if apply_transform is false. This basically ensures that
44
+ # the default NoTransform will not do anything and also not raise a warning
45
+ if key in self.ignore_keys:
46
+ return value
47
+
48
+ # does not transform if the key was not found
49
+ if hasattr(self, key):
50
+ func = getattr(self, func_template.format(key=key))
51
+ return func(value)
52
+ else:
53
+ raise GutsBaseError(
54
+ f"'{key}' was not found in '{type(self).__name__}'. All values "+
55
+ "must have an associated transform function or be explicitly excluded. "+
56
+ f"If necessary, define a transform method in '{type(self).__name__}' "+
57
+ f"named '{key}' to transform it:\n" +
58
+ f">>> {type(self).__name__}.{key} = lambda self, x: ... \nAlternatively, "+
59
+ f"use `{type(self).__name__}(..., ignore_keys=[..., '{key}'])` to "+
60
+ "suppress this error"
61
+ )
62
+
63
+ def _transform_dataset(self, x: xr.Dataset, func_template: str) -> xr.Dataset:
64
+ x_transformed = xr.Dataset({
65
+ key: self._transform_value(key, value, func_template)
66
+ for key, value in x.data_vars.items()
67
+ })
68
+ x_transformed.attrs = x.attrs
69
+ return x_transformed
70
+
71
+ def _transform_dict(self, x: Dict, func_template: str) -> Dict:
72
+ x_transformed = {
73
+ key: self._transform_value(key, value, func_template)
74
+ for key, value in x.items()
75
+ }
76
+ return x_transformed
77
+
78
+ def _transform_modelparameters(self, x: Modelparameters, func_template: str) -> Modelparameters:
79
+ """Transform model parameter values of a ``Modelparameters`` instance.
80
+
81
+ The method dumps the Modelparameters object and updates each parameter's ``value``
82
+ in the dumpled dict using the appropriate
83
+ transformation function and returns the a newly validated ``Modelparameters`` object.
84
+
85
+ Parameters
86
+ ----------
87
+ x : Modelparameters
88
+ The model parameters container whose values will be transformed.
89
+ func_template : str
90
+ Template string used to locate the correct transformation method.
91
+
92
+ Returns
93
+ -------
94
+ Modelparameters
95
+ A new ``Modelparameters`` instance after transformation and validation.
96
+ """
97
+ model_parameters = x.model_dump(mode="python")
98
+
99
+ for key, param_dict in model_parameters.items():
100
+ transformed_value = self._transform_value(
101
+ key, param_dict["value"], func_template
102
+ )
103
+ param_dict.update({"value": transformed_value})
104
+
105
+ # return a validated model parameters dict with updated values
106
+ # this is no in-place transform any longer
107
+ return Modelparameters.model_validate(model_parameters)
108
+
109
+
110
+ def _transform(self, x: xr.Dataset|Dict|Modelparameters, func_template: str) -> xr.Dataset|Dict|Modelparameters:
111
+ if isinstance(x, dict):
112
+ x_transformed = self._transform_dict(x, func_template)
113
+
114
+ elif isinstance(x, Modelparameters):
115
+ x_transformed = self._transform_modelparameters(x, func_template)
116
+
117
+ elif isinstance(x, xr.Dataset):
118
+ x_transformed = self._transform_dataset(x, func_template)
119
+ else:
120
+ raise NotImplementedError(
121
+ "Use one of dict or xr.Dataset"
122
+ )
123
+
124
+ return x_transformed
125
+
126
+ def transform(self, x: Any) -> Any:
127
+ """Transform the provided object using the appropriate per-key function.
128
+
129
+ The method accepts an ``xarray.Dataset``, a ``dict`` of values, or a
130
+ ``Modelparameters`` instance. The actual transformation is delegated to
131
+ :meth:`_transform`, which dispatches based on the object's type.
132
+ """
133
+ return self._transform(x, func_template="{key}")
134
+
135
+ def transform_inv(self, x: Any) -> Any:
136
+ """Inverse transform the provided object using the appropriate per-key
137
+ function.
138
+
139
+ Mirrors :meth:`transform` but uses the ``*_inv`` variants of the
140
+ transformation methods.
141
+ """
142
+ return self._transform(x, func_template="{key}_inv")
143
+
144
+
145
+ @dataclass
146
+ class ParameterTransform(TransformBase):
147
+ """Contains methods that define the transformation for each parameter. Coefficients
148
+ for storing the required transformations are defined as class attributes
149
+ """
150
+
151
+ def _test_transform_consistency(self, x: Modelparameters):
152
+ _roundtrip_x = self.transform(self.transform_inv(x))
153
+ for key in x.all.keys():
154
+ np.testing.assert_array_almost_equal(
155
+ np.array(x[key].value),
156
+ np.array(_roundtrip_x[key].value)
157
+ )
158
+
159
+
160
+
161
+ @dataclass
162
+ class DataTransform(TransformBase):
163
+ def _transform_dataset(self, x: xr.Dataset, func_template):
164
+ x_transformed = super()._transform_dataset(x, func_template)
165
+ if hasattr(x, "time"):
166
+ x_transformed = x_transformed.assign_coords({
167
+ "time": getattr(self, func_template.format(key="time"))(x.time)
168
+ })
169
+
170
+ return x_transformed
171
+
172
+ def _test_transform_consistency(self, arr: xr.Dataset):
173
+ np.testing.assert_array_almost_equal(
174
+ self.transform(self.transform_inv(arr)).to_array(),
175
+ arr.to_array(),
176
+ decimal=4
177
+ )
178
+
179
+
180
+
181
+ @dataclass
182
+ class GenericTransform:
183
+ """High-level interface to transform simulation objects and associated InferenceData.
184
+
185
+
186
+
187
+ Parameters
188
+ ----------
189
+ ignore_keys : List[str], optional
190
+ Keys to ignore during transformation.
191
+
192
+ Attributes
193
+ ----------
194
+ par_transformer : ParameterTransform
195
+ Transformer for model parameters.
196
+ obs_transformer : ResultsTransform
197
+ Transformer for observations/results.
198
+ is_transformed : dict
199
+ Tracks which components have been transformed.
200
+ """
201
+ # USER ATTRIBUTES
202
+ # transformer classes can be injected – defaults keep the current behaviour
203
+ # By default the SimTransform is not transforming anything
204
+ parameter_transformer_class: Type[ParameterTransform] = field(init=False, repr=False, default=ParameterTransform)
205
+ data_transformer_class: Type[DataTransform] = field(init=False, repr=False, default=DataTransform)
206
+
207
+ # keys to ignore when transforming
208
+ ignore_keys: List[str] = field(default_factory=list)
209
+
210
+ # INTERNAL ATTRIBUTES
211
+ # internal state – created per instance
212
+ # the fields below are initialised in __post_init__
213
+
214
+ parameter_transformer: ParameterTransform = field(init=False)
215
+ data_transformer: DataTransform = field(init=False)
216
+ is_transformed: Dict[str, bool] = field(init=False)
217
+ apply_transform: bool = field(init=False, default=True)
218
+
219
+ def __post_init__(self) -> None:
220
+ """Automatically assign the required keyword arguments to the parameter classes"""
221
+ init_kwargs_par = {
222
+ k: getattr(self, k) for k in
223
+ self.parameter_transformer_class.__dataclass_fields__.keys()
224
+ }
225
+ self.parameter_transformer = self.parameter_transformer_class(**init_kwargs_par)
226
+
227
+ init_kwargs_obs = {
228
+ k: getattr(self, k) for k in
229
+ self.data_transformer_class.__dataclass_fields__.keys()
230
+ }
231
+ self.data_transformer = self.data_transformer_class(**init_kwargs_obs)
232
+
233
+ # set status
234
+ self.is_transformed = {"idata": False, "observations": False, "parameters": False}
235
+
236
+ def __repr__(self) -> str:
237
+ """Represent the SimTransform with its current transformation state."""
238
+ _it = [f'{k}={v}' for k, v in self.is_transformed.items()]
239
+ return (
240
+ f"{type(self).__name__}("
241
+ f"\n {', '.join(_it)}, "+
242
+ f"\n data_transformer={self.data_transformer}, "+
243
+ f"\n parameter_transformer={self.parameter_transformer}"+
244
+ "\n)"
245
+ )
246
+
247
+
248
+ def _check_transform_state(self, target: str, transform: bool, inverse: bool):
249
+ """Determine whether a transformation should be performed.
250
+
251
+ Parameters
252
+ ----------
253
+ target : str
254
+ Component name ('idata', 'observations', 'parameters').
255
+ transform : bool
256
+ Whether a transformation is requested.
257
+ inverse : bool
258
+ Whether the inverse transformation is requested.
259
+
260
+ Returns
261
+ -------
262
+ tuple (bool, str)
263
+ ``(flip_transform_status, msg)`` where ``flip_transform_status`` indicates
264
+ if the transformation should be executed and ``msg`` contains a diagnostic
265
+ message.
266
+ """
267
+ if not transform:
268
+ msg = f"No transform requested for '{target}'."
269
+ flip_transform_status = False
270
+ elif self.is_transformed[target] and transform and inverse:
271
+ msg = f"'{target}' is transformed and inverse-transforme was requested. OK: executed"
272
+ flip_transform_status = True
273
+ elif self.is_transformed[target] and transform and not inverse:
274
+ msg = f"'{target}' is transformed and transform was requested. Invalid: skipped."
275
+ flip_transform_status = False
276
+ elif not self.is_transformed[target] and transform and inverse:
277
+ msg = f"'{target}' is not transformed and inverse-transform was requested. Invalid: skipped."
278
+ flip_transform_status = False
279
+ elif not self.is_transformed[target] and transform and not inverse:
280
+ msg = f"'{target}' is not transformed and transform was requested. OK: executed."
281
+ flip_transform_status = True
282
+
283
+ return flip_transform_status, msg
284
+
285
+ def _update_transform_state(self, target, fts, msg):
286
+ """Update internal transformation state and optionally print a message.
287
+
288
+ Parameters
289
+ ----------
290
+ target : str
291
+ Component name.
292
+ fts : bool
293
+ Whether the transformation was performed.
294
+ msg : str
295
+ Message to print.
296
+ """
297
+ if fts:
298
+ self.is_transformed[target] = not self.is_transformed[target]
299
+ if msg:
300
+ print(msg)
301
+
302
+
303
+ def _transform_idata(self, idata, inverse=False):
304
+ """Apply parameter and data transformations to an InferenceData object.
305
+ Needs to be in the transform sim, because it needs both parameter transformer
306
+ and data transformer.
307
+
308
+ Parameters
309
+ ----------
310
+ idata : InferenceData
311
+ The arviz InferenceData to transform.
312
+ inverse : bool, optional
313
+ If ``True``, apply the inverse transformation. Default is ``False``.
314
+ """
315
+ if inverse:
316
+ transform_params = self.parameter_transformer.transform_inv
317
+ transform_data = self.data_transformer.transform_inv
318
+ else:
319
+ transform_params = self.parameter_transformer.transform
320
+ transform_data = self.data_transformer.transform
321
+
322
+
323
+ # this makes sure that idata is not edited inplace
324
+ groups = {
325
+ "posterior": transform_params(idata.posterior),
326
+ "posterior_model_fits": transform_data(idata.posterior_model_fits),
327
+ "posterior_predictive": transform_data(idata.posterior_predictive),
328
+ "observed_data": transform_data(idata.observed_data),
329
+ "log_likelihood": idata.log_likelihood.assign_coords(
330
+ transform_data({"time": idata.log_likelihood.coords["time"]})
331
+ )
332
+ }
333
+
334
+ # transform parameters
335
+ idata.posterior = groups["posterior"]
336
+ # transform results
337
+ idata.posterior_model_fits = groups["posterior_model_fits"]
338
+ idata.posterior_predictive = groups["posterior_predictive"]
339
+ idata.observed_data = groups["observed_data"]
340
+ idata.log_likelihood = groups["log_likelihood"]
341
+
342
+
343
+ def transform(
344
+ self,
345
+ sim: SimulationBase,
346
+ inverse: bool = False,
347
+ idata: bool = True,
348
+ observations: bool = True,
349
+ parameters=True
350
+ ) -> None:
351
+ """Transform simulation data and/or parameters.
352
+
353
+ Parameters
354
+ ----------
355
+ sim : object
356
+ Simulation instance containing ``inferer``, ``config``, and ``observations``.
357
+ inverse : bool, optional
358
+ Apply inverse transformation if ``True``.
359
+ idata : bool, optional
360
+ Transform the InferenceData in ``sim.inferer``.
361
+ observations : bool, optional
362
+ Transform ``sim.observations``.
363
+ parameters : bool, optional
364
+ Transform model parameters in ``sim.config``.
365
+ """
366
+
367
+ # simply exit without applying any changes if apply transform is False
368
+ if not self.apply_transform:
369
+ return
370
+
371
+ # transform idata
372
+ fts, msg = self._check_transform_state(target="idata", transform=idata, inverse=inverse)
373
+ if fts:
374
+ if not hasattr(sim, "inferer"):
375
+ pass
376
+
377
+ else:
378
+ if not hasattr(sim.inferer, "idata"):
379
+ pass
380
+ else:
381
+ self._transform_idata(
382
+ idata=sim.inferer.idata,
383
+ inverse=inverse
384
+ )
385
+
386
+ self._update_transform_state(target="idata", fts=fts, msg=msg)
387
+
388
+
389
+ fts, msg = self._check_transform_state(target="parameters", transform=parameters, inverse=inverse)
390
+ if fts:
391
+ if inverse:
392
+ sim.config.model_parameters = self.parameter_transformer.transform_inv(
393
+ sim.config.model_parameters
394
+ )
395
+ else:
396
+ sim.config.model_parameters = self.parameter_transformer.transform(
397
+ sim.config.model_parameters
398
+ )
399
+
400
+ self._update_transform_state(target="parameters", fts=fts, msg=msg)
401
+
402
+
403
+ fts, msg = self._check_transform_state(target="observations", transform=observations, inverse=inverse)
404
+ if fts:
405
+ if inverse:
406
+ sim.observations = self.data_transformer.transform_inv(sim.observations)
407
+ else:
408
+ sim.observations = self.data_transformer.transform(sim.observations)
409
+
410
+ self._update_transform_state(target="observations", fts=fts, msg=msg)
411
+
412
+
413
+
414
+ @dataclass(repr=False)
415
+ class NoTransform(GenericTransform):
416
+ data_transformer_class: Type[DataTransform] = DataTransform
417
+ parameter_transformer_class: Type[ParameterTransform] = ParameterTransform
418
+ # update apply_transform field to not apply transforms. This is critical
419
+ # for the NoTransform Class, but will default to True in any other classese
420
+ # inheriting from GenericTransform
421
+ apply_transform: bool = field(init=False, default=False)
422
+
423
+
424
+
425
+
426
+ @dataclass(repr=False)
427
+ class GutsDataTransform(DataTransform):
428
+ """Transformer for GUTS Model datasets. If additional data variables are recorded in the
429
+ observations, these must be in a subclass of GutsRedDataTransform, otherwise they will
430
+ not be transformed.
431
+
432
+ Can transform
433
+ RED_SD, RED_IT
434
+ BufferGUTS_SD, BufferGUTS_IT
435
+
436
+ """
437
+ x_in_factor: float = 1.0
438
+ time_factor: float = 1.0
439
+
440
+ def exposure(self, x):
441
+ return x / self.x_in_factor
442
+
443
+ def exposure_inv(self, x):
444
+ return x * self.x_in_factor
445
+
446
+ def D(self, x):
447
+ return x / self.x_in_factor
448
+
449
+ def D_inv(self, x):
450
+ return x * self.x_in_factor
451
+
452
+ def B(self, x):
453
+ return x / self.x_in_factor
454
+
455
+ def B_inv(self, x):
456
+ return x * self.x_in_factor
457
+
458
+ def H(self, x):
459
+ return x
460
+
461
+ def H_inv(self, x):
462
+ return x
463
+
464
+ def survival(self, x):
465
+ return x
466
+
467
+ def survival_inv(self, x):
468
+ return x
469
+
470
+ def time(self, x):
471
+ return x / self.time_factor
472
+
473
+ def time_inv(self, x):
474
+ return x * self.time_factor
475
+
476
+
477
+
478
+ @dataclass
479
+ class GutsParameterTransform(ParameterTransform):
480
+ """Transformer for model parameters.
481
+
482
+ Scales time-related parameters by ``time_factor`` and concentration-related
483
+ parameters by ``x_in_factor``.
484
+ """
485
+ time_factor: float = 1.0
486
+ x_in_factor: float = 1.0
487
+
488
+ # in some beautiful future world. These transformation can be automatized based
489
+ # on the ODE system and the input and output quantities.
490
+ # generally parameter units can be parsed
491
+ def hb(self, x):
492
+ return x * self.time_factor
493
+
494
+ def hb_inv(self, x):
495
+ return x / self.time_factor
496
+
497
+ def kd(self, x):
498
+ return x * self.time_factor
499
+
500
+ def kd_inv(self, x):
501
+ return x / self.time_factor
502
+
503
+ def m(self, x):
504
+ return x / self.x_in_factor
505
+
506
+ def m_inv(self, x):
507
+ return x * self.x_in_factor
508
+
509
+ def b(self, x):
510
+ return x * self.x_in_factor * self.time_factor
511
+
512
+ def b_inv(self, x):
513
+ return x / self.x_in_factor / self.time_factor
514
+
515
+ def beta(self, x):
516
+ """beta is scale invariant"""
517
+ return x
518
+
519
+ def beta_inv(self, x):
520
+ return x
521
+
522
+ def w(self, x):
523
+ return x
524
+
525
+ def w_inv(self, x):
526
+ return x
527
+
528
+ def eps(self, x):
529
+ """eps is a small value added to D in the computation of the IT model
530
+ Scaling it is required, so that extremely small exposures (leading to small
531
+ damages, do not get disproportionally large by adding small eps values)
532
+ """
533
+ return x / self.x_in_factor
534
+
535
+ def eps_inv(self, x):
536
+ return x * self.x_in_factor
537
+
538
+ @dataclass(repr=False)
539
+ class GutsTransform(GenericTransform):
540
+ # the transformation classes
541
+ data_transformer_class: Type[DataTransform] = GutsDataTransform
542
+ parameter_transformer_class: Type[ParameterTransform] = GutsParameterTransform
543
+
544
+ # Necessary coefficients for performing the transformation default values of 1.0
545
+ # will result in no applied transform. The true values will be passed during
546
+ # initialization (see example.py)
547
+ time_factor: float = 1.0
548
+ x_in_factor: float = 1.0