vivarium-public-health 3.0.3__py3-none-any.whl → 3.0.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/state.py +24 -19
  3. vivarium_public_health/mslt/delay.py +13 -5
  4. vivarium_public_health/mslt/disease.py +35 -14
  5. vivarium_public_health/mslt/intervention.py +12 -9
  6. vivarium_public_health/mslt/observer.py +56 -17
  7. vivarium_public_health/mslt/population.py +7 -10
  8. vivarium_public_health/plugins/parser.py +29 -80
  9. vivarium_public_health/population/add_new_birth_cohorts.py +8 -9
  10. vivarium_public_health/population/base_population.py +0 -5
  11. vivarium_public_health/population/data_transformations.py +1 -8
  12. vivarium_public_health/population/mortality.py +3 -3
  13. vivarium_public_health/results/columns.py +1 -1
  14. vivarium_public_health/results/disability.py +85 -11
  15. vivarium_public_health/results/disease.py +125 -2
  16. vivarium_public_health/results/mortality.py +78 -2
  17. vivarium_public_health/results/observer.py +141 -6
  18. vivarium_public_health/results/risk.py +66 -5
  19. vivarium_public_health/results/simple_cause.py +8 -2
  20. vivarium_public_health/results/stratification.py +39 -14
  21. vivarium_public_health/risks/base_risk.py +14 -16
  22. vivarium_public_health/risks/data_transformations.py +3 -1
  23. vivarium_public_health/risks/distributions.py +0 -1
  24. vivarium_public_health/risks/effect.py +31 -29
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +51 -30
  26. vivarium_public_health/treatment/scale_up.py +6 -10
  27. vivarium_public_health/treatment/therapeutic_inertia.py +3 -1
  28. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.4.dist-info}/METADATA +1 -1
  29. vivarium_public_health-3.0.4.dist-info/RECORD +49 -0
  30. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.4.dist-info}/WHEEL +1 -1
  31. vivarium_public_health-3.0.3.dist-info/RECORD +0 -49
  32. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.4.dist-info}/LICENSE.txt +0 -0
  33. {vivarium_public_health-3.0.3.dist-info → vivarium_public_health-3.0.4.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ Component Configuration Parsers in this module are specialized implementations o
7
7
  :class:`ComponentConfigurationParser <vivarium.framework.components.parser.ComponentConfigurationParser>`
8
8
  that can parse configurations of components specific to the Vivarium Public
9
9
  Health package.
10
+
10
11
  """
11
12
 
12
13
  from importlib import import_module
@@ -34,34 +35,32 @@ from vivarium_public_health.utilities import TargetString
34
35
 
35
36
 
36
37
  class CausesParsingErrors(ParsingError):
37
- """
38
- Error raised when there are any errors parsing a cause model configuration.
39
- """
38
+ """Error raised when there are any errors parsing a cause model configuration."""
40
39
 
41
40
  def __init__(self, messages: List[str]):
42
41
  super().__init__("\n - " + "\n - ".join(messages))
43
42
 
44
43
 
45
44
  class CausesConfigurationParser(ComponentConfigurationParser):
46
- """
45
+ """Parser for cause model configurations.
46
+
47
47
  Component configuration parser that acts the same as the standard vivarium
48
48
  `ComponentConfigurationParser` but adds the additional ability to parse a
49
49
  configuration to create `DiseaseModel` components. These DiseaseModel
50
50
  configurations can either be specified directly in the configuration in a
51
51
  `causes` key or in external configuration files that are specified in the
52
52
  `external_configuration` key.
53
+
53
54
  """
54
55
 
55
56
  DEFAULT_MODEL_CONFIG = {
56
57
  "model_type": f"{DiseaseModel.__module__}.{DiseaseModel.__name__}",
57
58
  "initial_state": None,
58
59
  }
59
- """
60
- If a cause model configuration does not specify a model type or initial
61
- state, these default values will be used. The default model type is
62
- `DiseaseModel` and the
63
- default initial state is `None`. If the initial state is not specified,
64
- the cause model must have a state named 'susceptible'.
60
+ """Default cause model configuration if it's not explicitly specified.
61
+
62
+ If the initial state is not specified, the cause model must have a state
63
+ named 'susceptible'.
65
64
  """
66
65
 
67
66
  DEFAULT_STATE_CONFIG = {
@@ -72,23 +71,16 @@ class CausesConfigurationParser(ComponentConfigurationParser):
72
71
  "cleanup_function": None,
73
72
  "state_type": None,
74
73
  }
75
- """
76
- If a state configuration does not specify cause_type, transient,
77
- allow_self_transition, side_effect, cleanup_function, or state_type,
78
- these default values will be used. The default cause type is 'cause', the
79
- default transient value is False, and the default allow_self_transition
80
- value is True.
81
- """
74
+ """Default state configuration if it's not explicitly specified."""
82
75
 
83
76
  DEFAULT_TRANSITION_CONFIG = {"triggered": "NOT_TRIGGERED"}
84
- """
85
- If a transition configuration does not specify a triggered value, this
86
- default value will be used. The default triggered value is 'NOT_TRIGGERED'.
77
+ """Default triggered value.
78
+
79
+ This value is used if the transition configuration does not explicity specify it.
87
80
  """
88
81
 
89
82
  def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]:
90
- """
91
- Parses the component configuration and returns a list of components.
83
+ """Parses the component configuration and returns a list of components.
92
84
 
93
85
  In particular, this method looks for an `external_configuration` key
94
86
  and/or a `causes` key.
@@ -143,7 +135,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
143
135
 
144
136
  Returns
145
137
  -------
146
- List
147
138
  A list of initialized components.
148
139
 
149
140
  Raises
@@ -188,18 +179,14 @@ class CausesConfigurationParser(ComponentConfigurationParser):
188
179
  #########################
189
180
 
190
181
  def _add_default_config_layer(self, causes_config: LayeredConfigTree) -> None:
191
- """
192
- Adds a default layer to the provided configuration that specifies
193
- default values for the cause model configuration.
182
+ """Adds a default layer to the provided configuration.
183
+
184
+ This default layer specifies values for the cause model configuration.
194
185
 
195
186
  Parameters
196
187
  ----------
197
188
  causes_config
198
189
  A LayeredConfigTree defining the cause model configurations
199
-
200
- Returns
201
- -------
202
- None
203
190
  """
204
191
  default_config = {}
205
192
  for cause_name, cause_config in causes_config.items():
@@ -228,9 +215,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
228
215
  def _get_cause_model_components(
229
216
  self, causes_config: LayeredConfigTree
230
217
  ) -> List[Component]:
231
- """
232
- Parses the cause model configuration and returns a list of
233
- `DiseaseModel` components.
218
+ """Parses the cause model configuration and returns the `DiseaseModel` components.
234
219
 
235
220
  Parameters
236
221
  ----------
@@ -239,7 +224,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
239
224
 
240
225
  Returns
241
226
  -------
242
- List[Component]
243
227
  A list of initialized `DiseaseModel` components
244
228
  """
245
229
  cause_models = []
@@ -277,9 +261,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
277
261
  def _get_state(
278
262
  self, state_name: str, state_config: LayeredConfigTree, cause_name: str
279
263
  ) -> BaseDiseaseState:
280
- """
281
- Parses a state configuration and returns an initialized `BaseDiseaseState`
282
- object.
264
+ """Parses a state configuration and returns an initialized `BaseDiseaseState` object.
283
265
 
284
266
  Parameters
285
267
  ----------
@@ -292,7 +274,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
292
274
 
293
275
  Returns
294
276
  -------
295
- BaseDiseaseState
296
277
  An initialized `BaseDiseaseState` object
297
278
  """
298
279
  state_id = cause_name if state_name in ["susceptible", "recovered"] else state_name
@@ -330,8 +311,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
330
311
  sink_state: BaseDiseaseState,
331
312
  transition_config: LayeredConfigTree,
332
313
  ) -> None:
333
- """
334
- Adds a transition between two states.
314
+ """Adds a transition between two states.
335
315
 
336
316
  Parameters
337
317
  ----------
@@ -341,10 +321,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
341
321
  The state the transition ends at
342
322
  transition_config
343
323
  A `LayeredConfigTree` defining the transition to add
344
-
345
- Returns
346
- -------
347
- None
348
324
  """
349
325
  triggered = Trigger[transition_config.triggered]
350
326
  if "data_sources" in transition_config:
@@ -372,9 +348,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
372
348
  def _get_data_sources(
373
349
  self, config: LayeredConfigTree
374
350
  ) -> Dict[str, Callable[[Builder, Any], Any]]:
375
- """
376
- Parses a data sources configuration and returns a dictionary of data
377
- sources.
351
+ """Parses a data sources configuration and returns the data sources.
378
352
 
379
353
  Parameters
380
354
  ----------
@@ -383,7 +357,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
383
357
 
384
358
  Returns
385
359
  -------
386
- Dict[str, Callable[[Builder, Any], Any]]
387
360
  A dictionary of data source getters
388
361
  """
389
362
  return {name: self._get_data_source(name, config[name]) for name in config.keys()}
@@ -392,9 +365,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
392
365
  def _get_data_source(
393
366
  name: str, source: Union[str, float]
394
367
  ) -> Callable[[Builder, Any], Any]:
395
- """
396
- Parses a data source and returns a callable that can be used to retrieve
397
- the data.
368
+ """Parses a data source and returns a callable that can be used to retrieve the data.
398
369
 
399
370
  Parameters
400
371
  ----------
@@ -405,7 +376,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
405
376
 
406
377
  Returns
407
378
  -------
408
- Callable[[Builder, Any], Any]
409
379
  A callable that can be used to retrieve the data
410
380
  """
411
381
  if isinstance(source, float):
@@ -465,18 +435,13 @@ class CausesConfigurationParser(ComponentConfigurationParser):
465
435
 
466
436
  @staticmethod
467
437
  def _validate_external_configuration(external_configuration: LayeredConfigTree) -> None:
468
- """
469
- Validates the external configuration.
438
+ """Validates the external configuration.
470
439
 
471
440
  Parameters
472
441
  ----------
473
442
  external_configuration
474
443
  A LayeredConfigTree defining the external configuration
475
444
 
476
- Returns
477
- -------
478
- None
479
-
480
445
  Raises
481
446
  ------
482
447
  CausesParsingErrors
@@ -504,18 +469,13 @@ class CausesConfigurationParser(ComponentConfigurationParser):
504
469
  raise CausesParsingErrors(error_messages)
505
470
 
506
471
  def _validate_causes_config(self, causes_config: LayeredConfigTree) -> None:
507
- """
508
- Validates the cause model configuration.
472
+ """Validates the cause model configuration.
509
473
 
510
474
  Parameters
511
475
  ----------
512
476
  causes_config
513
477
  A LayeredConfigTree defining the cause model configurations
514
478
 
515
- Returns
516
- -------
517
- None
518
-
519
479
  Raises
520
480
  ------
521
481
  CausesParsingErrors
@@ -530,8 +490,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
530
490
  raise CausesParsingErrors(error_messages)
531
491
 
532
492
  def _validate_cause(self, cause_name: str, cause_config: Dict[str, Any]) -> List[str]:
533
- """
534
- Validates a cause configuration and returns a list of error messages.
493
+ """Validates a cause configuration and returns a list of error messages.
535
494
 
536
495
  Parameters
537
496
  ----------
@@ -542,7 +501,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
542
501
 
543
502
  Returns
544
503
  -------
545
- List[str]
546
504
  A list of error messages
547
505
  """
548
506
  error_messages = []
@@ -607,8 +565,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
607
565
  def _validate_state(
608
566
  self, cause_name: str, state_name: str, state_config: Dict[str, Any]
609
567
  ) -> List[str]:
610
- """
611
- Validates a state configuration and returns a list of error messages.
568
+ """Validates a state configuration and returns a list of error messages.
612
569
 
613
570
  Parameters
614
571
  ----------
@@ -621,7 +578,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
621
578
 
622
579
  Returns
623
580
  -------
624
- List[str]
625
581
  A list of error messages
626
582
  """
627
583
  error_messages = []
@@ -702,8 +658,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
702
658
  transition_config: Dict[str, Any],
703
659
  states_config: Dict[str, Any],
704
660
  ) -> List[str]:
705
- """
706
- Validates a transition configuration and returns a list of error messages.
661
+ """Validates a transition configuration and returns a list of error messages.
707
662
 
708
663
  Parameters
709
664
  ----------
@@ -718,7 +673,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
718
673
 
719
674
  Returns
720
675
  -------
721
- List[str]
722
676
  A list of error messages
723
677
  """
724
678
  error_messages = []
@@ -803,8 +757,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
803
757
  def _validate_imported_type(
804
758
  import_path: str, cause_name: str, entity_type: str, entity_name: Optional[str] = None
805
759
  ) -> List[str]:
806
- """
807
- Validates an imported type and returns a list of error messages.
760
+ """Validates an imported type and returns a list of error messages.
808
761
 
809
762
  Parameters
810
763
  ----------
@@ -820,7 +773,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
820
773
 
821
774
  Returns
822
775
  -------
823
- List[str]
824
776
  A list of error messages
825
777
  """
826
778
  expected_type = {"model": DiseaseModel, "state": BaseDiseaseState}[entity_type]
@@ -847,9 +799,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
847
799
  def _validate_data_sources(
848
800
  self, config: Dict[str, Any], cause_name: str, config_type: str, config_name: str
849
801
  ) -> List[str]:
850
- """
851
- Validates the data sources in a configuration and returns a list of
852
- error messages.
802
+ """Validates the data sources in a configuration and returns any error messages.
853
803
 
854
804
  Parameters
855
805
  ----------
@@ -864,7 +814,6 @@ class CausesConfigurationParser(ComponentConfigurationParser):
864
814
 
865
815
  Returns
866
816
  -------
867
- List[str]
868
817
  A list of error messages
869
818
  """
870
819
  error_messages = []
@@ -81,8 +81,7 @@ class FertilityCrudeBirthRate(Component):
81
81
 
82
82
  new_births = sim_pop_size_t0 * live_births / true_pop_size * step_size
83
83
 
84
- Where
85
-
84
+ Where:
86
85
  sim_pop_size_t0 = the initial simulation population size
87
86
  live_births = annual number of live births in the true population
88
87
  true_pop_size = the true population size
@@ -126,6 +125,7 @@ class FertilityCrudeBirthRate(Component):
126
125
  def on_time_step(self, event: Event) -> None:
127
126
  """Adds new simulants every time step based on the Crude Birth Rate
128
127
  and an assumption that birth is a Poisson process
128
+
129
129
  Parameters
130
130
  ----------
131
131
  event
@@ -151,9 +151,7 @@ class FertilityCrudeBirthRate(Component):
151
151
 
152
152
 
153
153
  class FertilityAgeSpecificRates(Component):
154
- """
155
- A simulant-specific model for fertility and pregnancies.
156
- """
154
+ """A simulant-specific model for fertility and pregnancies."""
157
155
 
158
156
  ##############
159
157
  # Properties #
@@ -180,11 +178,11 @@ class FertilityAgeSpecificRates(Component):
180
178
  #####################
181
179
 
182
180
  def setup(self, builder: Builder) -> None:
183
- """Setup the common randomness stream and
184
- age-specific fertility lookup tables.
181
+ """Setup the common randomness stream and age-specific fertility lookup tables.
182
+
185
183
  Parameters
186
184
  ----------
187
- builder : vivarium.engine.Builder
185
+ builder
188
186
  Framework coordination object.
189
187
  """
190
188
  age_specific_fertility_rate = self.load_age_specific_fertility_rate_data(builder)
@@ -238,9 +236,10 @@ class FertilityAgeSpecificRates(Component):
238
236
 
239
237
  def on_time_step(self, event: Event) -> None:
240
238
  """Produces new children and updates parent status on time steps.
239
+
241
240
  Parameters
242
241
  ----------
243
- event : vivarium.population.PopulationEvent
242
+ event
244
243
  The event that triggered the function call.
245
244
  """
246
245
  # Get a view on all living women who haven't had a child in at least nine months.
@@ -126,7 +126,6 @@ class BasePopulation(Component):
126
126
  respectively. Here we are agnostic to the methods of entrance and exit (e.g., birth,
127
127
  migration, death, etc.) as these characteristics can be inferred from this column and
128
128
  other information about the simulant and the simulation parameters.
129
-
130
129
  """
131
130
 
132
131
  age_params = {
@@ -282,7 +281,6 @@ def generate_population(
282
281
 
283
282
  Returns
284
283
  -------
285
- pandas.DataFrame
286
284
  Table with columns
287
285
  'entrance_time'
288
286
  The `pandas.Timestamp` describing when the simulant entered
@@ -299,7 +297,6 @@ def generate_population(
299
297
  The location indicating where the simulant resides.
300
298
  'sex'
301
299
  Either 'Male' or 'Female'. The sex of the simulant.
302
-
303
300
  """
304
301
  simulants = pd.DataFrame(
305
302
  {
@@ -361,7 +358,6 @@ def _assign_demography_with_initial_age(
361
358
 
362
359
  Returns
363
360
  -------
364
- pandas.DataFrame
365
361
  Table with same columns as `simulants` and with the additional
366
362
  columns 'age', 'sex', and 'location'.
367
363
  """
@@ -426,7 +422,6 @@ def _assign_demography_with_age_bounds(
426
422
 
427
423
  Returns
428
424
  -------
429
- pandas.DataFrame
430
425
  Table with same columns as `simulants` and with the additional columns
431
426
  'age', 'sex', and 'location'.
432
427
 
@@ -33,7 +33,6 @@ def assign_demographic_proportions(
33
33
 
34
34
  Returns
35
35
  -------
36
- pandas.DataFrame
37
36
  Table with columns
38
37
  'age' : Midpoint of the age group,
39
38
  'age_start' : Lower bound of the age group,
@@ -101,7 +100,6 @@ def rescale_binned_proportions(
101
100
 
102
101
  Returns
103
102
  -------
104
- pandas.DataFrame
105
103
  Table with the same columns as `pop_data` where all bins outside the range
106
104
  (age_start, age_end) have been discarded. If age_start and age_end
107
105
  don't fall cleanly on age boundaries, the bins in which they lie are clipped and
@@ -174,7 +172,6 @@ def rescale_binned_proportions(
174
172
  def _add_edge_age_groups(pop_data: pd.DataFrame) -> pd.DataFrame:
175
173
  """Pads the population data with age groups that enforce constant
176
174
  left interpolation and interpolation to zero on the right.
177
-
178
175
  """
179
176
  index_cols = ["location", "year_start", "year_end", "sex"]
180
177
  age_cols = ["age", "age_start", "age_end"]
@@ -250,7 +247,6 @@ def smooth_ages(
250
247
 
251
248
  Returns
252
249
  -------
253
- pandas.DataFrame
254
250
  Table with same columns as `simulants` with ages smoothed out within the age bins.
255
251
  """
256
252
  simulants = simulants.copy()
@@ -324,7 +320,7 @@ def _get_bins_and_proportions(
324
320
 
325
321
  Returns
326
322
  -------
327
- Tuple[EndpointValues, AgeValues]
323
+ A tuple of endpoints tuples and ages tuples.
328
324
  The `EndpointValues` tuple has values (
329
325
  age at left edge of bin,
330
326
  age at right edge of bin,
@@ -334,7 +330,6 @@ def _get_bins_and_proportions(
334
330
  proportion of pop in previous bin,
335
331
  proportion of pop in next bin,
336
332
  )
337
-
338
333
  """
339
334
  left = float(pop_data.loc[pop_data["age"] == age.current, "age_start"].iloc[0])
340
335
  right = float(pop_data.loc[pop_data["age"] == age.current, "age_end"].iloc[0])
@@ -405,7 +400,6 @@ def _construct_sampling_parameters(
405
400
 
406
401
  Returns
407
402
  -------
408
- Tuple[EndpointValues, EndpointValues, float, float]
409
403
  A tuple of (pdf, slope, area, cdf_inflection_point) where
410
404
  pdf is a tuple with values (
411
405
  pdf evaluated at left bin edge,
@@ -474,7 +468,6 @@ def _compute_ages(
474
468
 
475
469
  Returns
476
470
  -------
477
- Union[np.ndarray, float]
478
471
  Smoothed ages from one half of the age bin distribution.
479
472
  """
480
473
  if abs(slope) < np.finfo(np.float32).eps:
@@ -59,9 +59,9 @@ from vivarium_public_health.utilities import get_lookup_columns
59
59
 
60
60
 
61
61
  class Mortality(Component):
62
- """
63
- This is the mortality component which models sources of mortality for a model.
64
- THe component models all cause mortality and allows for disease models to contribute
62
+ """This is the mortality component which models of mortality in a population.
63
+
64
+ The component models all cause mortality and allows for disease models to contribute
65
65
  cause specific mortality. Data used by this class should be supplied in the artifact
66
66
  and is configurable in the configuration to build lookup tables. For instance, let's
67
67
  say we want to use sex and hair color to build a lookup table for all cause mortality.
@@ -4,7 +4,7 @@ from vivarium.framework.results import VALUE_COLUMN
4
4
 
5
5
 
6
6
  class __Columns(NamedTuple):
7
- """column names"""
7
+ """Container class for column names used in results dataframes."""
8
8
 
9
9
  VALUE: str = VALUE_COLUMN
10
10
  MEASURE: str = "measure"
@@ -8,7 +8,7 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, List, Union
11
+ from typing import Union
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -42,6 +42,17 @@ class DisabilityObserver(PublicHealthObserver):
42
42
  - "sex"
43
43
  include:
44
44
  - "sample_stratification"
45
+ Attributes
46
+ ----------
47
+ disability_weight_pipeline_name
48
+ The name of the pipeline that produces disability weights.
49
+ step_size
50
+ The time step size of the simulation.
51
+ disability_weight
52
+ The pipeline that produces disability weights.
53
+ causes_of_disability
54
+ The causes of disability to be observed.
55
+
45
56
  """
46
57
 
47
58
  ##############
@@ -50,7 +61,7 @@ class DisabilityObserver(PublicHealthObserver):
50
61
 
51
62
  @property
52
63
  def disability_classes(self) -> list[type]:
53
- """The classes to be considered for causes of disability."""
64
+ """The classes to be considered as causes of disability."""
54
65
  return [DiseaseState, RiskAttributableDisease]
55
66
 
56
67
  #####################
@@ -66,16 +77,25 @@ class DisabilityObserver(PublicHealthObserver):
66
77
  #################
67
78
 
68
79
  def setup(self, builder: Builder) -> None:
80
+ """Set up the observer."""
69
81
  self.step_size = pd.Timedelta(days=builder.configuration.time.step_size)
70
82
  self.disability_weight = self.get_disability_weight_pipeline(builder)
71
83
  self.set_causes_of_disability(builder)
72
84
 
73
85
  def set_causes_of_disability(self, builder: Builder) -> None:
74
- """Set the causes of disability to be observed by removing any excluded
75
- via the model spec from the list of all disability class causes. We implement
76
- exclusions here because disabilities are unique in that they are not
77
- registered stratifications and so cannot be excluded during the stratification
78
- call like other categories.
86
+ """Set the causes of disability to be observed.
87
+
88
+ The causes to be observed are any registered components of class types
89
+ found in the ``disability_classes`` property *excluding* any listed in
90
+ the model spec as ``excluded_categories``.
91
+
92
+ Notes
93
+ -----
94
+ We implement exclusions here instead of during the stratification call
95
+ like most other categories because disabilities are unique in that they are
96
+ *not* actually registered stratifications.
97
+
98
+ Also note that we add an 'all_causes' category here.
79
99
  """
80
100
  causes_of_disability = builder.components.get_components_by_type(
81
101
  self.disability_classes
@@ -111,9 +131,21 @@ class DisabilityObserver(PublicHealthObserver):
111
131
  ]
112
132
 
113
133
  def get_configuration(self, builder: Builder) -> LayeredConfigTree:
134
+ """Get the stratification configuration for this observer.
135
+
136
+ Parameters
137
+ ----------
138
+ builder
139
+ The builder object for the simulation.
140
+
141
+ Returns
142
+ -------
143
+ The stratification configuration for this observer.
144
+ """
114
145
  return builder.configuration.stratification.disability
115
146
 
116
147
  def register_observations(self, builder: Builder) -> None:
148
+ """Register an observation for years lived with disability."""
117
149
  cause_pipelines = [
118
150
  f"{cause.state_id}.disability_weight" for cause in self.causes_of_disability
119
151
  ]
@@ -131,6 +163,17 @@ class DisabilityObserver(PublicHealthObserver):
131
163
  )
132
164
 
133
165
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
166
+ """Register (and return) the pipeline that produces disability weights.
167
+
168
+ Parameters
169
+ ----------
170
+ builder
171
+ The builder object for the simulation.
172
+
173
+ Returns
174
+ -------
175
+ The pipeline that produces disability weights.
176
+ """
134
177
  return builder.value.register_value_producer(
135
178
  self.disability_weight_pipeline_name,
136
179
  source=lambda index: [pd.Series(0.0, index=index)],
@@ -143,6 +186,17 @@ class DisabilityObserver(PublicHealthObserver):
143
186
  ###############
144
187
 
145
188
  def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
189
+ """Aggregate disability weights for the time step.
190
+
191
+ Parameters
192
+ ----------
193
+ dw
194
+ The disability weights to aggregate.
195
+
196
+ Returns
197
+ -------
198
+ The aggregated disability weights.
199
+ """
146
200
  aggregated_dw = (dw * to_years(self.step_size)).sum().squeeze()
147
201
  if isinstance(aggregated_dw, pd.Series):
148
202
  aggregated_dw.index.name = "cause_of_disability"
@@ -153,10 +207,27 @@ class DisabilityObserver(PublicHealthObserver):
153
207
  ##############################
154
208
 
155
209
  def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
156
- """Format results. Note that ylds are unique in that we
157
- can't stratify by cause of disability (because there can be multiple at
158
- once), and so the results here are actually wide by disability weight
159
- pipeline name.
210
+ """Format wide YLD results to match typical/long stratified results.
211
+
212
+ YLDs are unique in that we can't stratify by cause of disability (because
213
+ there can be multiple at once), and so the results here are actually wide
214
+ by disability weight pipeline name. This method formats the results to be
215
+ long by cause of disability.
216
+
217
+ Parameters
218
+ ----------
219
+ measure
220
+ The measure.
221
+ results
222
+ The wide results to format.
223
+
224
+ Returns
225
+ -------
226
+ The results stacked by causes of disability.
227
+
228
+ Notes
229
+ -----
230
+ This method also adds the 'sub_entity' column to the results.
160
231
  """
161
232
  if len(self.causes_of_disability) > 1:
162
233
  # Drop the unused 'value' column and rename the remaining pipeline names to cause names
@@ -180,15 +251,18 @@ class DisabilityObserver(PublicHealthObserver):
180
251
  return results
181
252
 
182
253
  def get_entity_type_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
254
+ """Get the 'entity_type' column values."""
183
255
  entity_type_map = {
184
256
  cause.state_id: cause.cause_type for cause in self.causes_of_disability
185
257
  }
186
258
  return results[COLUMNS.SUB_ENTITY].map(entity_type_map).astype(CategoricalDtype())
187
259
 
188
260
  def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
261
+ """Get the 'entity' column values."""
189
262
  entity_map = {cause.state_id: cause.model for cause in self.causes_of_disability}
190
263
  return results[COLUMNS.SUB_ENTITY].map(entity_map).astype(CategoricalDtype())
191
264
 
192
265
  def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
266
+ """Get the 'sub_entity' column values."""
193
267
  # The sub-entity col was created in the 'format' method
194
268
  return results[COLUMNS.SUB_ENTITY]