vivarium-public-health 3.0.3__py3-none-any.whl → 3.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/state.py +24 -19
  3. vivarium_public_health/mslt/delay.py +13 -5
  4. vivarium_public_health/mslt/disease.py +35 -14
  5. vivarium_public_health/mslt/intervention.py +12 -9
  6. vivarium_public_health/mslt/observer.py +56 -17
  7. vivarium_public_health/mslt/population.py +7 -10
  8. vivarium_public_health/plugins/parser.py +29 -80
  9. vivarium_public_health/population/add_new_birth_cohorts.py +8 -9
  10. vivarium_public_health/population/base_population.py +0 -5
  11. vivarium_public_health/population/data_transformations.py +1 -8
  12. vivarium_public_health/population/mortality.py +3 -3
  13. vivarium_public_health/results/columns.py +1 -1
  14. vivarium_public_health/results/disability.py +89 -15
  15. vivarium_public_health/results/disease.py +128 -5
  16. vivarium_public_health/results/mortality.py +82 -6
  17. vivarium_public_health/results/observer.py +151 -6
  18. vivarium_public_health/results/risk.py +66 -5
  19. vivarium_public_health/results/simple_cause.py +30 -5
  20. vivarium_public_health/results/stratification.py +39 -14
  21. vivarium_public_health/risks/base_risk.py +14 -16
  22. vivarium_public_health/risks/data_transformations.py +3 -1
  23. vivarium_public_health/risks/distributions.py +0 -1
  24. vivarium_public_health/risks/effect.py +31 -29
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +51 -30
  26. vivarium_public_health/treatment/scale_up.py +6 -10
  27. vivarium_public_health/treatment/therapeutic_inertia.py +3 -1
  28. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/METADATA +1 -1
  29. vivarium_public_health-3.0.5.dist-info/RECORD +49 -0
  30. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/WHEEL +1 -1
  31. vivarium_public_health-3.0.3.dist-info/RECORD +0 -49
  32. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/LICENSE.txt +0 -0
  33. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.5.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  """
2
- ===================
3
- Disability Observer
4
- ===================
2
+ ====================
3
+ Disability Observers
4
+ ====================
5
5
 
6
6
  This module contains tools for observing years lived with disability (YLDs)
7
7
  in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, List, Union
11
+ from typing import Union
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -42,6 +42,17 @@ class DisabilityObserver(PublicHealthObserver):
42
42
  - "sex"
43
43
  include:
44
44
  - "sample_stratification"
45
+ Attributes
46
+ ----------
47
+ disability_weight_pipeline_name
48
+ The name of the pipeline that produces disability weights.
49
+ step_size
50
+ The time step size of the simulation.
51
+ disability_weight
52
+ The pipeline that produces disability weights.
53
+ causes_of_disability
54
+ The causes of disability to be observed.
55
+
45
56
  """
46
57
 
47
58
  ##############
@@ -50,7 +61,7 @@ class DisabilityObserver(PublicHealthObserver):
50
61
 
51
62
  @property
52
63
  def disability_classes(self) -> list[type]:
53
- """The classes to be considered for causes of disability."""
64
+ """The classes to be considered as causes of disability."""
54
65
  return [DiseaseState, RiskAttributableDisease]
55
66
 
56
67
  #####################
@@ -66,23 +77,32 @@ class DisabilityObserver(PublicHealthObserver):
66
77
  #################
67
78
 
68
79
  def setup(self, builder: Builder) -> None:
80
+ """Set up the observer."""
69
81
  self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
70
82
  self.disability_weight = self.get_disability_weight_pipeline(builder)
71
83
  self.set_causes_of_disability(builder)
72
84
 
73
85
  def set_causes_of_disability(self, builder: Builder) -> None:
74
- """Set the causes of disability to be observed by removing any excluded
75
- via the model spec from the list of all disability class causes. We implement
76
- exclusions here because disabilities are unique in that they are not
77
- registered stratifications and so cannot be excluded during the stratification
78
- call like other categories.
86
+ """Set the causes of disability to be observed.
87
+
88
+ The causes to be observed are any registered components of class types
89
+ found in the ``disability_classes`` property *excluding* any listed in
90
+ the model spec as ``excluded_categories``.
91
+
92
+ Notes
93
+ -----
94
+ We implement exclusions here instead of during the stratification call
95
+ like most other categories because disabilities are unique in that they are
96
+ *not* actually registered stratifications.
97
+
98
+ Also note that we add an 'all_causes' category here.
79
99
  """
80
100
  causes_of_disability = builder.components.get_components_by_type(
81
101
  self.disability_classes
82
102
  )
83
103
  # Convert to SimpleCause instances and add on all_causes
84
104
  causes_of_disability = [
85
- SimpleCause.create_from_disease_state(cause) for cause in causes_of_disability
105
+ SimpleCause.create_from_specific_cause(cause) for cause in causes_of_disability
86
106
  ] + [SimpleCause("all_causes", "all_causes", "cause")]
87
107
 
88
108
  excluded_causes = (
@@ -111,9 +131,21 @@ class DisabilityObserver(PublicHealthObserver):
111
131
  ]
112
132
 
113
133
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
134
+ """Get the stratification configuration for this observer.
135
+
136
+ Parameters
137
+ ----------
138
+ builder
139
+ The builder object for the simulation.
140
+
141
+ Returns
142
+ -------
143
+ The stratification configuration for this observer.
144
+ """
114
145
  return builder.configuration.stratification.disability
115
146
 
116
147
  def register_observations(self, builder: Builder) -> None:
148
+ """Register an observation for years lived with disability."""
117
149
  cause_pipelines = [
118
150
  f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
119
151
  ]
@@ -131,6 +163,17 @@ class DisabilityObserver(PublicHealthObserver):
131
163
  )
132
164
 
133
165
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
166
+ """Register (and return) the pipeline that produces disability weights.
167
+
168
+ Parameters
169
+ ----------
170
+ builder
171
+ The builder object for the simulation.
172
+
173
+ Returns
174
+ -------
175
+ The pipeline that produces disability weights.
176
+ """
134
177
  return builder.value.register_value_producer(
135
178
  self.disability_weight_pipeline_name,
136
179
  source=lambda index: [pd.Series(0.0, index=index)],
@@ -143,6 +186,17 @@ class DisabilityObserver(PublicHealthObserver):
143
186
  ###############
144
187
 
145
188
  def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
189
+ """Aggregate disability weights for the time step.
190
+
191
+ Parameters
192
+ ----------
193
+ dw
194
+ The disability weights to aggregate.
195
+
196
+ Returns
197
+ -------
198
+ The aggregated disability weights.
199
+ """
146
200
  aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
147
201
  if isinstance(aggregated_dw, pd.Series):
148
202
  aggregated_dw.index.name = "cause_of_disability"
@@ -153,10 +207,27 @@ class DisabilityObserver(PublicHealthObserver):
153
207
  ##############################
154
208
 
155
209
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
156
- """Format results. Note that ylds are unique in that we
157
- can't stratify by cause of disability (because there can be multiple at
158
- once), and so the results here are actually wide by disability weight
159
- pipeline name.
210
+ """Format wide YLD results to match typical/long stratified results.
211
+
212
+ YLDs are unique in that we can't stratify by cause of disability (because
213
+ there can be multiple at once), and so the results here are actually wide
214
+ by disability weight pipeline name. This method formats the results to be
215
+ long by cause of disability.
216
+
217
+ Parameters
218
+ ----------
219
+ measure
220
+ The measure.
221
+ results
222
+ The wide results to format.
223
+
224
+ Returns
225
+ -------
226
+ The results stacked by causes of disability.
227
+
228
+ Notes
229
+ -----
230
+ This method also adds the 'sub_entity' column to the results.
160
231
  """
161
232
  if len(self.causes_of_disability) > 1:
162
233
  # Drop the unused 'value' column and rename the remaining pipeline names to cause names
@@ -180,15 +251,18 @@ class DisabilityObserver(PublicHealthObserver):
180
251
  return results
181
252
 
182
253
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
254
+ """Get the 'entity_type' column values."""
183
255
  entity_type_map = {
184
256
  cause.state_id: cause.cause_type for cause in self.causes_of_disability
185
257
  }
186
258
  return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
187
259
 
188
260
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
261
+ """Get the 'entity' column values."""
189
262
  entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
190
263
  return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
191
264
 
192
265
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
266
+ """Get the 'sub_entity' column values."""
193
267
  # The sub-entity col was created in the 'format' method
194
268
  return results[COLUMNS.SUB_ENTITY]
@@ -1,7 +1,7 @@
1
1
  """
2
- ================
3
- Disease Observer
4
- ================
2
+ =================
3
+ Disease Observers
4
+ =================
5
5
 
6
6
  This module contains tools for observing disease incidence and prevalence
7
7
  in the simulation.
@@ -41,6 +41,24 @@ class DiseaseObserver(PublicHealthObserver):
41
41
  - "sex"
42
42
  include:
43
43
  - "sample_stratification"
44
+
45
+ Attributes
46
+ ----------
47
+ disease
48
+ The name of the disease being observed.
49
+ previous_state_column_name
50
+ The name of the column that stores the previous state of the disease.
51
+ step_size
52
+ The time step size of the simulation.
53
+ disease_model
54
+ The disease model for the disease being observed.
55
+ entity_type
56
+ The type of entity being observed.
57
+ entity
58
+ The entity being observed.
59
+ transition_stratification_name
60
+ The stratification name for transitions between disease states.
61
+
44
62
  """
45
63
 
46
64
  ##############
@@ -49,6 +67,9 @@ class DiseaseObserver(PublicHealthObserver):
49
67
 
50
68
  @property
51
69
  def configuration_defaults(self) -> Dict[str, Any]:
70
+ """A dictionary containing the defaults for any configurations managed by
71
+ this component.
72
+ """
52
73
  return {
53
74
  "stratification": {
54
75
  self.disease: super().configuration_defaults["stratification"][
@@ -59,14 +80,17 @@ class DiseaseObserver(PublicHealthObserver):
59
80
 
60
81
  @property
61
82
  def columns_created(self) -> List[str]:
83
+ """Columns created by this observer."""
62
84
  return [self.previous_state_column_name]
63
85
 
64
86
  @property
65
87
  def columns_required(self) -> List[str]:
88
+ """Columns required by this observer."""
66
89
  return [self.disease]
67
90
 
68
91
  @property
69
92
  def initialization_requirements(self) -> Dict[str, List[str]]:
93
+ """Requirements for observer initialization."""
70
94
  return {
71
95
  "requires_columns": [self.disease],
72
96
  }
@@ -76,6 +100,13 @@ class DiseaseObserver(PublicHealthObserver):
76
100
  #####################
77
101
 
78
102
  def __init__(self, disease: str) -> None:
103
+ """Constructor for this observer.
104
+
105
+ Parameters
106
+ ----------
107
+ disease
108
+ The name of the disease being observed.
109
+ """
79
110
  super().__init__()
80
111
  self.disease = disease
81
112
  self.previous_state_column_name = f"previous_{self.disease}"
@@ -85,6 +116,7 @@ class DiseaseObserver(PublicHealthObserver):
85
116
  #################
86
117
 
87
118
  def setup(self, builder: Builder) -> None:
119
+ """Set up the observer."""
88
120
  self.step_size = builder.time.step_size()
89
121
  self.disease_model = builder.components.get_component(f"disease_model.{self.disease}")
90
122
  self.entity_type = self.disease_model.cause_type
@@ -92,10 +124,35 @@ class DiseaseObserver(PublicHealthObserver):
92
124
  self.transition_stratification_name = f"transition_{self.disease}"
93
125
 
94
126
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
127
+ """Get the stratification configuration for this observer.
128
+
129
+ Parameters
130
+ ----------
131
+ builder
132
+ The builder object for the simulation.
133
+
134
+ Returns
135
+ -------
136
+ The stratification configuration for this observer.
137
+ """
95
138
  return builder.configuration.stratification[self.disease]
96
139
 
97
140
  def register_observations(self, builder: Builder) -> None:
98
-
141
+ """Register stratifications and observations.
142
+
143
+ Notes
144
+ -----
145
+ Ideally, each observer registers a single observation. This one, however,
146
+ registeres two.
147
+
148
+ While it's typical for all stratification registrations to be encapsulated
149
+ in a single class (i.e. the
150
+ :class:ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier),
151
+ this observer registers two additional stratifications. While they could
152
+ be registered in the ``ResultsStratifier`` as well, they are specific to
153
+ this observer and so they are registered here while we have easy access
154
+ to the required names and categories.
155
+ """
99
156
  self.register_disease_state_stratification(builder)
100
157
  self.register_transition_stratification(builder)
101
158
 
@@ -104,6 +161,7 @@ class DiseaseObserver(PublicHealthObserver):
104
161
  self.register_transition_count_observation(builder, pop_filter)
105
162
 
106
163
  def register_disease_state_stratification(self, builder: Builder) -> None:
164
+ """Register the disease state stratification."""
107
165
  builder.results.register_stratification(
108
166
  self.disease,
109
167
  [state.state_id for state in self.disease_model.states],
@@ -111,6 +169,20 @@ class DiseaseObserver(PublicHealthObserver):
111
169
  )
112
170
 
113
171
  def register_transition_stratification(self, builder: Builder) -> None:
172
+ """Register the transition stratification.
173
+
174
+ This stratification is used to track transitions between disease states.
175
+ It appends 'no_transition' to the list of transition categories and also
176
+ includes it as an exluded category.
177
+
178
+ Notes
179
+ -----
180
+ It is important to include 'no_transition' in bith the list of transition
181
+ categories as well as the list of excluded categories. This is because
182
+ it must exist as a category for the transition mapping to work correctly,
183
+ but then we don't want to include it later during the actual stratification
184
+ process.
185
+ """
114
186
  transitions = [
115
187
  str(transition) for transition in self.disease_model.transition_names
116
188
  ] + ["no_transition"]
@@ -130,6 +202,7 @@ class DiseaseObserver(PublicHealthObserver):
130
202
  )
131
203
 
132
204
  def register_person_time_observation(self, builder: Builder, pop_filter: str) -> None:
205
+ """Register a person time observation."""
133
206
  self.register_adding_observation(
134
207
  builder=builder,
135
208
  name=f"person_time_{self.disease}",
@@ -144,6 +217,7 @@ class DiseaseObserver(PublicHealthObserver):
144
217
  def register_transition_count_observation(
145
218
  self, builder: Builder, pop_filter: str
146
219
  ) -> None:
220
+ """Register a transition count observation."""
147
221
  self.register_adding_observation(
148
222
  builder=builder,
149
223
  name=f"transition_count_{self.disease}",
@@ -158,6 +232,17 @@ class DiseaseObserver(PublicHealthObserver):
158
232
  )
159
233
 
160
234
  def map_transitions(self, df: pd.DataFrame) -> pd.Series:
235
+ """Map previous and current disease states to transition string.
236
+
237
+ Parameters
238
+ ----------
239
+ df
240
+ The DataFrame containing the disease states.
241
+
242
+ Returns
243
+ -------
244
+ The transitions between disease states.
245
+ """
161
246
  transitions = pd.Series(index=df.index, dtype=str)
162
247
  transition_mask = df[self.previous_state_column_name] != df[self.disease]
163
248
  transitions[~transition_mask] = "no_transition"
@@ -179,7 +264,10 @@ class DiseaseObserver(PublicHealthObserver):
179
264
  self.population_view.update(pop)
180
265
 
181
266
  def on_time_step_prepare(self, event: Event) -> None:
182
- # This enables tracking of transitions between states
267
+ """Update the previous state column to the current state.
268
+
269
+ This enables tracking of transitions between states.
270
+ """
183
271
  prior_state_pop = self.population_view.get(event.index)
184
272
  prior_state_pop[self.previous_state_column_name] = prior_state_pop[self.disease]
185
273
  self.population_view.update(prior_state_pop)
@@ -189,6 +277,17 @@ class DiseaseObserver(PublicHealthObserver):
189
277
  ###############
190
278
 
191
279
  def aggregate_state_person_time(self, x: pd.DataFrame) -> float:
280
+ """Aggregate person time for the time step.
281
+
282
+ Parameters
283
+ ----------
284
+ x
285
+ The DataFrame containing the population.
286
+
287
+ Returns
288
+ -------
289
+ The aggregated person time.
290
+ """
192
291
  return len(x) * to_years(self.step_size())
193
292
 
194
293
  ##############################
@@ -196,6 +295,26 @@ class DiseaseObserver(PublicHealthObserver):
196
295
  ##############################
197
296
 
198
297
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
298
+ """Rename the appropriate column to 'sub_entity'.
299
+
300
+ The primary thing this method does is rename the appropriate column
301
+ (either the transition stratification name of the disease name, depending
302
+ on the measure) to 'sub_entity'. We do this here instead of the
303
+ 'get_sub_entity_column' method simply because we do not want the original
304
+ column at all. If we keep it here and then return it as the sub-entity
305
+ column later, the final results would have both.
306
+
307
+ Parameters
308
+ ----------
309
+ measure
310
+ The measure.
311
+ results
312
+ The results to format.
313
+
314
+ Returns
315
+ -------
316
+ The formatted results.
317
+ """
199
318
  results = results.reset_index()
200
319
  if "transition_count_" in measure:
201
320
  sub_entity = self.transition_stratification_name
@@ -205,6 +324,7 @@ class DiseaseObserver(PublicHealthObserver):
205
324
  return results
206
325
 
207
326
  def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
327
+ """Get the 'measure' column values."""
208
328
  if "transition_count_" in measure:
209
329
  measure_name = "transition_count"
210
330
  if "person_time_" in measure:
@@ -212,11 +332,14 @@ class DiseaseObserver(PublicHealthObserver):
212
332
  return pd.Series(measure_name, index=results.index)
213
333
 
214
334
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
335
+ """Get the 'entity_type' column values."""
215
336
  return pd.Series(self.entity_type, index=results.index)
216
337
 
217
338
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
339
+ """Get the 'entity' column values."""
218
340
  return pd.Series(self.entity, index=results.index)
219
341
 
220
342
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
343
+ """Get the 'sub_entity' column values."""
221
344
  # The sub-entity col was created in the 'format' method
222
345
  return results[COLUMNS.SUB_ENTITY]
@@ -1,7 +1,7 @@
1
1
  """
2
- ==================
3
- Mortality Observer
4
- ==================
2
+ ===================
3
+ Mortality Observers
4
+ ===================
5
5
 
6
6
  This module contains tools for observing cause-specific and
7
7
  excess mortality in the simulation, including "other causes".
@@ -47,6 +47,18 @@ class MortalityObserver(PublicHealthObserver):
47
47
  This observer needs to access the has_excess_mortality attribute of the causes
48
48
  we're observing, but this attribute gets defined in the setup of the cause models.
49
49
  As a result, the model specification should list this observer after causes.
50
+
51
+ Attributes
52
+ ----------
53
+ required_death_columns
54
+ Columns required by the deaths observation.
55
+ required_yll_columns
56
+ Columns required by the ylls observation.
57
+ clock
58
+ The simulation clock.
59
+ causes_of_death
60
+ Causes of death to be observed.
61
+
50
62
  """
51
63
 
52
64
  def __init__(self) -> None:
@@ -65,12 +77,12 @@ class MortalityObserver(PublicHealthObserver):
65
77
 
66
78
  @property
67
79
  def mortality_classes(self) -> list[type]:
80
+ """The classes to be considered as causes of death."""
68
81
  return [DiseaseState, RiskAttributableDisease]
69
82
 
70
83
  @property
71
84
  def configuration_defaults(self) -> Dict[str, Any]:
72
- """
73
- A dictionary containing the defaults for any configurations managed by
85
+ """A dictionary containing the defaults for any configurations managed by
74
86
  this component.
75
87
  """
76
88
  config_defaults = super().configuration_defaults
@@ -79,6 +91,7 @@ class MortalityObserver(PublicHealthObserver):
79
91
 
80
92
  @property
81
93
  def columns_required(self) -> List[str]:
94
+ """Columns required by this observer."""
82
95
  return [
83
96
  "alive",
84
97
  "years_of_life_lost",
@@ -91,10 +104,22 @@ class MortalityObserver(PublicHealthObserver):
91
104
  #################
92
105
 
93
106
  def setup(self, builder: Builder) -> None:
107
+ """Set up the observer."""
94
108
  self.clock = builder.time.clock()
95
109
  self.set_causes_of_death(builder)
96
110
 
97
111
  def set_causes_of_death(self, builder: Builder) -> None:
112
+ """Set the causes of death to be observed.
113
+
114
+ The causes to be observed are any registered components of class types
115
+ found in the ``mortality_classes`` property.
116
+
117
+ Notes
118
+ -----
119
+ We do not actually exclude any categories in this method.
120
+
121
+ Also note that we add 'not_dead' and 'other_causes' categories here.
122
+ """
98
123
  causes_of_death = [
99
124
  cause
100
125
  for cause in builder.components.get_components_by_type(
@@ -105,16 +130,42 @@ class MortalityObserver(PublicHealthObserver):
105
130
 
106
131
  # Convert to SimpleCauses and add on other_causes and not_dead
107
132
  self.causes_of_death = [
108
- SimpleCause.create_from_disease_state(cause) for cause in causes_of_death
133
+ SimpleCause.create_from_specific_cause(cause) for cause in causes_of_death
109
134
  ] + [
110
135
  SimpleCause("not_dead", "not_dead", "cause"),
111
136
  SimpleCause("other_causes", "other_causes", "cause"),
112
137
  ]
113
138
 
114
139
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
140
+ """Get the stratification configuration for this observer.
141
+
142
+ Parameters
143
+ ----------
144
+ builder
145
+ The builder object for the simulation.
146
+
147
+ Returns
148
+ -------
149
+ The stratification configuration for this observer.
150
+ """
115
151
  return builder.configuration.stratification[self.get_configuration_name()]
116
152
 
117
153
  def register_observations(self, builder: Builder) -> None:
154
+ """Register stratifications and observations.
155
+
156
+ Notes
157
+ -----
158
+ Ideally, each observer registers a single observation. This one, however,
159
+ registeres two.
160
+
161
+ While it's typical for all stratification registrations to be encapsulated
162
+ in a single class (i.e. the
163
+ :class:ResultsStratifier <vivarium_public_health.results.stratification.ResultsStratifier),
164
+ this observer potentially registers an additional one. While it could
165
+ be registered in the ``ResultsStratifier`` as well, it is specific to
166
+ this observer and so it is registered here while we have easy access
167
+ to the required categories.
168
+ """
118
169
  pop_filter = 'alive == "dead" and tracked == True'
119
170
  additional_stratifications = self.configuration.include
120
171
  if not self.configuration.aggregate:
@@ -155,10 +206,12 @@ class MortalityObserver(PublicHealthObserver):
155
206
  ###############
156
207
 
157
208
  def count_deaths(self, x: pd.DataFrame) -> float:
209
+ """Count the number of deaths that occurred during this time step."""
158
210
  died_of_cause = x["exit_time"] > self.clock()
159
211
  return sum(died_of_cause)
160
212
 
161
213
  def calculate_ylls(self, x: pd.DataFrame) -> float:
214
+ """Calculate the years of life lost during this time step."""
162
215
  died_of_cause = x["exit_time"] > self.clock()
163
216
  return x.loc[died_of_cause, "years_of_life_lost"].sum()
164
217
 
@@ -167,6 +220,26 @@ class MortalityObserver(PublicHealthObserver):
167
220
  ##############################
168
221
 
169
222
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
223
+ """Rename the appropriate column to 'entity'.
224
+
225
+ The primary thing this method does is rename the 'cause_of_death' column
226
+ to 'entity' (or, it we are aggregating, and there is no 'cause_of_death'
227
+ column, we simply create a new 'entity' column). We do this here instead
228
+ of the 'get_entity_column' method simply because we do not want the
229
+ 'cause_of_death' at all. If we keep it here and then return it as the
230
+ entity column later, the final results would have both.
231
+
232
+ Parameters
233
+ ----------
234
+ measure
235
+ The measure.
236
+ results
237
+ The results to format.
238
+
239
+ Returns
240
+ -------
241
+ The formatted results.
242
+ """
170
243
  results = results.reset_index()
171
244
  if self.configuration.aggregate:
172
245
  results[COLUMNS.ENTITY] = "all_causes"
@@ -175,12 +248,15 @@ class MortalityObserver(PublicHealthObserver):
175
248
  return results
176
249
 
177
250
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
251
+ """Get the 'entity_type' column values."""
178
252
  entity_type_map = {cause.state_id: cause.cause_type for cause in self.causes_of_death}
179
253
  return results[COLUMNS.ENTITY].map(entity_type_map).astype(CategoricalDtype())
180
254
 
181
255
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
256
+ """Get the 'entity' column values."""
182
257
  # The entity col was created in the 'format' method
183
258
  return results[COLUMNS.ENTITY]
184
259
 
185
260
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
261
+ """Get the 'sub_entity' column values."""
186
262
  return results[COLUMNS.ENTITY]