guts-base 2.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
guts_base/sim/base.py ADDED
@@ -0,0 +1,1286 @@
1
+ import os
2
+ import glob
3
+ import tempfile
4
+ import warnings
5
+ import numpy as np
6
+ import xarray as xr
7
+ from diffrax import Dopri5
8
+ from typing import Literal, Optional, List, Dict, Mapping, Sequence, Type, Hashable
9
+ import pandas as pd
10
+
11
+ from pymob import SimulationBase
12
+ from pymob.sim.config import (
13
+ DataVariable, Param, NumericArray, Numpyro
14
+ )
15
+
16
+ from pymob.solvers import JaxSolver
17
+ from pymob.solvers.base import rect_interpolation
18
+ from expyDB.intervention_model import (
19
+ Treatment, Timeseries, select, from_expydb
20
+ )
21
+
22
+
23
+ from guts_base.sim.transformer import NoTransform, GenericTransform
24
+ from guts_base.sim.utils import GutsBaseError
25
+ from guts_base import mod
26
+ from guts_base.data import (
27
+ to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
28
+ create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
29
+ )
30
+ from guts_base.data.generator import draft_laboratory_experiment
31
+ from guts_base.sim.report import GutsReport
32
+ from guts_base.sim import units as _units
33
+
34
+
35
+ class GutsBase(SimulationBase):
36
+ """
37
+ Base class for GUTS simulations.
38
+
39
+ Attributes:
40
+ unit_time (str): The unit of time used for simulation outputs and plots.
41
+ Can be one of "day", "hour", "minute", or "second".
42
+ Defaults to "day" if not specified in the simulation configuration.
43
+
44
+ """
45
+ solver = JaxSolver
46
+ Report = GutsReport
47
+ Transform: Type[GenericTransform] = NoTransform
48
+
49
+ transformer: GenericTransform = NoTransform()
50
+
51
+
52
+ _exclude_controls_after_fixing_background_mortality = False
53
+ # results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
54
+ # _skip_data_processing: bool = False
55
+ # ecx_mode: Literal["mean", "draws"] = "mean"
56
+ # exposure_scenarios = {
57
+ # "acute_1day": {"start": 0.0, "end": 1.0},
58
+ # "chronic": {"start": 0.0, "end": None},
59
+ # }
60
+
61
+ def initialize(self, input: Optional[Dict] = None):
62
+ """Initiaization goes through a couple of steps:
63
+
64
+ 1. Configuration: This makes case-study specific changes to the configuration
65
+ file or sets state variables that are relevant for the simulation
66
+ TODO: Ideally everything that is configurable ends up in the config so it
67
+ can be serialized
68
+
69
+ 2. Import data: This method consists of submethods that can be adapted or
70
+ overwritten in subclass methods.
71
+ - .read_data
72
+ - .save_observations
73
+ - .process_data
74
+ process_data itself utilizes the submethods _create_indices and
75
+ _indices_to_dimensions which are empty methods by default, but can be used
76
+ in subclasses if needed
77
+
78
+ 3. Initialize the simulation input (parameters, y0, x_in). This can
79
+
80
+ By splitting up the simulation init method, into these three steps, modifcations
81
+ of the initialize method allows for higher granularity in subclasses.
82
+ """
83
+
84
+ # 1. Configuration
85
+ self.configure_case_study()
86
+
87
+ # 2. Import data
88
+ self.observations = self.read_data()
89
+ # FIXME: Saving observations here is not intuituve. If i export a simulation,
90
+ # I want to use the last used state, not some obscure intermediate state
91
+ # self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
92
+ if not self.config.guts_base.skip_data_processing:
93
+ self.process_data()
94
+
95
+ # 3. prepare y0 and x_in
96
+ self.prepare_simulation_input()
97
+
98
+ def configure_case_study(self):
99
+ """Modify configuration file or set state variables
100
+ """
101
+ if self._model_class is not None:
102
+ self.model = self._model_class._rhs_jax
103
+ self.solver_post_processing = self._model_class._solver_post_processing
104
+
105
+ def prepare_simulation_input(self):
106
+ x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
107
+ y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
108
+
109
+ # add model components
110
+ if self.config.guts_base.forward_interpolate_exposure_data: # type: ignore
111
+ self.model_parameters["x_in"] = rect_interpolation(x_in)
112
+ else:
113
+ # linear interpolation will be the default assumption, and this will also
114
+ # lead to rect_interpolation, if the exposure_profile was before rectified
115
+ # ffill(dim="time") makes sure that no NaN values are at the end
116
+ self.model_parameters["x_in"] = x_in.interpolate_na(dim="time", method="linear").ffill(dim="time")
117
+
118
+ self.model_parameters["y0"] = y0
119
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
120
+
121
+ def construct_database_statement_from_config(self):
122
+ """returns a statement to be used on a database"""
123
+ substance = self.config.simulation.substance # type:ignore
124
+ exposure_path = self.config.simulation.exposure_path # type:ignore
125
+ return (
126
+ select(Timeseries, Treatment)
127
+ .join(Timeseries)
128
+ ).where(
129
+ Timeseries.variable.in_([substance]), # type: ignore
130
+ Timeseries.name == {exposure_path}
131
+ )
132
+
133
+ def read_data(self) -> xr.Dataset:
134
+ """Reads data and returns an xarray.Dataset.
135
+
136
+ GutsBase supports reading data from
137
+ - netcdf (.nc) files
138
+ - expyDB (SQLite databases)
139
+ - excel (directories of excel files)
140
+
141
+ expyDB and excel operate by converting data to xarrays while netcdf directly
142
+ loads xarray Datasets. For highest control over your data, you should always use
143
+ .nc files, because they are imported as-is.
144
+ """
145
+ # TODO: Update to new INTERVENTION MODEL
146
+ dataset = str(self.config.case_study.observations)
147
+
148
+ # read from a directory
149
+ if os.path.isdir(os.path.join(self.config.case_study.data_path, dataset)):
150
+ # This looks for xlsx files in the folder and imports them as a database and
151
+ # then proceeds as normal
152
+ files = glob.glob(os.path.join(
153
+ self.config.case_study.data_path,
154
+ dataset, "*.xlsx"
155
+ ))
156
+
157
+ tempdir = tempfile.TemporaryDirectory()
158
+ dataset = self.read_data_from_xlsx(data=files, tempdir=tempdir)
159
+
160
+ ext = dataset.split(".")[-1]
161
+
162
+ if not os.path.exists(dataset):
163
+ dataset = os.path.join(self.data_path, dataset)
164
+
165
+ if ext == "db":
166
+ statement = self.construct_database_statement_from_config()
167
+ observations = self.read_data_from_expydb(dataset, statement)
168
+
169
+ # TODO: Integrate interventions in observations dataset
170
+
171
+ elif ext == "nc":
172
+ observations = xr.load_dataset(dataset)
173
+
174
+ else:
175
+ raise NotImplementedError(
176
+ f"Dataset extension '.{ext}' is not recognized. "+
177
+ "Please use one of '.db' (mysql), '.nc' (netcdf)."
178
+ )
179
+
180
+ return observations
181
+
182
+ def read_data_from_xlsx(self, data, tempdir):
183
+ database = os.path.join(tempdir.name, "import.db")
184
+
185
+ create_database_and_import_data_main(
186
+ datasets_path=data,
187
+ database_path=database,
188
+ preprocessing=self.config.guts_base.data_preprocessing,
189
+ preprocessing_out=os.path.join(tempdir.name, "processed_{filename}")
190
+ )
191
+
192
+ return database
193
+
194
+
195
+ def read_data_from_expydb(self, database, statement) -> xr.Dataset:
196
+
197
+ observations_idata, interventions_idata = from_expydb(
198
+ database=f"sqlite:///{database}",
199
+ statement=statement
200
+ )
201
+
202
+ dataset = to_dataset(
203
+ observations_idata,
204
+ interventions_idata,
205
+ unit_time=self.config.guts_base.unit_time
206
+ )
207
+ dataset = reduce_multiindex_to_flat_index(dataset)
208
+
209
+ # "Continue here. I want to return multidimensional datasets for data coming "+
210
+ # "from the database. The method can be implemented in any class. Currently I'm looking "+
211
+ # "at guts base"
212
+
213
+ filtered_dataset = self.filter_dataset(dataset)
214
+
215
+ return filtered_dataset
216
+
217
+ def process_data(self):
218
+ """
219
+ Currently these methods, change datasets, indices, etc. in-place.
220
+ This is convenient, but more difficult to re-arragen with other methods
221
+ TODO: Make these methods static if possible
222
+ """
223
+ self._create_indices()
224
+ self._indices_to_dimensions()
225
+ # define tolerance based on the sovler tolerance
226
+ self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
227
+
228
+ self._reindex_time_dim()
229
+
230
+ if "survival" in self.observations:
231
+ if "subject_count" not in self.observations.coords:
232
+ self.observations = self.observations.assign_coords(
233
+ subject_count=("id", self.observations["survival"].isel(time=0).values, )
234
+ )
235
+
236
+ if self._data.is_survival_only_nan_except_start(self.observations.survival):
237
+ self.observations = self.observations.assign_coords({
238
+ "survivors_at_start": (("id", "time"), np.broadcast_to(
239
+ self.observations.survival.isel(time=0).values.reshape(-1,1),
240
+ shape=self.observations.survival.shape
241
+ ).astype(int))})
242
+ else:
243
+ self.observations = self._data.prepare_survival_data_for_conditional_binomial(
244
+ observations=self.observations
245
+ )
246
+
247
+ if "exposure" not in self.observations:
248
+ self.observations["exposure"] = self.observations[self.config.guts_base.substance]
249
+
250
+ # set
251
+ self.config.data_structure["exposure"].observed=False
252
+
253
+ def _convert_exposure_units(self):
254
+ """
255
+ TODO: Here I need to decide what to do. Work with rescaled units is dangerous
256
+ because fitting might be complicated with weird quantities.
257
+ It would be better to rescale output parameters
258
+ """
259
+ units, unit_conversion_factors = _units._convert_units(
260
+ self.observations.unit.reset_coords("unit", drop=True),
261
+ target_units=self.config.guts_base.unit_target
262
+ )
263
+
264
+ self.observations = self.observations.assign_coords({
265
+ "unit": units,
266
+ "unit_conversion_factors": unit_conversion_factors
267
+ })
268
+
269
+ self.observations[self.config.guts_base.substance] =\
270
+ self.observations[self.config.guts_base.substance] * unit_conversion_factors
271
+
272
+ @staticmethod
273
+ def _unique_unsorted(values):
274
+ _, index = np.unique(values, return_index=True)
275
+ return tuple(np.array(values)[sorted(index)])
276
+
277
+
278
+ def _create_indices(self):
279
+ """Use if indices should be added to sim.indices and sim.observations"""
280
+ pass
281
+
282
+ def _indices_to_dimensions(self):
283
+ pass
284
+
285
+ def filter_dataset(self, dataset: xr.Dataset) -> xr.Dataset:
286
+ return dataset
287
+
288
+ def _reindex_time_dim(self):
289
+ if self.config.simulation.model is not None:
290
+ if "_it" in self.config.simulation.model.lower():
291
+ self.logger.info(msg=(
292
+ "Redindexing time vector to increase resolution, because model has "+
293
+ "'_it' (individual tolerance) in it's name"
294
+ ))
295
+ if not hasattr(self.config.simulation, "n_reindexed_x"):
296
+ self.config.simulation.n_reindexed_x = 100
297
+
298
+ new_time_index = np.unique(np.concatenate([
299
+ self.coordinates["time"],
300
+ np.linspace(
301
+ 0, np.max(self.coordinates["time"]),
302
+ int(self.config.simulation.n_reindexed_x) # type: ignore
303
+ )
304
+ ]))
305
+ self.observations = self.observations.reindex(time = new_time_index)
306
+ return
307
+
308
+ self.logger.info(msg=(
309
+ "No redindexing of time vector to, because model name did not contain "+
310
+ "'_it' (individual tolerance), or model was not given by name. If an IT model " +
311
+ "is calculated without a dense time resolution, the estimates can be biased!"
312
+ ))
313
+
314
+
315
+ def reset_observations(self):
316
+ """Resets the observations to the original observations after using .from_mempy(...)
317
+ This also resets the sim.coordinates dictionary.
318
+ """
319
+
320
+ self.observations = self._obs_backup
321
+
322
+
323
+ def recompute_posterior(self):
324
+ """This function interpolates the posterior with a given resolution
325
+ posterior_predictions calculate proper survival predictions for the
326
+ posterior.
327
+
328
+ It also makes sure that the new interpolation does not include fewer values
329
+ than the original dataset
330
+ """
331
+
332
+ ri = self.config.guts_base.results_interpolation
333
+
334
+ # generate high resolution posterior predictions
335
+ if ri is not None:
336
+ time_interpolate = np.linspace(
337
+ start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
338
+ stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
339
+ num=ri[2]
340
+ )
341
+
342
+ # combine original coordinates and interpolation. This
343
+ # a) helps error checking during posterior predictions.
344
+ # b) makes sure that the original time vector is retained, which may be
345
+ # relevant for the simulation success (e.g. IT model)
346
+ obs = self.observations.reindex(
347
+ time=np.unique(np.concatenate(
348
+ [time_interpolate, self.observations["time"]]
349
+ )),
350
+ )
351
+
352
+ obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
353
+ obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
354
+ self.observations = obs
355
+
356
+ self.dispatch_constructor()
357
+ _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
358
+
359
+
360
+ def prior_predictive_checks(self, **plot_kwargs):
361
+ super().prior_predictive_checks(**plot_kwargs)
362
+
363
+ self._plot.plot_prior_predictions(self, data_vars=["survival"])
364
+
365
+ def posterior_predictive_checks(self, **plot_kwargs):
366
+ super().posterior_predictive_checks(**plot_kwargs)
367
+
368
+ sim_copy: GutsBase = self.copy()
369
+ sim_copy.recompute_posterior()
370
+ # TODO: Include posterior_predictive group once the survival predictions are correctly working
371
+ sim_copy._plot.plot_posterior_predictions(
372
+ sim_copy, data_vars=["survival"], groups=["posterior_model_fits"]
373
+ )
374
+
375
+
376
+ def plot(self, results):
377
+ self._plot.plot_survival(self, results)
378
+
379
+ def predefined_scenarios(self):
380
+ """
381
+ TODO: Fix timescale to observations
382
+ TODO: Incorporate extra exposure patterns (constant, pulse_1day, pulse_2day)
383
+ """
384
+ # get the maximum possible time to provide exposure scenarios that are definitely
385
+ # long enough
386
+ time_max = float(max(
387
+ self.observations[self.config.simulation.x_dimension].max(),
388
+ *self.config.guts_base.ecx_estimates_times
389
+ ))
390
+
391
+ # this produces a exposure x_in dataset with only the dimensions ID and TIME
392
+ standard_dimensions = (
393
+ self.config.simulation.batch_dimension,
394
+ self.config.simulation.x_dimension,
395
+ )
396
+
397
+ # get dimensions different from standard dimensions
398
+ exposure_dimension = [
399
+ d for d in self.observations.exposure.dims if d not in
400
+ standard_dimensions
401
+ ]
402
+
403
+ # raise an error if the number of extra dimensions is larger than 1
404
+ if len(exposure_dimension) > 1:
405
+ raise ValueError(
406
+ f"{type(self).__name__} can currently handle one additional dimension for "+
407
+ f"the exposure beside {standard_dimensions}. You provided an exposure "+
408
+ f"array with the dimensions: {self.observations.exposure.dims}"
409
+ )
410
+ else:
411
+ exposure_dimension = exposure_dimension[0]
412
+
413
+ # iterate over the coordinates of the exposure dimensions to
414
+ exposure_coordinates = self.observations.exposure[exposure_dimension].values
415
+
416
+ scenarios = {}
417
+ for coord in exposure_coordinates:
418
+ concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
419
+
420
+ for _name, _expo_scenario in self.config.guts_base.ecx_exposure_scenarios.items():
421
+ if _expo_scenario["start"] is None:
422
+ _expo_scenario["start"] = 0.0
423
+
424
+ exposure_dict = {
425
+ coord: ExposureDataDict(
426
+ start=_expo_scenario["start"],
427
+ end=_expo_scenario["end"],
428
+ exposure=conc
429
+ )
430
+ for coord, conc in zip(exposure_coordinates, concentrations)
431
+ }
432
+
433
+ scenario = design_exposure_scenario(
434
+ exposures=exposure_dict,
435
+ t_max=time_max,
436
+ dt=1/24,
437
+ exposure_dimension=exposure_dimension
438
+ )
439
+
440
+ scenarios.update({
441
+ f"{_name}_{coord}": scenario
442
+ })
443
+
444
+ return scenarios
445
+
446
+
447
+
448
+ @staticmethod
449
+ def _exposure_data_to_xarray(
450
+ exposure_data: Dict[str, pd.DataFrame],
451
+ dim: str,
452
+ default_time_unit: str = "",
453
+ exposure_units: Mapping[str,str] = {"default": ""},
454
+ ) -> xr.Dataset:
455
+ """Creates a Dataset named exposure that has coordinates corresponding to the
456
+ keys in the exposure_data and a dimension name accordint to dim. It also carries
457
+ an unused coordinate called unit, which carries the unit information of the exposure
458
+ """
459
+ arrays = {}
460
+ _unit_time = []
461
+ _unit_exposure = {}
462
+ for key, df in exposure_data.items():
463
+ # this override is necessary to make all dimensions work out
464
+ unit_time = _units._get_unit_from_dataframe_index(df)
465
+ unit_expo = exposure_units.get(key, exposure_units["default"])
466
+
467
+ df.index.name = "time"
468
+ arrays.update({
469
+ key: df.to_xarray().to_dataarray(dim="id", name=key)
470
+ })
471
+ _unit_time.append(unit_time)
472
+ _unit_exposure.update({key: f"{_units.ureg.parse_expression(unit_expo).units:C}"})
473
+
474
+ # convert exposure units to an xr.Dataarray
475
+ units_arr = pd.Series(_unit_exposure).to_xarray()
476
+ units_arr = units_arr.rename({"index": dim})
477
+
478
+ # make sure times of all exposures are identical
479
+ if len(set(_unit_time)) > 1:
480
+ raise GutsBaseError(
481
+ "Different time units were specified in the exposure datasets " +
482
+ f"{set(_unit_time)}. Make sure all exposure datasets have the " +
483
+ "same time unit."
484
+ )
485
+ else:
486
+ unit_time = list(set(_unit_time))[0]
487
+
488
+ # if the unit_time was not defined, resort to the default time unit (dimensionless)
489
+ if len(unit_time) == 0:
490
+ unit_time = default_time_unit
491
+
492
+ # create the exposure dataset
493
+ exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
494
+ exposure_array = exposure_array.transpose("id", "time", ...)
495
+ ds = xr.Dataset({"exposure": exposure_array})
496
+
497
+ # add the time unit as an attribute (rounttrip ureg, to standardize)
498
+ ds.attrs["unit_time"] = f"{_units.ureg.parse_expression(unit_time).units:C}"
499
+ # add exposure units as coordinates. This is used later on by _convert units
500
+ ds = ds.assign_coords({"unit": units_arr})
501
+
502
+ return ds
503
+
504
+ @staticmethod
505
+ def _survival_data_to_xarray(
506
+ survival_data: pd.DataFrame,
507
+ default_time_unit: str = ""
508
+ ) -> xr.Dataset:
509
+ # TODO: survival name is currently not kept because the raw data is not transferred from the survival
510
+
511
+ unit_time = _units._get_unit_from_dataframe_index(survival_data)
512
+ survival_data.index.name = "time"
513
+
514
+ survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
515
+ survival_array = survival_array.transpose("id", "time", ...)
516
+ arrays = {"survival": survival_array}
517
+
518
+ ds = xr.Dataset(arrays)
519
+ if len(unit_time) == 0:
520
+ unit_time = default_time_unit
521
+
522
+ ds.attrs["unit_time"] = f"{_units.ureg.parse_expression(unit_time).units:C}"
523
+
524
+ return ds
525
+
526
+ @classmethod
527
+ def _observations_from_dataframes(
528
+ cls,
529
+ exposure_data: Dict[str, pd.DataFrame],
530
+ survival_data: Optional[pd.DataFrame] = None,
531
+ exposure_dim: str = "substance",
532
+ unit_input: Mapping[str,str] = {"default": ""},
533
+ unit_time: str = "day",
534
+ ):
535
+ # parse observations
536
+ # obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
537
+ _exposure = cls._exposure_data_to_xarray(
538
+ exposure_data, dim=exposure_dim,
539
+ exposure_units=unit_input,
540
+ default_time_unit=unit_time,
541
+ )
542
+ arrays = [_exposure]
543
+ if survival_data is not None:
544
+ _survival = cls._survival_data_to_xarray(
545
+ survival_data,
546
+ default_time_unit=unit_time
547
+ )
548
+ arrays.append(_survival)
549
+
550
+ observations = xr.combine_by_coords(arrays, combine_attrs="no_conflicts")
551
+
552
+ return observations
553
+
554
+ @property
555
+ def _exposure_dimension(self):
556
+ return self._get_exposure_dimension(
557
+ dimensions=self.config.data_structure["exposure"].dimensions,
558
+ batch_dim=self.config.simulation.batch_dimension,
559
+ x_dim=self.config.simulation.x_dimension
560
+ )
561
+
562
+ @staticmethod
563
+ def _get_exposure_dimension(dimensions, batch_dim: str = "id", x_dim: str = "time"):
564
+ extra_dims = []
565
+ for k in dimensions:
566
+ if k not in (batch_dim, x_dim):
567
+ extra_dims.append(k)
568
+ else:
569
+ pass
570
+
571
+ if len(extra_dims) > 1:
572
+ raise GutsBaseError(
573
+ "Guts Base can currently only handle one exposure dimension beside" +
574
+ "the standard dimensions."
575
+ )
576
+ else:
577
+ return extra_dims[0]
578
+
579
+
580
+ def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
581
+ """This method will take an existing coordinate of a dataset that has the same
582
+ coordinate has the batch dimension. It will then re-express the coordinate as a
583
+ separate dimension for the given variables, by duplicating the N-Dimensional array
584
+ times the amount of unique names in the specified coordinate to create an
585
+ N+1-dimensional array. This array will be filled with zeros along the batch dimension
586
+ where the specified coordinate along the ID dimension coincides with the new (unique)
587
+ coordinate of the new dimension.
588
+
589
+ This process is entirely reversible
590
+ """
591
+ old_coords = self.observations[coordinate]
592
+ batch_dim = self.config.simulation.batch_dimension
593
+
594
+ # old coordinate before turning it into a dimension
595
+ obs = self.observations.drop(coordinate)
596
+
597
+ # create unique coordinates of the new dimension, preserving the order of the
598
+ # old coordinate
599
+ _, index = np.unique(old_coords, return_index=True)
600
+ coords_new_dim = tuple(np.array(old_coords)[sorted(index)])
601
+
602
+ for v in variables:
603
+ # take data variable and extract dimension order
604
+ data_var = obs[v]
605
+ dim_order = data_var.dims
606
+
607
+ # expand the dimensionality, then transpose for new dim to be last
608
+ data_var = data_var.expand_dims(coordinate).transpose(..., batch_dim, coordinate)
609
+
610
+ # create a dummy dimension to broadcast the new array
611
+ # dummy_3d = np.ones((1, len(coords_new_dim)))
612
+ dummy_categorical = pd.get_dummies(old_coords).astype(int).values
613
+
614
+ # apply automatic broadcasting to increase the size of the new dimension
615
+ # data_var_np1_d = data_var * dummy_3d
616
+ data_var_np1_d = data_var * dummy_categorical
617
+ data_var_np1_d.attrs = data_var.attrs
618
+
619
+ # annotate coordinates of the new dimension
620
+ data_var_np1_d = data_var_np1_d.assign_coords({
621
+ coordinate: list(coords_new_dim)
622
+ })
623
+
624
+ # transpose back to original dimension order with new dim as last dim
625
+ data_var_np1_d = data_var_np1_d.transpose(*dim_order, coordinate)
626
+ obs[v] = data_var_np1_d
627
+
628
+ return obs
629
+
630
+ def map_batch_coordinates_to_extra_dim_coordinates(
631
+ self,
632
+ observations: xr.Dataset,
633
+ target_dimension: str,
634
+ coordinates: Optional[List[Hashable]] = None
635
+ ) -> xr.Dataset:
636
+ """Iterates over coordinates and reduces those coordinates to the new dimension
637
+ which have the same number of unique coordinates as the new dimension has coordinates
638
+ """
639
+ if coordinates is None:
640
+ coordinates = list(observations.coords.keys())
641
+
642
+ for key, coord in observations.coords.items():
643
+ # skips coords, if not specified in coordinates
644
+ if key not in coordinates:
645
+ continue
646
+
647
+ if self.config.simulation.batch_dimension in coord.dims and key not in observations.dims:
648
+ if len(coord.dims) == 1:
649
+ dim_coords = self._unique_unsorted(coord.values)
650
+ if len(dim_coords) == len(observations[target_dimension]):
651
+ observations[key] = (target_dimension, list(dim_coords))
652
+ else:
653
+ pass
654
+ else:
655
+ warnings.warn(
656
+ f"Coordinate '{key}' is has dimensions {coord.dims}. " +
657
+ "Mapping coordinates with more than 1 dimension to the extra " +
658
+ f"dimension '{target_dimension}' is not supported yet."
659
+ )
660
+ pass
661
+
662
+ return observations
663
+
664
+
665
+ def reduce_dimension_to_batch_like_coordinate(self, dimension, variables):
666
+ """This method takes an existing dimension from a N-D array and reduces it to an
667
+ (N-1)-D array, by writing a new coordinate from the reducible dimension in the way
668
+ that the new batch-like coordinate takes the coordinate of the dimension, where
669
+ the data of the N-D array was not zero. After it has been asserted that there is
670
+ only a unique candidate for the each coordinate along the batch dimension
671
+ (i.e. only one value is non-zero for a given batch-coordinate), the dimension will
672
+ be reduced by summing over the given dimension.
673
+
674
+ The method is contingent on having no overlap in batch dimension in the dataset
675
+ """
676
+ pass
677
+
678
+ def initialize_from_script(self):
679
+ pass
680
+
681
+
682
+ @staticmethod
683
+ def _update_model_parameters(model_parameters, params: dict):
684
+ params_stash = {}
685
+ for name, new_values in params.items():
686
+ if not hasattr(model_parameters, name):
687
+ print(f"{name} not in model_parameters, skipping.")
688
+ continue
689
+
690
+ param = model_parameters[name]
691
+ stash = param.model_dump(include=list(new_values.keys()))
692
+ params_stash.update({name: stash})
693
+
694
+ for k, v in new_values.items():
695
+ setattr(param, k, v)
696
+
697
+ return params_stash
698
+
699
+ ### API methods ###
700
+
701
+ def transform(
702
+ self,
703
+ inverse=False,
704
+ idata=True,
705
+ observations=True,
706
+ parameters=True,
707
+ ):
708
+ """EXPERIMENTAL FEATURE
709
+
710
+ Transform with care! Transforming a simulation changes the parameter **values**,
711
+ observations, and results, observations and parameters in the idata object (if
712
+ existing).
713
+
714
+ Usage Example
715
+ -------------
716
+
717
+ A typical workflow is:
718
+
719
+ 1. Set up the simulation `sim = Constructor.from_model_and_dataset(...)`
720
+
721
+ >>> from mempy.model import RED_SD
722
+ >>> from guts_base import Constructor
723
+ >>> experiment = Constructor.draft_laboratory_experiment(
724
+ ... treatments={"C": 0.0, "T1": 1, "T2": 5, "T3": 50, "T4": 100},
725
+ ... simulate_survival=True,
726
+ ... )
727
+ >>> survival_data = experiment.survival.to_pandas().T
728
+ >>> exposure_data = {"A": experiment.exposure.to_pandas().T}
729
+ >>> sim = Constructor.from_model_and_dataset(
730
+ ... model=RED_SD(),
731
+ ... exposure_data=survival,
732
+ ... survival_data=exposure,
733
+ ... output_directory="results/test"
734
+ ... )
735
+
736
+ 2. Set up the transform `sim.transformer = GutsTransform(x_in_factor=..., time_factor=...)`
737
+
738
+ >>> from guts_base.sim.transformer import GutsTransform
739
+ >>> # define a transformation factor
740
+ >>> x_in_factor = float(sim.observations.exposure.max().values)
741
+ >>> # set transformation
742
+ >>> sim.transformer = GutsTransform(
743
+ ... time_factor=1.0, x_in_factor=x_in_factor,
744
+ ... ignore_keys=["id", "exposure_path", "w"]
745
+ ... )
746
+
747
+ 3. Transform the simulation `GutsBase.transform(idata=False)`
748
+
749
+ >>> sim.transform(idata=False)
750
+
751
+ 4. Run parameter estimation `GutsBase.estimate_parameters(...)`
752
+
753
+ >>> sim.estimate_parameters()
754
+
755
+ 5. Inverse transform the simulation `GutsBase.transform(inverse=True)`
756
+
757
+ >>> sim.transform(inverse=True)
758
+
759
+ Explanations
760
+ ------------
761
+
762
+ `GutsBase.transform` DOES NOT change the priors. This means, estimating
763
+ parameters of a transformed simulation will yield different results than
764
+ estimating parameters of an untransformed simulations.
765
+
766
+ However, this is the whole point. Transformations are meant to make the life of
767
+ a modeller easier. By bringing all data on a unit scale, the problem can be
768
+ easier solved, using default priors (e.g. lognormal(scale=1, s=5)). Depending
769
+ of the size of the transformation, the effect can be larger or smaller. If
770
+ for instance I want to scale the exposure to a unit interval [0, 1], and my
771
+ largest exposure is, e.g. 500, this shifts the relative influence of the prior
772
+ distributions also by a factor of 500, especially the m and b parameters.
773
+
774
+ ⚠️ Therefore a word of warning. Before applying transformations to a wide set
775
+ of different problems, double check your priors and make sure that they behave
776
+ as expected.
777
+
778
+ This means, transforming is a double edged sword. If the priors for a transformed
779
+ distribution are chosen well, with one set of priors, many problems of vastly
780
+ different scales can be solved. If the default priors are not chosen well,
781
+ you will make your life more difficult rather than easier.
782
+
783
+ Extending Transform classes
784
+ ---------------------------
785
+
786
+ **Note** that forward transform (`inverse=False`) divides by the factors and
787
+ inverse transform multiplies by the factors. This is how the functions are defined
788
+ in the GutsTransform.
789
+
790
+ Also, transforms for all parameters might not exist. If you want to define the
791
+ transforms at runtime, this is relatively easy to do:
792
+
793
+ >>> sim.transformer.data_transformer.xxx = lambda self, x: x / self.time_factor
794
+ >>> sim.transformer.data_transformer.xxx_inv = lambda self, x: x * self.time_factor
795
+
796
+ This will create new transforms (forward and inverse) for the data variable 'xxx'.
797
+ Similarly this can be done for the `parameter_transformer` attribute of the sim
798
+ transformer.
799
+
800
+ If you want to provide your own transform functions, feel free to write your own
801
+ class. See `guts_base.sim.transformer` for inspiration. There the `GutsTransform`
802
+ class is defined.
803
+
804
+ Defining Priors
805
+ ---------------
806
+
807
+ Some considerations for defining good priors for a simulation transformed to the
808
+ unit scale.
809
+
810
+ If the exposure is on the interval [0, 1] (by transforming the data by the max.
811
+ exposure), for Guts-SD models, the m parameter will be oftentimes within that
812
+ interval. If of course the experiment was conducted in a way that no mortality
813
+ was observed, the m-parameter is likely outside of that interval (m > 1).
814
+
815
+ A sensible prior would be for instance:
816
+
817
+ >>> sim.config.model_parameters.m.prior = 'lognorm(scale=0.01,s=5)'
818
+
819
+ This prior will assign large probability mass to values in the interval [0, 1],
820
+ but will also cover parts above 1 with quite some probability.
821
+
822
+ For the remaining parameters, a good default is 'lognorm(scale=1.0,s=5)'
823
+
824
+ The s-parameter controls the width of the distribution.
825
+ """
826
+ self.transformer.transform(
827
+ sim=self,
828
+ inverse=inverse,
829
+ idata=idata,
830
+ observations=observations,
831
+ parameters=parameters,
832
+ )
833
+
834
+ self.prepare_simulation_input()
835
+ self.dispatch_constructor()
836
+
837
+ def point_estimate(
838
+ self,
839
+ estimate: Literal["mean", "map"] = "map",
840
+ to: Literal["xarray", "dict"] = "xarray"
841
+ ):
842
+ """Returns a point estimate of the posterior. If you want more control over the posterior
843
+ use the attribute: sim.inferer.idata.posterior and summarize it or select from it
844
+ using the arviz (https://python.arviz.org/en/stable/index.html) and the
845
+ xarray (https://docs.xarray.dev/en/stable/index.html) packages
846
+
847
+ Parameters
848
+ ----------
849
+
850
+ estimate : Literal["map", "mean"]
851
+ Point estimate to return.
852
+ - map: Maximum a Posteriori. The sample that has the highest posterior probability.
853
+ This sample considers the correlation structure of the posterior
854
+ - mean: The average of all marginal parameter distributions.
855
+
856
+ to : Literal["xarray", "dict"]
857
+ Specifies the representation to transform the summarized data to. dict can
858
+ be used to insert parameters in the .evaluate() method. While xarray is the
859
+ standard view. Defaults to xarray
860
+
861
+ Example
862
+ -------
863
+
864
+ >>> sim.best_estimate(to='dict')
865
+ """
866
+ if estimate == "mean":
867
+ best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
868
+
869
+ elif estimate == "map":
870
+ loglik = self.inferer.idata.log_likelihood\
871
+ .sum(["id", "time"])\
872
+ .to_array().sum("variable")
873
+
874
+ sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
875
+ best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik) # type: ignore
876
+ else:
877
+ raise GutsBaseError(
878
+ f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
879
+ )
880
+
881
+
882
+ if to == "xarray":
883
+ return best_estimate
884
+
885
+ elif to == "dict":
886
+ return {k: v.values for k, v in best_estimate.items()}
887
+
888
+ else:
889
+ raise GutsBaseError(
890
+ "PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
891
+ f"You used {to=}"
892
+ )
893
+
894
+
895
+ def evaluate(
896
+ self,
897
+ parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
898
+ y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
899
+ x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
900
+ ):
901
+ """Evaluates the model along the coordinates of the observations with given
902
+ parameters, x_in, and y0. The dictionaries passed to the function arguments
903
+ only overwrite the existing default parameters; which makes the usage very simple.
904
+
905
+ Note that the first run of .evaluate() after calling the .dispatch_constructor()
906
+ takes a little longer, because the model and solver are jit-compiled to JAX for
907
+ highly efficient computations.
908
+
909
+ Parameters
910
+ ----------
911
+
912
+ theta : Dict[float|Sequence[float]]
913
+ Dictionary of model parameters that should be changed for dispatch.
914
+ Unspecified model parameters will assume the default values,
915
+ specified under config.model_parameters.NAME.value
916
+
917
+ y0 : Dict[float|Sequence[float]]
918
+ Dictionary of initial values that should be changed for dispatch.
919
+
920
+ x_in : Dict[float|Sequence[float]]
921
+ Dictionary of model input values that should be changed for dispatch.
922
+
923
+
924
+ Example
925
+ -------
926
+
927
+ >>> sim.dispatch_constructor() # necessary if the sim object has been modified
928
+ >>> # evaluate setting the background mortaltiy to zero
929
+ >>> sim.evaluate(parameters={'hb': 0.0})
930
+
931
+ """
932
+ evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
933
+ evaluator()
934
+ return evaluator.results
935
+
936
+ def estimate_background_mortality(
937
+ self,
938
+ control_ids: Optional[str|List[str]] = None,
939
+ exclude_controls_after_fixing_background_mortality: bool = True,
940
+ inference_numpyro: Numpyro = Numpyro(
941
+ kernel="map",
942
+ svi_iterations=1000,
943
+ svi_learning_rate=0.01,
944
+ init_strategy="init_to_median",
945
+ gaussian_base_distribution=True
946
+ ),
947
+ ):
948
+ """Separately estimates the background mortality parameters based on the control
949
+ treatments. Afterwards, the background mortality parameters are fixed to the
950
+ estimated maximum-a-posteriori values. Note that in the case of SVI, and NUTS
951
+ the map value is the sample of the posteiror distribution that comes closest to
952
+ the true MAP value.
953
+
954
+ Parameters
955
+ ----------
956
+
957
+ control_ids : Optional[str | List [str]]
958
+ The names of the IDs to use for fitting the control mortality parameters
959
+ By default, this selects all IDs that have no exposure throghout the entire
960
+ duration of the provided timeseries.
961
+ exclude_controls_after_fixing_background_mortality : bool
962
+ If the controls should be excluded from fitting after calibration.
963
+ inference_numpyro: Numpyro
964
+ inference_numpyro config section to parameterize background mortality
965
+ estimation. By default, the MAP kernel is used, which is sufficient
966
+ for a problem, where the uncertainty of the estimate is not propagated to
967
+ the following analysis.
968
+ """
969
+
970
+ self._exclude_controls_after_fixing_background_mortality =\
971
+ exclude_controls_after_fixing_background_mortality
972
+ # copy the simulation in order not to mix up anything in the original sim
973
+ sim_control = self.copy()
974
+
975
+ if isinstance(control_ids, str):
976
+ control_ids = [control_ids]
977
+ elif control_ids is None:
978
+ cum_expo = sim_control.observations.exposure.sum(
979
+ ("time", sim_control._exposure_dimension)
980
+ )
981
+ control_ids = cum_expo.where(cum_expo == 0, drop=True).id.values
982
+ else:
983
+ pass
984
+
985
+ # constrain the observation of the copied object to the control ids
986
+ sim_control.observations = sim_control.observations.sel(id=control_ids)
987
+
988
+ # Fix parameters of the background-mortality module at zero
989
+ params_fix_at_zero = {
990
+ k: {"value": 0.0, "free": False } for k in sim_control.model_parameter_names
991
+ if k not in sim_control.config.guts_base.background_mortality_parameters
992
+ }
993
+
994
+ # update the parameters in the model_parameters dict
995
+ params_backup = sim_control._update_model_parameters(
996
+ sim_control.config.model_parameters, params_fix_at_zero
997
+ )
998
+
999
+ # setup inferer
1000
+ sim_control.prepare_simulation_input()
1001
+ sim_control.dispatch_constructor()
1002
+ sim_control.set_inferer("numpyro")
1003
+
1004
+ # run inference
1005
+ sim_control.config.inference_numpyro = Numpyro.model_validate(inference_numpyro)
1006
+ sim_control.inferer.run()
1007
+
1008
+ # plot results of background mortality
1009
+ sim_control._plot.plot_survival_multipanel(
1010
+ sim_control, sim_control.inferer.idata.posterior_model_fits,
1011
+ filename="survival_multipanel_control_treatments"
1012
+ )
1013
+
1014
+ # optain the maximum a posteriori estiamte from the inferer using the guts-base API
1015
+ # and start iterating over the background mortality parameters
1016
+ map_estimate = sim_control.point_estimate("map", to="dict")
1017
+ for bgm_param in sim_control.config.guts_base.background_mortality_parameters:
1018
+ bgm_param_value = map_estimate[bgm_param]
1019
+
1020
+ # assign the estimated parameter MAP value to the parameters of the original
1021
+ # simulation object. Also set them as fixed parameters
1022
+ self.config.model_parameters[bgm_param].value = bgm_param_value
1023
+ self.config.model_parameters[bgm_param].free = False
1024
+
1025
+ # reverse the process from before (this is strictly not necessary)
1026
+ _ = sim_control._update_model_parameters(
1027
+ sim_control.config.model_parameters, params_backup
1028
+ )
1029
+
1030
+ # constrain observations to non-control IDs if flag is set
1031
+ if exclude_controls_after_fixing_background_mortality:
1032
+ control_mask = [id for id in self.observations.id.values if id not in control_ids]
1033
+ self.observations = self.observations.sel(id=control_mask)
1034
+
1035
+ # assemble simulation inputs and Evaluator with new fixed parameter values.
1036
+ self.prepare_simulation_input()
1037
+ self.dispatch_constructor()
1038
+
1039
+ @classmethod
1040
+ def draft_laboratory_experiment(
1041
+ cls,
1042
+ treatments: Dict[str, float|Dict[str,float]],
1043
+ n_test_organisms_per_treatment: int = 10,
1044
+ experiment_end: pd.Timedelta = pd.Timedelta(10, unit="days"),
1045
+ exposure_pattern: ExposureDataDict|Dict[str,ExposureDataDict] = ExposureDataDict(start=0.0, end=None, exposure=None),
1046
+ exposure_interpolation: Literal["linear", "constant-forward"] = "constant-forward",
1047
+ exposure_dimension: str = "substance",
1048
+ observation_times: Optional[List[float]] = None,
1049
+ dt: pd.Timedelta = pd.Timedelta("1 day"),
1050
+ write_to_file: Optional[str] = None
1051
+ ):
1052
+ """Simulate a laboratory experiment according to a treatment dictionary
1053
+
1054
+ """
1055
+
1056
+ experiment = draft_laboratory_experiment(
1057
+ treatments=treatments,
1058
+ experiment_end=experiment_end,
1059
+ exposure_pattern=exposure_pattern,
1060
+ exposure_dimension=exposure_dimension,
1061
+ dt=dt
1062
+ )
1063
+
1064
+
1065
+ survival = np.full(
1066
+ [v for k, v in experiment.sizes.items() if k in ("time", "id")],
1067
+ fill_value=np.nan
1068
+ )
1069
+ # set with the number of test organism at time zerp
1070
+ survival[:, 0] = n_test_organisms_per_treatment
1071
+ experiment["survival"] = xr.DataArray(survival, coords=[experiment.id,experiment.time])
1072
+
1073
+ if observation_times is None:
1074
+ observation_times_safe = experiment.time
1075
+ else:
1076
+ observation_times_safe = np.unique(np.concatenate([experiment.time,observation_times]))
1077
+
1078
+ experiment = experiment.reindex(time=observation_times_safe)
1079
+
1080
+ # TODO: This does not year make the exposure profiles openguts ready. I.e.
1081
+ # if concentration changes occurr this will not be completely explicit by
1082
+ # making jumps
1083
+ # this requires a method that adds a time point before any change if constant-forward
1084
+ if exposure_interpolation == "linear":
1085
+ experiment["exposure"] = experiment["exposure"].interpolate_na(dim="time", method="linear")
1086
+ else:
1087
+ experiment["exposure"] = experiment["exposure"].ffill(dim="time",)
1088
+
1089
+
1090
+ # TODO: What to do with interpolation
1091
+ return experiment
1092
+
1093
+ @classmethod
1094
+ def to_openguts(cls, observations: xr.Dataset, path: str, time_unit: str):
1095
+
1096
+ experiment = observations.rename({
1097
+ "time": f"time [{time_unit}]"
1098
+ })
1099
+
1100
+ extra_dim = cls._get_exposure_dimension(observations.dims.keys())
1101
+
1102
+ os.makedirs(os.path.dirname(path), exist_ok=True)
1103
+
1104
+ with pd.ExcelWriter(path) as writer:
1105
+ for coord in observations[extra_dim].values:
1106
+ experiment.exposure.sel({extra_dim: coord}).to_pandas().T.to_excel(writer, sheet_name=coord)
1107
+ experiment.survival.to_pandas().T.to_excel(writer, sheet_name="survival")
1108
+
1109
+
1110
+
1111
+ def load_exposure_scenario(
1112
+ self,
1113
+ data: Dict[str,pd.DataFrame],
1114
+ sheet_name_prefix: str = "",
1115
+ rect_interpolate=False
1116
+
1117
+ ):
1118
+
1119
+ self._obs_backup = self.observations.copy(deep=True)
1120
+
1121
+ # read exposure array from file
1122
+ exposure_dim = [
1123
+ d for d in self.config.data_structure["exposure"].dimensions
1124
+ if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
1125
+ ]
1126
+ exposure = self._exposure_data_to_xarray(
1127
+ exposure_data=data,
1128
+ dim=exposure_dim[0]
1129
+ )
1130
+
1131
+ # combine with observations
1132
+ new_obs = xr.combine_by_coords([
1133
+ exposure,
1134
+ self.observations.survival
1135
+ ]).sel(id=exposure.id)
1136
+
1137
+ self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()]) # type: ignore
1138
+ self.config.simulation.x_in = ["exposure=exposure"]
1139
+ self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
1140
+ self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
1141
+
1142
+ self.dispatch_constructor()
1143
+
1144
+ def export(self, directory: Optional[str] = None, mode: Literal["export", "copy"] = "export", skip_data_processing=True):
1145
+ self.config.simulation.skip_data_processing = skip_data_processing
1146
+ super().export(directory=directory, mode=mode)
1147
+
1148
+ def export_to_scenario(self, scenario, force=False):
1149
+ """Exports a case study as a new scenario for running inference"""
1150
+ self.config.case_study.scenario = scenario
1151
+ self.config.case_study.data = None
1152
+ self.config.case_study.output = None
1153
+ self.config.case_study.scenario_path_override = None
1154
+ self.config.simulation.skip_data_processing = True
1155
+ self.save_observations(filename=f"observations_{scenario}.nc", force=force)
1156
+ self.config.save(force=force)
1157
+
1158
+ @staticmethod
1159
+ def _condition_posterior(
1160
+ posterior: xr.Dataset,
1161
+ parameter: str,
1162
+ value: float,
1163
+ exception: Literal["raise", "warn"]="raise"
1164
+ ):
1165
+ """TODO: Provide this method also to SimulationBase"""
1166
+ if parameter not in posterior:
1167
+ keys = list(posterior.keys())
1168
+ msg = (
1169
+ f"{parameter=} was not found in the posterior {keys=}. " +
1170
+ f"Unable to condition the posterior to {value=}. Have you "+
1171
+ "requested the correct parameter for conditioning?"
1172
+ )
1173
+
1174
+ if exception == "raise":
1175
+ raise GutsBaseError(msg)
1176
+ elif exception == "warn":
1177
+ warnings.warn(msg)
1178
+ else:
1179
+ raise GutsBaseError(
1180
+ "Use one of exception='raise' or exception='warn'. " +
1181
+ f"Currently using {exception=}"
1182
+ )
1183
+
1184
+ # broadcast value so that methods like drawing samples and hdi still work
1185
+ broadcasted_value = np.full_like(posterior[parameter], value)
1186
+
1187
+ return posterior.assign({
1188
+ parameter: (posterior[parameter].dims, broadcasted_value)
1189
+ })
1190
+
1191
+
1192
+ class GutsSimulationConstantExposure(GutsBase):
1193
+ t_max = 10
1194
+ def initialize_from_script(self):
1195
+ self.config.data_structure.B = DataVariable(dimensions=["time"], observed=False)
1196
+ self.config.data_structure.D = DataVariable(dimensions=["time"], observed=False)
1197
+ self.config.data_structure.H = DataVariable(dimensions=["time"], observed=False)
1198
+ self.config.data_structure.survival = DataVariable(dimensions=["time"], observed=False)
1199
+
1200
+ # y0
1201
+ self.config.simulation.y0 = ["D=Array([0])", "H=Array([0])", "survival=Array([1])"]
1202
+ self.model_parameters["y0"] = self.parse_input(input="y0", drop_dims=["time"])
1203
+
1204
+ # parameters
1205
+ self.config.model_parameters.C_0 = Param(value=10.0, free=False)
1206
+ self.config.model_parameters.k_d = Param(value=0.9, free=True)
1207
+ self.config.model_parameters.h_b = Param(value=0.00005, free=True)
1208
+ self.config.model_parameters.b = Param(value=5.0, free=True)
1209
+ self.config.model_parameters.z = Param(value=0.2, free=True)
1210
+
1211
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
1212
+ self.config.simulation.model = "guts_constant_exposure"
1213
+
1214
+ self.coordinates["time"] = np.linspace(0,self.t_max)
1215
+
1216
+ def use_jax_solver(self):
1217
+ # =======================
1218
+ # Define model and solver
1219
+ # =======================
1220
+
1221
+ self.coordinates["time"] = np.array([0,self.t_max])
1222
+ self.config.simulation.model = "guts_constant_exposure"
1223
+
1224
+ self.solver = JaxSolver
1225
+
1226
+ self.dispatch_constructor(diffrax_solver=Dopri5)
1227
+
1228
+ def use_symbolic_solver(self):
1229
+ # =======================
1230
+ # Define model and solver
1231
+ # =======================
1232
+
1233
+ self.coordinates["time"] = np.array([0,self.t_max])
1234
+ self.config.simulation.model = "guts_sympy"
1235
+
1236
+ self.solver = mod.PiecewiseSymbolicSolver
1237
+
1238
+ self.dispatch_constructor(diffrax_solver=Dopri5)
1239
+
1240
+
1241
+ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
1242
+ t_max = 10
1243
+ def initialize_from_script(self):
1244
+ super().initialize_from_script()
1245
+ del self.coordinates["time"]
1246
+ exposure = create_artificial_data(
1247
+ t_max=self.t_max, dt=1,
1248
+ exposure_paths=["topical"]
1249
+ ).squeeze()
1250
+ self.observations = exposure
1251
+
1252
+ self.config.data_structure.exposure = DataVariable(dimensions=["time"], observed=True)
1253
+
1254
+ self.config.simulation.x_in = ["exposure=exposure"]
1255
+ x_in = self.parse_input(input="x_in", reference_data=exposure, drop_dims=[])
1256
+ x_in = rect_interpolation(x_in=x_in, x_dim="time")
1257
+ self.model_parameters["x_in"] = x_in
1258
+
1259
+ # parameters
1260
+ self.config.model_parameters.remove("C_0")
1261
+
1262
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
1263
+ self.config.simulation.solver_post_processing = "red_sd_post_processing"
1264
+ self.config.simulation.model = "guts_variable_exposure"
1265
+
1266
+
1267
+ def use_jax_solver(self):
1268
+ # =======================
1269
+ # Define model and solver
1270
+ # =======================
1271
+
1272
+ self.model = self._mod.guts_variable_exposure
1273
+ self.solver = JaxSolver
1274
+
1275
+ self.dispatch_constructor(diffrax_solver=Dopri5)
1276
+
1277
+ def use_symbolic_solver(self, do_compile=True):
1278
+ # =======================
1279
+ # Define model and solver
1280
+ # =======================
1281
+
1282
+ self.model = self._mod.guts_sympy
1283
+ self.solver = self._mod.PiecewiseSymbolicSolver
1284
+
1285
+ self.dispatch_constructor(do_compile=do_compile, output_path=self.output_path)
1286
+