vivarium-public-health 3.0.3__py3-none-any.whl → 3.0.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/state.py +24 -19
  3. vivarium_public_health/mslt/delay.py +13 -5
  4. vivarium_public_health/mslt/disease.py +35 -14
  5. vivarium_public_health/mslt/intervention.py +12 -9
  6. vivarium_public_health/mslt/observer.py +56 -17
  7. vivarium_public_health/mslt/population.py +7 -10
  8. vivarium_public_health/plugins/parser.py +29 -80
  9. vivarium_public_health/population/add_new_birth_cohorts.py +8 -9
  10. vivarium_public_health/population/base_population.py +0 -5
  11. vivarium_public_health/population/data_transformations.py +1 -8
  12. vivarium_public_health/population/mortality.py +3 -3
  13. vivarium_public_health/results/columns.py +1 -1
  14. vivarium_public_health/results/disability.py +89 -15
  15. vivarium_public_health/results/disease.py +128 -5
  16. vivarium_public_health/results/mortality.py +82 -6
  17. vivarium_public_health/results/observer.py +151 -6
  18. vivarium_public_health/results/risk.py +66 -5
  19. vivarium_public_health/results/simple_cause.py +30 -5
  20. vivarium_public_health/results/stratification.py +39 -14
  21. vivarium_public_health/risks/base_risk.py +14 -16
  22. vivarium_public_health/risks/data_transformations.py +3 -1
  23. vivarium_public_health/risks/distributions.py +0 -1
  24. vivarium_public_health/risks/effect.py +31 -29
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +51 -30
  26. vivarium_public_health/treatment/scale_up.py +6 -10
  27. vivarium_public_health/treatment/therapeutic_inertia.py +3 -1
  28. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/METADATA +1 -1
  29. vivarium_public_health-3.0.5.dist-info/RECORD +49 -0
  30. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/WHEEL +1 -1
  31. vivarium_public_health-3.0.3.dist-info/RECORD +0 -49
  32. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/LICENSE.txt +0 -0
  33. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  """
2
- ===================
3
- Disability Observer
4
- ===================
2
+ ====================
3
+ Disability Observers
4
+ ====================
5
5
 
6
6
  This module contains tools for observing years lived with disability (YLDs)
7
7
  in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, List, Union
11
+ from typing import Union
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -42,6 +42,17 @@ class DisabilityObserver(PublicHealthObserver):
42
42
  - "sex"
43
43
  include:
44
44
  - "sample_stratification"
45
+ Attributes
46
+ ----------
47
+ disability_weight_pipeline_name
48
+ The name of the pipeline that produces disability weights.
49
+ step_size
50
+ The time step size of the simulation.
51
+ disability_weight
52
+ The pipeline that produces disability weights.
53
+ causes_of_disability
54
+ The causes of disability to be observed.
55
+
45
56
  """
46
57
 
47
58
  ##############
@@ -50,7 +61,7 @@ class DisabilityObserver(PublicHealthObserver):
50
61
 
51
62
  @property
52
63
  def disability_classes(self) -> list[type]:
53
- """The classes to be considered for causes of disability."""
64
+ """The classes to be considered as causes of disability."""
54
65
  return [DiseaseState, RiskAttributableDisease]
55
66
 
56
67
  #####################
@@ -66,23 +77,32 @@ class DisabilityObserver(PublicHealthObserver):
66
77
  #################
67
78
 
68
79
  def setup(self, builder: Builder) -> None:
80
+ """Set up the observer."""
69
81
  self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
70
82
  self.disability_weight = self.get_disability_weight_pipeline(builder)
71
83
  self.set_causes_of_disability(builder)
72
84
 
73
85
  def set_causes_of_disability(self, builder: Builder) -> None:
74
- """Set the causes of disability to be observed by removing any excluded
75
- via the model spec from the list of all disability class causes. We implement
76
- exclusions here because disabilities are unique in that they are not
77
- registered stratifications and so cannot be excluded during the stratification
78
- call like other categories.
86
+ """Set the causes of disability to be observed.
87
+
88
+ The causes to be observed are any registered components of class types
89
+ found in the ``disability_classes`` property *excluding* any listed in
90
+ the model spec as ``excluded_categories``.
91
+
92
+ Notes
93
+ -----
94
+ We implement exclusions here instead of during the stratification call
95
+ like most other categories because disabilities are unique in that they are
96
+ *not* actually registered stratifications.
97
+
98
+ Also note that we add an 'all_causes' category here.
79
99
  """
80
100
  causes_of_disability = builder.components.get_components_by_type(
81
101
  self.disability_classes
82
102
  )
83
103
  # Convert to SimpleCause instances and add on all_causes
84
104
  causes_of_disability = [
85
- SimpleCause.create_from_disease_state(cause) for cause in causes_of_disability
105
+ SimpleCause.create_from_specific_cause(cause) for cause in causes_of_disability
86
106
  ] + [SimpleCause("all_causes", "all_causes", "cause")]
87
107
 
88
108
  excluded_causes = (
@@ -111,9 +131,21 @@ class DisabilityObserver(PublicHealthObserver):
111
131
  ]
112
132
 
113
133
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
134
+ """Get the stratification configuration for this observer.
135
+
136
+ Parameters
137
+ ----------
138
+ builder
139
+ The builder object for the simulation.
140
+
141
+ Returns
142
+ -------
143
+ The stratification configuration for this observer.
144
+ """
114
145
  return builder.configuration.stratification.disability
115
146
 
116
147
  def register_observations(self, builder: Builder) -> None:
148
+ """Register an observation for years lived with disability."""
117
149
  cause_pipelines = [
118
150
  f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
119
151
  ]
@@ -131,6 +163,17 @@ class DisabilityObserver(PublicHealthObserver):
131
163
  )
132
164
 
133
165
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
166
+ """Register (and return) the pipeline that produces disability weights.
167
+
168
+ Parameters
169
+ ----------
170
+ builder
171
+ The builder object for the simulation.
172
+
173
+ Returns
174
+ -------
175
+ The pipeline that produces disability weights.
176
+ """
134
177
  return builder.value.register_value_producer(
135
178
  self.disability_weight_pipeline_name,
136
179
  source=lambda index: [pd.Series(0.0, index=index)],
@@ -143,6 +186,17 @@ class DisabilityObserver(PublicHealthObserver):
143
186
  ###############
144
187
 
145
188
  def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
189
+ """Aggregate disability weights for the time step.
190
+
191
+ Parameters
192
+ ----------
193
+ dw
194
+ The disability weights to aggregate.
195
+
196
+ Returns
197
+ -------
198
+ The aggregated disability weights.
199
+ """
146
200
  aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
147
201
  if isinstance(aggregated_dw, pd.Series):
148
202
  aggregated_dw.index.name = "cause_of_disability"
@@ -153,10 +207,27 @@ class DisabilityObserver(PublicHealthObserver):
153
207
  ##############################
154
208
 
155
209
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
156
- """Format results. Note that ylds are unique in that we
157
- can't stratify by cause of disability (because there can be multiple at
158
- once), and so the results here are actually wide by disability weight
159
- pipeline name.
210
+ """Format wide YLD results to match typical/long stratified results.
211
+
212
+ YLDs are unique in that we can't stratify by cause of disability (because
213
+ there can be multiple at once), and so the results here are actually wide
214
+ by disability weight pipeline name. This method formats the results to be
215
+ long by cause of disability.
216
+
217
+ Parameters
218
+ ----------
219
+ measure
220
+ The measure.
221
+ results
222
+ The wide results to format.
223
+
224
+ Returns
225
+ -------
226
+ The results stacked by causes of disability.
227
+
228
+ Notes
229
+ -----
230
+ This method also adds the 'sub_entity' column to the results.
160
231
  """
161
232
  if len(self.causes_of_disability) > 1:
162
233
  # Drop the unused 'value' column and rename the remaining pipeline names to cause names
@@ -180,15 +251,18 @@ class DisabilityObserver(PublicHealthObserver):
180
251
  return results
181
252
 
182
253
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
254
+ """Get the 'entity_type' column values."""
183
255
  entity_type_map = {
184
256
  cause.state_id: cause.cause_type for cause in self.causes_of_disability
185
257
  }
186
258
  return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
187
259
 
188
260
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
261
+ """Get the 'entity' column values."""
189
262
  entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
190
263
  return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
191
264
 
192
265
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
266
+ """Get the 'sub_entity' column values."""
193
267
  # The sub-entity col was created in the 'format' method
194
268
  return results[COLUMNS.SUB_ENTITY]
@@ -1,7 +1,7 @@
1
1
  """
2
- ================
3
- Disease Observer
4
- ================
2
+ =================
3
+ Disease Observers
4
+ =================
5
5
 
6
6
  This module contains tools for observing disease incidence and prevalence
7
7
  in the simulation.
@@ -41,6 +41,24 @@ class DiseaseObserver(PublicHealthObserver):
41
41
  - "sex"
42
42
  include:
43
43
  - "sample_stratification"
44
+
45
+ Attributes
46
+ ----------
47
+ disease
48
+ The name of the disease being observed.
49
+ previous_state_column_name
50
+ The name of the column that stores the previous state of the disease.
51
+ step_size
52
+ The time step size of the simulation.
53
+ disease_model
54
+ The disease model for the disease being observed.
55
+ entity_type
56
+ The type of entity being observed.
57
+ entity
58
+ The entity being observed.
59
+ transition_stratification_name
60
+ The stratification name for transitions between disease states.
61
+
44
62
  """
45
63
 
46
64
  ##############
@@ -49,6 +67,9 @@ class DiseaseObserver(PublicHealthObserver):
49
67
 
50
68
  @property
51
69
  def configuration_defaults(self) -> Dict[str, Any]:
70
+ """A dictionary containing the defaults for any configurations managed by
71
+ this component.
72
+ """
52
73
  return {
53
74
  "stratification": {
54
75
  self.disease: super().configuration_defaults["stratification"][
@@ -59,14 +80,17 @@ class DiseaseObserver(PublicHealthObserver):
59
80
 
60
81
  @property
61
82
  def columns_created(self) -> List[str]:
83
+ """Columns created by this observer."""
62
84
  return [self.previous_state_column_name]
63
85
 
64
86
  @property
65
87
  def columns_required(self) -> List[str]:
88
+ """Columns required by this observer."""
66
89
  return [self.disease]
67
90
 
68
91
  @property
69
92
  def initialization_requirements(self) -> Dict[str, List[str]]:
93
+ """Requirements for observer initialization."""
70
94
  return {
71
95
  "requires_columns": [self.disease],
72
96
  }
@@ -76,6 +100,13 @@ class DiseaseObserver(PublicHealthObserver):
76
100
  #####################
77
101
 
78
102
  def __init__(self, disease: str) -> None:
103
+ """Constructor for this observer.
104
+
105
+ Parameters
106
+ ----------
107
+ disease
108
+ The name of the disease being observed.
109
+ """
79
110
  super().__init__()
80
111
  self.disease = disease
81
112
  self.previous_state_column_name = f"previous_{self.disease}"
@@ -85,6 +116,7 @@ class DiseaseObserver(PublicHealthObserver):
85
116
  #################
86
117
 
87
118
  def setup(self, builder: Builder) -> None:
119
+ """Set up the observer."""
88
120
  self.step_size = builder.time.step_size()
89
121
  self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
90
122
  self.entity_type = self.disease_model.cause_type
@@ -92,10 +124,35 @@ class DiseaseObserver(PublicHealthObserver):
92
124
  self.transition_stratification_name = f"transition_{self.disease}"
93
125
 
94
126
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
127
+ """Get the stratification configuration for this observer.
128
+
129
+ Parameters
130
+ ----------
131
+ builder
132
+ The builder object for the simulation.
133
+
134
+ Returns
135
+ -------
136
+ The stratification configuration for this observer.
137
+ """
95
138
  return builder.configuration.stratification[self.disease]
96
139
 
97
140
  def register_observations(self, builder: Builder) -> None:
98
-
141
+ """Register stratifications and observations.
142
+
143
+ Notes
144
+ -----
145
+ Ideally, each observer registers a single observation. This one, however,
146
+ registeres two.
147
+
148
+ While it's typical for all stratification registrations to be encapsulated
149
+ in a single class (i.e. the
150
+ :class:ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier),
151
+ this observer registers two additional stratifications. While they could
152
+ be registered in the ``ResultsStratifier`` as well, they are specific to
153
+ this observer and so they are registered here while we have easy access
154
+ to the required names and categories.
155
+ """
99
156
  self.register_disease_state_stratification(builder)
100
157
  self.register_transition_stratification(builder)
101
158
 
@@ -104,6 +161,7 @@ class DiseaseObserver(PublicHealthObserver):
104
161
  self.register_transition_count_observation(builder, pop_filter)
105
162
 
106
163
  def register_disease_state_stratification(self, builder: Builder) -> None:
164
+ """Register the disease state stratification."""
107
165
  builder.results.register_stratification(
108
166
  self.disease,
109
167
  [state.state_id for state in self.disease_model.states],
@@ -111,6 +169,20 @@ class DiseaseObserver(PublicHealthObserver):
111
169
  )
112
170
 
113
171
  def register_transition_stratification(self, builder: Builder) -> None:
172
+ """Register the transition stratification.
173
+
174
+ This stratification is used to track transitions between disease states.
175
+ It appends 'no_transition' to the list of transition categories and also
176
+ includes it as an exluded category.
177
+
178
+ Notes
179
+ -----
180
+ It is important to include 'no_transition' in bith the list of transition
181
+ categories as well as the list of excluded categories. This is because
182
+ it must exist as a category for the transition mapping to work correctly,
183
+ but then we don't want to include it later during the actual stratification
184
+ process.
185
+ """
114
186
  transitions = [
115
187
  str(transition) for transition in self.disease_model.transition_names
116
188
  ] + ["no_transition"]
@@ -130,6 +202,7 @@ class DiseaseObserver(PublicHealthObserver):
130
202
  )
131
203
 
132
204
  def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
205
+ """Register a person time observation."""
133
206
  self.register_adding_observation(
134
207
  builder=builder,
135
208
  name=f"person_time_{self.disease}",
@@ -144,6 +217,7 @@ class DiseaseObserver(PublicHealthObserver):
144
217
  def register_transition_count_observation(
145
218
  self, builder: Builder, pop_filter: str
146
219
  ) -> None:
220
+ """Register a transition count observation."""
147
221
  self.register_adding_observation(
148
222
  builder=builder,
149
223
  name=f"transition_count_{self.disease}",
@@ -158,6 +232,17 @@ class DiseaseObserver(PublicHealthObserver):
158
232
  )
159
233
 
160
234
  def map_transitions(self, df: pd.DataFrame) -> pd.Series:
235
+ """Map previous and current disease states to transition string.
236
+
237
+ Parameters
238
+ ----------
239
+ df
240
+ The DataFrame containing the disease states.
241
+
242
+ Returns
243
+ -------
244
+ The transitions between disease states.
245
+ """
161
246
  transitions = pd.Series(index=df.index, dtype=str)
162
247
  transition_mask = df[self.previous_state_column_name] != df[self.disease]
163
248
  transitions[~transition_mask] = "no_transition"
@@ -179,7 +264,10 @@ class DiseaseObserver(PublicHealthObserver):
179
264
  self.population_view.update(pop)
180
265
 
181
266
  def on_time_step_prepare(self, event: Event) -> None:
182
- # This enables tracking of transitions between states
267
+ """Update the previous state column to the current state.
268
+
269
+ This enables tracking of transitions between states.
270
+ """
183
271
  prior_state_pop = self.population_view.get(event.index)
184
272
  prior_state_pop[self.previous_state_column_name] = prior_state_pop[self.disease]
185
273
  self.population_view.update(prior_state_pop)
@@ -189,6 +277,17 @@ class DiseaseObserver(PublicHealthObserver):
189
277
  ###############
190
278
 
191
279
  def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
280
+ """Aggregate person time for the time step.
281
+
282
+ Parameters
283
+ ----------
284
+ x
285
+ The DataFrame containing the population.
286
+
287
+ Returns
288
+ -------
289
+ The aggregated person time.
290
+ """
192
291
  return len(x) * to_years(self.step_size())
193
292
 
194
293
  ##############################
@@ -196,6 +295,26 @@ class DiseaseObserver(PublicHealthObserver):
196
295
  ##############################
197
296
 
198
297
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
298
+ """Rename the appropriate column to 'sub_entity'.
299
+
300
+ The primary thing this method does is rename the appropriate column
301
+ (either the transition stratification name of the disease name, depending
302
+ on the measure) to 'sub_entity'. We do this here instead of the
303
+ 'get_sub_entity_column' method simply because we do not want the original
304
+ column at all. If we keep it here and then return it as the sub-entity
305
+ column later, the final results would have both.
306
+
307
+ Parameters
308
+ ----------
309
+ measure
310
+ The measure.
311
+ results
312
+ The results to format.
313
+
314
+ Returns
315
+ -------
316
+ The formatted results.
317
+ """
199
318
  results = results.reset_index()
200
319
  if "transition_count_" in measure:
201
320
  sub_entity = self.transition_stratification_name
@@ -205,6 +324,7 @@ class DiseaseObserver(PublicHealthObserver):
205
324
  return results
206
325
 
207
326
  def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
327
+ """Get the 'measure' column values."""
208
328
  if "transition_count_" in measure:
209
329
  measure_name = "transition_count"
210
330
  if "person_time_" in measure:
@@ -212,11 +332,14 @@ class DiseaseObserver(PublicHealthObserver):
212
332
  return pd.Series(measure_name, index=results.index)
213
333
 
214
334
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
335
+ """Get the 'entity_type' column values."""
215
336
  return pd.Series(self.entity_type, index=results.index)
216
337
 
217
338
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
339
+ """Get the 'entity' column values."""
218
340
  return pd.Series(self.entity, index=results.index)
219
341
 
220
342
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
343
+ """Get the 'sub_entity' column values."""
221
344
  # The sub-entity col was created in the 'format' method
222
345
  return results[COLUMNS.SUB_ENTITY]
@@ -1,7 +1,7 @@
1
1
  """
2
- ==================
3
- Mortality Observer
4
- ==================
2
+ ===================
3
+ Mortality Observers
4
+ ===================
5
5
 
6
6
  This module contains tools for observing cause-specific and
7
7
  excess mortality in the simulation, including "other causes".
@@ -47,6 +47,18 @@ class MortalityObserver(PublicHealthObserver):
47
47
  This observer needs to access the has_excess_mortality attribute of the causes
48
48
  we're observing, but this attribute gets defined in the setup of the cause models.
49
49
  As a result, the model specification should list this observer after causes.
50
+
51
+ Attributes
52
+ ----------
53
+ required_death_columns
54
+ Columns required by the deaths observation.
55
+ required_yll_columns
56
+ Columns required by the ylls observation.
57
+ clock
58
+ The simulation clock.
59
+ causes_of_death
60
+ Causes of death to be observed.
61
+
50
62
  """
51
63
 
52
64
  def __init__(self) -> None:
@@ -65,12 +77,12 @@ class MortalityObserver(PublicHealthObserver):
65
77
 
66
78
  @property
67
79
  def mortality_classes(self) -> list[type]:
80
+ """The classes to be considered as causes of death."""
68
81
  return [DiseaseState, RiskAttributableDisease]
69
82
 
70
83
  @property
71
84
  def configuration_defaults(self) -> Dict[str, Any]:
72
- """
73
- A dictionary containing the defaults for any configurations managed by
85
+ """A dictionary containing the defaults for any configurations managed by
74
86
  this component.
75
87
  """
76
88
  config_defaults = super().configuration_defaults
@@ -79,6 +91,7 @@ class MortalityObserver(PublicHealthObserver):
79
91
 
80
92
  @property
81
93
  def columns_required(self) -> List[str]:
94
+ """Columns required by this observer."""
82
95
  return [
83
96
  "alive",
84
97
  "years_of_life_lost",
@@ -91,10 +104,22 @@ class MortalityObserver(PublicHealthObserver):
91
104
  #################
92
105
 
93
106
  def setup(self, builder: Builder) -> None:
107
+ """Set up the observer."""
94
108
  self.clock = builder.time.clock()
95
109
  self.set_causes_of_death(builder)
96
110
 
97
111
  def set_causes_of_death(self, builder: Builder) -> None:
112
+ """Set the causes of death to be observed.
113
+
114
+ The causes to be observed are any registered components of class types
115
+ found in the ``mortality_classes`` property.
116
+
117
+ Notes
118
+ -----
119
+ We do not actually exclude any categories in this method.
120
+
121
+ Also note that we add 'not_dead' and 'other_causes' categories here.
122
+ """
98
123
  causes_of_death = [
99
124
  cause
100
125
  for cause in builder.components.get_components_by_type(
@@ -105,16 +130,42 @@ class MortalityObserver(PublicHealthObserver):
105
130
 
106
131
  # Convert to SimpleCauses and add on other_causes and not_dead
107
132
  self.causes_of_death = [
108
- SimpleCause.create_from_disease_state(cause) for cause in causes_of_death
133
+ SimpleCause.create_from_specific_cause(cause) for cause in causes_of_death
109
134
  ] + [
110
135
  SimpleCause("not_dead", "not_dead", "cause"),
111
136
  SimpleCause("other_causes", "other_causes", "cause"),
112
137
  ]
113
138
 
114
139
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
140
+ """Get the stratification configuration for this observer.
141
+
142
+ Parameters
143
+ ----------
144
+ builder
145
+ The builder object for the simulation.
146
+
147
+ Returns
148
+ -------
149
+ The stratification configuration for this observer.
150
+ """
115
151
  return builder.configuration.stratification[self.get_configuration_name()]
116
152
 
117
153
  def register_observations(self, builder: Builder) -> None:
154
+ """Register stratifications and observations.
155
+
156
+ Notes
157
+ -----
158
+ Ideally, each observer registers a single observation. This one, however,
159
+ registeres two.
160
+
161
+ While it's typical for all stratification registrations to be encapsulated
162
+ in a single class (i.e. the
163
+ :class:ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier),
164
+ this observer potentially registers an additional one. While it could
165
+ be registered in the ``ResultsStratifier`` as well, it is specific to
166
+ this observer and so it is registered here while we have easy access
167
+ to the required categories.
168
+ """
118
169
  pop_filter = 'alive == "dead" and tracked == True'
119
170
  additional_stratifications = self.configuration.include
120
171
  if not self.configuration.aggregate:
@@ -155,10 +206,12 @@ class MortalityObserver(PublicHealthObserver):
155
206
  ###############
156
207
 
157
208
  def count_deaths(self, x: pd.DataFrame) -> float:
209
+ """Count the number of deaths that occurred during this time step."""
158
210
  died_of_cause = x["exit_time"] > self.clock()
159
211
  return sum(died_of_cause)
160
212
 
161
213
  def calculate_ylls(self, x: pd.DataFrame) -> float:
214
+ """Calculate the years of life lost during this time step."""
162
215
  died_of_cause = x["exit_time"] > self.clock()
163
216
  return x.loc[died_of_cause, "years_of_life_lost"].sum()
164
217
 
@@ -167,6 +220,26 @@ class MortalityObserver(PublicHealthObserver):
167
220
  ##############################
168
221
 
169
222
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
223
+ """Rename the appropriate column to 'entity'.
224
+
225
+ The primary thing this method does is rename the 'cause_of_death' column
226
+ to 'entity' (or, it we are aggregating, and there is no 'cause_of_death'
227
+ column, we simply create a new 'entity' column). We do this here instead
228
+ of the 'get_entity_column' method simply because we do not want the
229
+ 'cause_of_death' at all. If we keep it here and then return it as the
230
+ entity column later, the final results would have both.
231
+
232
+ Parameters
233
+ ----------
234
+ measure
235
+ The measure.
236
+ results
237
+ The results to format.
238
+
239
+ Returns
240
+ -------
241
+ The formatted results.
242
+ """
170
243
  results = results.reset_index()
171
244
  if self.configuration.aggregate:
172
245
  results[COLUMNS.ENTITY] = "all_causes"
@@ -175,12 +248,15 @@ class MortalityObserver(PublicHealthObserver):
175
248
  return results
176
249
 
177
250
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
251
+ """Get the 'entity_type' column values."""
178
252
  entity_type_map = {cause.state_id: cause.cause_type for cause in self.causes_of_death}
179
253
  return results[COLUMNS.ENTITY].map(entity_type_map).astype(CategoricalDtype())
180
254
 
181
255
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
256
+ """Get the 'entity' column values."""
182
257
  # The entity col was created in the 'format' method
183
258
  return results[COLUMNS.ENTITY]
184
259
 
185
260
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
261
+ """Get the 'sub_entity' column values."""
186
262
  return results[COLUMNS.ENTITY]