guts-base 0.8.5__tar.gz → 0.8.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guts-base might be problematic. Click here for more details.
- {guts_base-0.8.5 → guts_base-0.8.6}/PKG-INFO +3 -3
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/__init__.py +1 -1
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/generator.py +4 -4
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/mod.py +6 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim/base.py +86 -24
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim/ecx.py +62 -14
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim/report.py +9 -8
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/PKG-INFO +3 -3
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/SOURCES.txt +2 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/requires.txt +1 -1
- {guts_base-0.8.5 → guts_base-0.8.6}/pyproject.toml +4 -4
- guts_base-0.8.6/tests/test_ecx.py +62 -0
- guts_base-0.8.6/tests/test_simulations.py +87 -0
- guts_base-0.8.6/tests/test_simulations_from_mempy.py +119 -0
- guts_base-0.8.5/tests/test_simulations.py +0 -109
- {guts_base-0.8.5 → guts_base-0.8.6}/LICENSE +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/README.md +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/__init__.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/expydb.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/openguts.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/preprocessing.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/survival.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/time_of_death.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/data/utils.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/plot.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/prob.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim/__init__.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim/mempy.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base/sim.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/dependency_links.txt +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/entry_points.txt +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/guts_base.egg-info/top_level.txt +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/setup.cfg +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/tests/test_data_import.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/tests/test_from_pymob.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/tests/test_scripted_simulations.py +0 -0
- {guts_base-0.8.5 → guts_base-0.8.6}/tests/test_symbolic_solve.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: guts_base
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Basic GUTS model implementation in pymob
|
|
5
5
|
Author-email: Florian Schunck <fluncki@protonmail.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -686,14 +686,14 @@ Classifier: Natural Language :: English
|
|
|
686
686
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
687
687
|
Classifier: Operating System :: OS Independent
|
|
688
688
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
689
|
-
Requires-Python:
|
|
689
|
+
Requires-Python: <3.12,>=3.10
|
|
690
690
|
Description-Content-Type: text/markdown
|
|
691
691
|
License-File: LICENSE
|
|
692
692
|
Requires-Dist: openpyxl>=3.1.3
|
|
693
693
|
Requires-Dist: Bottleneck>=1.5.0
|
|
694
694
|
Requires-Dist: expydb>=0.6.0
|
|
695
695
|
Requires-Dist: mempyguts>=1.5.0
|
|
696
|
-
Requires-Dist: pymob[numpyro]<
|
|
696
|
+
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
|
|
697
697
|
Provides-Extra: dev
|
|
698
698
|
Requires-Dist: pytest>=7.3; extra == "dev"
|
|
699
699
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -67,7 +67,8 @@ def design_exposure_scenario(
|
|
|
67
67
|
"""
|
|
68
68
|
TODO: tmax, dt and eps are probably not necessary
|
|
69
69
|
"""
|
|
70
|
-
|
|
70
|
+
# add dt so that tmax is definitely inclded
|
|
71
|
+
time = np.arange(0, t_max+dt, step=dt) # daily time resolution
|
|
71
72
|
time = np.unique(np.concatenate([time] + [
|
|
72
73
|
np.array([time[-1] if vals["end"] is None else vals["end"]])
|
|
73
74
|
for key, vals in exposures.items()
|
|
@@ -79,13 +80,12 @@ def design_exposure_scenario(
|
|
|
79
80
|
treat = design_exposure_timeseries(time, expo, eps)
|
|
80
81
|
treatments.update({key: treat})
|
|
81
82
|
|
|
82
|
-
data = np.column_stack(list(treatments.values()))
|
|
83
|
+
data = np.column_stack(list(treatments.values()))
|
|
83
84
|
data = np.expand_dims(data, axis=0)
|
|
84
85
|
|
|
85
86
|
coords = {"id": [0], "time": time}
|
|
86
87
|
|
|
87
|
-
|
|
88
|
-
coords.update({exposure_dimension: list(treatments.keys())})
|
|
88
|
+
coords.update({exposure_dimension: list(treatments.keys())})
|
|
89
89
|
|
|
90
90
|
exposures_dataset = xr.Dataset(
|
|
91
91
|
data_vars={"exposure": (tuple(coords.keys()), data)},
|
|
@@ -32,7 +32,13 @@ from mempy.model import (
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
red_sd = RED_SD._rhs_jax
|
|
35
|
+
red_sd_post_processing = RED_SD._solver_post_processing
|
|
36
|
+
|
|
35
37
|
red_it = RED_IT._rhs_jax
|
|
38
|
+
red_it_post_processing = RED_IT._solver_post_processing
|
|
39
|
+
|
|
40
|
+
red_sd_ia = RED_SD_IA._rhs_jax
|
|
41
|
+
red_sd_ia_post_processing = RED_SD_IA._solver_post_processing
|
|
36
42
|
|
|
37
43
|
|
|
38
44
|
def p_survival(results, t, interpolation, z, k_k, h_b):
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import glob
|
|
3
|
+
from functools import partial
|
|
4
|
+
from copy import deepcopy
|
|
3
5
|
import warnings
|
|
4
6
|
import numpy as np
|
|
5
7
|
import xarray as xr
|
|
@@ -13,6 +15,7 @@ from pymob.sim.config import DataVariable, Param, string_to_list
|
|
|
13
15
|
|
|
14
16
|
from pymob.solvers import JaxSolver
|
|
15
17
|
from pymob.solvers.base import rect_interpolation
|
|
18
|
+
from pymob.sim.config import ParameterDict
|
|
16
19
|
from expyDB.intervention_model import (
|
|
17
20
|
Treatment, Timeseries, select, from_expydb
|
|
18
21
|
)
|
|
@@ -24,7 +27,6 @@ from guts_base.data import (
|
|
|
24
27
|
)
|
|
25
28
|
from guts_base.sim.report import GutsReport
|
|
26
29
|
|
|
27
|
-
|
|
28
30
|
class GutsBase(SimulationBase):
|
|
29
31
|
"""
|
|
30
32
|
Initializes GUTS models from a variety of data sources
|
|
@@ -36,15 +38,15 @@ class GutsBase(SimulationBase):
|
|
|
36
38
|
"""
|
|
37
39
|
solver = JaxSolver
|
|
38
40
|
Report = GutsReport
|
|
39
|
-
unit_time: Literal["day", "hour", "minute", "second"] = "day"
|
|
40
|
-
results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
|
|
41
|
-
ecx_mode: Literal["mean", "draws"] = "mean"
|
|
42
41
|
|
|
43
42
|
def initialize(self, input: Dict = None):
|
|
43
|
+
self.ecx_mode: Literal["mean", "draws"] = "mean"
|
|
44
44
|
|
|
45
|
+
self.unit_time: Literal["day", "hour", "minute", "second"] = "day"
|
|
45
46
|
if hasattr(self.config.simulation, "unit_time"):
|
|
46
47
|
self.unit_time = self.config.simulation.unit_time # type: ignore
|
|
47
48
|
|
|
49
|
+
self.results_interpolation: Optional[List[float|int]] = [np.nan, np.nan, 100]
|
|
48
50
|
if hasattr(self.config.simulation, "results_interpolation"):
|
|
49
51
|
self.results_interpolation = string_to_list(self.config.simulation.results_interpolation)
|
|
50
52
|
self.results_interpolation[0] = float(self.results_interpolation[0])
|
|
@@ -229,6 +231,9 @@ class GutsBase(SimulationBase):
|
|
|
229
231
|
"""This function interpolates the posterior with a given resolution
|
|
230
232
|
posterior_predictions calculate proper survival predictions for the
|
|
231
233
|
posterior.
|
|
234
|
+
|
|
235
|
+
It also makes sure that the new interpolation does not include fewer values
|
|
236
|
+
than the original dataset
|
|
232
237
|
"""
|
|
233
238
|
|
|
234
239
|
if np.isnan(self.results_interpolation[0]):
|
|
@@ -244,8 +249,15 @@ class GutsBase(SimulationBase):
|
|
|
244
249
|
stop=self.results_interpolation[1],
|
|
245
250
|
num=self.results_interpolation[2]
|
|
246
251
|
)
|
|
252
|
+
|
|
253
|
+
# combine original coordinates and interpolation. This
|
|
254
|
+
# a) helps error checking during posterior predictions.
|
|
255
|
+
# b) makes sure that the original time vector is retained, which may be
|
|
256
|
+
# relevant for the simulation success (e.g. IT model)
|
|
247
257
|
self.observations = self.observations.reindex(
|
|
248
|
-
time=
|
|
258
|
+
time=np.unique(np.concatenate(
|
|
259
|
+
[time_interpolate, self.observations["time"]]
|
|
260
|
+
))
|
|
249
261
|
)
|
|
250
262
|
|
|
251
263
|
self.dispatch_constructor()
|
|
@@ -272,35 +284,85 @@ class GutsBase(SimulationBase):
|
|
|
272
284
|
|
|
273
285
|
def copy(self):
|
|
274
286
|
with warnings.catch_warnings(action="ignore"):
|
|
275
|
-
sim_copy = type(self)(self.config)
|
|
276
|
-
sim_copy.observations = self.observations
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
287
|
+
sim_copy = type(self)(self.config.copy(deep=True))
|
|
288
|
+
sim_copy.observations = self.observations.copy(deep=True)
|
|
289
|
+
|
|
290
|
+
# must copy individual parts of the dict due to the on_update method
|
|
291
|
+
model_parameters = {k: deepcopy(v) for k, v in self.model_parameters.items()}
|
|
292
|
+
|
|
293
|
+
# TODO: Refactor this once the parameterize method is removed.
|
|
294
|
+
sim_copy.parameterize = partial(sim_copy.parameterize, model_parameters=model_parameters)
|
|
295
|
+
sim_copy._model_parameters = ParameterDict(model_parameters, callback=sim_copy._on_params_updated)
|
|
296
|
+
|
|
297
|
+
sim_copy.load_modules()
|
|
298
|
+
if hasattr(self, "inferer"):
|
|
299
|
+
sim_copy.inferer = type(self.inferer)(sim_copy)
|
|
300
|
+
# idata uses deepcopy by default
|
|
301
|
+
sim_copy.inferer.idata = self.inferer.idata.copy()
|
|
281
302
|
sim_copy.model = self.model
|
|
282
303
|
sim_copy.solver_post_processing = self.solver_post_processing
|
|
283
|
-
sim_copy.load_modules()
|
|
284
304
|
|
|
285
305
|
return sim_copy
|
|
286
|
-
|
|
287
306
|
|
|
288
|
-
@property
|
|
289
307
|
def predefined_scenarios(self):
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
308
|
+
"""
|
|
309
|
+
TODO: Fix timescale to observations
|
|
310
|
+
TODO: Incorporate extra exposure patterns (constant, pulse_1day, pulse_2day)
|
|
311
|
+
"""
|
|
312
|
+
# get the maximum possible time to provide exposure scenarios that are definitely
|
|
313
|
+
# long enough
|
|
314
|
+
time_max = max(
|
|
315
|
+
self.observations[self.config.simulation.x_dimension].max(),
|
|
316
|
+
*self.Report.ecx_estimates_times
|
|
298
317
|
)
|
|
299
318
|
|
|
300
|
-
|
|
301
|
-
|
|
319
|
+
# this produces a exposure x_in dataset with only the dimensions ID and TIME
|
|
320
|
+
standard_dimensions = (
|
|
321
|
+
self.config.simulation.batch_dimension,
|
|
322
|
+
self.config.simulation.x_dimension,
|
|
302
323
|
)
|
|
303
324
|
|
|
325
|
+
# get dimensions different from standard dimensions
|
|
326
|
+
exposure_dimension = [
|
|
327
|
+
d for d in self.observations.exposure.dims if d not in
|
|
328
|
+
standard_dimensions
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
# raise an error if the number of extra dimensions is larger than 1
|
|
332
|
+
if len(exposure_dimension) > 1:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
f"{type(self).__name__} can currently handle one additional dimension for "+
|
|
335
|
+
f"the exposure beside {standard_dimensions}. You provided an exposure "+
|
|
336
|
+
f"array with the dimensions: {self.observations.exposure.dims}"
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
exposure_dimension = exposure_dimension[0]
|
|
340
|
+
|
|
341
|
+
# iterate over the coordinates of the exposure dimensions to
|
|
342
|
+
exposure_coordinates = self.observations.exposure[exposure_dimension].values
|
|
343
|
+
|
|
344
|
+
scenarios = {}
|
|
345
|
+
for coord in exposure_coordinates:
|
|
346
|
+
concentrations = np.where(coord == exposure_coordinates, 1, 0)
|
|
347
|
+
|
|
348
|
+
exposure_dict = {
|
|
349
|
+
coord: dict(start=0.0, end=1.0, concentration=conc)
|
|
350
|
+
for coord, conc in zip(exposure_coordinates, concentrations)
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
scenario = design_exposure_scenario(
|
|
354
|
+
exposures=exposure_dict,
|
|
355
|
+
t_max=time_max,
|
|
356
|
+
dt=1/24,
|
|
357
|
+
exposure_dimension=exposure_dimension
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
scenarios.update({
|
|
361
|
+
f"1day_exposure_{coord}": scenario
|
|
362
|
+
})
|
|
363
|
+
|
|
364
|
+
return scenarios
|
|
365
|
+
|
|
304
366
|
def expand_batch_like_coordinate_to_new_dimension(self, coordinate, variables):
|
|
305
367
|
"""This method will take an existing coordinate of a dataset that has the same
|
|
306
368
|
coordinate has the batch dimension. It will then re-express the coordinate as a
|
|
@@ -37,17 +37,32 @@ class ECxEstimator:
|
|
|
37
37
|
else:
|
|
38
38
|
self.sim.coordinates["id"] = [id]
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
# ensure correct coordinate order with x_in and raise errors early
|
|
41
|
+
self.sim.model_parameters["x_in"] = self.sim.parse_input("x_in", x_in)
|
|
41
42
|
|
|
42
43
|
self.sim.config.data_structure.survival.observed = False
|
|
43
44
|
self.sim.observations = self.sim.observations.sel(id=self.sim.coordinates["id"])
|
|
44
45
|
|
|
45
|
-
# fix time after observations have been set
|
|
46
|
-
|
|
46
|
+
# fix time after observations have been set. The outcome of the simulation
|
|
47
|
+
# can dependend on the time vector, because in e.g. IT models, the time resolution
|
|
48
|
+
# is important. Therefore the time at which the ECx is computed is just inserted
|
|
49
|
+
# into the time vector at the right position.
|
|
50
|
+
self.sim.coordinates["time"] = np.unique(np.concatenate([
|
|
51
|
+
self.sim.coordinates["time"], np.array(time, ndmin=1)
|
|
52
|
+
]))
|
|
47
53
|
|
|
48
54
|
self.sim.model_parameters["y0"] = self.sim.parse_input("y0", drop_dims="time")
|
|
49
55
|
self.sim.dispatch_constructor()
|
|
50
56
|
|
|
57
|
+
self.results = pd.Series({
|
|
58
|
+
"mean": np.nan,
|
|
59
|
+
"q05": np.nan,
|
|
60
|
+
"q95": np.nan,
|
|
61
|
+
"std": np.nan,
|
|
62
|
+
"cv": np.nan,
|
|
63
|
+
"msg": np.nan
|
|
64
|
+
})
|
|
65
|
+
|
|
51
66
|
|
|
52
67
|
|
|
53
68
|
def _evaluate(self, factor, theta):
|
|
@@ -108,6 +123,7 @@ class ECxEstimator:
|
|
|
108
123
|
accept_tol: float = 1e-5,
|
|
109
124
|
optimizer_tol: float = 1e-5,
|
|
110
125
|
method: str = "cobyla",
|
|
126
|
+
show_plot: bool = True,
|
|
111
127
|
**optimizer_kwargs
|
|
112
128
|
):
|
|
113
129
|
"""The minimizer for the EC_x operates on the unbounded linear scale, estimating
|
|
@@ -157,6 +173,9 @@ class ECxEstimator:
|
|
|
157
173
|
method : str
|
|
158
174
|
Minization algorithm. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
|
|
159
175
|
Default: 'cobyla'
|
|
176
|
+
|
|
177
|
+
show_plot : bool
|
|
178
|
+
Show the results plot of the lpx. Default: True
|
|
160
179
|
|
|
161
180
|
optimizer_kwargs :
|
|
162
181
|
Additional arguments to pass to the optimizer
|
|
@@ -178,9 +197,18 @@ class ECxEstimator:
|
|
|
178
197
|
)
|
|
179
198
|
else:
|
|
180
199
|
pass
|
|
200
|
+
|
|
201
|
+
warnings.warn(
|
|
202
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
203
|
+
)
|
|
181
204
|
|
|
182
205
|
elif mode == "mean":
|
|
183
206
|
draws = 1
|
|
207
|
+
|
|
208
|
+
warnings.warn(
|
|
209
|
+
"Values passed to 'parameters' don't have an effect in mode='draws'"
|
|
210
|
+
)
|
|
211
|
+
|
|
184
212
|
elif mode == "manual":
|
|
185
213
|
draws = 1
|
|
186
214
|
if parameters is None:
|
|
@@ -220,6 +248,7 @@ class ECxEstimator:
|
|
|
220
248
|
)
|
|
221
249
|
|
|
222
250
|
success = opt_res.fun < accept_tol
|
|
251
|
+
iteration += 1
|
|
223
252
|
|
|
224
253
|
# convert to linear scale from log scale
|
|
225
254
|
factor = np.exp(opt_res.x)
|
|
@@ -229,24 +258,43 @@ class ECxEstimator:
|
|
|
229
258
|
loss.append(opt_res.fun)
|
|
230
259
|
|
|
231
260
|
res_full = pd.DataFrame(dict(factor = mult_factor, loss=loss, retries=iterations))
|
|
232
|
-
|
|
261
|
+
self.results_full = res_full
|
|
262
|
+
|
|
263
|
+
metric = "{name}_{x}".format(name=self._name, x=int(self.x * 100),)
|
|
264
|
+
|
|
265
|
+
successes = sum(res_full.loss < accept_tol)
|
|
266
|
+
if successes < draws:
|
|
233
267
|
warnings.warn(
|
|
234
|
-
f"Not all optimizations converged on the {
|
|
268
|
+
f"Not all optimizations converged on the {metric} ({successes/draws*100}%). " +
|
|
235
269
|
"Adjust starting values and method")
|
|
236
270
|
print(res_full)
|
|
271
|
+
|
|
272
|
+
short_msg = f"Estimation success rate: {successes/draws*100}%"
|
|
273
|
+
self.results["msg"] = short_msg
|
|
237
274
|
|
|
238
275
|
res = res_full.loc[res_full.loss < accept_tol,:]
|
|
239
276
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
277
|
+
if len(res) == 0:
|
|
278
|
+
self.msg = (
|
|
279
|
+
f"{metric} could not be found. Two reasons typically cause this problem: "+
|
|
280
|
+
f"1) no expoure before the time at which the {metric} is calculated. "+
|
|
281
|
+
"Check the the exposure profile. " +
|
|
282
|
+
f"2) Too high background mortality. If the time at which the {metric} is "+
|
|
283
|
+
f"calculated is large and background mortality is high, the {metric}, " +
|
|
284
|
+
"may be reached independent of the effect and optimization cannot succeed."
|
|
285
|
+
)
|
|
247
286
|
|
|
248
|
-
|
|
249
|
-
|
|
287
|
+
print(self.msg)
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
self.results["mean"] = np.round(np.mean(res.factor.values), 4)
|
|
291
|
+
self.results["q05"] = np.round(np.quantile(res.factor.values, 0.05), 4)
|
|
292
|
+
self.results["q95"] = np.round(np.quantile(res.factor.values, 0.95), 4)
|
|
293
|
+
self.results["std"] = np.round(np.std(res.factor.values), 4)
|
|
294
|
+
self.results["cv"] = np.round(np.std(res.factor.values)/np.mean(res.factor.values), 2)
|
|
295
|
+
|
|
296
|
+
if show_plot:
|
|
297
|
+
self.plot_profile_and_effect(parameters=parameters)
|
|
250
298
|
|
|
251
299
|
print("{name}_{x}".format(name=self._name, x=int(self.x * 100),))
|
|
252
300
|
print(self.results)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import itertools as it
|
|
3
|
+
from typing import List
|
|
3
4
|
import pandas as pd
|
|
4
5
|
|
|
5
6
|
from pymob import SimulationBase
|
|
@@ -9,6 +10,8 @@ from guts_base.plot import plot_survival_multipanel
|
|
|
9
10
|
from guts_base.sim.ecx import ECxEstimator
|
|
10
11
|
|
|
11
12
|
class GutsReport(Report):
|
|
13
|
+
ecx_estimates_times: List = [1, 2, 4, 10]
|
|
14
|
+
ecx_estimates_x: List = [0.1, 0.25, 0.5, 0.75, 0.9]
|
|
12
15
|
|
|
13
16
|
def additional_reports(self, sim: "SimulationBase"):
|
|
14
17
|
super().additional_reports(sim=sim)
|
|
@@ -33,9 +36,9 @@ class GutsReport(Report):
|
|
|
33
36
|
|
|
34
37
|
@reporting
|
|
35
38
|
def LCx_estimates(self, sim):
|
|
36
|
-
X =
|
|
37
|
-
T =
|
|
38
|
-
P = sim.predefined_scenarios
|
|
39
|
+
X = self.ecx_estimates_x
|
|
40
|
+
T = self.ecx_estimates_times
|
|
41
|
+
P = sim.predefined_scenarios()
|
|
39
42
|
|
|
40
43
|
estimates = pd.DataFrame(
|
|
41
44
|
it.product(X, T, P.keys()),
|
|
@@ -57,6 +60,7 @@ class GutsReport(Report):
|
|
|
57
60
|
ecx_estimator.estimate(
|
|
58
61
|
mode=sim.ecx_mode,
|
|
59
62
|
draws=250,
|
|
63
|
+
show_plot=False
|
|
60
64
|
)
|
|
61
65
|
|
|
62
66
|
ecx.append(ecx_estimator.results)
|
|
@@ -64,9 +68,6 @@ class GutsReport(Report):
|
|
|
64
68
|
results = pd.DataFrame(ecx)
|
|
65
69
|
estimates[results.columns] = results
|
|
66
70
|
|
|
67
|
-
|
|
68
|
-
estimates.to_csv()
|
|
69
|
-
file = os.path.join(sim.output_path, "lcx_estimates.csv")
|
|
70
|
-
lab = self._label.format(placeholder='$LC_x$ estimates')
|
|
71
|
-
self._write_table(tab=estimates, label_insert=f"$LC_x$ estimates \label{{{lab}}}]({os.path.basename(file)})")
|
|
71
|
+
out = self._write_table(tab=estimates, label_insert="$LC_x$ estimates")
|
|
72
72
|
|
|
73
|
+
return out
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: guts_base
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Basic GUTS model implementation in pymob
|
|
5
5
|
Author-email: Florian Schunck <fluncki@protonmail.com>
|
|
6
6
|
License: GNU GENERAL PUBLIC LICENSE
|
|
@@ -686,14 +686,14 @@ Classifier: Natural Language :: English
|
|
|
686
686
|
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
|
687
687
|
Classifier: Operating System :: OS Independent
|
|
688
688
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
689
|
-
Requires-Python:
|
|
689
|
+
Requires-Python: <3.12,>=3.10
|
|
690
690
|
Description-Content-Type: text/markdown
|
|
691
691
|
License-File: LICENSE
|
|
692
692
|
Requires-Dist: openpyxl>=3.1.3
|
|
693
693
|
Requires-Dist: Bottleneck>=1.5.0
|
|
694
694
|
Requires-Dist: expydb>=0.6.0
|
|
695
695
|
Requires-Dist: mempyguts>=1.5.0
|
|
696
|
-
Requires-Dist: pymob[numpyro]<
|
|
696
|
+
Requires-Dist: pymob[interactive,numpyro]<0.6.0,>=0.4.1
|
|
697
697
|
Provides-Extra: dev
|
|
698
698
|
Requires-Dist: pytest>=7.3; extra == "dev"
|
|
699
699
|
Requires-Dist: bumpver; extra == "dev"
|
|
@@ -26,7 +26,9 @@ guts_base/sim/ecx.py
|
|
|
26
26
|
guts_base/sim/mempy.py
|
|
27
27
|
guts_base/sim/report.py
|
|
28
28
|
tests/test_data_import.py
|
|
29
|
+
tests/test_ecx.py
|
|
29
30
|
tests/test_from_pymob.py
|
|
30
31
|
tests/test_scripted_simulations.py
|
|
31
32
|
tests/test_simulations.py
|
|
33
|
+
tests/test_simulations_from_mempy.py
|
|
32
34
|
tests/test_symbolic_solve.py
|
|
@@ -4,19 +4,19 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "guts_base"
|
|
7
|
-
version = "0.8.
|
|
7
|
+
version = "0.8.6"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name="Florian Schunck", email="fluncki@protonmail.com" },
|
|
10
10
|
]
|
|
11
11
|
description = "Basic GUTS model implementation in pymob"
|
|
12
12
|
readme = "README.md"
|
|
13
|
-
requires-python = ">=3.10"
|
|
13
|
+
requires-python = ">=3.10, <3.12"
|
|
14
14
|
dependencies=[
|
|
15
15
|
"openpyxl >= 3.1.3",
|
|
16
16
|
"Bottleneck >= 1.5.0",
|
|
17
17
|
"expydb >= 0.6.0",
|
|
18
18
|
"mempyguts >= 1.5.0",
|
|
19
|
-
"pymob[numpyro] >= 0.4.1, <
|
|
19
|
+
"pymob[numpyro,interactive] >= 0.4.1, < 0.6.0"
|
|
20
20
|
]
|
|
21
21
|
license = {file = "LICENSE"}
|
|
22
22
|
classifiers = [
|
|
@@ -49,7 +49,7 @@ import-openguts = "guts_base.data.openguts:create_database_and_import_data"
|
|
|
49
49
|
convert-time-of-death-to-openguts = "guts_base.data.time_of_death:time_of_death_to_openguts"
|
|
50
50
|
|
|
51
51
|
[tool.bumpver]
|
|
52
|
-
current_version = "0.8.
|
|
52
|
+
current_version = "0.8.6"
|
|
53
53
|
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
|
|
54
54
|
commit_message = "bump version {old_version} -> {new_version}"
|
|
55
55
|
tag_message = "{new_version}"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import arviz as az
|
|
3
|
+
|
|
4
|
+
from guts_base import LPxEstimator, GutsBase
|
|
5
|
+
|
|
6
|
+
from tests.conftest import construct_sim, idfunc
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Derive simulations for testing from fixtures
|
|
10
|
+
@pytest.fixture(params=[
|
|
11
|
+
(GutsBase, "red_sd_ia", "ecx/idata_red_sd_ia.nc", "FLUA.5"),
|
|
12
|
+
(GutsBase, "red_it", "ecx/idata_red_it.nc", "T 1"),
|
|
13
|
+
], ids=idfunc,)
|
|
14
|
+
def lpx_estimator(request, tmp_path):
|
|
15
|
+
simulation_class, scenario, idata, id = request.param
|
|
16
|
+
yield construct_estimator(
|
|
17
|
+
simulation_class=simulation_class,
|
|
18
|
+
scenario=scenario,
|
|
19
|
+
idata=idata,
|
|
20
|
+
id=id,
|
|
21
|
+
output_path=tmp_path
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def construct_estimator(simulation_class, scenario, idata, id, output_path=None):
|
|
25
|
+
sim = construct_sim(
|
|
26
|
+
simulation_class=simulation_class,
|
|
27
|
+
scenario=scenario,
|
|
28
|
+
output_path=output_path
|
|
29
|
+
)
|
|
30
|
+
sim.set_inferer("numpyro")
|
|
31
|
+
sim.inferer.idata = az.from_netcdf(f"data/testing/{idata}")
|
|
32
|
+
|
|
33
|
+
return LPxEstimator(sim=sim, id=id)
|
|
34
|
+
|
|
35
|
+
@pytest.mark.slow
|
|
36
|
+
def test_lp50(lpx_estimator):
|
|
37
|
+
# pytest.skip()
|
|
38
|
+
|
|
39
|
+
theta_mean = lpx_estimator.sim.inferer.idata.posterior.mean(("chain", "draw"))
|
|
40
|
+
theta_mean = {k: v["data"] for k, v in theta_mean.to_dict()["data_vars"].items()}
|
|
41
|
+
|
|
42
|
+
lpx_estimator._loss(log_factor=0.0, theta=theta_mean)
|
|
43
|
+
|
|
44
|
+
lpx_estimator.plot_loss_curve()
|
|
45
|
+
|
|
46
|
+
lpx_estimator.estimate(mode="mean")
|
|
47
|
+
lpx_estimator.estimate(mode="manual", parameters=lpx_estimator._posterior_mean())
|
|
48
|
+
lpx_estimator.estimate(mode="draws")
|
|
49
|
+
|
|
50
|
+
lpx_estimator.results
|
|
51
|
+
lpx_estimator.results_full
|
|
52
|
+
|
|
53
|
+
def test_copy(lpx_estimator):
|
|
54
|
+
e = lpx_estimator.sim.dispatch()
|
|
55
|
+
e()
|
|
56
|
+
e.results
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
# test_inference(sim=construct_sim("test_scenario_v2", Simulation_v2), backend="numpyro")
|
|
61
|
+
# test_lp50(simulation_class=GutsBase, scenario="red_sd_ia", idata="ecx/idata_red_sd_ia.nc", id="FLUA.5")
|
|
62
|
+
test_copy(construct_estimator(GutsBase, "red_sd_ia", "ecx/idata_red_sd_ia.nc", "FLUA.5"))
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from guts_base import GutsBase
|
|
4
|
+
|
|
5
|
+
from tests.conftest import construct_sim, idfunc
|
|
6
|
+
|
|
7
|
+
# List test scenarios and simulations
|
|
8
|
+
@pytest.fixture(params=[
|
|
9
|
+
(GutsBase, "red_sd"),
|
|
10
|
+
(GutsBase, "red_it"),
|
|
11
|
+
(GutsBase, "red_sd_ia"),
|
|
12
|
+
], ids=idfunc)
|
|
13
|
+
def sim_and_scenario(request):
|
|
14
|
+
return request.param
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Derive simulations for testing from fixtures
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def sim(sim_and_scenario, tmp_path):
|
|
20
|
+
simulation_class, scenario = sim_and_scenario
|
|
21
|
+
yield construct_sim(
|
|
22
|
+
scenario=scenario,
|
|
23
|
+
simulation_class=simulation_class,
|
|
24
|
+
output_path=tmp_path
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# run tests with the Simulation fixtures
|
|
29
|
+
def test_setup(sim):
|
|
30
|
+
"""Tests the construction method"""
|
|
31
|
+
assert True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_simulation(sim):
|
|
35
|
+
"""Tests if a forward simulation pass can be computed"""
|
|
36
|
+
sim.dispatch_constructor()
|
|
37
|
+
evaluator = sim.dispatch()
|
|
38
|
+
evaluator()
|
|
39
|
+
evaluator.results
|
|
40
|
+
|
|
41
|
+
assert True
|
|
42
|
+
|
|
43
|
+
def test_copy(sim):
|
|
44
|
+
sim.dispatch_constructor()
|
|
45
|
+
e_orig = sim.dispatch()
|
|
46
|
+
e_orig()
|
|
47
|
+
e_orig.results
|
|
48
|
+
|
|
49
|
+
sim_copy = sim.copy()
|
|
50
|
+
|
|
51
|
+
sim_copy.dispatch_constructor()
|
|
52
|
+
e_copy = sim_copy.dispatch()
|
|
53
|
+
e_copy()
|
|
54
|
+
|
|
55
|
+
assert (e_copy.results == e_orig.results).all().to_array().all().values
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.mark.slow
|
|
59
|
+
@pytest.mark.parametrize("backend", ["numpyro"])
|
|
60
|
+
def test_inference(sim: GutsBase, backend):
|
|
61
|
+
"""Tests if prior predictions can be computed for arbitrary backends"""
|
|
62
|
+
sim.dispatch_constructor()
|
|
63
|
+
sim.set_inferer(backend)
|
|
64
|
+
|
|
65
|
+
sim.config.inference.n_predictions = 2
|
|
66
|
+
sim.prior_predictive_checks()
|
|
67
|
+
|
|
68
|
+
sim.config.inference_numpyro.kernel = "svi"
|
|
69
|
+
sim.config.inference_numpyro.svi_iterations = 10
|
|
70
|
+
sim.config.inference_numpyro.svi_learning_rate = 0.05
|
|
71
|
+
sim.config.inference_numpyro.draws = 10
|
|
72
|
+
sim.config.inference.n_predictions = 10
|
|
73
|
+
|
|
74
|
+
sim.inferer.run()
|
|
75
|
+
|
|
76
|
+
sim.inferer.idata
|
|
77
|
+
sim.inferer.store_results()
|
|
78
|
+
|
|
79
|
+
sim.posterior_predictive_checks()
|
|
80
|
+
|
|
81
|
+
sim.inferer.load_results()
|
|
82
|
+
sim.config.report.debug_report = True
|
|
83
|
+
sim.report()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
if __name__ == "__main__":
|
|
87
|
+
test_inference(sim=construct_sim("red_sd", GutsBase), backend="numpyro")
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
import pytest
|
|
3
|
+
import numpy as np
|
|
4
|
+
import xarray as xr
|
|
5
|
+
from mempy.model import (
|
|
6
|
+
RED_IT, RED_SD, RED_SD_DA, Reduced,
|
|
7
|
+
BufferGUTS_SD, BufferGUTS_IT, BufferGUTS_SD_CA
|
|
8
|
+
)
|
|
9
|
+
from mempy.input_data import read_exposure_survival
|
|
10
|
+
from guts_base import PymobSimulator
|
|
11
|
+
|
|
12
|
+
from tests.conftest import idfunc
|
|
13
|
+
|
|
14
|
+
# results are from Bürger and Focks 2025 (https://doi.org/10.1093/etojnl/vgae058)
|
|
15
|
+
# supplementary material (Tab. 5.3)
|
|
16
|
+
OPENGUTS_ESTIMATES = dict(
|
|
17
|
+
red_sd = xr.Dataset(dict(kd=0.712, m=2.89, b=0.619, hb=0.008)).to_array().sortby("variable"),
|
|
18
|
+
red_sd_da = None,
|
|
19
|
+
bufferguts_sd_ca = None,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def read_data(file):
|
|
23
|
+
data = read_exposure_survival(
|
|
24
|
+
"data/testing", file,
|
|
25
|
+
survival_name="Survival",
|
|
26
|
+
exposure_name="Exposure",
|
|
27
|
+
visualize=False,
|
|
28
|
+
with_raw_exposure_data=True
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
exposure_funcs, survival_data, num_expos, exposure_data = data
|
|
32
|
+
info_dict = {}
|
|
33
|
+
|
|
34
|
+
return exposure_funcs, survival_data, num_expos, info_dict, exposure_data
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def construct_sim(dataset: Tuple, model: type, output_path="results/testing"):
|
|
38
|
+
"""Helper function to construct simulations for debugging"""
|
|
39
|
+
_, survival_data, num_expos, _, exposure_data = read_data(file=dataset)
|
|
40
|
+
|
|
41
|
+
if model in (RED_IT, RED_SD, BufferGUTS_SD, BufferGUTS_IT):
|
|
42
|
+
_model = model()
|
|
43
|
+
else:
|
|
44
|
+
_model = model(num_expos=num_expos)
|
|
45
|
+
|
|
46
|
+
sim = PymobSimulator.from_mempy(
|
|
47
|
+
exposure_data=exposure_data,
|
|
48
|
+
survival_data=survival_data,
|
|
49
|
+
model=_model,
|
|
50
|
+
output_directory=output_path
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return sim
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.fixture(params=[
|
|
57
|
+
("ringtest_A_SD.xlsx", RED_SD),
|
|
58
|
+
("Fit_Data_Cloeon_final.xlsx", RED_SD_DA,),
|
|
59
|
+
("osmia_multiexpo_synthetic.xlsx", BufferGUTS_SD_CA,),
|
|
60
|
+
], ids=idfunc)
|
|
61
|
+
def dataset_and_model(request) -> Reduced:
|
|
62
|
+
yield request.param
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Derive simulations for testing from fixtures
|
|
66
|
+
@pytest.fixture
|
|
67
|
+
def sim(dataset_and_model, tmp_path):
|
|
68
|
+
dataset, model = dataset_and_model
|
|
69
|
+
yield construct_sim(dataset=dataset, model=model, output_path=tmp_path)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# run tests with the Simulation fixtures
|
|
73
|
+
def test_setup(sim):
|
|
74
|
+
"""Tests the construction method"""
|
|
75
|
+
assert True
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_simulation(sim):
|
|
79
|
+
"""Tests if a forward simulation pass can be computed"""
|
|
80
|
+
# sim.dispatch_constructor()
|
|
81
|
+
evaluator = sim.dispatch()
|
|
82
|
+
evaluator()
|
|
83
|
+
evaluator.results
|
|
84
|
+
|
|
85
|
+
assert True
|
|
86
|
+
|
|
87
|
+
@pytest.mark.slow
|
|
88
|
+
@pytest.mark.parametrize("backend", ["numpyro"])
|
|
89
|
+
def test_inference(sim: PymobSimulator, backend):
|
|
90
|
+
"""Tests if prior predictions can be computed for arbitrary backends"""
|
|
91
|
+
|
|
92
|
+
sim.set_inferer(backend)
|
|
93
|
+
|
|
94
|
+
sim.prior_predictive_checks()
|
|
95
|
+
sim.inferer.run()
|
|
96
|
+
|
|
97
|
+
sim.posterior_predictive_checks()
|
|
98
|
+
|
|
99
|
+
sim.config.report.debug_report = True
|
|
100
|
+
sim.report()
|
|
101
|
+
|
|
102
|
+
# test if inferer converged on the true estmiates
|
|
103
|
+
pymob_estimates = sim.inferer.idata.posterior.mean(("chain", "draw")).to_array().sortby("variable")
|
|
104
|
+
openguts_estimates = OPENGUTS_ESTIMATES[sim.config.simulation.model.lower()]
|
|
105
|
+
|
|
106
|
+
if openguts_estimates is None:
|
|
107
|
+
# this explicitly skips testing the results, since they are not available,
|
|
108
|
+
# but does not fail the test.
|
|
109
|
+
pytest.skip()
|
|
110
|
+
|
|
111
|
+
np.testing.assert_allclose(pymob_estimates, openguts_estimates, rtol=0.05, atol=0.1)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
# test_inference(sim=construct_sim("ringtest_A_SD.xlsx", RED_SD), backend="numpyro",)
|
|
118
|
+
test_inference(sim=construct_sim("osmia_multiexpo_synthetic.xlsx", BufferGUTS_SD_CA), backend="numpyro",)
|
|
119
|
+
# test_inference(sim=construct_sim("Fit_Data_Cloeon_final.xlsx", RED_SD_DA), backend="numpyro",)
|
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
import arviz as az
|
|
3
|
-
|
|
4
|
-
from guts_base import LPxEstimator, GutsBase
|
|
5
|
-
from mempy.model import RED_SD_DA, RED_IT
|
|
6
|
-
|
|
7
|
-
def construct_sim(scenario, simulation_class):
|
|
8
|
-
"""Helper function to construct simulations for debugging"""
|
|
9
|
-
sim = simulation_class(f"scenarios/{scenario}/settings.cfg")
|
|
10
|
-
|
|
11
|
-
# this sets a different output directory
|
|
12
|
-
sim.config.case_study.scenario = "testing"
|
|
13
|
-
sim.setup()
|
|
14
|
-
return sim
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
# List test scenarios and simulations
|
|
18
|
-
@pytest.fixture(scope="session", params=[
|
|
19
|
-
# (GutsBase, "red_sd"),
|
|
20
|
-
])
|
|
21
|
-
def sim_and_scenario(request):
|
|
22
|
-
return request.param
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
# Derive simulations for testing from fixtures
|
|
26
|
-
@pytest.fixture(scope="session")
|
|
27
|
-
def sim(sim_and_scenario):
|
|
28
|
-
simulation_class, scenario = sim_and_scenario
|
|
29
|
-
yield construct_sim(scenario, simulation_class)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# run tests with the Simulation fixtures
|
|
33
|
-
def test_setup(sim):
|
|
34
|
-
"""Tests the construction method"""
|
|
35
|
-
assert True
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def test_simulation(sim):
|
|
39
|
-
"""Tests if a forward simulation pass can be computed"""
|
|
40
|
-
sim.dispatch_constructor()
|
|
41
|
-
evaluator = sim.dispatch()
|
|
42
|
-
evaluator()
|
|
43
|
-
evaluator.results
|
|
44
|
-
|
|
45
|
-
assert True
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@pytest.mark.slow
|
|
49
|
-
@pytest.mark.parametrize("backend", ["numpyro"])
|
|
50
|
-
def test_inference(sim: GutsBase, backend):
|
|
51
|
-
"""Tests if prior predictions can be computed for arbitrary backends"""
|
|
52
|
-
sim.dispatch_constructor()
|
|
53
|
-
sim.set_inferer(backend)
|
|
54
|
-
|
|
55
|
-
sim.config.inference.n_predictions = 2
|
|
56
|
-
sim.prior_predictive_checks()
|
|
57
|
-
|
|
58
|
-
sim.config.inference_numpyro.kernel = "svi"
|
|
59
|
-
sim.config.inference_numpyro.svi_iterations = 10
|
|
60
|
-
sim.config.inference_numpyro.svi_learning_rate = 0.05
|
|
61
|
-
sim.config.inference_numpyro.draws = 10
|
|
62
|
-
sim.config.inference.n_predictions = 10
|
|
63
|
-
|
|
64
|
-
sim.inferer.run()
|
|
65
|
-
|
|
66
|
-
sim.inferer.idata
|
|
67
|
-
sim.inferer.store_results()
|
|
68
|
-
|
|
69
|
-
sim.posterior_predictive_checks()
|
|
70
|
-
|
|
71
|
-
sim.inferer.load_results()
|
|
72
|
-
sim.config.report.debug_report = True
|
|
73
|
-
sim.report()
|
|
74
|
-
|
|
75
|
-
@pytest.mark.slow
|
|
76
|
-
@pytest.mark.parametrize("model,dataset,idata,id", [
|
|
77
|
-
(RED_SD_DA, "Fit_Data_Cloeon_final.xlsx", "idata_red_sd_da.nc", "FLUA.5"),
|
|
78
|
-
(RED_IT, "ringtest_A_IT.xlsx", "idata_red_it.nc", "T 1")
|
|
79
|
-
])
|
|
80
|
-
def test_lp50(model, dataset, idata, id):
|
|
81
|
-
pytest.skip()
|
|
82
|
-
sim=construct_sim(dataset=dataset, model=model)
|
|
83
|
-
sim.set_inferer("numpyro")
|
|
84
|
-
sim.inferer.idata = az.from_netcdf(f"data/testing/{idata}")
|
|
85
|
-
|
|
86
|
-
lpx_estimator = LPxEstimator(sim=sim, id=id)
|
|
87
|
-
|
|
88
|
-
theta_mean = lpx_estimator.sim.inferer.idata.posterior.mean(("chain", "draw"))
|
|
89
|
-
theta_mean = {k: v["data"] for k, v in theta_mean.to_dict()["data_vars"].items()}
|
|
90
|
-
|
|
91
|
-
lpx_estimator._loss(log_factor=0.0, theta=theta_mean)
|
|
92
|
-
|
|
93
|
-
lpx_estimator.plot_loss_curve()
|
|
94
|
-
|
|
95
|
-
lpx_estimator.estimate(mode="mean")
|
|
96
|
-
lpx_estimator.plot_profile_and_effect()
|
|
97
|
-
lpx_estimator.estimate(mode="manual", parameters=lpx_estimator._posterior_mean())
|
|
98
|
-
lpx_estimator.plot_profile_and_effect(parameters=lpx_estimator._posterior_mean())
|
|
99
|
-
|
|
100
|
-
lpx_estimator.estimate(mode="draws")
|
|
101
|
-
lpx_estimator.plot_profile_and_effect()
|
|
102
|
-
|
|
103
|
-
lpx_estimator.results
|
|
104
|
-
lpx_estimator.results_full
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
if __name__ == "__main__":
|
|
108
|
-
# test_inference(sim=construct_sim("test_scenario_v2", Simulation_v2), backend="numpyro")
|
|
109
|
-
test_inference(sim=construct_sim("red_it", GutsBase), backend="numpyro",)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|