vivarium-public-health 3.1.2__py3-none-any.whl → 3.1.4__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/exceptions.py +5 -0
- vivarium_public_health/disease/model.py +116 -112
- vivarium_public_health/disease/state.py +215 -79
- vivarium_public_health/disease/transition.py +78 -21
- vivarium_public_health/risks/effect.py +35 -23
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +3 -4
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/METADATA +32 -32
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/RECORD +12 -11
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/WHEEL +1 -1
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
|
|
1
|
-
__version__ = "3.1.
|
1
|
+
__version__ = "3.1.4"
|
@@ -10,24 +10,20 @@ transitions at simulation initialization and during transitions.
|
|
10
10
|
"""
|
11
11
|
import warnings
|
12
12
|
from collections.abc import Callable, Iterable
|
13
|
+
from functools import partial
|
13
14
|
from typing import Any
|
14
15
|
|
15
|
-
import numpy as np
|
16
16
|
import pandas as pd
|
17
|
-
from vivarium.exceptions import VivariumError
|
18
17
|
from vivarium.framework.engine import Builder
|
19
|
-
from vivarium.framework.event import Event
|
20
18
|
from vivarium.framework.population import SimulantData
|
21
19
|
from vivarium.framework.state_machine import Machine
|
20
|
+
from vivarium.types import DataInput, LookupTableData
|
22
21
|
|
22
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
23
23
|
from vivarium_public_health.disease.state import BaseDiseaseState, SusceptibleState
|
24
24
|
from vivarium_public_health.disease.transition import TransitionString
|
25
25
|
|
26
26
|
|
27
|
-
class DiseaseModelError(VivariumError):
|
28
|
-
pass
|
29
|
-
|
30
|
-
|
31
27
|
class DiseaseModel(Machine):
|
32
28
|
|
33
29
|
##############
|
@@ -82,57 +78,77 @@ class DiseaseModel(Machine):
|
|
82
78
|
cause_type: str = "cause",
|
83
79
|
states: Iterable[BaseDiseaseState] = (),
|
84
80
|
residual_state: BaseDiseaseState | None = None,
|
85
|
-
|
86
|
-
):
|
87
|
-
super().__init__(cause, states=states
|
81
|
+
cause_specific_mortality_rate: DataInput | None = None,
|
82
|
+
) -> None:
|
83
|
+
super().__init__(cause, states=states)
|
88
84
|
self.cause = cause
|
89
85
|
self.cause_type = cause_type
|
86
|
+
self.residual_state = self._get_residual_state(initial_state, residual_state)
|
87
|
+
self._csmr_source = cause_specific_mortality_rate
|
88
|
+
self._get_data_functions = (
|
89
|
+
get_data_functions if get_data_functions is not None else {}
|
90
|
+
)
|
90
91
|
|
91
|
-
if
|
92
|
+
if get_data_functions is not None:
|
92
93
|
warnings.warn(
|
93
|
-
"
|
94
|
-
"
|
95
|
-
" retain the current behavior of defining a residual state, use"
|
96
|
-
" the 'residual_state' argument.",
|
94
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
95
|
+
" cause_specific_mortality_rate instead.",
|
97
96
|
DeprecationWarning,
|
98
97
|
stacklevel=2,
|
99
98
|
)
|
100
99
|
|
101
|
-
if
|
100
|
+
if cause_specific_mortality_rate is not None:
|
102
101
|
raise DiseaseModelError(
|
103
|
-
"
|
104
|
-
"
|
102
|
+
"It is not allowed to pass cause_specific_mortality_rate"
|
103
|
+
" both as a stand-alone argument and as part of"
|
104
|
+
" get_data_functions."
|
105
105
|
)
|
106
106
|
|
107
|
-
self.residual_state = initial_state.state_id
|
108
|
-
elif residual_state is not None:
|
109
|
-
self.residual_state = residual_state.state_id
|
110
|
-
else:
|
111
|
-
self.residual_state = self._get_default_residual_state()
|
112
|
-
|
113
|
-
self._get_data_functions = (
|
114
|
-
get_data_functions if get_data_functions is not None else {}
|
115
|
-
)
|
116
|
-
|
117
107
|
def setup(self, builder: Builder) -> None:
|
118
108
|
"""Perform this component's setup."""
|
119
109
|
super().setup(builder)
|
120
110
|
|
121
111
|
self.configuration_age_start = builder.configuration.population.initialization_age_min
|
122
112
|
self.configuration_age_end = builder.configuration.population.initialization_age_max
|
113
|
+
|
123
114
|
builder.value.register_value_modifier(
|
124
115
|
"cause_specific_mortality_rate",
|
125
116
|
self.adjust_cause_specific_mortality_rate,
|
126
117
|
requires_columns=["age", "sex"],
|
127
118
|
)
|
128
|
-
|
119
|
+
|
120
|
+
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
121
|
+
"""Initialize the simulants in the population.
|
122
|
+
|
123
|
+
If all simulants are initialized at age 0, birth prevalence is used.
|
124
|
+
Otherwise, prevalence is used.
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
pop_data
|
129
|
+
The population data object.
|
130
|
+
"""
|
131
|
+
if pop_data.user_data.get("age_end", self.configuration_age_end) == 0:
|
132
|
+
initialization_table_name = "birth_prevalence"
|
133
|
+
else:
|
134
|
+
initialization_table_name = "prevalence"
|
135
|
+
|
136
|
+
for state in self.states:
|
137
|
+
state.lookup_tables["initialization_weights"] = state.lookup_tables[
|
138
|
+
initialization_table_name
|
139
|
+
]
|
140
|
+
|
141
|
+
super().on_initialize_simulants(pop_data)
|
129
142
|
|
130
143
|
#################
|
131
144
|
# Setup methods #
|
132
145
|
#################
|
133
146
|
|
134
147
|
def load_cause_specific_mortality_rate(self, builder: Builder) -> float | pd.DataFrame:
|
135
|
-
if
|
148
|
+
if (
|
149
|
+
"cause_specific_mortality_rate" not in self._get_data_functions
|
150
|
+
and self._csmr_source is None
|
151
|
+
):
|
136
152
|
only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
|
137
153
|
if only_morbid:
|
138
154
|
csmr_data = 0.0
|
@@ -140,64 +156,14 @@ class DiseaseModel(Machine):
|
|
140
156
|
csmr_data = builder.data.load(
|
141
157
|
f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
|
142
158
|
)
|
159
|
+
elif self._csmr_source is not None:
|
160
|
+
csmr_data = self.get_data(builder, self._csmr_source)
|
143
161
|
else:
|
144
162
|
csmr_data = self._get_data_functions["cause_specific_mortality_rate"](
|
145
163
|
self.cause, builder
|
146
164
|
)
|
147
165
|
return csmr_data
|
148
166
|
|
149
|
-
########################
|
150
|
-
# Event-driven methods #
|
151
|
-
########################
|
152
|
-
|
153
|
-
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
154
|
-
population = self.population_view.subview(["age", "sex"]).get(pop_data.index)
|
155
|
-
|
156
|
-
assert self.residual_state in {s.state_id for s in self.states}
|
157
|
-
|
158
|
-
if pop_data.user_data["sim_state"] == "setup": # simulation start
|
159
|
-
if self.configuration_age_start != self.configuration_age_end != 0:
|
160
|
-
state_names, weights_bins = self.get_state_weights(
|
161
|
-
pop_data.index, "prevalence"
|
162
|
-
)
|
163
|
-
else:
|
164
|
-
raise NotImplementedError(
|
165
|
-
"We do not currently support an age 0 cohort. "
|
166
|
-
"configuration.population.initialization_age_min and "
|
167
|
-
"configuration.population.initialization_age_max "
|
168
|
-
"cannot both be 0."
|
169
|
-
)
|
170
|
-
|
171
|
-
else: # on time step
|
172
|
-
if pop_data.user_data["age_start"] == pop_data.user_data["age_end"] == 0:
|
173
|
-
state_names, weights_bins = self.get_state_weights(
|
174
|
-
pop_data.index, "birth_prevalence"
|
175
|
-
)
|
176
|
-
else:
|
177
|
-
state_names, weights_bins = self.get_state_weights(
|
178
|
-
pop_data.index, "prevalence"
|
179
|
-
)
|
180
|
-
|
181
|
-
if state_names and not population.empty:
|
182
|
-
# only do this if there are states in the model that supply prevalence data
|
183
|
-
population["sex_id"] = population.sex.apply({"Male": 1, "Female": 2}.get)
|
184
|
-
|
185
|
-
condition_column = self.assign_initial_status_to_simulants(
|
186
|
-
population,
|
187
|
-
state_names,
|
188
|
-
weights_bins,
|
189
|
-
self.randomness.get_draw(population.index),
|
190
|
-
)
|
191
|
-
|
192
|
-
condition_column = condition_column.rename(
|
193
|
-
columns={"condition_state": self.state_column}
|
194
|
-
)
|
195
|
-
else:
|
196
|
-
condition_column = pd.Series(
|
197
|
-
self.residual_state, index=population.index, name=self.state_column
|
198
|
-
)
|
199
|
-
self.population_view.update(condition_column)
|
200
|
-
|
201
167
|
##################################
|
202
168
|
# Pipeline sources and modifiers #
|
203
169
|
##################################
|
@@ -209,40 +175,78 @@ class DiseaseModel(Machine):
|
|
209
175
|
# Helper functions #
|
210
176
|
####################
|
211
177
|
|
212
|
-
def
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
return susceptible_states[0].state_id
|
217
|
-
|
218
|
-
def get_state_weights(
|
219
|
-
self, pop_index: pd.Index, prevalence_type: str
|
220
|
-
) -> tuple[list[str], np.ndarray | None]:
|
221
|
-
states = [state for state in self.states if state.lookup_tables.get(prevalence_type)]
|
222
|
-
|
223
|
-
if not states:
|
224
|
-
return states, None
|
178
|
+
def _get_residual_state(
|
179
|
+
self, initial_state: BaseDiseaseState, residual_state: BaseDiseaseState
|
180
|
+
) -> BaseDiseaseState:
|
181
|
+
"""Get the residual state for the DiseaseModel.
|
225
182
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
183
|
+
This will be the residual state if it is provided, otherwise it will be
|
184
|
+
the model's SusceptibleState. This method also calculates the residual
|
185
|
+
state's birth_prevalence and prevalence.
|
186
|
+
"""
|
187
|
+
if initial_state is not None:
|
188
|
+
warnings.warn(
|
189
|
+
"In the future, the 'initial_state' argument to DiseaseModel"
|
190
|
+
" will be used to initialize all simulants into that state. To"
|
191
|
+
" retain the current behavior of defining a residual state, use"
|
192
|
+
" the 'residual_state' argument.",
|
193
|
+
DeprecationWarning,
|
194
|
+
stacklevel=2,
|
195
|
+
)
|
230
196
|
|
231
|
-
|
232
|
-
|
197
|
+
if residual_state:
|
198
|
+
raise DiseaseModelError(
|
199
|
+
"A DiseaseModel cannot be initialized with both"
|
200
|
+
" 'initial_state and 'residual_state'."
|
201
|
+
)
|
233
202
|
|
234
|
-
|
203
|
+
residual_state = initial_state
|
204
|
+
elif residual_state is None:
|
205
|
+
susceptible_states = [s for s in self.states if isinstance(s, SusceptibleState)]
|
206
|
+
if len(susceptible_states) != 1:
|
207
|
+
raise DiseaseModelError(
|
208
|
+
"Disease model must have exactly one SusceptibleState."
|
209
|
+
)
|
210
|
+
residual_state = susceptible_states[0]
|
235
211
|
|
236
|
-
|
212
|
+
if residual_state not in self.states:
|
213
|
+
raise DiseaseModelError(
|
214
|
+
f"Residual state '{self.residual_state}' must be one of the"
|
215
|
+
f" states: {self.states}."
|
216
|
+
)
|
237
217
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
218
|
+
residual_state.birth_prevalence = partial(
|
219
|
+
self._get_residual_state_probabilities, table_name="birth_prevalence"
|
220
|
+
)
|
221
|
+
residual_state.prevalence = partial(
|
222
|
+
self._get_residual_state_probabilities, table_name="prevalence"
|
223
|
+
)
|
243
224
|
|
244
|
-
|
245
|
-
|
225
|
+
return residual_state
|
226
|
+
|
227
|
+
def _get_residual_state_probabilities(
|
228
|
+
self, builder: Builder, table_name: str
|
229
|
+
) -> LookupTableData:
|
230
|
+
"""Calculate the probabilities of the residual state based on the other states."""
|
231
|
+
non_residual_states = [s for s in self.states if s != self.residual_state]
|
232
|
+
non_residual_probabilities = 0
|
233
|
+
for state in non_residual_states:
|
234
|
+
weights_source = builder.configuration[state.name].data_sources[table_name]
|
235
|
+
weights = state.get_data(builder, weights_source)
|
236
|
+
if isinstance(weights, pd.DataFrame):
|
237
|
+
weights = weights.set_index(
|
238
|
+
[c for c in weights.columns if c != "value"]
|
239
|
+
).squeeze()
|
240
|
+
non_residual_probabilities += weights
|
241
|
+
|
242
|
+
residual_probabilities = 1 - non_residual_probabilities
|
243
|
+
|
244
|
+
if pd.Series(residual_probabilities < 0).any():
|
245
|
+
raise ValueError(
|
246
|
+
f"The {table_name} for the states in the DiseaseModel must sum"
|
247
|
+
" to less than 1."
|
248
|
+
)
|
249
|
+
if isinstance(residual_probabilities, pd.Series):
|
250
|
+
residual_probabilities = residual_probabilities.reset_index()
|
246
251
|
|
247
|
-
|
248
|
-
return simulants
|
252
|
+
return residual_probabilities
|
@@ -6,19 +6,20 @@ Disease States
|
|
6
6
|
This module contains tools to manage standard disease states.
|
7
7
|
|
8
8
|
"""
|
9
|
-
|
9
|
+
import warnings
|
10
10
|
from collections.abc import Callable
|
11
11
|
from typing import Any
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
import pandas as pd
|
15
15
|
from vivarium.framework.engine import Builder
|
16
|
-
from vivarium.framework.lookup import LookupTableData
|
17
16
|
from vivarium.framework.population import PopulationView, SimulantData
|
18
17
|
from vivarium.framework.randomness import RandomnessStream
|
19
18
|
from vivarium.framework.state_machine import State, Transient, Transition, Trigger
|
20
19
|
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
20
|
+
from vivarium.types import DataInput, LookupTableData
|
21
21
|
|
22
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
22
23
|
from vivarium_public_health.disease.transition import (
|
23
24
|
ProportionTransition,
|
24
25
|
RateTransition,
|
@@ -28,10 +29,25 @@ from vivarium_public_health.utilities import get_lookup_columns, is_non_zero
|
|
28
29
|
|
29
30
|
|
30
31
|
class BaseDiseaseState(State):
|
32
|
+
|
31
33
|
##############
|
32
34
|
# Properties #
|
33
35
|
##############
|
34
36
|
|
37
|
+
@property
|
38
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
39
|
+
configuration_defaults = super().configuration_defaults
|
40
|
+
additional_defaults = {
|
41
|
+
"prevalence": self.prevalence,
|
42
|
+
"birth_prevalence": self.birth_prevalence,
|
43
|
+
}
|
44
|
+
data_sources = {
|
45
|
+
**configuration_defaults[self.name]["data_sources"],
|
46
|
+
**additional_defaults,
|
47
|
+
}
|
48
|
+
configuration_defaults[self.name]["data_sources"] = data_sources
|
49
|
+
return configuration_defaults
|
50
|
+
|
35
51
|
@property
|
36
52
|
def columns_created(self):
|
37
53
|
return [self.event_time_column, self.event_count_column]
|
@@ -59,13 +75,15 @@ class BaseDiseaseState(State):
|
|
59
75
|
side_effect_function: Callable | None = None,
|
60
76
|
cause_type: str = "cause",
|
61
77
|
):
|
62
|
-
super().__init__(state_id, allow_self_transition)
|
78
|
+
super().__init__(state_id, allow_self_transition)
|
63
79
|
self.cause_type = cause_type
|
64
80
|
|
65
81
|
self.side_effect_function = side_effect_function
|
66
82
|
|
67
83
|
self.event_time_column = self.state_id + "_event_time"
|
68
84
|
self.event_count_column = self.state_id + "_event_count"
|
85
|
+
self.prevalence = 0.0
|
86
|
+
self.birth_prevalence = 0.0
|
69
87
|
|
70
88
|
########################
|
71
89
|
# Event-driven methods #
|
@@ -87,10 +105,7 @@ class BaseDiseaseState(State):
|
|
87
105
|
def get_initialization_parameters(self) -> dict[str, Any]:
|
88
106
|
"""Exclude side effect function and cause type from name and __repr__."""
|
89
107
|
initialization_parameters = super().get_initialization_parameters()
|
90
|
-
|
91
|
-
if key in initialization_parameters.keys():
|
92
|
-
del initialization_parameters[key]
|
93
|
-
return initialization_parameters
|
108
|
+
return {"state_id": initialization_parameters["state_id"]}
|
94
109
|
|
95
110
|
def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame:
|
96
111
|
return pd.DataFrame(
|
@@ -132,6 +147,8 @@ class BaseDiseaseState(State):
|
|
132
147
|
output: "BaseDiseaseState",
|
133
148
|
get_data_functions: dict[str, Callable] = None,
|
134
149
|
triggered=Trigger.NOT_TRIGGERED,
|
150
|
+
transition_rate: DataInput | None = None,
|
151
|
+
rate_type: str = "transition_rate",
|
135
152
|
) -> RateTransition:
|
136
153
|
"""Builds a RateTransition from this state to the given state.
|
137
154
|
|
@@ -139,18 +156,29 @@ class BaseDiseaseState(State):
|
|
139
156
|
----------
|
140
157
|
output
|
141
158
|
The end state after the transition.
|
142
|
-
|
143
159
|
get_data_functions
|
144
160
|
Map from transition type to the function to pull that transition's data.
|
145
161
|
triggered
|
146
162
|
The trigger for the transition
|
147
|
-
|
163
|
+
transition_rate
|
164
|
+
The transition rate source. Can be the data itself, a function to
|
165
|
+
retrieve the data, or the artifact key containing the data.
|
166
|
+
rate_type
|
167
|
+
The type of rate. Can be "incidence_rate", "transition_rate", or
|
168
|
+
"remission_rate".
|
148
169
|
|
149
170
|
Returns
|
150
171
|
-------
|
151
172
|
The created transition object.
|
152
173
|
"""
|
153
|
-
transition = RateTransition(
|
174
|
+
transition = RateTransition(
|
175
|
+
input_state=self,
|
176
|
+
output_state=output,
|
177
|
+
get_data_functions=get_data_functions,
|
178
|
+
triggered=triggered,
|
179
|
+
transition_rate=transition_rate,
|
180
|
+
rate_type=rate_type,
|
181
|
+
)
|
154
182
|
self.add_transition(transition)
|
155
183
|
return transition
|
156
184
|
|
@@ -159,6 +187,7 @@ class BaseDiseaseState(State):
|
|
159
187
|
output: "BaseDiseaseState",
|
160
188
|
get_data_functions: dict[str, Callable] | None = None,
|
161
189
|
triggered=Trigger.NOT_TRIGGERED,
|
190
|
+
proportion: DataInput | None = None,
|
162
191
|
) -> ProportionTransition:
|
163
192
|
"""Builds a ProportionTransition from this state to the given state.
|
164
193
|
|
@@ -170,15 +199,26 @@ class BaseDiseaseState(State):
|
|
170
199
|
Map from transition type to the function to pull that transition's data.
|
171
200
|
triggered
|
172
201
|
The trigger for the transition.
|
202
|
+
proportion
|
203
|
+
The proportion source. Can be the data itself, a function to
|
204
|
+
retrieve the data, or the artifact key containing the data.
|
173
205
|
|
174
206
|
Returns
|
175
207
|
-------
|
176
208
|
The created transition object.
|
177
209
|
"""
|
178
|
-
if
|
210
|
+
if (
|
211
|
+
get_data_functions is None or "proportion" not in get_data_functions
|
212
|
+
) and proportion is None:
|
179
213
|
raise ValueError("You must supply a proportion function.")
|
180
214
|
|
181
|
-
transition = ProportionTransition(
|
215
|
+
transition = ProportionTransition(
|
216
|
+
input_state=self,
|
217
|
+
output_state=output,
|
218
|
+
get_data_functions=get_data_functions,
|
219
|
+
triggered=triggered,
|
220
|
+
proportion=proportion,
|
221
|
+
)
|
182
222
|
self.add_transition(transition)
|
183
223
|
return transition
|
184
224
|
|
@@ -220,17 +260,19 @@ class NonDiseasedState(BaseDiseaseState):
|
|
220
260
|
self,
|
221
261
|
output: BaseDiseaseState,
|
222
262
|
get_data_functions: dict[str, Callable] = None,
|
223
|
-
|
263
|
+
triggered=Trigger.NOT_TRIGGERED,
|
264
|
+
transition_rate: DataInput | None = None,
|
265
|
+
**_kwargs,
|
224
266
|
) -> RateTransition:
|
225
|
-
if get_data_functions is None:
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
267
|
+
if get_data_functions is None and transition_rate is None:
|
268
|
+
transition_rate = f"{self.cause_type}.{output.state_id}.incidence_rate"
|
269
|
+
return super().add_rate_transition(
|
270
|
+
output=output,
|
271
|
+
get_data_functions=get_data_functions,
|
272
|
+
triggered=triggered,
|
273
|
+
transition_rate=transition_rate,
|
274
|
+
rate_type="incidence_rate",
|
275
|
+
)
|
234
276
|
|
235
277
|
|
236
278
|
class SusceptibleState(NonDiseasedState):
|
@@ -253,6 +295,13 @@ class SusceptibleState(NonDiseasedState):
|
|
253
295
|
name_prefix="susceptible_to_",
|
254
296
|
)
|
255
297
|
|
298
|
+
##################
|
299
|
+
# Public methods #
|
300
|
+
##################
|
301
|
+
|
302
|
+
def has_initialization_weights(self) -> bool:
|
303
|
+
return True
|
304
|
+
|
256
305
|
|
257
306
|
class RecoveredState(NonDiseasedState):
|
258
307
|
def __init__(
|
@@ -280,17 +329,20 @@ class DiseaseState(BaseDiseaseState):
|
|
280
329
|
|
281
330
|
@property
|
282
331
|
def configuration_defaults(self) -> dict[str, Any]:
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
"excess_mortality_rate": self.load_excess_mortality_rate,
|
291
|
-
},
|
292
|
-
},
|
332
|
+
configuration_defaults = super().configuration_defaults
|
333
|
+
additional_defaults = {
|
334
|
+
"prevalence": self._prevalence_source,
|
335
|
+
"birth_prevalence": self._birth_prevalence_source,
|
336
|
+
"dwell_time": self._dwell_time_source,
|
337
|
+
"disability_weight": self._disability_weight_source,
|
338
|
+
"excess_mortality_rate": self._excess_modality_rate_source,
|
293
339
|
}
|
340
|
+
data_sources = {
|
341
|
+
**configuration_defaults[self.name]["data_sources"],
|
342
|
+
**additional_defaults,
|
343
|
+
}
|
344
|
+
configuration_defaults[self.name]["data_sources"] = data_sources
|
345
|
+
return configuration_defaults
|
294
346
|
|
295
347
|
#####################
|
296
348
|
# Lifecycle methods #
|
@@ -304,6 +356,11 @@ class DiseaseState(BaseDiseaseState):
|
|
304
356
|
cause_type: str = "cause",
|
305
357
|
get_data_functions: dict[str, Callable] | None = None,
|
306
358
|
cleanup_function: Callable | None = None,
|
359
|
+
prevalence: DataInput | None = None,
|
360
|
+
birth_prevalence: DataInput | None = None,
|
361
|
+
dwell_time: DataInput | None = None,
|
362
|
+
disability_weight: DataInput | None = None,
|
363
|
+
excess_mortality_rate: DataInput | None = None,
|
307
364
|
):
|
308
365
|
"""
|
309
366
|
Parameters
|
@@ -322,7 +379,31 @@ class DiseaseState(BaseDiseaseState):
|
|
322
379
|
various state attributes.
|
323
380
|
cleanup_function
|
324
381
|
The cleanup function.
|
382
|
+
prevalence
|
383
|
+
The prevalence source. This is used to initialize simulants. Can be
|
384
|
+
the data itself, a function to retrieve the data, or the artifact
|
385
|
+
key containing the data.
|
386
|
+
birth_prevalence
|
387
|
+
The birth prevalence source. This is used to initialize newborn
|
388
|
+
simulants. Can be the data itself, a function to retrieve the data,
|
389
|
+
or the artifact key containing the data.
|
390
|
+
dwell_time
|
391
|
+
The dwell time source. This is used to determine how long a simulant
|
392
|
+
must remain in the state before transitioning. Can be the data
|
393
|
+
itself, a function to retrieve the data, or the artifact key
|
394
|
+
containing the data.
|
395
|
+
disability_weight
|
396
|
+
The disability weight source. This is used to calculate the
|
397
|
+
disability weight for simulants in this state. Can be the data
|
398
|
+
itself, a function to retrieve the data, or the artifact key
|
399
|
+
containing the data.
|
400
|
+
excess_mortality_rate
|
401
|
+
The excess mortality rate source. This is used to calculate the
|
402
|
+
excess mortality rate for simulants in this state. Can be the data
|
403
|
+
itself, a function to retrieve the data, or the artifact key
|
404
|
+
containing the data.
|
325
405
|
"""
|
406
|
+
|
326
407
|
super().__init__(
|
327
408
|
state_id,
|
328
409
|
allow_self_transition=allow_self_transition,
|
@@ -340,6 +421,34 @@ class DiseaseState(BaseDiseaseState):
|
|
340
421
|
)
|
341
422
|
self._cleanup_function = cleanup_function
|
342
423
|
|
424
|
+
if get_data_functions is not None:
|
425
|
+
warnings.warn(
|
426
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
427
|
+
" cause_specific_mortality_rate instead.",
|
428
|
+
DeprecationWarning,
|
429
|
+
stacklevel=2,
|
430
|
+
)
|
431
|
+
|
432
|
+
for data_type in self._get_data_functions:
|
433
|
+
try:
|
434
|
+
data_source = locals()[data_type]
|
435
|
+
except KeyError:
|
436
|
+
data_source = None
|
437
|
+
|
438
|
+
if locals()[data_type] is not None:
|
439
|
+
raise DiseaseModelError(
|
440
|
+
f"It is not allowed to pass '{data_type}' both as a"
|
441
|
+
" stand-alone argument and as part of get_data_functions."
|
442
|
+
)
|
443
|
+
|
444
|
+
self._prevalence_source = self.get_prevalence_source(prevalence)
|
445
|
+
self._birth_prevalence_source = self.get_birth_prevalence_source(birth_prevalence)
|
446
|
+
self._dwell_time_source = self.get_dwell_time_source(dwell_time)
|
447
|
+
self._disability_weight_source = self.get_disability_weight_source(disability_weight)
|
448
|
+
self._excess_modality_rate_source = self.get_excess_mortality_rate_source(
|
449
|
+
excess_mortality_rate
|
450
|
+
)
|
451
|
+
|
343
452
|
# noinspection PyAttributeOutsideInit
|
344
453
|
def setup(self, builder: Builder) -> None:
|
345
454
|
"""Performs this component's simulation setup.
|
@@ -377,32 +486,45 @@ class DiseaseState(BaseDiseaseState):
|
|
377
486
|
# Setup methods #
|
378
487
|
#################
|
379
488
|
|
380
|
-
def
|
489
|
+
def _get_data_functions_source(self, data_type: str) -> DataInput:
|
490
|
+
def data_source(builder: Builder) -> LookupTableData:
|
491
|
+
return self._get_data_functions[data_type](builder, self.state_id)
|
492
|
+
|
493
|
+
return data_source
|
494
|
+
|
495
|
+
def get_prevalence_source(self, prevalence: DataInput | None) -> DataInput:
|
381
496
|
if "prevalence" in self._get_data_functions:
|
382
|
-
return self.
|
497
|
+
return self._get_data_functions_source("prevalence")
|
498
|
+
elif prevalence is not None:
|
499
|
+
return prevalence
|
383
500
|
else:
|
384
|
-
return
|
501
|
+
return f"{self.cause_type}.{self.state_id}.prevalence"
|
385
502
|
|
386
|
-
def
|
503
|
+
def get_birth_prevalence_source(self, birth_prevalence: DataInput | None) -> DataInput:
|
387
504
|
if "birth_prevalence" in self._get_data_functions:
|
388
|
-
return self.
|
505
|
+
return self._get_data_functions_source("birth_prevalence")
|
506
|
+
elif birth_prevalence is not None:
|
507
|
+
return birth_prevalence
|
389
508
|
else:
|
390
|
-
return 0
|
509
|
+
return 0.0
|
391
510
|
|
392
|
-
def
|
511
|
+
def get_dwell_time_source(self, dwell_time: DataInput | None) -> DataInput:
|
393
512
|
if "dwell_time" in self._get_data_functions:
|
394
|
-
dwell_time = self.
|
395
|
-
|
396
|
-
dwell_time = 0
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
513
|
+
dwell_time = self._get_data_functions_source("dwell_time")
|
514
|
+
elif dwell_time is None:
|
515
|
+
dwell_time = 0.0
|
516
|
+
|
517
|
+
def dwell_time_source(builder: Builder) -> LookupTableData:
|
518
|
+
dwell_time_ = self.get_data(builder, dwell_time)
|
519
|
+
if isinstance(dwell_time_, pd.Timedelta):
|
520
|
+
dwell_time_ = dwell_time_.total_seconds() / (60 * 60 * 24)
|
521
|
+
if (
|
522
|
+
isinstance(dwell_time_, pd.DataFrame) and np.any(dwell_time_.value != 0)
|
523
|
+
) or dwell_time_ > 0:
|
524
|
+
self.transition_set.allow_null_transition = True
|
525
|
+
return dwell_time_
|
526
|
+
|
527
|
+
return dwell_time_source
|
406
528
|
|
407
529
|
def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
|
408
530
|
required_columns = get_lookup_columns([self.lookup_tables["dwell_time"]])
|
@@ -412,20 +534,22 @@ class DiseaseState(BaseDiseaseState):
|
|
412
534
|
requires_columns=required_columns,
|
413
535
|
)
|
414
536
|
|
415
|
-
def
|
537
|
+
def get_disability_weight_source(self, disability_weight: DataInput | None) -> DataInput:
|
416
538
|
if "disability_weight" in self._get_data_functions:
|
417
|
-
disability_weight = self.
|
418
|
-
|
419
|
-
|
539
|
+
disability_weight = self._get_data_functions_source("disability_weight")
|
540
|
+
elif disability_weight is not None:
|
541
|
+
disability_weight = disability_weight
|
420
542
|
else:
|
421
|
-
disability_weight =
|
422
|
-
f"{self.cause_type}.{self.state_id}.disability_weight"
|
423
|
-
)
|
543
|
+
disability_weight = f"{self.cause_type}.{self.state_id}.disability_weight"
|
424
544
|
|
425
|
-
|
426
|
-
|
545
|
+
def disability_weight_source(builder: Builder) -> LookupTableData:
|
546
|
+
disability_weight_ = self.get_data(builder, disability_weight)
|
547
|
+
if isinstance(disability_weight_, pd.DataFrame) and len(disability_weight_) == 1:
|
548
|
+
# sequela only have single value
|
549
|
+
disability_weight_ = disability_weight_.value[0]
|
550
|
+
return disability_weight_
|
427
551
|
|
428
|
-
return
|
552
|
+
return disability_weight_source
|
429
553
|
|
430
554
|
def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
|
431
555
|
lookup_columns = get_lookup_columns([self.lookup_tables["disability_weight"]])
|
@@ -435,16 +559,25 @@ class DiseaseState(BaseDiseaseState):
|
|
435
559
|
requires_columns=lookup_columns + ["alive", self.model],
|
436
560
|
)
|
437
561
|
|
438
|
-
def
|
562
|
+
def get_excess_mortality_rate_source(
|
563
|
+
self, excess_mortality_rate: DataInput | None
|
564
|
+
) -> DataInput:
|
439
565
|
if "excess_mortality_rate" in self._get_data_functions:
|
440
|
-
|
441
|
-
elif
|
442
|
-
|
443
|
-
|
566
|
+
excess_mortality_rate = self._get_data_functions_source("excess_mortality_rate")
|
567
|
+
elif excess_mortality_rate is None:
|
568
|
+
excess_mortality_rate = f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
|
569
|
+
|
570
|
+
def excess_mortality_rate_source(builder: Builder) -> LookupTableData:
|
571
|
+
if excess_mortality_rate is not None:
|
572
|
+
return self.get_data(builder, excess_mortality_rate)
|
573
|
+
elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
|
574
|
+
return 0
|
444
575
|
return builder.data.load(
|
445
576
|
f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
|
446
577
|
)
|
447
578
|
|
579
|
+
return excess_mortality_rate_source
|
580
|
+
|
448
581
|
def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
|
449
582
|
lookup_columns = get_lookup_columns([self.lookup_tables["excess_mortality_rate"]])
|
450
583
|
return builder.value.register_rate_producer(
|
@@ -470,24 +603,27 @@ class DiseaseState(BaseDiseaseState):
|
|
470
603
|
# Public methods #
|
471
604
|
##################
|
472
605
|
|
606
|
+
def has_initialization_weights(self) -> bool:
|
607
|
+
return True
|
608
|
+
|
473
609
|
def add_rate_transition(
|
474
610
|
self,
|
475
611
|
output: BaseDiseaseState,
|
476
612
|
get_data_functions: dict[str, Callable] = None,
|
477
|
-
|
613
|
+
triggered=Trigger.NOT_TRIGGERED,
|
614
|
+
transition_rate: DataInput | None = None,
|
615
|
+
rate_type: str = "transition_rate",
|
478
616
|
) -> RateTransition:
|
479
|
-
if get_data_functions is None:
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
)
|
489
|
-
raise ValueError("You must supply a transition rate or remission rate function.")
|
490
|
-
return super().add_rate_transition(output, get_data_functions, **kwargs)
|
617
|
+
if get_data_functions is None and transition_rate is None:
|
618
|
+
transition_rate = f"{self.cause_type}.{self.state_id}.remission_rate"
|
619
|
+
rate_type = "remission_rate"
|
620
|
+
return super().add_rate_transition(
|
621
|
+
output=output,
|
622
|
+
get_data_functions=get_data_functions,
|
623
|
+
triggered=triggered,
|
624
|
+
transition_rate=transition_rate,
|
625
|
+
rate_type=rate_type,
|
626
|
+
)
|
491
627
|
|
492
628
|
def add_dwell_time_transition(
|
493
629
|
self,
|
@@ -6,7 +6,7 @@ Disease Transitions
|
|
6
6
|
This module contains tools to model transitions between disease states.
|
7
7
|
|
8
8
|
"""
|
9
|
-
|
9
|
+
import warnings
|
10
10
|
from collections.abc import Callable
|
11
11
|
from typing import TYPE_CHECKING, Any
|
12
12
|
|
@@ -15,7 +15,9 @@ from vivarium.framework.engine import Builder
|
|
15
15
|
from vivarium.framework.state_machine import Transition, Trigger
|
16
16
|
from vivarium.framework.utilities import rate_to_probability
|
17
17
|
from vivarium.framework.values import list_combiner, union_post_processor
|
18
|
+
from vivarium.types import DataInput
|
18
19
|
|
20
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
19
21
|
from vivarium_public_health.utilities import get_lookup_columns
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
@@ -44,7 +46,7 @@ class RateTransition(Transition):
|
|
44
46
|
return {
|
45
47
|
f"{self.name}": {
|
46
48
|
"data_sources": {
|
47
|
-
"transition_rate": self.
|
49
|
+
"transition_rate": self._rate_source,
|
48
50
|
},
|
49
51
|
},
|
50
52
|
}
|
@@ -55,19 +57,37 @@ class RateTransition(Transition):
|
|
55
57
|
|
56
58
|
@property
|
57
59
|
def transition_rate_pipeline_name(self) -> str:
|
58
|
-
if
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
60
|
+
if self._get_data_functions:
|
61
|
+
if "incidence_rate" in self._get_data_functions:
|
62
|
+
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
|
63
|
+
elif "remission_rate" in self._get_data_functions:
|
64
|
+
pipeline_name = f"{self.input_state.state_id}.remission_rate"
|
65
|
+
elif "transition_rate" in self._get_data_functions:
|
66
|
+
pipeline_name = (
|
67
|
+
f"{self.input_state.state_id}_to_{self.output_state.state_id}"
|
68
|
+
".transition_rate"
|
69
|
+
)
|
70
|
+
else:
|
71
|
+
raise DiseaseModelError(
|
72
|
+
"Cannot determine rate_transition pipeline name: "
|
73
|
+
"no valid data functions supplied."
|
74
|
+
)
|
66
75
|
else:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
76
|
+
if self.rate_type == "incidence_rate":
|
77
|
+
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
|
78
|
+
elif self.rate_type == "remission_rate":
|
79
|
+
pipeline_name = f"{self.input_state.state_id}.remission_rate"
|
80
|
+
elif self.rate_type == "transition_rate":
|
81
|
+
pipeline_name = (
|
82
|
+
f"{self.input_state.state_id}_to_{self.output_state.state_id}"
|
83
|
+
".transition_rate"
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
raise DiseaseModelError(
|
87
|
+
"Cannot determine rate_transition pipeline name: invalid"
|
88
|
+
f" rate_type '{self.rate_type} supplied."
|
89
|
+
)
|
90
|
+
|
71
91
|
return pipeline_name
|
72
92
|
|
73
93
|
#####################
|
@@ -80,6 +100,8 @@ class RateTransition(Transition):
|
|
80
100
|
output_state: "BaseDiseaseState",
|
81
101
|
get_data_functions: dict[str, Callable] = None,
|
82
102
|
triggered=Trigger.NOT_TRIGGERED,
|
103
|
+
transition_rate: DataInput | None = None,
|
104
|
+
rate_type: str = "transition_rate",
|
83
105
|
):
|
84
106
|
super().__init__(
|
85
107
|
input_state, output_state, probability_func=self._probability, triggered=triggered
|
@@ -87,6 +109,22 @@ class RateTransition(Transition):
|
|
87
109
|
self._get_data_functions = (
|
88
110
|
get_data_functions if get_data_functions is not None else {}
|
89
111
|
)
|
112
|
+
self._rate_source = self._get_rate_source(transition_rate)
|
113
|
+
self.rate_type = rate_type
|
114
|
+
|
115
|
+
if get_data_functions is not None:
|
116
|
+
warnings.warn(
|
117
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
118
|
+
" 'transition_rate' instead.",
|
119
|
+
DeprecationWarning,
|
120
|
+
stacklevel=2,
|
121
|
+
)
|
122
|
+
if transition_rate is not None:
|
123
|
+
raise DiseaseModelError(
|
124
|
+
"It is not allowed to pass a transition rate"
|
125
|
+
" both as a stand-alone argument and as part of"
|
126
|
+
" get_data_functions."
|
127
|
+
)
|
90
128
|
|
91
129
|
# noinspection PyAttributeOutsideInit
|
92
130
|
def setup(self, builder: Builder) -> None:
|
@@ -109,17 +147,19 @@ class RateTransition(Transition):
|
|
109
147
|
# Setup methods #
|
110
148
|
#################
|
111
149
|
|
112
|
-
def
|
113
|
-
if
|
114
|
-
rate_data =
|
150
|
+
def _get_rate_source(self, transition_rate: DataInput | None) -> DataInput:
|
151
|
+
if transition_rate is not None:
|
152
|
+
rate_data = transition_rate
|
153
|
+
elif "incidence_rate" in self._get_data_functions:
|
154
|
+
rate_data = lambda builder: self._get_data_functions["incidence_rate"](
|
115
155
|
builder, self.output_state.state_id
|
116
156
|
)
|
117
157
|
elif "remission_rate" in self._get_data_functions:
|
118
|
-
rate_data = self._get_data_functions["remission_rate"](
|
158
|
+
rate_data = lambda builder: self._get_data_functions["remission_rate"](
|
119
159
|
builder, self.input_state.state_id
|
120
160
|
)
|
121
161
|
elif "transition_rate" in self._get_data_functions:
|
122
|
-
rate_data = self._get_data_functions["transition_rate"](
|
162
|
+
rate_data = lambda builder: self._get_data_functions["transition_rate"](
|
123
163
|
builder, self.input_state.state_id, self.output_state.state_id
|
124
164
|
)
|
125
165
|
else:
|
@@ -172,21 +212,38 @@ class ProportionTransition(Transition):
|
|
172
212
|
output_state: "BaseDiseaseState",
|
173
213
|
get_data_functions: dict[str, Callable] = None,
|
174
214
|
triggered=Trigger.NOT_TRIGGERED,
|
215
|
+
proportion: DataInput | None = None,
|
175
216
|
):
|
176
217
|
super().__init__(
|
177
218
|
input_state, output_state, probability_func=self._probability, triggered=triggered
|
178
219
|
)
|
220
|
+
self._proportion_source = proportion
|
179
221
|
self._get_data_functions = (
|
180
222
|
get_data_functions if get_data_functions is not None else {}
|
181
223
|
)
|
182
224
|
|
225
|
+
if get_data_functions is not None:
|
226
|
+
warnings.warn(
|
227
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
228
|
+
" 'proportion' instead.",
|
229
|
+
DeprecationWarning,
|
230
|
+
stacklevel=2,
|
231
|
+
)
|
232
|
+
if proportion is not None:
|
233
|
+
raise DiseaseModelError(
|
234
|
+
"It is not allowed to pass a proportion both as a"
|
235
|
+
" stand-alone argument and as part of get_data_functions."
|
236
|
+
)
|
237
|
+
|
183
238
|
#################
|
184
239
|
# Setup methods #
|
185
240
|
#################
|
186
241
|
|
187
|
-
def load_proportion(self, builder: Builder) ->
|
242
|
+
def load_proportion(self, builder: Builder) -> DataInput:
|
243
|
+
if self._proportion_source is not None:
|
244
|
+
return self._proportion_source
|
188
245
|
if "proportion" not in self._get_data_functions:
|
189
|
-
raise
|
246
|
+
raise DiseaseModelError("Must supply a proportion function")
|
190
247
|
return self._get_data_functions["proportion"](builder, self.output_state.state_id)
|
191
248
|
|
192
249
|
def _probability(self, index):
|
@@ -17,6 +17,7 @@ import scipy
|
|
17
17
|
from layered_config_tree import ConfigurationError
|
18
18
|
from vivarium import Component
|
19
19
|
from vivarium.framework.engine import Builder
|
20
|
+
from vivarium.framework.values import Pipeline
|
20
21
|
|
21
22
|
from vivarium_public_health.risks import Risk
|
22
23
|
from vivarium_public_health.risks.data_transformations import (
|
@@ -112,7 +113,8 @@ class RiskEffect(Component):
|
|
112
113
|
def setup(self, builder: Builder) -> None:
|
113
114
|
self.exposure = self.get_risk_exposure(builder)
|
114
115
|
|
115
|
-
self.
|
116
|
+
self._relative_risk_source = self.get_relative_risk_source(builder)
|
117
|
+
self.relative_risk = self.get_relative_risk(builder)
|
116
118
|
|
117
119
|
self.register_target_modifier(builder)
|
118
120
|
self.register_paf_modifier(builder)
|
@@ -124,7 +126,7 @@ class RiskEffect(Component):
|
|
124
126
|
def build_all_lookup_tables(self, builder: Builder) -> None:
|
125
127
|
self._exposure_distribution_type = self.get_distribution_type(builder)
|
126
128
|
|
127
|
-
rr_data = self.
|
129
|
+
rr_data = self.load_relative_risk(builder)
|
128
130
|
rr_value_cols = None
|
129
131
|
if self.is_exposure_categorical:
|
130
132
|
rr_data, rr_value_cols = self.process_categorical_data(builder, rr_data)
|
@@ -146,7 +148,7 @@ class RiskEffect(Component):
|
|
146
148
|
return risk_exposure_component.distribution_type
|
147
149
|
return risk_exposure_component.get_distribution_type(builder)
|
148
150
|
|
149
|
-
def
|
151
|
+
def load_relative_risk(
|
150
152
|
self,
|
151
153
|
builder: Builder,
|
152
154
|
configuration=None,
|
@@ -260,24 +262,27 @@ class RiskEffect(Component):
|
|
260
262
|
def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
261
263
|
return builder.value.get_value(self.exposure_pipeline_name)
|
262
264
|
|
263
|
-
def
|
264
|
-
|
265
|
-
|
265
|
+
def adjust_target(self, index: pd.Index, target: pd.Series) -> pd.Series:
|
266
|
+
relative_risk = self.relative_risk(index)
|
267
|
+
return target * relative_risk
|
268
|
+
|
269
|
+
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
270
|
+
|
266
271
|
if not self.is_exposure_categorical:
|
267
272
|
tmred = builder.data.load(f"{self.risk}.tmred")
|
268
273
|
tmrel = 0.5 * (tmred["min"] + tmred["max"])
|
269
274
|
scale = builder.data.load(f"{self.risk}.relative_risk_scalar")
|
270
275
|
|
271
|
-
def
|
276
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
272
277
|
rr = self.lookup_tables["relative_risk"](index)
|
273
278
|
exposure = self.exposure(index)
|
274
279
|
relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1)
|
275
|
-
return
|
280
|
+
return relative_risk
|
276
281
|
|
277
282
|
else:
|
278
283
|
index_columns = ["index", self.risk.name]
|
279
284
|
|
280
|
-
def
|
285
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
281
286
|
rr = self.lookup_tables["relative_risk"](index)
|
282
287
|
exposure = self.exposure(index).reset_index()
|
283
288
|
exposure.columns = index_columns
|
@@ -288,16 +293,24 @@ class RiskEffect(Component):
|
|
288
293
|
relative_risk = relative_risk.set_index(index_columns)
|
289
294
|
|
290
295
|
effect = relative_risk.loc[exposure.index, "value"].droplevel(self.risk.name)
|
291
|
-
|
292
|
-
return affected_rates
|
296
|
+
return effect
|
293
297
|
|
294
|
-
return
|
298
|
+
return generate_relative_risk
|
299
|
+
|
300
|
+
def get_relative_risk(self, builder: Builder) -> Pipeline:
|
301
|
+
return builder.value.register_value_producer(
|
302
|
+
f"{self.risk.name}_on_{self.target.name}.relative_risk",
|
303
|
+
self._relative_risk_source,
|
304
|
+
component=self,
|
305
|
+
required_resources=[self.exposure],
|
306
|
+
)
|
295
307
|
|
296
308
|
def register_target_modifier(self, builder: Builder) -> None:
|
297
309
|
builder.value.register_value_modifier(
|
298
310
|
self.target_pipeline_name,
|
299
|
-
modifier=self.
|
300
|
-
|
311
|
+
modifier=self.adjust_target,
|
312
|
+
component=self,
|
313
|
+
required_resources=[self.relative_risk],
|
301
314
|
)
|
302
315
|
|
303
316
|
def register_paf_modifier(self, builder: Builder) -> None:
|
@@ -307,7 +320,8 @@ class RiskEffect(Component):
|
|
307
320
|
builder.value.register_value_modifier(
|
308
321
|
self.target_paf_pipeline_name,
|
309
322
|
modifier=self.lookup_tables["population_attributable_fraction"],
|
310
|
-
|
323
|
+
component=self,
|
324
|
+
required_resources=required_columns,
|
311
325
|
)
|
312
326
|
|
313
327
|
##################
|
@@ -371,7 +385,7 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
371
385
|
return f"non_log_linear_risk_effect.{risk.name}_on_{target}"
|
372
386
|
|
373
387
|
def build_all_lookup_tables(self, builder: Builder) -> None:
|
374
|
-
rr_data = self.
|
388
|
+
rr_data = self.load_relative_risk(builder)
|
375
389
|
self.validate_rr_data(rr_data)
|
376
390
|
|
377
391
|
def define_rr_intervals(df: pd.DataFrame) -> pd.DataFrame:
|
@@ -415,7 +429,7 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
415
429
|
builder, paf_data
|
416
430
|
)
|
417
431
|
|
418
|
-
def
|
432
|
+
def load_relative_risk(
|
419
433
|
self,
|
420
434
|
builder: Builder,
|
421
435
|
configuration=None,
|
@@ -472,10 +486,8 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
472
486
|
|
473
487
|
return rr_data
|
474
488
|
|
475
|
-
def
|
476
|
-
|
477
|
-
) -> Callable[[pd.Index, pd.Series], pd.Series]:
|
478
|
-
def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
|
489
|
+
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
490
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
479
491
|
rr_intervals = self.lookup_tables["relative_risk"](index)
|
480
492
|
exposure = self.population_view.get(index)[f"{self.risk.name}_exposure"]
|
481
493
|
x1, x2 = (
|
@@ -486,9 +498,9 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
486
498
|
m = (y2 - y1) / (x2 - x1)
|
487
499
|
b = y1 - m * x1
|
488
500
|
relative_risk = b + m * exposure
|
489
|
-
return
|
501
|
+
return relative_risk
|
490
502
|
|
491
|
-
return
|
503
|
+
return generate_relative_risk
|
492
504
|
|
493
505
|
##############
|
494
506
|
# Validators #
|
@@ -336,7 +336,6 @@ class LBWSGRiskEffect(RiskEffect):
|
|
336
336
|
|
337
337
|
super().setup(builder)
|
338
338
|
self.interpolator = self.get_interpolator(builder)
|
339
|
-
self.relative_risk = self.get_relative_risk_pipeline(builder)
|
340
339
|
|
341
340
|
#################
|
342
341
|
# Setup methods #
|
@@ -393,10 +392,10 @@ class LBWSGRiskEffect(RiskEffect):
|
|
393
392
|
for age_start in exposed_age_group_starts
|
394
393
|
}
|
395
394
|
|
396
|
-
def
|
395
|
+
def get_relative_risk(self, builder: Builder) -> Pipeline:
|
397
396
|
return builder.value.register_value_producer(
|
398
397
|
self.relative_risk_pipeline_name,
|
399
|
-
source=self.
|
398
|
+
source=self.get_relative_risk_source,
|
400
399
|
requires_columns=["age"] + self.rr_column_names,
|
401
400
|
)
|
402
401
|
|
@@ -470,7 +469,7 @@ class LBWSGRiskEffect(RiskEffect):
|
|
470
469
|
# Pipeline sources and modifiers #
|
471
470
|
##################################
|
472
471
|
|
473
|
-
def
|
472
|
+
def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
|
474
473
|
pop = self.population_view.get(index)
|
475
474
|
relative_risk = pd.Series(1.0, index=index, name=self.relative_risk_pipeline_name)
|
476
475
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: vivarium_public_health
|
3
|
-
Version: 3.1.
|
3
|
+
Version: 3.1.4
|
4
4
|
Summary: Components for modelling diseases, risks, and interventions with ``vivarium``
|
5
5
|
Home-page: https://github.com/ihmeuw/vivarium_public_health
|
6
6
|
Author: The vivarium developers
|
@@ -26,44 +26,44 @@ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
26
26
|
Classifier: Topic :: Scientific/Engineering :: Physics
|
27
27
|
Classifier: Topic :: Software Development :: Libraries
|
28
28
|
License-File: LICENSE.txt
|
29
|
-
Requires-Dist: vivarium
|
30
|
-
Requires-Dist:
|
29
|
+
Requires-Dist: vivarium>=3.2.3
|
30
|
+
Requires-Dist: layered_config_tree>=1.0.1
|
31
31
|
Requires-Dist: loguru
|
32
|
-
Requires-Dist: numpy
|
32
|
+
Requires-Dist: numpy<2.0.0
|
33
33
|
Requires-Dist: pandas
|
34
34
|
Requires-Dist: scipy
|
35
35
|
Requires-Dist: tables
|
36
|
-
Requires-Dist:
|
36
|
+
Requires-Dist: risk_distributions>=2.0.11
|
37
37
|
Requires-Dist: pyarrow
|
38
|
-
Provides-Extra: dev
|
39
|
-
Requires-Dist: sphinx <7.0 ; extra == 'dev'
|
40
|
-
Requires-Dist: sphinx-rtd-theme ; extra == 'dev'
|
41
|
-
Requires-Dist: sphinx-click ; extra == 'dev'
|
42
|
-
Requires-Dist: sphinx-autodoc-typehints ; extra == 'dev'
|
43
|
-
Requires-Dist: IPython ; extra == 'dev'
|
44
|
-
Requires-Dist: matplotlib ; extra == 'dev'
|
45
|
-
Requires-Dist: vivarium-testing-utils ; extra == 'dev'
|
46
|
-
Requires-Dist: pytest ; extra == 'dev'
|
47
|
-
Requires-Dist: pytest-cov ; extra == 'dev'
|
48
|
-
Requires-Dist: pytest-mock ; extra == 'dev'
|
49
|
-
Requires-Dist: hypothesis ; extra == 'dev'
|
50
|
-
Requires-Dist: pyyaml ; extra == 'dev'
|
51
|
-
Requires-Dist: black ==22.3.0 ; extra == 'dev'
|
52
|
-
Requires-Dist: isort ; extra == 'dev'
|
53
38
|
Provides-Extra: docs
|
54
|
-
Requires-Dist: sphinx
|
55
|
-
Requires-Dist: sphinx-rtd-theme
|
56
|
-
Requires-Dist: sphinx-click
|
57
|
-
Requires-Dist: sphinx-autodoc-typehints
|
58
|
-
Requires-Dist: IPython
|
59
|
-
Requires-Dist: matplotlib
|
39
|
+
Requires-Dist: sphinx<7.0; extra == "docs"
|
40
|
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
41
|
+
Requires-Dist: sphinx-click; extra == "docs"
|
42
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
|
43
|
+
Requires-Dist: IPython; extra == "docs"
|
44
|
+
Requires-Dist: matplotlib; extra == "docs"
|
60
45
|
Provides-Extra: test
|
61
|
-
Requires-Dist:
|
62
|
-
Requires-Dist: pytest
|
63
|
-
Requires-Dist: pytest-cov
|
64
|
-
Requires-Dist: pytest-mock
|
65
|
-
Requires-Dist: hypothesis
|
66
|
-
Requires-Dist: pyyaml
|
46
|
+
Requires-Dist: vivarium_testing_utils; extra == "test"
|
47
|
+
Requires-Dist: pytest; extra == "test"
|
48
|
+
Requires-Dist: pytest-cov; extra == "test"
|
49
|
+
Requires-Dist: pytest-mock; extra == "test"
|
50
|
+
Requires-Dist: hypothesis; extra == "test"
|
51
|
+
Requires-Dist: pyyaml; extra == "test"
|
52
|
+
Provides-Extra: dev
|
53
|
+
Requires-Dist: sphinx<7.0; extra == "dev"
|
54
|
+
Requires-Dist: sphinx-rtd-theme; extra == "dev"
|
55
|
+
Requires-Dist: sphinx-click; extra == "dev"
|
56
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "dev"
|
57
|
+
Requires-Dist: IPython; extra == "dev"
|
58
|
+
Requires-Dist: matplotlib; extra == "dev"
|
59
|
+
Requires-Dist: vivarium_testing_utils; extra == "dev"
|
60
|
+
Requires-Dist: pytest; extra == "dev"
|
61
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
62
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
63
|
+
Requires-Dist: hypothesis; extra == "dev"
|
64
|
+
Requires-Dist: pyyaml; extra == "dev"
|
65
|
+
Requires-Dist: black==22.3.0; extra == "dev"
|
66
|
+
Requires-Dist: isort; extra == "dev"
|
67
67
|
|
68
68
|
Vivarium Public Health
|
69
69
|
======================
|
@@ -1,13 +1,14 @@
|
|
1
1
|
vivarium_public_health/__about__.py,sha256=RgWycPypKZS80TpSX7o41cREnG8PfguNHDHLuLyl820,487
|
2
2
|
vivarium_public_health/__init__.py,sha256=GDeeP-7OlCBwPuv_xQoB1wNmvCaFsqfTB7qnnYApm0w,1343
|
3
|
-
vivarium_public_health/_version.py,sha256=
|
3
|
+
vivarium_public_health/_version.py,sha256=2uixSkocUHf2KiY1oTfzz_5AQGmlrHnypVxGgr4mV9c,22
|
4
4
|
vivarium_public_health/utilities.py,sha256=QNXQ6fhAr1HcV-GwKw7wQLz6QyuNxqNvMA-XujKjTgs,3035
|
5
5
|
vivarium_public_health/disease/__init__.py,sha256=VUJHDLlE6ngo2qHNQUtZ8OWH5H_T7_ao-xsYKDkRmHw,443
|
6
|
-
vivarium_public_health/disease/
|
6
|
+
vivarium_public_health/disease/exceptions.py,sha256=vb30IIV82OiDf2cNZCs_E2rF6mdDDHbnZSND60no5CU,97
|
7
|
+
vivarium_public_health/disease/model.py,sha256=qA0mhbIL0rfPtJ1Csmvxvf757vUdrjcbpxYmL_dsniI,9175
|
7
8
|
vivarium_public_health/disease/models.py,sha256=01UK7yB2zGPFzmlIpvhd-XnGe6vSCMDza3QTidgY7Nc,3479
|
8
9
|
vivarium_public_health/disease/special_disease.py,sha256=kTVuE5rQjUK62ysComG8nB2f61aCKdca9trRB1zsDCQ,14537
|
9
|
-
vivarium_public_health/disease/state.py,sha256=
|
10
|
-
vivarium_public_health/disease/transition.py,sha256=
|
10
|
+
vivarium_public_health/disease/state.py,sha256=l1iBZl8IpNOegQjZIY2xq09SBhsn5jSoYD89CymJAUs,28309
|
11
|
+
vivarium_public_health/disease/transition.py,sha256=qRpHzq29S_I_KzyDncAahg27bnHNdrA5tsUWuJ4Zty4,8964
|
11
12
|
vivarium_public_health/mslt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
13
|
vivarium_public_health/mslt/delay.py,sha256=TrFp9nmv-PSu6xdQXfyeM4X8RrXTLUnQMxOuhdvD-Wk,22787
|
13
14
|
vivarium_public_health/mslt/disease.py,sha256=DONSItBiOUk1qBE6Msw0vrV0XLW4BPzWVxwVK90GK1I,16638
|
@@ -35,15 +36,15 @@ vivarium_public_health/risks/__init__.py,sha256=z8DcnZGxqNVAyFZm2WAV-IVNGvrSS4iz
|
|
35
36
|
vivarium_public_health/risks/base_risk.py,sha256=XQ_7rYJS5gh0coEKDqcc_zYdjPDBZlj6-THsIQxL3zs,10888
|
36
37
|
vivarium_public_health/risks/data_transformations.py,sha256=SgdPKc95BBqgMNUdlAQM8k6iaXcpxnjk5B2ySTES1Yg,9269
|
37
38
|
vivarium_public_health/risks/distributions.py,sha256=a63-ihg2itxqgowDZbUix8soErxs_y8TRwsdtTCIUU4,18121
|
38
|
-
vivarium_public_health/risks/effect.py,sha256=
|
39
|
+
vivarium_public_health/risks/effect.py,sha256=2DaBKxncS94cm8Ih-TQtbV1mGsEZhx6fEnB5V_ocIZM,21241
|
39
40
|
vivarium_public_health/risks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
-
vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py,sha256=
|
41
|
+
vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py,sha256=Iw5e7sCrWplEu6dic1Wu6q9e0rdeem_6gQ_-N52G54E,17866
|
41
42
|
vivarium_public_health/treatment/__init__.py,sha256=wONElu9aJbBYwpYIovYPYaN_GYfVhPXtTeFWSdQMgA0,222
|
42
43
|
vivarium_public_health/treatment/magic_wand.py,sha256=zGIhrNgB9q6JD7fHlvbDQb3H5e_N_QsROO4Y0kl_JQM,1955
|
43
44
|
vivarium_public_health/treatment/scale_up.py,sha256=hVz0ELXDqlpcExI31rKdepxqcW_hy2hZSa6qCzv6udU,7020
|
44
45
|
vivarium_public_health/treatment/therapeutic_inertia.py,sha256=8Z97s7GfcpfLu1U1ESJSqeEk4L__a3M0GbBV21MFg2s,2346
|
45
|
-
vivarium_public_health-3.1.
|
46
|
-
vivarium_public_health-3.1.
|
47
|
-
vivarium_public_health-3.1.
|
48
|
-
vivarium_public_health-3.1.
|
49
|
-
vivarium_public_health-3.1.
|
46
|
+
vivarium_public_health-3.1.4.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
|
47
|
+
vivarium_public_health-3.1.4.dist-info/METADATA,sha256=7h8TxTjuThUizK3wr0mzLB31wwS4Kqr9SVgJA9Gf7gA,4028
|
48
|
+
vivarium_public_health-3.1.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
49
|
+
vivarium_public_health-3.1.4.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
|
50
|
+
vivarium_public_health-3.1.4.dist-info/RECORD,,
|
{vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/LICENSE.txt
RENAMED
File without changes
|
{vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/top_level.txt
RENAMED
File without changes
|