vivarium-public-health 3.1.0__py3-none-any.whl → 3.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +40 -28
  3. vivarium_public_health/disease/special_disease.py +7 -8
  4. vivarium_public_health/disease/state.py +23 -22
  5. vivarium_public_health/disease/transition.py +11 -10
  6. vivarium_public_health/mslt/delay.py +5 -5
  7. vivarium_public_health/mslt/disease.py +5 -5
  8. vivarium_public_health/mslt/intervention.py +8 -8
  9. vivarium_public_health/mslt/magic_wand_components.py +3 -3
  10. vivarium_public_health/mslt/observer.py +4 -6
  11. vivarium_public_health/mslt/population.py +4 -6
  12. vivarium_public_health/plugins/parser.py +57 -30
  13. vivarium_public_health/population/add_new_birth_cohorts.py +3 -5
  14. vivarium_public_health/population/base_population.py +8 -10
  15. vivarium_public_health/population/data_transformations.py +4 -5
  16. vivarium_public_health/population/mortality.py +7 -7
  17. vivarium_public_health/results/disability.py +1 -3
  18. vivarium_public_health/results/disease.py +5 -5
  19. vivarium_public_health/results/mortality.py +3 -3
  20. vivarium_public_health/results/observer.py +7 -7
  21. vivarium_public_health/results/risk.py +3 -3
  22. vivarium_public_health/risks/base_risk.py +4 -4
  23. vivarium_public_health/risks/distributions.py +15 -15
  24. vivarium_public_health/risks/effect.py +10 -10
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +17 -16
  26. vivarium_public_health/treatment/magic_wand.py +3 -3
  27. vivarium_public_health/treatment/scale_up.py +6 -5
  28. vivarium_public_health/utilities.py +4 -5
  29. {vivarium_public_health-3.1.0.dist-info → vivarium_public_health-3.1.2.dist-info}/METADATA +2 -2
  30. vivarium_public_health-3.1.2.dist-info/RECORD +49 -0
  31. {vivarium_public_health-3.1.0.dist-info → vivarium_public_health-3.1.2.dist-info}/WHEEL +1 -1
  32. vivarium_public_health-3.1.0.dist-info/RECORD +0 -49
  33. {vivarium_public_health-3.1.0.dist-info → vivarium_public_health-3.1.2.dist-info}/LICENSE.txt +0 -0
  34. {vivarium_public_health-3.1.0.dist-info → vivarium_public_health-3.1.2.dist-info}/top_level.txt +0 -0
@@ -8,8 +8,6 @@ multi-state lifetable simulations.
8
8
 
9
9
  """
10
10
 
11
- from typing import List, Optional
12
-
13
11
  import numpy as np
14
12
  import pandas as pd
15
13
  from vivarium import Component
@@ -50,7 +48,7 @@ class BasePopulation(Component):
50
48
  ##############
51
49
 
52
50
  @property
53
- def columns_created(self) -> List[str]:
51
+ def columns_created(self) -> list[str]:
54
52
  return [
55
53
  "age",
56
54
  "sex",
@@ -71,7 +69,7 @@ class BasePopulation(Component):
71
69
  ]
72
70
 
73
71
  @property
74
- def columns_required(self) -> Optional[List[str]]:
72
+ def columns_required(self) -> list[str] | None:
75
73
  return ["tracked"]
76
74
 
77
75
  #####################
@@ -122,7 +120,7 @@ class Mortality(Component):
122
120
  ##############
123
121
 
124
122
  @property
125
- def columns_required(self) -> Optional[List[str]]:
123
+ def columns_required(self) -> list[str] | None:
126
124
  return [
127
125
  "population",
128
126
  "bau_population",
@@ -190,7 +188,7 @@ class Disability(Component):
190
188
  ##############
191
189
 
192
190
  @property
193
- def columns_required(self) -> Optional[List[str]]:
191
+ def columns_required(self) -> list[str] | None:
194
192
  return [
195
193
  "bau_yld_rate",
196
194
  "yld_rate",
@@ -9,9 +9,10 @@ that can parse configurations of components specific to the Vivarium Public
9
9
  Health package.
10
10
 
11
11
  """
12
-
12
+ import warnings
13
+ from collections.abc import Callable
13
14
  from importlib import import_module
14
- from typing import Any, Callable, Dict, List, Optional, Union
15
+ from typing import Any
15
16
 
16
17
  import pandas as pd
17
18
  from layered_config_tree import LayeredConfigTree
@@ -37,7 +38,7 @@ from vivarium_public_health.utilities import TargetString
37
38
  class CausesParsingErrors(ParsingError):
38
39
  """Error raised when there are any errors parsing a cause model configuration."""
39
40
 
40
- def __init__(self, messages: List[str]):
41
+ def __init__(self, messages: list[str]):
41
42
  super().__init__("\n - " + "\n - ".join(messages))
42
43
 
43
44
 
@@ -56,11 +57,13 @@ class CausesConfigurationParser(ComponentConfigurationParser):
56
57
  DEFAULT_MODEL_CONFIG = {
57
58
  "model_type": f"{DiseaseModel.__module__}.{DiseaseModel.__name__}",
58
59
  "initial_state": None,
60
+ "residual_state": None,
59
61
  }
60
62
  """Default cause model configuration if it's not explicitly specified.
61
63
 
62
- If the initial state is not specified, the cause model must have a state
63
- named 'susceptible'.
64
+ Initial state and residual state cannot both be provided. If neither initial
65
+ state nor residual state has been specified, the cause model must have a
66
+ state named 'susceptible'.
64
67
  """
65
68
 
66
69
  DEFAULT_STATE_CONFIG = {
@@ -79,7 +82,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
79
82
  This value is used if the transition configuration does not explicity specify it.
80
83
  """
81
84
 
82
- def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]:
85
+ def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]:
83
86
  """Parses the component configuration and returns a list of components.
84
87
 
85
88
  In particular, this method looks for an `external_configuration` key
@@ -104,7 +107,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
104
107
  causes:
105
108
  cause_1:
106
109
  model_type: vivarium_public_health.disease.DiseaseModel
107
- initial_state: susceptible
110
+ residual_state: susceptible
108
111
  states:
109
112
  susceptible:
110
113
  cause_type: cause
@@ -214,7 +217,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
214
217
 
215
218
  def _get_cause_model_components(
216
219
  self, causes_config: LayeredConfigTree
217
- ) -> List[Component]:
220
+ ) -> list[Component]:
218
221
  """Parses the cause model configuration and returns the `DiseaseModel` components.
219
222
 
220
223
  Parameters
@@ -234,7 +237,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
234
237
  data_sources_config = cause_config.data_sources
235
238
  data_sources = self._get_data_sources(data_sources_config)
236
239
 
237
- states: Dict[str, BaseDiseaseState] = {
240
+ states: dict[str, BaseDiseaseState] = {
238
241
  state_name: self._get_state(state_name, state_config, cause_name)
239
242
  for state_name, state_config in cause_config.states.items()
240
243
  }
@@ -247,11 +250,13 @@ class CausesConfigurationParser(ComponentConfigurationParser):
247
250
  )
248
251
 
249
252
  model_type = import_by_path(cause_config.model_type)
250
- initial_state = states.get(cause_config.initial_state, None)
253
+ residual_state = states.get(
254
+ cause_config.residual_state, states.get(cause_config.initial_state, None)
255
+ )
251
256
  model = model_type(
252
257
  cause_name,
253
- initial_state=initial_state,
254
258
  states=list(states.values()),
259
+ residual_state=residual_state,
255
260
  get_data_functions=data_sources,
256
261
  )
257
262
  cause_models.append(model)
@@ -347,7 +352,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
347
352
 
348
353
  def _get_data_sources(
349
354
  self, config: LayeredConfigTree
350
- ) -> Dict[str, Callable[[Builder, Any], Any]]:
355
+ ) -> dict[str, Callable[[Builder, Any], Any]]:
351
356
  """Parses a data sources configuration and returns the data sources.
352
357
 
353
358
  Parameters
@@ -362,9 +367,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
362
367
  return {name: self._get_data_source(name, config[name]) for name in config.keys()}
363
368
 
364
369
  @staticmethod
365
- def _get_data_source(
366
- name: str, source: Union[str, float]
367
- ) -> Callable[[Builder, Any], Any]:
370
+ def _get_data_source(name: str, source: str | float) -> Callable[[Builder, Any], Any]:
368
371
  """Parses a data source and returns a callable that can be used to retrieve the data.
369
372
 
370
373
  Parameters
@@ -403,7 +406,14 @@ class CausesConfigurationParser(ComponentConfigurationParser):
403
406
  # Validation methods #
404
407
  ######################
405
408
 
406
- _CAUSE_KEYS = {"model_type", "initial_state", "states", "transitions", "data_sources"}
409
+ _CAUSE_KEYS = {
410
+ "model_type",
411
+ "initial_state",
412
+ "states",
413
+ "transitions",
414
+ "data_sources",
415
+ "residual_state",
416
+ }
407
417
  _STATE_KEYS = {
408
418
  "state_type",
409
419
  "cause_type",
@@ -489,7 +499,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
489
499
  if error_messages:
490
500
  raise CausesParsingErrors(error_messages)
491
501
 
492
- def _validate_cause(self, cause_name: str, cause_config: Dict[str, Any]) -> List[str]:
502
+ def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]:
493
503
  """Validates a cause configuration and returns a list of error messages.
494
504
 
495
505
  Parameters
@@ -535,11 +545,28 @@ class CausesConfigurationParser(ComponentConfigurationParser):
535
545
  )
536
546
  else:
537
547
  initial_state = cause_config.get("initial_state", None)
538
- if initial_state is not None and initial_state not in states_config:
548
+ residual_state = cause_config.get("residual_state", None)
549
+ if initial_state is not None:
550
+ warnings.warn(
551
+ "In the future, the 'initial_state' cause configuration will"
552
+ " be used to initialize all simulants into that state. To"
553
+ " retain the current behavior of defining a residual state,"
554
+ " use the 'residual_state' cause configuration.",
555
+ DeprecationWarning,
556
+ stacklevel=2,
557
+ )
558
+ if residual_state is None:
559
+ residual_state = initial_state
560
+ else:
561
+ error_messages.append(
562
+ "A cause may not have both 'initial_state and"
563
+ " 'residual_state' configurations."
564
+ )
565
+
566
+ if residual_state is not None and residual_state not in states_config:
539
567
  error_messages.append(
540
- f"Initial state '{cause_config['initial_state']}' for cause "
541
- f"'{cause_name}' must be present in the states for cause "
542
- f"'{cause_name}."
568
+ f"Residual state '{residual_state}' for cause '{cause_name}'"
569
+ f" must be present in the states for cause '{cause_name}."
543
570
  )
544
571
  for state_name, state_config in states_config.items():
545
572
  error_messages += self._validate_state(cause_name, state_name, state_config)
@@ -563,8 +590,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
563
590
  return error_messages
564
591
 
565
592
  def _validate_state(
566
- self, cause_name: str, state_name: str, state_config: Dict[str, Any]
567
- ) -> List[str]:
593
+ self, cause_name: str, state_name: str, state_config: dict[str, Any]
594
+ ) -> list[str]:
568
595
  """Validates a state configuration and returns a list of error messages.
569
596
 
570
597
  Parameters
@@ -655,9 +682,9 @@ class CausesConfigurationParser(ComponentConfigurationParser):
655
682
  self,
656
683
  cause_name: str,
657
684
  transition_name: str,
658
- transition_config: Dict[str, Any],
659
- states_config: Dict[str, Any],
660
- ) -> List[str]:
685
+ transition_config: dict[str, Any],
686
+ states_config: dict[str, Any],
687
+ ) -> list[str]:
661
688
  """Validates a transition configuration and returns a list of error messages.
662
689
 
663
690
  Parameters
@@ -755,8 +782,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
755
782
 
756
783
  @staticmethod
757
784
  def _validate_imported_type(
758
- import_path: str, cause_name: str, entity_type: str, entity_name: Optional[str] = None
759
- ) -> List[str]:
785
+ import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None
786
+ ) -> list[str]:
760
787
  """Validates an imported type and returns a list of error messages.
761
788
 
762
789
  Parameters
@@ -797,8 +824,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
797
824
  return error_messages
798
825
 
799
826
  def _validate_data_sources(
800
- self, config: Dict[str, Any], cause_name: str, config_type: str, config_name: str
801
- ) -> List[str]:
827
+ self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str
828
+ ) -> list[str]:
802
829
  """Validates the data sources in a configuration and returns any error messages.
803
830
 
804
831
  Parameters
@@ -7,8 +7,6 @@ This module contains several different models of fertility.
7
7
 
8
8
  """
9
9
 
10
- from typing import Dict, List, Optional
11
-
12
10
  import numpy as np
13
11
  import pandas as pd
14
12
  from vivarium import Component
@@ -158,15 +156,15 @@ class FertilityAgeSpecificRates(Component):
158
156
  ##############
159
157
 
160
158
  @property
161
- def columns_created(self) -> List[str]:
159
+ def columns_created(self) -> list[str]:
162
160
  return ["last_birth_time", "parent_id"]
163
161
 
164
162
  @property
165
- def columns_required(self) -> Optional[List[str]]:
163
+ def columns_required(self) -> list[str] | None:
166
164
  return ["sex"]
167
165
 
168
166
  @property
169
- def initialization_requirements(self) -> Dict[str, List[str]]:
167
+ def initialization_requirements(self) -> dict[str, list[str]]:
170
168
  return {
171
169
  "requires_columns": ["sex"],
172
170
  "requires_values": [],
@@ -8,9 +8,7 @@ characteristics to simulants.
8
8
 
9
9
  """
10
10
 
11
- from __future__ import annotations
12
-
13
- from typing import Callable, Dict, Iterable, List
11
+ from collections.abc import Callable, Iterable
14
12
 
15
13
  import numpy as np
16
14
  import pandas as pd
@@ -48,7 +46,7 @@ class BasePopulation(Component):
48
46
  ##############
49
47
 
50
48
  @property
51
- def columns_created(self) -> List[str]:
49
+ def columns_created(self) -> list[str]:
52
50
  return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]
53
51
 
54
52
  @property
@@ -91,7 +89,7 @@ class BasePopulation(Component):
91
89
  #################
92
90
 
93
91
  @staticmethod
94
- def get_randomness_streams(builder: Builder) -> Dict[str, RandomnessStream]:
92
+ def get_randomness_streams(builder: Builder) -> dict[str, RandomnessStream]:
95
93
  return {
96
94
  "general_purpose": builder.randomness.get_stream("population_generation"),
97
95
  "bin_selection": builder.randomness.get_stream(
@@ -255,7 +253,7 @@ class AgeOutSimulants(Component):
255
253
  ##############
256
254
 
257
255
  @property
258
- def columns_required(self) -> List[str]:
256
+ def columns_required(self) -> list[str]:
259
257
  """A list of the columns this component requires that it did not create."""
260
258
  return self._columns_required
261
259
 
@@ -290,9 +288,9 @@ def generate_population(
290
288
  simulant_ids: pd.Index,
291
289
  creation_time: pd.Timestamp,
292
290
  step_size: pd.Timedelta,
293
- age_params: Dict[str, float],
291
+ age_params: dict[str, float],
294
292
  demographic_proportions: pd.DataFrame,
295
- randomness_streams: Dict[str, RandomnessStream],
293
+ randomness_streams: dict[str, RandomnessStream],
296
294
  register_simulants: Callable[[pd.DataFrame], None],
297
295
  key_columns: Iterable[str] = ("entrance_time", "age"),
298
296
  ) -> pd.DataFrame:
@@ -378,7 +376,7 @@ def _assign_demography_with_initial_age(
378
376
  pop_data: pd.DataFrame,
379
377
  initial_age: float,
380
378
  step_size: pd.Timedelta,
381
- randomness_streams: Dict[str, RandomnessStream],
379
+ randomness_streams: dict[str, RandomnessStream],
382
380
  register_simulants: Callable[[pd.DataFrame], None],
383
381
  ) -> pd.DataFrame:
384
382
  """Assigns age, sex, and location information to the provided simulants given a fixed age.
@@ -441,7 +439,7 @@ def _assign_demography_with_age_bounds(
441
439
  pop_data: pd.DataFrame,
442
440
  age_start: float,
443
441
  age_end: float,
444
- randomness_streams: Dict[str, RandomnessStream],
442
+ randomness_streams: dict[str, RandomnessStream],
445
443
  register_simulants: Callable[[pd.DataFrame], None],
446
444
  key_columns: Iterable[str] = ("entrance_time", "age"),
447
445
  ) -> pd.DataFrame:
@@ -9,7 +9,6 @@ it into different distributions for sampling.
9
9
  """
10
10
 
11
11
  from collections import namedtuple
12
- from typing import Tuple, Union
13
12
 
14
13
  import numpy as np
15
14
  import pandas as pd
@@ -302,7 +301,7 @@ def smooth_ages(
302
301
  def _get_bins_and_proportions(
303
302
  pop_data: pd.DataFrame,
304
303
  age: AgeValues,
305
- ) -> Tuple[EndpointValues, AgeValues]:
304
+ ) -> tuple[EndpointValues, AgeValues]:
306
305
  """Finds and returns the bin edges and the population proportions in
307
306
  the current and neighboring bins.
308
307
 
@@ -376,7 +375,7 @@ def _get_bins_and_proportions(
376
375
 
377
376
  def _construct_sampling_parameters(
378
377
  age: AgeValues, endpoint: EndpointValues, proportion: AgeValues
379
- ) -> Tuple[EndpointValues, EndpointValues, float, float]:
378
+ ) -> tuple[EndpointValues, EndpointValues, float, float]:
380
379
  """Calculates some sampling distribution parameters from known values.
381
380
 
382
381
  Parameters
@@ -442,12 +441,12 @@ def _construct_sampling_parameters(
442
441
 
443
442
 
444
443
  def _compute_ages(
445
- uniform_rv: Union[np.ndarray, float],
444
+ uniform_rv: np.ndarray | float,
446
445
  start: float,
447
446
  height: float,
448
447
  slope: float,
449
448
  normalization: float,
450
- ) -> Union[np.ndarray, float]:
449
+ ) -> np.ndarray | float:
451
450
  """Produces samples from the local age distribution.
452
451
 
453
452
  Parameters
@@ -45,7 +45,7 @@ back the modified unmodeled csmr.
45
45
 
46
46
  """
47
47
 
48
- from typing import Any, Dict, List, Optional, Union
48
+ from typing import Any
49
49
 
50
50
  import pandas as pd
51
51
  from vivarium import Component
@@ -100,12 +100,12 @@ class Mortality(Component):
100
100
  ##############
101
101
 
102
102
  @property
103
- def configuration_defaults(self) -> Dict[str, Any]:
103
+ def configuration_defaults(self) -> dict[str, Any]:
104
104
  return {
105
105
  "mortality": {
106
106
  "data_sources": {
107
107
  "all_cause_mortality_rate": "cause.all_causes.cause_specific_mortality_rate",
108
- "unmodeled_cause_specific_mortality_rate": "self::load_unmodeled_csmr",
108
+ "unmodeled_cause_specific_mortality_rate": self.load_unmodeled_csmr,
109
109
  "life_expectancy": "population.theoretical_minimum_risk_life_expectancy",
110
110
  },
111
111
  "unmodeled_causes": [],
@@ -113,18 +113,18 @@ class Mortality(Component):
113
113
  }
114
114
 
115
115
  @property
116
- def standard_lookup_tables(self) -> List[str]:
116
+ def standard_lookup_tables(self) -> list[str]:
117
117
  return [
118
118
  "all_cause_mortality_rate",
119
119
  "life_expectancy",
120
120
  ]
121
121
 
122
122
  @property
123
- def columns_created(self) -> List[str]:
123
+ def columns_created(self) -> list[str]:
124
124
  return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
125
125
 
126
126
  @property
127
- def columns_required(self) -> Optional[List[str]]:
127
+ def columns_required(self) -> list[str] | None:
128
128
  return ["alive", "exit_time", "age", "sex"]
129
129
 
130
130
  @property
@@ -184,7 +184,7 @@ class Mortality(Component):
184
184
  requires_columns=required_columns,
185
185
  )
186
186
 
187
- def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
187
+ def load_unmodeled_csmr(self, builder: Builder) -> float | pd.DataFrame:
188
188
  # todo validate that all data have the same columns
189
189
  raw_csmr = 0.0
190
190
  for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
@@ -8,8 +8,6 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Union
12
-
13
11
  import pandas as pd
14
12
  from layered_config_tree import LayeredConfigTree
15
13
  from loguru import logger
@@ -185,7 +183,7 @@ class DisabilityObserver(PublicHealthObserver):
185
183
  # Aggregators #
186
184
  ###############
187
185
 
188
- def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
186
+ def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
189
187
  """Aggregate disability weights for the time step.
190
188
 
191
189
  Parameters
@@ -8,7 +8,7 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -66,7 +66,7 @@ class DiseaseObserver(PublicHealthObserver):
66
66
  ##############
67
67
 
68
68
  @property
69
- def configuration_defaults(self) -> Dict[str, Any]:
69
+ def configuration_defaults(self) -> dict[str, Any]:
70
70
  """A dictionary containing the defaults for any configurations managed by
71
71
  this component.
72
72
  """
@@ -79,17 +79,17 @@ class DiseaseObserver(PublicHealthObserver):
79
79
  }
80
80
 
81
81
  @property
82
- def columns_created(self) -> List[str]:
82
+ def columns_created(self) -> list[str]:
83
83
  """Columns created by this observer."""
84
84
  return [self.previous_state_column_name]
85
85
 
86
86
  @property
87
- def columns_required(self) -> List[str]:
87
+ def columns_required(self) -> list[str]:
88
88
  """Columns required by this observer."""
89
89
  return [self.disease]
90
90
 
91
91
  @property
92
- def initialization_requirements(self) -> Dict[str, List[str]]:
92
+ def initialization_requirements(self) -> dict[str, list[str]]:
93
93
  """Requirements for observer initialization."""
94
94
  return {
95
95
  "requires_columns": [self.disease],
@@ -8,7 +8,7 @@ excess mortality in the simulation, including "other causes".
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -81,7 +81,7 @@ class MortalityObserver(PublicHealthObserver):
81
81
  return [DiseaseState, RiskAttributableDisease]
82
82
 
83
83
  @property
84
- def configuration_defaults(self) -> Dict[str, Any]:
84
+ def configuration_defaults(self) -> dict[str, Any]:
85
85
  """A dictionary containing the defaults for any configurations managed by
86
86
  this component.
87
87
  """
@@ -90,7 +90,7 @@ class MortalityObserver(PublicHealthObserver):
90
90
  return config_defaults
91
91
 
92
92
  @property
93
- def columns_required(self) -> List[str]:
93
+ def columns_required(self) -> list[str]:
94
94
  """Columns required by this observer."""
95
95
  return [
96
96
  "alive",
@@ -8,7 +8,7 @@ public health models.
8
8
 
9
9
  """
10
10
 
11
- from typing import Callable, List, Optional, Union
11
+ from collections.abc import Callable
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium.framework.engine import Builder
@@ -32,12 +32,12 @@ class PublicHealthObserver(Observer):
32
32
  name: str,
33
33
  pop_filter: str,
34
34
  when: str = "collect_metrics",
35
- requires_columns: List[str] = [],
36
- requires_values: List[str] = [],
37
- additional_stratifications: List[str] = [],
38
- excluded_stratifications: List[str] = [],
39
- aggregator_sources: Optional[List[str]] = None,
40
- aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len,
35
+ requires_columns: list[str] = [],
36
+ requires_values: list[str] = [],
37
+ additional_stratifications: list[str] = [],
38
+ excluded_stratifications: list[str] = [],
39
+ aggregator_sources: list[str] | None = None,
40
+ aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
41
41
  ) -> None:
42
42
  """Registers an adding observation to the results system.
43
43
 
@@ -7,7 +7,7 @@ This module contains tools for observing risk exposure during the simulation.
7
7
 
8
8
  """
9
9
 
10
- from typing import Any, Dict, List, Optional
10
+ from typing import Any
11
11
 
12
12
  import pandas as pd
13
13
  from layered_config_tree import LayeredConfigTree
@@ -59,7 +59,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
59
59
  ##############
60
60
 
61
61
  @property
62
- def configuration_defaults(self) -> Dict[str, Any]:
62
+ def configuration_defaults(self) -> dict[str, Any]:
63
63
  """A dictionary containing the defaults for any configurations managed by
64
64
  this component.
65
65
  """
@@ -72,7 +72,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
72
72
  }
73
73
 
74
74
  @property
75
- def columns_required(self) -> Optional[List[str]]:
75
+ def columns_required(self) -> list[str] | None:
76
76
  """The columns required by this observer."""
77
77
  return ["alive"]
78
78
 
@@ -8,7 +8,7 @@ exposure.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium import Component
@@ -106,7 +106,7 @@ class Risk(Component):
106
106
  return self.risk
107
107
 
108
108
  @property
109
- def configuration_defaults(self) -> Dict[str, Any]:
109
+ def configuration_defaults(self) -> dict[str, Any]:
110
110
  return {
111
111
  self.name: {
112
112
  "data_sources": {
@@ -122,14 +122,14 @@ class Risk(Component):
122
122
  }
123
123
 
124
124
  @property
125
- def columns_created(self) -> List[str]:
125
+ def columns_created(self) -> list[str]:
126
126
  columns_to_create = [self.propensity_column_name]
127
127
  if self.create_exposure_column:
128
128
  columns_to_create.append(self.exposure_column_name)
129
129
  return columns_to_create
130
130
 
131
131
  @property
132
- def initialization_requirements(self) -> Dict[str, List[str]]:
132
+ def initialization_requirements(self) -> dict[str, list[str]]:
133
133
  return {
134
134
  "requires_columns": [],
135
135
  "requires_values": [],