vivarium-public-health 3.1.1__py3-none-any.whl → 3.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +12 -20
  3. vivarium_public_health/disease/special_disease.py +5 -5
  4. vivarium_public_health/disease/state.py +18 -17
  5. vivarium_public_health/disease/transition.py +9 -8
  6. vivarium_public_health/mslt/delay.py +5 -5
  7. vivarium_public_health/mslt/disease.py +5 -5
  8. vivarium_public_health/mslt/intervention.py +8 -8
  9. vivarium_public_health/mslt/magic_wand_components.py +3 -3
  10. vivarium_public_health/mslt/observer.py +4 -6
  11. vivarium_public_health/mslt/population.py +4 -6
  12. vivarium_public_health/plugins/parser.py +18 -19
  13. vivarium_public_health/population/add_new_birth_cohorts.py +3 -5
  14. vivarium_public_health/population/base_population.py +8 -10
  15. vivarium_public_health/population/data_transformations.py +4 -5
  16. vivarium_public_health/population/mortality.py +6 -6
  17. vivarium_public_health/results/disability.py +1 -3
  18. vivarium_public_health/results/disease.py +5 -5
  19. vivarium_public_health/results/mortality.py +3 -3
  20. vivarium_public_health/results/observer.py +7 -7
  21. vivarium_public_health/results/risk.py +3 -3
  22. vivarium_public_health/risks/base_risk.py +4 -4
  23. vivarium_public_health/risks/distributions.py +15 -15
  24. vivarium_public_health/risks/effect.py +10 -10
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +17 -16
  26. vivarium_public_health/treatment/magic_wand.py +3 -3
  27. vivarium_public_health/treatment/scale_up.py +6 -5
  28. vivarium_public_health/utilities.py +4 -5
  29. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.2.dist-info}/METADATA +1 -1
  30. vivarium_public_health-3.1.2.dist-info/RECORD +49 -0
  31. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.2.dist-info}/WHEEL +1 -1
  32. vivarium_public_health-3.1.1.dist-info/RECORD +0 -49
  33. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.2.dist-info}/LICENSE.txt +0 -0
  34. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.2.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ Health package.
10
10
 
11
11
  """
12
12
  import warnings
13
+ from collections.abc import Callable
13
14
  from importlib import import_module
14
- from typing import Any, Callable, Dict, List, Optional, Union
15
+ from typing import Any
15
16
 
16
17
  import pandas as pd
17
18
  from layered_config_tree import LayeredConfigTree
@@ -37,7 +38,7 @@ from vivarium_public_health.utilities import TargetString
37
38
  class CausesParsingErrors(ParsingError):
38
39
  """Error raised when there are any errors parsing a cause model configuration."""
39
40
 
40
- def __init__(self, messages: List[str]):
41
+ def __init__(self, messages: list[str]):
41
42
  super().__init__("\n - " + "\n - ".join(messages))
42
43
 
43
44
 
@@ -81,7 +82,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
81
82
  This value is used if the transition configuration does not explicity specify it.
82
83
  """
83
84
 
84
- def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]:
85
+ def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]:
85
86
  """Parses the component configuration and returns a list of components.
86
87
 
87
88
  In particular, this method looks for an `external_configuration` key
@@ -216,7 +217,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
216
217
 
217
218
  def _get_cause_model_components(
218
219
  self, causes_config: LayeredConfigTree
219
- ) -> List[Component]:
220
+ ) -> list[Component]:
220
221
  """Parses the cause model configuration and returns the `DiseaseModel` components.
221
222
 
222
223
  Parameters
@@ -236,7 +237,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
236
237
  data_sources_config = cause_config.data_sources
237
238
  data_sources = self._get_data_sources(data_sources_config)
238
239
 
239
- states: Dict[str, BaseDiseaseState] = {
240
+ states: dict[str, BaseDiseaseState] = {
240
241
  state_name: self._get_state(state_name, state_config, cause_name)
241
242
  for state_name, state_config in cause_config.states.items()
242
243
  }
@@ -351,7 +352,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
351
352
 
352
353
  def _get_data_sources(
353
354
  self, config: LayeredConfigTree
354
- ) -> Dict[str, Callable[[Builder, Any], Any]]:
355
+ ) -> dict[str, Callable[[Builder, Any], Any]]:
355
356
  """Parses a data sources configuration and returns the data sources.
356
357
 
357
358
  Parameters
@@ -366,9 +367,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
366
367
  return {name: self._get_data_source(name, config[name]) for name in config.keys()}
367
368
 
368
369
  @staticmethod
369
- def _get_data_source(
370
- name: str, source: Union[str, float]
371
- ) -> Callable[[Builder, Any], Any]:
370
+ def _get_data_source(name: str, source: str | float) -> Callable[[Builder, Any], Any]:
372
371
  """Parses a data source and returns a callable that can be used to retrieve the data.
373
372
 
374
373
  Parameters
@@ -500,7 +499,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
500
499
  if error_messages:
501
500
  raise CausesParsingErrors(error_messages)
502
501
 
503
- def _validate_cause(self, cause_name: str, cause_config: Dict[str, Any]) -> List[str]:
502
+ def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]:
504
503
  """Validates a cause configuration and returns a list of error messages.
505
504
 
506
505
  Parameters
@@ -591,8 +590,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
591
590
  return error_messages
592
591
 
593
592
  def _validate_state(
594
- self, cause_name: str, state_name: str, state_config: Dict[str, Any]
595
- ) -> List[str]:
593
+ self, cause_name: str, state_name: str, state_config: dict[str, Any]
594
+ ) -> list[str]:
596
595
  """Validates a state configuration and returns a list of error messages.
597
596
 
598
597
  Parameters
@@ -683,9 +682,9 @@ class CausesConfigurationParser(ComponentConfigurationParser):
683
682
  self,
684
683
  cause_name: str,
685
684
  transition_name: str,
686
- transition_config: Dict[str, Any],
687
- states_config: Dict[str, Any],
688
- ) -> List[str]:
685
+ transition_config: dict[str, Any],
686
+ states_config: dict[str, Any],
687
+ ) -> list[str]:
689
688
  """Validates a transition configuration and returns a list of error messages.
690
689
 
691
690
  Parameters
@@ -783,8 +782,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
783
782
 
784
783
  @staticmethod
785
784
  def _validate_imported_type(
786
- import_path: str, cause_name: str, entity_type: str, entity_name: Optional[str] = None
787
- ) -> List[str]:
785
+ import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None
786
+ ) -> list[str]:
788
787
  """Validates an imported type and returns a list of error messages.
789
788
 
790
789
  Parameters
@@ -825,8 +824,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
825
824
  return error_messages
826
825
 
827
826
  def _validate_data_sources(
828
- self, config: Dict[str, Any], cause_name: str, config_type: str, config_name: str
829
- ) -> List[str]:
827
+ self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str
828
+ ) -> list[str]:
830
829
  """Validates the data sources in a configuration and returns any error messages.
831
830
 
832
831
  Parameters
@@ -7,8 +7,6 @@ This module contains several different models of fertility.
7
7
 
8
8
  """
9
9
 
10
- from typing import Dict, List, Optional
11
-
12
10
  import numpy as np
13
11
  import pandas as pd
14
12
  from vivarium import Component
@@ -158,15 +156,15 @@ class FertilityAgeSpecificRates(Component):
158
156
  ##############
159
157
 
160
158
  @property
161
- def columns_created(self) -> List[str]:
159
+ def columns_created(self) -> list[str]:
162
160
  return ["last_birth_time", "parent_id"]
163
161
 
164
162
  @property
165
- def columns_required(self) -> Optional[List[str]]:
163
+ def columns_required(self) -> list[str] | None:
166
164
  return ["sex"]
167
165
 
168
166
  @property
169
- def initialization_requirements(self) -> Dict[str, List[str]]:
167
+ def initialization_requirements(self) -> dict[str, list[str]]:
170
168
  return {
171
169
  "requires_columns": ["sex"],
172
170
  "requires_values": [],
@@ -8,9 +8,7 @@ characteristics to simulants.
8
8
 
9
9
  """
10
10
 
11
- from __future__ import annotations
12
-
13
- from typing import Callable, Dict, Iterable, List
11
+ from collections.abc import Callable, Iterable
14
12
 
15
13
  import numpy as np
16
14
  import pandas as pd
@@ -48,7 +46,7 @@ class BasePopulation(Component):
48
46
  ##############
49
47
 
50
48
  @property
51
- def columns_created(self) -> List[str]:
49
+ def columns_created(self) -> list[str]:
52
50
  return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]
53
51
 
54
52
  @property
@@ -91,7 +89,7 @@ class BasePopulation(Component):
91
89
  #################
92
90
 
93
91
  @staticmethod
94
- def get_randomness_streams(builder: Builder) -> Dict[str, RandomnessStream]:
92
+ def get_randomness_streams(builder: Builder) -> dict[str, RandomnessStream]:
95
93
  return {
96
94
  "general_purpose": builder.randomness.get_stream("population_generation"),
97
95
  "bin_selection": builder.randomness.get_stream(
@@ -255,7 +253,7 @@ class AgeOutSimulants(Component):
255
253
  ##############
256
254
 
257
255
  @property
258
- def columns_required(self) -> List[str]:
256
+ def columns_required(self) -> list[str]:
259
257
  """A list of the columns this component requires that it did not create."""
260
258
  return self._columns_required
261
259
 
@@ -290,9 +288,9 @@ def generate_population(
290
288
  simulant_ids: pd.Index,
291
289
  creation_time: pd.Timestamp,
292
290
  step_size: pd.Timedelta,
293
- age_params: Dict[str, float],
291
+ age_params: dict[str, float],
294
292
  demographic_proportions: pd.DataFrame,
295
- randomness_streams: Dict[str, RandomnessStream],
293
+ randomness_streams: dict[str, RandomnessStream],
296
294
  register_simulants: Callable[[pd.DataFrame], None],
297
295
  key_columns: Iterable[str] = ("entrance_time", "age"),
298
296
  ) -> pd.DataFrame:
@@ -378,7 +376,7 @@ def _assign_demography_with_initial_age(
378
376
  pop_data: pd.DataFrame,
379
377
  initial_age: float,
380
378
  step_size: pd.Timedelta,
381
- randomness_streams: Dict[str, RandomnessStream],
379
+ randomness_streams: dict[str, RandomnessStream],
382
380
  register_simulants: Callable[[pd.DataFrame], None],
383
381
  ) -> pd.DataFrame:
384
382
  """Assigns age, sex, and location information to the provided simulants given a fixed age.
@@ -441,7 +439,7 @@ def _assign_demography_with_age_bounds(
441
439
  pop_data: pd.DataFrame,
442
440
  age_start: float,
443
441
  age_end: float,
444
- randomness_streams: Dict[str, RandomnessStream],
442
+ randomness_streams: dict[str, RandomnessStream],
445
443
  register_simulants: Callable[[pd.DataFrame], None],
446
444
  key_columns: Iterable[str] = ("entrance_time", "age"),
447
445
  ) -> pd.DataFrame:
@@ -9,7 +9,6 @@ it into different distributions for sampling.
9
9
  """
10
10
 
11
11
  from collections import namedtuple
12
- from typing import Tuple, Union
13
12
 
14
13
  import numpy as np
15
14
  import pandas as pd
@@ -302,7 +301,7 @@ def smooth_ages(
302
301
  def _get_bins_and_proportions(
303
302
  pop_data: pd.DataFrame,
304
303
  age: AgeValues,
305
- ) -> Tuple[EndpointValues, AgeValues]:
304
+ ) -> tuple[EndpointValues, AgeValues]:
306
305
  """Finds and returns the bin edges and the population proportions in
307
306
  the current and neighboring bins.
308
307
 
@@ -376,7 +375,7 @@ def _get_bins_and_proportions(
376
375
 
377
376
  def _construct_sampling_parameters(
378
377
  age: AgeValues, endpoint: EndpointValues, proportion: AgeValues
379
- ) -> Tuple[EndpointValues, EndpointValues, float, float]:
378
+ ) -> tuple[EndpointValues, EndpointValues, float, float]:
380
379
  """Calculates some sampling distribution parameters from known values.
381
380
 
382
381
  Parameters
@@ -442,12 +441,12 @@ def _construct_sampling_parameters(
442
441
 
443
442
 
444
443
  def _compute_ages(
445
- uniform_rv: Union[np.ndarray, float],
444
+ uniform_rv: np.ndarray | float,
446
445
  start: float,
447
446
  height: float,
448
447
  slope: float,
449
448
  normalization: float,
450
- ) -> Union[np.ndarray, float]:
449
+ ) -> np.ndarray | float:
451
450
  """Produces samples from the local age distribution.
452
451
 
453
452
  Parameters
@@ -45,7 +45,7 @@ back the modified unmodeled csmr.
45
45
 
46
46
  """
47
47
 
48
- from typing import Any, Dict, List, Optional, Union
48
+ from typing import Any
49
49
 
50
50
  import pandas as pd
51
51
  from vivarium import Component
@@ -100,7 +100,7 @@ class Mortality(Component):
100
100
  ##############
101
101
 
102
102
  @property
103
- def configuration_defaults(self) -> Dict[str, Any]:
103
+ def configuration_defaults(self) -> dict[str, Any]:
104
104
  return {
105
105
  "mortality": {
106
106
  "data_sources": {
@@ -113,18 +113,18 @@ class Mortality(Component):
113
113
  }
114
114
 
115
115
  @property
116
- def standard_lookup_tables(self) -> List[str]:
116
+ def standard_lookup_tables(self) -> list[str]:
117
117
  return [
118
118
  "all_cause_mortality_rate",
119
119
  "life_expectancy",
120
120
  ]
121
121
 
122
122
  @property
123
- def columns_created(self) -> List[str]:
123
+ def columns_created(self) -> list[str]:
124
124
  return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
125
125
 
126
126
  @property
127
- def columns_required(self) -> Optional[List[str]]:
127
+ def columns_required(self) -> list[str] | None:
128
128
  return ["alive", "exit_time", "age", "sex"]
129
129
 
130
130
  @property
@@ -184,7 +184,7 @@ class Mortality(Component):
184
184
  requires_columns=required_columns,
185
185
  )
186
186
 
187
- def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
187
+ def load_unmodeled_csmr(self, builder: Builder) -> float | pd.DataFrame:
188
188
  # todo validate that all data have the same columns
189
189
  raw_csmr = 0.0
190
190
  for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
@@ -8,8 +8,6 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Union
12
-
13
11
  import pandas as pd
14
12
  from layered_config_tree import LayeredConfigTree
15
13
  from loguru import logger
@@ -185,7 +183,7 @@ class DisabilityObserver(PublicHealthObserver):
185
183
  # Aggregators #
186
184
  ###############
187
185
 
188
- def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
186
+ def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
189
187
  """Aggregate disability weights for the time step.
190
188
 
191
189
  Parameters
@@ -8,7 +8,7 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -66,7 +66,7 @@ class DiseaseObserver(PublicHealthObserver):
66
66
  ##############
67
67
 
68
68
  @property
69
- def configuration_defaults(self) -> Dict[str, Any]:
69
+ def configuration_defaults(self) -> dict[str, Any]:
70
70
  """A dictionary containing the defaults for any configurations managed by
71
71
  this component.
72
72
  """
@@ -79,17 +79,17 @@ class DiseaseObserver(PublicHealthObserver):
79
79
  }
80
80
 
81
81
  @property
82
- def columns_created(self) -> List[str]:
82
+ def columns_created(self) -> list[str]:
83
83
  """Columns created by this observer."""
84
84
  return [self.previous_state_column_name]
85
85
 
86
86
  @property
87
- def columns_required(self) -> List[str]:
87
+ def columns_required(self) -> list[str]:
88
88
  """Columns required by this observer."""
89
89
  return [self.disease]
90
90
 
91
91
  @property
92
- def initialization_requirements(self) -> Dict[str, List[str]]:
92
+ def initialization_requirements(self) -> dict[str, list[str]]:
93
93
  """Requirements for observer initialization."""
94
94
  return {
95
95
  "requires_columns": [self.disease],
@@ -8,7 +8,7 @@ excess mortality in the simulation, including "other causes".
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -81,7 +81,7 @@ class MortalityObserver(PublicHealthObserver):
81
81
  return [DiseaseState, RiskAttributableDisease]
82
82
 
83
83
  @property
84
- def configuration_defaults(self) -> Dict[str, Any]:
84
+ def configuration_defaults(self) -> dict[str, Any]:
85
85
  """A dictionary containing the defaults for any configurations managed by
86
86
  this component.
87
87
  """
@@ -90,7 +90,7 @@ class MortalityObserver(PublicHealthObserver):
90
90
  return config_defaults
91
91
 
92
92
  @property
93
- def columns_required(self) -> List[str]:
93
+ def columns_required(self) -> list[str]:
94
94
  """Columns required by this observer."""
95
95
  return [
96
96
  "alive",
@@ -8,7 +8,7 @@ public health models.
8
8
 
9
9
  """
10
10
 
11
- from typing import Callable, List, Optional, Union
11
+ from collections.abc import Callable
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium.framework.engine import Builder
@@ -32,12 +32,12 @@ class PublicHealthObserver(Observer):
32
32
  name: str,
33
33
  pop_filter: str,
34
34
  when: str = "collect_metrics",
35
- requires_columns: List[str] = [],
36
- requires_values: List[str] = [],
37
- additional_stratifications: List[str] = [],
38
- excluded_stratifications: List[str] = [],
39
- aggregator_sources: Optional[List[str]] = None,
40
- aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len,
35
+ requires_columns: list[str] = [],
36
+ requires_values: list[str] = [],
37
+ additional_stratifications: list[str] = [],
38
+ excluded_stratifications: list[str] = [],
39
+ aggregator_sources: list[str] | None = None,
40
+ aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
41
41
  ) -> None:
42
42
  """Registers an adding observation to the results system.
43
43
 
@@ -7,7 +7,7 @@ This module contains tools for observing risk exposure during the simulation.
7
7
 
8
8
  """
9
9
 
10
- from typing import Any, Dict, List, Optional
10
+ from typing import Any
11
11
 
12
12
  import pandas as pd
13
13
  from layered_config_tree import LayeredConfigTree
@@ -59,7 +59,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
59
59
  ##############
60
60
 
61
61
  @property
62
- def configuration_defaults(self) -> Dict[str, Any]:
62
+ def configuration_defaults(self) -> dict[str, Any]:
63
63
  """A dictionary containing the defaults for any configurations managed by
64
64
  this component.
65
65
  """
@@ -72,7 +72,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
72
72
  }
73
73
 
74
74
  @property
75
- def columns_required(self) -> Optional[List[str]]:
75
+ def columns_required(self) -> list[str] | None:
76
76
  """The columns required by this observer."""
77
77
  return ["alive"]
78
78
 
@@ -8,7 +8,7 @@ exposure.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium import Component
@@ -106,7 +106,7 @@ class Risk(Component):
106
106
  return self.risk
107
107
 
108
108
  @property
109
- def configuration_defaults(self) -> Dict[str, Any]:
109
+ def configuration_defaults(self) -> dict[str, Any]:
110
110
  return {
111
111
  self.name: {
112
112
  "data_sources": {
@@ -122,14 +122,14 @@ class Risk(Component):
122
122
  }
123
123
 
124
124
  @property
125
- def columns_created(self) -> List[str]:
125
+ def columns_created(self) -> list[str]:
126
126
  columns_to_create = [self.propensity_column_name]
127
127
  if self.create_exposure_column:
128
128
  columns_to_create.append(self.exposure_column_name)
129
129
  return columns_to_create
130
130
 
131
131
  @property
132
- def initialization_requirements(self) -> Dict[str, List[str]]:
132
+ def initialization_requirements(self) -> dict[str, list[str]]:
133
133
  return {
134
134
  "requires_columns": [],
135
135
  "requires_values": [],
@@ -9,7 +9,7 @@ exposure distributions.
9
9
  """
10
10
 
11
11
  from abc import ABC, abstractmethod
12
- from typing import Callable, Dict, List, Optional, Union
12
+ from collections.abc import Callable
13
13
 
14
14
  import numpy as np
15
15
  import pandas as pd
@@ -38,7 +38,7 @@ class RiskExposureDistribution(Component, ABC):
38
38
  self,
39
39
  risk: EntityString,
40
40
  distribution_type: str,
41
- exposure_data: Optional[Union[int, float, pd.DataFrame]] = None,
41
+ exposure_data: int | float | pd.DataFrame | None = None,
42
42
  ) -> None:
43
43
  super().__init__()
44
44
  self.risk = risk
@@ -51,14 +51,14 @@ class RiskExposureDistribution(Component, ABC):
51
51
  # Setup methods #
52
52
  #################
53
53
 
54
- def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]:
54
+ def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None:
55
55
  return builder.configuration[self.risk]
56
56
 
57
57
  @abstractmethod
58
58
  def build_all_lookup_tables(self, builder: "Builder") -> None:
59
59
  raise NotImplementedError
60
60
 
61
- def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
61
+ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
62
62
  if self._exposure_data is not None:
63
63
  return self._exposure_data
64
64
  return self.get_data(builder, self.configuration["data_sources"]["exposure"])
@@ -92,11 +92,11 @@ class EnsembleDistribution(RiskExposureDistribution):
92
92
  ##############
93
93
 
94
94
  @property
95
- def columns_created(self) -> List[str]:
95
+ def columns_created(self) -> list[str]:
96
96
  return [self._propensity]
97
97
 
98
98
  @property
99
- def initialization_requirements(self) -> Dict[str, List[str]]:
99
+ def initialization_requirements(self) -> dict[str, list[str]]:
100
100
  return {
101
101
  "requires_columns": [],
102
102
  "requires_values": [],
@@ -265,7 +265,7 @@ class ContinuousDistribution(RiskExposureDistribution):
265
265
 
266
266
  class PolytomousDistribution(RiskExposureDistribution):
267
267
  @property
268
- def categories(self) -> List[str]:
268
+ def categories(self) -> list[str]:
269
269
  # These need to be sorted so the cumulative sum is in the ocrrect order of categories
270
270
  # and results are therefore reproducible and correct
271
271
  return sorted(self.lookup_tables["exposure"].value_columns)
@@ -286,8 +286,8 @@ class PolytomousDistribution(RiskExposureDistribution):
286
286
  )
287
287
 
288
288
  def get_exposure_value_columns(
289
- self, exposure_data: Union[int, float, pd.DataFrame]
290
- ) -> Optional[List[str]]:
289
+ self, exposure_data: int | float | pd.DataFrame
290
+ ) -> list[str] | None:
291
291
  if isinstance(exposure_data, pd.DataFrame):
292
292
  return list(exposure_data["parameter"].unique())
293
293
  return None
@@ -342,7 +342,7 @@ class DichotomousDistribution(RiskExposureDistribution):
342
342
  )
343
343
  self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
344
344
 
345
- def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
345
+ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
346
346
  exposure_data = super().get_exposure_data(builder)
347
347
 
348
348
  if isinstance(exposure_data, (int, float)):
@@ -373,8 +373,8 @@ class DichotomousDistribution(RiskExposureDistribution):
373
373
  return exposure_data
374
374
 
375
375
  def get_exposure_value_columns(
376
- self, exposure_data: Union[int, float, pd.DataFrame]
377
- ) -> Optional[List[str]]:
376
+ self, exposure_data: int | float | pd.DataFrame
377
+ ) -> list[str] | None:
378
378
  if isinstance(exposure_data, pd.DataFrame):
379
379
  return self.get_value_columns(exposure_data)
380
380
  return None
@@ -470,9 +470,9 @@ def clip(q):
470
470
 
471
471
 
472
472
  def get_risk_distribution_parameter(
473
- value_columns_getter: Callable[[Union[pd.DataFrame]], List[str]],
474
- data: Union[float, pd.DataFrame],
475
- ) -> Union[float, pd.Series]:
473
+ value_columns_getter: Callable[[pd.DataFrame], list[str]],
474
+ data: float | pd.DataFrame,
475
+ ) -> float | pd.Series:
476
476
  if isinstance(data, pd.DataFrame):
477
477
  value_columns = value_columns_getter(data)
478
478
  if len(value_columns) > 1:
@@ -7,9 +7,9 @@ This module contains tools for modeling the relationship between risk
7
7
  exposure models and disease models.
8
8
 
9
9
  """
10
-
10
+ from collections.abc import Callable
11
11
  from importlib import import_module
12
- from typing import Any, Callable, Dict, List, Tuple, Union
12
+ from typing import Any
13
13
 
14
14
  import numpy as np
15
15
  import pandas as pd
@@ -58,7 +58,7 @@ class RiskEffect(Component):
58
58
  return f"risk_effect.{risk.name}_on_{target}"
59
59
 
60
60
  @property
61
- def configuration_defaults(self) -> Dict[str, Any]:
61
+ def configuration_defaults(self) -> dict[str, Any]:
62
62
  """Default values for any configurations managed by this component."""
63
63
  return {
64
64
  self.name: {
@@ -150,7 +150,7 @@ class RiskEffect(Component):
150
150
  self,
151
151
  builder: Builder,
152
152
  configuration=None,
153
- ) -> Union[str, float, pd.DataFrame]:
153
+ ) -> str | float | pd.DataFrame:
154
154
  if configuration is None:
155
155
  configuration = self.configuration
156
156
 
@@ -173,8 +173,8 @@ class RiskEffect(Component):
173
173
  return rr_data
174
174
 
175
175
  def get_filtered_data(
176
- self, builder: "Builder", data_source: Union[str, float, pd.DataFrame]
177
- ) -> Union[float, pd.DataFrame]:
176
+ self, builder: "Builder", data_source: str | float | pd.DataFrame
177
+ ) -> float | pd.DataFrame:
178
178
  data = super().get_data(builder, data_source)
179
179
 
180
180
  if isinstance(data, pd.DataFrame):
@@ -191,8 +191,8 @@ class RiskEffect(Component):
191
191
  return data
192
192
 
193
193
  def process_categorical_data(
194
- self, builder: Builder, rr_data: Union[str, float, pd.DataFrame]
195
- ) -> Tuple[Union[str, float, pd.DataFrame], List[str]]:
194
+ self, builder: Builder, rr_data: str | float | pd.DataFrame
195
+ ) -> tuple[str | float | pd.DataFrame, list[str]]:
196
196
  if not isinstance(rr_data, pd.DataFrame):
197
197
  cat1 = builder.data.load("population.demographic_dimensions")
198
198
  cat1["parameter"] = "cat1"
@@ -347,7 +347,7 @@ class NonLogLinearRiskEffect(RiskEffect):
347
347
  ##############
348
348
 
349
349
  @property
350
- def configuration_defaults(self) -> Dict[str, Any]:
350
+ def configuration_defaults(self) -> dict[str, Any]:
351
351
  """Default values for any configurations managed by this component."""
352
352
  return {
353
353
  self.name: {
@@ -419,7 +419,7 @@ class NonLogLinearRiskEffect(RiskEffect):
419
419
  self,
420
420
  builder: Builder,
421
421
  configuration=None,
422
- ) -> Union[str, float, pd.DataFrame]:
422
+ ) -> str | float | pd.DataFrame:
423
423
  if configuration is None:
424
424
  configuration = self.configuration
425
425