guts-base 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- guts_base/__init__.py +14 -0
- guts_base/data/__init__.py +34 -0
- guts_base/data/expydb.py +247 -0
- guts_base/data/generator.py +96 -0
- guts_base/data/openguts.py +294 -0
- guts_base/data/preprocessing.py +55 -0
- guts_base/data/survival.py +137 -0
- guts_base/data/time_of_death.py +571 -0
- guts_base/data/utils.py +8 -0
- guts_base/mod.py +251 -0
- guts_base/plot.py +162 -0
- guts_base/prob.py +412 -0
- guts_base/sim/__init__.py +14 -0
- guts_base/sim/base.py +464 -0
- guts_base/sim/ecx.py +357 -0
- guts_base/sim/mempy.py +252 -0
- guts_base/sim/report.py +72 -0
- guts_base/sim.py +0 -0
- guts_base-0.8.2.dist-info/METADATA +836 -0
- guts_base-0.8.2.dist-info/RECORD +24 -0
- guts_base-0.8.2.dist-info/WHEEL +5 -0
- guts_base-0.8.2.dist-info/entry_points.txt +3 -0
- guts_base-0.8.2.dist-info/licenses/LICENSE +674 -0
- guts_base-0.8.2.dist-info/top_level.txt +1 -0
guts_base/sim/base.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import warnings
|
|
4
|
+
import numpy as np
|
|
5
|
+
import xarray as xr
|
|
6
|
+
from diffrax import Dopri5
|
|
7
|
+
from typing import Literal, Optional, List, Dict
|
|
8
|
+
import tempfile
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from pymob import SimulationBase
|
|
12
|
+
from pymob.sim.config import DataVariable, Param, string_to_list
|
|
13
|
+
|
|
14
|
+
from pymob.solvers import JaxSolver
|
|
15
|
+
from pymob.solvers.base import rect_interpolation
|
|
16
|
+
from expyDB.intervention_model import (
|
|
17
|
+
Treatment, Timeseries, select, from_expydb
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from guts_base import mod
|
|
21
|
+
from guts_base.data import (
|
|
22
|
+
to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
|
|
23
|
+
create_database_and_import_data_main, design_exposure_scenario
|
|
24
|
+
)
|
|
25
|
+
from guts_base.sim.report import GutsReport
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GutsBase(SimulationBase):
|
|
29
|
+
"""
|
|
30
|
+
Initializes GUTS models from a variety of data sources
|
|
31
|
+
|
|
32
|
+
Initialization follows a couple of steps
|
|
33
|
+
1. check if necessary entries are made in the configuration, otherwise add defaults
|
|
34
|
+
2. read data or take from input
|
|
35
|
+
3. process data (add dimensions, or add indices)
|
|
36
|
+
"""
|
|
37
|
+
solver = JaxSolver
|
|
38
|
+
Report = GutsReport
|
|
39
|
+
unit_time: Literal["day", "hour", "minute", "second"] = "day"
|
|
40
|
+
results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
|
|
41
|
+
ecx_mode: Literal["mean", "draws"] = "mean"
|
|
42
|
+
|
|
43
|
+
def initialize(self, input: Dict = None):
|
|
44
|
+
|
|
45
|
+
if hasattr(self.config.simulation, "unit_time"):
|
|
46
|
+
self.unit_time = self.config.simulation.unit_time # type: ignore
|
|
47
|
+
|
|
48
|
+
if hasattr(self.config.simulation, "results_interpolation"):
|
|
49
|
+
self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
|
|
50
|
+
self.results_interpolation[0] = float(self.results_interpolation[0])
|
|
51
|
+
self.results_interpolation[1] = float(self.results_interpolation[1])
|
|
52
|
+
self.results_interpolation[2] = int(self.results_interpolation[2])
|
|
53
|
+
|
|
54
|
+
if "observations" in input:
|
|
55
|
+
self.observations = input["observations"]
|
|
56
|
+
else:
|
|
57
|
+
self.observations = self.read_data()
|
|
58
|
+
self.process_data()
|
|
59
|
+
|
|
60
|
+
# define tolerance based on the sovler tolerance
|
|
61
|
+
self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
|
|
62
|
+
|
|
63
|
+
self._reindex_time_dim()
|
|
64
|
+
|
|
65
|
+
if "survival" in self.observations:
|
|
66
|
+
if "subject_count" not in self.observations.coords:
|
|
67
|
+
self.observations = self.observations.assign_coords(
|
|
68
|
+
subject_count=("id", self.observations["survival"].isel(time=0).values, )
|
|
69
|
+
)
|
|
70
|
+
self.observations = self._data.prepare_survival_data_for_conditional_binomial(
|
|
71
|
+
observations=self.observations
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if "exposure" in self.observations:
|
|
75
|
+
self.config.data_structure.exposure.observed=False
|
|
76
|
+
|
|
77
|
+
# prepare y0 and x_in
|
|
78
|
+
x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
|
|
79
|
+
y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
|
|
80
|
+
|
|
81
|
+
# add model components
|
|
82
|
+
if self.config.simulation.forward_interpolate_exposure_data: # type: ignore
|
|
83
|
+
self.model_parameters["x_in"] = rect_interpolation(x_in)
|
|
84
|
+
else:
|
|
85
|
+
self.model_parameters["x_in"] = x_in
|
|
86
|
+
|
|
87
|
+
self.model_parameters["y0"] = y0
|
|
88
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
89
|
+
|
|
90
|
+
def construct_database_statement_from_config(self):
|
|
91
|
+
"""returns a statement to be used on a database"""
|
|
92
|
+
substance = self.config.simulation.substance # type:ignore
|
|
93
|
+
exposure_path = self.config.simulation.exposure_path # type:ignore
|
|
94
|
+
return (
|
|
95
|
+
select(Timeseries, Treatment)
|
|
96
|
+
.join(Timeseries)
|
|
97
|
+
).where(
|
|
98
|
+
Timeseries.variable.in_([substance]), # type: ignore
|
|
99
|
+
Timeseries.name == {exposure_path}
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def read_data(self):
|
|
103
|
+
# TODO: Update to new INTERVENTION MODEL
|
|
104
|
+
dataset = str(self.config.case_study.observations)
|
|
105
|
+
|
|
106
|
+
# read from a directory
|
|
107
|
+
if os.path.isdir(os.path.join(self.config.case_study.data_path, dataset)):
|
|
108
|
+
# This looks for xlsx files in the folder and imports them as a database and
|
|
109
|
+
# then proceeds as normal
|
|
110
|
+
files = glob.glob(os.path.join(
|
|
111
|
+
self.config.case_study.data_path,
|
|
112
|
+
dataset, "*.xlsx"
|
|
113
|
+
))
|
|
114
|
+
|
|
115
|
+
tempdir = tempfile.TemporaryDirectory()
|
|
116
|
+
dataset = self.read_data_from_xlsx(data=files, tempdir=tempdir)
|
|
117
|
+
|
|
118
|
+
ext = dataset.split(".")[-1]
|
|
119
|
+
|
|
120
|
+
if not os.path.exists(dataset):
|
|
121
|
+
dataset = os.path.join(self.data_path, dataset)
|
|
122
|
+
|
|
123
|
+
if ext == "db":
|
|
124
|
+
statement = self.construct_database_statement_from_config()
|
|
125
|
+
observations = self.read_data_from_expydb(dataset, statement)
|
|
126
|
+
|
|
127
|
+
# TODO: Integrate interventions in observations dataset
|
|
128
|
+
|
|
129
|
+
elif ext == "nc":
|
|
130
|
+
observations = xr.load_dataset(dataset)
|
|
131
|
+
|
|
132
|
+
else:
|
|
133
|
+
raise NotImplementedError(
|
|
134
|
+
f"Dataset extension '.{ext}' is not recognized. "+
|
|
135
|
+
"Please use one of '.db' (mysql), '.nc' (netcdf)."
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return observations
|
|
139
|
+
|
|
140
|
+
def read_data_from_xlsx(self, data, tempdir):
|
|
141
|
+
database = os.path.join(tempdir.name, "import.db")
|
|
142
|
+
|
|
143
|
+
if hasattr(self.config.simulation, "data_preprocessing"):
|
|
144
|
+
preprocessing = self.config.simulation.data_preprocessing
|
|
145
|
+
else:
|
|
146
|
+
preprocessing = None
|
|
147
|
+
|
|
148
|
+
create_database_and_import_data_main(
|
|
149
|
+
datasets_path=data,
|
|
150
|
+
database_path=database,
|
|
151
|
+
preprocessing=preprocessing,
|
|
152
|
+
preprocessing_out=os.path.join(tempdir.name, "processed_{filename}")
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return database
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def read_data_from_expydb(self, database, statement) -> xr.Dataset:
|
|
159
|
+
|
|
160
|
+
observations_idata, interventions_idata = from_expydb(
|
|
161
|
+
database=f"sqlite:///{database}",
|
|
162
|
+
statement=statement
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
dataset = to_dataset(
|
|
166
|
+
observations_idata,
|
|
167
|
+
interventions_idata,
|
|
168
|
+
unit_time=self.unit_time
|
|
169
|
+
)
|
|
170
|
+
dataset = reduce_multiindex_to_flat_index(dataset)
|
|
171
|
+
|
|
172
|
+
# "Continue here. I want to return multidimensional datasets for data coming "+
|
|
173
|
+
# "from the database. The method can be implemented in any class. Currently I'm looking "+
|
|
174
|
+
# "at guts base"
|
|
175
|
+
|
|
176
|
+
filtered_dataset = self.filter_dataset(dataset)
|
|
177
|
+
|
|
178
|
+
return filtered_dataset
|
|
179
|
+
|
|
180
|
+
def process_data(self):
|
|
181
|
+
"""
|
|
182
|
+
Currently these methods, change datasets, indices, etc. in-place.
|
|
183
|
+
This is convenient, but more difficult to re-arragen with other methods
|
|
184
|
+
TODO: Make these methods static if possible
|
|
185
|
+
|
|
186
|
+
"""
|
|
187
|
+
self._create_indices()
|
|
188
|
+
self._indices_to_dimensions()
|
|
189
|
+
|
|
190
|
+
def _create_indices(self):
|
|
191
|
+
"""Use if indices should be added to sim.indices and sim.observations"""
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
def _indices_to_dimensions(self):
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
def filter_dataset(self, dataset: xr.Dataset) -> xr.Dataset:
|
|
198
|
+
return dataset
|
|
199
|
+
|
|
200
|
+
def _reindex_time_dim(self):
|
|
201
|
+
if self.config.simulation.model is not None:
|
|
202
|
+
if "_it" in self.config.simulation.model.lower():
|
|
203
|
+
self.logger.info(msg=(
|
|
204
|
+
"Redindexing time vector to increase resolution, because model has "+
|
|
205
|
+
"'_it' (individual tolerance) in it's name"
|
|
206
|
+
))
|
|
207
|
+
if not hasattr(self.config.simulation, "n_reindexed_x"):
|
|
208
|
+
self.config.simulation.n_reindexed_x = 100
|
|
209
|
+
|
|
210
|
+
new_time_index = np.unique(np.concatenate([
|
|
211
|
+
self.coordinates["time"],
|
|
212
|
+
np.linspace(
|
|
213
|
+
0, np.max(self.coordinates["time"]),
|
|
214
|
+
int(self.config.simulation.n_reindexed_x) # type: ignore
|
|
215
|
+
)
|
|
216
|
+
]))
|
|
217
|
+
self.observations = self.observations.reindex(time = new_time_index)
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
self.logger.info(msg=(
|
|
221
|
+
"No redindexing of time vector to, because model name did not contain "+
|
|
222
|
+
"'_it' (individual tolerance), or model was not given by name. If an IT model "
|
|
223
|
+
"is calculated without a dense time resolution, the estimates can be biased!"
|
|
224
|
+
))
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def recompute_posterior(self):
|
|
229
|
+
"""This function interpolates the posterior with a given resolution
|
|
230
|
+
posterior_predictions calculate proper survival predictions for the
|
|
231
|
+
posterior.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
if np.isnan(self.results_interpolation[0]):
|
|
235
|
+
self.results_interpolation[0] = self.observations["time"].min()
|
|
236
|
+
|
|
237
|
+
if np.isnan(self.results_interpolation[1]):
|
|
238
|
+
self.results_interpolation[1] = self.observations["time"].max()
|
|
239
|
+
|
|
240
|
+
# generate high resolution posterior predictions
|
|
241
|
+
if self.results_interpolation is not None:
|
|
242
|
+
time_interpolate = np.linspace(
|
|
243
|
+
start=self.results_interpolation[0],
|
|
244
|
+
stop=self.results_interpolation[1],
|
|
245
|
+
num=self.results_interpolation[2]
|
|
246
|
+
)
|
|
247
|
+
self.observations = self.observations.reindex(
|
|
248
|
+
time=time_interpolate
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.dispatch_constructor()
|
|
252
|
+
_ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
|
|
253
|
+
self.inferer.store_results(output=f"{self.output_path}/numpyro_posterior_interp.nc") # type: ignore
|
|
254
|
+
self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def prior_predictive_checks(self):
|
|
258
|
+
super().prior_predictive_checks()
|
|
259
|
+
|
|
260
|
+
self._plot.plot_prior_predictions(self, data_vars=["survival"])
|
|
261
|
+
|
|
262
|
+
def posterior_predictive_checks(self):
|
|
263
|
+
super().posterior_predictive_checks()
|
|
264
|
+
|
|
265
|
+
self.recompute_posterior()
|
|
266
|
+
# TODO: Include posterior_predictive group once the survival predictions are correctly working
|
|
267
|
+
self._plot.plot_posterior_predictions(self, data_vars=["survival"], groups=["posterior_model_fits"])
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def plot(self, results):
|
|
271
|
+
self._plot.plot_survival(self, results)
|
|
272
|
+
|
|
273
|
+
def copy(self):
|
|
274
|
+
with warnings.catch_warnings(action="ignore"):
|
|
275
|
+
sim_copy = type(self)(self.config)
|
|
276
|
+
sim_copy.observations = self.observations
|
|
277
|
+
sim_copy.model_parameters = self.model_parameters
|
|
278
|
+
if self.inferer is not None:
|
|
279
|
+
sim_copy.inferer = type(self.inferer)(self)
|
|
280
|
+
sim_copy.inferer.idata = self.inferer.idata
|
|
281
|
+
sim_copy.model = self.model
|
|
282
|
+
sim_copy.solver_post_processing = self.solver_post_processing
|
|
283
|
+
sim_copy.load_modules()
|
|
284
|
+
|
|
285
|
+
return sim_copy
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def predefined_scenarios(self):
|
|
290
|
+
# this produces a exposure x_in dataset with only the dimensions ID and TIME
|
|
291
|
+
oral_acute_1d = design_exposure_scenario(
|
|
292
|
+
exposures={
|
|
293
|
+
"oral":dict(start=0, end=1.0, concentration=1.0),
|
|
294
|
+
},
|
|
295
|
+
t_max=10.01,
|
|
296
|
+
dt=1/24,
|
|
297
|
+
exposure_dimension="exposure_path"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return dict(
|
|
301
|
+
oral_acute_1d=oral_acute_1d
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
|
|
305
|
+
"""This method will take an existing coordinate of a dataset that has the same
|
|
306
|
+
coordinate has the batch dimension. It will then re-express the coordinate as a
|
|
307
|
+
separate dimension for the given variables, by duplicating the N-Dimensional array
|
|
308
|
+
times the amount of unique names in the specified coordinate to create an
|
|
309
|
+
N+1-dimensional array. This array will be filled with zeros along the batch dimension
|
|
310
|
+
where the specified coordinate along the ID dimension coincides with the new (unique)
|
|
311
|
+
coordinate of the new dimension.
|
|
312
|
+
|
|
313
|
+
This process is entirely reversible
|
|
314
|
+
"""
|
|
315
|
+
old_coords = self.observations[coordinate]
|
|
316
|
+
batch_dim = self.config.simulation.batch_dimension
|
|
317
|
+
|
|
318
|
+
# old coordinate before turning it into a dimension
|
|
319
|
+
obs = self.observations.drop(coordinate)
|
|
320
|
+
|
|
321
|
+
# create unique coordinates of the new dimension, preserving the order of the
|
|
322
|
+
# old coordinate
|
|
323
|
+
_, index = np.unique(old_coords, return_index=True)
|
|
324
|
+
coords_new_dim = tuple(np.array(old_coords)[sorted(index)])
|
|
325
|
+
|
|
326
|
+
for v in variables:
|
|
327
|
+
# take data variable and extract dimension order
|
|
328
|
+
data_var = obs[v]
|
|
329
|
+
dim_order = data_var.dims
|
|
330
|
+
|
|
331
|
+
# expand the dimensionality, then transpose for new dim to be last
|
|
332
|
+
data_var = data_var.expand_dims(coordinate).transpose(..., batch_dim, coordinate)
|
|
333
|
+
|
|
334
|
+
# create a dummy dimension to broadcast the new array
|
|
335
|
+
# dummy_3d = np.ones((1, len(coords_new_dim)))
|
|
336
|
+
dummy_categorical = pd.get_dummies(old_coords).astype(int).values
|
|
337
|
+
|
|
338
|
+
# apply automatic broadcasting to increase the size of the new dimension
|
|
339
|
+
# data_var_np1_d = data_var * dummy_3d
|
|
340
|
+
data_var_np1_d = data_var * dummy_categorical
|
|
341
|
+
|
|
342
|
+
# annotate coordinates of the new dimension
|
|
343
|
+
data_var_np1_d = data_var_np1_d.assign_coords({
|
|
344
|
+
coordinate: list(coords_new_dim)
|
|
345
|
+
})
|
|
346
|
+
|
|
347
|
+
# transpose back to original dimension order with new dim as last dim
|
|
348
|
+
data_var_np1_d = data_var_np1_d.transpose(*dim_order, coordinate)
|
|
349
|
+
obs[v] = data_var_np1_d
|
|
350
|
+
|
|
351
|
+
return obs
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def reduce_dimension_to_batch_like_coordinate(self, dimension, variables):
|
|
355
|
+
"""This method takes an existing dimension from a N-D array and reduces it to an
|
|
356
|
+
(N-1)-D array, by writing a new coordinate from the reducible dimension in the way
|
|
357
|
+
that the new batch-like coordinate takes the coordinate of the dimension, where
|
|
358
|
+
the data of the N-D array was not zero. After it has been asserted that there is
|
|
359
|
+
only a unique candidate for the each coordinate along the batch dimension
|
|
360
|
+
(i.e. only one value is non-zero for a given batch-coordinate), the dimension will
|
|
361
|
+
be reduced by summing over the given dimension.
|
|
362
|
+
|
|
363
|
+
The method is contingent on having no overlap in batch dimension in the dataset
|
|
364
|
+
"""
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
def initialize_from_script(self):
|
|
368
|
+
pass
|
|
369
|
+
|
|
370
|
+
class GutsSimulationConstantExposure(GutsBase):
|
|
371
|
+
t_max = 10
|
|
372
|
+
def initialize_from_script(self):
|
|
373
|
+
self.config.data_structure.B = DataVariable(dimensions=["time"], observed=False)
|
|
374
|
+
self.config.data_structure.D = DataVariable(dimensions=["time"], observed=False)
|
|
375
|
+
self.config.data_structure.H = DataVariable(dimensions=["time"], observed=False)
|
|
376
|
+
self.config.data_structure.survival = DataVariable(dimensions=["time"], observed=False)
|
|
377
|
+
|
|
378
|
+
# y0
|
|
379
|
+
self.config.simulation.y0 = ["D=Array([0])", "H=Array([0])", "survival=Array([1])"]
|
|
380
|
+
self.model_parameters["y0"] = self.parse_input(input="y0", drop_dims=["time"])
|
|
381
|
+
|
|
382
|
+
# parameters
|
|
383
|
+
self.config.model_parameters.C_0 = Param(value=10.0, free=False)
|
|
384
|
+
self.config.model_parameters.k_d = Param(value=0.9, free=True)
|
|
385
|
+
self.config.model_parameters.h_b = Param(value=0.00005, free=True)
|
|
386
|
+
self.config.model_parameters.b = Param(value=5.0, free=True)
|
|
387
|
+
self.config.model_parameters.z = Param(value=0.2, free=True)
|
|
388
|
+
|
|
389
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
390
|
+
self.config.simulation.model = "guts_jax"
|
|
391
|
+
|
|
392
|
+
self.coordinates["time"] = np.linspace(0,self.t_max)
|
|
393
|
+
|
|
394
|
+
def use_jax_solver(self):
|
|
395
|
+
# =======================
|
|
396
|
+
# Define model and solver
|
|
397
|
+
# =======================
|
|
398
|
+
|
|
399
|
+
self.coordinates["time"] = np.array([0,self.t_max])
|
|
400
|
+
self.config.simulation.model = "guts_jax"
|
|
401
|
+
|
|
402
|
+
self.solver = JaxSolver
|
|
403
|
+
|
|
404
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
405
|
+
|
|
406
|
+
def use_symbolic_solver(self):
|
|
407
|
+
# =======================
|
|
408
|
+
# Define model and solver
|
|
409
|
+
# =======================
|
|
410
|
+
|
|
411
|
+
self.coordinates["time"] = np.array([0,self.t_max])
|
|
412
|
+
self.config.simulation.model = "guts_sympy"
|
|
413
|
+
|
|
414
|
+
self.solver = mod.PiecewiseSymbolicSolver
|
|
415
|
+
|
|
416
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
|
|
420
|
+
t_max = 10
|
|
421
|
+
def initialize_from_script(self):
|
|
422
|
+
super().initialize_from_script()
|
|
423
|
+
del self.coordinates["time"]
|
|
424
|
+
exposure = create_artificial_data(
|
|
425
|
+
t_max=self.t_max, dt=1,
|
|
426
|
+
exposure_paths=["topical"]
|
|
427
|
+
).squeeze()
|
|
428
|
+
self.observations = exposure
|
|
429
|
+
|
|
430
|
+
self.config.data_structure.exposure = DataVariable(dimensions=["time"], observed=True)
|
|
431
|
+
|
|
432
|
+
self.config.simulation.x_in = ["exposure=exposure"]
|
|
433
|
+
x_in = self.parse_input(input="x_in", reference_data=exposure, drop_dims=[])
|
|
434
|
+
x_in = rect_interpolation(x_in=x_in, x_dim="time")
|
|
435
|
+
self.model_parameters["x_in"] = x_in
|
|
436
|
+
|
|
437
|
+
# parameters
|
|
438
|
+
self.config.model_parameters.remove("C_0")
|
|
439
|
+
|
|
440
|
+
self.model_parameters["parameters"] = self.config.model_parameters.value_dict
|
|
441
|
+
self.config.simulation.solver_post_processing = "post_exposure"
|
|
442
|
+
self.config.simulation.model = "guts_variable_exposure"
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def use_jax_solver(self):
|
|
446
|
+
# =======================
|
|
447
|
+
# Define model and solver
|
|
448
|
+
# =======================
|
|
449
|
+
|
|
450
|
+
self.model = self._mod.guts_variable_exposure
|
|
451
|
+
self.solver = JaxSolver
|
|
452
|
+
|
|
453
|
+
self.dispatch_constructor(diffrax_solver=Dopri5)
|
|
454
|
+
|
|
455
|
+
def use_symbolic_solver(self, do_compile=True):
|
|
456
|
+
# =======================
|
|
457
|
+
# Define model and solver
|
|
458
|
+
# =======================
|
|
459
|
+
|
|
460
|
+
self.model = self._mod.guts_sympy
|
|
461
|
+
self.solver = self._mod.PiecewiseSymbolicSolver
|
|
462
|
+
|
|
463
|
+
self.dispatch_constructor(do_compile=do_compile, output_path=self.output_path)
|
|
464
|
+
|