vivarium-public-health 2.3.3__py3-none-any.whl → 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vivarium_public_health/_version.py +1 -1
 - vivarium_public_health/disease/model.py +23 -21
 - vivarium_public_health/disease/models.py +1 -0
 - vivarium_public_health/disease/special_disease.py +40 -41
 - vivarium_public_health/disease/state.py +42 -125
 - vivarium_public_health/disease/transition.py +70 -27
 - vivarium_public_health/mslt/delay.py +1 -0
 - vivarium_public_health/mslt/disease.py +1 -0
 - vivarium_public_health/mslt/intervention.py +1 -0
 - vivarium_public_health/mslt/magic_wand_components.py +1 -0
 - vivarium_public_health/mslt/observer.py +1 -0
 - vivarium_public_health/mslt/population.py +1 -0
 - vivarium_public_health/plugins/parser.py +61 -31
 - vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
 - vivarium_public_health/population/base_population.py +2 -1
 - vivarium_public_health/population/mortality.py +83 -80
 - vivarium_public_health/{metrics → results}/__init__.py +2 -0
 - vivarium_public_health/results/columns.py +22 -0
 - vivarium_public_health/results/disability.py +187 -0
 - vivarium_public_health/results/disease.py +222 -0
 - vivarium_public_health/results/mortality.py +186 -0
 - vivarium_public_health/results/observer.py +78 -0
 - vivarium_public_health/results/risk.py +138 -0
 - vivarium_public_health/results/simple_cause.py +18 -0
 - vivarium_public_health/{metrics → results}/stratification.py +10 -8
 - vivarium_public_health/risks/__init__.py +1 -2
 - vivarium_public_health/risks/base_risk.py +134 -29
 - vivarium_public_health/risks/data_transformations.py +65 -326
 - vivarium_public_health/risks/distributions.py +315 -145
 - vivarium_public_health/risks/effect.py +376 -75
 - vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
 - vivarium_public_health/treatment/magic_wand.py +1 -0
 - vivarium_public_health/treatment/scale_up.py +1 -0
 - vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
 - vivarium_public_health/utilities.py +17 -2
 - {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/METADATA +12 -2
 - vivarium_public_health-3.0.1.dist-info/RECORD +49 -0
 - {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/WHEEL +1 -1
 - vivarium_public_health/metrics/disability.py +0 -118
 - vivarium_public_health/metrics/disease.py +0 -136
 - vivarium_public_health/metrics/mortality.py +0 -144
 - vivarium_public_health/metrics/risk.py +0 -110
 - vivarium_public_health/testing/__init__.py +0 -0
 - vivarium_public_health/testing/mock_artifact.py +0 -145
 - vivarium_public_health/testing/utils.py +0 -71
 - vivarium_public_health-2.3.3.dist-info/RECORD +0 -49
 - {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/LICENSE.txt +0 -0
 - {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -6,7 +6,8 @@ Disease Transitions 
     | 
|
| 
       6 
6 
     | 
    
         
             
            This module contains tools to model transitions between disease states.
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            """
         
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            from typing import TYPE_CHECKING, Any, Callable, Dict, Union
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
       11 
12 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
       12 
13 
     | 
    
         
             
            from vivarium.framework.engine import Builder
         
     | 
| 
         @@ -14,6 +15,8 @@ from vivarium.framework.state_machine import Transition, Trigger 
     | 
|
| 
       14 
15 
     | 
    
         
             
            from vivarium.framework.utilities import rate_to_probability
         
     | 
| 
       15 
16 
     | 
    
         
             
            from vivarium.framework.values import list_combiner, union_post_processor
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
      
 18 
     | 
    
         
            +
            from vivarium_public_health.utilities import get_lookup_columns
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
       17 
20 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       18 
21 
     | 
    
         
             
                from vivarium_public_health.disease import BaseDiseaseState
         
     | 
| 
       19 
22 
     | 
    
         | 
| 
         @@ -25,8 +28,43 @@ class TransitionString(str): 
     | 
|
| 
       25 
28 
     | 
    
         
             
                    obj.from_state, obj.to_state = value.split("_TO_")
         
     | 
| 
       26 
29 
     | 
    
         
             
                    return obj
         
     | 
| 
       27 
30 
     | 
    
         | 
| 
      
 31 
     | 
    
         
            +
                def __getnewargs__(self):
         
     | 
| 
      
 32 
     | 
    
         
            +
                    return (self.from_state + "_TO_" + self.to_state,)
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
       28 
34 
     | 
    
         | 
| 
       29 
35 
     | 
    
         
             
            class RateTransition(Transition):
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                ##############
         
     | 
| 
      
 38 
     | 
    
         
            +
                # Properties #
         
     | 
| 
      
 39 
     | 
    
         
            +
                ##############
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                @property
         
     | 
| 
      
 42 
     | 
    
         
            +
                def configuration_defaults(self) -> Dict[str, Any]:
         
     | 
| 
      
 43 
     | 
    
         
            +
                    return {
         
     | 
| 
      
 44 
     | 
    
         
            +
                        f"{self.name}": {
         
     | 
| 
      
 45 
     | 
    
         
            +
                            "data_sources": {
         
     | 
| 
      
 46 
     | 
    
         
            +
                                "transition_rate": "self::load_transition_rate",
         
     | 
| 
      
 47 
     | 
    
         
            +
                            },
         
     | 
| 
      
 48 
     | 
    
         
            +
                        },
         
     | 
| 
      
 49 
     | 
    
         
            +
                    }
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                @property
         
     | 
| 
      
 52 
     | 
    
         
            +
                def transition_rate_pipeline_name(self) -> str:
         
     | 
| 
      
 53 
     | 
    
         
            +
                    if "incidence_rate" in self._get_data_functions:
         
     | 
| 
      
 54 
     | 
    
         
            +
                        pipeline_name = f"{self.output_state.state_id}.incidence_rate"
         
     | 
| 
      
 55 
     | 
    
         
            +
                    elif "remission_rate" in self._get_data_functions:
         
     | 
| 
      
 56 
     | 
    
         
            +
                        pipeline_name = f"{self.input_state.state_id}.remission_rate"
         
     | 
| 
      
 57 
     | 
    
         
            +
                    elif "transition_rate" in self._get_data_functions:
         
     | 
| 
      
 58 
     | 
    
         
            +
                        pipeline_name = (
         
     | 
| 
      
 59 
     | 
    
         
            +
                            f"{self.input_state.state_id}_to_{self.output_state.state_id}.transition_rate"
         
     | 
| 
      
 60 
     | 
    
         
            +
                        )
         
     | 
| 
      
 61 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 62 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 63 
     | 
    
         
            +
                            "Cannot determine rate_transition pipeline name: "
         
     | 
| 
      
 64 
     | 
    
         
            +
                            "no valid data functions supplied."
         
     | 
| 
      
 65 
     | 
    
         
            +
                        )
         
     | 
| 
      
 66 
     | 
    
         
            +
                    return pipeline_name
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
       30 
68 
     | 
    
         
             
                #####################
         
     | 
| 
       31 
69 
     | 
    
         
             
                # Lifecycle methods #
         
     | 
| 
       32 
70 
     | 
    
         
             
                #####################
         
     | 
| 
         @@ -47,19 +85,16 @@ class RateTransition(Transition): 
     | 
|
| 
       47 
85 
     | 
    
         | 
| 
       48 
86 
     | 
    
         
             
                # noinspection PyAttributeOutsideInit
         
     | 
| 
       49 
87 
     | 
    
         
             
                def setup(self, builder: Builder) -> None:
         
     | 
| 
       50 
     | 
    
         
            -
                     
     | 
| 
       51 
     | 
    
         
            -
                    self.base_rate = builder.lookup.build_table(
         
     | 
| 
       52 
     | 
    
         
            -
                        rate_data, key_columns=["sex"], parameter_columns=["age", "year"]
         
     | 
| 
       53 
     | 
    
         
            -
                    )
         
     | 
| 
      
 88 
     | 
    
         
            +
                    lookup_columns = get_lookup_columns([self.lookup_tables["transition_rate"]])
         
     | 
| 
       54 
89 
     | 
    
         
             
                    self.transition_rate = builder.value.register_rate_producer(
         
     | 
| 
       55 
     | 
    
         
            -
                         
     | 
| 
      
 90 
     | 
    
         
            +
                        self.transition_rate_pipeline_name,
         
     | 
| 
       56 
91 
     | 
    
         
             
                        source=self.compute_transition_rate,
         
     | 
| 
       57 
     | 
    
         
            -
                        requires_columns= 
     | 
| 
       58 
     | 
    
         
            -
                        requires_values=[f"{ 
     | 
| 
      
 92 
     | 
    
         
            +
                        requires_columns=lookup_columns + ["alive"],
         
     | 
| 
      
 93 
     | 
    
         
            +
                        requires_values=[f"{self.transition_rate_pipeline_name}.paf"],
         
     | 
| 
       59 
94 
     | 
    
         
             
                    )
         
     | 
| 
       60 
95 
     | 
    
         
             
                    paf = builder.lookup.build_table(0)
         
     | 
| 
       61 
96 
     | 
    
         
             
                    self.joint_paf = builder.value.register_value_producer(
         
     | 
| 
       62 
     | 
    
         
            -
                        f"{ 
     | 
| 
      
 97 
     | 
    
         
            +
                        f"{self.transition_rate_pipeline_name}.paf",
         
     | 
| 
       63 
98 
     | 
    
         
             
                        source=lambda index: [paf(index)],
         
     | 
| 
       64 
99 
     | 
    
         
             
                        preferred_combiner=list_combiner,
         
     | 
| 
       65 
100 
     | 
    
         
             
                        preferred_post_processor=union_post_processor,
         
     | 
| 
         @@ -71,27 +106,22 @@ class RateTransition(Transition): 
     | 
|
| 
       71 
106 
     | 
    
         
             
                # Setup methods #
         
     | 
| 
       72 
107 
     | 
    
         
             
                #################
         
     | 
| 
       73 
108 
     | 
    
         | 
| 
       74 
     | 
    
         
            -
                def  
     | 
| 
      
 109 
     | 
    
         
            +
                def load_transition_rate(self, builder: Builder) -> Union[float, pd.DataFrame]:
         
     | 
| 
       75 
110 
     | 
    
         
             
                    if "incidence_rate" in self._get_data_functions:
         
     | 
| 
       76 
111 
     | 
    
         
             
                        rate_data = self._get_data_functions["incidence_rate"](
         
     | 
| 
       77 
112 
     | 
    
         
             
                            builder, self.output_state.state_id
         
     | 
| 
       78 
113 
     | 
    
         
             
                        )
         
     | 
| 
       79 
     | 
    
         
            -
                        pipeline_name = f"{self.output_state.state_id}.incidence_rate"
         
     | 
| 
       80 
114 
     | 
    
         
             
                    elif "remission_rate" in self._get_data_functions:
         
     | 
| 
       81 
115 
     | 
    
         
             
                        rate_data = self._get_data_functions["remission_rate"](
         
     | 
| 
       82 
116 
     | 
    
         
             
                            builder, self.input_state.state_id
         
     | 
| 
       83 
117 
     | 
    
         
             
                        )
         
     | 
| 
       84 
     | 
    
         
            -
                        pipeline_name = f"{self.input_state.state_id}.remission_rate"
         
     | 
| 
       85 
118 
     | 
    
         
             
                    elif "transition_rate" in self._get_data_functions:
         
     | 
| 
       86 
119 
     | 
    
         
             
                        rate_data = self._get_data_functions["transition_rate"](
         
     | 
| 
       87 
120 
     | 
    
         
             
                            builder, self.input_state.state_id, self.output_state.state_id
         
     | 
| 
       88 
121 
     | 
    
         
             
                        )
         
     | 
| 
       89 
     | 
    
         
            -
                        pipeline_name = (
         
     | 
| 
       90 
     | 
    
         
            -
                            f"{self.input_state.state_id}_to_{self.output_state.state_id}.transition_rate"
         
     | 
| 
       91 
     | 
    
         
            -
                        )
         
     | 
| 
       92 
122 
     | 
    
         
             
                    else:
         
     | 
| 
       93 
123 
     | 
    
         
             
                        raise ValueError("No valid data functions supplied.")
         
     | 
| 
       94 
     | 
    
         
            -
                    return rate_data 
     | 
| 
      
 124 
     | 
    
         
            +
                    return rate_data
         
     | 
| 
       95 
125 
     | 
    
         | 
| 
       96 
126 
     | 
    
         
             
                ##################################
         
     | 
| 
       97 
127 
     | 
    
         
             
                # Pipeline sources and modifiers #
         
     | 
| 
         @@ -100,7 +130,7 @@ class RateTransition(Transition): 
     | 
|
| 
       100 
130 
     | 
    
         
             
                def compute_transition_rate(self, index: pd.Index) -> pd.Series:
         
     | 
| 
       101 
131 
     | 
    
         
             
                    transition_rate = pd.Series(0.0, index=index)
         
     | 
| 
       102 
132 
     | 
    
         
             
                    living = self.population_view.get(index, query='alive == "alive"').index
         
     | 
| 
       103 
     | 
    
         
            -
                    base_rates = self. 
     | 
| 
      
 133 
     | 
    
         
            +
                    base_rates = self.lookup_tables["transition_rate"](living)
         
     | 
| 
       104 
134 
     | 
    
         
             
                    joint_paf = self.joint_paf(living)
         
     | 
| 
       105 
135 
     | 
    
         
             
                    transition_rate.loc[living] = base_rates * (1 - joint_paf)
         
     | 
| 
       106 
136 
     | 
    
         
             
                    return transition_rate
         
     | 
| 
         @@ -114,6 +144,21 @@ class RateTransition(Transition): 
     | 
|
| 
       114 
144 
     | 
    
         | 
| 
       115 
145 
     | 
    
         | 
| 
       116 
146 
     | 
    
         
             
            class ProportionTransition(Transition):
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                ##############
         
     | 
| 
      
 149 
     | 
    
         
            +
                # Properties #
         
     | 
| 
      
 150 
     | 
    
         
            +
                ##############
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                @property
         
     | 
| 
      
 153 
     | 
    
         
            +
                def configuration_defaults(self) -> Dict[str, Any]:
         
     | 
| 
      
 154 
     | 
    
         
            +
                    return {
         
     | 
| 
      
 155 
     | 
    
         
            +
                        f"{self.name}": {
         
     | 
| 
      
 156 
     | 
    
         
            +
                            "data_sources": {
         
     | 
| 
      
 157 
     | 
    
         
            +
                                "proportion": "self::load_proportion",
         
     | 
| 
      
 158 
     | 
    
         
            +
                            },
         
     | 
| 
      
 159 
     | 
    
         
            +
                        },
         
     | 
| 
      
 160 
     | 
    
         
            +
                    }
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
       117 
162 
     | 
    
         
             
                #####################
         
     | 
| 
       118 
163 
     | 
    
         
             
                # Lifecycle methods #
         
     | 
| 
       119 
164 
     | 
    
         
             
                #####################
         
     | 
| 
         @@ -132,16 +177,14 @@ class ProportionTransition(Transition): 
     | 
|
| 
       132 
177 
     | 
    
         
             
                        get_data_functions if get_data_functions is not None else {}
         
     | 
| 
       133 
178 
     | 
    
         
             
                    )
         
     | 
| 
       134 
179 
     | 
    
         | 
| 
       135 
     | 
    
         
            -
                 
     | 
| 
       136 
     | 
    
         
            -
                 
     | 
| 
       137 
     | 
    
         
            -
             
     | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
             
     | 
| 
      
 180 
     | 
    
         
            +
                #################
         
     | 
| 
      
 181 
     | 
    
         
            +
                # Setup methods #
         
     | 
| 
      
 182 
     | 
    
         
            +
                #################
         
     | 
| 
      
 183 
     | 
    
         
            +
             
     | 
| 
      
 184 
     | 
    
         
            +
                def load_proportion(self, builder: Builder) -> Union[float, pd.DataFrame]:
         
     | 
| 
      
 185 
     | 
    
         
            +
                    if "proportion" not in self._get_data_functions:
         
     | 
| 
       140 
186 
     | 
    
         
             
                        raise ValueError("Must supply a proportion function")
         
     | 
| 
       141 
     | 
    
         
            -
                    self. 
     | 
| 
       142 
     | 
    
         
            -
                    self.proportion = builder.lookup.build_table(
         
     | 
| 
       143 
     | 
    
         
            -
                        self._proportion_data, key_columns=["sex"], parameter_columns=["age", "year"]
         
     | 
| 
       144 
     | 
    
         
            -
                    )
         
     | 
| 
      
 187 
     | 
    
         
            +
                    return self._get_data_functions["proportion"](builder, self.output_state.state_id)
         
     | 
| 
       145 
188 
     | 
    
         | 
| 
       146 
189 
     | 
    
         
             
                def _probability(self, index):
         
     | 
| 
       147 
     | 
    
         
            -
                    return self.proportion(index)
         
     | 
| 
      
 190 
     | 
    
         
            +
                    return self.lookup_tables["proportion"](index)
         
     | 
| 
         @@ -8,12 +8,14 @@ Component Configuration Parsers in this module are specialized implementations o 
     | 
|
| 
       8 
8 
     | 
    
         
             
            that can parse configurations of components specific to the Vivarium Public
         
     | 
| 
       9 
9 
     | 
    
         
             
            Health package.
         
     | 
| 
       10 
10 
     | 
    
         
             
            """
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
       11 
12 
     | 
    
         
             
            from importlib import import_module
         
     | 
| 
       12 
13 
     | 
    
         
             
            from typing import Any, Callable, Dict, List, Optional, Union
         
     | 
| 
       13 
14 
     | 
    
         | 
| 
       14 
15 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
      
 16 
     | 
    
         
            +
            from layered_config_tree import LayeredConfigTree
         
     | 
| 
       15 
17 
     | 
    
         
             
            from pkg_resources import resource_filename
         
     | 
| 
       16 
     | 
    
         
            -
            from vivarium import Component 
     | 
| 
      
 18 
     | 
    
         
            +
            from vivarium import Component
         
     | 
| 
       17 
19 
     | 
    
         
             
            from vivarium.framework.components import ComponentConfigurationParser
         
     | 
| 
       18 
20 
     | 
    
         
             
            from vivarium.framework.components.parser import ParsingError
         
     | 
| 
       19 
21 
     | 
    
         
             
            from vivarium.framework.engine import Builder
         
     | 
| 
         @@ -84,7 +86,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       84 
86 
     | 
    
         
             
                default value will be used. The default triggered value is 'NOT_TRIGGERED'.
         
     | 
| 
       85 
87 
     | 
    
         
             
                """
         
     | 
| 
       86 
88 
     | 
    
         | 
| 
       87 
     | 
    
         
            -
                def parse_component_config(self, component_config:  
     | 
| 
      
 89 
     | 
    
         
            +
                def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]:
         
     | 
| 
       88 
90 
     | 
    
         
             
                    """
         
     | 
| 
       89 
91 
     | 
    
         
             
                    Parses the component configuration and returns a list of components.
         
     | 
| 
       90 
92 
     | 
    
         | 
| 
         @@ -137,7 +139,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       137 
139 
     | 
    
         
             
                    Parameters
         
     | 
| 
       138 
140 
     | 
    
         
             
                    ----------
         
     | 
| 
       139 
141 
     | 
    
         
             
                    component_config
         
     | 
| 
       140 
     | 
    
         
            -
                        A  
     | 
| 
      
 142 
     | 
    
         
            +
                        A LayeredConfigTree defining the components to initialize.
         
     | 
| 
       141 
143 
     | 
    
         | 
| 
       142 
144 
     | 
    
         
             
                    Returns
         
     | 
| 
       143 
145 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -158,7 +160,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       158 
160 
     | 
    
         
             
                                source = f"{package}::{config_file}"
         
     | 
| 
       159 
161 
     | 
    
         
             
                                config_file = resource_filename(package, config_file)
         
     | 
| 
       160 
162 
     | 
    
         | 
| 
       161 
     | 
    
         
            -
                                external_config =  
     | 
| 
      
 163 
     | 
    
         
            +
                                external_config = LayeredConfigTree(config_file)
         
     | 
| 
       162 
164 
     | 
    
         
             
                                component_config.update(
         
     | 
| 
       163 
165 
     | 
    
         
             
                                    external_config, layer="model_override", source=source
         
     | 
| 
       164 
166 
     | 
    
         
             
                                )
         
     | 
| 
         @@ -185,7 +187,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       185 
187 
     | 
    
         
             
                # Configuration methods #
         
     | 
| 
       186 
188 
     | 
    
         
             
                #########################
         
     | 
| 
       187 
189 
     | 
    
         | 
| 
       188 
     | 
    
         
            -
                def _add_default_config_layer(self, causes_config:  
     | 
| 
      
 190 
     | 
    
         
            +
                def _add_default_config_layer(self, causes_config: LayeredConfigTree) -> None:
         
     | 
| 
       189 
191 
     | 
    
         
             
                    """
         
     | 
| 
       190 
192 
     | 
    
         
             
                    Adds a default layer to the provided configuration that specifies
         
     | 
| 
       191 
193 
     | 
    
         
             
                    default values for the cause model configuration.
         
     | 
| 
         @@ -193,7 +195,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       193 
195 
     | 
    
         
             
                    Parameters
         
     | 
| 
       194 
196 
     | 
    
         
             
                    ----------
         
     | 
| 
       195 
197 
     | 
    
         
             
                    causes_config
         
     | 
| 
       196 
     | 
    
         
            -
                        A  
     | 
| 
      
 198 
     | 
    
         
            +
                        A LayeredConfigTree defining the cause model configurations
         
     | 
| 
       197 
199 
     | 
    
         | 
| 
       198 
200 
     | 
    
         
             
                    Returns
         
     | 
| 
       199 
201 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -223,7 +225,9 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       223 
225 
     | 
    
         
             
                # Cause model creation methods #
         
     | 
| 
       224 
226 
     | 
    
         
             
                ################################
         
     | 
| 
       225 
227 
     | 
    
         | 
| 
       226 
     | 
    
         
            -
                def _get_cause_model_components( 
     | 
| 
      
 228 
     | 
    
         
            +
                def _get_cause_model_components(
         
     | 
| 
      
 229 
     | 
    
         
            +
                    self, causes_config: LayeredConfigTree
         
     | 
| 
      
 230 
     | 
    
         
            +
                ) -> List[Component]:
         
     | 
| 
       227 
231 
     | 
    
         
             
                    """
         
     | 
| 
       228 
232 
     | 
    
         
             
                    Parses the cause model configuration and returns a list of
         
     | 
| 
       229 
233 
     | 
    
         
             
                    `DiseaseModel` components.
         
     | 
| 
         @@ -231,7 +235,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       231 
235 
     | 
    
         
             
                    Parameters
         
     | 
| 
       232 
236 
     | 
    
         
             
                    ----------
         
     | 
| 
       233 
237 
     | 
    
         
             
                    causes_config
         
     | 
| 
       234 
     | 
    
         
            -
                        A  
     | 
| 
      
 238 
     | 
    
         
            +
                        A LayeredConfigTree defining the cause model components to initialize
         
     | 
| 
       235 
239 
     | 
    
         | 
| 
       236 
240 
     | 
    
         
             
                    Returns
         
     | 
| 
       237 
241 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -241,6 +245,11 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       241 
245 
     | 
    
         
             
                    cause_models = []
         
     | 
| 
       242 
246 
     | 
    
         | 
| 
       243 
247 
     | 
    
         
             
                    for cause_name, cause_config in causes_config.items():
         
     | 
| 
      
 248 
     | 
    
         
            +
                        data_sources = None
         
     | 
| 
      
 249 
     | 
    
         
            +
                        if "data_sources" in cause_config:
         
     | 
| 
      
 250 
     | 
    
         
            +
                            data_sources_config = cause_config.data_sources
         
     | 
| 
      
 251 
     | 
    
         
            +
                            data_sources = self._get_data_sources(data_sources_config)
         
     | 
| 
      
 252 
     | 
    
         
            +
             
     | 
| 
       244 
253 
     | 
    
         
             
                        states: Dict[str, BaseDiseaseState] = {
         
     | 
| 
       245 
254 
     | 
    
         
             
                            state_name: self._get_state(state_name, state_config, cause_name)
         
     | 
| 
       246 
255 
     | 
    
         
             
                            for state_name, state_config in cause_config.states.items()
         
     | 
| 
         @@ -256,14 +265,17 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       256 
265 
     | 
    
         
             
                        model_type = import_by_path(cause_config.model_type)
         
     | 
| 
       257 
266 
     | 
    
         
             
                        initial_state = states.get(cause_config.initial_state, None)
         
     | 
| 
       258 
267 
     | 
    
         
             
                        model = model_type(
         
     | 
| 
       259 
     | 
    
         
            -
                            cause_name, 
     | 
| 
      
 268 
     | 
    
         
            +
                            cause_name,
         
     | 
| 
      
 269 
     | 
    
         
            +
                            initial_state=initial_state,
         
     | 
| 
      
 270 
     | 
    
         
            +
                            states=list(states.values()),
         
     | 
| 
      
 271 
     | 
    
         
            +
                            get_data_functions=data_sources,
         
     | 
| 
       260 
272 
     | 
    
         
             
                        )
         
     | 
| 
       261 
273 
     | 
    
         
             
                        cause_models.append(model)
         
     | 
| 
       262 
274 
     | 
    
         | 
| 
       263 
275 
     | 
    
         
             
                    return cause_models
         
     | 
| 
       264 
276 
     | 
    
         | 
| 
       265 
277 
     | 
    
         
             
                def _get_state(
         
     | 
| 
       266 
     | 
    
         
            -
                    self, state_name: str, state_config:  
     | 
| 
      
 278 
     | 
    
         
            +
                    self, state_name: str, state_config: LayeredConfigTree, cause_name: str
         
     | 
| 
       267 
279 
     | 
    
         
             
                ) -> BaseDiseaseState:
         
     | 
| 
       268 
280 
     | 
    
         
             
                    """
         
     | 
| 
       269 
281 
     | 
    
         
             
                    Parses a state configuration and returns an initialized `BaseDiseaseState`
         
     | 
| 
         @@ -274,7 +286,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       274 
286 
     | 
    
         
             
                    state_name
         
     | 
| 
       275 
287 
     | 
    
         
             
                        The name of the state to initialize
         
     | 
| 
       276 
288 
     | 
    
         
             
                    state_config
         
     | 
| 
       277 
     | 
    
         
            -
                        A  
     | 
| 
      
 289 
     | 
    
         
            +
                        A LayeredConfigTree defining the state to initialize
         
     | 
| 
       278 
290 
     | 
    
         
             
                    cause_name
         
     | 
| 
       279 
291 
     | 
    
         
             
                        The name of the cause to which the state belongs
         
     | 
| 
       280 
292 
     | 
    
         | 
| 
         @@ -296,10 +308,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       296 
308 
     | 
    
         
             
                        state_kwargs["cleanup_function"] = lambda *x: x
         
     | 
| 
       297 
309 
     | 
    
         
             
                    if "data_sources" in state_config:
         
     | 
| 
       298 
310 
     | 
    
         
             
                        data_sources_config = state_config.data_sources
         
     | 
| 
       299 
     | 
    
         
            -
                        state_kwargs["get_data_functions"] =  
     | 
| 
       300 
     | 
    
         
            -
                            name: self._get_data_source(name, data_sources_config[name])
         
     | 
| 
       301 
     | 
    
         
            -
                            for name in data_sources_config.keys()
         
     | 
| 
       302 
     | 
    
         
            -
                        }
         
     | 
| 
      
 311 
     | 
    
         
            +
                        state_kwargs["get_data_functions"] = self._get_data_sources(data_sources_config)
         
     | 
| 
       303 
312 
     | 
    
         | 
| 
       304 
313 
     | 
    
         
             
                    if state_config.state_type is not None:
         
     | 
| 
       305 
314 
     | 
    
         
             
                        state_type = import_by_path(state_config.state_type)
         
     | 
| 
         @@ -319,7 +328,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       319 
328 
     | 
    
         
             
                    self,
         
     | 
| 
       320 
329 
     | 
    
         
             
                    source_state: BaseDiseaseState,
         
     | 
| 
       321 
330 
     | 
    
         
             
                    sink_state: BaseDiseaseState,
         
     | 
| 
       322 
     | 
    
         
            -
                    transition_config:  
     | 
| 
      
 331 
     | 
    
         
            +
                    transition_config: LayeredConfigTree,
         
     | 
| 
       323 
332 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       324 
333 
     | 
    
         
             
                    """
         
     | 
| 
       325 
334 
     | 
    
         
             
                    Adds a transition between two states.
         
     | 
| 
         @@ -331,7 +340,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       331 
340 
     | 
    
         
             
                    sink_state
         
     | 
| 
       332 
341 
     | 
    
         
             
                        The state the transition ends at
         
     | 
| 
       333 
342 
     | 
    
         
             
                    transition_config
         
     | 
| 
       334 
     | 
    
         
            -
                        A ` 
     | 
| 
      
 343 
     | 
    
         
            +
                        A `LayeredConfigTree` defining the transition to add
         
     | 
| 
       335 
344 
     | 
    
         | 
| 
       336 
345 
     | 
    
         
             
                    Returns
         
     | 
| 
       337 
346 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -340,10 +349,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       340 
349 
     | 
    
         
             
                    triggered = Trigger[transition_config.triggered]
         
     | 
| 
       341 
350 
     | 
    
         
             
                    if "data_sources" in transition_config:
         
     | 
| 
       342 
351 
     | 
    
         
             
                        data_sources_config = transition_config.data_sources
         
     | 
| 
       343 
     | 
    
         
            -
                        data_sources =  
     | 
| 
       344 
     | 
    
         
            -
                            name: self._get_data_source(name, data_sources_config[name])
         
     | 
| 
       345 
     | 
    
         
            -
                            for name in data_sources_config.keys()
         
     | 
| 
       346 
     | 
    
         
            -
                        }
         
     | 
| 
      
 352 
     | 
    
         
            +
                        data_sources = self._get_data_sources(data_sources_config)
         
     | 
| 
       347 
353 
     | 
    
         
             
                    else:
         
     | 
| 
       348 
354 
     | 
    
         
             
                        data_sources = None
         
     | 
| 
       349 
355 
     | 
    
         | 
| 
         @@ -363,6 +369,25 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       363 
369 
     | 
    
         
             
                            f" provided for transition '{transition_config}'."
         
     | 
| 
       364 
370 
     | 
    
         
             
                        )
         
     | 
| 
       365 
371 
     | 
    
         | 
| 
      
 372 
     | 
    
         
            +
                def _get_data_sources(
         
     | 
| 
      
 373 
     | 
    
         
            +
                    self, config: LayeredConfigTree
         
     | 
| 
      
 374 
     | 
    
         
            +
                ) -> Dict[str, Callable[[Builder, Any], Any]]:
         
     | 
| 
      
 375 
     | 
    
         
            +
                    """
         
     | 
| 
      
 376 
     | 
    
         
            +
                    Parses a data sources configuration and returns a dictionary of data
         
     | 
| 
      
 377 
     | 
    
         
            +
                    sources.
         
     | 
| 
      
 378 
     | 
    
         
            +
             
     | 
| 
      
 379 
     | 
    
         
            +
                    Parameters
         
     | 
| 
      
 380 
     | 
    
         
            +
                    ----------
         
     | 
| 
      
 381 
     | 
    
         
            +
                    config
         
     | 
| 
      
 382 
     | 
    
         
            +
                        A LayeredConfigTree defining the data sources to initialize
         
     | 
| 
      
 383 
     | 
    
         
            +
             
     | 
| 
      
 384 
     | 
    
         
            +
                    Returns
         
     | 
| 
      
 385 
     | 
    
         
            +
                    -------
         
     | 
| 
      
 386 
     | 
    
         
            +
                    Dict[str, Callable[[Builder, Any], Any]]
         
     | 
| 
      
 387 
     | 
    
         
            +
                        A dictionary of data source getters
         
     | 
| 
      
 388 
     | 
    
         
            +
                    """
         
     | 
| 
      
 389 
     | 
    
         
            +
                    return {name: self._get_data_source(name, config[name]) for name in config.keys()}
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
       366 
391 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       367 
392 
     | 
    
         
             
                def _get_data_source(
         
     | 
| 
       368 
393 
     | 
    
         
             
                    name: str, source: Union[str, float]
         
     | 
| 
         @@ -408,7 +433,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       408 
433 
     | 
    
         
             
                # Validation methods #
         
     | 
| 
       409 
434 
     | 
    
         
             
                ######################
         
     | 
| 
       410 
435 
     | 
    
         | 
| 
       411 
     | 
    
         
            -
                _CAUSE_KEYS = {"model_type", "initial_state", "states", "transitions"}
         
     | 
| 
      
 436 
     | 
    
         
            +
                _CAUSE_KEYS = {"model_type", "initial_state", "states", "transitions", "data_sources"}
         
     | 
| 
       412 
437 
     | 
    
         
             
                _STATE_KEYS = {
         
     | 
| 
       413 
438 
     | 
    
         
             
                    "state_type",
         
     | 
| 
       414 
439 
     | 
    
         
             
                    "cause_type",
         
     | 
| 
         @@ -420,6 +445,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       420 
445 
     | 
    
         
             
                }
         
     | 
| 
       421 
446 
     | 
    
         | 
| 
       422 
447 
     | 
    
         
             
                _DATA_SOURCE_KEYS = {
         
     | 
| 
      
 448 
     | 
    
         
            +
                    "cause": {"cause_specific_mortality_rate"},
         
     | 
| 
       423 
449 
     | 
    
         
             
                    "state": {
         
     | 
| 
       424 
450 
     | 
    
         
             
                        "prevalence",
         
     | 
| 
       425 
451 
     | 
    
         
             
                        "birth_prevalence",
         
     | 
| 
         @@ -438,14 +464,14 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       438 
464 
     | 
    
         
             
                _TRANSITION_TYPE_KEYS = {"rate", "proportion", "dwell_time"}
         
     | 
| 
       439 
465 
     | 
    
         | 
| 
       440 
466 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       441 
     | 
    
         
            -
                def _validate_external_configuration(external_configuration:  
     | 
| 
      
 467 
     | 
    
         
            +
                def _validate_external_configuration(external_configuration: LayeredConfigTree) -> None:
         
     | 
| 
       442 
468 
     | 
    
         
             
                    """
         
     | 
| 
       443 
469 
     | 
    
         
             
                    Validates the external configuration.
         
     | 
| 
       444 
470 
     | 
    
         | 
| 
       445 
471 
     | 
    
         
             
                    Parameters
         
     | 
| 
       446 
472 
     | 
    
         
             
                    ----------
         
     | 
| 
       447 
473 
     | 
    
         
             
                    external_configuration
         
     | 
| 
       448 
     | 
    
         
            -
                        A  
     | 
| 
      
 474 
     | 
    
         
            +
                        A LayeredConfigTree defining the external configuration
         
     | 
| 
       449 
475 
     | 
    
         | 
| 
       450 
476 
     | 
    
         
             
                    Returns
         
     | 
| 
       451 
477 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -477,14 +503,14 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       477 
503 
     | 
    
         
             
                    if error_messages:
         
     | 
| 
       478 
504 
     | 
    
         
             
                        raise CausesParsingErrors(error_messages)
         
     | 
| 
       479 
505 
     | 
    
         | 
| 
       480 
     | 
    
         
            -
                def _validate_causes_config(self, causes_config:  
     | 
| 
      
 506 
     | 
    
         
            +
                def _validate_causes_config(self, causes_config: LayeredConfigTree) -> None:
         
     | 
| 
       481 
507 
     | 
    
         
             
                    """
         
     | 
| 
       482 
508 
     | 
    
         
             
                    Validates the cause model configuration.
         
     | 
| 
       483 
509 
     | 
    
         | 
| 
       484 
510 
     | 
    
         
             
                    Parameters
         
     | 
| 
       485 
511 
     | 
    
         
             
                    ----------
         
     | 
| 
       486 
512 
     | 
    
         
             
                    causes_config
         
     | 
| 
       487 
     | 
    
         
            -
                        A  
     | 
| 
      
 513 
     | 
    
         
            +
                        A LayeredConfigTree defining the cause model configurations
         
     | 
| 
       488 
514 
     | 
    
         | 
| 
       489 
515 
     | 
    
         
             
                    Returns
         
     | 
| 
       490 
516 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -512,7 +538,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       512 
538 
     | 
    
         
             
                    cause_name
         
     | 
| 
       513 
539 
     | 
    
         
             
                        The name of the cause to validate
         
     | 
| 
       514 
540 
     | 
    
         
             
                    cause_config
         
     | 
| 
       515 
     | 
    
         
            -
                        A  
     | 
| 
      
 541 
     | 
    
         
            +
                        A LayeredConfigTree defining the cause to validate
         
     | 
| 
       516 
542 
     | 
    
         | 
| 
       517 
543 
     | 
    
         
             
                    Returns
         
     | 
| 
       518 
544 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -572,6 +598,10 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       572 
598 
     | 
    
         
             
                                cause_name, transition_name, transition_config, states_config
         
     | 
| 
       573 
599 
     | 
    
         
             
                            )
         
     | 
| 
       574 
600 
     | 
    
         | 
| 
      
 601 
     | 
    
         
            +
                    error_messages += self._validate_data_sources(
         
     | 
| 
      
 602 
     | 
    
         
            +
                        cause_config, cause_name, "cause", cause_name
         
     | 
| 
      
 603 
     | 
    
         
            +
                    )
         
     | 
| 
      
 604 
     | 
    
         
            +
             
     | 
| 
       575 
605 
     | 
    
         
             
                    return error_messages
         
     | 
| 
       576 
606 
     | 
    
         | 
| 
       577 
607 
     | 
    
         
             
                def _validate_state(
         
     | 
| 
         @@ -587,7 +617,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       587 
617 
     | 
    
         
             
                    state_name
         
     | 
| 
       588 
618 
     | 
    
         
             
                        The name of the state to validate
         
     | 
| 
       589 
619 
     | 
    
         
             
                    state_config
         
     | 
| 
       590 
     | 
    
         
            -
                        A  
     | 
| 
      
 620 
     | 
    
         
            +
                        A LayeredConfigTree defining the state to validate
         
     | 
| 
       591 
621 
     | 
    
         | 
| 
       592 
622 
     | 
    
         
             
                    Returns
         
     | 
| 
       593 
623 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -682,9 +712,9 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       682 
712 
     | 
    
         
             
                    transition_name
         
     | 
| 
       683 
713 
     | 
    
         
             
                        The name of the transition to validate
         
     | 
| 
       684 
714 
     | 
    
         
             
                    transition_config
         
     | 
| 
       685 
     | 
    
         
            -
                        A  
     | 
| 
      
 715 
     | 
    
         
            +
                        A LayeredConfigTree defining the transition to validate
         
     | 
| 
       686 
716 
     | 
    
         
             
                    states_config
         
     | 
| 
       687 
     | 
    
         
            -
                        A  
     | 
| 
      
 717 
     | 
    
         
            +
                        A LayeredConfigTree defining the states for the cause
         
     | 
| 
       688 
718 
     | 
    
         | 
| 
       689 
719 
     | 
    
         
             
                    Returns
         
     | 
| 
       690 
720 
     | 
    
         
             
                    -------
         
     | 
| 
         @@ -824,7 +854,7 @@ class CausesConfigurationParser(ComponentConfigurationParser): 
     | 
|
| 
       824 
854 
     | 
    
         
             
                    Parameters
         
     | 
| 
       825 
855 
     | 
    
         
             
                    ----------
         
     | 
| 
       826 
856 
     | 
    
         
             
                    config
         
     | 
| 
       827 
     | 
    
         
            -
                        A  
     | 
| 
      
 857 
     | 
    
         
            +
                        A LayeredConfigTree defining the configuration to validate
         
     | 
| 
       828 
858 
     | 
    
         
             
                    cause_name
         
     | 
| 
       829 
859 
     | 
    
         
             
                        The name of the cause to which the configuration belongs
         
     | 
| 
       830 
860 
     | 
    
         
             
                    config_type
         
     | 
| 
         @@ -6,6 +6,7 @@ Fertility Models 
     | 
|
| 
       6 
6 
     | 
    
         
             
            This module contains several different models of fertility.
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            """
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
       9 
10 
     | 
    
         
             
            from typing import Dict, List, Optional
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
       11 
12 
     | 
    
         
             
            import numpy as np
         
     | 
| 
         @@ -16,9 +17,7 @@ from vivarium.framework.event import Event 
     | 
|
| 
       16 
17 
     | 
    
         
             
            from vivarium.framework.population import SimulantData
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
       18 
19 
     | 
    
         
             
            from vivarium_public_health import utilities
         
     | 
| 
       19 
     | 
    
         
            -
            from vivarium_public_health.population.data_transformations import  
     | 
| 
       20 
     | 
    
         
            -
                get_live_births_per_year,
         
     | 
| 
       21 
     | 
    
         
            -
            )
         
     | 
| 
      
 20 
     | 
    
         
            +
            from vivarium_public_health.population.data_transformations import get_live_births_per_year
         
     | 
| 
       22 
21 
     | 
    
         | 
| 
       23 
22 
     | 
    
         
             
            # TODO: Incorporate better data into gestational model (probably as a separate component)
         
     | 
| 
       24 
23 
     | 
    
         
             
            PREGNANCY_DURATION = pd.Timedelta(days=9 * utilities.DAYS_PER_MONTH)
         
     | 
| 
         @@ -7,13 +7,14 @@ This module contains tools for sampling and assigning core demographic 
     | 
|
| 
       7 
7 
     | 
    
         
             
            characteristics to simulants.
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
9 
     | 
    
         
             
            """
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
       10 
11 
     | 
    
         
             
            from typing import Callable, Dict, Iterable, List
         
     | 
| 
       11 
12 
     | 
    
         | 
| 
       12 
13 
     | 
    
         
             
            import numpy as np
         
     | 
| 
       13 
14 
     | 
    
         
             
            import pandas as pd
         
     | 
| 
      
 15 
     | 
    
         
            +
            from layered_config_tree.exceptions import ConfigurationKeyError
         
     | 
| 
       14 
16 
     | 
    
         
             
            from loguru import logger
         
     | 
| 
       15 
17 
     | 
    
         
             
            from vivarium import Component
         
     | 
| 
       16 
     | 
    
         
            -
            from vivarium.config_tree import ConfigurationKeyError
         
     | 
| 
       17 
18 
     | 
    
         
             
            from vivarium.framework.engine import Builder
         
     | 
| 
       18 
19 
     | 
    
         
             
            from vivarium.framework.event import Event
         
     | 
| 
       19 
20 
     | 
    
         
             
            from vivarium.framework.population import SimulantData
         
     |