vivarium-public-health 3.1.2__py3-none-any.whl → 3.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/exceptions.py +5 -0
- vivarium_public_health/disease/model.py +116 -112
- vivarium_public_health/disease/state.py +215 -79
- vivarium_public_health/disease/transition.py +78 -21
- vivarium_public_health/risks/effect.py +35 -23
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +3 -4
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/METADATA +32 -32
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/RECORD +12 -11
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/WHEEL +1 -1
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
|
|
1
|
-
__version__ = "3.1.
|
1
|
+
__version__ = "3.1.4"
|
@@ -10,24 +10,20 @@ transitions at simulation initialization and during transitions.
|
|
10
10
|
"""
|
11
11
|
import warnings
|
12
12
|
from collections.abc import Callable, Iterable
|
13
|
+
from functools import partial
|
13
14
|
from typing import Any
|
14
15
|
|
15
|
-
import numpy as np
|
16
16
|
import pandas as pd
|
17
|
-
from vivarium.exceptions import VivariumError
|
18
17
|
from vivarium.framework.engine import Builder
|
19
|
-
from vivarium.framework.event import Event
|
20
18
|
from vivarium.framework.population import SimulantData
|
21
19
|
from vivarium.framework.state_machine import Machine
|
20
|
+
from vivarium.types import DataInput, LookupTableData
|
22
21
|
|
22
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
23
23
|
from vivarium_public_health.disease.state import BaseDiseaseState, SusceptibleState
|
24
24
|
from vivarium_public_health.disease.transition import TransitionString
|
25
25
|
|
26
26
|
|
27
|
-
class DiseaseModelError(VivariumError):
|
28
|
-
pass
|
29
|
-
|
30
|
-
|
31
27
|
class DiseaseModel(Machine):
|
32
28
|
|
33
29
|
##############
|
@@ -82,57 +78,77 @@ class DiseaseModel(Machine):
|
|
82
78
|
cause_type: str = "cause",
|
83
79
|
states: Iterable[BaseDiseaseState] = (),
|
84
80
|
residual_state: BaseDiseaseState | None = None,
|
85
|
-
|
86
|
-
):
|
87
|
-
super().__init__(cause, states=states
|
81
|
+
cause_specific_mortality_rate: DataInput | None = None,
|
82
|
+
) -> None:
|
83
|
+
super().__init__(cause, states=states)
|
88
84
|
self.cause = cause
|
89
85
|
self.cause_type = cause_type
|
86
|
+
self.residual_state = self._get_residual_state(initial_state, residual_state)
|
87
|
+
self._csmr_source = cause_specific_mortality_rate
|
88
|
+
self._get_data_functions = (
|
89
|
+
get_data_functions if get_data_functions is not None else {}
|
90
|
+
)
|
90
91
|
|
91
|
-
if
|
92
|
+
if get_data_functions is not None:
|
92
93
|
warnings.warn(
|
93
|
-
"
|
94
|
-
"
|
95
|
-
" retain the current behavior of defining a residual state, use"
|
96
|
-
" the 'residual_state' argument.",
|
94
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
95
|
+
" cause_specific_mortality_rate instead.",
|
97
96
|
DeprecationWarning,
|
98
97
|
stacklevel=2,
|
99
98
|
)
|
100
99
|
|
101
|
-
if
|
100
|
+
if cause_specific_mortality_rate is not None:
|
102
101
|
raise DiseaseModelError(
|
103
|
-
"
|
104
|
-
"
|
102
|
+
"It is not allowed to pass cause_specific_mortality_rate"
|
103
|
+
" both as a stand-alone argument and as part of"
|
104
|
+
" get_data_functions."
|
105
105
|
)
|
106
106
|
|
107
|
-
self.residual_state = initial_state.state_id
|
108
|
-
elif residual_state is not None:
|
109
|
-
self.residual_state = residual_state.state_id
|
110
|
-
else:
|
111
|
-
self.residual_state = self._get_default_residual_state()
|
112
|
-
|
113
|
-
self._get_data_functions = (
|
114
|
-
get_data_functions if get_data_functions is not None else {}
|
115
|
-
)
|
116
|
-
|
117
107
|
def setup(self, builder: Builder) -> None:
|
118
108
|
"""Perform this component's setup."""
|
119
109
|
super().setup(builder)
|
120
110
|
|
121
111
|
self.configuration_age_start = builder.configuration.population.initialization_age_min
|
122
112
|
self.configuration_age_end = builder.configuration.population.initialization_age_max
|
113
|
+
|
123
114
|
builder.value.register_value_modifier(
|
124
115
|
"cause_specific_mortality_rate",
|
125
116
|
self.adjust_cause_specific_mortality_rate,
|
126
117
|
requires_columns=["age", "sex"],
|
127
118
|
)
|
128
|
-
|
119
|
+
|
120
|
+
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
121
|
+
"""Initialize the simulants in the population.
|
122
|
+
|
123
|
+
If all simulants are initialized at age 0, birth prevalence is used.
|
124
|
+
Otherwise, prevalence is used.
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
pop_data
|
129
|
+
The population data object.
|
130
|
+
"""
|
131
|
+
if pop_data.user_data.get("age_end", self.configuration_age_end) == 0:
|
132
|
+
initialization_table_name = "birth_prevalence"
|
133
|
+
else:
|
134
|
+
initialization_table_name = "prevalence"
|
135
|
+
|
136
|
+
for state in self.states:
|
137
|
+
state.lookup_tables["initialization_weights"] = state.lookup_tables[
|
138
|
+
initialization_table_name
|
139
|
+
]
|
140
|
+
|
141
|
+
super().on_initialize_simulants(pop_data)
|
129
142
|
|
130
143
|
#################
|
131
144
|
# Setup methods #
|
132
145
|
#################
|
133
146
|
|
134
147
|
def load_cause_specific_mortality_rate(self, builder: Builder) -> float | pd.DataFrame:
|
135
|
-
if
|
148
|
+
if (
|
149
|
+
"cause_specific_mortality_rate" not in self._get_data_functions
|
150
|
+
and self._csmr_source is None
|
151
|
+
):
|
136
152
|
only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
|
137
153
|
if only_morbid:
|
138
154
|
csmr_data = 0.0
|
@@ -140,64 +156,14 @@ class DiseaseModel(Machine):
|
|
140
156
|
csmr_data = builder.data.load(
|
141
157
|
f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
|
142
158
|
)
|
159
|
+
elif self._csmr_source is not None:
|
160
|
+
csmr_data = self.get_data(builder, self._csmr_source)
|
143
161
|
else:
|
144
162
|
csmr_data = self._get_data_functions["cause_specific_mortality_rate"](
|
145
163
|
self.cause, builder
|
146
164
|
)
|
147
165
|
return csmr_data
|
148
166
|
|
149
|
-
########################
|
150
|
-
# Event-driven methods #
|
151
|
-
########################
|
152
|
-
|
153
|
-
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
154
|
-
population = self.population_view.subview(["age", "sex"]).get(pop_data.index)
|
155
|
-
|
156
|
-
assert self.residual_state in {s.state_id for s in self.states}
|
157
|
-
|
158
|
-
if pop_data.user_data["sim_state"] == "setup": # simulation start
|
159
|
-
if self.configuration_age_start != self.configuration_age_end != 0:
|
160
|
-
state_names, weights_bins = self.get_state_weights(
|
161
|
-
pop_data.index, "prevalence"
|
162
|
-
)
|
163
|
-
else:
|
164
|
-
raise NotImplementedError(
|
165
|
-
"We do not currently support an age 0 cohort. "
|
166
|
-
"configuration.population.initialization_age_min and "
|
167
|
-
"configuration.population.initialization_age_max "
|
168
|
-
"cannot both be 0."
|
169
|
-
)
|
170
|
-
|
171
|
-
else: # on time step
|
172
|
-
if pop_data.user_data["age_start"] == pop_data.user_data["age_end"] == 0:
|
173
|
-
state_names, weights_bins = self.get_state_weights(
|
174
|
-
pop_data.index, "birth_prevalence"
|
175
|
-
)
|
176
|
-
else:
|
177
|
-
state_names, weights_bins = self.get_state_weights(
|
178
|
-
pop_data.index, "prevalence"
|
179
|
-
)
|
180
|
-
|
181
|
-
if state_names and not population.empty:
|
182
|
-
# only do this if there are states in the model that supply prevalence data
|
183
|
-
population["sex_id"] = population.sex.apply({"Male": 1, "Female": 2}.get)
|
184
|
-
|
185
|
-
condition_column = self.assign_initial_status_to_simulants(
|
186
|
-
population,
|
187
|
-
state_names,
|
188
|
-
weights_bins,
|
189
|
-
self.randomness.get_draw(population.index),
|
190
|
-
)
|
191
|
-
|
192
|
-
condition_column = condition_column.rename(
|
193
|
-
columns={"condition_state": self.state_column}
|
194
|
-
)
|
195
|
-
else:
|
196
|
-
condition_column = pd.Series(
|
197
|
-
self.residual_state, index=population.index, name=self.state_column
|
198
|
-
)
|
199
|
-
self.population_view.update(condition_column)
|
200
|
-
|
201
167
|
##################################
|
202
168
|
# Pipeline sources and modifiers #
|
203
169
|
##################################
|
@@ -209,40 +175,78 @@ class DiseaseModel(Machine):
|
|
209
175
|
# Helper functions #
|
210
176
|
####################
|
211
177
|
|
212
|
-
def
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
return susceptible_states[0].state_id
|
217
|
-
|
218
|
-
def get_state_weights(
|
219
|
-
self, pop_index: pd.Index, prevalence_type: str
|
220
|
-
) -> tuple[list[str], np.ndarray | None]:
|
221
|
-
states = [state for state in self.states if state.lookup_tables.get(prevalence_type)]
|
222
|
-
|
223
|
-
if not states:
|
224
|
-
return states, None
|
178
|
+
def _get_residual_state(
|
179
|
+
self, initial_state: BaseDiseaseState, residual_state: BaseDiseaseState
|
180
|
+
) -> BaseDiseaseState:
|
181
|
+
"""Get the residual state for the DiseaseModel.
|
225
182
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
183
|
+
This will be the residual state if it is provided, otherwise it will be
|
184
|
+
the model's SusceptibleState. This method also calculates the residual
|
185
|
+
state's birth_prevalence and prevalence.
|
186
|
+
"""
|
187
|
+
if initial_state is not None:
|
188
|
+
warnings.warn(
|
189
|
+
"In the future, the 'initial_state' argument to DiseaseModel"
|
190
|
+
" will be used to initialize all simulants into that state. To"
|
191
|
+
" retain the current behavior of defining a residual state, use"
|
192
|
+
" the 'residual_state' argument.",
|
193
|
+
DeprecationWarning,
|
194
|
+
stacklevel=2,
|
195
|
+
)
|
230
196
|
|
231
|
-
|
232
|
-
|
197
|
+
if residual_state:
|
198
|
+
raise DiseaseModelError(
|
199
|
+
"A DiseaseModel cannot be initialized with both"
|
200
|
+
" 'initial_state and 'residual_state'."
|
201
|
+
)
|
233
202
|
|
234
|
-
|
203
|
+
residual_state = initial_state
|
204
|
+
elif residual_state is None:
|
205
|
+
susceptible_states = [s for s in self.states if isinstance(s, SusceptibleState)]
|
206
|
+
if len(susceptible_states) != 1:
|
207
|
+
raise DiseaseModelError(
|
208
|
+
"Disease model must have exactly one SusceptibleState."
|
209
|
+
)
|
210
|
+
residual_state = susceptible_states[0]
|
235
211
|
|
236
|
-
|
212
|
+
if residual_state not in self.states:
|
213
|
+
raise DiseaseModelError(
|
214
|
+
f"Residual state '{self.residual_state}' must be one of the"
|
215
|
+
f" states: {self.states}."
|
216
|
+
)
|
237
217
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
218
|
+
residual_state.birth_prevalence = partial(
|
219
|
+
self._get_residual_state_probabilities, table_name="birth_prevalence"
|
220
|
+
)
|
221
|
+
residual_state.prevalence = partial(
|
222
|
+
self._get_residual_state_probabilities, table_name="prevalence"
|
223
|
+
)
|
243
224
|
|
244
|
-
|
245
|
-
|
225
|
+
return residual_state
|
226
|
+
|
227
|
+
def _get_residual_state_probabilities(
|
228
|
+
self, builder: Builder, table_name: str
|
229
|
+
) -> LookupTableData:
|
230
|
+
"""Calculate the probabilities of the residual state based on the other states."""
|
231
|
+
non_residual_states = [s for s in self.states if s != self.residual_state]
|
232
|
+
non_residual_probabilities = 0
|
233
|
+
for state in non_residual_states:
|
234
|
+
weights_source = builder.configuration[state.name].data_sources[table_name]
|
235
|
+
weights = state.get_data(builder, weights_source)
|
236
|
+
if isinstance(weights, pd.DataFrame):
|
237
|
+
weights = weights.set_index(
|
238
|
+
[c for c in weights.columns if c != "value"]
|
239
|
+
).squeeze()
|
240
|
+
non_residual_probabilities += weights
|
241
|
+
|
242
|
+
residual_probabilities = 1 - non_residual_probabilities
|
243
|
+
|
244
|
+
if pd.Series(residual_probabilities < 0).any():
|
245
|
+
raise ValueError(
|
246
|
+
f"The {table_name} for the states in the DiseaseModel must sum"
|
247
|
+
" to less than 1."
|
248
|
+
)
|
249
|
+
if isinstance(residual_probabilities, pd.Series):
|
250
|
+
residual_probabilities = residual_probabilities.reset_index()
|
246
251
|
|
247
|
-
|
248
|
-
return simulants
|
252
|
+
return residual_probabilities
|
@@ -6,19 +6,20 @@ Disease States
|
|
6
6
|
This module contains tools to manage standard disease states.
|
7
7
|
|
8
8
|
"""
|
9
|
-
|
9
|
+
import warnings
|
10
10
|
from collections.abc import Callable
|
11
11
|
from typing import Any
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
import pandas as pd
|
15
15
|
from vivarium.framework.engine import Builder
|
16
|
-
from vivarium.framework.lookup import LookupTableData
|
17
16
|
from vivarium.framework.population import PopulationView, SimulantData
|
18
17
|
from vivarium.framework.randomness import RandomnessStream
|
19
18
|
from vivarium.framework.state_machine import State, Transient, Transition, Trigger
|
20
19
|
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
20
|
+
from vivarium.types import DataInput, LookupTableData
|
21
21
|
|
22
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
22
23
|
from vivarium_public_health.disease.transition import (
|
23
24
|
ProportionTransition,
|
24
25
|
RateTransition,
|
@@ -28,10 +29,25 @@ from vivarium_public_health.utilities import get_lookup_columns, is_non_zero
|
|
28
29
|
|
29
30
|
|
30
31
|
class BaseDiseaseState(State):
|
32
|
+
|
31
33
|
##############
|
32
34
|
# Properties #
|
33
35
|
##############
|
34
36
|
|
37
|
+
@property
|
38
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
39
|
+
configuration_defaults = super().configuration_defaults
|
40
|
+
additional_defaults = {
|
41
|
+
"prevalence": self.prevalence,
|
42
|
+
"birth_prevalence": self.birth_prevalence,
|
43
|
+
}
|
44
|
+
data_sources = {
|
45
|
+
**configuration_defaults[self.name]["data_sources"],
|
46
|
+
**additional_defaults,
|
47
|
+
}
|
48
|
+
configuration_defaults[self.name]["data_sources"] = data_sources
|
49
|
+
return configuration_defaults
|
50
|
+
|
35
51
|
@property
|
36
52
|
def columns_created(self):
|
37
53
|
return [self.event_time_column, self.event_count_column]
|
@@ -59,13 +75,15 @@ class BaseDiseaseState(State):
|
|
59
75
|
side_effect_function: Callable | None = None,
|
60
76
|
cause_type: str = "cause",
|
61
77
|
):
|
62
|
-
super().__init__(state_id, allow_self_transition)
|
78
|
+
super().__init__(state_id, allow_self_transition)
|
63
79
|
self.cause_type = cause_type
|
64
80
|
|
65
81
|
self.side_effect_function = side_effect_function
|
66
82
|
|
67
83
|
self.event_time_column = self.state_id + "_event_time"
|
68
84
|
self.event_count_column = self.state_id + "_event_count"
|
85
|
+
self.prevalence = 0.0
|
86
|
+
self.birth_prevalence = 0.0
|
69
87
|
|
70
88
|
########################
|
71
89
|
# Event-driven methods #
|
@@ -87,10 +105,7 @@ class BaseDiseaseState(State):
|
|
87
105
|
def get_initialization_parameters(self) -> dict[str, Any]:
|
88
106
|
"""Exclude side effect function and cause type from name and __repr__."""
|
89
107
|
initialization_parameters = super().get_initialization_parameters()
|
90
|
-
|
91
|
-
if key in initialization_parameters.keys():
|
92
|
-
del initialization_parameters[key]
|
93
|
-
return initialization_parameters
|
108
|
+
return {"state_id": initialization_parameters["state_id"]}
|
94
109
|
|
95
110
|
def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame:
|
96
111
|
return pd.DataFrame(
|
@@ -132,6 +147,8 @@ class BaseDiseaseState(State):
|
|
132
147
|
output: "BaseDiseaseState",
|
133
148
|
get_data_functions: dict[str, Callable] = None,
|
134
149
|
triggered=Trigger.NOT_TRIGGERED,
|
150
|
+
transition_rate: DataInput | None = None,
|
151
|
+
rate_type: str = "transition_rate",
|
135
152
|
) -> RateTransition:
|
136
153
|
"""Builds a RateTransition from this state to the given state.
|
137
154
|
|
@@ -139,18 +156,29 @@ class BaseDiseaseState(State):
|
|
139
156
|
----------
|
140
157
|
output
|
141
158
|
The end state after the transition.
|
142
|
-
|
143
159
|
get_data_functions
|
144
160
|
Map from transition type to the function to pull that transition's data.
|
145
161
|
triggered
|
146
162
|
The trigger for the transition
|
147
|
-
|
163
|
+
transition_rate
|
164
|
+
The transition rate source. Can be the data itself, a function to
|
165
|
+
retrieve the data, or the artifact key containing the data.
|
166
|
+
rate_type
|
167
|
+
The type of rate. Can be "incidence_rate", "transition_rate", or
|
168
|
+
"remission_rate".
|
148
169
|
|
149
170
|
Returns
|
150
171
|
-------
|
151
172
|
The created transition object.
|
152
173
|
"""
|
153
|
-
transition = RateTransition(
|
174
|
+
transition = RateTransition(
|
175
|
+
input_state=self,
|
176
|
+
output_state=output,
|
177
|
+
get_data_functions=get_data_functions,
|
178
|
+
triggered=triggered,
|
179
|
+
transition_rate=transition_rate,
|
180
|
+
rate_type=rate_type,
|
181
|
+
)
|
154
182
|
self.add_transition(transition)
|
155
183
|
return transition
|
156
184
|
|
@@ -159,6 +187,7 @@ class BaseDiseaseState(State):
|
|
159
187
|
output: "BaseDiseaseState",
|
160
188
|
get_data_functions: dict[str, Callable] | None = None,
|
161
189
|
triggered=Trigger.NOT_TRIGGERED,
|
190
|
+
proportion: DataInput | None = None,
|
162
191
|
) -> ProportionTransition:
|
163
192
|
"""Builds a ProportionTransition from this state to the given state.
|
164
193
|
|
@@ -170,15 +199,26 @@ class BaseDiseaseState(State):
|
|
170
199
|
Map from transition type to the function to pull that transition's data.
|
171
200
|
triggered
|
172
201
|
The trigger for the transition.
|
202
|
+
proportion
|
203
|
+
The proportion source. Can be the data itself, a function to
|
204
|
+
retrieve the data, or the artifact key containing the data.
|
173
205
|
|
174
206
|
Returns
|
175
207
|
-------
|
176
208
|
The created transition object.
|
177
209
|
"""
|
178
|
-
if
|
210
|
+
if (
|
211
|
+
get_data_functions is None or "proportion" not in get_data_functions
|
212
|
+
) and proportion is None:
|
179
213
|
raise ValueError("You must supply a proportion function.")
|
180
214
|
|
181
|
-
transition = ProportionTransition(
|
215
|
+
transition = ProportionTransition(
|
216
|
+
input_state=self,
|
217
|
+
output_state=output,
|
218
|
+
get_data_functions=get_data_functions,
|
219
|
+
triggered=triggered,
|
220
|
+
proportion=proportion,
|
221
|
+
)
|
182
222
|
self.add_transition(transition)
|
183
223
|
return transition
|
184
224
|
|
@@ -220,17 +260,19 @@ class NonDiseasedState(BaseDiseaseState):
|
|
220
260
|
self,
|
221
261
|
output: BaseDiseaseState,
|
222
262
|
get_data_functions: dict[str, Callable] = None,
|
223
|
-
|
263
|
+
triggered=Trigger.NOT_TRIGGERED,
|
264
|
+
transition_rate: DataInput | None = None,
|
265
|
+
**_kwargs,
|
224
266
|
) -> RateTransition:
|
225
|
-
if get_data_functions is None:
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
267
|
+
if get_data_functions is None and transition_rate is None:
|
268
|
+
transition_rate = f"{self.cause_type}.{output.state_id}.incidence_rate"
|
269
|
+
return super().add_rate_transition(
|
270
|
+
output=output,
|
271
|
+
get_data_functions=get_data_functions,
|
272
|
+
triggered=triggered,
|
273
|
+
transition_rate=transition_rate,
|
274
|
+
rate_type="incidence_rate",
|
275
|
+
)
|
234
276
|
|
235
277
|
|
236
278
|
class SusceptibleState(NonDiseasedState):
|
@@ -253,6 +295,13 @@ class SusceptibleState(NonDiseasedState):
|
|
253
295
|
name_prefix="susceptible_to_",
|
254
296
|
)
|
255
297
|
|
298
|
+
##################
|
299
|
+
# Public methods #
|
300
|
+
##################
|
301
|
+
|
302
|
+
def has_initialization_weights(self) -> bool:
|
303
|
+
return True
|
304
|
+
|
256
305
|
|
257
306
|
class RecoveredState(NonDiseasedState):
|
258
307
|
def __init__(
|
@@ -280,17 +329,20 @@ class DiseaseState(BaseDiseaseState):
|
|
280
329
|
|
281
330
|
@property
|
282
331
|
def configuration_defaults(self) -> dict[str, Any]:
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
"excess_mortality_rate": self.load_excess_mortality_rate,
|
291
|
-
},
|
292
|
-
},
|
332
|
+
configuration_defaults = super().configuration_defaults
|
333
|
+
additional_defaults = {
|
334
|
+
"prevalence": self._prevalence_source,
|
335
|
+
"birth_prevalence": self._birth_prevalence_source,
|
336
|
+
"dwell_time": self._dwell_time_source,
|
337
|
+
"disability_weight": self._disability_weight_source,
|
338
|
+
"excess_mortality_rate": self._excess_modality_rate_source,
|
293
339
|
}
|
340
|
+
data_sources = {
|
341
|
+
**configuration_defaults[self.name]["data_sources"],
|
342
|
+
**additional_defaults,
|
343
|
+
}
|
344
|
+
configuration_defaults[self.name]["data_sources"] = data_sources
|
345
|
+
return configuration_defaults
|
294
346
|
|
295
347
|
#####################
|
296
348
|
# Lifecycle methods #
|
@@ -304,6 +356,11 @@ class DiseaseState(BaseDiseaseState):
|
|
304
356
|
cause_type: str = "cause",
|
305
357
|
get_data_functions: dict[str, Callable] | None = None,
|
306
358
|
cleanup_function: Callable | None = None,
|
359
|
+
prevalence: DataInput | None = None,
|
360
|
+
birth_prevalence: DataInput | None = None,
|
361
|
+
dwell_time: DataInput | None = None,
|
362
|
+
disability_weight: DataInput | None = None,
|
363
|
+
excess_mortality_rate: DataInput | None = None,
|
307
364
|
):
|
308
365
|
"""
|
309
366
|
Parameters
|
@@ -322,7 +379,31 @@ class DiseaseState(BaseDiseaseState):
|
|
322
379
|
various state attributes.
|
323
380
|
cleanup_function
|
324
381
|
The cleanup function.
|
382
|
+
prevalence
|
383
|
+
The prevalence source. This is used to initialize simulants. Can be
|
384
|
+
the data itself, a function to retrieve the data, or the artifact
|
385
|
+
key containing the data.
|
386
|
+
birth_prevalence
|
387
|
+
The birth prevalence source. This is used to initialize newborn
|
388
|
+
simulants. Can be the data itself, a function to retrieve the data,
|
389
|
+
or the artifact key containing the data.
|
390
|
+
dwell_time
|
391
|
+
The dwell time source. This is used to determine how long a simulant
|
392
|
+
must remain in the state before transitioning. Can be the data
|
393
|
+
itself, a function to retrieve the data, or the artifact key
|
394
|
+
containing the data.
|
395
|
+
disability_weight
|
396
|
+
The disability weight source. This is used to calculate the
|
397
|
+
disability weight for simulants in this state. Can be the data
|
398
|
+
itself, a function to retrieve the data, or the artifact key
|
399
|
+
containing the data.
|
400
|
+
excess_mortality_rate
|
401
|
+
The excess mortality rate source. This is used to calculate the
|
402
|
+
excess mortality rate for simulants in this state. Can be the data
|
403
|
+
itself, a function to retrieve the data, or the artifact key
|
404
|
+
containing the data.
|
325
405
|
"""
|
406
|
+
|
326
407
|
super().__init__(
|
327
408
|
state_id,
|
328
409
|
allow_self_transition=allow_self_transition,
|
@@ -340,6 +421,34 @@ class DiseaseState(BaseDiseaseState):
|
|
340
421
|
)
|
341
422
|
self._cleanup_function = cleanup_function
|
342
423
|
|
424
|
+
if get_data_functions is not None:
|
425
|
+
warnings.warn(
|
426
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
427
|
+
" cause_specific_mortality_rate instead.",
|
428
|
+
DeprecationWarning,
|
429
|
+
stacklevel=2,
|
430
|
+
)
|
431
|
+
|
432
|
+
for data_type in self._get_data_functions:
|
433
|
+
try:
|
434
|
+
data_source = locals()[data_type]
|
435
|
+
except KeyError:
|
436
|
+
data_source = None
|
437
|
+
|
438
|
+
if locals()[data_type] is not None:
|
439
|
+
raise DiseaseModelError(
|
440
|
+
f"It is not allowed to pass '{data_type}' both as a"
|
441
|
+
" stand-alone argument and as part of get_data_functions."
|
442
|
+
)
|
443
|
+
|
444
|
+
self._prevalence_source = self.get_prevalence_source(prevalence)
|
445
|
+
self._birth_prevalence_source = self.get_birth_prevalence_source(birth_prevalence)
|
446
|
+
self._dwell_time_source = self.get_dwell_time_source(dwell_time)
|
447
|
+
self._disability_weight_source = self.get_disability_weight_source(disability_weight)
|
448
|
+
self._excess_modality_rate_source = self.get_excess_mortality_rate_source(
|
449
|
+
excess_mortality_rate
|
450
|
+
)
|
451
|
+
|
343
452
|
# noinspection PyAttributeOutsideInit
|
344
453
|
def setup(self, builder: Builder) -> None:
|
345
454
|
"""Performs this component's simulation setup.
|
@@ -377,32 +486,45 @@ class DiseaseState(BaseDiseaseState):
|
|
377
486
|
# Setup methods #
|
378
487
|
#################
|
379
488
|
|
380
|
-
def
|
489
|
+
def _get_data_functions_source(self, data_type: str) -> DataInput:
|
490
|
+
def data_source(builder: Builder) -> LookupTableData:
|
491
|
+
return self._get_data_functions[data_type](builder, self.state_id)
|
492
|
+
|
493
|
+
return data_source
|
494
|
+
|
495
|
+
def get_prevalence_source(self, prevalence: DataInput | None) -> DataInput:
|
381
496
|
if "prevalence" in self._get_data_functions:
|
382
|
-
return self.
|
497
|
+
return self._get_data_functions_source("prevalence")
|
498
|
+
elif prevalence is not None:
|
499
|
+
return prevalence
|
383
500
|
else:
|
384
|
-
return
|
501
|
+
return f"{self.cause_type}.{self.state_id}.prevalence"
|
385
502
|
|
386
|
-
def
|
503
|
+
def get_birth_prevalence_source(self, birth_prevalence: DataInput | None) -> DataInput:
|
387
504
|
if "birth_prevalence" in self._get_data_functions:
|
388
|
-
return self.
|
505
|
+
return self._get_data_functions_source("birth_prevalence")
|
506
|
+
elif birth_prevalence is not None:
|
507
|
+
return birth_prevalence
|
389
508
|
else:
|
390
|
-
return 0
|
509
|
+
return 0.0
|
391
510
|
|
392
|
-
def
|
511
|
+
def get_dwell_time_source(self, dwell_time: DataInput | None) -> DataInput:
|
393
512
|
if "dwell_time" in self._get_data_functions:
|
394
|
-
dwell_time = self.
|
395
|
-
|
396
|
-
dwell_time = 0
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
513
|
+
dwell_time = self._get_data_functions_source("dwell_time")
|
514
|
+
elif dwell_time is None:
|
515
|
+
dwell_time = 0.0
|
516
|
+
|
517
|
+
def dwell_time_source(builder: Builder) -> LookupTableData:
|
518
|
+
dwell_time_ = self.get_data(builder, dwell_time)
|
519
|
+
if isinstance(dwell_time_, pd.Timedelta):
|
520
|
+
dwell_time_ = dwell_time_.total_seconds() / (60 * 60 * 24)
|
521
|
+
if (
|
522
|
+
isinstance(dwell_time_, pd.DataFrame) and np.any(dwell_time_.value != 0)
|
523
|
+
) or dwell_time_ > 0:
|
524
|
+
self.transition_set.allow_null_transition = True
|
525
|
+
return dwell_time_
|
526
|
+
|
527
|
+
return dwell_time_source
|
406
528
|
|
407
529
|
def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
|
408
530
|
required_columns = get_lookup_columns([self.lookup_tables["dwell_time"]])
|
@@ -412,20 +534,22 @@ class DiseaseState(BaseDiseaseState):
|
|
412
534
|
requires_columns=required_columns,
|
413
535
|
)
|
414
536
|
|
415
|
-
def
|
537
|
+
def get_disability_weight_source(self, disability_weight: DataInput | None) -> DataInput:
|
416
538
|
if "disability_weight" in self._get_data_functions:
|
417
|
-
disability_weight = self.
|
418
|
-
|
419
|
-
|
539
|
+
disability_weight = self._get_data_functions_source("disability_weight")
|
540
|
+
elif disability_weight is not None:
|
541
|
+
disability_weight = disability_weight
|
420
542
|
else:
|
421
|
-
disability_weight =
|
422
|
-
f"{self.cause_type}.{self.state_id}.disability_weight"
|
423
|
-
)
|
543
|
+
disability_weight = f"{self.cause_type}.{self.state_id}.disability_weight"
|
424
544
|
|
425
|
-
|
426
|
-
|
545
|
+
def disability_weight_source(builder: Builder) -> LookupTableData:
|
546
|
+
disability_weight_ = self.get_data(builder, disability_weight)
|
547
|
+
if isinstance(disability_weight_, pd.DataFrame) and len(disability_weight_) == 1:
|
548
|
+
# sequela only have single value
|
549
|
+
disability_weight_ = disability_weight_.value[0]
|
550
|
+
return disability_weight_
|
427
551
|
|
428
|
-
return
|
552
|
+
return disability_weight_source
|
429
553
|
|
430
554
|
def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
|
431
555
|
lookup_columns = get_lookup_columns([self.lookup_tables["disability_weight"]])
|
@@ -435,16 +559,25 @@ class DiseaseState(BaseDiseaseState):
|
|
435
559
|
requires_columns=lookup_columns + ["alive", self.model],
|
436
560
|
)
|
437
561
|
|
438
|
-
def
|
562
|
+
def get_excess_mortality_rate_source(
|
563
|
+
self, excess_mortality_rate: DataInput | None
|
564
|
+
) -> DataInput:
|
439
565
|
if "excess_mortality_rate" in self._get_data_functions:
|
440
|
-
|
441
|
-
elif
|
442
|
-
|
443
|
-
|
566
|
+
excess_mortality_rate = self._get_data_functions_source("excess_mortality_rate")
|
567
|
+
elif excess_mortality_rate is None:
|
568
|
+
excess_mortality_rate = f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
|
569
|
+
|
570
|
+
def excess_mortality_rate_source(builder: Builder) -> LookupTableData:
|
571
|
+
if excess_mortality_rate is not None:
|
572
|
+
return self.get_data(builder, excess_mortality_rate)
|
573
|
+
elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
|
574
|
+
return 0
|
444
575
|
return builder.data.load(
|
445
576
|
f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
|
446
577
|
)
|
447
578
|
|
579
|
+
return excess_mortality_rate_source
|
580
|
+
|
448
581
|
def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
|
449
582
|
lookup_columns = get_lookup_columns([self.lookup_tables["excess_mortality_rate"]])
|
450
583
|
return builder.value.register_rate_producer(
|
@@ -470,24 +603,27 @@ class DiseaseState(BaseDiseaseState):
|
|
470
603
|
# Public methods #
|
471
604
|
##################
|
472
605
|
|
606
|
+
def has_initialization_weights(self) -> bool:
|
607
|
+
return True
|
608
|
+
|
473
609
|
def add_rate_transition(
|
474
610
|
self,
|
475
611
|
output: BaseDiseaseState,
|
476
612
|
get_data_functions: dict[str, Callable] = None,
|
477
|
-
|
613
|
+
triggered=Trigger.NOT_TRIGGERED,
|
614
|
+
transition_rate: DataInput | None = None,
|
615
|
+
rate_type: str = "transition_rate",
|
478
616
|
) -> RateTransition:
|
479
|
-
if get_data_functions is None:
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
)
|
489
|
-
raise ValueError("You must supply a transition rate or remission rate function.")
|
490
|
-
return super().add_rate_transition(output, get_data_functions, **kwargs)
|
617
|
+
if get_data_functions is None and transition_rate is None:
|
618
|
+
transition_rate = f"{self.cause_type}.{self.state_id}.remission_rate"
|
619
|
+
rate_type = "remission_rate"
|
620
|
+
return super().add_rate_transition(
|
621
|
+
output=output,
|
622
|
+
get_data_functions=get_data_functions,
|
623
|
+
triggered=triggered,
|
624
|
+
transition_rate=transition_rate,
|
625
|
+
rate_type=rate_type,
|
626
|
+
)
|
491
627
|
|
492
628
|
def add_dwell_time_transition(
|
493
629
|
self,
|
@@ -6,7 +6,7 @@ Disease Transitions
|
|
6
6
|
This module contains tools to model transitions between disease states.
|
7
7
|
|
8
8
|
"""
|
9
|
-
|
9
|
+
import warnings
|
10
10
|
from collections.abc import Callable
|
11
11
|
from typing import TYPE_CHECKING, Any
|
12
12
|
|
@@ -15,7 +15,9 @@ from vivarium.framework.engine import Builder
|
|
15
15
|
from vivarium.framework.state_machine import Transition, Trigger
|
16
16
|
from vivarium.framework.utilities import rate_to_probability
|
17
17
|
from vivarium.framework.values import list_combiner, union_post_processor
|
18
|
+
from vivarium.types import DataInput
|
18
19
|
|
20
|
+
from vivarium_public_health.disease.exceptions import DiseaseModelError
|
19
21
|
from vivarium_public_health.utilities import get_lookup_columns
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
@@ -44,7 +46,7 @@ class RateTransition(Transition):
|
|
44
46
|
return {
|
45
47
|
f"{self.name}": {
|
46
48
|
"data_sources": {
|
47
|
-
"transition_rate": self.
|
49
|
+
"transition_rate": self._rate_source,
|
48
50
|
},
|
49
51
|
},
|
50
52
|
}
|
@@ -55,19 +57,37 @@ class RateTransition(Transition):
|
|
55
57
|
|
56
58
|
@property
|
57
59
|
def transition_rate_pipeline_name(self) -> str:
|
58
|
-
if
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
60
|
+
if self._get_data_functions:
|
61
|
+
if "incidence_rate" in self._get_data_functions:
|
62
|
+
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
|
63
|
+
elif "remission_rate" in self._get_data_functions:
|
64
|
+
pipeline_name = f"{self.input_state.state_id}.remission_rate"
|
65
|
+
elif "transition_rate" in self._get_data_functions:
|
66
|
+
pipeline_name = (
|
67
|
+
f"{self.input_state.state_id}_to_{self.output_state.state_id}"
|
68
|
+
".transition_rate"
|
69
|
+
)
|
70
|
+
else:
|
71
|
+
raise DiseaseModelError(
|
72
|
+
"Cannot determine rate_transition pipeline name: "
|
73
|
+
"no valid data functions supplied."
|
74
|
+
)
|
66
75
|
else:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
76
|
+
if self.rate_type == "incidence_rate":
|
77
|
+
pipeline_name = f"{self.output_state.state_id}.incidence_rate"
|
78
|
+
elif self.rate_type == "remission_rate":
|
79
|
+
pipeline_name = f"{self.input_state.state_id}.remission_rate"
|
80
|
+
elif self.rate_type == "transition_rate":
|
81
|
+
pipeline_name = (
|
82
|
+
f"{self.input_state.state_id}_to_{self.output_state.state_id}"
|
83
|
+
".transition_rate"
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
raise DiseaseModelError(
|
87
|
+
"Cannot determine rate_transition pipeline name: invalid"
|
88
|
+
f" rate_type '{self.rate_type} supplied."
|
89
|
+
)
|
90
|
+
|
71
91
|
return pipeline_name
|
72
92
|
|
73
93
|
#####################
|
@@ -80,6 +100,8 @@ class RateTransition(Transition):
|
|
80
100
|
output_state: "BaseDiseaseState",
|
81
101
|
get_data_functions: dict[str, Callable] = None,
|
82
102
|
triggered=Trigger.NOT_TRIGGERED,
|
103
|
+
transition_rate: DataInput | None = None,
|
104
|
+
rate_type: str = "transition_rate",
|
83
105
|
):
|
84
106
|
super().__init__(
|
85
107
|
input_state, output_state, probability_func=self._probability, triggered=triggered
|
@@ -87,6 +109,22 @@ class RateTransition(Transition):
|
|
87
109
|
self._get_data_functions = (
|
88
110
|
get_data_functions if get_data_functions is not None else {}
|
89
111
|
)
|
112
|
+
self._rate_source = self._get_rate_source(transition_rate)
|
113
|
+
self.rate_type = rate_type
|
114
|
+
|
115
|
+
if get_data_functions is not None:
|
116
|
+
warnings.warn(
|
117
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
118
|
+
" 'transition_rate' instead.",
|
119
|
+
DeprecationWarning,
|
120
|
+
stacklevel=2,
|
121
|
+
)
|
122
|
+
if transition_rate is not None:
|
123
|
+
raise DiseaseModelError(
|
124
|
+
"It is not allowed to pass a transition rate"
|
125
|
+
" both as a stand-alone argument and as part of"
|
126
|
+
" get_data_functions."
|
127
|
+
)
|
90
128
|
|
91
129
|
# noinspection PyAttributeOutsideInit
|
92
130
|
def setup(self, builder: Builder) -> None:
|
@@ -109,17 +147,19 @@ class RateTransition(Transition):
|
|
109
147
|
# Setup methods #
|
110
148
|
#################
|
111
149
|
|
112
|
-
def
|
113
|
-
if
|
114
|
-
rate_data =
|
150
|
+
def _get_rate_source(self, transition_rate: DataInput | None) -> DataInput:
|
151
|
+
if transition_rate is not None:
|
152
|
+
rate_data = transition_rate
|
153
|
+
elif "incidence_rate" in self._get_data_functions:
|
154
|
+
rate_data = lambda builder: self._get_data_functions["incidence_rate"](
|
115
155
|
builder, self.output_state.state_id
|
116
156
|
)
|
117
157
|
elif "remission_rate" in self._get_data_functions:
|
118
|
-
rate_data = self._get_data_functions["remission_rate"](
|
158
|
+
rate_data = lambda builder: self._get_data_functions["remission_rate"](
|
119
159
|
builder, self.input_state.state_id
|
120
160
|
)
|
121
161
|
elif "transition_rate" in self._get_data_functions:
|
122
|
-
rate_data = self._get_data_functions["transition_rate"](
|
162
|
+
rate_data = lambda builder: self._get_data_functions["transition_rate"](
|
123
163
|
builder, self.input_state.state_id, self.output_state.state_id
|
124
164
|
)
|
125
165
|
else:
|
@@ -172,21 +212,38 @@ class ProportionTransition(Transition):
|
|
172
212
|
output_state: "BaseDiseaseState",
|
173
213
|
get_data_functions: dict[str, Callable] = None,
|
174
214
|
triggered=Trigger.NOT_TRIGGERED,
|
215
|
+
proportion: DataInput | None = None,
|
175
216
|
):
|
176
217
|
super().__init__(
|
177
218
|
input_state, output_state, probability_func=self._probability, triggered=triggered
|
178
219
|
)
|
220
|
+
self._proportion_source = proportion
|
179
221
|
self._get_data_functions = (
|
180
222
|
get_data_functions if get_data_functions is not None else {}
|
181
223
|
)
|
182
224
|
|
225
|
+
if get_data_functions is not None:
|
226
|
+
warnings.warn(
|
227
|
+
"The argument 'get_data_functions' has been deprecated. Use"
|
228
|
+
" 'proportion' instead.",
|
229
|
+
DeprecationWarning,
|
230
|
+
stacklevel=2,
|
231
|
+
)
|
232
|
+
if proportion is not None:
|
233
|
+
raise DiseaseModelError(
|
234
|
+
"It is not allowed to pass a proportion both as a"
|
235
|
+
" stand-alone argument and as part of get_data_functions."
|
236
|
+
)
|
237
|
+
|
183
238
|
#################
|
184
239
|
# Setup methods #
|
185
240
|
#################
|
186
241
|
|
187
|
-
def load_proportion(self, builder: Builder) ->
|
242
|
+
def load_proportion(self, builder: Builder) -> DataInput:
|
243
|
+
if self._proportion_source is not None:
|
244
|
+
return self._proportion_source
|
188
245
|
if "proportion" not in self._get_data_functions:
|
189
|
-
raise
|
246
|
+
raise DiseaseModelError("Must supply a proportion function")
|
190
247
|
return self._get_data_functions["proportion"](builder, self.output_state.state_id)
|
191
248
|
|
192
249
|
def _probability(self, index):
|
@@ -17,6 +17,7 @@ import scipy
|
|
17
17
|
from layered_config_tree import ConfigurationError
|
18
18
|
from vivarium import Component
|
19
19
|
from vivarium.framework.engine import Builder
|
20
|
+
from vivarium.framework.values import Pipeline
|
20
21
|
|
21
22
|
from vivarium_public_health.risks import Risk
|
22
23
|
from vivarium_public_health.risks.data_transformations import (
|
@@ -112,7 +113,8 @@ class RiskEffect(Component):
|
|
112
113
|
def setup(self, builder: Builder) -> None:
|
113
114
|
self.exposure = self.get_risk_exposure(builder)
|
114
115
|
|
115
|
-
self.
|
116
|
+
self._relative_risk_source = self.get_relative_risk_source(builder)
|
117
|
+
self.relative_risk = self.get_relative_risk(builder)
|
116
118
|
|
117
119
|
self.register_target_modifier(builder)
|
118
120
|
self.register_paf_modifier(builder)
|
@@ -124,7 +126,7 @@ class RiskEffect(Component):
|
|
124
126
|
def build_all_lookup_tables(self, builder: Builder) -> None:
|
125
127
|
self._exposure_distribution_type = self.get_distribution_type(builder)
|
126
128
|
|
127
|
-
rr_data = self.
|
129
|
+
rr_data = self.load_relative_risk(builder)
|
128
130
|
rr_value_cols = None
|
129
131
|
if self.is_exposure_categorical:
|
130
132
|
rr_data, rr_value_cols = self.process_categorical_data(builder, rr_data)
|
@@ -146,7 +148,7 @@ class RiskEffect(Component):
|
|
146
148
|
return risk_exposure_component.distribution_type
|
147
149
|
return risk_exposure_component.get_distribution_type(builder)
|
148
150
|
|
149
|
-
def
|
151
|
+
def load_relative_risk(
|
150
152
|
self,
|
151
153
|
builder: Builder,
|
152
154
|
configuration=None,
|
@@ -260,24 +262,27 @@ class RiskEffect(Component):
|
|
260
262
|
def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
261
263
|
return builder.value.get_value(self.exposure_pipeline_name)
|
262
264
|
|
263
|
-
def
|
264
|
-
|
265
|
-
|
265
|
+
def adjust_target(self, index: pd.Index, target: pd.Series) -> pd.Series:
|
266
|
+
relative_risk = self.relative_risk(index)
|
267
|
+
return target * relative_risk
|
268
|
+
|
269
|
+
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
270
|
+
|
266
271
|
if not self.is_exposure_categorical:
|
267
272
|
tmred = builder.data.load(f"{self.risk}.tmred")
|
268
273
|
tmrel = 0.5 * (tmred["min"] + tmred["max"])
|
269
274
|
scale = builder.data.load(f"{self.risk}.relative_risk_scalar")
|
270
275
|
|
271
|
-
def
|
276
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
272
277
|
rr = self.lookup_tables["relative_risk"](index)
|
273
278
|
exposure = self.exposure(index)
|
274
279
|
relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1)
|
275
|
-
return
|
280
|
+
return relative_risk
|
276
281
|
|
277
282
|
else:
|
278
283
|
index_columns = ["index", self.risk.name]
|
279
284
|
|
280
|
-
def
|
285
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
281
286
|
rr = self.lookup_tables["relative_risk"](index)
|
282
287
|
exposure = self.exposure(index).reset_index()
|
283
288
|
exposure.columns = index_columns
|
@@ -288,16 +293,24 @@ class RiskEffect(Component):
|
|
288
293
|
relative_risk = relative_risk.set_index(index_columns)
|
289
294
|
|
290
295
|
effect = relative_risk.loc[exposure.index, "value"].droplevel(self.risk.name)
|
291
|
-
|
292
|
-
return affected_rates
|
296
|
+
return effect
|
293
297
|
|
294
|
-
return
|
298
|
+
return generate_relative_risk
|
299
|
+
|
300
|
+
def get_relative_risk(self, builder: Builder) -> Pipeline:
|
301
|
+
return builder.value.register_value_producer(
|
302
|
+
f"{self.risk.name}_on_{self.target.name}.relative_risk",
|
303
|
+
self._relative_risk_source,
|
304
|
+
component=self,
|
305
|
+
required_resources=[self.exposure],
|
306
|
+
)
|
295
307
|
|
296
308
|
def register_target_modifier(self, builder: Builder) -> None:
|
297
309
|
builder.value.register_value_modifier(
|
298
310
|
self.target_pipeline_name,
|
299
|
-
modifier=self.
|
300
|
-
|
311
|
+
modifier=self.adjust_target,
|
312
|
+
component=self,
|
313
|
+
required_resources=[self.relative_risk],
|
301
314
|
)
|
302
315
|
|
303
316
|
def register_paf_modifier(self, builder: Builder) -> None:
|
@@ -307,7 +320,8 @@ class RiskEffect(Component):
|
|
307
320
|
builder.value.register_value_modifier(
|
308
321
|
self.target_paf_pipeline_name,
|
309
322
|
modifier=self.lookup_tables["population_attributable_fraction"],
|
310
|
-
|
323
|
+
component=self,
|
324
|
+
required_resources=required_columns,
|
311
325
|
)
|
312
326
|
|
313
327
|
##################
|
@@ -371,7 +385,7 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
371
385
|
return f"non_log_linear_risk_effect.{risk.name}_on_{target}"
|
372
386
|
|
373
387
|
def build_all_lookup_tables(self, builder: Builder) -> None:
|
374
|
-
rr_data = self.
|
388
|
+
rr_data = self.load_relative_risk(builder)
|
375
389
|
self.validate_rr_data(rr_data)
|
376
390
|
|
377
391
|
def define_rr_intervals(df: pd.DataFrame) -> pd.DataFrame:
|
@@ -415,7 +429,7 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
415
429
|
builder, paf_data
|
416
430
|
)
|
417
431
|
|
418
|
-
def
|
432
|
+
def load_relative_risk(
|
419
433
|
self,
|
420
434
|
builder: Builder,
|
421
435
|
configuration=None,
|
@@ -472,10 +486,8 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
472
486
|
|
473
487
|
return rr_data
|
474
488
|
|
475
|
-
def
|
476
|
-
|
477
|
-
) -> Callable[[pd.Index, pd.Series], pd.Series]:
|
478
|
-
def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
|
489
|
+
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
|
490
|
+
def generate_relative_risk(index: pd.Index) -> pd.Series:
|
479
491
|
rr_intervals = self.lookup_tables["relative_risk"](index)
|
480
492
|
exposure = self.population_view.get(index)[f"{self.risk.name}_exposure"]
|
481
493
|
x1, x2 = (
|
@@ -486,9 +498,9 @@ class NonLogLinearRiskEffect(RiskEffect):
|
|
486
498
|
m = (y2 - y1) / (x2 - x1)
|
487
499
|
b = y1 - m * x1
|
488
500
|
relative_risk = b + m * exposure
|
489
|
-
return
|
501
|
+
return relative_risk
|
490
502
|
|
491
|
-
return
|
503
|
+
return generate_relative_risk
|
492
504
|
|
493
505
|
##############
|
494
506
|
# Validators #
|
@@ -336,7 +336,6 @@ class LBWSGRiskEffect(RiskEffect):
|
|
336
336
|
|
337
337
|
super().setup(builder)
|
338
338
|
self.interpolator = self.get_interpolator(builder)
|
339
|
-
self.relative_risk = self.get_relative_risk_pipeline(builder)
|
340
339
|
|
341
340
|
#################
|
342
341
|
# Setup methods #
|
@@ -393,10 +392,10 @@ class LBWSGRiskEffect(RiskEffect):
|
|
393
392
|
for age_start in exposed_age_group_starts
|
394
393
|
}
|
395
394
|
|
396
|
-
def
|
395
|
+
def get_relative_risk(self, builder: Builder) -> Pipeline:
|
397
396
|
return builder.value.register_value_producer(
|
398
397
|
self.relative_risk_pipeline_name,
|
399
|
-
source=self.
|
398
|
+
source=self.get_relative_risk_source,
|
400
399
|
requires_columns=["age"] + self.rr_column_names,
|
401
400
|
)
|
402
401
|
|
@@ -470,7 +469,7 @@ class LBWSGRiskEffect(RiskEffect):
|
|
470
469
|
# Pipeline sources and modifiers #
|
471
470
|
##################################
|
472
471
|
|
473
|
-
def
|
472
|
+
def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
|
474
473
|
pop = self.population_view.get(index)
|
475
474
|
relative_risk = pd.Series(1.0, index=index, name=self.relative_risk_pipeline_name)
|
476
475
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: vivarium_public_health
|
3
|
-
Version: 3.1.
|
3
|
+
Version: 3.1.4
|
4
4
|
Summary: Components for modelling diseases, risks, and interventions with ``vivarium``
|
5
5
|
Home-page: https://github.com/ihmeuw/vivarium_public_health
|
6
6
|
Author: The vivarium developers
|
@@ -26,44 +26,44 @@ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
26
26
|
Classifier: Topic :: Scientific/Engineering :: Physics
|
27
27
|
Classifier: Topic :: Software Development :: Libraries
|
28
28
|
License-File: LICENSE.txt
|
29
|
-
Requires-Dist: vivarium
|
30
|
-
Requires-Dist:
|
29
|
+
Requires-Dist: vivarium>=3.2.3
|
30
|
+
Requires-Dist: layered_config_tree>=1.0.1
|
31
31
|
Requires-Dist: loguru
|
32
|
-
Requires-Dist: numpy
|
32
|
+
Requires-Dist: numpy<2.0.0
|
33
33
|
Requires-Dist: pandas
|
34
34
|
Requires-Dist: scipy
|
35
35
|
Requires-Dist: tables
|
36
|
-
Requires-Dist:
|
36
|
+
Requires-Dist: risk_distributions>=2.0.11
|
37
37
|
Requires-Dist: pyarrow
|
38
|
-
Provides-Extra: dev
|
39
|
-
Requires-Dist: sphinx <7.0 ; extra == 'dev'
|
40
|
-
Requires-Dist: sphinx-rtd-theme ; extra == 'dev'
|
41
|
-
Requires-Dist: sphinx-click ; extra == 'dev'
|
42
|
-
Requires-Dist: sphinx-autodoc-typehints ; extra == 'dev'
|
43
|
-
Requires-Dist: IPython ; extra == 'dev'
|
44
|
-
Requires-Dist: matplotlib ; extra == 'dev'
|
45
|
-
Requires-Dist: vivarium-testing-utils ; extra == 'dev'
|
46
|
-
Requires-Dist: pytest ; extra == 'dev'
|
47
|
-
Requires-Dist: pytest-cov ; extra == 'dev'
|
48
|
-
Requires-Dist: pytest-mock ; extra == 'dev'
|
49
|
-
Requires-Dist: hypothesis ; extra == 'dev'
|
50
|
-
Requires-Dist: pyyaml ; extra == 'dev'
|
51
|
-
Requires-Dist: black ==22.3.0 ; extra == 'dev'
|
52
|
-
Requires-Dist: isort ; extra == 'dev'
|
53
38
|
Provides-Extra: docs
|
54
|
-
Requires-Dist: sphinx
|
55
|
-
Requires-Dist: sphinx-rtd-theme
|
56
|
-
Requires-Dist: sphinx-click
|
57
|
-
Requires-Dist: sphinx-autodoc-typehints
|
58
|
-
Requires-Dist: IPython
|
59
|
-
Requires-Dist: matplotlib
|
39
|
+
Requires-Dist: sphinx<7.0; extra == "docs"
|
40
|
+
Requires-Dist: sphinx-rtd-theme; extra == "docs"
|
41
|
+
Requires-Dist: sphinx-click; extra == "docs"
|
42
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
|
43
|
+
Requires-Dist: IPython; extra == "docs"
|
44
|
+
Requires-Dist: matplotlib; extra == "docs"
|
60
45
|
Provides-Extra: test
|
61
|
-
Requires-Dist:
|
62
|
-
Requires-Dist: pytest
|
63
|
-
Requires-Dist: pytest-cov
|
64
|
-
Requires-Dist: pytest-mock
|
65
|
-
Requires-Dist: hypothesis
|
66
|
-
Requires-Dist: pyyaml
|
46
|
+
Requires-Dist: vivarium_testing_utils; extra == "test"
|
47
|
+
Requires-Dist: pytest; extra == "test"
|
48
|
+
Requires-Dist: pytest-cov; extra == "test"
|
49
|
+
Requires-Dist: pytest-mock; extra == "test"
|
50
|
+
Requires-Dist: hypothesis; extra == "test"
|
51
|
+
Requires-Dist: pyyaml; extra == "test"
|
52
|
+
Provides-Extra: dev
|
53
|
+
Requires-Dist: sphinx<7.0; extra == "dev"
|
54
|
+
Requires-Dist: sphinx-rtd-theme; extra == "dev"
|
55
|
+
Requires-Dist: sphinx-click; extra == "dev"
|
56
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "dev"
|
57
|
+
Requires-Dist: IPython; extra == "dev"
|
58
|
+
Requires-Dist: matplotlib; extra == "dev"
|
59
|
+
Requires-Dist: vivarium_testing_utils; extra == "dev"
|
60
|
+
Requires-Dist: pytest; extra == "dev"
|
61
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
62
|
+
Requires-Dist: pytest-mock; extra == "dev"
|
63
|
+
Requires-Dist: hypothesis; extra == "dev"
|
64
|
+
Requires-Dist: pyyaml; extra == "dev"
|
65
|
+
Requires-Dist: black==22.3.0; extra == "dev"
|
66
|
+
Requires-Dist: isort; extra == "dev"
|
67
67
|
|
68
68
|
Vivarium Public Health
|
69
69
|
======================
|
@@ -1,13 +1,14 @@
|
|
1
1
|
vivarium_public_health/__about__.py,sha256=RgWycPypKZS80TpSX7o41cREnG8PfguNHDHLuLyl820,487
|
2
2
|
vivarium_public_health/__init__.py,sha256=GDeeP-7OlCBwPuv_xQoB1wNmvCaFsqfTB7qnnYApm0w,1343
|
3
|
-
vivarium_public_health/_version.py,sha256=
|
3
|
+
vivarium_public_health/_version.py,sha256=2uixSkocUHf2KiY1oTfzz_5AQGmlrHnypVxGgr4mV9c,22
|
4
4
|
vivarium_public_health/utilities.py,sha256=QNXQ6fhAr1HcV-GwKw7wQLz6QyuNxqNvMA-XujKjTgs,3035
|
5
5
|
vivarium_public_health/disease/__init__.py,sha256=VUJHDLlE6ngo2qHNQUtZ8OWH5H_T7_ao-xsYKDkRmHw,443
|
6
|
-
vivarium_public_health/disease/
|
6
|
+
vivarium_public_health/disease/exceptions.py,sha256=vb30IIV82OiDf2cNZCs_E2rF6mdDDHbnZSND60no5CU,97
|
7
|
+
vivarium_public_health/disease/model.py,sha256=qA0mhbIL0rfPtJ1Csmvxvf757vUdrjcbpxYmL_dsniI,9175
|
7
8
|
vivarium_public_health/disease/models.py,sha256=01UK7yB2zGPFzmlIpvhd-XnGe6vSCMDza3QTidgY7Nc,3479
|
8
9
|
vivarium_public_health/disease/special_disease.py,sha256=kTVuE5rQjUK62ysComG8nB2f61aCKdca9trRB1zsDCQ,14537
|
9
|
-
vivarium_public_health/disease/state.py,sha256=
|
10
|
-
vivarium_public_health/disease/transition.py,sha256=
|
10
|
+
vivarium_public_health/disease/state.py,sha256=l1iBZl8IpNOegQjZIY2xq09SBhsn5jSoYD89CymJAUs,28309
|
11
|
+
vivarium_public_health/disease/transition.py,sha256=qRpHzq29S_I_KzyDncAahg27bnHNdrA5tsUWuJ4Zty4,8964
|
11
12
|
vivarium_public_health/mslt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
13
|
vivarium_public_health/mslt/delay.py,sha256=TrFp9nmv-PSu6xdQXfyeM4X8RrXTLUnQMxOuhdvD-Wk,22787
|
13
14
|
vivarium_public_health/mslt/disease.py,sha256=DONSItBiOUk1qBE6Msw0vrV0XLW4BPzWVxwVK90GK1I,16638
|
@@ -35,15 +36,15 @@ vivarium_public_health/risks/__init__.py,sha256=z8DcnZGxqNVAyFZm2WAV-IVNGvrSS4iz
|
|
35
36
|
vivarium_public_health/risks/base_risk.py,sha256=XQ_7rYJS5gh0coEKDqcc_zYdjPDBZlj6-THsIQxL3zs,10888
|
36
37
|
vivarium_public_health/risks/data_transformations.py,sha256=SgdPKc95BBqgMNUdlAQM8k6iaXcpxnjk5B2ySTES1Yg,9269
|
37
38
|
vivarium_public_health/risks/distributions.py,sha256=a63-ihg2itxqgowDZbUix8soErxs_y8TRwsdtTCIUU4,18121
|
38
|
-
vivarium_public_health/risks/effect.py,sha256=
|
39
|
+
vivarium_public_health/risks/effect.py,sha256=2DaBKxncS94cm8Ih-TQtbV1mGsEZhx6fEnB5V_ocIZM,21241
|
39
40
|
vivarium_public_health/risks/implementations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
-
vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py,sha256=
|
41
|
+
vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py,sha256=Iw5e7sCrWplEu6dic1Wu6q9e0rdeem_6gQ_-N52G54E,17866
|
41
42
|
vivarium_public_health/treatment/__init__.py,sha256=wONElu9aJbBYwpYIovYPYaN_GYfVhPXtTeFWSdQMgA0,222
|
42
43
|
vivarium_public_health/treatment/magic_wand.py,sha256=zGIhrNgB9q6JD7fHlvbDQb3H5e_N_QsROO4Y0kl_JQM,1955
|
43
44
|
vivarium_public_health/treatment/scale_up.py,sha256=hVz0ELXDqlpcExI31rKdepxqcW_hy2hZSa6qCzv6udU,7020
|
44
45
|
vivarium_public_health/treatment/therapeutic_inertia.py,sha256=8Z97s7GfcpfLu1U1ESJSqeEk4L__a3M0GbBV21MFg2s,2346
|
45
|
-
vivarium_public_health-3.1.
|
46
|
-
vivarium_public_health-3.1.
|
47
|
-
vivarium_public_health-3.1.
|
48
|
-
vivarium_public_health-3.1.
|
49
|
-
vivarium_public_health-3.1.
|
46
|
+
vivarium_public_health-3.1.4.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
|
47
|
+
vivarium_public_health-3.1.4.dist-info/METADATA,sha256=7h8TxTjuThUizK3wr0mzLB31wwS4Kqr9SVgJA9Gf7gA,4028
|
48
|
+
vivarium_public_health-3.1.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
49
|
+
vivarium_public_health-3.1.4.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
|
50
|
+
vivarium_public_health-3.1.4.dist-info/RECORD,,
|
{vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/LICENSE.txt
RENAMED
File without changes
|
{vivarium_public_health-3.1.2.dist-info → vivarium_public_health-3.1.4.dist-info}/top_level.txt
RENAMED
File without changes
|