vivarium-public-health 3.1.1__py3-none-any.whl → 3.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +12 -20
  3. vivarium_public_health/disease/special_disease.py +5 -5
  4. vivarium_public_health/disease/state.py +18 -17
  5. vivarium_public_health/disease/transition.py +9 -8
  6. vivarium_public_health/mslt/delay.py +5 -5
  7. vivarium_public_health/mslt/disease.py +5 -5
  8. vivarium_public_health/mslt/intervention.py +8 -8
  9. vivarium_public_health/mslt/magic_wand_components.py +3 -3
  10. vivarium_public_health/mslt/observer.py +4 -6
  11. vivarium_public_health/mslt/population.py +4 -6
  12. vivarium_public_health/plugins/parser.py +18 -19
  13. vivarium_public_health/population/add_new_birth_cohorts.py +3 -5
  14. vivarium_public_health/population/base_population.py +8 -10
  15. vivarium_public_health/population/data_transformations.py +4 -5
  16. vivarium_public_health/population/mortality.py +6 -6
  17. vivarium_public_health/results/disability.py +1 -3
  18. vivarium_public_health/results/disease.py +5 -5
  19. vivarium_public_health/results/mortality.py +3 -3
  20. vivarium_public_health/results/observer.py +7 -7
  21. vivarium_public_health/results/risk.py +3 -3
  22. vivarium_public_health/risks/base_risk.py +4 -4
  23. vivarium_public_health/risks/distributions.py +15 -15
  24. vivarium_public_health/risks/effect.py +45 -33
  25. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +20 -20
  26. vivarium_public_health/treatment/magic_wand.py +3 -3
  27. vivarium_public_health/treatment/scale_up.py +6 -5
  28. vivarium_public_health/utilities.py +4 -5
  29. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/METADATA +32 -32
  30. vivarium_public_health-3.1.3.dist-info/RECORD +49 -0
  31. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/WHEEL +1 -1
  32. vivarium_public_health-3.1.1.dist-info/RECORD +0 -49
  33. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/LICENSE.txt +0 -0
  34. {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ Health package.
10
10
 
11
11
  """
12
12
  import warnings
13
+ from collections.abc import Callable
13
14
  from importlib import import_module
14
- from typing import Any, Callable, Dict, List, Optional, Union
15
+ from typing import Any
15
16
 
16
17
  import pandas as pd
17
18
  from layered_config_tree import LayeredConfigTree
@@ -37,7 +38,7 @@ from vivarium_public_health.utilities import TargetString
37
38
  class CausesParsingErrors(ParsingError):
38
39
  """Error raised when there are any errors parsing a cause model configuration."""
39
40
 
40
- def __init__(self, messages: List[str]):
41
+ def __init__(self, messages: list[str]):
41
42
  super().__init__("\n - " + "\n - ".join(messages))
42
43
 
43
44
 
@@ -81,7 +82,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
81
82
  This value is used if the transition configuration does not explicity specify it.
82
83
  """
83
84
 
84
- def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]:
85
+ def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]:
85
86
  """Parses the component configuration and returns a list of components.
86
87
 
87
88
  In particular, this method looks for an `external_configuration` key
@@ -216,7 +217,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
216
217
 
217
218
  def _get_cause_model_components(
218
219
  self, causes_config: LayeredConfigTree
219
- ) -> List[Component]:
220
+ ) -> list[Component]:
220
221
  """Parses the cause model configuration and returns the `DiseaseModel` components.
221
222
 
222
223
  Parameters
@@ -236,7 +237,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
236
237
  data_sources_config = cause_config.data_sources
237
238
  data_sources = self._get_data_sources(data_sources_config)
238
239
 
239
- states: Dict[str, BaseDiseaseState] = {
240
+ states: dict[str, BaseDiseaseState] = {
240
241
  state_name: self._get_state(state_name, state_config, cause_name)
241
242
  for state_name, state_config in cause_config.states.items()
242
243
  }
@@ -351,7 +352,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
351
352
 
352
353
  def _get_data_sources(
353
354
  self, config: LayeredConfigTree
354
- ) -> Dict[str, Callable[[Builder, Any], Any]]:
355
+ ) -> dict[str, Callable[[Builder, Any], Any]]:
355
356
  """Parses a data sources configuration and returns the data sources.
356
357
 
357
358
  Parameters
@@ -366,9 +367,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
366
367
  return {name: self._get_data_source(name, config[name]) for name in config.keys()}
367
368
 
368
369
  @staticmethod
369
- def _get_data_source(
370
- name: str, source: Union[str, float]
371
- ) -> Callable[[Builder, Any], Any]:
370
+ def _get_data_source(name: str, source: str | float) -> Callable[[Builder, Any], Any]:
372
371
  """Parses a data source and returns a callable that can be used to retrieve the data.
373
372
 
374
373
  Parameters
@@ -500,7 +499,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
500
499
  if error_messages:
501
500
  raise CausesParsingErrors(error_messages)
502
501
 
503
- def _validate_cause(self, cause_name: str, cause_config: Dict[str, Any]) -> List[str]:
502
+ def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]:
504
503
  """Validates a cause configuration and returns a list of error messages.
505
504
 
506
505
  Parameters
@@ -591,8 +590,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
591
590
  return error_messages
592
591
 
593
592
  def _validate_state(
594
- self, cause_name: str, state_name: str, state_config: Dict[str, Any]
595
- ) -> List[str]:
593
+ self, cause_name: str, state_name: str, state_config: dict[str, Any]
594
+ ) -> list[str]:
596
595
  """Validates a state configuration and returns a list of error messages.
597
596
 
598
597
  Parameters
@@ -683,9 +682,9 @@ class CausesConfigurationParser(ComponentConfigurationParser):
683
682
  self,
684
683
  cause_name: str,
685
684
  transition_name: str,
686
- transition_config: Dict[str, Any],
687
- states_config: Dict[str, Any],
688
- ) -> List[str]:
685
+ transition_config: dict[str, Any],
686
+ states_config: dict[str, Any],
687
+ ) -> list[str]:
689
688
  """Validates a transition configuration and returns a list of error messages.
690
689
 
691
690
  Parameters
@@ -783,8 +782,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
783
782
 
784
783
  @staticmethod
785
784
  def _validate_imported_type(
786
- import_path: str, cause_name: str, entity_type: str, entity_name: Optional[str] = None
787
- ) -> List[str]:
785
+ import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None
786
+ ) -> list[str]:
788
787
  """Validates an imported type and returns a list of error messages.
789
788
 
790
789
  Parameters
@@ -825,8 +824,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
825
824
  return error_messages
826
825
 
827
826
  def _validate_data_sources(
828
- self, config: Dict[str, Any], cause_name: str, config_type: str, config_name: str
829
- ) -> List[str]:
827
+ self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str
828
+ ) -> list[str]:
830
829
  """Validates the data sources in a configuration and returns any error messages.
831
830
 
832
831
  Parameters
@@ -7,8 +7,6 @@ This module contains several different models of fertility.
7
7
 
8
8
  """
9
9
 
10
- from typing import Dict, List, Optional
11
-
12
10
  import numpy as np
13
11
  import pandas as pd
14
12
  from vivarium import Component
@@ -158,15 +156,15 @@ class FertilityAgeSpecificRates(Component):
158
156
  ##############
159
157
 
160
158
  @property
161
- def columns_created(self) -> List[str]:
159
+ def columns_created(self) -> list[str]:
162
160
  return ["last_birth_time", "parent_id"]
163
161
 
164
162
  @property
165
- def columns_required(self) -> Optional[List[str]]:
163
+ def columns_required(self) -> list[str] | None:
166
164
  return ["sex"]
167
165
 
168
166
  @property
169
- def initialization_requirements(self) -> Dict[str, List[str]]:
167
+ def initialization_requirements(self) -> dict[str, list[str]]:
170
168
  return {
171
169
  "requires_columns": ["sex"],
172
170
  "requires_values": [],
@@ -8,9 +8,7 @@ characteristics to simulants.
8
8
 
9
9
  """
10
10
 
11
- from __future__ import annotations
12
-
13
- from typing import Callable, Dict, Iterable, List
11
+ from collections.abc import Callable, Iterable
14
12
 
15
13
  import numpy as np
16
14
  import pandas as pd
@@ -48,7 +46,7 @@ class BasePopulation(Component):
48
46
  ##############
49
47
 
50
48
  @property
51
- def columns_created(self) -> List[str]:
49
+ def columns_created(self) -> list[str]:
52
50
  return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]
53
51
 
54
52
  @property
@@ -91,7 +89,7 @@ class BasePopulation(Component):
91
89
  #################
92
90
 
93
91
  @staticmethod
94
- def get_randomness_streams(builder: Builder) -> Dict[str, RandomnessStream]:
92
+ def get_randomness_streams(builder: Builder) -> dict[str, RandomnessStream]:
95
93
  return {
96
94
  "general_purpose": builder.randomness.get_stream("population_generation"),
97
95
  "bin_selection": builder.randomness.get_stream(
@@ -255,7 +253,7 @@ class AgeOutSimulants(Component):
255
253
  ##############
256
254
 
257
255
  @property
258
- def columns_required(self) -> List[str]:
256
+ def columns_required(self) -> list[str]:
259
257
  """A list of the columns this component requires that it did not create."""
260
258
  return self._columns_required
261
259
 
@@ -290,9 +288,9 @@ def generate_population(
290
288
  simulant_ids: pd.Index,
291
289
  creation_time: pd.Timestamp,
292
290
  step_size: pd.Timedelta,
293
- age_params: Dict[str, float],
291
+ age_params: dict[str, float],
294
292
  demographic_proportions: pd.DataFrame,
295
- randomness_streams: Dict[str, RandomnessStream],
293
+ randomness_streams: dict[str, RandomnessStream],
296
294
  register_simulants: Callable[[pd.DataFrame], None],
297
295
  key_columns: Iterable[str] = ("entrance_time", "age"),
298
296
  ) -> pd.DataFrame:
@@ -378,7 +376,7 @@ def _assign_demography_with_initial_age(
378
376
  pop_data: pd.DataFrame,
379
377
  initial_age: float,
380
378
  step_size: pd.Timedelta,
381
- randomness_streams: Dict[str, RandomnessStream],
379
+ randomness_streams: dict[str, RandomnessStream],
382
380
  register_simulants: Callable[[pd.DataFrame], None],
383
381
  ) -> pd.DataFrame:
384
382
  """Assigns age, sex, and location information to the provided simulants given a fixed age.
@@ -441,7 +439,7 @@ def _assign_demography_with_age_bounds(
441
439
  pop_data: pd.DataFrame,
442
440
  age_start: float,
443
441
  age_end: float,
444
- randomness_streams: Dict[str, RandomnessStream],
442
+ randomness_streams: dict[str, RandomnessStream],
445
443
  register_simulants: Callable[[pd.DataFrame], None],
446
444
  key_columns: Iterable[str] = ("entrance_time", "age"),
447
445
  ) -> pd.DataFrame:
@@ -9,7 +9,6 @@ it into different distributions for sampling.
9
9
  """
10
10
 
11
11
  from collections import namedtuple
12
- from typing import Tuple, Union
13
12
 
14
13
  import numpy as np
15
14
  import pandas as pd
@@ -302,7 +301,7 @@ def smooth_ages(
302
301
  def _get_bins_and_proportions(
303
302
  pop_data: pd.DataFrame,
304
303
  age: AgeValues,
305
- ) -> Tuple[EndpointValues, AgeValues]:
304
+ ) -> tuple[EndpointValues, AgeValues]:
306
305
  """Finds and returns the bin edges and the population proportions in
307
306
  the current and neighboring bins.
308
307
 
@@ -376,7 +375,7 @@ def _get_bins_and_proportions(
376
375
 
377
376
  def _construct_sampling_parameters(
378
377
  age: AgeValues, endpoint: EndpointValues, proportion: AgeValues
379
- ) -> Tuple[EndpointValues, EndpointValues, float, float]:
378
+ ) -> tuple[EndpointValues, EndpointValues, float, float]:
380
379
  """Calculates some sampling distribution parameters from known values.
381
380
 
382
381
  Parameters
@@ -442,12 +441,12 @@ def _construct_sampling_parameters(
442
441
 
443
442
 
444
443
  def _compute_ages(
445
- uniform_rv: Union[np.ndarray, float],
444
+ uniform_rv: np.ndarray | float,
446
445
  start: float,
447
446
  height: float,
448
447
  slope: float,
449
448
  normalization: float,
450
- ) -> Union[np.ndarray, float]:
449
+ ) -> np.ndarray | float:
451
450
  """Produces samples from the local age distribution.
452
451
 
453
452
  Parameters
@@ -45,7 +45,7 @@ back the modified unmodeled csmr.
45
45
 
46
46
  """
47
47
 
48
- from typing import Any, Dict, List, Optional, Union
48
+ from typing import Any
49
49
 
50
50
  import pandas as pd
51
51
  from vivarium import Component
@@ -100,7 +100,7 @@ class Mortality(Component):
100
100
  ##############
101
101
 
102
102
  @property
103
- def configuration_defaults(self) -> Dict[str, Any]:
103
+ def configuration_defaults(self) -> dict[str, Any]:
104
104
  return {
105
105
  "mortality": {
106
106
  "data_sources": {
@@ -113,18 +113,18 @@ class Mortality(Component):
113
113
  }
114
114
 
115
115
  @property
116
- def standard_lookup_tables(self) -> List[str]:
116
+ def standard_lookup_tables(self) -> list[str]:
117
117
  return [
118
118
  "all_cause_mortality_rate",
119
119
  "life_expectancy",
120
120
  ]
121
121
 
122
122
  @property
123
- def columns_created(self) -> List[str]:
123
+ def columns_created(self) -> list[str]:
124
124
  return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
125
125
 
126
126
  @property
127
- def columns_required(self) -> Optional[List[str]]:
127
+ def columns_required(self) -> list[str] | None:
128
128
  return ["alive", "exit_time", "age", "sex"]
129
129
 
130
130
  @property
@@ -184,7 +184,7 @@ class Mortality(Component):
184
184
  requires_columns=required_columns,
185
185
  )
186
186
 
187
- def load_unmodeled_csmr(self, builder: Builder) -> Union[float, pd.DataFrame]:
187
+ def load_unmodeled_csmr(self, builder: Builder) -> float | pd.DataFrame:
188
188
  # todo validate that all data have the same columns
189
189
  raw_csmr = 0.0
190
190
  for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
@@ -8,8 +8,6 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Union
12
-
13
11
  import pandas as pd
14
12
  from layered_config_tree import LayeredConfigTree
15
13
  from loguru import logger
@@ -185,7 +183,7 @@ class DisabilityObserver(PublicHealthObserver):
185
183
  # Aggregators #
186
184
  ###############
187
185
 
188
- def disability_weight_aggregator(self, dw: pd.DataFrame) -> Union[float, pd.Series]:
186
+ def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
189
187
  """Aggregate disability weights for the time step.
190
188
 
191
189
  Parameters
@@ -8,7 +8,7 @@ in the simulation.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -66,7 +66,7 @@ class DiseaseObserver(PublicHealthObserver):
66
66
  ##############
67
67
 
68
68
  @property
69
- def configuration_defaults(self) -> Dict[str, Any]:
69
+ def configuration_defaults(self) -> dict[str, Any]:
70
70
  """A dictionary containing the defaults for any configurations managed by
71
71
  this component.
72
72
  """
@@ -79,17 +79,17 @@ class DiseaseObserver(PublicHealthObserver):
79
79
  }
80
80
 
81
81
  @property
82
- def columns_created(self) -> List[str]:
82
+ def columns_created(self) -> list[str]:
83
83
  """Columns created by this observer."""
84
84
  return [self.previous_state_column_name]
85
85
 
86
86
  @property
87
- def columns_required(self) -> List[str]:
87
+ def columns_required(self) -> list[str]:
88
88
  """Columns required by this observer."""
89
89
  return [self.disease]
90
90
 
91
91
  @property
92
- def initialization_requirements(self) -> Dict[str, List[str]]:
92
+ def initialization_requirements(self) -> dict[str, list[str]]:
93
93
  """Requirements for observer initialization."""
94
94
  return {
95
95
  "requires_columns": [self.disease],
@@ -8,7 +8,7 @@ excess mortality in the simulation, including "other causes".
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from layered_config_tree import LayeredConfigTree
@@ -81,7 +81,7 @@ class MortalityObserver(PublicHealthObserver):
81
81
  return [DiseaseState, RiskAttributableDisease]
82
82
 
83
83
  @property
84
- def configuration_defaults(self) -> Dict[str, Any]:
84
+ def configuration_defaults(self) -> dict[str, Any]:
85
85
  """A dictionary containing the defaults for any configurations managed by
86
86
  this component.
87
87
  """
@@ -90,7 +90,7 @@ class MortalityObserver(PublicHealthObserver):
90
90
  return config_defaults
91
91
 
92
92
  @property
93
- def columns_required(self) -> List[str]:
93
+ def columns_required(self) -> list[str]:
94
94
  """Columns required by this observer."""
95
95
  return [
96
96
  "alive",
@@ -8,7 +8,7 @@ public health models.
8
8
 
9
9
  """
10
10
 
11
- from typing import Callable, List, Optional, Union
11
+ from collections.abc import Callable
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium.framework.engine import Builder
@@ -32,12 +32,12 @@ class PublicHealthObserver(Observer):
32
32
  name: str,
33
33
  pop_filter: str,
34
34
  when: str = "collect_metrics",
35
- requires_columns: List[str] = [],
36
- requires_values: List[str] = [],
37
- additional_stratifications: List[str] = [],
38
- excluded_stratifications: List[str] = [],
39
- aggregator_sources: Optional[List[str]] = None,
40
- aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len,
35
+ requires_columns: list[str] = [],
36
+ requires_values: list[str] = [],
37
+ additional_stratifications: list[str] = [],
38
+ excluded_stratifications: list[str] = [],
39
+ aggregator_sources: list[str] | None = None,
40
+ aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
41
41
  ) -> None:
42
42
  """Registers an adding observation to the results system.
43
43
 
@@ -7,7 +7,7 @@ This module contains tools for observing risk exposure during the simulation.
7
7
 
8
8
  """
9
9
 
10
- from typing import Any, Dict, List, Optional
10
+ from typing import Any
11
11
 
12
12
  import pandas as pd
13
13
  from layered_config_tree import LayeredConfigTree
@@ -59,7 +59,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
59
59
  ##############
60
60
 
61
61
  @property
62
- def configuration_defaults(self) -> Dict[str, Any]:
62
+ def configuration_defaults(self) -> dict[str, Any]:
63
63
  """A dictionary containing the defaults for any configurations managed by
64
64
  this component.
65
65
  """
@@ -72,7 +72,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
72
72
  }
73
73
 
74
74
  @property
75
- def columns_required(self) -> Optional[List[str]]:
75
+ def columns_required(self) -> list[str] | None:
76
76
  """The columns required by this observer."""
77
77
  return ["alive"]
78
78
 
@@ -8,7 +8,7 @@ exposure.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Dict, List
11
+ from typing import Any
12
12
 
13
13
  import pandas as pd
14
14
  from vivarium import Component
@@ -106,7 +106,7 @@ class Risk(Component):
106
106
  return self.risk
107
107
 
108
108
  @property
109
- def configuration_defaults(self) -> Dict[str, Any]:
109
+ def configuration_defaults(self) -> dict[str, Any]:
110
110
  return {
111
111
  self.name: {
112
112
  "data_sources": {
@@ -122,14 +122,14 @@ class Risk(Component):
122
122
  }
123
123
 
124
124
  @property
125
- def columns_created(self) -> List[str]:
125
+ def columns_created(self) -> list[str]:
126
126
  columns_to_create = [self.propensity_column_name]
127
127
  if self.create_exposure_column:
128
128
  columns_to_create.append(self.exposure_column_name)
129
129
  return columns_to_create
130
130
 
131
131
  @property
132
- def initialization_requirements(self) -> Dict[str, List[str]]:
132
+ def initialization_requirements(self) -> dict[str, list[str]]:
133
133
  return {
134
134
  "requires_columns": [],
135
135
  "requires_values": [],
@@ -9,7 +9,7 @@ exposure distributions.
9
9
  """
10
10
 
11
11
  from abc import ABC, abstractmethod
12
- from typing import Callable, Dict, List, Optional, Union
12
+ from collections.abc import Callable
13
13
 
14
14
  import numpy as np
15
15
  import pandas as pd
@@ -38,7 +38,7 @@ class RiskExposureDistribution(Component, ABC):
38
38
  self,
39
39
  risk: EntityString,
40
40
  distribution_type: str,
41
- exposure_data: Optional[Union[int, float, pd.DataFrame]] = None,
41
+ exposure_data: int | float | pd.DataFrame | None = None,
42
42
  ) -> None:
43
43
  super().__init__()
44
44
  self.risk = risk
@@ -51,14 +51,14 @@ class RiskExposureDistribution(Component, ABC):
51
51
  # Setup methods #
52
52
  #################
53
53
 
54
- def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]:
54
+ def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None:
55
55
  return builder.configuration[self.risk]
56
56
 
57
57
  @abstractmethod
58
58
  def build_all_lookup_tables(self, builder: "Builder") -> None:
59
59
  raise NotImplementedError
60
60
 
61
- def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
61
+ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
62
62
  if self._exposure_data is not None:
63
63
  return self._exposure_data
64
64
  return self.get_data(builder, self.configuration["data_sources"]["exposure"])
@@ -92,11 +92,11 @@ class EnsembleDistribution(RiskExposureDistribution):
92
92
  ##############
93
93
 
94
94
  @property
95
- def columns_created(self) -> List[str]:
95
+ def columns_created(self) -> list[str]:
96
96
  return [self._propensity]
97
97
 
98
98
  @property
99
- def initialization_requirements(self) -> Dict[str, List[str]]:
99
+ def initialization_requirements(self) -> dict[str, list[str]]:
100
100
  return {
101
101
  "requires_columns": [],
102
102
  "requires_values": [],
@@ -265,7 +265,7 @@ class ContinuousDistribution(RiskExposureDistribution):
265
265
 
266
266
  class PolytomousDistribution(RiskExposureDistribution):
267
267
  @property
268
- def categories(self) -> List[str]:
268
+ def categories(self) -> list[str]:
269
269
  # These need to be sorted so the cumulative sum is in the ocrrect order of categories
270
270
  # and results are therefore reproducible and correct
271
271
  return sorted(self.lookup_tables["exposure"].value_columns)
@@ -286,8 +286,8 @@ class PolytomousDistribution(RiskExposureDistribution):
286
286
  )
287
287
 
288
288
  def get_exposure_value_columns(
289
- self, exposure_data: Union[int, float, pd.DataFrame]
290
- ) -> Optional[List[str]]:
289
+ self, exposure_data: int | float | pd.DataFrame
290
+ ) -> list[str] | None:
291
291
  if isinstance(exposure_data, pd.DataFrame):
292
292
  return list(exposure_data["parameter"].unique())
293
293
  return None
@@ -342,7 +342,7 @@ class DichotomousDistribution(RiskExposureDistribution):
342
342
  )
343
343
  self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
344
344
 
345
- def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
345
+ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
346
346
  exposure_data = super().get_exposure_data(builder)
347
347
 
348
348
  if isinstance(exposure_data, (int, float)):
@@ -373,8 +373,8 @@ class DichotomousDistribution(RiskExposureDistribution):
373
373
  return exposure_data
374
374
 
375
375
  def get_exposure_value_columns(
376
- self, exposure_data: Union[int, float, pd.DataFrame]
377
- ) -> Optional[List[str]]:
376
+ self, exposure_data: int | float | pd.DataFrame
377
+ ) -> list[str] | None:
378
378
  if isinstance(exposure_data, pd.DataFrame):
379
379
  return self.get_value_columns(exposure_data)
380
380
  return None
@@ -470,9 +470,9 @@ def clip(q):
470
470
 
471
471
 
472
472
  def get_risk_distribution_parameter(
473
- value_columns_getter: Callable[[Union[pd.DataFrame]], List[str]],
474
- data: Union[float, pd.DataFrame],
475
- ) -> Union[float, pd.Series]:
473
+ value_columns_getter: Callable[[pd.DataFrame], list[str]],
474
+ data: float | pd.DataFrame,
475
+ ) -> float | pd.Series:
476
476
  if isinstance(data, pd.DataFrame):
477
477
  value_columns = value_columns_getter(data)
478
478
  if len(value_columns) > 1: