guts-base 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of guts-base might be problematic. Click here for more details.

guts_base/sim/base.py ADDED
@@ -0,0 +1,464 @@
1
+ import os
2
+ import glob
3
+ import warnings
4
+ import numpy as np
5
+ import xarray as xr
6
+ from diffrax import Dopri5
7
+ from typing import Literal, Optional, List, Dict
8
+ import tempfile
9
+ import pandas as pd
10
+
11
+ from pymob import SimulationBase
12
+ from pymob.sim.config import DataVariable, Param, string_to_list
13
+
14
+ from pymob.solvers import JaxSolver
15
+ from pymob.solvers.base import rect_interpolation
16
+ from expyDB.intervention_model import (
17
+ Treatment, Timeseries, select, from_expydb
18
+ )
19
+
20
+ from guts_base import mod
21
+ from guts_base.data import (
22
+ to_dataset, reduce_multiindex_to_flat_index, create_artificial_data,
23
+ create_database_and_import_data_main, design_exposure_scenario
24
+ )
25
+ from guts_base.sim.report import GutsReport
26
+
27
+
28
+ class GutsBase(SimulationBase):
29
+ """
30
+ Initializes GUTS models from a variety of data sources
31
+
32
+ Initialization follows a couple of steps
33
+ 1. check if necessary entries are made in the configuration, otherwise add defaults
34
+ 2. read data or take from input
35
+ 3. process data (add dimensions, or add indices)
36
+ """
37
+ solver = JaxSolver
38
+ Report = GutsReport
39
+ unit_time: Literal["day", "hour", "minute", "second"] = "day"
40
+ results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
41
+ ecx_mode: Literal["mean", "draws"] = "mean"
42
+
43
+ def initialize(self, input: Dict = None):
44
+
45
+ if hasattr(self.config.simulation, "unit_time"):
46
+ self.unit_time = self.config.simulation.unit_time # type: ignore
47
+
48
+ if hasattr(self.config.simulation, "results_interpolation"):
49
+ self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
50
+ self.results_interpolation[0] = float(self.results_interpolation[0])
51
+ self.results_interpolation[1] = float(self.results_interpolation[1])
52
+ self.results_interpolation[2] = int(self.results_interpolation[2])
53
+
54
+ if "observations" in input:
55
+ self.observations = input["observations"]
56
+ else:
57
+ self.observations = self.read_data()
58
+ self.process_data()
59
+
60
+ # define tolerance based on the sovler tolerance
61
+ self.observations = self.observations.assign_coords(eps=self.config.jaxsolver.atol * 10)
62
+
63
+ self._reindex_time_dim()
64
+
65
+ if "survival" in self.observations:
66
+ if "subject_count" not in self.observations.coords:
67
+ self.observations = self.observations.assign_coords(
68
+ subject_count=("id", self.observations["survival"].isel(time=0).values, )
69
+ )
70
+ self.observations = self._data.prepare_survival_data_for_conditional_binomial(
71
+ observations=self.observations
72
+ )
73
+
74
+ if "exposure" in self.observations:
75
+ self.config.data_structure.exposure.observed=False
76
+
77
+ # prepare y0 and x_in
78
+ x_in = self.parse_input(input="x_in", reference_data=self.observations, drop_dims=[])
79
+ y0 = self.parse_input(input="y0", reference_data=self.observations, drop_dims=["time"])
80
+
81
+ # add model components
82
+ if self.config.simulation.forward_interpolate_exposure_data: # type: ignore
83
+ self.model_parameters["x_in"] = rect_interpolation(x_in)
84
+ else:
85
+ self.model_parameters["x_in"] = x_in
86
+
87
+ self.model_parameters["y0"] = y0
88
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
89
+
90
+ def construct_database_statement_from_config(self):
91
+ """returns a statement to be used on a database"""
92
+ substance = self.config.simulation.substance # type:ignore
93
+ exposure_path = self.config.simulation.exposure_path # type:ignore
94
+ return (
95
+ select(Timeseries, Treatment)
96
+ .join(Timeseries)
97
+ ).where(
98
+ Timeseries.variable.in_([substance]), # type: ignore
99
+ Timeseries.name == {exposure_path}
100
+ )
101
+
102
+ def read_data(self):
103
+ # TODO: Update to new INTERVENTION MODEL
104
+ dataset = str(self.config.case_study.observations)
105
+
106
+ # read from a directory
107
+ if os.path.isdir(os.path.join(self.config.case_study.data_path, dataset)):
108
+ # This looks for xlsx files in the folder and imports them as a database and
109
+ # then proceeds as normal
110
+ files = glob.glob(os.path.join(
111
+ self.config.case_study.data_path,
112
+ dataset, "*.xlsx"
113
+ ))
114
+
115
+ tempdir = tempfile.TemporaryDirectory()
116
+ dataset = self.read_data_from_xlsx(data=files, tempdir=tempdir)
117
+
118
+ ext = dataset.split(".")[-1]
119
+
120
+ if not os.path.exists(dataset):
121
+ dataset = os.path.join(self.data_path, dataset)
122
+
123
+ if ext == "db":
124
+ statement = self.construct_database_statement_from_config()
125
+ observations = self.read_data_from_expydb(dataset, statement)
126
+
127
+ # TODO: Integrate interventions in observations dataset
128
+
129
+ elif ext == "nc":
130
+ observations = xr.load_dataset(dataset)
131
+
132
+ else:
133
+ raise NotImplementedError(
134
+ f"Dataset extension '.{ext}' is not recognized. "+
135
+ "Please use one of '.db' (mysql), '.nc' (netcdf)."
136
+ )
137
+
138
+ return observations
139
+
140
+ def read_data_from_xlsx(self, data, tempdir):
141
+ database = os.path.join(tempdir.name, "import.db")
142
+
143
+ if hasattr(self.config.simulation, "data_preprocessing"):
144
+ preprocessing = self.config.simulation.data_preprocessing
145
+ else:
146
+ preprocessing = None
147
+
148
+ create_database_and_import_data_main(
149
+ datasets_path=data,
150
+ database_path=database,
151
+ preprocessing=preprocessing,
152
+ preprocessing_out=os.path.join(tempdir.name, "processed_{filename}")
153
+ )
154
+
155
+ return database
156
+
157
+
158
+ def read_data_from_expydb(self, database, statement) -> xr.Dataset:
159
+
160
+ observations_idata, interventions_idata = from_expydb(
161
+ database=f"sqlite:///{database}",
162
+ statement=statement
163
+ )
164
+
165
+ dataset = to_dataset(
166
+ observations_idata,
167
+ interventions_idata,
168
+ unit_time=self.unit_time
169
+ )
170
+ dataset = reduce_multiindex_to_flat_index(dataset)
171
+
172
+ # "Continue here. I want to return multidimensional datasets for data coming "+
173
+ # "from the database. The method can be implemented in any class. Currently I'm looking "+
174
+ # "at guts base"
175
+
176
+ filtered_dataset = self.filter_dataset(dataset)
177
+
178
+ return filtered_dataset
179
+
180
+ def process_data(self):
181
+ """
182
+ Currently these methods, change datasets, indices, etc. in-place.
183
+ This is convenient, but more difficult to re-arragen with other methods
184
+ TODO: Make these methods static if possible
185
+
186
+ """
187
+ self._create_indices()
188
+ self._indices_to_dimensions()
189
+
190
+ def _create_indices(self):
191
+ """Use if indices should be added to sim.indices and sim.observations"""
192
+ pass
193
+
194
+ def _indices_to_dimensions(self):
195
+ pass
196
+
197
+ def filter_dataset(self, dataset: xr.Dataset) -> xr.Dataset:
198
+ return dataset
199
+
200
+ def _reindex_time_dim(self):
201
+ if self.config.simulation.model is not None:
202
+ if "_it" in self.config.simulation.model.lower():
203
+ self.logger.info(msg=(
204
+ "Redindexing time vector to increase resolution, because model has "+
205
+ "'_it' (individual tolerance) in it's name"
206
+ ))
207
+ if not hasattr(self.config.simulation, "n_reindexed_x"):
208
+ self.config.simulation.n_reindexed_x = 100
209
+
210
+ new_time_index = np.unique(np.concatenate([
211
+ self.coordinates["time"],
212
+ np.linspace(
213
+ 0, np.max(self.coordinates["time"]),
214
+ int(self.config.simulation.n_reindexed_x) # type: ignore
215
+ )
216
+ ]))
217
+ self.observations = self.observations.reindex(time = new_time_index)
218
+ return
219
+
220
+ self.logger.info(msg=(
221
+ "No redindexing of time vector to, because model name did not contain "+
222
+ "'_it' (individual tolerance), or model was not given by name. If an IT model "
223
+ "is calculated without a dense time resolution, the estimates can be biased!"
224
+ ))
225
+
226
+
227
+
228
+ def recompute_posterior(self):
229
+ """This function interpolates the posterior with a given resolution
230
+ posterior_predictions calculate proper survival predictions for the
231
+ posterior.
232
+ """
233
+
234
+ if np.isnan(self.results_interpolation[0]):
235
+ self.results_interpolation[0] = self.observations["time"].min()
236
+
237
+ if np.isnan(self.results_interpolation[1]):
238
+ self.results_interpolation[1] = self.observations["time"].max()
239
+
240
+ # generate high resolution posterior predictions
241
+ if self.results_interpolation is not None:
242
+ time_interpolate = np.linspace(
243
+ start=self.results_interpolation[0],
244
+ stop=self.results_interpolation[1],
245
+ num=self.results_interpolation[2]
246
+ )
247
+ self.observations = self.observations.reindex(
248
+ time=time_interpolate
249
+ )
250
+
251
+ self.dispatch_constructor()
252
+ _ = self._prob.posterior_predictions(self, self.inferer.idata) # type: ignore
253
+ self.inferer.store_results(output=f"{self.output_path}/numpyro_posterior_interp.nc") # type: ignore
254
+ self.logger.info("Recomputed posterior and storing in `numpyro_posterior_interp.nc`")
255
+
256
+
257
+ def prior_predictive_checks(self):
258
+ super().prior_predictive_checks()
259
+
260
+ self._plot.plot_prior_predictions(self, data_vars=["survival"])
261
+
262
+ def posterior_predictive_checks(self):
263
+ super().posterior_predictive_checks()
264
+
265
+ self.recompute_posterior()
266
+ # TODO: Include posterior_predictive group once the survival predictions are correctly working
267
+ self._plot.plot_posterior_predictions(self, data_vars=["survival"], groups=["posterior_model_fits"])
268
+
269
+
270
+ def plot(self, results):
271
+ self._plot.plot_survival(self, results)
272
+
273
+ def copy(self):
274
+ with warnings.catch_warnings(action="ignore"):
275
+ sim_copy = type(self)(self.config)
276
+ sim_copy.observations = self.observations
277
+ sim_copy.model_parameters = self.model_parameters
278
+ if self.inferer is not None:
279
+ sim_copy.inferer = type(self.inferer)(self)
280
+ sim_copy.inferer.idata = self.inferer.idata
281
+ sim_copy.model = self.model
282
+ sim_copy.solver_post_processing = self.solver_post_processing
283
+ sim_copy.load_modules()
284
+
285
+ return sim_copy
286
+
287
+
288
+ @property
289
+ def predefined_scenarios(self):
290
+ # this produces a exposure x_in dataset with only the dimensions ID and TIME
291
+ oral_acute_1d = design_exposure_scenario(
292
+ exposures={
293
+ "oral":dict(start=0, end=1.0, concentration=1.0),
294
+ },
295
+ t_max=10.01,
296
+ dt=1/24,
297
+ exposure_dimension="exposure_path"
298
+ )
299
+
300
+ return dict(
301
+ oral_acute_1d=oral_acute_1d
302
+ )
303
+
304
+ def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
305
+ """This method will take an existing coordinate of a dataset that has the same
306
+ coordinate has the batch dimension. It will then re-express the coordinate as a
307
+ separate dimension for the given variables, by duplicating the N-Dimensional array
308
+ times the amount of unique names in the specified coordinate to create an
309
+ N+1-dimensional array. This array will be filled with zeros along the batch dimension
310
+ where the specified coordinate along the ID dimension coincides with the new (unique)
311
+ coordinate of the new dimension.
312
+
313
+ This process is entirely reversible
314
+ """
315
+ old_coords = self.observations[coordinate]
316
+ batch_dim = self.config.simulation.batch_dimension
317
+
318
+ # old coordinate before turning it into a dimension
319
+ obs = self.observations.drop(coordinate)
320
+
321
+ # create unique coordinates of the new dimension, preserving the order of the
322
+ # old coordinate
323
+ _, index = np.unique(old_coords, return_index=True)
324
+ coords_new_dim = tuple(np.array(old_coords)[sorted(index)])
325
+
326
+ for v in variables:
327
+ # take data variable and extract dimension order
328
+ data_var = obs[v]
329
+ dim_order = data_var.dims
330
+
331
+ # expand the dimensionality, then transpose for new dim to be last
332
+ data_var = data_var.expand_dims(coordinate).transpose(..., batch_dim, coordinate)
333
+
334
+ # create a dummy dimension to broadcast the new array
335
+ # dummy_3d = np.ones((1, len(coords_new_dim)))
336
+ dummy_categorical = pd.get_dummies(old_coords).astype(int).values
337
+
338
+ # apply automatic broadcasting to increase the size of the new dimension
339
+ # data_var_np1_d = data_var * dummy_3d
340
+ data_var_np1_d = data_var * dummy_categorical
341
+
342
+ # annotate coordinates of the new dimension
343
+ data_var_np1_d = data_var_np1_d.assign_coords({
344
+ coordinate: list(coords_new_dim)
345
+ })
346
+
347
+ # transpose back to original dimension order with new dim as last dim
348
+ data_var_np1_d = data_var_np1_d.transpose(*dim_order, coordinate)
349
+ obs[v] = data_var_np1_d
350
+
351
+ return obs
352
+
353
+
354
+ def reduce_dimension_to_batch_like_coordinate(self, dimension, variables):
355
+ """This method takes an existing dimension from a N-D array and reduces it to an
356
+ (N-1)-D array, by writing a new coordinate from the reducible dimension in the way
357
+ that the new batch-like coordinate takes the coordinate of the dimension, where
358
+ the data of the N-D array was not zero. After it has been asserted that there is
359
+ only a unique candidate for the each coordinate along the batch dimension
360
+ (i.e. only one value is non-zero for a given batch-coordinate), the dimension will
361
+ be reduced by summing over the given dimension.
362
+
363
+ The method is contingent on having no overlap in batch dimension in the dataset
364
+ """
365
+ pass
366
+
367
+ def initialize_from_script(self):
368
+ pass
369
+
370
+ class GutsSimulationConstantExposure(GutsBase):
371
+ t_max = 10
372
+ def initialize_from_script(self):
373
+ self.config.data_structure.B = DataVariable(dimensions=["time"], observed=False)
374
+ self.config.data_structure.D = DataVariable(dimensions=["time"], observed=False)
375
+ self.config.data_structure.H = DataVariable(dimensions=["time"], observed=False)
376
+ self.config.data_structure.survival = DataVariable(dimensions=["time"], observed=False)
377
+
378
+ # y0
379
+ self.config.simulation.y0 = ["D=Array([0])", "H=Array([0])", "survival=Array([1])"]
380
+ self.model_parameters["y0"] = self.parse_input(input="y0", drop_dims=["time"])
381
+
382
+ # parameters
383
+ self.config.model_parameters.C_0 = Param(value=10.0, free=False)
384
+ self.config.model_parameters.k_d = Param(value=0.9, free=True)
385
+ self.config.model_parameters.h_b = Param(value=0.00005, free=True)
386
+ self.config.model_parameters.b = Param(value=5.0, free=True)
387
+ self.config.model_parameters.z = Param(value=0.2, free=True)
388
+
389
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
390
+ self.config.simulation.model = "guts_jax"
391
+
392
+ self.coordinates["time"] = np.linspace(0,self.t_max)
393
+
394
+ def use_jax_solver(self):
395
+ # =======================
396
+ # Define model and solver
397
+ # =======================
398
+
399
+ self.coordinates["time"] = np.array([0,self.t_max])
400
+ self.config.simulation.model = "guts_jax"
401
+
402
+ self.solver = JaxSolver
403
+
404
+ self.dispatch_constructor(diffrax_solver=Dopri5)
405
+
406
+ def use_symbolic_solver(self):
407
+ # =======================
408
+ # Define model and solver
409
+ # =======================
410
+
411
+ self.coordinates["time"] = np.array([0,self.t_max])
412
+ self.config.simulation.model = "guts_sympy"
413
+
414
+ self.solver = mod.PiecewiseSymbolicSolver
415
+
416
+ self.dispatch_constructor(diffrax_solver=Dopri5)
417
+
418
+
419
+ class GutsSimulationVariableExposure(GutsSimulationConstantExposure):
420
+ t_max = 10
421
+ def initialize_from_script(self):
422
+ super().initialize_from_script()
423
+ del self.coordinates["time"]
424
+ exposure = create_artificial_data(
425
+ t_max=self.t_max, dt=1,
426
+ exposure_paths=["topical"]
427
+ ).squeeze()
428
+ self.observations = exposure
429
+
430
+ self.config.data_structure.exposure = DataVariable(dimensions=["time"], observed=True)
431
+
432
+ self.config.simulation.x_in = ["exposure=exposure"]
433
+ x_in = self.parse_input(input="x_in", reference_data=exposure, drop_dims=[])
434
+ x_in = rect_interpolation(x_in=x_in, x_dim="time")
435
+ self.model_parameters["x_in"] = x_in
436
+
437
+ # parameters
438
+ self.config.model_parameters.remove("C_0")
439
+
440
+ self.model_parameters["parameters"] = self.config.model_parameters.value_dict
441
+ self.config.simulation.solver_post_processing = "post_exposure"
442
+ self.config.simulation.model = "guts_variable_exposure"
443
+
444
+
445
+ def use_jax_solver(self):
446
+ # =======================
447
+ # Define model and solver
448
+ # =======================
449
+
450
+ self.model = self._mod.guts_variable_exposure
451
+ self.solver = JaxSolver
452
+
453
+ self.dispatch_constructor(diffrax_solver=Dopri5)
454
+
455
+ def use_symbolic_solver(self, do_compile=True):
456
+ # =======================
457
+ # Define model and solver
458
+ # =======================
459
+
460
+ self.model = self._mod.guts_sympy
461
+ self.solver = self._mod.PiecewiseSymbolicSolver
462
+
463
+ self.dispatch_constructor(do_compile=do_compile, output_path=self.output_path)
464
+