vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ """
2
+ ==================
3
+ Mortality Observer
4
+ ==================
5
+
6
+ This module contains tools for observing cause-specific and
7
+ excess mortality in the simulation, including "other causes".
8
+
9
+ """
10
+
11
+ from typing import Any, Dict, List
12
+
13
+ import pandas as pd
14
+ from layered_config_tree import LayeredConfigTree
15
+ from pandas.api.types import CategoricalDtype
16
+ from vivarium.framework.engine import Builder
17
+
18
+ from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
19
+ from vivarium_public_health.results.columns import COLUMNS
20
+ from vivarium_public_health.results.observer import PublicHealthObserver
21
+ from vivarium_public_health.results.simple_cause import SimpleCause
22
+
23
+
24
+ class MortalityObserver(PublicHealthObserver):
25
+ """An observer for cause-specific deaths and ylls (including "other causes").
26
+
27
+ By default, this counts cause-specific deaths and years of life lost over
28
+ the full course of the simulation. It can be configured to add or remove
29
+ stratification groups to the default groups defined by a
30
+ :class:ResultsStratifier. The aggregate configuration key can be set to
31
+ True to aggregate all deaths and ylls into a single observation and remove
32
+ the stratification by cause of death to improve runtime.
33
+
34
+ In the model specification, your configuration for this component should
35
+ be specified as, e.g.:
36
+
37
+ .. code-block:: yaml
38
+
39
+ configuration:
40
+ stratification:
41
+ mortality:
42
+ exclude:
43
+ - "sex"
44
+ include:
45
+ - "sample_stratification"
46
+
47
+ This observer needs to access the has_excess_mortality attribute of the causes
48
+ we're observing, but this attribute gets defined in the setup of the cause models.
49
+ As a result, the model specification should list this observer after causes.
50
+ """
51
+
52
+ def __init__(self) -> None:
53
+ super().__init__()
54
+ self.required_death_columns = ["alive", "exit_time", "cause_of_death"]
55
+ self.required_yll_columns = [
56
+ "alive",
57
+ "cause_of_death",
58
+ "exit_time",
59
+ "years_of_life_lost",
60
+ ]
61
+
62
+ ##############
63
+ # Properties #
64
+ ##############
65
+
66
+ @property
67
+ def mortality_classes(self) -> list[type]:
68
+ return [DiseaseState, RiskAttributableDisease]
69
+
70
+ @property
71
+ def configuration_defaults(self) -> Dict[str, Any]:
72
+ """
73
+ A dictionary containing the defaults for any configurations managed by
74
+ this component.
75
+ """
76
+ config_defaults = super().configuration_defaults
77
+ config_defaults["stratification"][self.get_configuration_name()]["aggregate"] = False
78
+ return config_defaults
79
+
80
+ @property
81
+ def columns_required(self) -> List[str]:
82
+ return [
83
+ "alive",
84
+ "years_of_life_lost",
85
+ "cause_of_death",
86
+ "exit_time",
87
+ ]
88
+
89
+ #################
90
+ # Setup methods #
91
+ #################
92
+
93
+ def setup(self, builder: Builder) -> None:
94
+ self.clock = builder.time.clock()
95
+ self.set_causes_of_death(builder)
96
+
97
+ def set_causes_of_death(self, builder: Builder) -> None:
98
+ causes_of_death = [
99
+ cause
100
+ for cause in builder.components.get_components_by_type(
101
+ tuple(self.mortality_classes)
102
+ )
103
+ if cause.has_excess_mortality
104
+ ]
105
+
106
+ # Convert to SimpleCauses and add on other_causes and not_dead
107
+ self.causes_of_death = [
108
+ SimpleCause.create_from_disease_state(cause) for cause in causes_of_death
109
+ ] + [
110
+ SimpleCause("not_dead", "not_dead", "cause"),
111
+ SimpleCause("other_causes", "other_causes", "cause"),
112
+ ]
113
+
114
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
115
+ return builder.configuration.stratification[self.get_configuration_name()]
116
+
117
+ def register_observations(self, builder: Builder) -> None:
118
+ pop_filter = 'alive == "dead" and tracked == True'
119
+ additional_stratifications = self.configuration.include
120
+ if not self.configuration.aggregate:
121
+ # manually append 'not_dead' as an excluded cause
122
+ excluded_categories = (
123
+ builder.configuration.stratification.excluded_categories.to_dict().get(
124
+ "cause_of_death", []
125
+ )
126
+ ) + ["not_dead"]
127
+ builder.results.register_stratification(
128
+ "cause_of_death",
129
+ [cause.state_id for cause in self.causes_of_death],
130
+ excluded_categories=excluded_categories,
131
+ requires_columns=["cause_of_death"],
132
+ )
133
+ additional_stratifications += ["cause_of_death"]
134
+ self.register_adding_observation(
135
+ builder=builder,
136
+ name="deaths",
137
+ pop_filter=pop_filter,
138
+ requires_columns=self.required_death_columns,
139
+ additional_stratifications=additional_stratifications,
140
+ excluded_stratifications=self.configuration.exclude,
141
+ aggregator=self.count_deaths,
142
+ )
143
+ self.register_adding_observation(
144
+ builder=builder,
145
+ name="ylls",
146
+ pop_filter=pop_filter,
147
+ requires_columns=self.required_yll_columns,
148
+ additional_stratifications=additional_stratifications,
149
+ excluded_stratifications=self.configuration.exclude,
150
+ aggregator=self.calculate_ylls,
151
+ )
152
+
153
+ ###############
154
+ # Aggregators #
155
+ ###############
156
+
157
+ def count_deaths(self, x: pd.DataFrame) -> float:
158
+ died_of_cause = x["exit_time"] > self.clock()
159
+ return sum(died_of_cause)
160
+
161
+ def calculate_ylls(self, x: pd.DataFrame) -> float:
162
+ died_of_cause = x["exit_time"] > self.clock()
163
+ return x.loc[died_of_cause, "years_of_life_lost"].sum()
164
+
165
+ ##############################
166
+ # Results formatting methods #
167
+ ##############################
168
+
169
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
170
+ results = results.reset_index()
171
+ if self.configuration.aggregate:
172
+ results[COLUMNS.ENTITY] = "all_causes"
173
+ else:
174
+ results.rename(columns={"cause_of_death": COLUMNS.ENTITY}, inplace=True)
175
+ return results
176
+
177
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
178
+ entity_type_map = {cause.state_id: cause.cause_type for cause in self.causes_of_death}
179
+ return results[COLUMNS.ENTITY].map(entity_type_map).astype(CategoricalDtype())
180
+
181
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
182
+ # The entity col was created in the 'format' method
183
+ return results[COLUMNS.ENTITY]
184
+
185
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
186
+ return results[COLUMNS.ENTITY]
@@ -0,0 +1,78 @@
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import pandas as pd
4
+ from vivarium.framework.engine import Builder
5
+ from vivarium.framework.results import Observer
6
+
7
+ from vivarium_public_health.results.columns import COLUMNS
8
+
9
+
10
+ class PublicHealthObserver(Observer):
11
+ """A convenience class for typical public health observers. It provides
12
+ an entry point for registering the most common observation type
13
+ as well as standardized results formatting methods to overwrite as necessary.
14
+ """
15
+
16
+ def register_adding_observation(
17
+ self,
18
+ builder: Builder,
19
+ name,
20
+ pop_filter,
21
+ when: str = "collect_metrics",
22
+ requires_columns: List[str] = [],
23
+ requires_values: List[str] = [],
24
+ additional_stratifications: List[str] = [],
25
+ excluded_stratifications: List[str] = [],
26
+ aggregator_sources: Optional[List[str]] = None,
27
+ aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len,
28
+ ):
29
+ builder.results.register_adding_observation(
30
+ name=name,
31
+ pop_filter=pop_filter,
32
+ when=when,
33
+ requires_columns=requires_columns,
34
+ requires_values=requires_values,
35
+ results_formatter=self.format_results,
36
+ additional_stratifications=additional_stratifications,
37
+ excluded_stratifications=excluded_stratifications,
38
+ aggregator_sources=aggregator_sources,
39
+ aggregator=aggregator,
40
+ )
41
+
42
+ def format_results(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
43
+ """Top-level results formatter that calls standard sub-methods to be
44
+ overwritten as necessary.
45
+ """
46
+
47
+ results = self.format(measure, results)
48
+ results[COLUMNS.MEASURE] = self.get_measure_column(measure, results)
49
+ results[COLUMNS.ENTITY_TYPE] = self.get_entity_type_column(measure, results)
50
+ results[COLUMNS.ENTITY] = self.get_entity_column(measure, results)
51
+ results[COLUMNS.SUB_ENTITY] = self.get_sub_entity_column(measure, results)
52
+
53
+ ordered_columns = [
54
+ COLUMNS.MEASURE,
55
+ COLUMNS.ENTITY_TYPE,
56
+ COLUMNS.ENTITY,
57
+ COLUMNS.SUB_ENTITY,
58
+ ]
59
+ ordered_columns += [
60
+ c for c in results.columns if c not in ordered_columns + [COLUMNS.VALUE]
61
+ ]
62
+ ordered_columns += [COLUMNS.VALUE]
63
+ return results[ordered_columns]
64
+
65
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
66
+ return results
67
+
68
+ def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
69
+ return pd.Series(measure, index=results.index)
70
+
71
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
72
+ return pd.Series(None, index=results.index)
73
+
74
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
75
+ return pd.Series(None, index=results.index)
76
+
77
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
78
+ return pd.Series(None, index=results.index)
@@ -0,0 +1,138 @@
1
+ """
2
+ ==============
3
+ Risk Observers
4
+ ==============
5
+
6
+ This module contains tools for observing risk exposure during the simulation.
7
+
8
+ """
9
+
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import pandas as pd
13
+ from layered_config_tree import LayeredConfigTree
14
+ from vivarium.framework.engine import Builder
15
+
16
+ from vivarium_public_health.results.columns import COLUMNS
17
+ from vivarium_public_health.results.observer import PublicHealthObserver
18
+ from vivarium_public_health.utilities import to_years
19
+
20
+
21
+ class CategoricalRiskObserver(PublicHealthObserver):
22
+ """An observer for a categorical risk factor.
23
+
24
+ Observes category person time for a risk factor.
25
+
26
+ By default, this observer computes aggregate categorical person time
27
+ over the full course of the simulation. It can be configured to add or
28
+ remove stratification groups to the default groups defined by a
29
+ ResultsStratifier.
30
+
31
+ In the model specification, your configuration for this component should
32
+ be specified as, e.g.:
33
+
34
+ .. code-block:: yaml
35
+
36
+ configuration:
37
+ stratification:
38
+ risk_name:
39
+ exclude:
40
+ - "sex"
41
+ include:
42
+ - "sample_stratification"
43
+ """
44
+
45
+ ##############
46
+ # Properties #
47
+ ##############
48
+
49
+ @property
50
+ def configuration_defaults(self) -> Dict[str, Any]:
51
+ """
52
+ A dictionary containing the defaults for any configurations managed by
53
+ this component.
54
+ """
55
+ return {
56
+ "stratification": {
57
+ f"{self.risk}": super().configuration_defaults["stratification"][
58
+ self.get_configuration_name()
59
+ ]
60
+ }
61
+ }
62
+
63
+ @property
64
+ def columns_required(self) -> Optional[List[str]]:
65
+ return ["alive"]
66
+
67
+ #####################
68
+ # Lifecycle methods #
69
+ #####################
70
+
71
+ def __init__(self, risk: str) -> None:
72
+ """
73
+ Parameters
74
+ ----------
75
+ risk: name of a risk
76
+
77
+ """
78
+ super().__init__()
79
+ self.risk = risk
80
+ self.exposure_pipeline_name = f"{self.risk}.exposure"
81
+
82
+ #################
83
+ # Setup methods #
84
+ #################
85
+
86
+ def setup(self, builder: Builder) -> None:
87
+ self.step_size = builder.time.step_size()
88
+ self.categories = builder.data.load(f"risk_factor.{self.risk}.categories")
89
+
90
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
91
+ return builder.configuration.stratification[self.risk]
92
+
93
+ def register_observations(self, builder: Builder) -> None:
94
+ builder.results.register_stratification(
95
+ f"{self.risk}",
96
+ list(self.categories.keys()),
97
+ requires_values=[self.exposure_pipeline_name],
98
+ )
99
+ self.register_adding_observation(
100
+ builder=builder,
101
+ name=f"person_time_{self.risk}",
102
+ pop_filter=f'alive == "alive" and tracked==True',
103
+ when="time_step__prepare",
104
+ requires_columns=["alive"],
105
+ requires_values=[self.exposure_pipeline_name],
106
+ additional_stratifications=self.configuration.include + [self.risk],
107
+ excluded_stratifications=self.configuration.exclude,
108
+ aggregator=self.aggregate_risk_category_person_time,
109
+ )
110
+
111
+ ###############
112
+ # Aggregators #
113
+ ###############
114
+
115
+ def aggregate_risk_category_person_time(self, x: pd.DataFrame) -> float:
116
+ return len(x) * to_years(self.step_size())
117
+
118
+ ##############################
119
+ # Results formatting methods #
120
+ ##############################
121
+
122
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
123
+ results = results.reset_index()
124
+ results.rename(columns={self.risk: COLUMNS.SUB_ENTITY}, inplace=True)
125
+ return results
126
+
127
+ def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
128
+ return pd.Series("person_time", index=results.index)
129
+
130
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
131
+ return pd.Series("rei", index=results.index)
132
+
133
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
134
+ return pd.Series(self.risk, index=results.index)
135
+
136
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
137
+ # The sub-entity col was created in the 'format' method
138
+ return results[COLUMNS.SUB_ENTITY]
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class SimpleCause:
6
+ """A simple dataclass to represent the bare minimum information needed
7
+ for observers, e.g. 'all_causes' as a cause of disability. It also
8
+ includes a class method to convert a provided disease state into a
9
+ ``SimpleCause`` instance.
10
+ """
11
+
12
+ state_id: str
13
+ model: str
14
+ cause_type: str
15
+
16
+ @classmethod
17
+ def create_from_disease_state(cls, disease_state: type) -> "SimpleCause":
18
+ return cls(disease_state.state_id, disease_state.model, disease_state.cause_type)
@@ -7,6 +7,8 @@ This module contains tools for stratifying observed quantities
7
7
  by specified characteristics through the vivarium results interface.
8
8
  """
9
9
 
10
+ from __future__ import annotations
11
+
10
12
  import pandas as pd
11
13
  from vivarium import Component
12
14
  from vivarium.framework.engine import Builder
@@ -33,14 +35,14 @@ class ResultsStratifier(Component):
33
35
  builder.results.register_stratification(
34
36
  "age_group",
35
37
  self.age_bins["age_group_name"].to_list(),
36
- self.map_age_groups,
38
+ mapper=self.map_age_groups,
37
39
  is_vectorized=True,
38
40
  requires_columns=["age"],
39
41
  )
40
42
  builder.results.register_stratification(
41
43
  "current_year",
42
44
  [str(year) for year in range(self.start_year, self.end_year + 1)],
43
- self.map_year,
45
+ mapper=self.map_year,
44
46
  is_vectorized=True,
45
47
  requires_columns=["current_time"],
46
48
  )
@@ -49,7 +51,7 @@ class ResultsStratifier(Component):
49
51
  # builder.results.register_stratification(
50
52
  # "event_year",
51
53
  # [str(year) for year in range(self.start_year, self.end_year + 1)],
52
- # self.map_year,
54
+ # mapper=self.map_year,
53
55
  # is_vectorized=True,
54
56
  # requires_columns=["event_time"],
55
57
  # )
@@ -66,7 +68,7 @@ class ResultsStratifier(Component):
66
68
  # builder.results.register_stratification(
67
69
  # "exit_year",
68
70
  # [str(year) for year in range(self.start_year, self.end_year + 1)] + ["nan"],
69
- # self.map_year,
71
+ # mapper=self.map_year,
70
72
  # is_vectorized=True,
71
73
  # requires_columns=["exit_time"],
72
74
  # )
@@ -78,13 +80,13 @@ class ResultsStratifier(Component):
78
80
  # Mappers #
79
81
  ###########
80
82
 
81
- def map_age_groups(self, pop: pd.DataFrame) -> pd.Series:
83
+ def map_age_groups(self, pop: pd.DataFrame) -> pd.Series[str]:
82
84
  """Map age with age group name strings
83
85
 
84
86
  Parameters
85
87
  ----------
86
88
  pop
87
- A DataFrame with one column, an age to be mapped to an age group name string
89
+ A pd.DataFrame with one column, an age to be mapped to an age group name string
88
90
 
89
91
  Returns
90
92
  ------
@@ -97,13 +99,13 @@ class ResultsStratifier(Component):
97
99
  return age_group
98
100
 
99
101
  @staticmethod
100
- def map_year(pop: pd.DataFrame) -> pd.Series:
102
+ def map_year(pop: pd.DataFrame) -> pd.Series[str]:
101
103
  """Map datetime with year
102
104
 
103
105
  Parameters
104
106
  ----------
105
107
  pop
106
- A DataFrame with one column, a datetime to be mapped to year
108
+ A pd.DataFrame with one column, a datetime to be mapped to year
107
109
 
108
110
  Returns
109
111
  ------
@@ -1,6 +1,5 @@
1
1
  from .base_risk import Risk
2
- from .distributions import get_distribution
3
- from .effect import RiskEffect
2
+ from .effect import NonLogLinearRiskEffect, RiskEffect
4
3
  from .implementations.low_birth_weight_and_short_gestation import (
5
4
  LBWSGDistribution,
6
5
  LBWSGRisk,