vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/model.py +23 -21
- vivarium_public_health/disease/models.py +1 -0
- vivarium_public_health/disease/special_disease.py +40 -41
- vivarium_public_health/disease/state.py +42 -125
- vivarium_public_health/disease/transition.py +70 -27
- vivarium_public_health/mslt/delay.py +1 -0
- vivarium_public_health/mslt/disease.py +1 -0
- vivarium_public_health/mslt/intervention.py +1 -0
- vivarium_public_health/mslt/magic_wand_components.py +1 -0
- vivarium_public_health/mslt/observer.py +1 -0
- vivarium_public_health/mslt/population.py +1 -0
- vivarium_public_health/plugins/parser.py +61 -31
- vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
- vivarium_public_health/population/base_population.py +2 -1
- vivarium_public_health/population/mortality.py +83 -80
- vivarium_public_health/{metrics → results}/__init__.py +2 -0
- vivarium_public_health/results/columns.py +22 -0
- vivarium_public_health/results/disability.py +187 -0
- vivarium_public_health/results/disease.py +222 -0
- vivarium_public_health/results/mortality.py +186 -0
- vivarium_public_health/results/observer.py +78 -0
- vivarium_public_health/results/risk.py +138 -0
- vivarium_public_health/results/simple_cause.py +18 -0
- vivarium_public_health/{metrics → results}/stratification.py +10 -8
- vivarium_public_health/risks/__init__.py +1 -2
- vivarium_public_health/risks/base_risk.py +134 -29
- vivarium_public_health/risks/data_transformations.py +65 -326
- vivarium_public_health/risks/distributions.py +315 -145
- vivarium_public_health/risks/effect.py +376 -75
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
- vivarium_public_health/treatment/magic_wand.py +1 -0
- vivarium_public_health/treatment/scale_up.py +1 -0
- vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
- vivarium_public_health/utilities.py +17 -2
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
- vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
- vivarium_public_health/metrics/disability.py +0 -118
- vivarium_public_health/metrics/disease.py +0 -136
- vivarium_public_health/metrics/mortality.py +0 -144
- vivarium_public_health/metrics/risk.py +0 -110
- vivarium_public_health/testing/__init__.py +0 -0
- vivarium_public_health/testing/mock_artifact.py +0 -145
- vivarium_public_health/testing/utils.py +0 -71
- vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -44,25 +44,81 @@ on mortality is calculated, by subtracting the raw unmodeled csmr before adding
|
|
44
44
|
back the modified unmodeled csmr.
|
45
45
|
|
46
46
|
"""
|
47
|
-
|
47
|
+
|
48
|
+
from typing import Any, Dict, List, Optional, Union
|
48
49
|
|
49
50
|
import pandas as pd
|
50
51
|
from vivarium import Component
|
51
52
|
from vivarium.framework.engine import Builder
|
52
53
|
from vivarium.framework.event import Event
|
53
|
-
from vivarium.framework.lookup import LookupTable
|
54
54
|
from vivarium.framework.population import SimulantData
|
55
55
|
from vivarium.framework.randomness import RandomnessStream
|
56
56
|
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
57
57
|
|
58
|
+
from vivarium_public_health.utilities import get_lookup_columns
|
59
|
+
|
58
60
|
|
59
61
|
class Mortality(Component):
|
60
|
-
|
62
|
+
"""
|
63
|
+
This is the mortality component which models sources of mortality for a model.
|
64
|
+
THe component models all cause mortality and allows for disease models to contribute
|
65
|
+
cause specific mortality. Data used by this class should be supplied in the artifact
|
66
|
+
and is configurable in the configuration to build lookup tables. For instance, let's
|
67
|
+
say we want to use sex and hair color to build a lookup table for all cause mortality.
|
68
|
+
|
69
|
+
.. code-block:: yaml
|
70
|
+
|
71
|
+
configuration:
|
72
|
+
mortality:
|
73
|
+
all_cause_mortality_rate:
|
74
|
+
categorical_columns: ["sex", "hair_color"]
|
75
|
+
|
76
|
+
Similarly, we can do the same thing for unmodeled causes. Here is an example:
|
77
|
+
|
78
|
+
.. code-block:: yaml
|
79
|
+
|
80
|
+
configuration:
|
81
|
+
mortality:
|
82
|
+
unmodeled_cause_specific_mortality_rate:
|
83
|
+
unmodeled_causes: ["maternal_disorders", maternal_hemorrhage]
|
84
|
+
categorical_columns: ["sex", "hair_color"]
|
85
|
+
|
86
|
+
Or if we wanted to make the data a scalar value for all cause mortality rate we could
|
87
|
+
configure that as well.
|
88
|
+
|
89
|
+
.. code-block:: yaml
|
90
|
+
|
91
|
+
configuration:
|
92
|
+
mortality:
|
93
|
+
all_cause_mortality_rate:
|
94
|
+
value: 0.01
|
95
|
+
|
96
|
+
"""
|
61
97
|
|
62
98
|
##############
|
63
99
|
# Properties #
|
64
100
|
##############
|
65
101
|
|
102
|
+
@property
|
103
|
+
def configuration_defaults(self) -> Dict[str, Any]:
|
104
|
+
return {
|
105
|
+
"mortality": {
|
106
|
+
"data_sources": {
|
107
|
+
"all_cause_mortality_rate": "cause.all_causes.cause_specific_mortality_rate",
|
108
|
+
"unmodeled_cause_specific_mortality_rate": "self::load_unmodeled_csmr",
|
109
|
+
"life_expectancy": "population.theoretical_minimum_risk_life_expectancy",
|
110
|
+
},
|
111
|
+
"unmodeled_causes": [],
|
112
|
+
},
|
113
|
+
}
|
114
|
+
|
115
|
+
@property
|
116
|
+
def standard_lookup_tables(self) -> List[str]:
|
117
|
+
return [
|
118
|
+
"all_cause_mortality_rate",
|
119
|
+
"life_expectancy",
|
120
|
+
]
|
121
|
+
|
66
122
|
@property
|
67
123
|
def columns_created(self) -> List[str]:
|
68
124
|
return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
|
@@ -97,14 +153,10 @@ class Mortality(Component):
|
|
97
153
|
self.clock = builder.time.clock()
|
98
154
|
|
99
155
|
self.cause_specific_mortality_rate = self.get_cause_specific_mortality_rate(builder)
|
100
|
-
self.mortality_rate = self.get_mortality_rate(builder)
|
101
156
|
|
102
|
-
self.all_cause_mortality_rate = self.get_all_cause_mortality_rate(builder)
|
103
|
-
self.life_expectancy = self.get_life_expectancy(builder)
|
104
|
-
|
105
|
-
self._raw_unmodeled_csmr = self.get_raw_unmodeled_csmr(builder)
|
106
157
|
self.unmodeled_csmr = self.get_unmodeled_csmr(builder)
|
107
158
|
self.unmodeled_csmr_paf = self.get_unmodeled_csmr_paf(builder)
|
159
|
+
self.mortality_rate = self.get_mortality_rate(builder)
|
108
160
|
|
109
161
|
#################
|
110
162
|
# Setup methods #
|
@@ -120,90 +172,37 @@ class Mortality(Component):
|
|
120
172
|
)
|
121
173
|
|
122
174
|
def get_mortality_rate(self, builder: Builder) -> Pipeline:
|
175
|
+
required_columns = get_lookup_columns(
|
176
|
+
[
|
177
|
+
self.lookup_tables["all_cause_mortality_rate"],
|
178
|
+
self.lookup_tables["unmodeled_cause_specific_mortality_rate"],
|
179
|
+
],
|
180
|
+
)
|
123
181
|
return builder.value.register_rate_producer(
|
124
182
|
self.mortality_rate_pipeline_name,
|
125
183
|
source=self.calculate_mortality_rate,
|
126
|
-
requires_columns=
|
127
|
-
)
|
128
|
-
|
129
|
-
# noinspection PyMethodMayBeStatic
|
130
|
-
def get_all_cause_mortality_rate(self, builder: Builder) -> Union[LookupTable, Pipeline]:
|
131
|
-
"""
|
132
|
-
Load all cause mortality rate data and build a lookup table or pipeline.
|
133
|
-
|
134
|
-
Parameters
|
135
|
-
----------
|
136
|
-
builder
|
137
|
-
Interface to access simulation managers.
|
138
|
-
|
139
|
-
Returns
|
140
|
-
-------
|
141
|
-
Union[LookupTable, Pipeline]
|
142
|
-
A lookup table or pipeline returning the all cause mortality rate.
|
143
|
-
"""
|
144
|
-
acmr_data = builder.data.load("cause.all_causes.cause_specific_mortality_rate")
|
145
|
-
return builder.lookup.build_table(
|
146
|
-
acmr_data, key_columns=["sex"], parameter_columns=["age", "year"]
|
184
|
+
requires_columns=required_columns,
|
147
185
|
)
|
148
186
|
|
149
|
-
|
150
|
-
|
151
|
-
"""
|
152
|
-
Load life expectancy data and build a lookup table or pipeline.
|
153
|
-
|
154
|
-
Parameters
|
155
|
-
----------
|
156
|
-
builder
|
157
|
-
Interface to access simulation managers.
|
158
|
-
|
159
|
-
Returns
|
160
|
-
-------
|
161
|
-
Union[LookupTable, Pipeline]
|
162
|
-
A lookup table or pipeline returning the life expectancy.
|
163
|
-
"""
|
164
|
-
life_expectancy_data = builder.data.load(
|
165
|
-
"population.theoretical_minimum_risk_life_expectancy"
|
166
|
-
)
|
167
|
-
return builder.lookup.build_table(life_expectancy_data, parameter_columns=["age"])
|
168
|
-
|
169
|
-
# noinspection PyMethodMayBeStatic
|
170
|
-
def get_raw_unmodeled_csmr(self, builder: Builder) -> Union[LookupTable, Pipeline]:
|
171
|
-
"""
|
172
|
-
Load unmodeled cause specific mortality rate data and build a lookup
|
173
|
-
table or pipeline.
|
174
|
-
|
175
|
-
Parameters
|
176
|
-
----------
|
177
|
-
builder
|
178
|
-
Interface to access simulation managers.
|
179
|
-
|
180
|
-
Returns
|
181
|
-
-------
|
182
|
-
Union[LookupTable, Pipeline]
|
183
|
-
A lookup table or pipeline returning the unmodeled csmr.
|
184
|
-
"""
|
185
|
-
unmodeled_causes = builder.configuration.unmodeled_causes
|
187
|
+
def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
|
188
|
+
# todo validate that all data have the same columns
|
186
189
|
raw_csmr = 0.0
|
187
|
-
for idx, cause in enumerate(unmodeled_causes):
|
190
|
+
for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
|
188
191
|
csmr = f"cause.{cause}.cause_specific_mortality_rate"
|
189
192
|
if 0 == idx:
|
190
193
|
raw_csmr = builder.data.load(csmr)
|
191
194
|
else:
|
192
195
|
raw_csmr.loc[:, "value"] += builder.data.load(csmr).value
|
193
|
-
|
194
|
-
additional_parameters = (
|
195
|
-
{"key_columns": ["sex"], "parameter_columns": ["age", "year"]}
|
196
|
-
if unmodeled_causes
|
197
|
-
else {}
|
198
|
-
)
|
199
|
-
|
200
|
-
return builder.lookup.build_table(raw_csmr, **additional_parameters)
|
196
|
+
return raw_csmr
|
201
197
|
|
202
198
|
def get_unmodeled_csmr(self, builder: Builder) -> Pipeline:
|
199
|
+
required_columns = get_lookup_columns(
|
200
|
+
[self.lookup_tables["unmodeled_cause_specific_mortality_rate"]]
|
201
|
+
)
|
203
202
|
return builder.value.register_value_producer(
|
204
203
|
self.unmodeled_csmr_pipeline_name,
|
205
204
|
source=self.get_unmodeled_csmr_source,
|
206
|
-
requires_columns=
|
205
|
+
requires_columns=required_columns,
|
207
206
|
)
|
208
207
|
|
209
208
|
def get_unmodeled_csmr_paf(self, builder: Builder) -> Pipeline:
|
@@ -248,7 +247,9 @@ class Mortality(Component):
|
|
248
247
|
)
|
249
248
|
pop.loc[deaths, "alive"] = "dead"
|
250
249
|
pop.loc[deaths, "exit_time"] = event.time
|
251
|
-
pop.loc[deaths, "years_of_life_lost"] = self.life_expectancy(
|
250
|
+
pop.loc[deaths, "years_of_life_lost"] = self.lookup_tables["life_expectancy"](
|
251
|
+
deaths
|
252
|
+
)
|
252
253
|
pop.loc[deaths, "cause_of_death"] = cause_of_death
|
253
254
|
self.population_view.update(pop)
|
254
255
|
|
@@ -257,9 +258,11 @@ class Mortality(Component):
|
|
257
258
|
##################################
|
258
259
|
|
259
260
|
def calculate_mortality_rate(self, index: pd.Index) -> pd.DataFrame:
|
260
|
-
acmr = self.all_cause_mortality_rate(index)
|
261
|
+
acmr = self.lookup_tables["all_cause_mortality_rate"](index)
|
261
262
|
modeled_csmr = self.cause_specific_mortality_rate(index)
|
262
|
-
unmodeled_csmr_raw = self.
|
263
|
+
unmodeled_csmr_raw = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](
|
264
|
+
index
|
265
|
+
)
|
263
266
|
unmodeled_csmr = self.unmodeled_csmr(index)
|
264
267
|
cause_deleted_mortality_rate = (
|
265
268
|
acmr - modeled_csmr - unmodeled_csmr_raw + unmodeled_csmr
|
@@ -267,6 +270,6 @@ class Mortality(Component):
|
|
267
270
|
return pd.DataFrame({"other_causes": cause_deleted_mortality_rate})
|
268
271
|
|
269
272
|
def get_unmodeled_csmr_source(self, index: pd.Index) -> pd.Series:
|
270
|
-
raw_csmr = self.
|
273
|
+
raw_csmr = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](index)
|
271
274
|
paf = self.unmodeled_csmr_paf(index)
|
272
275
|
return raw_csmr * (1 - paf)
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from .columns import COLUMNS
|
1
2
|
from .disability import DisabilityObserver
|
2
3
|
from .disease import DiseaseObserver
|
3
4
|
from .mortality import MortalityObserver
|
5
|
+
from .observer import PublicHealthObserver
|
4
6
|
from .risk import CategoricalRiskObserver
|
5
7
|
from .stratification import ResultsStratifier
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from typing import NamedTuple
|
2
|
+
|
3
|
+
from vivarium.framework.results import VALUE_COLUMN
|
4
|
+
|
5
|
+
|
6
|
+
class __Columns(NamedTuple):
|
7
|
+
"""column names"""
|
8
|
+
|
9
|
+
VALUE: str = VALUE_COLUMN
|
10
|
+
MEASURE: str = "measure"
|
11
|
+
TRANSITION: str = "transition"
|
12
|
+
STATE: str = "state"
|
13
|
+
ENTITY_TYPE: str = "entity_type"
|
14
|
+
SUB_ENTITY: str = "sub_entity"
|
15
|
+
ENTITY: str = "entity"
|
16
|
+
|
17
|
+
@property
|
18
|
+
def name(self) -> str:
|
19
|
+
return "columns"
|
20
|
+
|
21
|
+
|
22
|
+
COLUMNS = __Columns()
|
@@ -0,0 +1,187 @@
|
|
1
|
+
"""
|
2
|
+
===================
|
3
|
+
Disability Observer
|
4
|
+
===================
|
5
|
+
|
6
|
+
This module contains tools for observing years lived with disability (YLDs)
|
7
|
+
in the simulation.
|
8
|
+
|
9
|
+
"""
|
10
|
+
|
11
|
+
from typing import Any, List, Union
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
from layered_config_tree import LayeredConfigTree
|
15
|
+
from loguru import logger
|
16
|
+
from pandas.api.types import CategoricalDtype
|
17
|
+
from vivarium.framework.engine import Builder
|
18
|
+
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
19
|
+
|
20
|
+
from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
|
21
|
+
from vivarium_public_health.results.columns import COLUMNS
|
22
|
+
from vivarium_public_health.results.observer import PublicHealthObserver
|
23
|
+
from vivarium_public_health.results.simple_cause import SimpleCause
|
24
|
+
from vivarium_public_health.utilities import to_years
|
25
|
+
|
26
|
+
|
27
|
+
class DisabilityObserver(PublicHealthObserver):
|
28
|
+
"""Counts years lived with disability.
|
29
|
+
|
30
|
+
By default, this counts both aggregate and cause-specific years lived
|
31
|
+
with disability over the full course of the simulation.
|
32
|
+
|
33
|
+
In the model specification, your configuration for this component should
|
34
|
+
be specified as, e.g.:
|
35
|
+
|
36
|
+
.. code-block:: yaml
|
37
|
+
|
38
|
+
configuration:
|
39
|
+
observers:
|
40
|
+
disability:
|
41
|
+
exclude:
|
42
|
+
- "sex"
|
43
|
+
include:
|
44
|
+
- "sample_stratification"
|
45
|
+
"""
|
46
|
+
|
47
|
+
##############
|
48
|
+
# Properties #
|
49
|
+
##############
|
50
|
+
|
51
|
+
@property
|
52
|
+
def disability_classes(self) -> list[type]:
|
53
|
+
"""The classes to be considered for causes of disability."""
|
54
|
+
return [DiseaseState, RiskAttributableDisease]
|
55
|
+
|
56
|
+
#####################
|
57
|
+
# Lifecycle methods #
|
58
|
+
#####################
|
59
|
+
|
60
|
+
def __init__(self) -> None:
|
61
|
+
super().__init__()
|
62
|
+
self.disability_weight_pipeline_name = "all_causes.disability_weight"
|
63
|
+
|
64
|
+
#################
|
65
|
+
# Setup methods #
|
66
|
+
#################
|
67
|
+
|
68
|
+
def setup(self, builder: Builder) -> None:
|
69
|
+
self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
|
70
|
+
self.disability_weight = self.get_disability_weight_pipeline(builder)
|
71
|
+
self.set_causes_of_disability(builder)
|
72
|
+
|
73
|
+
def set_causes_of_disability(self, builder: Builder) -> None:
|
74
|
+
"""Set the causes of disability to be observed by removing any excluded
|
75
|
+
via the model spec from the list of all disability class causes. We implement
|
76
|
+
exclusions here because disabilities are unique in that they are not
|
77
|
+
registered stratifications and so cannot be excluded during the stratification
|
78
|
+
call like other categories.
|
79
|
+
"""
|
80
|
+
causes_of_disability = builder.components.get_components_by_type(
|
81
|
+
self.disability_classes
|
82
|
+
)
|
83
|
+
# Convert to SimpleCause instances and add on all_causes
|
84
|
+
causes_of_disability = [
|
85
|
+
SimpleCause.create_from_disease_state(cause) for cause in causes_of_disability
|
86
|
+
] + [SimpleCause("all_causes", "all_causes", "cause")]
|
87
|
+
|
88
|
+
excluded_causes = (
|
89
|
+
builder.configuration.stratification.excluded_categories.to_dict().get(
|
90
|
+
"disability", []
|
91
|
+
)
|
92
|
+
)
|
93
|
+
|
94
|
+
# Handle exclusions that don't exist in the list of causes
|
95
|
+
cause_names = [cause.state_id for cause in causes_of_disability]
|
96
|
+
unknown_exclusions = set(excluded_causes) - set(cause_names)
|
97
|
+
if len(unknown_exclusions) > 0:
|
98
|
+
raise ValueError(
|
99
|
+
f"Excluded 'disability' causes {unknown_exclusions} not found in "
|
100
|
+
f"expected categories categories: {cause_names}"
|
101
|
+
)
|
102
|
+
|
103
|
+
# Drop excluded causes
|
104
|
+
if excluded_causes:
|
105
|
+
logger.debug(
|
106
|
+
f"'disability' has category exclusion requests: {excluded_causes}\n"
|
107
|
+
"Removing these from the allowable categories."
|
108
|
+
)
|
109
|
+
self.causes_of_disability = [
|
110
|
+
cause for cause in causes_of_disability if cause.state_id not in excluded_causes
|
111
|
+
]
|
112
|
+
|
113
|
+
def get_configuration(self, builder: Builder) -> LayeredConfigTree:
|
114
|
+
return builder.configuration.stratification.disability
|
115
|
+
|
116
|
+
def register_observations(self, builder: Builder) -> None:
|
117
|
+
cause_pipelines = [
|
118
|
+
f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
|
119
|
+
]
|
120
|
+
self.register_adding_observation(
|
121
|
+
builder=builder,
|
122
|
+
name="ylds",
|
123
|
+
pop_filter='tracked == True and alive == "alive"',
|
124
|
+
when="time_step__prepare",
|
125
|
+
requires_columns=["alive"],
|
126
|
+
requires_values=cause_pipelines,
|
127
|
+
additional_stratifications=self.configuration.include,
|
128
|
+
excluded_stratifications=self.configuration.exclude,
|
129
|
+
aggregator_sources=cause_pipelines,
|
130
|
+
aggregator=self.disability_weight_aggregator,
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
|
134
|
+
return builder.value.register_value_producer(
|
135
|
+
self.disability_weight_pipeline_name,
|
136
|
+
source=lambda index: [pd.Series(0.0, index=index)],
|
137
|
+
preferred_combiner=list_combiner,
|
138
|
+
preferred_post_processor=union_post_processor,
|
139
|
+
)
|
140
|
+
|
141
|
+
###############
|
142
|
+
# Aggregators #
|
143
|
+
###############
|
144
|
+
|
145
|
+
def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
|
146
|
+
aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
|
147
|
+
if isinstance(aggregated_dw, pd.Series):
|
148
|
+
aggregated_dw.index.name = "cause_of_disability"
|
149
|
+
return aggregated_dw
|
150
|
+
|
151
|
+
##############################
|
152
|
+
# Results formatting methods #
|
153
|
+
##############################
|
154
|
+
|
155
|
+
def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
|
156
|
+
"""Format results. Note that ylds are unique in that we
|
157
|
+
can't stratify by cause of disability (because there can be multiple at
|
158
|
+
once), and so the results here are actually wide by disability weight
|
159
|
+
pipeline name.
|
160
|
+
"""
|
161
|
+
|
162
|
+
# Drop the unused 'value' column and rename the pipeline names to causes
|
163
|
+
results = results.drop(columns=["value"]).rename(
|
164
|
+
columns={col: col.replace(".disability_weight", "") for col in results.columns},
|
165
|
+
)
|
166
|
+
# Get desired index names prior to stacking
|
167
|
+
idx_names = list(results.index.names) + [COLUMNS.SUB_ENTITY]
|
168
|
+
results = pd.DataFrame(results.stack(), columns=[COLUMNS.VALUE])
|
169
|
+
# Name the new index level
|
170
|
+
results.index.set_names(idx_names, inplace=True)
|
171
|
+
results = results.reset_index()
|
172
|
+
results[COLUMNS.SUB_ENTITY] = results[COLUMNS.SUB_ENTITY].astype(CategoricalDtype())
|
173
|
+
return results
|
174
|
+
|
175
|
+
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
176
|
+
entity_type_map = {
|
177
|
+
cause.state_id: cause.cause_type for cause in self.causes_of_disability
|
178
|
+
}
|
179
|
+
return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
|
180
|
+
|
181
|
+
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
182
|
+
entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
|
183
|
+
return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
|
184
|
+
|
185
|
+
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
186
|
+
# The sub-entity col was created in the 'format' method
|
187
|
+
return results[COLUMNS.SUB_ENTITY]
|
@@ -0,0 +1,222 @@
|
|
1
|
+
"""
|
2
|
+
================
|
3
|
+
Disease Observer
|
4
|
+
================
|
5
|
+
|
6
|
+
This module contains tools for observing disease incidence and prevalence
|
7
|
+
in the simulation.
|
8
|
+
|
9
|
+
"""
|
10
|
+
|
11
|
+
from typing import Any, Dict, List
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
from layered_config_tree import LayeredConfigTree
|
15
|
+
from vivarium.framework.engine import Builder
|
16
|
+
from vivarium.framework.event import Event
|
17
|
+
from vivarium.framework.population import SimulantData
|
18
|
+
|
19
|
+
from vivarium_public_health.results.columns import COLUMNS
|
20
|
+
from vivarium_public_health.results.observer import PublicHealthObserver
|
21
|
+
from vivarium_public_health.utilities import to_years
|
22
|
+
|
23
|
+
|
24
|
+
class DiseaseObserver(PublicHealthObserver):
|
25
|
+
"""Observes disease counts and person time for a cause.
|
26
|
+
|
27
|
+
By default, this observer computes aggregate disease state person time and
|
28
|
+
counts of disease events over the full course of the simulation. It can be
|
29
|
+
configured to add or remove stratification groups to the default groups
|
30
|
+
defined by a ResultsStratifier.
|
31
|
+
|
32
|
+
In the model specification, your configuration for this component should
|
33
|
+
be specified as, e.g.:
|
34
|
+
|
35
|
+
.. code-block:: yaml
|
36
|
+
|
37
|
+
configuration:
|
38
|
+
stratification:
|
39
|
+
cause_name:
|
40
|
+
exclude:
|
41
|
+
- "sex"
|
42
|
+
include:
|
43
|
+
- "sample_stratification"
|
44
|
+
"""
|
45
|
+
|
46
|
+
##############
|
47
|
+
# Properties #
|
48
|
+
##############
|
49
|
+
|
50
|
+
@property
|
51
|
+
def configuration_defaults(self) -> Dict[str, Any]:
|
52
|
+
return {
|
53
|
+
"stratification": {
|
54
|
+
self.disease: super().configuration_defaults["stratification"][
|
55
|
+
self.get_configuration_name()
|
56
|
+
]
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
@property
|
61
|
+
def columns_created(self) -> List[str]:
|
62
|
+
return [self.previous_state_column_name]
|
63
|
+
|
64
|
+
@property
|
65
|
+
def columns_required(self) -> List[str]:
|
66
|
+
return [self.disease]
|
67
|
+
|
68
|
+
@property
|
69
|
+
def initialization_requirements(self) -> Dict[str, List[str]]:
|
70
|
+
return {
|
71
|
+
"requires_columns": [self.disease],
|
72
|
+
}
|
73
|
+
|
74
|
+
#####################
|
75
|
+
# Lifecycle methods #
|
76
|
+
#####################
|
77
|
+
|
78
|
+
def __init__(self, disease: str) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.disease = disease
|
81
|
+
self.previous_state_column_name = f"previous_{self.disease}"
|
82
|
+
|
83
|
+
#################
|
84
|
+
# Setup methods #
|
85
|
+
#################
|
86
|
+
|
87
|
+
def setup(self, builder: Builder) -> None:
|
88
|
+
self.step_size = builder.time.step_size()
|
89
|
+
self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
|
90
|
+
self.entity_type = self.disease_model.cause_type
|
91
|
+
self.entity = self.disease_model.cause
|
92
|
+
self.transition_stratification_name = f"transition_{self.disease}"
|
93
|
+
|
94
|
+
def get_configuration(self, builder: Builder) -> LayeredConfigTree:
|
95
|
+
return builder.configuration.stratification[self.disease]
|
96
|
+
|
97
|
+
def register_observations(self, builder: Builder) -> None:
|
98
|
+
|
99
|
+
self.register_disease_state_stratification(builder)
|
100
|
+
self.register_transition_stratification(builder)
|
101
|
+
|
102
|
+
pop_filter = 'alive == "alive" and tracked==True'
|
103
|
+
self.register_person_time_observation(builder, pop_filter)
|
104
|
+
self.register_transition_count_observation(builder, pop_filter)
|
105
|
+
|
106
|
+
def register_disease_state_stratification(self, builder: Builder) -> None:
|
107
|
+
builder.results.register_stratification(
|
108
|
+
self.disease,
|
109
|
+
[state.state_id for state in self.disease_model.states],
|
110
|
+
requires_columns=[self.disease],
|
111
|
+
)
|
112
|
+
|
113
|
+
def register_transition_stratification(self, builder: Builder) -> None:
|
114
|
+
transitions = [
|
115
|
+
str(transition) for transition in self.disease_model.transition_names
|
116
|
+
] + ["no_transition"]
|
117
|
+
# manually append 'no_transition' as an excluded transition
|
118
|
+
excluded_categories = (
|
119
|
+
builder.configuration.stratification.excluded_categories.to_dict().get(
|
120
|
+
self.transition_stratification_name, []
|
121
|
+
)
|
122
|
+
) + ["no_transition"]
|
123
|
+
builder.results.register_stratification(
|
124
|
+
self.transition_stratification_name,
|
125
|
+
categories=transitions,
|
126
|
+
excluded_categories=excluded_categories,
|
127
|
+
mapper=self.map_transitions,
|
128
|
+
requires_columns=[self.disease, self.previous_state_column_name],
|
129
|
+
is_vectorized=True,
|
130
|
+
)
|
131
|
+
|
132
|
+
def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
|
133
|
+
self.register_adding_observation(
|
134
|
+
builder=builder,
|
135
|
+
name=f"person_time_{self.disease}",
|
136
|
+
pop_filter=pop_filter,
|
137
|
+
when="time_step__prepare",
|
138
|
+
requires_columns=["alive", self.disease],
|
139
|
+
additional_stratifications=self.configuration.include + [self.disease],
|
140
|
+
excluded_stratifications=self.configuration.exclude,
|
141
|
+
aggregator=self.aggregate_state_person_time,
|
142
|
+
)
|
143
|
+
|
144
|
+
def register_transition_count_observation(
|
145
|
+
self, builder: Builder, pop_filter: str
|
146
|
+
) -> None:
|
147
|
+
self.register_adding_observation(
|
148
|
+
builder=builder,
|
149
|
+
name=f"transition_count_{self.disease}",
|
150
|
+
pop_filter=pop_filter,
|
151
|
+
requires_columns=[
|
152
|
+
self.previous_state_column_name,
|
153
|
+
self.disease,
|
154
|
+
],
|
155
|
+
additional_stratifications=self.configuration.include
|
156
|
+
+ [self.transition_stratification_name],
|
157
|
+
excluded_stratifications=self.configuration.exclude,
|
158
|
+
)
|
159
|
+
|
160
|
+
def map_transitions(self, df: pd.DataFrame) -> pd.Series:
|
161
|
+
transitions = pd.Series(index=df.index, dtype=str)
|
162
|
+
transition_mask = df[self.previous_state_column_name] != df[self.disease]
|
163
|
+
transitions[~transition_mask] = "no_transition"
|
164
|
+
transitions[transition_mask] = (
|
165
|
+
df[self.previous_state_column_name].astype(str)
|
166
|
+
+ "_to_"
|
167
|
+
+ df[self.disease].astype(str)
|
168
|
+
)
|
169
|
+
return transitions
|
170
|
+
|
171
|
+
########################
|
172
|
+
# Event-driven methods #
|
173
|
+
########################
|
174
|
+
|
175
|
+
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
176
|
+
"""Initialize the previous state column to the current state"""
|
177
|
+
pop = self.population_view.subview([self.disease]).get(pop_data.index)
|
178
|
+
pop[self.previous_state_column_name] = pop[self.disease]
|
179
|
+
self.population_view.update(pop)
|
180
|
+
|
181
|
+
def on_time_step_prepare(self, event: Event) -> None:
|
182
|
+
# This enables tracking of transitions between states
|
183
|
+
prior_state_pop = self.population_view.get(event.index)
|
184
|
+
prior_state_pop[self.previous_state_column_name] = prior_state_pop[self.disease]
|
185
|
+
self.population_view.update(prior_state_pop)
|
186
|
+
|
187
|
+
###############
|
188
|
+
# Aggregators #
|
189
|
+
###############
|
190
|
+
|
191
|
+
def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
|
192
|
+
return len(x) * to_years(self.step_size())
|
193
|
+
|
194
|
+
##############################
|
195
|
+
# Results formatting methods #
|
196
|
+
##############################
|
197
|
+
|
198
|
+
def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
|
199
|
+
results = results.reset_index()
|
200
|
+
if "transition_count_" in measure:
|
201
|
+
sub_entity = self.transition_stratification_name
|
202
|
+
if "person_time_" in measure:
|
203
|
+
sub_entity = self.disease
|
204
|
+
results.rename(columns={sub_entity: COLUMNS.SUB_ENTITY}, inplace=True)
|
205
|
+
return results
|
206
|
+
|
207
|
+
def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
208
|
+
if "transition_count_" in measure:
|
209
|
+
measure_name = "transition_count"
|
210
|
+
if "person_time_" in measure:
|
211
|
+
measure_name = "person_time"
|
212
|
+
return pd.Series(measure_name, index=results.index)
|
213
|
+
|
214
|
+
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
215
|
+
return pd.Series(self.entity_type, index=results.index)
|
216
|
+
|
217
|
+
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
218
|
+
return pd.Series(self.entity, index=results.index)
|
219
|
+
|
220
|
+
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
221
|
+
# The sub-entity col was created in the 'format' method
|
222
|
+
return results[COLUMNS.SUB_ENTITY]
|