vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/model.py +23 -21
- vivarium_public_health/disease/models.py +1 -0
- vivarium_public_health/disease/special_disease.py +40 -41
- vivarium_public_health/disease/state.py +42 -125
- vivarium_public_health/disease/transition.py +70 -27
- vivarium_public_health/mslt/delay.py +1 -0
- vivarium_public_health/mslt/disease.py +1 -0
- vivarium_public_health/mslt/intervention.py +1 -0
- vivarium_public_health/mslt/magic_wand_components.py +1 -0
- vivarium_public_health/mslt/observer.py +1 -0
- vivarium_public_health/mslt/population.py +1 -0
- vivarium_public_health/plugins/parser.py +61 -31
- vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
- vivarium_public_health/population/base_population.py +2 -1
- vivarium_public_health/population/mortality.py +83 -80
- vivarium_public_health/{metrics → results}/__init__.py +2 -0
- vivarium_public_health/results/columns.py +22 -0
- vivarium_public_health/results/disability.py +187 -0
- vivarium_public_health/results/disease.py +222 -0
- vivarium_public_health/results/mortality.py +186 -0
- vivarium_public_health/results/observer.py +78 -0
- vivarium_public_health/results/risk.py +138 -0
- vivarium_public_health/results/simple_cause.py +18 -0
- vivarium_public_health/{metrics → results}/stratification.py +10 -8
- vivarium_public_health/risks/__init__.py +1 -2
- vivarium_public_health/risks/base_risk.py +134 -29
- vivarium_public_health/risks/data_transformations.py +65 -326
- vivarium_public_health/risks/distributions.py +315 -145
- vivarium_public_health/risks/effect.py +376 -75
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
- vivarium_public_health/treatment/magic_wand.py +1 -0
- vivarium_public_health/treatment/scale_up.py +1 -0
- vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
- vivarium_public_health/utilities.py +17 -2
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
- vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
- vivarium_public_health/metrics/disability.py +0 -118
- vivarium_public_health/metrics/disease.py +0 -136
- vivarium_public_health/metrics/mortality.py +0 -144
- vivarium_public_health/metrics/risk.py +0 -110
- vivarium_public_health/testing/__init__.py +0 -0
- vivarium_public_health/testing/mock_artifact.py +0 -145
- vivarium_public_health/testing/utils.py +0 -71
- vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -44,25 +44,81 @@ on mortality is calculated, by subtracting the raw unmodeled csmr before adding
|
|
44
44
|
back the modified unmodeled csmr.
|
45
45
|
|
46
46
|
"""
|
47
|
-
|
47
|
+
|
48
|
+
from typing import Any, Dict, List, Optional, Union
|
48
49
|
|
49
50
|
import pandas as pd
|
50
51
|
from vivarium import Component
|
51
52
|
from vivarium.framework.engine import Builder
|
52
53
|
from vivarium.framework.event import Event
|
53
|
-
from vivarium.framework.lookup import LookupTable
|
54
54
|
from vivarium.framework.population import SimulantData
|
55
55
|
from vivarium.framework.randomness import RandomnessStream
|
56
56
|
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
57
57
|
|
58
|
+
from vivarium_public_health.utilities import get_lookup_columns
|
59
|
+
|
58
60
|
|
59
61
|
class Mortality(Component):
|
60
|
-
|
62
|
+
"""
|
63
|
+
This is the mortality component which models sources of mortality for a model.
|
64
|
+
THe component models all cause mortality and allows for disease models to contribute
|
65
|
+
cause specific mortality. Data used by this class should be supplied in the artifact
|
66
|
+
and is configurable in the configuration to build lookup tables. For instance, let's
|
67
|
+
say we want to use sex and hair color to build a lookup table for all cause mortality.
|
68
|
+
|
69
|
+
.. code-block:: yaml
|
70
|
+
|
71
|
+
configuration:
|
72
|
+
mortality:
|
73
|
+
all_cause_mortality_rate:
|
74
|
+
categorical_columns: ["sex", "hair_color"]
|
75
|
+
|
76
|
+
Similarly, we can do the same thing for unmodeled causes. Here is an example:
|
77
|
+
|
78
|
+
.. code-block:: yaml
|
79
|
+
|
80
|
+
configuration:
|
81
|
+
mortality:
|
82
|
+
unmodeled_cause_specific_mortality_rate:
|
83
|
+
unmodeled_causes: ["maternal_disorders", maternal_hemorrhage]
|
84
|
+
categorical_columns: ["sex", "hair_color"]
|
85
|
+
|
86
|
+
Or if we wanted to make the data a scalar value for all cause mortality rate we could
|
87
|
+
configure that as well.
|
88
|
+
|
89
|
+
.. code-block:: yaml
|
90
|
+
|
91
|
+
configuration:
|
92
|
+
mortality:
|
93
|
+
all_cause_mortality_rate:
|
94
|
+
value: 0.01
|
95
|
+
|
96
|
+
"""
|
61
97
|
|
62
98
|
##############
|
63
99
|
# Properties #
|
64
100
|
##############
|
65
101
|
|
102
|
+
@property
|
103
|
+
def configuration_defaults(self) -> Dict[str, Any]:
|
104
|
+
return {
|
105
|
+
"mortality": {
|
106
|
+
"data_sources": {
|
107
|
+
"all_cause_mortality_rate": "cause.all_causes.cause_specific_mortality_rate",
|
108
|
+
"unmodeled_cause_specific_mortality_rate": "self::load_unmodeled_csmr",
|
109
|
+
"life_expectancy": "population.theoretical_minimum_risk_life_expectancy",
|
110
|
+
},
|
111
|
+
"unmodeled_causes": [],
|
112
|
+
},
|
113
|
+
}
|
114
|
+
|
115
|
+
@property
|
116
|
+
def standard_lookup_tables(self) -> List[str]:
|
117
|
+
return [
|
118
|
+
"all_cause_mortality_rate",
|
119
|
+
"life_expectancy",
|
120
|
+
]
|
121
|
+
|
66
122
|
@property
|
67
123
|
def columns_created(self) -> List[str]:
|
68
124
|
return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
|
@@ -97,14 +153,10 @@ class Mortality(Component):
|
|
97
153
|
self.clock = builder.time.clock()
|
98
154
|
|
99
155
|
self.cause_specific_mortality_rate = self.get_cause_specific_mortality_rate(builder)
|
100
|
-
self.mortality_rate = self.get_mortality_rate(builder)
|
101
156
|
|
102
|
-
self.all_cause_mortality_rate = self.get_all_cause_mortality_rate(builder)
|
103
|
-
self.life_expectancy = self.get_life_expectancy(builder)
|
104
|
-
|
105
|
-
self._raw_unmodeled_csmr = self.get_raw_unmodeled_csmr(builder)
|
106
157
|
self.unmodeled_csmr = self.get_unmodeled_csmr(builder)
|
107
158
|
self.unmodeled_csmr_paf = self.get_unmodeled_csmr_paf(builder)
|
159
|
+
self.mortality_rate = self.get_mortality_rate(builder)
|
108
160
|
|
109
161
|
#################
|
110
162
|
# Setup methods #
|
@@ -120,90 +172,37 @@ class Mortality(Component):
|
|
120
172
|
)
|
121
173
|
|
122
174
|
def get_mortality_rate(self, builder: Builder) -> Pipeline:
|
175
|
+
required_columns = get_lookup_columns(
|
176
|
+
[
|
177
|
+
self.lookup_tables["all_cause_mortality_rate"],
|
178
|
+
self.lookup_tables["unmodeled_cause_specific_mortality_rate"],
|
179
|
+
],
|
180
|
+
)
|
123
181
|
return builder.value.register_rate_producer(
|
124
182
|
self.mortality_rate_pipeline_name,
|
125
183
|
source=self.calculate_mortality_rate,
|
126
|
-
requires_columns=
|
127
|
-
)
|
128
|
-
|
129
|
-
# noinspection PyMethodMayBeStatic
|
130
|
-
def get_all_cause_mortality_rate(self, builder: Builder) -> Union[LookupTable, Pipeline]:
|
131
|
-
"""
|
132
|
-
Load all cause mortality rate data and build a lookup table or pipeline.
|
133
|
-
|
134
|
-
Parameters
|
135
|
-
----------
|
136
|
-
builder
|
137
|
-
Interface to access simulation managers.
|
138
|
-
|
139
|
-
Returns
|
140
|
-
-------
|
141
|
-
Union[LookupTable, Pipeline]
|
142
|
-
A lookup table or pipeline returning the all cause mortality rate.
|
143
|
-
"""
|
144
|
-
acmr_data = builder.data.load("cause.all_causes.cause_specific_mortality_rate")
|
145
|
-
return builder.lookup.build_table(
|
146
|
-
acmr_data, key_columns=["sex"], parameter_columns=["age", "year"]
|
184
|
+
requires_columns=required_columns,
|
147
185
|
)
|
148
186
|
|
149
|
-
|
150
|
-
|
151
|
-
"""
|
152
|
-
Load life expectancy data and build a lookup table or pipeline.
|
153
|
-
|
154
|
-
Parameters
|
155
|
-
----------
|
156
|
-
builder
|
157
|
-
Interface to access simulation managers.
|
158
|
-
|
159
|
-
Returns
|
160
|
-
-------
|
161
|
-
Union[LookupTable, Pipeline]
|
162
|
-
A lookup table or pipeline returning the life expectancy.
|
163
|
-
"""
|
164
|
-
life_expectancy_data = builder.data.load(
|
165
|
-
"population.theoretical_minimum_risk_life_expectancy"
|
166
|
-
)
|
167
|
-
return builder.lookup.build_table(life_expectancy_data, parameter_columns=["age"])
|
168
|
-
|
169
|
-
# noinspection PyMethodMayBeStatic
|
170
|
-
def get_raw_unmodeled_csmr(self, builder: Builder) -> Union[LookupTable, Pipeline]:
|
171
|
-
"""
|
172
|
-
Load unmodeled cause specific mortality rate data and build a lookup
|
173
|
-
table or pipeline.
|
174
|
-
|
175
|
-
Parameters
|
176
|
-
----------
|
177
|
-
builder
|
178
|
-
Interface to access simulation managers.
|
179
|
-
|
180
|
-
Returns
|
181
|
-
-------
|
182
|
-
Union[LookupTable, Pipeline]
|
183
|
-
A lookup table or pipeline returning the unmodeled csmr.
|
184
|
-
"""
|
185
|
-
unmodeled_causes = builder.configuration.unmodeled_causes
|
187
|
+
def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
|
188
|
+
# todo validate that all data have the same columns
|
186
189
|
raw_csmr = 0.0
|
187
|
-
for idx, cause in enumerate(unmodeled_causes):
|
190
|
+
for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
|
188
191
|
csmr = f"cause.{cause}.cause_specific_mortality_rate"
|
189
192
|
if 0 == idx:
|
190
193
|
raw_csmr = builder.data.load(csmr)
|
191
194
|
else:
|
192
195
|
raw_csmr.loc[:, "value"] += builder.data.load(csmr).value
|
193
|
-
|
194
|
-
additional_parameters = (
|
195
|
-
{"key_columns": ["sex"], "parameter_columns": ["age", "year"]}
|
196
|
-
if unmodeled_causes
|
197
|
-
else {}
|
198
|
-
)
|
199
|
-
|
200
|
-
return builder.lookup.build_table(raw_csmr, **additional_parameters)
|
196
|
+
return raw_csmr
|
201
197
|
|
202
198
|
def get_unmodeled_csmr(self, builder: Builder) -> Pipeline:
|
199
|
+
required_columns = get_lookup_columns(
|
200
|
+
[self.lookup_tables["unmodeled_cause_specific_mortality_rate"]]
|
201
|
+
)
|
203
202
|
return builder.value.register_value_producer(
|
204
203
|
self.unmodeled_csmr_pipeline_name,
|
205
204
|
source=self.get_unmodeled_csmr_source,
|
206
|
-
requires_columns=
|
205
|
+
requires_columns=required_columns,
|
207
206
|
)
|
208
207
|
|
209
208
|
def get_unmodeled_csmr_paf(self, builder: Builder) -> Pipeline:
|
@@ -248,7 +247,9 @@ class Mortality(Component):
|
|
248
247
|
)
|
249
248
|
pop.loc[deaths, "alive"] = "dead"
|
250
249
|
pop.loc[deaths, "exit_time"] = event.time
|
251
|
-
pop.loc[deaths, "years_of_life_lost"] = self.life_expectancy(
|
250
|
+
pop.loc[deaths, "years_of_life_lost"] = self.lookup_tables["life_expectancy"](
|
251
|
+
deaths
|
252
|
+
)
|
252
253
|
pop.loc[deaths, "cause_of_death"] = cause_of_death
|
253
254
|
self.population_view.update(pop)
|
254
255
|
|
@@ -257,9 +258,11 @@ class Mortality(Component):
|
|
257
258
|
##################################
|
258
259
|
|
259
260
|
def calculate_mortality_rate(self, index: pd.Index) -> pd.DataFrame:
|
260
|
-
acmr = self.all_cause_mortality_rate(index)
|
261
|
+
acmr = self.lookup_tables["all_cause_mortality_rate"](index)
|
261
262
|
modeled_csmr = self.cause_specific_mortality_rate(index)
|
262
|
-
unmodeled_csmr_raw = self.
|
263
|
+
unmodeled_csmr_raw = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](
|
264
|
+
index
|
265
|
+
)
|
263
266
|
unmodeled_csmr = self.unmodeled_csmr(index)
|
264
267
|
cause_deleted_mortality_rate = (
|
265
268
|
acmr - modeled_csmr - unmodeled_csmr_raw + unmodeled_csmr
|
@@ -267,6 +270,6 @@ class Mortality(Component):
|
|
267
270
|
return pd.DataFrame({"other_causes": cause_deleted_mortality_rate})
|
268
271
|
|
269
272
|
def get_unmodeled_csmr_source(self, index: pd.Index) -> pd.Series:
|
270
|
-
raw_csmr = self.
|
273
|
+
raw_csmr = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](index)
|
271
274
|
paf = self.unmodeled_csmr_paf(index)
|
272
275
|
return raw_csmr * (1 - paf)
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from .columns import COLUMNS
|
1
2
|
from .disability import DisabilityObserver
|
2
3
|
from .disease import DiseaseObserver
|
3
4
|
from .mortality import MortalityObserver
|
5
|
+
from .observer import PublicHealthObserver
|
4
6
|
from .risk import CategoricalRiskObserver
|
5
7
|
from .stratification import ResultsStratifier
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from typing import NamedTuple
|
2
|
+
|
3
|
+
from vivarium.framework.results import VALUE_COLUMN
|
4
|
+
|
5
|
+
|
6
|
+
class __Columns(NamedTuple):
|
7
|
+
"""column names"""
|
8
|
+
|
9
|
+
VALUE: str = VALUE_COLUMN
|
10
|
+
MEASURE: str = "measure"
|
11
|
+
TRANSITION: str = "transition"
|
12
|
+
STATE: str = "state"
|
13
|
+
ENTITY_TYPE: str = "entity_type"
|
14
|
+
SUB_ENTITY: str = "sub_entity"
|
15
|
+
ENTITY: str = "entity"
|
16
|
+
|
17
|
+
@property
|
18
|
+
def name(self) -> str:
|
19
|
+
return "columns"
|
20
|
+
|
21
|
+
|
22
|
+
COLUMNS = __Columns()
|
@@ -0,0 +1,187 @@
|
|
1
|
+
"""
|
2
|
+
===================
|
3
|
+
Disability Observer
|
4
|
+
===================
|
5
|
+
|
6
|
+
This module contains tools for observing years lived with disability (YLDs)
|
7
|
+
in the simulation.
|
8
|
+
|
9
|
+
"""
|
10
|
+
|
11
|
+
from typing import Any, List, Union
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
from layered_config_tree import LayeredConfigTree
|
15
|
+
from loguru import logger
|
16
|
+
from pandas.api.types import CategoricalDtype
|
17
|
+
from vivarium.framework.engine import Builder
|
18
|
+
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
|
19
|
+
|
20
|
+
from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
|
21
|
+
from vivarium_public_health.results.columns import COLUMNS
|
22
|
+
from vivarium_public_health.results.observer import PublicHealthObserver
|
23
|
+
from vivarium_public_health.results.simple_cause import SimpleCause
|
24
|
+
from vivarium_public_health.utilities import to_years
|
25
|
+
|
26
|
+
|
27
|
+
class DisabilityObserver(PublicHealthObserver):
|
28
|
+
"""Counts years lived with disability.
|
29
|
+
|
30
|
+
By default, this counts both aggregate and cause-specific years lived
|
31
|
+
with disability over the full course of the simulation.
|
32
|
+
|
33
|
+
In the model specification, your configuration for this component should
|
34
|
+
be specified as, e.g.:
|
35
|
+
|
36
|
+
.. code-block:: yaml
|
37
|
+
|
38
|
+
configuration:
|
39
|
+
observers:
|
40
|
+
disability:
|
41
|
+
exclude:
|
42
|
+
- "sex"
|
43
|
+
include:
|
44
|
+
- "sample_stratification"
|
45
|
+
"""
|
46
|
+
|
47
|
+
##############
|
48
|
+
# Properties #
|
49
|
+
##############
|
50
|
+
|
51
|
+
@property
|
52
|
+
def disability_classes(self) -> list[type]:
|
53
|
+
"""The classes to be considered for causes of disability."""
|
54
|
+
return [DiseaseState, RiskAttributableDisease]
|
55
|
+
|
56
|
+
#####################
|
57
|
+
# Lifecycle methods #
|
58
|
+
#####################
|
59
|
+
|
60
|
+
def __init__(self) -> None:
|
61
|
+
super().__init__()
|
62
|
+
self.disability_weight_pipeline_name = "all_causes.disability_weight"
|
63
|
+
|
64
|
+
#################
|
65
|
+
# Setup methods #
|
66
|
+
#################
|
67
|
+
|
68
|
+
def setup(self, builder: Builder) -> None:
|
69
|
+
self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
|
70
|
+
self.disability_weight = self.get_disability_weight_pipeline(builder)
|
71
|
+
self.set_causes_of_disability(builder)
|
72
|
+
|
73
|
+
def set_causes_of_disability(self, builder: Builder) -> None:
|
74
|
+
"""Set the causes of disability to be observed by removing any excluded
|
75
|
+
via the model spec from the list of all disability class causes. We implement
|
76
|
+
exclusions here because disabilities are unique in that they are not
|
77
|
+
registered stratifications and so cannot be excluded during the stratification
|
78
|
+
call like other categories.
|
79
|
+
"""
|
80
|
+
causes_of_disability = builder.components.get_components_by_type(
|
81
|
+
self.disability_classes
|
82
|
+
)
|
83
|
+
# Convert to SimpleCause instances and add on all_causes
|
84
|
+
causes_of_disability = [
|
85
|
+
SimpleCause.create_from_disease_state(cause) for cause in causes_of_disability
|
86
|
+
] + [SimpleCause("all_causes", "all_causes", "cause")]
|
87
|
+
|
88
|
+
excluded_causes = (
|
89
|
+
builder.configuration.stratification.excluded_categories.to_dict().get(
|
90
|
+
"disability", []
|
91
|
+
)
|
92
|
+
)
|
93
|
+
|
94
|
+
# Handle exclusions that don't exist in the list of causes
|
95
|
+
cause_names = [cause.state_id for cause in causes_of_disability]
|
96
|
+
unknown_exclusions = set(excluded_causes) - set(cause_names)
|
97
|
+
if len(unknown_exclusions) > 0:
|
98
|
+
raise ValueError(
|
99
|
+
f"Excluded 'disability' causes {unknown_exclusions} not found in "
|
100
|
+
f"expected categories categories: {cause_names}"
|
101
|
+
)
|
102
|
+
|
103
|
+
# Drop excluded causes
|
104
|
+
if excluded_causes:
|
105
|
+
logger.debug(
|
106
|
+
f"'disability' has category exclusion requests: {excluded_causes}\n"
|
107
|
+
"Removing these from the allowable categories."
|
108
|
+
)
|
109
|
+
self.causes_of_disability = [
|
110
|
+
cause for cause in causes_of_disability if cause.state_id not in excluded_causes
|
111
|
+
]
|
112
|
+
|
113
|
+
def get_configuration(self, builder: Builder) -> LayeredConfigTree:
|
114
|
+
return builder.configuration.stratification.disability
|
115
|
+
|
116
|
+
def register_observations(self, builder: Builder) -> None:
|
117
|
+
cause_pipelines = [
|
118
|
+
f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
|
119
|
+
]
|
120
|
+
self.register_adding_observation(
|
121
|
+
builder=builder,
|
122
|
+
name="ylds",
|
123
|
+
pop_filter='tracked == True and alive == "alive"',
|
124
|
+
when="time_step__prepare",
|
125
|
+
requires_columns=["alive"],
|
126
|
+
requires_values=cause_pipelines,
|
127
|
+
additional_stratifications=self.configuration.include,
|
128
|
+
excluded_stratifications=self.configuration.exclude,
|
129
|
+
aggregator_sources=cause_pipelines,
|
130
|
+
aggregator=self.disability_weight_aggregator,
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
|
134
|
+
return builder.value.register_value_producer(
|
135
|
+
self.disability_weight_pipeline_name,
|
136
|
+
source=lambda index: [pd.Series(0.0, index=index)],
|
137
|
+
preferred_combiner=list_combiner,
|
138
|
+
preferred_post_processor=union_post_processor,
|
139
|
+
)
|
140
|
+
|
141
|
+
###############
|
142
|
+
# Aggregators #
|
143
|
+
###############
|
144
|
+
|
145
|
+
def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
|
146
|
+
aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
|
147
|
+
if isinstance(aggregated_dw, pd.Series):
|
148
|
+
aggregated_dw.index.name = "cause_of_disability"
|
149
|
+
return aggregated_dw
|
150
|
+
|
151
|
+
##############################
|
152
|
+
# Results formatting methods #
|
153
|
+
##############################
|
154
|
+
|
155
|
+
def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
|
156
|
+
"""Format results. Note that ylds are unique in that we
|
157
|
+
can't stratify by cause of disability (because there can be multiple at
|
158
|
+
once), and so the results here are actually wide by disability weight
|
159
|
+
pipeline name.
|
160
|
+
"""
|
161
|
+
|
162
|
+
# Drop the unused 'value' column and rename the pipeline names to causes
|
163
|
+
results = results.drop(columns=["value"]).rename(
|
164
|
+
columns={col: col.replace(".disability_weight", "") for col in results.columns},
|
165
|
+
)
|
166
|
+
# Get desired index names prior to stacking
|
167
|
+
idx_names = list(results.index.names) + [COLUMNS.SUB_ENTITY]
|
168
|
+
results = pd.DataFrame(results.stack(), columns=[COLUMNS.VALUE])
|
169
|
+
# Name the new index level
|
170
|
+
results.index.set_names(idx_names, inplace=True)
|
171
|
+
results = results.reset_index()
|
172
|
+
results[COLUMNS.SUB_ENTITY] = results[COLUMNS.SUB_ENTITY].astype(CategoricalDtype())
|
173
|
+
return results
|
174
|
+
|
175
|
+
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
176
|
+
entity_type_map = {
|
177
|
+
cause.state_id: cause.cause_type for cause in self.causes_of_disability
|
178
|
+
}
|
179
|
+
return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
|
180
|
+
|
181
|
+
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
182
|
+
entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
|
183
|
+
return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
|
184
|
+
|
185
|
+
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
186
|
+
# The sub-entity col was created in the 'format' method
|
187
|
+
return results[COLUMNS.SUB_ENTITY]
|
@@ -0,0 +1,222 @@
|
|
1
|
+
"""
|
2
|
+
================
|
3
|
+
Disease Observer
|
4
|
+
================
|
5
|
+
|
6
|
+
This module contains tools for observing disease incidence and prevalence
|
7
|
+
in the simulation.
|
8
|
+
|
9
|
+
"""
|
10
|
+
|
11
|
+
from typing import Any, Dict, List
|
12
|
+
|
13
|
+
import pandas as pd
|
14
|
+
from layered_config_tree import LayeredConfigTree
|
15
|
+
from vivarium.framework.engine import Builder
|
16
|
+
from vivarium.framework.event import Event
|
17
|
+
from vivarium.framework.population import SimulantData
|
18
|
+
|
19
|
+
from vivarium_public_health.results.columns import COLUMNS
|
20
|
+
from vivarium_public_health.results.observer import PublicHealthObserver
|
21
|
+
from vivarium_public_health.utilities import to_years
|
22
|
+
|
23
|
+
|
24
|
+
class DiseaseObserver(PublicHealthObserver):
|
25
|
+
"""Observes disease counts and person time for a cause.
|
26
|
+
|
27
|
+
By default, this observer computes aggregate disease state person time and
|
28
|
+
counts of disease events over the full course of the simulation. It can be
|
29
|
+
configured to add or remove stratification groups to the default groups
|
30
|
+
defined by a ResultsStratifier.
|
31
|
+
|
32
|
+
In the model specification, your configuration for this component should
|
33
|
+
be specified as, e.g.:
|
34
|
+
|
35
|
+
.. code-block:: yaml
|
36
|
+
|
37
|
+
configuration:
|
38
|
+
stratification:
|
39
|
+
cause_name:
|
40
|
+
exclude:
|
41
|
+
- "sex"
|
42
|
+
include:
|
43
|
+
- "sample_stratification"
|
44
|
+
"""
|
45
|
+
|
46
|
+
##############
|
47
|
+
# Properties #
|
48
|
+
##############
|
49
|
+
|
50
|
+
@property
|
51
|
+
def configuration_defaults(self) -> Dict[str, Any]:
|
52
|
+
return {
|
53
|
+
"stratification": {
|
54
|
+
self.disease: super().configuration_defaults["stratification"][
|
55
|
+
self.get_configuration_name()
|
56
|
+
]
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
@property
|
61
|
+
def columns_created(self) -> List[str]:
|
62
|
+
return [self.previous_state_column_name]
|
63
|
+
|
64
|
+
@property
|
65
|
+
def columns_required(self) -> List[str]:
|
66
|
+
return [self.disease]
|
67
|
+
|
68
|
+
@property
|
69
|
+
def initialization_requirements(self) -> Dict[str, List[str]]:
|
70
|
+
return {
|
71
|
+
"requires_columns": [self.disease],
|
72
|
+
}
|
73
|
+
|
74
|
+
#####################
|
75
|
+
# Lifecycle methods #
|
76
|
+
#####################
|
77
|
+
|
78
|
+
def __init__(self, disease: str) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.disease = disease
|
81
|
+
self.previous_state_column_name = f"previous_{self.disease}"
|
82
|
+
|
83
|
+
#################
|
84
|
+
# Setup methods #
|
85
|
+
#################
|
86
|
+
|
87
|
+
def setup(self, builder: Builder) -> None:
|
88
|
+
self.step_size = builder.time.step_size()
|
89
|
+
self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
|
90
|
+
self.entity_type = self.disease_model.cause_type
|
91
|
+
self.entity = self.disease_model.cause
|
92
|
+
self.transition_stratification_name = f"transition_{self.disease}"
|
93
|
+
|
94
|
+
def get_configuration(self, builder: Builder) -> LayeredConfigTree:
|
95
|
+
return builder.configuration.stratification[self.disease]
|
96
|
+
|
97
|
+
def register_observations(self, builder: Builder) -> None:
|
98
|
+
|
99
|
+
self.register_disease_state_stratification(builder)
|
100
|
+
self.register_transition_stratification(builder)
|
101
|
+
|
102
|
+
pop_filter = 'alive == "alive" and tracked==True'
|
103
|
+
self.register_person_time_observation(builder, pop_filter)
|
104
|
+
self.register_transition_count_observation(builder, pop_filter)
|
105
|
+
|
106
|
+
def register_disease_state_stratification(self, builder: Builder) -> None:
|
107
|
+
builder.results.register_stratification(
|
108
|
+
self.disease,
|
109
|
+
[state.state_id for state in self.disease_model.states],
|
110
|
+
requires_columns=[self.disease],
|
111
|
+
)
|
112
|
+
|
113
|
+
def register_transition_stratification(self, builder: Builder) -> None:
|
114
|
+
transitions = [
|
115
|
+
str(transition) for transition in self.disease_model.transition_names
|
116
|
+
] + ["no_transition"]
|
117
|
+
# manually append 'no_transition' as an excluded transition
|
118
|
+
excluded_categories = (
|
119
|
+
builder.configuration.stratification.excluded_categories.to_dict().get(
|
120
|
+
self.transition_stratification_name, []
|
121
|
+
)
|
122
|
+
) + ["no_transition"]
|
123
|
+
builder.results.register_stratification(
|
124
|
+
self.transition_stratification_name,
|
125
|
+
categories=transitions,
|
126
|
+
excluded_categories=excluded_categories,
|
127
|
+
mapper=self.map_transitions,
|
128
|
+
requires_columns=[self.disease, self.previous_state_column_name],
|
129
|
+
is_vectorized=True,
|
130
|
+
)
|
131
|
+
|
132
|
+
def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
|
133
|
+
self.register_adding_observation(
|
134
|
+
builder=builder,
|
135
|
+
name=f"person_time_{self.disease}",
|
136
|
+
pop_filter=pop_filter,
|
137
|
+
when="time_step__prepare",
|
138
|
+
requires_columns=["alive", self.disease],
|
139
|
+
additional_stratifications=self.configuration.include + [self.disease],
|
140
|
+
excluded_stratifications=self.configuration.exclude,
|
141
|
+
aggregator=self.aggregate_state_person_time,
|
142
|
+
)
|
143
|
+
|
144
|
+
def register_transition_count_observation(
|
145
|
+
self, builder: Builder, pop_filter: str
|
146
|
+
) -> None:
|
147
|
+
self.register_adding_observation(
|
148
|
+
builder=builder,
|
149
|
+
name=f"transition_count_{self.disease}",
|
150
|
+
pop_filter=pop_filter,
|
151
|
+
requires_columns=[
|
152
|
+
self.previous_state_column_name,
|
153
|
+
self.disease,
|
154
|
+
],
|
155
|
+
additional_stratifications=self.configuration.include
|
156
|
+
+ [self.transition_stratification_name],
|
157
|
+
excluded_stratifications=self.configuration.exclude,
|
158
|
+
)
|
159
|
+
|
160
|
+
def map_transitions(self, df: pd.DataFrame) -> pd.Series:
|
161
|
+
transitions = pd.Series(index=df.index, dtype=str)
|
162
|
+
transition_mask = df[self.previous_state_column_name] != df[self.disease]
|
163
|
+
transitions[~transition_mask] = "no_transition"
|
164
|
+
transitions[transition_mask] = (
|
165
|
+
df[self.previous_state_column_name].astype(str)
|
166
|
+
+ "_to_"
|
167
|
+
+ df[self.disease].astype(str)
|
168
|
+
)
|
169
|
+
return transitions
|
170
|
+
|
171
|
+
########################
|
172
|
+
# Event-driven methods #
|
173
|
+
########################
|
174
|
+
|
175
|
+
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
|
176
|
+
"""Initialize the previous state column to the current state"""
|
177
|
+
pop = self.population_view.subview([self.disease]).get(pop_data.index)
|
178
|
+
pop[self.previous_state_column_name] = pop[self.disease]
|
179
|
+
self.population_view.update(pop)
|
180
|
+
|
181
|
+
def on_time_step_prepare(self, event: Event) -> None:
|
182
|
+
# This enables tracking of transitions between states
|
183
|
+
prior_state_pop = self.population_view.get(event.index)
|
184
|
+
prior_state_pop[self.previous_state_column_name] = prior_state_pop[self.disease]
|
185
|
+
self.population_view.update(prior_state_pop)
|
186
|
+
|
187
|
+
###############
|
188
|
+
# Aggregators #
|
189
|
+
###############
|
190
|
+
|
191
|
+
def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
|
192
|
+
return len(x) * to_years(self.step_size())
|
193
|
+
|
194
|
+
##############################
|
195
|
+
# Results formatting methods #
|
196
|
+
##############################
|
197
|
+
|
198
|
+
def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
|
199
|
+
results = results.reset_index()
|
200
|
+
if "transition_count_" in measure:
|
201
|
+
sub_entity = self.transition_stratification_name
|
202
|
+
if "person_time_" in measure:
|
203
|
+
sub_entity = self.disease
|
204
|
+
results.rename(columns={sub_entity: COLUMNS.SUB_ENTITY}, inplace=True)
|
205
|
+
return results
|
206
|
+
|
207
|
+
def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
208
|
+
if "transition_count_" in measure:
|
209
|
+
measure_name = "transition_count"
|
210
|
+
if "person_time_" in measure:
|
211
|
+
measure_name = "person_time"
|
212
|
+
return pd.Series(measure_name, index=results.index)
|
213
|
+
|
214
|
+
def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
215
|
+
return pd.Series(self.entity_type, index=results.index)
|
216
|
+
|
217
|
+
def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
218
|
+
return pd.Series(self.entity, index=results.index)
|
219
|
+
|
220
|
+
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
|
221
|
+
# The sub-entity col was created in the 'format' method
|
222
|
+
return results[COLUMNS.SUB_ENTITY]
|