vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -44,25 +44,81 @@ on mortality is calculated, by subtracting the raw unmodeled csmr before adding
44
44
  back the modified unmodeled csmr.
45
45
 
46
46
  """
47
- from typing import List, Optional, Union
47
+
48
+ from typing import Any, Dict, List, Optional, Union
48
49
 
49
50
  import pandas as pd
50
51
  from vivarium import Component
51
52
  from vivarium.framework.engine import Builder
52
53
  from vivarium.framework.event import Event
53
- from vivarium.framework.lookup import LookupTable
54
54
  from vivarium.framework.population import SimulantData
55
55
  from vivarium.framework.randomness import RandomnessStream
56
56
  from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
57
57
 
58
+ from vivarium_public_health.utilities import get_lookup_columns
59
+
58
60
 
59
61
  class Mortality(Component):
60
- CONFIGURATION_DEFAULTS = {"unmodeled_causes": []}
62
+ """
63
+ This is the mortality component which models sources of mortality for a model.
64
+ THe component models all cause mortality and allows for disease models to contribute
65
+ cause specific mortality. Data used by this class should be supplied in the artifact
66
+ and is configurable in the configuration to build lookup tables. For instance, let's
67
+ say we want to use sex and hair color to build a lookup table for all cause mortality.
68
+
69
+ .. code-block:: yaml
70
+
71
+ configuration:
72
+ mortality:
73
+ all_cause_mortality_rate:
74
+ categorical_columns: ["sex", "hair_color"]
75
+
76
+ Similarly, we can do the same thing for unmodeled causes. Here is an example:
77
+
78
+ .. code-block:: yaml
79
+
80
+ configuration:
81
+ mortality:
82
+ unmodeled_cause_specific_mortality_rate:
83
+ unmodeled_causes: ["maternal_disorders", maternal_hemorrhage]
84
+ categorical_columns: ["sex", "hair_color"]
85
+
86
+ Or if we wanted to make the data a scalar value for all cause mortality rate we could
87
+ configure that as well.
88
+
89
+ .. code-block:: yaml
90
+
91
+ configuration:
92
+ mortality:
93
+ all_cause_mortality_rate:
94
+ value: 0.01
95
+
96
+ """
61
97
 
62
98
  ##############
63
99
  # Properties #
64
100
  ##############
65
101
 
102
+ @property
103
+ def configuration_defaults(self) -> Dict[str, Any]:
104
+ return {
105
+ "mortality": {
106
+ "data_sources": {
107
+ "all_cause_mortality_rate": "cause.all_causes.cause_specific_mortality_rate",
108
+ "unmodeled_cause_specific_mortality_rate": "self::load_unmodeled_csmr",
109
+ "life_expectancy": "population.theoretical_minimum_risk_life_expectancy",
110
+ },
111
+ "unmodeled_causes": [],
112
+ },
113
+ }
114
+
115
+ @property
116
+ def standard_lookup_tables(self) -> List[str]:
117
+ return [
118
+ "all_cause_mortality_rate",
119
+ "life_expectancy",
120
+ ]
121
+
66
122
  @property
67
123
  def columns_created(self) -> List[str]:
68
124
  return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
@@ -97,14 +153,10 @@ class Mortality(Component):
97
153
  self.clock = builder.time.clock()
98
154
 
99
155
  self.cause_specific_mortality_rate = self.get_cause_specific_mortality_rate(builder)
100
- self.mortality_rate = self.get_mortality_rate(builder)
101
156
 
102
- self.all_cause_mortality_rate = self.get_all_cause_mortality_rate(builder)
103
- self.life_expectancy = self.get_life_expectancy(builder)
104
-
105
- self._raw_unmodeled_csmr = self.get_raw_unmodeled_csmr(builder)
106
157
  self.unmodeled_csmr = self.get_unmodeled_csmr(builder)
107
158
  self.unmodeled_csmr_paf = self.get_unmodeled_csmr_paf(builder)
159
+ self.mortality_rate = self.get_mortality_rate(builder)
108
160
 
109
161
  #################
110
162
  # Setup methods #
@@ -120,90 +172,37 @@ class Mortality(Component):
120
172
  )
121
173
 
122
174
  def get_mortality_rate(self, builder: Builder) -> Pipeline:
175
+ required_columns = get_lookup_columns(
176
+ [
177
+ self.lookup_tables["all_cause_mortality_rate"],
178
+ self.lookup_tables["unmodeled_cause_specific_mortality_rate"],
179
+ ],
180
+ )
123
181
  return builder.value.register_rate_producer(
124
182
  self.mortality_rate_pipeline_name,
125
183
  source=self.calculate_mortality_rate,
126
- requires_columns=["age", "sex"],
127
- )
128
-
129
- # noinspection PyMethodMayBeStatic
130
- def get_all_cause_mortality_rate(self, builder: Builder) -> Union[LookupTable, Pipeline]:
131
- """
132
- Load all cause mortality rate data and build a lookup table or pipeline.
133
-
134
- Parameters
135
- ----------
136
- builder
137
- Interface to access simulation managers.
138
-
139
- Returns
140
- -------
141
- Union[LookupTable, Pipeline]
142
- A lookup table or pipeline returning the all cause mortality rate.
143
- """
144
- acmr_data = builder.data.load("cause.all_causes.cause_specific_mortality_rate")
145
- return builder.lookup.build_table(
146
- acmr_data, key_columns=["sex"], parameter_columns=["age", "year"]
184
+ requires_columns=required_columns,
147
185
  )
148
186
 
149
- # noinspection PyMethodMayBeStatic
150
- def get_life_expectancy(self, builder: Builder) -> Union[LookupTable, Pipeline]:
151
- """
152
- Load life expectancy data and build a lookup table or pipeline.
153
-
154
- Parameters
155
- ----------
156
- builder
157
- Interface to access simulation managers.
158
-
159
- Returns
160
- -------
161
- Union[LookupTable, Pipeline]
162
- A lookup table or pipeline returning the life expectancy.
163
- """
164
- life_expectancy_data = builder.data.load(
165
- "population.theoretical_minimum_risk_life_expectancy"
166
- )
167
- return builder.lookup.build_table(life_expectancy_data, parameter_columns=["age"])
168
-
169
- # noinspection PyMethodMayBeStatic
170
- def get_raw_unmodeled_csmr(self, builder: Builder) -> Union[LookupTable, Pipeline]:
171
- """
172
- Load unmodeled cause specific mortality rate data and build a lookup
173
- table or pipeline.
174
-
175
- Parameters
176
- ----------
177
- builder
178
- Interface to access simulation managers.
179
-
180
- Returns
181
- -------
182
- Union[LookupTable, Pipeline]
183
- A lookup table or pipeline returning the unmodeled csmr.
184
- """
185
- unmodeled_causes = builder.configuration.unmodeled_causes
187
+ def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
188
+ # todo validate that all data have the same columns
186
189
  raw_csmr = 0.0
187
- for idx, cause in enumerate(unmodeled_causes):
190
+ for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
188
191
  csmr = f"cause.{cause}.cause_specific_mortality_rate"
189
192
  if 0 == idx:
190
193
  raw_csmr = builder.data.load(csmr)
191
194
  else:
192
195
  raw_csmr.loc[:, "value"] += builder.data.load(csmr).value
193
-
194
- additional_parameters = (
195
- {"key_columns": ["sex"], "parameter_columns": ["age", "year"]}
196
- if unmodeled_causes
197
- else {}
198
- )
199
-
200
- return builder.lookup.build_table(raw_csmr, **additional_parameters)
196
+ return raw_csmr
201
197
 
202
198
  def get_unmodeled_csmr(self, builder: Builder) -> Pipeline:
199
+ required_columns = get_lookup_columns(
200
+ [self.lookup_tables["unmodeled_cause_specific_mortality_rate"]]
201
+ )
203
202
  return builder.value.register_value_producer(
204
203
  self.unmodeled_csmr_pipeline_name,
205
204
  source=self.get_unmodeled_csmr_source,
206
- requires_columns=["age", "sex"],
205
+ requires_columns=required_columns,
207
206
  )
208
207
 
209
208
  def get_unmodeled_csmr_paf(self, builder: Builder) -> Pipeline:
@@ -248,7 +247,9 @@ class Mortality(Component):
248
247
  )
249
248
  pop.loc[deaths, "alive"] = "dead"
250
249
  pop.loc[deaths, "exit_time"] = event.time
251
- pop.loc[deaths, "years_of_life_lost"] = self.life_expectancy(deaths)
250
+ pop.loc[deaths, "years_of_life_lost"] = self.lookup_tables["life_expectancy"](
251
+ deaths
252
+ )
252
253
  pop.loc[deaths, "cause_of_death"] = cause_of_death
253
254
  self.population_view.update(pop)
254
255
 
@@ -257,9 +258,11 @@ class Mortality(Component):
257
258
  ##################################
258
259
 
259
260
  def calculate_mortality_rate(self, index: pd.Index) -> pd.DataFrame:
260
- acmr = self.all_cause_mortality_rate(index)
261
+ acmr = self.lookup_tables["all_cause_mortality_rate"](index)
261
262
  modeled_csmr = self.cause_specific_mortality_rate(index)
262
- unmodeled_csmr_raw = self._raw_unmodeled_csmr(index)
263
+ unmodeled_csmr_raw = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](
264
+ index
265
+ )
263
266
  unmodeled_csmr = self.unmodeled_csmr(index)
264
267
  cause_deleted_mortality_rate = (
265
268
  acmr - modeled_csmr - unmodeled_csmr_raw + unmodeled_csmr
@@ -267,6 +270,6 @@ class Mortality(Component):
267
270
  return pd.DataFrame({"other_causes": cause_deleted_mortality_rate})
268
271
 
269
272
  def get_unmodeled_csmr_source(self, index: pd.Index) -> pd.Series:
270
- raw_csmr = self._raw_unmodeled_csmr(index)
273
+ raw_csmr = self.lookup_tables["unmodeled_cause_specific_mortality_rate"](index)
271
274
  paf = self.unmodeled_csmr_paf(index)
272
275
  return raw_csmr * (1 - paf)
@@ -1,5 +1,7 @@
1
+ from .columns import COLUMNS
1
2
  from .disability import DisabilityObserver
2
3
  from .disease import DiseaseObserver
3
4
  from .mortality import MortalityObserver
5
+ from .observer import PublicHealthObserver
4
6
  from .risk import CategoricalRiskObserver
5
7
  from .stratification import ResultsStratifier
@@ -0,0 +1,22 @@
1
+ from typing import NamedTuple
2
+
3
+ from vivarium.framework.results import VALUE_COLUMN
4
+
5
+
6
+ class __Columns(NamedTuple):
7
+ """column names"""
8
+
9
+ VALUE: str = VALUE_COLUMN
10
+ MEASURE: str = "measure"
11
+ TRANSITION: str = "transition"
12
+ STATE: str = "state"
13
+ ENTITY_TYPE: str = "entity_type"
14
+ SUB_ENTITY: str = "sub_entity"
15
+ ENTITY: str = "entity"
16
+
17
+ @property
18
+ def name(self) -> str:
19
+ return "columns"
20
+
21
+
22
+ COLUMNS = __Columns()
@@ -0,0 +1,187 @@
1
+ """
2
+ ===================
3
+ Disability Observer
4
+ ===================
5
+
6
+ This module contains tools for observing years lived with disability (YLDs)
7
+ in the simulation.
8
+
9
+ """
10
+
11
+ from typing import Any, List, Union
12
+
13
+ import pandas as pd
14
+ from layered_config_tree import LayeredConfigTree
15
+ from loguru import logger
16
+ from pandas.api.types import CategoricalDtype
17
+ from vivarium.framework.engine import Builder
18
+ from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
19
+
20
+ from vivarium_public_health.disease import DiseaseState, RiskAttributableDisease
21
+ from vivarium_public_health.results.columns import COLUMNS
22
+ from vivarium_public_health.results.observer import PublicHealthObserver
23
+ from vivarium_public_health.results.simple_cause import SimpleCause
24
+ from vivarium_public_health.utilities import to_years
25
+
26
+
27
+ class DisabilityObserver(PublicHealthObserver):
28
+ """Counts years lived with disability.
29
+
30
+ By default, this counts both aggregate and cause-specific years lived
31
+ with disability over the full course of the simulation.
32
+
33
+ In the model specification, your configuration for this component should
34
+ be specified as, e.g.:
35
+
36
+ .. code-block:: yaml
37
+
38
+ configuration:
39
+ observers:
40
+ disability:
41
+ exclude:
42
+ - "sex"
43
+ include:
44
+ - "sample_stratification"
45
+ """
46
+
47
+ ##############
48
+ # Properties #
49
+ ##############
50
+
51
+ @property
52
+ def disability_classes(self) -> list[type]:
53
+ """The classes to be considered for causes of disability."""
54
+ return [DiseaseState, RiskAttributableDisease]
55
+
56
+ #####################
57
+ # Lifecycle methods #
58
+ #####################
59
+
60
+ def __init__(self) -> None:
61
+ super().__init__()
62
+ self.disability_weight_pipeline_name = "all_causes.disability_weight"
63
+
64
+ #################
65
+ # Setup methods #
66
+ #################
67
+
68
+ def setup(self, builder: Builder) -> None:
69
+ self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
70
+ self.disability_weight = self.get_disability_weight_pipeline(builder)
71
+ self.set_causes_of_disability(builder)
72
+
73
+ def set_causes_of_disability(self, builder: Builder) -> None:
74
+ """Set the causes of disability to be observed by removing any excluded
75
+ via the model spec from the list of all disability class causes. We implement
76
+ exclusions here because disabilities are unique in that they are not
77
+ registered stratifications and so cannot be excluded during the stratification
78
+ call like other categories.
79
+ """
80
+ causes_of_disability = builder.components.get_components_by_type(
81
+ self.disability_classes
82
+ )
83
+ # Convert to SimpleCause instances and add on all_causes
84
+ causes_of_disability = [
85
+ SimpleCause.create_from_disease_state(cause) for cause in causes_of_disability
86
+ ] + [SimpleCause("all_causes", "all_causes", "cause")]
87
+
88
+ excluded_causes = (
89
+ builder.configuration.stratification.excluded_categories.to_dict().get(
90
+ "disability", []
91
+ )
92
+ )
93
+
94
+ # Handle exclusions that don't exist in the list of causes
95
+ cause_names = [cause.state_id for cause in causes_of_disability]
96
+ unknown_exclusions = set(excluded_causes) - set(cause_names)
97
+ if len(unknown_exclusions) > 0:
98
+ raise ValueError(
99
+ f"Excluded 'disability' causes {unknown_exclusions} not found in "
100
+ f"expected categories categories: {cause_names}"
101
+ )
102
+
103
+ # Drop excluded causes
104
+ if excluded_causes:
105
+ logger.debug(
106
+ f"'disability' has category exclusion requests: {excluded_causes}\n"
107
+ "Removing these from the allowable categories."
108
+ )
109
+ self.causes_of_disability = [
110
+ cause for cause in causes_of_disability if cause.state_id not in excluded_causes
111
+ ]
112
+
113
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
114
+ return builder.configuration.stratification.disability
115
+
116
+ def register_observations(self, builder: Builder) -> None:
117
+ cause_pipelines = [
118
+ f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
119
+ ]
120
+ self.register_adding_observation(
121
+ builder=builder,
122
+ name="ylds",
123
+ pop_filter='tracked == True and alive == "alive"',
124
+ when="time_step__prepare",
125
+ requires_columns=["alive"],
126
+ requires_values=cause_pipelines,
127
+ additional_stratifications=self.configuration.include,
128
+ excluded_stratifications=self.configuration.exclude,
129
+ aggregator_sources=cause_pipelines,
130
+ aggregator=self.disability_weight_aggregator,
131
+ )
132
+
133
+ def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
134
+ return builder.value.register_value_producer(
135
+ self.disability_weight_pipeline_name,
136
+ source=lambda index: [pd.Series(0.0, index=index)],
137
+ preferred_combiner=list_combiner,
138
+ preferred_post_processor=union_post_processor,
139
+ )
140
+
141
+ ###############
142
+ # Aggregators #
143
+ ###############
144
+
145
+ def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
146
+ aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
147
+ if isinstance(aggregated_dw, pd.Series):
148
+ aggregated_dw.index.name = "cause_of_disability"
149
+ return aggregated_dw
150
+
151
+ ##############################
152
+ # Results formatting methods #
153
+ ##############################
154
+
155
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
156
+ """Format results. Note that ylds are unique in that we
157
+ can't stratify by cause of disability (because there can be multiple at
158
+ once), and so the results here are actually wide by disability weight
159
+ pipeline name.
160
+ """
161
+
162
+ # Drop the unused 'value' column and rename the pipeline names to causes
163
+ results = results.drop(columns=["value"]).rename(
164
+ columns={col: col.replace(".disability_weight", "") for col in results.columns},
165
+ )
166
+ # Get desired index names prior to stacking
167
+ idx_names = list(results.index.names) + [COLUMNS.SUB_ENTITY]
168
+ results = pd.DataFrame(results.stack(), columns=[COLUMNS.VALUE])
169
+ # Name the new index level
170
+ results.index.set_names(idx_names, inplace=True)
171
+ results = results.reset_index()
172
+ results[COLUMNS.SUB_ENTITY] = results[COLUMNS.SUB_ENTITY].astype(CategoricalDtype())
173
+ return results
174
+
175
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
176
+ entity_type_map = {
177
+ cause.state_id: cause.cause_type for cause in self.causes_of_disability
178
+ }
179
+ return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
180
+
181
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
182
+ entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
183
+ return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
184
+
185
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
186
+ # The sub-entity col was created in the 'format' method
187
+ return results[COLUMNS.SUB_ENTITY]
@@ -0,0 +1,222 @@
1
+ """
2
+ ================
3
+ Disease Observer
4
+ ================
5
+
6
+ This module contains tools for observing disease incidence and prevalence
7
+ in the simulation.
8
+
9
+ """
10
+
11
+ from typing import Any, Dict, List
12
+
13
+ import pandas as pd
14
+ from layered_config_tree import LayeredConfigTree
15
+ from vivarium.framework.engine import Builder
16
+ from vivarium.framework.event import Event
17
+ from vivarium.framework.population import SimulantData
18
+
19
+ from vivarium_public_health.results.columns import COLUMNS
20
+ from vivarium_public_health.results.observer import PublicHealthObserver
21
+ from vivarium_public_health.utilities import to_years
22
+
23
+
24
+ class DiseaseObserver(PublicHealthObserver):
25
+ """Observes disease counts and person time for a cause.
26
+
27
+ By default, this observer computes aggregate disease state person time and
28
+ counts of disease events over the full course of the simulation. It can be
29
+ configured to add or remove stratification groups to the default groups
30
+ defined by a ResultsStratifier.
31
+
32
+ In the model specification, your configuration for this component should
33
+ be specified as, e.g.:
34
+
35
+ .. code-block:: yaml
36
+
37
+ configuration:
38
+ stratification:
39
+ cause_name:
40
+ exclude:
41
+ - "sex"
42
+ include:
43
+ - "sample_stratification"
44
+ """
45
+
46
+ ##############
47
+ # Properties #
48
+ ##############
49
+
50
+ @property
51
+ def configuration_defaults(self) -> Dict[str, Any]:
52
+ return {
53
+ "stratification": {
54
+ self.disease: super().configuration_defaults["stratification"][
55
+ self.get_configuration_name()
56
+ ]
57
+ }
58
+ }
59
+
60
+ @property
61
+ def columns_created(self) -> List[str]:
62
+ return [self.previous_state_column_name]
63
+
64
+ @property
65
+ def columns_required(self) -> List[str]:
66
+ return [self.disease]
67
+
68
+ @property
69
+ def initialization_requirements(self) -> Dict[str, List[str]]:
70
+ return {
71
+ "requires_columns": [self.disease],
72
+ }
73
+
74
+ #####################
75
+ # Lifecycle methods #
76
+ #####################
77
+
78
+ def __init__(self, disease: str) -> None:
79
+ super().__init__()
80
+ self.disease = disease
81
+ self.previous_state_column_name = f"previous_{self.disease}"
82
+
83
+ #################
84
+ # Setup methods #
85
+ #################
86
+
87
+ def setup(self, builder: Builder) -> None:
88
+ self.step_size = builder.time.step_size()
89
+ self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
90
+ self.entity_type = self.disease_model.cause_type
91
+ self.entity = self.disease_model.cause
92
+ self.transition_stratification_name = f"transition_{self.disease}"
93
+
94
+ def get_configuration(self, builder: Builder) -> LayeredConfigTree:
95
+ return builder.configuration.stratification[self.disease]
96
+
97
+ def register_observations(self, builder: Builder) -> None:
98
+
99
+ self.register_disease_state_stratification(builder)
100
+ self.register_transition_stratification(builder)
101
+
102
+ pop_filter = 'alive == "alive" and tracked==True'
103
+ self.register_person_time_observation(builder, pop_filter)
104
+ self.register_transition_count_observation(builder, pop_filter)
105
+
106
+ def register_disease_state_stratification(self, builder: Builder) -> None:
107
+ builder.results.register_stratification(
108
+ self.disease,
109
+ [state.state_id for state in self.disease_model.states],
110
+ requires_columns=[self.disease],
111
+ )
112
+
113
+ def register_transition_stratification(self, builder: Builder) -> None:
114
+ transitions = [
115
+ str(transition) for transition in self.disease_model.transition_names
116
+ ] + ["no_transition"]
117
+ # manually append 'no_transition' as an excluded transition
118
+ excluded_categories = (
119
+ builder.configuration.stratification.excluded_categories.to_dict().get(
120
+ self.transition_stratification_name, []
121
+ )
122
+ ) + ["no_transition"]
123
+ builder.results.register_stratification(
124
+ self.transition_stratification_name,
125
+ categories=transitions,
126
+ excluded_categories=excluded_categories,
127
+ mapper=self.map_transitions,
128
+ requires_columns=[self.disease, self.previous_state_column_name],
129
+ is_vectorized=True,
130
+ )
131
+
132
+ def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
133
+ self.register_adding_observation(
134
+ builder=builder,
135
+ name=f"person_time_{self.disease}",
136
+ pop_filter=pop_filter,
137
+ when="time_step__prepare",
138
+ requires_columns=["alive", self.disease],
139
+ additional_stratifications=self.configuration.include + [self.disease],
140
+ excluded_stratifications=self.configuration.exclude,
141
+ aggregator=self.aggregate_state_person_time,
142
+ )
143
+
144
+ def register_transition_count_observation(
145
+ self, builder: Builder, pop_filter: str
146
+ ) -> None:
147
+ self.register_adding_observation(
148
+ builder=builder,
149
+ name=f"transition_count_{self.disease}",
150
+ pop_filter=pop_filter,
151
+ requires_columns=[
152
+ self.previous_state_column_name,
153
+ self.disease,
154
+ ],
155
+ additional_stratifications=self.configuration.include
156
+ + [self.transition_stratification_name],
157
+ excluded_stratifications=self.configuration.exclude,
158
+ )
159
+
160
+ def map_transitions(self, df: pd.DataFrame) -> pd.Series:
161
+ transitions = pd.Series(index=df.index, dtype=str)
162
+ transition_mask = df[self.previous_state_column_name] != df[self.disease]
163
+ transitions[~transition_mask] = "no_transition"
164
+ transitions[transition_mask] = (
165
+ df[self.previous_state_column_name].astype(str)
166
+ + "_to_"
167
+ + df[self.disease].astype(str)
168
+ )
169
+ return transitions
170
+
171
+ ########################
172
+ # Event-driven methods #
173
+ ########################
174
+
175
+ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
176
+ """Initialize the previous state column to the current state"""
177
+ pop = self.population_view.subview([self.disease]).get(pop_data.index)
178
+ pop[self.previous_state_column_name] = pop[self.disease]
179
+ self.population_view.update(pop)
180
+
181
+ def on_time_step_prepare(self, event: Event) -> None:
182
+ # This enables tracking of transitions between states
183
+ prior_state_pop = self.population_view.get(event.index)
184
+ prior_state_pop[self.previous_state_column_name] = prior_state_pop[self.disease]
185
+ self.population_view.update(prior_state_pop)
186
+
187
+ ###############
188
+ # Aggregators #
189
+ ###############
190
+
191
+ def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
192
+ return len(x) * to_years(self.step_size())
193
+
194
+ ##############################
195
+ # Results formatting methods #
196
+ ##############################
197
+
198
+ def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
199
+ results = results.reset_index()
200
+ if "transition_count_" in measure:
201
+ sub_entity = self.transition_stratification_name
202
+ if "person_time_" in measure:
203
+ sub_entity = self.disease
204
+ results.rename(columns={sub_entity: COLUMNS.SUB_ENTITY}, inplace=True)
205
+ return results
206
+
207
+ def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
208
+ if "transition_count_" in measure:
209
+ measure_name = "transition_count"
210
+ if "person_time_" in measure:
211
+ measure_name = "person_time"
212
+ return pd.Series(measure_name, index=results.index)
213
+
214
+ def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
215
+ return pd.Series(self.entity_type, index=results.index)
216
+
217
+ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
218
+ return pd.Series(self.entity, index=results.index)
219
+
220
+ def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
221
+ # The sub-entity col was created in the 'format' method
222
+ return results[COLUMNS.SUB_ENTITY]