guts-base 2.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guts_base/__init__.py +15 -0
- guts_base/data/__init__.py +35 -0
- guts_base/data/expydb.py +248 -0
- guts_base/data/generator.py +191 -0
- guts_base/data/openguts.py +296 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +148 -0
- guts_base/data/time_of_death.py +595 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +332 -0
- guts_base/plot.py +201 -0
- guts_base/prob/__init__.py +13 -0
- guts_base/prob/binom.py +18 -0
- guts_base/prob/conditional_binom.py +118 -0
- guts_base/prob/conditional_binom_mv.py +233 -0
- guts_base/prob/predictions.py +164 -0
- guts_base/sim/__init__.py +28 -0
- guts_base/sim/base.py +1286 -0
- guts_base/sim/config.py +170 -0
- guts_base/sim/constructors.py +31 -0
- guts_base/sim/ecx.py +585 -0
- guts_base/sim/mempy.py +290 -0
- guts_base/sim/report.py +405 -0
- guts_base/sim/transformer.py +548 -0
- guts_base/sim/units.py +313 -0
- guts_base/sim/utils.py +10 -0
- guts_base-2.0.0b0.dist-info/METADATA +853 -0
- guts_base-2.0.0b0.dist-info/RECORD +32 -0
- guts_base-2.0.0b0.dist-info/WHEEL +5 -0
- guts_base-2.0.0b0.dist-info/entry_points.txt +3 -0
- guts_base-2.0.0b0.dist-info/licenses/LICENSE +674 -0
- guts_base-2.0.0b0.dist-info/top_level.txt +1 -0
guts_base/sim/base.py
ADDED
|
@@ -0,0 +1,1286 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import tempfile
|
|
4
|
+
import warnings
|
|
5
|
+
import numpy as np
|
|
6
|
+
import xarray as xr
|
|
7
|
+
from diffrax import Dopri5
|
|
8
|
+
from typing import Literal, Optional, List, Dict, Mapping, Sequence, Type, Hashable
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from pymob import SimulationBase
|
|
12
|
+
from pymob.sim.config import (
|
|
13
|
+
DataVariable, Param, NumericArray, Numpyro
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from pymob.solvers import JaxSolver
|
|
17
|
+
from pymob.solvers.base import rect_interpolation
|
|
18
|
+
from expyDB.intervention_model import (
|
|
19
|
+
Treatment, Timeseries, select, from_expydb
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
from guts_base.sim.transformer import NoTransform, GenericTransform
|
|
24
|
+
from guts_base.sim.utils import GutsBaseError
|
|
25
|
+
from guts_base import mod
|
|
26
|
+
from guts_base.data import (
|
|
27
|
+
to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
|
|
28
|
+
create_database_and_import_data_main, design_exposure_scenario, ExposureDataDict
|
|
29
|
+
)
|
|
30
|
+
from guts_base.data.generator import draft_laboratory_experiment
|
|
31
|
+
from guts_base.sim.report import GutsReport
|
|
32
|
+
from guts_base.sim import units as _units
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GutsBase(SimulationBase):
|
|
36
|
+
"""
|
|
37
|
+
Base class for GUTS simulations.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
unit_time (str): The unit of time used for simulation outputs and plots.
|
|
41
|
+
Can be one of "day", "hour", "minute", or "second".
|
|
42
|
+
Defaults to "day" if not specified in the simulation configuration.
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
solver = JaxSolver
|
|
46
|
+
Report = GutsReport
|
|
47
|
+
Transform: Type[GenericTransform] = NoTransform
|
|
48
|
+
|
|
49
|
+
transformer: GenericTransform = NoTransform()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_exclude_controls_after_fixing_background_mortality = False
|
|
53
|
+
# results_interpolation: Tuple[float,float,int] = (np.nan, np.nan, 100)
|
|
54
|
+
# _skip_data_processing: bool = False
|
|
55
|
+
# ecx_mode: Literal["mean", "draws"] = "mean"
|
|
56
|
+
# exposure_scenarios = {
|
|
57
|
+
# "acute_1day": {"start": 0.0, "end": 1.0},
|
|
58
|
+
# "chronic": {"start": 0.0, "end": None},
|
|
59
|
+
# }
|
|
60
|
+
|
|
61
|
+
def initialize(self, input: Optional[Dict] = None):
|
|
62
|
+
"""Initiaization goes through a couple of steps:
|
|
63
|
+
|
|
64
|
+
1. Configuration: This makes case-study specific changes to the configuration
|
|
65
|
+
file or sets state variables that are relevant for the simulation
|
|
66
|
+
TODO: Ideally everything that is configurable ends up in the config so it
|
|
67
|
+
can be serialized
|
|
68
|
+
|
|
69
|
+
2. Import data: This method consists of submethods that can be adapted or
|
|
70
|
+
overwritten in subclass methods.
|
|
71
|
+
- .read_data
|
|
72
|
+
- .save_observations
|
|
73
|
+
- .process_data
|
|
74
|
+
process_data itself utilizes the submethods _create_indices and
|
|
75
|
+
_indices_to_dimensions which are empty methods by default, but can be used
|
|
76
|
+
in subclasses if needed
|
|
77
|
+
|
|
78
|
+
3. Initialize the simulation input (parameters, y0, x_in). This can
|
|
79
|
+
|
|
80
|
+
By splitting up the simulation init method, into these three steps, modifcations
|
|
81
|
+
of the initialize method allows for higher granularity in subclasses.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# 1. Configuration
|
|
85
|
+
self.configure_case_study()
|
|
86
|
+
|
|
87
|
+
# 2. Import data
|
|
88
|
+
self.observations = self.read_data()
|
|
89
|
+
# FIXME: Saving observations here is not intuituve. If i export a simulation,
|
|
90
|
+
# I want to use the last used state, not some obscure intermediate state
|
|
91
|
+
# self.save_observations(filename="observations.nc", directory=self.output_path, force=True)
|
|
92
|
+
if not self.config.guts_base.skip_data_processing:
|
|
93
|
+
self.process_data()
|
|
94
|
+
|
|
95
|
+
# 3. prepare y0 and x_in
|
|
96
|
+
self.prepare_simulation_input()
|
|
97
|
+
|
|
98
|
+
def configure_case_study(self):
|
|
99
|
+
"""Modify configuration file or set state variables
|
|
100
|
+
"""
|
|
101
|
+
if self._model_class is not None:
|
|
102
|
+
self.model = self._model_class._rhs_jax
|
|
103
|
+
self.solver_post_processing = self._model_class._solver_post_processing
|
|
104
|
+
|
|
105
|
+
def prepare_simulation_input(self):
|
|
106
|
+
x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
|
|
107
|
+
y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
|
|
108
|
+
|
|
109
|
+
# add model components
|
|
110
|
+
if self.config.guts_base.forward_interpolate_exposure_data: # type: ignore
|
|
111
|
+
self.model_parameters["x_in"] = rect_interpolation(x_in)
|
|
112
|
+
else:
|
|
113
|
+
# linear interpolation will be the default assumption, and this will also
|
|
114
|
+
# lead to rect_interpolation, if the exposure_profile was before rectified
|
|
115
|
+
# ffill(dim="time") makes sure that no NaN values are at the end
|
|
116
|
+
self.model_parameters["x_in"] = x_in.interpolate_na(dim="time", method="linear").ffill(dim="time")
|
|
117
|
+
|
|
118
|
+
self.model_parameters["y0"] = y0
|
|
119
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
120
|
+
|
|
121
|
+
def construct_database_statement_from_config(self):
|
|
122
|
+
"""returns a statement to be used on a database"""
|
|
123
|
+
substance = self.config.simulation.substance # type:ignore
|
|
124
|
+
exposure_path = self.config.simulation.exposure_path # type:ignore
|
|
125
|
+
return (
|
|
126
|
+
select(Timeseries, Treatment)
|
|
127
|
+
.join(Timeseries)
|
|
128
|
+
).where(
|
|
129
|
+
Timeseries.variable.in_([substance]), # type: ignore
|
|
130
|
+
Timeseries.name == {exposure_path}
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def read_data(self) -> xr.Dataset:
|
|
134
|
+
"""Reads data and returns an xarray.Dataset.
|
|
135
|
+
|
|
136
|
+
GutsBase supports reading data from
|
|
137
|
+
- netcdf (.nc) files
|
|
138
|
+
- expyDB (SQLite databases)
|
|
139
|
+
- excel (directories of excel files)
|
|
140
|
+
|
|
141
|
+
expyDB and excel operate by converting data to xarrays while netcdf directly
|
|
142
|
+
loads xarray Datasets. For highest control over your data, you should always use
|
|
143
|
+
.nc files, because they are imported as-is.
|
|
144
|
+
"""
|
|
145
|
+
# TODO: Update to new INTERVENTION MODEL
|
|
146
|
+
dataset = str(self.config.case_study.observations)
|
|
147
|
+
|
|
148
|
+
# read from a directory
|
|
149
|
+
if os.path.isdir(os.path.join(self.config.case_study.data_path, dataset)):
|
|
150
|
+
# This looks for xlsx files in the folder and imports them as a database and
|
|
151
|
+
# then proceeds as normal
|
|
152
|
+
files = glob.glob(os.path.join(
|
|
153
|
+
self.config.case_study.data_path,
|
|
154
|
+
dataset, "*.xlsx"
|
|
155
|
+
))
|
|
156
|
+
|
|
157
|
+
tempdir = tempfile.TemporaryDirectory()
|
|
158
|
+
dataset = self.read_data_from_xlsx(data=files, tempdir=tempdir)
|
|
159
|
+
|
|
160
|
+
ext = dataset.split(".")[-1]
|
|
161
|
+
|
|
162
|
+
if not os.path.exists(dataset):
|
|
163
|
+
dataset = os.path.join(self.data_path, dataset)
|
|
164
|
+
|
|
165
|
+
if ext == "db":
|
|
166
|
+
statement = self.construct_database_statement_from_config()
|
|
167
|
+
observations = self.read_data_from_expydb(dataset, statement)
|
|
168
|
+
|
|
169
|
+
# TODO: Integrate interventions in observations dataset
|
|
170
|
+
|
|
171
|
+
elif ext == "nc":
|
|
172
|
+
observations = xr.load_dataset(dataset)
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
raise NotImplementedError(
|
|
176
|
+
f"Dataset extension '.{ext}' is not recognized. "+
|
|
177
|
+
"Please use one of '.db' (mysql), '.nc' (netcdf)."
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return observations
|
|
181
|
+
|
|
182
|
+
def read_data_from_xlsx(self, data, tempdir):
|
|
183
|
+
database = os.path.join(tempdir.name, "import.db")
|
|
184
|
+
|
|
185
|
+
create_database_and_import_data_main(
|
|
186
|
+
datasets_path=data,
|
|
187
|
+
database_path=database,
|
|
188
|
+
preprocessing=self.config.guts_base.data_preprocessing,
|
|
189
|
+
preprocessing_out=os.path.join(tempdir.name, "processed_{filename}")
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return database
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def read_data_from_expydb(self, database, statement) -> xr.Dataset:
|
|
196
|
+
|
|
197
|
+
observations_idata, interventions_idata = from_expydb(
|
|
198
|
+
database=f"sqlite:///{database}",
|
|
199
|
+
statement=statement
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
dataset = to_dataset(
|
|
203
|
+
observations_idata,
|
|
204
|
+
interventions_idata,
|
|
205
|
+
unit_time=self.config.guts_base.unit_time
|
|
206
|
+
)
|
|
207
|
+
dataset = reduce_multiindex_to_flat_index(dataset)
|
|
208
|
+
|
|
209
|
+
# "Continue here. I want to return multidimensional datasets for data coming "+
|
|
210
|
+
# "from the database. The method can be implemented in any class. Currently I'm looking "+
|
|
211
|
+
# "at guts base"
|
|
212
|
+
|
|
213
|
+
filtered_dataset = self.filter_dataset(dataset)
|
|
214
|
+
|
|
215
|
+
return filtered_dataset
|
|
216
|
+
|
|
217
|
+
def process_data(self):
|
|
218
|
+
"""
|
|
219
|
+
Currently these methods, change datasets, indices, etc. in-place.
|
|
220
|
+
This is convenient, but more difficult to re-arragen with other methods
|
|
221
|
+
TODO: Make these methods static if possible
|
|
222
|
+
"""
|
|
223
|
+
self._create_indices()
|
|
224
|
+
self._indices_to_dimensions()
|
|
225
|
+
# define tolerance based on the sovler tolerance
|
|
226
|
+
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
227
|
+
|
|
228
|
+
self._reindex_time_dim()
|
|
229
|
+
|
|
230
|
+
if "survival" in self.observations:
|
|
231
|
+
if "subject_count" not in self.observations.coords:
|
|
232
|
+
self.observations = self.observations.assign_coords(
|
|
233
|
+
subject_count=("id", self.observations["survival"].isel(time=0).values, )
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if self._data.is_survival_only_nan_except_start(self.observations.survival):
|
|
237
|
+
self.observations = self.observations.assign_coords({
|
|
238
|
+
"survivors_at_start": (("id", "time"), np.broadcast_to(
|
|
239
|
+
self.observations.survival.isel(time=0).values.reshape(-1,1),
|
|
240
|
+
shape=self.observations.survival.shape
|
|
241
|
+
).astype(int))})
|
|
242
|
+
else:
|
|
243
|
+
self.observations = self._data.prepare_survival_data_for_conditional_binomial(
|
|
244
|
+
observations=self.observations
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if "exposure" not in self.observations:
|
|
248
|
+
self.observations["exposure"] = self.observations[self.config.guts_base.substance]
|
|
249
|
+
|
|
250
|
+
# set
|
|
251
|
+
self.config.data_structure["exposure"].observed=False
|
|
252
|
+
|
|
253
|
+
def _convert_exposure_units(self):
|
|
254
|
+
"""
|
|
255
|
+
TODO: Here I need to decide what to do. Work with rescaled units is dangerous
|
|
256
|
+
because fitting might be complicated with weird quantities.
|
|
257
|
+
It would be better to rescale output parameters
|
|
258
|
+
"""
|
|
259
|
+
units, unit_conversion_factors = _units._convert_units(
|
|
260
|
+
self.observations.unit.reset_coords("unit", drop=True),
|
|
261
|
+
target_units=self.config.guts_base.unit_target
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self.observations = self.observations.assign_coords({
|
|
265
|
+
"unit": units,
|
|
266
|
+
"unit_conversion_factors": unit_conversion_factors
|
|
267
|
+
})
|
|
268
|
+
|
|
269
|
+
self.observations[self.config.guts_base.substance] =\
|
|
270
|
+
self.observations[self.config.guts_base.substance] * unit_conversion_factors
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _unique_unsorted(values):
|
|
274
|
+
_, index = np.unique(values, return_index=True)
|
|
275
|
+
return tuple(np.array(values)[sorted(index)])
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _create_indices(self):
|
|
279
|
+
"""Use if indices should be added to sim.indices and sim.observations"""
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
def _indices_to_dimensions(self):
|
|
283
|
+
pass
|
|
284
|
+
|
|
285
|
+
def filter_dataset(self, dataset: xr.Dataset) -> xr.Dataset:
|
|
286
|
+
return dataset
|
|
287
|
+
|
|
288
|
+
def _reindex_time_dim(self):
|
|
289
|
+
if self.config.simulation.model is not None:
|
|
290
|
+
if "_it" in self.config.simulation.model.lower():
|
|
291
|
+
self.logger.info(msg=(
|
|
292
|
+
"Redindexing time vector to increase resolution, because model has "+
|
|
293
|
+
"'_it' (individual tolerance) in it's name"
|
|
294
|
+
))
|
|
295
|
+
if not hasattr(self.config.simulation, "n_reindexed_x"):
|
|
296
|
+
self.config.simulation.n_reindexed_x = 100
|
|
297
|
+
|
|
298
|
+
new_time_index = np.unique(np.concatenate([
|
|
299
|
+
self.coordinates["time"],
|
|
300
|
+
np.linspace(
|
|
301
|
+
0, np.max(self.coordinates["time"]),
|
|
302
|
+
int(self.config.simulation.n_reindexed_x) # type: ignore
|
|
303
|
+
)
|
|
304
|
+
]))
|
|
305
|
+
self.observations = self.observations.reindex(time = new_time_index)
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
self.logger.info(msg=(
|
|
309
|
+
"No redindexing of time vector to, because model name did not contain "+
|
|
310
|
+
"'_it' (individual tolerance), or model was not given by name. If an IT model " +
|
|
311
|
+
"is calculated without a dense time resolution, the estimates can be biased!"
|
|
312
|
+
))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def reset_observations(self):
|
|
316
|
+
"""Resets the observations to the original observations after using .from_mempy(...)
|
|
317
|
+
This also resets the sim.coordinates dictionary.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
self.observations = self._obs_backup
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def recompute_posterior(self):
|
|
324
|
+
"""This function interpolates the posterior with a given resolution
|
|
325
|
+
posterior_predictions calculate proper survival predictions for the
|
|
326
|
+
posterior.
|
|
327
|
+
|
|
328
|
+
It also makes sure that the new interpolation does not include fewer values
|
|
329
|
+
than the original dataset
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
ri = self.config.guts_base.results_interpolation
|
|
333
|
+
|
|
334
|
+
# generate high resolution posterior predictions
|
|
335
|
+
if ri is not None:
|
|
336
|
+
time_interpolate = np.linspace(
|
|
337
|
+
start=float(self.observations["time"].min()) if np.isnan(ri[0]) else ri[0],
|
|
338
|
+
stop=float(self.observations["time"].max()) if np.isnan(ri[0]) else ri[1],
|
|
339
|
+
num=ri[2]
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# combine original coordinates and interpolation. This
|
|
343
|
+
# a) helps error checking during posterior predictions.
|
|
344
|
+
# b) makes sure that the original time vector is retained, which may be
|
|
345
|
+
# relevant for the simulation success (e.g. IT model)
|
|
346
|
+
obs = self.observations.reindex(
|
|
347
|
+
time=np.unique(np.concatenate(
|
|
348
|
+
[time_interpolate, self.observations["time"]]
|
|
349
|
+
)),
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
obs["survivors_before_t"] = obs.survivors_before_t.ffill(dim="time").astype(int)
|
|
353
|
+
obs["survivors_at_start"] = obs.survivors_at_start.ffill(dim="time").astype(int)
|
|
354
|
+
self.observations = obs
|
|
355
|
+
|
|
356
|
+
self.dispatch_constructor()
|
|
357
|
+
_ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def prior_predictive_checks(self, **plot_kwargs):
|
|
361
|
+
super().prior_predictive_checks(**plot_kwargs)
|
|
362
|
+
|
|
363
|
+
self._plot.plot_prior_predictions(self, data_vars=["survival"])
|
|
364
|
+
|
|
365
|
+
def posterior_predictive_checks(self, **plot_kwargs):
|
|
366
|
+
super().posterior_predictive_checks(**plot_kwargs)
|
|
367
|
+
|
|
368
|
+
sim_copy: GutsBase = self.copy()
|
|
369
|
+
sim_copy.recompute_posterior()
|
|
370
|
+
# TODO: Include posterior_predictive group once the survival predictions are correctly working
|
|
371
|
+
sim_copy._plot.plot_posterior_predictions(
|
|
372
|
+
sim_copy, data_vars=["survival"], groups=["posterior_model_fits"]
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def plot(self, results):
|
|
377
|
+
self._plot.plot_survival(self, results)
|
|
378
|
+
|
|
379
|
+
def predefined_scenarios(self):
|
|
380
|
+
"""
|
|
381
|
+
TODO: Fix timescale to observations
|
|
382
|
+
TODO: Incorporate extra exposure patterns (constant, pulse_1day, pulse_2day)
|
|
383
|
+
"""
|
|
384
|
+
# get the maximum possible time to provide exposure scenarios that are definitely
|
|
385
|
+
# long enough
|
|
386
|
+
time_max = float(max(
|
|
387
|
+
self.observations[self.config.simulation.x_dimension].max(),
|
|
388
|
+
*self.config.guts_base.ecx_estimates_times
|
|
389
|
+
))
|
|
390
|
+
|
|
391
|
+
# this produces a exposure x_in dataset with only the dimensions ID and TIME
|
|
392
|
+
standard_dimensions = (
|
|
393
|
+
self.config.simulation.batch_dimension,
|
|
394
|
+
self.config.simulation.x_dimension,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# get dimensions different from standard dimensions
|
|
398
|
+
exposure_dimension = [
|
|
399
|
+
d for d in self.observations.exposure.dims if d not in
|
|
400
|
+
standard_dimensions
|
|
401
|
+
]
|
|
402
|
+
|
|
403
|
+
# raise an error if the number of extra dimensions is larger than 1
|
|
404
|
+
if len(exposure_dimension) > 1:
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"{type(self).__name__} can currently handle one additional dimension for "+
|
|
407
|
+
f"the exposure beside {standard_dimensions}. You provided an exposure "+
|
|
408
|
+
f"array with the dimensions: {self.observations.exposure.dims}"
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
exposure_dimension = exposure_dimension[0]
|
|
412
|
+
|
|
413
|
+
# iterate over the coordinates of the exposure dimensions to
|
|
414
|
+
exposure_coordinates = self.observations.exposure[exposure_dimension].values
|
|
415
|
+
|
|
416
|
+
scenarios = {}
|
|
417
|
+
for coord in exposure_coordinates:
|
|
418
|
+
concentrations = np.where(coord == exposure_coordinates, 1.0, 0.0)
|
|
419
|
+
|
|
420
|
+
for _name, _expo_scenario in self.config.guts_base.ecx_exposure_scenarios.items():
|
|
421
|
+
if _expo_scenario["start"] is None:
|
|
422
|
+
_expo_scenario["start"] = 0.0
|
|
423
|
+
|
|
424
|
+
exposure_dict = {
|
|
425
|
+
coord: ExposureDataDict(
|
|
426
|
+
start=_expo_scenario["start"],
|
|
427
|
+
end=_expo_scenario["end"],
|
|
428
|
+
exposure=conc
|
|
429
|
+
)
|
|
430
|
+
for coord, conc in zip(exposure_coordinates, concentrations)
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
scenario = design_exposure_scenario(
|
|
434
|
+
exposures=exposure_dict,
|
|
435
|
+
t_max=time_max,
|
|
436
|
+
dt=1/24,
|
|
437
|
+
exposure_dimension=exposure_dimension
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
scenarios.update({
|
|
441
|
+
f"{_name}_{coord}": scenario
|
|
442
|
+
})
|
|
443
|
+
|
|
444
|
+
return scenarios
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
@staticmethod
|
|
449
|
+
def _exposure_data_to_xarray(
|
|
450
|
+
exposure_data: Dict[str, pd.DataFrame],
|
|
451
|
+
dim: str,
|
|
452
|
+
default_time_unit: str = "",
|
|
453
|
+
exposure_units: Mapping[str,str] = {"default": ""},
|
|
454
|
+
) -> xr.Dataset:
|
|
455
|
+
"""Creates a Dataset named exposure that has coordinates corresponding to the
|
|
456
|
+
keys in the exposure_data and a dimension name accordint to dim. It also carries
|
|
457
|
+
an unused coordinate called unit, which carries the unit information of the exposure
|
|
458
|
+
"""
|
|
459
|
+
arrays = {}
|
|
460
|
+
_unit_time = []
|
|
461
|
+
_unit_exposure = {}
|
|
462
|
+
for key, df in exposure_data.items():
|
|
463
|
+
# this override is necessary to make all dimensions work out
|
|
464
|
+
unit_time = _units._get_unit_from_dataframe_index(df)
|
|
465
|
+
unit_expo = exposure_units.get(key, exposure_units["default"])
|
|
466
|
+
|
|
467
|
+
df.index.name = "time"
|
|
468
|
+
arrays.update({
|
|
469
|
+
key: df.to_xarray().to_dataarray(dim="id", name=key)
|
|
470
|
+
})
|
|
471
|
+
_unit_time.append(unit_time)
|
|
472
|
+
_unit_exposure.update({key: f"{_units.ureg.parse_expression(unit_expo).units:C}"})
|
|
473
|
+
|
|
474
|
+
# convert exposure units to an xr.Dataarray
|
|
475
|
+
units_arr = pd.Series(_unit_exposure).to_xarray()
|
|
476
|
+
units_arr = units_arr.rename({"index": dim})
|
|
477
|
+
|
|
478
|
+
# make sure times of all exposures are identical
|
|
479
|
+
if len(set(_unit_time)) > 1:
|
|
480
|
+
raise GutsBaseError(
|
|
481
|
+
"Different time units were specified in the exposure datasets " +
|
|
482
|
+
f"{set(_unit_time)}. Make sure all exposure datasets have the " +
|
|
483
|
+
"same time unit."
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
unit_time = list(set(_unit_time))[0]
|
|
487
|
+
|
|
488
|
+
# if the unit_time was not defined, resort to the default time unit (dimensionless)
|
|
489
|
+
if len(unit_time) == 0:
|
|
490
|
+
unit_time = default_time_unit
|
|
491
|
+
|
|
492
|
+
# create the exposure dataset
|
|
493
|
+
exposure_array = xr.Dataset(arrays).to_array(dim=dim, name="exposure")
|
|
494
|
+
exposure_array = exposure_array.transpose("id", "time", ...)
|
|
495
|
+
ds = xr.Dataset({"exposure": exposure_array})
|
|
496
|
+
|
|
497
|
+
# add the time unit as an attribute (rounttrip ureg, to standardize)
|
|
498
|
+
ds.attrs["unit_time"] = f"{_units.ureg.parse_expression(unit_time).units:C}"
|
|
499
|
+
# add exposure units as coordinates. This is used later on by _convert units
|
|
500
|
+
ds = ds.assign_coords({"unit": units_arr})
|
|
501
|
+
|
|
502
|
+
return ds
|
|
503
|
+
|
|
504
|
+
@staticmethod
|
|
505
|
+
def _survival_data_to_xarray(
|
|
506
|
+
survival_data: pd.DataFrame,
|
|
507
|
+
default_time_unit: str = ""
|
|
508
|
+
) -> xr.Dataset:
|
|
509
|
+
# TODO: survival name is currently not kept because the raw data is not transferred from the survival
|
|
510
|
+
|
|
511
|
+
unit_time = _units._get_unit_from_dataframe_index(survival_data)
|
|
512
|
+
survival_data.index.name = "time"
|
|
513
|
+
|
|
514
|
+
survival_array = survival_data.to_xarray().to_dataarray(dim="id", name="survival")
|
|
515
|
+
survival_array = survival_array.transpose("id", "time", ...)
|
|
516
|
+
arrays = {"survival": survival_array}
|
|
517
|
+
|
|
518
|
+
ds = xr.Dataset(arrays)
|
|
519
|
+
if len(unit_time) == 0:
|
|
520
|
+
unit_time = default_time_unit
|
|
521
|
+
|
|
522
|
+
ds.attrs["unit_time"] = f"{_units.ureg.parse_expression(unit_time).units:C}"
|
|
523
|
+
|
|
524
|
+
return ds
|
|
525
|
+
|
|
526
|
+
@classmethod
|
|
527
|
+
def _observations_from_dataframes(
|
|
528
|
+
cls,
|
|
529
|
+
exposure_data: Dict[str, pd.DataFrame],
|
|
530
|
+
survival_data: Optional[pd.DataFrame] = None,
|
|
531
|
+
exposure_dim: str = "substance",
|
|
532
|
+
unit_input: Mapping[str,str] = {"default": ""},
|
|
533
|
+
unit_time: str = "day",
|
|
534
|
+
):
|
|
535
|
+
# parse observations
|
|
536
|
+
# obs can be simply subset by selection obs.sel(substance="Exposure-Dime")
|
|
537
|
+
_exposure = cls._exposure_data_to_xarray(
|
|
538
|
+
exposure_data, dim=exposure_dim,
|
|
539
|
+
exposure_units=unit_input,
|
|
540
|
+
default_time_unit=unit_time,
|
|
541
|
+
)
|
|
542
|
+
arrays = [_exposure]
|
|
543
|
+
if survival_data is not None:
|
|
544
|
+
_survival = cls._survival_data_to_xarray(
|
|
545
|
+
survival_data,
|
|
546
|
+
default_time_unit=unit_time
|
|
547
|
+
)
|
|
548
|
+
arrays.append(_survival)
|
|
549
|
+
|
|
550
|
+
observations = xr.combine_by_coords(arrays, combine_attrs="no_conflicts")
|
|
551
|
+
|
|
552
|
+
return observations
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def _exposure_dimension(self):
|
|
556
|
+
return self._get_exposure_dimension(
|
|
557
|
+
dimensions=self.config.data_structure["exposure"].dimensions,
|
|
558
|
+
batch_dim=self.config.simulation.batch_dimension,
|
|
559
|
+
x_dim=self.config.simulation.x_dimension
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
@staticmethod
|
|
563
|
+
def _get_exposure_dimension(dimensions, batch_dim: str = "id", x_dim: str = "time"):
|
|
564
|
+
extra_dims = []
|
|
565
|
+
for k in dimensions:
|
|
566
|
+
if k not in (batch_dim, x_dim):
|
|
567
|
+
extra_dims.append(k)
|
|
568
|
+
else:
|
|
569
|
+
pass
|
|
570
|
+
|
|
571
|
+
if len(extra_dims) > 1:
|
|
572
|
+
raise GutsBaseError(
|
|
573
|
+
"Guts Base can currently only handle one exposure dimension beside" +
|
|
574
|
+
"the standard dimensions."
|
|
575
|
+
)
|
|
576
|
+
else:
|
|
577
|
+
return extra_dims[0]
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
|
|
581
|
+
"""This method will take an existing coordinate of a dataset that has the same
|
|
582
|
+
coordinate has the batch dimension. It will then re-express the coordinate as a
|
|
583
|
+
separate dimension for the given variables, by duplicating the N-Dimensional array
|
|
584
|
+
times the amount of unique names in the specified coordinate to create an
|
|
585
|
+
N+1-dimensional array. This array will be filled with zeros along the batch dimension
|
|
586
|
+
where the specified coordinate along the ID dimension coincides with the new (unique)
|
|
587
|
+
coordinate of the new dimension.
|
|
588
|
+
|
|
589
|
+
This process is entirely reversible
|
|
590
|
+
"""
|
|
591
|
+
old_coords = self.observations[coordinate]
|
|
592
|
+
batch_dim = self.config.simulation.batch_dimension
|
|
593
|
+
|
|
594
|
+
# old coordinate before turning it into a dimension
|
|
595
|
+
obs = self.observations.drop(coordinate)
|
|
596
|
+
|
|
597
|
+
# create unique coordinates of the new dimension, preserving the order of the
|
|
598
|
+
# old coordinate
|
|
599
|
+
_, index = np.unique(old_coords, return_index=True)
|
|
600
|
+
coords_new_dim = tuple(np.array(old_coords)[sorted(index)])
|
|
601
|
+
|
|
602
|
+
for v in variables:
|
|
603
|
+
# take data variable and extract dimension order
|
|
604
|
+
data_var = obs[v]
|
|
605
|
+
dim_order = data_var.dims
|
|
606
|
+
|
|
607
|
+
# expand the dimensionality, then transpose for new dim to be last
|
|
608
|
+
data_var = data_var.expand_dims(coordinate).transpose(..., batch_dim, coordinate)
|
|
609
|
+
|
|
610
|
+
# create a dummy dimension to broadcast the new array
|
|
611
|
+
# dummy_3d = np.ones((1, len(coords_new_dim)))
|
|
612
|
+
dummy_categorical = pd.get_dummies(old_coords).astype(int).values
|
|
613
|
+
|
|
614
|
+
# apply automatic broadcasting to increase the size of the new dimension
|
|
615
|
+
# data_var_np1_d = data_var * dummy_3d
|
|
616
|
+
data_var_np1_d = data_var * dummy_categorical
|
|
617
|
+
data_var_np1_d.attrs = data_var.attrs
|
|
618
|
+
|
|
619
|
+
# annotate coordinates of the new dimension
|
|
620
|
+
data_var_np1_d = data_var_np1_d.assign_coords({
|
|
621
|
+
coordinate: list(coords_new_dim)
|
|
622
|
+
})
|
|
623
|
+
|
|
624
|
+
# transpose back to original dimension order with new dim as last dim
|
|
625
|
+
data_var_np1_d = data_var_np1_d.transpose(*dim_order, coordinate)
|
|
626
|
+
obs[v] = data_var_np1_d
|
|
627
|
+
|
|
628
|
+
return obs
|
|
629
|
+
|
|
630
|
+
def map_batch_coordinates_to_extra_dim_coordinates(
|
|
631
|
+
self,
|
|
632
|
+
observations: xr.Dataset,
|
|
633
|
+
target_dimension: str,
|
|
634
|
+
coordinates: Optional[List[Hashable]] = None
|
|
635
|
+
) -> xr.Dataset:
|
|
636
|
+
"""Iterates over coordinates and reduces those coordinates to the new dimension
|
|
637
|
+
which have the same number of unique coordinates as the new dimension has coordinates
|
|
638
|
+
"""
|
|
639
|
+
if coordinates is None:
|
|
640
|
+
coordinates = list(observations.coords.keys())
|
|
641
|
+
|
|
642
|
+
for key, coord in observations.coords.items():
|
|
643
|
+
# skips coords, if not specified in coordinates
|
|
644
|
+
if key not in coordinates:
|
|
645
|
+
continue
|
|
646
|
+
|
|
647
|
+
if self.config.simulation.batch_dimension in coord.dims and key not in observations.dims:
|
|
648
|
+
if len(coord.dims) == 1:
|
|
649
|
+
dim_coords = self._unique_unsorted(coord.values)
|
|
650
|
+
if len(dim_coords) == len(observations[target_dimension]):
|
|
651
|
+
observations[key] = (target_dimension, list(dim_coords))
|
|
652
|
+
else:
|
|
653
|
+
pass
|
|
654
|
+
else:
|
|
655
|
+
warnings.warn(
|
|
656
|
+
f"Coordinate '{key}' is has dimensions {coord.dims}. " +
|
|
657
|
+
"Mapping coordinates with more than 1 dimension to the extra " +
|
|
658
|
+
f"dimension '{target_dimension}' is not supported yet."
|
|
659
|
+
)
|
|
660
|
+
pass
|
|
661
|
+
|
|
662
|
+
return observations
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def reduce_dimension_to_batch_like_coordinate(self, dimension, variables):
|
|
666
|
+
"""This method takes an existing dimension from a N-D array and reduces it to an
|
|
667
|
+
(N-1)-D array, by writing a new coordinate from the reducible dimension in the way
|
|
668
|
+
that the new batch-like coordinate takes the coordinate of the dimension, where
|
|
669
|
+
the data of the N-D array was not zero. After it has been asserted that there is
|
|
670
|
+
only a unique candidate for the each coordinate along the batch dimension
|
|
671
|
+
(i.e. only one value is non-zero for a given batch-coordinate), the dimension will
|
|
672
|
+
be reduced by summing over the given dimension.
|
|
673
|
+
|
|
674
|
+
The method is contingent on having no overlap in batch dimension in the dataset
|
|
675
|
+
"""
|
|
676
|
+
pass
|
|
677
|
+
|
|
678
|
+
def initialize_from_script(self):
|
|
679
|
+
pass
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
@staticmethod
|
|
683
|
+
def _update_model_parameters(model_parameters, params: dict):
|
|
684
|
+
params_stash = {}
|
|
685
|
+
for name, new_values in params.items():
|
|
686
|
+
if not hasattr(model_parameters, name):
|
|
687
|
+
print(f"{name} not in model_parameters, skipping.")
|
|
688
|
+
continue
|
|
689
|
+
|
|
690
|
+
param = model_parameters[name]
|
|
691
|
+
stash = param.model_dump(include=list(new_values.keys()))
|
|
692
|
+
params_stash.update({name: stash})
|
|
693
|
+
|
|
694
|
+
for k, v in new_values.items():
|
|
695
|
+
setattr(param, k, v)
|
|
696
|
+
|
|
697
|
+
return params_stash
|
|
698
|
+
|
|
699
|
+
### API methods ###
|
|
700
|
+
|
|
701
|
+
def transform(
|
|
702
|
+
self,
|
|
703
|
+
inverse=False,
|
|
704
|
+
idata=True,
|
|
705
|
+
observations=True,
|
|
706
|
+
parameters=True,
|
|
707
|
+
):
|
|
708
|
+
"""EXPERIMENTAL FEATURE
|
|
709
|
+
|
|
710
|
+
Transform with care! Transforming a simulation changes the parameter **values**,
|
|
711
|
+
observations, and results, observations and parameters in the idata object (if
|
|
712
|
+
existing).
|
|
713
|
+
|
|
714
|
+
Usage Example
|
|
715
|
+
-------------
|
|
716
|
+
|
|
717
|
+
A typical workflow is:
|
|
718
|
+
|
|
719
|
+
1. Set up the simulation `sim = Constructor.from_model_and_dataset(...)`
|
|
720
|
+
|
|
721
|
+
>>> from mempy.model import RED_SD
|
|
722
|
+
>>> from guts_base import Constructor
|
|
723
|
+
>>> experiment = Constructor.draft_laboratory_experiment(
|
|
724
|
+
... treatments={"C": 0.0, "T1": 1, "T2": 5, "T3": 50, "T4": 100},
|
|
725
|
+
... simulate_survival=True,
|
|
726
|
+
... )
|
|
727
|
+
>>> survival_data = experiment.survival.to_pandas().T
|
|
728
|
+
>>> exposure_data = {"A": experiment.exposure.to_pandas().T}
|
|
729
|
+
>>> sim = Constructor.from_model_and_dataset(
|
|
730
|
+
... model=RED_SD(),
|
|
731
|
+
... exposure_data=survival,
|
|
732
|
+
... survival_data=exposure,
|
|
733
|
+
... output_directory="results/test"
|
|
734
|
+
... )
|
|
735
|
+
|
|
736
|
+
2. Set up the transform `sim.transformer = GutsTransform(x_in_factor=..., time_factor=...)`
|
|
737
|
+
|
|
738
|
+
>>> from guts_base.sim.transformer import GutsTransform
|
|
739
|
+
>>> # define a transformation factor
|
|
740
|
+
>>> x_in_factor = float(sim.observations.exposure.max().values)
|
|
741
|
+
>>> # set transformation
|
|
742
|
+
>>> sim.transformer = GutsTransform(
|
|
743
|
+
... time_factor=1.0, x_in_factor=x_in_factor,
|
|
744
|
+
... ignore_keys=["id", "exposure_path", "w"]
|
|
745
|
+
... )
|
|
746
|
+
|
|
747
|
+
3. Transform the simulation `GutsBase.transform(idata=False)`
|
|
748
|
+
|
|
749
|
+
>>> sim.transform(idata=False)
|
|
750
|
+
|
|
751
|
+
4. Run parameter estimation `GutsBase.estimate_parameters(...)`
|
|
752
|
+
|
|
753
|
+
>>> sim.estimate_parameters()
|
|
754
|
+
|
|
755
|
+
5. Inverse transform the simulation `GutsBase.transform(inverse=True)`
|
|
756
|
+
|
|
757
|
+
>>> sim.transform(inverse=True)
|
|
758
|
+
|
|
759
|
+
Explanations
|
|
760
|
+
------------
|
|
761
|
+
|
|
762
|
+
`GutsBase.transform` DOES NOT change the priors. This means, estimating
|
|
763
|
+
parameters of a transformed simulation will yield different results than
|
|
764
|
+
estimating parameters of an untransformed simulations.
|
|
765
|
+
|
|
766
|
+
However, this is the whole point. Transformations are meant to make the life of
|
|
767
|
+
a modeller easier. By bringing all data on a unit scale, the problem can be
|
|
768
|
+
easier solved, using default priors (e.g. lognormal(scale=1, s=5)). Depending
|
|
769
|
+
of the size of the transformation, the effect can be larger or smaller. If
|
|
770
|
+
for instance I want to scale the exposure to a unit interval [0, 1], and my
|
|
771
|
+
largest exposure is, e.g. 500, this shifts the relative influence of the prior
|
|
772
|
+
distributions also by a factor of 500, especially the m and b parameters.
|
|
773
|
+
|
|
774
|
+
⚠️ Therefore a word of warning. Before applying transformations to a wide set
|
|
775
|
+
of different problems, double check your priors and make sure that they behave
|
|
776
|
+
as expected.
|
|
777
|
+
|
|
778
|
+
This means, transforming is a double edged sword. If the priors for a transformed
|
|
779
|
+
distribution are chosen well, with one set of priors, many problems of vastly
|
|
780
|
+
different scales can be solved. If the default priors are not chosen well,
|
|
781
|
+
you will make your life more difficult rather than easier.
|
|
782
|
+
|
|
783
|
+
Extending Transform classes
|
|
784
|
+
---------------------------
|
|
785
|
+
|
|
786
|
+
**Note** that forward transform (`inverse=False`) divides by the factors and
|
|
787
|
+
inverse transform multiplies by the factors. This is how the functions are defined
|
|
788
|
+
in the GutsTransform.
|
|
789
|
+
|
|
790
|
+
Also, transforms for all parameters might not exist. If you want to define the
|
|
791
|
+
transforms at runtime, this is relatively easy to do:
|
|
792
|
+
|
|
793
|
+
>>> sim.transformer.data_transformer.xxx = lambda self, x: x / self.time_factor
|
|
794
|
+
>>> sim.transformer.data_transformer.xxx_inv = lambda self, x: x * self.time_factor
|
|
795
|
+
|
|
796
|
+
This will create new transforms (forward and inverse) for the data variable 'xxx'.
|
|
797
|
+
Similarly this can be done for the `parameter_transformer` attribute of the sim
|
|
798
|
+
transformer.
|
|
799
|
+
|
|
800
|
+
If you want to provide your own transform functions, feel free to write your own
|
|
801
|
+
class. See `guts_base.sim.transformer` for inspiration. There the `GutsTransform`
|
|
802
|
+
class is defined.
|
|
803
|
+
|
|
804
|
+
Defining Priors
|
|
805
|
+
---------------
|
|
806
|
+
|
|
807
|
+
Some considerations for defining good priors for a simulation transformed to the
|
|
808
|
+
unit scale.
|
|
809
|
+
|
|
810
|
+
If the exposure is on the interval [0, 1] (by transforming the data by the max.
|
|
811
|
+
exposure), for Guts-SD models, the m parameter will be oftentimes within that
|
|
812
|
+
interval. If of course the experiment was conducted in a way that no mortality
|
|
813
|
+
was observed, the m-parameter is likely outside of that interval (m > 1).
|
|
814
|
+
|
|
815
|
+
A sensible prior would be for instance:
|
|
816
|
+
|
|
817
|
+
>>> sim.config.model_parameters.m.prior = 'lognorm(scale=0.01,s=5)'
|
|
818
|
+
|
|
819
|
+
This prior will assign large probability mass to values in the interval [0, 1],
|
|
820
|
+
but will also cover parts above 1 with quite some probability.
|
|
821
|
+
|
|
822
|
+
For the remaining parameters, a good default is 'lognorm(scale=1.0,s=5)'
|
|
823
|
+
|
|
824
|
+
The s-parameter controls the width of the distribution.
|
|
825
|
+
"""
|
|
826
|
+
self.transformer.transform(
|
|
827
|
+
sim=self,
|
|
828
|
+
inverse=inverse,
|
|
829
|
+
idata=idata,
|
|
830
|
+
observations=observations,
|
|
831
|
+
parameters=parameters,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
self.prepare_simulation_input()
|
|
835
|
+
self.dispatch_constructor()
|
|
836
|
+
|
|
837
|
+
def point_estimate(
|
|
838
|
+
self,
|
|
839
|
+
estimate: Literal["mean", "map"] = "map",
|
|
840
|
+
to: Literal["xarray", "dict"] = "xarray"
|
|
841
|
+
):
|
|
842
|
+
"""Returns a point estimate of the posterior. If you want more control over the posterior
|
|
843
|
+
use the attribute: sim.inferer.idata.posterior and summarize it or select from it
|
|
844
|
+
using the arviz (https://python.arviz.org/en/stable/index.html) and the
|
|
845
|
+
xarray (https://docs.xarray.dev/en/stable/index.html) packages
|
|
846
|
+
|
|
847
|
+
Parameters
|
|
848
|
+
----------
|
|
849
|
+
|
|
850
|
+
estimate : Literal["map", "mean"]
|
|
851
|
+
Point estimate to return.
|
|
852
|
+
- map: Maximum a Posteriori. The sample that has the highest posterior probability.
|
|
853
|
+
This sample considers the correlation structure of the posterior
|
|
854
|
+
- mean: The average of all marginal parameter distributions.
|
|
855
|
+
|
|
856
|
+
to : Literal["xarray", "dict"]
|
|
857
|
+
Specifies the representation to transform the summarized data to. dict can
|
|
858
|
+
be used to insert parameters in the .evaluate() method. While xarray is the
|
|
859
|
+
standard view. Defaults to xarray
|
|
860
|
+
|
|
861
|
+
Example
|
|
862
|
+
-------
|
|
863
|
+
|
|
864
|
+
>>> sim.best_estimate(to='dict')
|
|
865
|
+
"""
|
|
866
|
+
if estimate == "mean":
|
|
867
|
+
best_estimate = self.inferer.idata.posterior.mean(("chain", "draw"))
|
|
868
|
+
|
|
869
|
+
elif estimate == "map":
|
|
870
|
+
loglik = self.inferer.idata.log_likelihood\
|
|
871
|
+
.sum(["id", "time"])\
|
|
872
|
+
.to_array().sum("variable")
|
|
873
|
+
|
|
874
|
+
sample_max_loglik = loglik.argmax(dim=("chain", "draw"))
|
|
875
|
+
best_estimate = self.inferer.idata.posterior.sel(sample_max_loglik) # type: ignore
|
|
876
|
+
else:
|
|
877
|
+
raise GutsBaseError(
|
|
878
|
+
f"Estimate '{estimate}' not implemented. Choose one of ['mean', 'map']"
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
if to == "xarray":
|
|
883
|
+
return best_estimate
|
|
884
|
+
|
|
885
|
+
elif to == "dict":
|
|
886
|
+
return {k: v.values for k, v in best_estimate.items()}
|
|
887
|
+
|
|
888
|
+
else:
|
|
889
|
+
raise GutsBaseError(
|
|
890
|
+
"PymobConverter.best_esimtate() supports only return types to=['xarray', 'dict']" +
|
|
891
|
+
f"You used {to=}"
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
def evaluate(
|
|
896
|
+
self,
|
|
897
|
+
parameters: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
898
|
+
y0: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
899
|
+
x_in: Mapping[str, float|NumericArray|Sequence[float]] = {},
|
|
900
|
+
):
|
|
901
|
+
"""Evaluates the model along the coordinates of the observations with given
|
|
902
|
+
parameters, x_in, and y0. The dictionaries passed to the function arguments
|
|
903
|
+
only overwrite the existing default parameters; which makes the usage very simple.
|
|
904
|
+
|
|
905
|
+
Note that the first run of .evaluate() after calling the .dispatch_constructor()
|
|
906
|
+
takes a little longer, because the model and solver are jit-compiled to JAX for
|
|
907
|
+
highly efficient computations.
|
|
908
|
+
|
|
909
|
+
Parameters
|
|
910
|
+
----------
|
|
911
|
+
|
|
912
|
+
theta : Dict[float|Sequence[float]]
|
|
913
|
+
Dictionary of model parameters that should be changed for dispatch.
|
|
914
|
+
Unspecified model parameters will assume the default values,
|
|
915
|
+
specified under config.model_parameters.NAME.value
|
|
916
|
+
|
|
917
|
+
y0 : Dict[float|Sequence[float]]
|
|
918
|
+
Dictionary of initial values that should be changed for dispatch.
|
|
919
|
+
|
|
920
|
+
x_in : Dict[float|Sequence[float]]
|
|
921
|
+
Dictionary of model input values that should be changed for dispatch.
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
Example
|
|
925
|
+
-------
|
|
926
|
+
|
|
927
|
+
>>> sim.dispatch_constructor() # necessary if the sim object has been modified
|
|
928
|
+
>>> # evaluate setting the background mortaltiy to zero
|
|
929
|
+
>>> sim.evaluate(parameters={'hb': 0.0})
|
|
930
|
+
|
|
931
|
+
"""
|
|
932
|
+
evaluator = self.dispatch(theta=parameters, x_in=x_in, y0=y0)
|
|
933
|
+
evaluator()
|
|
934
|
+
return evaluator.results
|
|
935
|
+
|
|
936
|
+
def estimate_background_mortality(
|
|
937
|
+
self,
|
|
938
|
+
control_ids: Optional[str|List[str]] = None,
|
|
939
|
+
exclude_controls_after_fixing_background_mortality: bool = True,
|
|
940
|
+
inference_numpyro: Numpyro = Numpyro(
|
|
941
|
+
kernel="map",
|
|
942
|
+
svi_iterations=1000,
|
|
943
|
+
svi_learning_rate=0.01,
|
|
944
|
+
init_strategy="init_to_median",
|
|
945
|
+
gaussian_base_distribution=True
|
|
946
|
+
),
|
|
947
|
+
):
|
|
948
|
+
"""Separately estimates the background mortality parameters based on the control
|
|
949
|
+
treatments. Afterwards, the background mortality parameters are fixed to the
|
|
950
|
+
estimated maximum-a-posteriori values. Note that in the case of SVI, and NUTS
|
|
951
|
+
the map value is the sample of the posteiror distribution that comes closest to
|
|
952
|
+
the true MAP value.
|
|
953
|
+
|
|
954
|
+
Parameters
|
|
955
|
+
----------
|
|
956
|
+
|
|
957
|
+
control_ids : Optional[str | List [str]]
|
|
958
|
+
The names of the IDs to use for fitting the control mortality parameters
|
|
959
|
+
By default, this selects all IDs that have no exposure throghout the entire
|
|
960
|
+
duration of the provided timeseries.
|
|
961
|
+
exclude_controls_after_fixing_background_mortality : bool
|
|
962
|
+
If the controls should be excluded from fitting after calibration.
|
|
963
|
+
inference_numpyro: Numpyro
|
|
964
|
+
inference_numpyro config section to parameterize background mortality
|
|
965
|
+
estimation. By default, the MAP kernel is used, which is sufficient
|
|
966
|
+
for a problem, where the uncertainty of the estimate is not propagated to
|
|
967
|
+
the following analysis.
|
|
968
|
+
"""
|
|
969
|
+
|
|
970
|
+
self._exclude_controls_after_fixing_background_mortality =\
|
|
971
|
+
exclude_controls_after_fixing_background_mortality
|
|
972
|
+
# copy the simulation in order not to mix up anything in the original sim
|
|
973
|
+
sim_control = self.copy()
|
|
974
|
+
|
|
975
|
+
if isinstance(control_ids, str):
|
|
976
|
+
control_ids = [control_ids]
|
|
977
|
+
elif control_ids is None:
|
|
978
|
+
cum_expo = sim_control.observations.exposure.sum(
|
|
979
|
+
("time", sim_control._exposure_dimension)
|
|
980
|
+
)
|
|
981
|
+
control_ids = cum_expo.where(cum_expo == 0, drop=True).id.values
|
|
982
|
+
else:
|
|
983
|
+
pass
|
|
984
|
+
|
|
985
|
+
# constrain the observation of the copied object to the control ids
|
|
986
|
+
sim_control.observations = sim_control.observations.sel(id=control_ids)
|
|
987
|
+
|
|
988
|
+
# Fix parameters of the background-mortality module at zero
|
|
989
|
+
params_fix_at_zero = {
|
|
990
|
+
k: {"value": 0.0, "free": False } for k in sim_control.model_parameter_names
|
|
991
|
+
if k not in sim_control.config.guts_base.background_mortality_parameters
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
# update the parameters in the model_parameters dict
|
|
995
|
+
params_backup = sim_control._update_model_parameters(
|
|
996
|
+
sim_control.config.model_parameters, params_fix_at_zero
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
# setup inferer
|
|
1000
|
+
sim_control.prepare_simulation_input()
|
|
1001
|
+
sim_control.dispatch_constructor()
|
|
1002
|
+
sim_control.set_inferer("numpyro")
|
|
1003
|
+
|
|
1004
|
+
# run inference
|
|
1005
|
+
sim_control.config.inference_numpyro = Numpyro.model_validate(inference_numpyro)
|
|
1006
|
+
sim_control.inferer.run()
|
|
1007
|
+
|
|
1008
|
+
# plot results of background mortality
|
|
1009
|
+
sim_control._plot.plot_survival_multipanel(
|
|
1010
|
+
sim_control, sim_control.inferer.idata.posterior_model_fits,
|
|
1011
|
+
filename="survival_multipanel_control_treatments"
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
# optain the maximum a posteriori estiamte from the inferer using the guts-base API
|
|
1015
|
+
# and start iterating over the background mortality parameters
|
|
1016
|
+
map_estimate = sim_control.point_estimate("map", to="dict")
|
|
1017
|
+
for bgm_param in sim_control.config.guts_base.background_mortality_parameters:
|
|
1018
|
+
bgm_param_value = map_estimate[bgm_param]
|
|
1019
|
+
|
|
1020
|
+
# assign the estimated parameter MAP value to the parameters of the original
|
|
1021
|
+
# simulation object. Also set them as fixed parameters
|
|
1022
|
+
self.config.model_parameters[bgm_param].value = bgm_param_value
|
|
1023
|
+
self.config.model_parameters[bgm_param].free = False
|
|
1024
|
+
|
|
1025
|
+
# reverse the process from before (this is strictly not necessary)
|
|
1026
|
+
_ = sim_control._update_model_parameters(
|
|
1027
|
+
sim_control.config.model_parameters, params_backup
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
# constrain observations to non-control IDs if flag is set
|
|
1031
|
+
if exclude_controls_after_fixing_background_mortality:
|
|
1032
|
+
control_mask = [id for id in self.observations.id.values if id not in control_ids]
|
|
1033
|
+
self.observations = self.observations.sel(id=control_mask)
|
|
1034
|
+
|
|
1035
|
+
# assemble simulation inputs and Evaluator with new fixed parameter values.
|
|
1036
|
+
self.prepare_simulation_input()
|
|
1037
|
+
self.dispatch_constructor()
|
|
1038
|
+
|
|
1039
|
+
@classmethod
|
|
1040
|
+
def draft_laboratory_experiment(
|
|
1041
|
+
cls,
|
|
1042
|
+
treatments: Dict[str, float|Dict[str,float]],
|
|
1043
|
+
n_test_organisms_per_treatment: int = 10,
|
|
1044
|
+
experiment_end: pd.Timedelta = pd.Timedelta(10, unit="days"),
|
|
1045
|
+
exposure_pattern: ExposureDataDict|Dict[str,ExposureDataDict] = ExposureDataDict(start=0.0, end=None, exposure=None),
|
|
1046
|
+
exposure_interpolation: Literal["linear", "constant-forward"] = "constant-forward",
|
|
1047
|
+
exposure_dimension: str = "substance",
|
|
1048
|
+
observation_times: Optional[List[float]] = None,
|
|
1049
|
+
dt: pd.Timedelta = pd.Timedelta("1 day"),
|
|
1050
|
+
write_to_file: Optional[str] = None
|
|
1051
|
+
):
|
|
1052
|
+
"""Simulate a laboratory experiment according to a treatment dictionary
|
|
1053
|
+
|
|
1054
|
+
"""
|
|
1055
|
+
|
|
1056
|
+
experiment = draft_laboratory_experiment(
|
|
1057
|
+
treatments=treatments,
|
|
1058
|
+
experiment_end=experiment_end,
|
|
1059
|
+
exposure_pattern=exposure_pattern,
|
|
1060
|
+
exposure_dimension=exposure_dimension,
|
|
1061
|
+
dt=dt
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
|
|
1065
|
+
survival = np.full(
|
|
1066
|
+
[v for k, v in experiment.sizes.items() if k in ("time", "id")],
|
|
1067
|
+
fill_value=np.nan
|
|
1068
|
+
)
|
|
1069
|
+
# set with the number of test organism at time zerp
|
|
1070
|
+
survival[:, 0] = n_test_organisms_per_treatment
|
|
1071
|
+
experiment["survival"] = xr.DataArray(survival, coords=[experiment.id,experiment.time])
|
|
1072
|
+
|
|
1073
|
+
if observation_times is None:
|
|
1074
|
+
observation_times_safe = experiment.time
|
|
1075
|
+
else:
|
|
1076
|
+
observation_times_safe = np.unique(np.concatenate([experiment.time,observation_times]))
|
|
1077
|
+
|
|
1078
|
+
experiment = experiment.reindex(time=observation_times_safe)
|
|
1079
|
+
|
|
1080
|
+
# TODO: This does not year make the exposure profiles openguts ready. I.e.
|
|
1081
|
+
# if concentration changes occurr this will not be completely explicit by
|
|
1082
|
+
# making jumps
|
|
1083
|
+
# this requires a method that adds a time point before any change if constant-forward
|
|
1084
|
+
if exposure_interpolation == "linear":
|
|
1085
|
+
experiment["exposure"] = experiment["exposure"].interpolate_na(dim="time", method="linear")
|
|
1086
|
+
else:
|
|
1087
|
+
experiment["exposure"] = experiment["exposure"].ffill(dim="time",)
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
# TODO: What to do with interpolation
|
|
1091
|
+
return experiment
|
|
1092
|
+
|
|
1093
|
+
@classmethod
|
|
1094
|
+
def to_openguts(cls, observations: xr.Dataset, path: str, time_unit: str):
|
|
1095
|
+
|
|
1096
|
+
experiment = observations.rename({
|
|
1097
|
+
"time": f"time [{time_unit}]"
|
|
1098
|
+
})
|
|
1099
|
+
|
|
1100
|
+
extra_dim = cls._get_exposure_dimension(observations.dims.keys())
|
|
1101
|
+
|
|
1102
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
1103
|
+
|
|
1104
|
+
with pd.ExcelWriter(path) as writer:
|
|
1105
|
+
for coord in observations[extra_dim].values:
|
|
1106
|
+
experiment.exposure.sel({extra_dim: coord}).to_pandas().T.to_excel(writer, sheet_name=coord)
|
|
1107
|
+
experiment.survival.to_pandas().T.to_excel(writer, sheet_name="survival")
|
|
1108
|
+
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
def load_exposure_scenario(
|
|
1112
|
+
self,
|
|
1113
|
+
data: Dict[str,pd.DataFrame],
|
|
1114
|
+
sheet_name_prefix: str = "",
|
|
1115
|
+
rect_interpolate=False
|
|
1116
|
+
|
|
1117
|
+
):
|
|
1118
|
+
|
|
1119
|
+
self._obs_backup = self.observations.copy(deep=True)
|
|
1120
|
+
|
|
1121
|
+
# read exposure array from file
|
|
1122
|
+
exposure_dim = [
|
|
1123
|
+
d for d in self.config.data_structure["exposure"].dimensions
|
|
1124
|
+
if d not in (self.config.simulation.x_dimension, self.config.simulation.batch_dimension)
|
|
1125
|
+
]
|
|
1126
|
+
exposure = self._exposure_data_to_xarray(
|
|
1127
|
+
exposure_data=data,
|
|
1128
|
+
dim=exposure_dim[0]
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
# combine with observations
|
|
1132
|
+
new_obs = xr.combine_by_coords([
|
|
1133
|
+
exposure,
|
|
1134
|
+
self.observations.survival
|
|
1135
|
+
]).sel(id=exposure.id)
|
|
1136
|
+
|
|
1137
|
+
self.observations = new_obs.sel(time=[t for t in new_obs.time if t <= exposure.time.max()]) # type: ignore
|
|
1138
|
+
self.config.simulation.x_in = ["exposure=exposure"]
|
|
1139
|
+
self.model_parameters["x_in"] = self.parse_input("x_in", exposure).ffill("time") # type: ignore
|
|
1140
|
+
self.model_parameters["y0"] = self.parse_input("y0", drop_dims=["time"])
|
|
1141
|
+
|
|
1142
|
+
self.dispatch_constructor()
|
|
1143
|
+
|
|
1144
|
+
def export(self, directory: Optional[str] = None, mode: Literal["export", "copy"] = "export", skip_data_processing=True):
|
|
1145
|
+
self.config.simulation.skip_data_processing = skip_data_processing
|
|
1146
|
+
super().export(directory=directory, mode=mode)
|
|
1147
|
+
|
|
1148
|
+
def export_to_scenario(self, scenario, force=False):
|
|
1149
|
+
"""Exports a case study as a new scenario for running inference"""
|
|
1150
|
+
self.config.case_study.scenario = scenario
|
|
1151
|
+
self.config.case_study.data = None
|
|
1152
|
+
self.config.case_study.output = None
|
|
1153
|
+
self.config.case_study.scenario_path_override = None
|
|
1154
|
+
self.config.simulation.skip_data_processing = True
|
|
1155
|
+
self.save_observations(filename=f"observations_{scenario}.nc", force=force)
|
|
1156
|
+
self.config.save(force=force)
|
|
1157
|
+
|
|
1158
|
+
@staticmethod
|
|
1159
|
+
def _condition_posterior(
|
|
1160
|
+
posterior: xr.Dataset,
|
|
1161
|
+
parameter: str,
|
|
1162
|
+
value: float,
|
|
1163
|
+
exception: Literal["raise", "warn"]="raise"
|
|
1164
|
+
):
|
|
1165
|
+
"""TODO: Provide this method also to SimulationBase"""
|
|
1166
|
+
if parameter not in posterior:
|
|
1167
|
+
keys = list(posterior.keys())
|
|
1168
|
+
msg = (
|
|
1169
|
+
f"{parameter=} was not found in the posterior {keys=}. " +
|
|
1170
|
+
f"Unable to condition the posterior to {value=}. Have you "+
|
|
1171
|
+
"requested the correct parameter for conditioning?"
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
if exception == "raise":
|
|
1175
|
+
raise GutsBaseError(msg)
|
|
1176
|
+
elif exception == "warn":
|
|
1177
|
+
warnings.warn(msg)
|
|
1178
|
+
else:
|
|
1179
|
+
raise GutsBaseError(
|
|
1180
|
+
"Use one of exception='raise' or exception='warn'. " +
|
|
1181
|
+
f"Currently using {exception=}"
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
# broadcast value so that methods like drawing samples and hdi still work
|
|
1185
|
+
broadcasted_value = np.full_like(posterior[parameter], value)
|
|
1186
|
+
|
|
1187
|
+
return posterior.assign({
|
|
1188
|
+
parameter: (posterior[parameter].dims, broadcasted_value)
|
|
1189
|
+
})
|
|
1190
|
+
|
|
1191
|
+
|
|
1192
|
+
class GutsSimulationConstantExposure(GutsBase):
|
|
1193
|
+
t_max = 10
|
|
1194
|
+
def initialize_from_script(self):
|
|
1195
|
+
self.config.data_structure.B = DataVariable(dimensions=["time"], observed=False)
|
|
1196
|
+
self.config.data_structure.D = DataVariable(dimensions=["time"], observed=False)
|
|
1197
|
+
self.config.data_structure.H = DataVariable(dimensions=["time"], observed=False)
|
|
1198
|
+
self.config.data_structure.survival = DataVariable(dimensions=["time"], observed=False)
|
|
1199
|
+
|
|
1200
|
+
# y0
|
|
1201
|
+
self.config.simulation.y0 = ["D=Array([0])", "H=Array([0])", "survival=Array([1])"]
|
|
1202
|
+
self.model_parameters["y0"] = self.parse_input(input="y0", drop_dims=["time"])
|
|
1203
|
+
|
|
1204
|
+
# parameters
|
|
1205
|
+
self.config.model_parameters.C_0 = Param(value=10.0, free=False)
|
|
1206
|
+
self.config.model_parameters.k_d = Param(value=0.9, free=True)
|
|
1207
|
+
self.config.model_parameters.h_b = Param(value=0.00005, free=True)
|
|
1208
|
+
self.config.model_parameters.b = Param(value=5.0, free=True)
|
|
1209
|
+
self.config.model_parameters.z = Param(value=0.2, free=True)
|
|
1210
|
+
|
|
1211
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
1212
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
1213
|
+
|
|
1214
|
+
self.coordinates["time"] = np.linspace(0,self.t_max)
|
|
1215
|
+
|
|
1216
|
+
def use_jax_solver(self):
|
|
1217
|
+
# =======================
|
|
1218
|
+
# Define model and solver
|
|
1219
|
+
# =======================
|
|
1220
|
+
|
|
1221
|
+
self.coordinates["time"] = np.array([0,self.t_max])
|
|
1222
|
+
self.config.simulation.model = "guts_constant_exposure"
|
|
1223
|
+
|
|
1224
|
+
self.solver = JaxSolver
|
|
1225
|
+
|
|
1226
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
1227
|
+
|
|
1228
|
+
def use_symbolic_solver(self):
|
|
1229
|
+
# =======================
|
|
1230
|
+
# Define model and solver
|
|
1231
|
+
# =======================
|
|
1232
|
+
|
|
1233
|
+
self.coordinates["time"] = np.array([0,self.t_max])
|
|
1234
|
+
self.config.simulation.model = "guts_sympy"
|
|
1235
|
+
|
|
1236
|
+
self.solver = mod.PiecewiseSymbolicSolver
|
|
1237
|
+
|
|
1238
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
1239
|
+
|
|
1240
|
+
|
|
1241
|
+
class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
|
|
1242
|
+
t_max = 10
|
|
1243
|
+
def initialize_from_script(self):
|
|
1244
|
+
super().initialize_from_script()
|
|
1245
|
+
del self.coordinates["time"]
|
|
1246
|
+
exposure = create_artificial_data(
|
|
1247
|
+
t_max=self.t_max, dt=1,
|
|
1248
|
+
exposure_paths=["topical"]
|
|
1249
|
+
).squeeze()
|
|
1250
|
+
self.observations = exposure
|
|
1251
|
+
|
|
1252
|
+
self.config.data_structure.exposure = DataVariable(dimensions=["time"], observed=True)
|
|
1253
|
+
|
|
1254
|
+
self.config.simulation.x_in = ["exposure=exposure"]
|
|
1255
|
+
x_in = self.parse_input(input="x_in", reference_data=exposure, drop_dims=[])
|
|
1256
|
+
x_in = rect_interpolation(x_in=x_in, x_dim="time")
|
|
1257
|
+
self.model_parameters["x_in"] = x_in
|
|
1258
|
+
|
|
1259
|
+
# parameters
|
|
1260
|
+
self.config.model_parameters.remove("C_0")
|
|
1261
|
+
|
|
1262
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
1263
|
+
self.config.simulation.solver_post_processing = "red_sd_post_processing"
|
|
1264
|
+
self.config.simulation.model = "guts_variable_exposure"
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def use_jax_solver(self):
|
|
1268
|
+
# =======================
|
|
1269
|
+
# Define model and solver
|
|
1270
|
+
# =======================
|
|
1271
|
+
|
|
1272
|
+
self.model = self._mod.guts_variable_exposure
|
|
1273
|
+
self.solver = JaxSolver
|
|
1274
|
+
|
|
1275
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
1276
|
+
|
|
1277
|
+
def use_symbolic_solver(self, do_compile=True):
|
|
1278
|
+
# =======================
|
|
1279
|
+
# Define model and solver
|
|
1280
|
+
# =======================
|
|
1281
|
+
|
|
1282
|
+
self.model = self._mod.guts_sympy
|
|
1283
|
+
self.solver = self._mod.PiecewiseSymbolicSolver
|
|
1284
|
+
|
|
1285
|
+
self.dispatch_constructor(do_compile=do_compile, output_path=self.output_path)
|
|
1286
|
+
|