vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ """
2
+ ==================
3
+ Mortality Observer
4
+ ==================
5
+
6
+ This module contains tools for observing cause-specific and
7
+ excess mortality in the simulation, including "other causes".
8
+
9
+ """
10
+
11
+ from typing import Any, Dict, List
12
+
13
+ import pandas as pd
14
+ from layered_config_tree import LayeredConfigTree
15
+ from pandas.api.types import CategoricalDtype
16
+ from vivarium.framework.engine import Builder
17
+
18
+ from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
19
+ from vivarium_public_health.results.columns import COLUMNS
20
+ from vivarium_public_health.results.observer import PublicHealthObserver
21
+ from vivarium_public_health.results.simple_cause import SimpleCause
22
+
23
+
24
+ class MortalityObserver(PublicHealthObserver):
25
+ """An observer for cause-specific deaths and ylls (including "other causes").
26
+
27
+ By default, this counts cause-specific deaths and years of life lost over
28
+ the full course of the simulation. It can be configured to add or remove
29
+ stratification groups to the default groups defined by a
30
+ :class:ResultsStratifier. The aggregate configuration key can be set to
31
+ True to aggregate all deaths and ylls into a single observation and remove
32
+ the stratification by cause of death to improve runtime.
33
+
34
+ In the model specification, your configuration for this component should
35
+ be specified as, e.g.:
36
+
37
+ .. code-block:: yaml
38
+
39
+ configuration:
40
+ stratification:
41
+ mortality:
42
+ exclude:
43
+ - "sex"
44
+ include:
45
+ - "sample_stratification"
46
+
47
+ This observer needs to access the has_excess_mortality attribute of the causes
48
+ we're observing, but this attribute gets defined in the setup of the cause models.
49
+ As a result, the model specification should list this observer after causes.
50
+ """
51
+
52
+ def __init__(self) -> None:
53
+ super().__init__()
54
+ self.required_death_columns = ["alive", "exit_time", "cause_of_death"]
55
+ self.required_yll_columns = [
56
+ "alive",
57
+ "cause_of_death",
58
+ "exit_time",
59
+ "years_of_life_lost",
60
+ ]
61
+
62
+ ##############
63
+ # Properties #
64
+ ##############
65
+
66
+ @property
67
+ def mortality_classes(self) -> list[type]:
68
+ return [DiseaseState, RiskAttributableDisease]
69
+
70
+ @property
71
+ def configuration_defaults(self) -> Dict[str, Any]:
72
+ """
73
+ A dictionary containing the defaults for any configurations managed by
74
+ this component.
75
+ """
76
+ config_defaults = super().configuration_defaults
77
+ config_defaults["stratification"][self.get_configuration_name()]["aggregate"] = False
78
+ return config_defaults
79
+
80
+ @property
81
+ def columns_required(self) -> List[str]:
82
+ return [
83
+ "alive",
84
+ "years_of_life_lost",
85
+ "cause_of_death",
86
+ "exit_time",
87
+ ]
88
+
89
+ #################
90
+ # Setup methods #
91
+ #################
92
+
93
+ def setup(self, builder: Builder) -> None:
94
+ self.clock = builder.time.clock()
95
+ self.set_causes_of_death(builder)
96
+
97
+ def set_causes_of_death(self, builder: Builder) -> None:
98
+ causes_of_death = [
99
+ cause
100
+ for cause in builder.components.get_components_by_type(
101
+ tuple(self.mortality_classes)
102
+ )
103
+ if cause.has_excess_mortality
104
+ ]
105
+
106
+ # Convert to SimpleCauses and add on other_causes and not_dead
107
+ self.causes_of_death = [
108
+ SimpleCause.create_from_disease_state(cause) for cause in causes_of_death
109
+ ] + [
110
+ SimpleCause("not_dead", "not_dead", "cause"),
111
+ SimpleCause("other_causes", "other_causes", "cause"),
112
+ ]
113
+
114
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
115
+ return builder.configuration.stratification[self.get_configuration_name()]
116
+
117
+ def register_observations(self, builder: Builder) -> None:
118
+ pop_filter = 'alive == "dead" and tracked == True'
119
+ additional_stratifications = self.configuration.include
120
+ if not self.configuration.aggregate:
121
+ # manually append 'not_dead' as an excluded cause
122
+ excluded_categories = (
123
+ builder.configuration.stratification.excluded_categories.to_dict().get(
124
+ "cause_of_death", []
125
+ )
126
+ ) + ["not_dead"]
127
+ builder.results.register_stratification(
128
+ "cause_of_death",
129
+ [cause.state_id for cause in self.causes_of_death],
130
+ excluded_categories=excluded_categories,
131
+ requires_columns=["cause_of_death"],
132
+ )
133
+ additional_stratifications += ["cause_of_death"]
134
+ self.register_adding_observation(
135
+ builder=builder,
136
+ name="deaths",
137
+ pop_filter=pop_filter,
138
+ requires_columns=self.required_death_columns,
139
+ additional_stratifications=additional_stratifications,
140
+ excluded_stratifications=self.configuration.exclude,
141
+ aggregator=self.count_deaths,
142
+ )
143
+ self.register_adding_observation(
144
+ builder=builder,
145
+ name="ylls",
146
+ pop_filter=pop_filter,
147
+ requires_columns=self.required_yll_columns,
148
+ additional_stratifications=additional_stratifications,
149
+ excluded_stratifications=self.configuration.exclude,
150
+ aggregator=self.calculate_ylls,
151
+ )
152
+
153
+ ###############
154
+ # Aggregators #
155
+ ###############
156
+
157
+ def count_deaths(self, x: pd.DataFrame) -> float:
158
+ died_of_cause = x["exit_time"] > self.clock()
159
+ return sum(died_of_cause)
160
+
161
+ def calculate_ylls(self, x: pd.DataFrame) -> float:
162
+ died_of_cause = x["exit_time"] > self.clock()
163
+ return x.loc[died_of_cause, "years_of_life_lost"].sum()
164
+
165
+ ##############################
166
+ # Results formatting methods #
167
+ ##############################
168
+
169
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
170
+ results = results.reset_index()
171
+ if self.configuration.aggregate:
172
+ results[COLUMNS.ENTITY] = "all_causes"
173
+ else:
174
+ results.rename(columns={"cause_of_death": COLUMNS.ENTITY}, inplace=True)
175
+ return results
176
+
177
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
178
+ entity_type_map = {cause.state_id: cause.cause_type for cause in self.causes_of_death}
179
+ return results[COLUMNS.ENTITY].map(entity_type_map).astype(CategoricalDtype())
180
+
181
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
182
+ # The entity col was created in the 'format' method
183
+ return results[COLUMNS.ENTITY]
184
+
185
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
186
+ return results[COLUMNS.ENTITY]
@@ -0,0 +1,78 @@
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import pandas as pd
4
+ from vivarium.framework.engine import Builder
5
+ from vivarium.framework.results import Observer
6
+
7
+ from vivarium_public_health.results.columns import COLUMNS
8
+
9
+
10
+ class PublicHealthObserver(Observer):
11
+ """A convenience class for typical public health observers. It provides
12
+ an entry point for registering the most common observation type
13
+ as well as standardized results formatting methods to overwrite as necessary.
14
+ """
15
+
16
+ def register_adding_observation(
17
+ self,
18
+ builder: Builder,
19
+ name,
20
+ pop_filter,
21
+ when: str = "collect_metrics",
22
+ requires_columns: List[str] = [],
23
+ requires_values: List[str] = [],
24
+ additional_stratifications: List[str] = [],
25
+ excluded_stratifications: List[str] = [],
26
+ aggregator_sources: Optional[List[str]] = None,
27
+ aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len,
28
+ ):
29
+ builder.results.register_adding_observation(
30
+ name=name,
31
+ pop_filter=pop_filter,
32
+ when=when,
33
+ requires_columns=requires_columns,
34
+ requires_values=requires_values,
35
+ results_formatter=self.format_results,
36
+ additional_stratifications=additional_stratifications,
37
+ excluded_stratifications=excluded_stratifications,
38
+ aggregator_sources=aggregator_sources,
39
+ aggregator=aggregator,
40
+ )
41
+
42
+ def format_results(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
43
+ """Top-level results formatter that calls standard sub-methods to be
44
+ overwritten as necessary.
45
+ """
46
+
47
+ results = self.format(measure, results)
48
+ results[COLUMNS.MEASURE] = self.get_measure_column(measure, results)
49
+ results[COLUMNS.ENTITY_TYPE] = self.get_entity_type_column(measure, results)
50
+ results[COLUMNS.ENTITY] = self.get_entity_column(measure, results)
51
+ results[COLUMNS.SUB_ENTITY] = self.get_sub_entity_column(measure, results)
52
+
53
+ ordered_columns = [
54
+ COLUMNS.MEASURE,
55
+ COLUMNS.ENTITY_TYPE,
56
+ COLUMNS.ENTITY,
57
+ COLUMNS.SUB_ENTITY,
58
+ ]
59
+ ordered_columns += [
60
+ c for c in results.columns if c not in ordered_columns + [COLUMNS.VALUE]
61
+ ]
62
+ ordered_columns += [COLUMNS.VALUE]
63
+ return results[ordered_columns]
64
+
65
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
66
+ return results
67
+
68
+ def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
69
+ return pd.Series(measure, index=results.index)
70
+
71
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
72
+ return pd.Series(None, index=results.index)
73
+
74
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
75
+ return pd.Series(None, index=results.index)
76
+
77
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
78
+ return pd.Series(None, index=results.index)
@@ -0,0 +1,138 @@
1
+ """
2
+ ==============
3
+ Risk Observers
4
+ ==============
5
+
6
+ This module contains tools for observing risk exposure during the simulation.
7
+
8
+ """
9
+
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import pandas as pd
13
+ from layered_config_tree import LayeredConfigTree
14
+ from vivarium.framework.engine import Builder
15
+
16
+ from vivarium_public_health.results.columns import COLUMNS
17
+ from vivarium_public_health.results.observer import PublicHealthObserver
18
+ from vivarium_public_health.utilities import to_years
19
+
20
+
21
+ class CategoricalRiskObserver(PublicHealthObserver):
22
+ """An observer for a categorical risk factor.
23
+
24
+ Observes category person time for a risk factor.
25
+
26
+ By default, this observer computes aggregate categorical person time
27
+ over the full course of the simulation. It can be configured to add or
28
+ remove stratification groups to the default groups defined by a
29
+ ResultsStratifier.
30
+
31
+ In the model specification, your configuration for this component should
32
+ be specified as, e.g.:
33
+
34
+ .. code-block:: yaml
35
+
36
+ configuration:
37
+ stratification:
38
+ risk_name:
39
+ exclude:
40
+ - "sex"
41
+ include:
42
+ - "sample_stratification"
43
+ """
44
+
45
+ ##############
46
+ # Properties #
47
+ ##############
48
+
49
+ @property
50
+ def configuration_defaults(self) -> Dict[str, Any]:
51
+ """
52
+ A dictionary containing the defaults for any configurations managed by
53
+ this component.
54
+ """
55
+ return {
56
+ "stratification": {
57
+ f"{self.risk}": super().configuration_defaults["stratification"][
58
+ self.get_configuration_name()
59
+ ]
60
+ }
61
+ }
62
+
63
+ @property
64
+ def columns_required(self) -> Optional[List[str]]:
65
+ return ["alive"]
66
+
67
+ #####################
68
+ # Lifecycle methods #
69
+ #####################
70
+
71
+ def __init__(self, risk: str) -> None:
72
+ """
73
+ Parameters
74
+ ----------
75
+ risk: name of a risk
76
+
77
+ """
78
+ super().__init__()
79
+ self.risk = risk
80
+ self.exposure_pipeline_name = f"{self.risk}.exposure"
81
+
82
+ #################
83
+ # Setup methods #
84
+ #################
85
+
86
+ def setup(self, builder: Builder) -> None:
87
+ self.step_size = builder.time.step_size()
88
+ self.categories = builder.data.load(f"risk_factor.{self.risk}.categories")
89
+
90
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
91
+ return builder.configuration.stratification[self.risk]
92
+
93
+ def register_observations(self, builder: Builder) -> None:
94
+ builder.results.register_stratification(
95
+ f"{self.risk}",
96
+ list(self.categories.keys()),
97
+ requires_values=[self.exposure_pipeline_name],
98
+ )
99
+ self.register_adding_observation(
100
+ builder=builder,
101
+ name=f"person_time_{self.risk}",
102
+ pop_filter=f'alive == "alive" and tracked==True',
103
+ when="time_step__prepare",
104
+ requires_columns=["alive"],
105
+ requires_values=[self.exposure_pipeline_name],
106
+ additional_stratifications=self.configuration.include + [self.risk],
107
+ excluded_stratifications=self.configuration.exclude,
108
+ aggregator=self.aggregate_risk_category_person_time,
109
+ )
110
+
111
+ ###############
112
+ # Aggregators #
113
+ ###############
114
+
115
+ def aggregate_risk_category_person_time(self, x: pd.DataFrame) -> float:
116
+ return len(x) * to_years(self.step_size())
117
+
118
+ ##############################
119
+ # Results formatting methods #
120
+ ##############################
121
+
122
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
123
+ results = results.reset_index()
124
+ results.rename(columns={self.risk: COLUMNS.SUB_ENTITY}, inplace=True)
125
+ return results
126
+
127
+ def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
128
+ return pd.Series("person_time", index=results.index)
129
+
130
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
131
+ return pd.Series("rei", index=results.index)
132
+
133
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
134
+ return pd.Series(self.risk, index=results.index)
135
+
136
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
137
+ # The sub-entity col was created in the 'format' method
138
+ return results[COLUMNS.SUB_ENTITY]
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class SimpleCause:
6
+ """A simple dataclass to represent the bare minimum information needed
7
+ for observers, e.g. 'all_causes' as a cause of disability. It also
8
+ includes a class method to convert a provided disease state into a
9
+ ``SimpleCause`` instance.
10
+ """
11
+
12
+ state_id: str
13
+ model: str
14
+ cause_type: str
15
+
16
+ @classmethod
17
+ def create_from_disease_state(cls, disease_state: type) -> "SimpleCause":
18
+ return cls(disease_state.state_id, disease_state.model, disease_state.cause_type)
@@ -7,6 +7,8 @@ This module contains tools for stratifying observed quantities
7
7
  by specified characteristics through the vivarium results interface.
8
8
  """
9
9
 
10
+ from __future__ import annotations
11
+
10
12
  import pandas as pd
11
13
  from vivarium import Component
12
14
  from vivarium.framework.engine import Builder
@@ -33,14 +35,14 @@ class ResultsStratifier(Component):
33
35
  builder.results.register_stratification(
34
36
  "age_group",
35
37
  self.age_bins["age_group_name"].to_list(),
36
- self.map_age_groups,
38
+ mapper=self.map_age_groups,
37
39
  is_vectorized=True,
38
40
  requires_columns=["age"],
39
41
  )
40
42
  builder.results.register_stratification(
41
43
  "current_year",
42
44
  [str(year) for year in range(self.start_year, self.end_year + 1)],
43
- self.map_year,
45
+ mapper=self.map_year,
44
46
  is_vectorized=True,
45
47
  requires_columns=["current_time"],
46
48
  )
@@ -49,7 +51,7 @@ class ResultsStratifier(Component):
49
51
  # builder.results.register_stratification(
50
52
  # "event_year",
51
53
  # [str(year) for year in range(self.start_year, self.end_year + 1)],
52
- # self.map_year,
54
+ # mapper=self.map_year,
53
55
  # is_vectorized=True,
54
56
  # requires_columns=["event_time"],
55
57
  # )
@@ -66,7 +68,7 @@ class ResultsStratifier(Component):
66
68
  # builder.results.register_stratification(
67
69
  # "exit_year",
68
70
  # [str(year) for year in range(self.start_year, self.end_year + 1)] + ["nan"],
69
- # self.map_year,
71
+ # mapper=self.map_year,
70
72
  # is_vectorized=True,
71
73
  # requires_columns=["exit_time"],
72
74
  # )
@@ -78,13 +80,13 @@ class ResultsStratifier(Component):
78
80
  # Mappers #
79
81
  ###########
80
82
 
81
- def map_age_groups(self, pop: pd.DataFrame) -> pd.Series:
83
+ def map_age_groups(self, pop: pd.DataFrame) -> pd.Series[str]:
82
84
  """Map age with age group name strings
83
85
 
84
86
  Parameters
85
87
  ----------
86
88
  pop
87
- A DataFrame with one column, an age to be mapped to an age group name string
89
+ A pd.DataFrame with one column, an age to be mapped to an age group name string
88
90
 
89
91
  Returns
90
92
  ------
@@ -97,13 +99,13 @@ class ResultsStratifier(Component):
97
99
  return age_group
98
100
 
99
101
  @staticmethod
100
- def map_year(pop: pd.DataFrame) -> pd.Series:
102
+ def map_year(pop: pd.DataFrame) -> pd.Series[str]:
101
103
  """Map datetime with year
102
104
 
103
105
  Parameters
104
106
  ----------
105
107
  pop
106
- A DataFrame with one column, a datetime to be mapped to year
108
+ A pd.DataFrame with one column, a datetime to be mapped to year
107
109
 
108
110
  Returns
109
111
  ------
@@ -1,6 +1,5 @@
1
1
  from .base_risk import Risk
2
- from .distributions import get_distribution
3
- from .effect import RiskEffect
2
+ from .effect import NonLogLinearRiskEffect, RiskEffect
4
3
  from .implementations.low_birth_weight_and_short_gestation import (
5
4
  LBWSGDistribution,
6
5
  LBWSGRisk,