vivarium-public-health 3.1.1__py3-none-any.whl → 3.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/model.py +12 -20
- vivarium_public_health/disease/special_disease.py +5 -5
- vivarium_public_health/disease/state.py +18 -17
- vivarium_public_health/disease/transition.py +9 -8
- vivarium_public_health/mslt/delay.py +5 -5
- vivarium_public_health/mslt/disease.py +5 -5
- vivarium_public_health/mslt/intervention.py +8 -8
- vivarium_public_health/mslt/magic_wand_components.py +3 -3
- vivarium_public_health/mslt/observer.py +4 -6
- vivarium_public_health/mslt/population.py +4 -6
- vivarium_public_health/plugins/parser.py +18 -19
- vivarium_public_health/population/add_new_birth_cohorts.py +3 -5
- vivarium_public_health/population/base_population.py +8 -10
- vivarium_public_health/population/data_transformations.py +4 -5
- vivarium_public_health/population/mortality.py +6 -6
- vivarium_public_health/results/disability.py +1 -3
- vivarium_public_health/results/disease.py +5 -5
- vivarium_public_health/results/mortality.py +3 -3
- vivarium_public_health/results/observer.py +7 -7
- vivarium_public_health/results/risk.py +3 -3
- vivarium_public_health/risks/base_risk.py +4 -4
- vivarium_public_health/risks/distributions.py +15 -15
- vivarium_public_health/risks/effect.py +45 -33
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +20 -20
- vivarium_public_health/treatment/magic_wand.py +3 -3
- vivarium_public_health/treatment/scale_up.py +6 -5
- vivarium_public_health/utilities.py +4 -5
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/METADATA +32 -32
- vivarium_public_health-3.1.3.dist-info/RECORD +49 -0
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/WHEEL +1 -1
- vivarium_public_health-3.1.1.dist-info/RECORD +0 -49
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ Health package.
|
|
10
10
|
|
11
11
|
"""
|
12
12
|
import warnings
|
13
|
+
from collections.abc import Callable
|
13
14
|
from importlib import import_module
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import pandas as pd
|
17
18
|
from layered_config_tree import LayeredConfigTree
|
@@ -37,7 +38,7 @@ from vivarium_public_health.utilities import TargetString
|
|
37
38
|
class CausesParsingErrors(ParsingError):
|
38
39
|
"""Error raised when there are any errors parsing a cause model configuration."""
|
39
40
|
|
40
|
-
def __init__(self, messages:
|
41
|
+
def __init__(self, messages: list[str]):
|
41
42
|
super().__init__("\n - " + "\n - ".join(messages))
|
42
43
|
|
43
44
|
|
@@ -81,7 +82,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
81
82
|
This value is used if the transition configuration does not explicity specify it.
|
82
83
|
"""
|
83
84
|
|
84
|
-
def parse_component_config(self, component_config: LayeredConfigTree) ->
|
85
|
+
def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]:
|
85
86
|
"""Parses the component configuration and returns a list of components.
|
86
87
|
|
87
88
|
In particular, this method looks for an `external_configuration` key
|
@@ -216,7 +217,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
216
217
|
|
217
218
|
def _get_cause_model_components(
|
218
219
|
self, causes_config: LayeredConfigTree
|
219
|
-
) ->
|
220
|
+
) -> list[Component]:
|
220
221
|
"""Parses the cause model configuration and returns the `DiseaseModel` components.
|
221
222
|
|
222
223
|
Parameters
|
@@ -236,7 +237,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
236
237
|
data_sources_config = cause_config.data_sources
|
237
238
|
data_sources = self._get_data_sources(data_sources_config)
|
238
239
|
|
239
|
-
states:
|
240
|
+
states: dict[str, BaseDiseaseState] = {
|
240
241
|
state_name: self._get_state(state_name, state_config, cause_name)
|
241
242
|
for state_name, state_config in cause_config.states.items()
|
242
243
|
}
|
@@ -351,7 +352,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
351
352
|
|
352
353
|
def _get_data_sources(
|
353
354
|
self, config: LayeredConfigTree
|
354
|
-
) ->
|
355
|
+
) -> dict[str, Callable[[Builder, Any], Any]]:
|
355
356
|
"""Parses a data sources configuration and returns the data sources.
|
356
357
|
|
357
358
|
Parameters
|
@@ -366,9 +367,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
366
367
|
return {name: self._get_data_source(name, config[name]) for name in config.keys()}
|
367
368
|
|
368
369
|
@staticmethod
|
369
|
-
def _get_data_source(
|
370
|
-
name: str, source: Union[str, float]
|
371
|
-
) -> Callable[[Builder, Any], Any]:
|
370
|
+
def _get_data_source(name: str, source: str | float) -> Callable[[Builder, Any], Any]:
|
372
371
|
"""Parses a data source and returns a callable that can be used to retrieve the data.
|
373
372
|
|
374
373
|
Parameters
|
@@ -500,7 +499,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
500
499
|
if error_messages:
|
501
500
|
raise CausesParsingErrors(error_messages)
|
502
501
|
|
503
|
-
def _validate_cause(self, cause_name: str, cause_config:
|
502
|
+
def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]:
|
504
503
|
"""Validates a cause configuration and returns a list of error messages.
|
505
504
|
|
506
505
|
Parameters
|
@@ -591,8 +590,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
591
590
|
return error_messages
|
592
591
|
|
593
592
|
def _validate_state(
|
594
|
-
self, cause_name: str, state_name: str, state_config:
|
595
|
-
) ->
|
593
|
+
self, cause_name: str, state_name: str, state_config: dict[str, Any]
|
594
|
+
) -> list[str]:
|
596
595
|
"""Validates a state configuration and returns a list of error messages.
|
597
596
|
|
598
597
|
Parameters
|
@@ -683,9 +682,9 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
683
682
|
self,
|
684
683
|
cause_name: str,
|
685
684
|
transition_name: str,
|
686
|
-
transition_config:
|
687
|
-
states_config:
|
688
|
-
) ->
|
685
|
+
transition_config: dict[str, Any],
|
686
|
+
states_config: dict[str, Any],
|
687
|
+
) -> list[str]:
|
689
688
|
"""Validates a transition configuration and returns a list of error messages.
|
690
689
|
|
691
690
|
Parameters
|
@@ -783,8 +782,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
783
782
|
|
784
783
|
@staticmethod
|
785
784
|
def _validate_imported_type(
|
786
|
-
import_path: str, cause_name: str, entity_type: str, entity_name:
|
787
|
-
) ->
|
785
|
+
import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None
|
786
|
+
) -> list[str]:
|
788
787
|
"""Validates an imported type and returns a list of error messages.
|
789
788
|
|
790
789
|
Parameters
|
@@ -825,8 +824,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
825
824
|
return error_messages
|
826
825
|
|
827
826
|
def _validate_data_sources(
|
828
|
-
self, config:
|
829
|
-
) ->
|
827
|
+
self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str
|
828
|
+
) -> list[str]:
|
830
829
|
"""Validates the data sources in a configuration and returns any error messages.
|
831
830
|
|
832
831
|
Parameters
|
@@ -7,8 +7,6 @@ This module contains several different models of fertility.
|
|
7
7
|
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Dict, List, Optional
|
11
|
-
|
12
10
|
import numpy as np
|
13
11
|
import pandas as pd
|
14
12
|
from vivarium import Component
|
@@ -158,15 +156,15 @@ class FertilityAgeSpecificRates(Component):
|
|
158
156
|
##############
|
159
157
|
|
160
158
|
@property
|
161
|
-
def columns_created(self) ->
|
159
|
+
def columns_created(self) -> list[str]:
|
162
160
|
return ["last_birth_time", "parent_id"]
|
163
161
|
|
164
162
|
@property
|
165
|
-
def columns_required(self) ->
|
163
|
+
def columns_required(self) -> list[str] | None:
|
166
164
|
return ["sex"]
|
167
165
|
|
168
166
|
@property
|
169
|
-
def initialization_requirements(self) ->
|
167
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
170
168
|
return {
|
171
169
|
"requires_columns": ["sex"],
|
172
170
|
"requires_values": [],
|
@@ -8,9 +8,7 @@ characteristics to simulants.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
12
|
-
|
13
|
-
from typing import Callable, Dict, Iterable, List
|
11
|
+
from collections.abc import Callable, Iterable
|
14
12
|
|
15
13
|
import numpy as np
|
16
14
|
import pandas as pd
|
@@ -48,7 +46,7 @@ class BasePopulation(Component):
|
|
48
46
|
##############
|
49
47
|
|
50
48
|
@property
|
51
|
-
def columns_created(self) ->
|
49
|
+
def columns_created(self) -> list[str]:
|
52
50
|
return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]
|
53
51
|
|
54
52
|
@property
|
@@ -91,7 +89,7 @@ class BasePopulation(Component):
|
|
91
89
|
#################
|
92
90
|
|
93
91
|
@staticmethod
|
94
|
-
def get_randomness_streams(builder: Builder) ->
|
92
|
+
def get_randomness_streams(builder: Builder) -> dict[str, RandomnessStream]:
|
95
93
|
return {
|
96
94
|
"general_purpose": builder.randomness.get_stream("population_generation"),
|
97
95
|
"bin_selection": builder.randomness.get_stream(
|
@@ -255,7 +253,7 @@ class AgeOutSimulants(Component):
|
|
255
253
|
##############
|
256
254
|
|
257
255
|
@property
|
258
|
-
def columns_required(self) ->
|
256
|
+
def columns_required(self) -> list[str]:
|
259
257
|
"""A list of the columns this component requires that it did not create."""
|
260
258
|
return self._columns_required
|
261
259
|
|
@@ -290,9 +288,9 @@ def generate_population(
|
|
290
288
|
simulant_ids: pd.Index,
|
291
289
|
creation_time: pd.Timestamp,
|
292
290
|
step_size: pd.Timedelta,
|
293
|
-
age_params:
|
291
|
+
age_params: dict[str, float],
|
294
292
|
demographic_proportions: pd.DataFrame,
|
295
|
-
randomness_streams:
|
293
|
+
randomness_streams: dict[str, RandomnessStream],
|
296
294
|
register_simulants: Callable[[pd.DataFrame], None],
|
297
295
|
key_columns: Iterable[str] = ("entrance_time", "age"),
|
298
296
|
) -> pd.DataFrame:
|
@@ -378,7 +376,7 @@ def _assign_demography_with_initial_age(
|
|
378
376
|
pop_data: pd.DataFrame,
|
379
377
|
initial_age: float,
|
380
378
|
step_size: pd.Timedelta,
|
381
|
-
randomness_streams:
|
379
|
+
randomness_streams: dict[str, RandomnessStream],
|
382
380
|
register_simulants: Callable[[pd.DataFrame], None],
|
383
381
|
) -> pd.DataFrame:
|
384
382
|
"""Assigns age, sex, and location information to the provided simulants given a fixed age.
|
@@ -441,7 +439,7 @@ def _assign_demography_with_age_bounds(
|
|
441
439
|
pop_data: pd.DataFrame,
|
442
440
|
age_start: float,
|
443
441
|
age_end: float,
|
444
|
-
randomness_streams:
|
442
|
+
randomness_streams: dict[str, RandomnessStream],
|
445
443
|
register_simulants: Callable[[pd.DataFrame], None],
|
446
444
|
key_columns: Iterable[str] = ("entrance_time", "age"),
|
447
445
|
) -> pd.DataFrame:
|
@@ -9,7 +9,6 @@ it into different distributions for sampling.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from collections import namedtuple
|
12
|
-
from typing import Tuple, Union
|
13
12
|
|
14
13
|
import numpy as np
|
15
14
|
import pandas as pd
|
@@ -302,7 +301,7 @@ def smooth_ages(
|
|
302
301
|
def _get_bins_and_proportions(
|
303
302
|
pop_data: pd.DataFrame,
|
304
303
|
age: AgeValues,
|
305
|
-
) ->
|
304
|
+
) -> tuple[EndpointValues, AgeValues]:
|
306
305
|
"""Finds and returns the bin edges and the population proportions in
|
307
306
|
the current and neighboring bins.
|
308
307
|
|
@@ -376,7 +375,7 @@ def _get_bins_and_proportions(
|
|
376
375
|
|
377
376
|
def _construct_sampling_parameters(
|
378
377
|
age: AgeValues, endpoint: EndpointValues, proportion: AgeValues
|
379
|
-
) ->
|
378
|
+
) -> tuple[EndpointValues, EndpointValues, float, float]:
|
380
379
|
"""Calculates some sampling distribution parameters from known values.
|
381
380
|
|
382
381
|
Parameters
|
@@ -442,12 +441,12 @@ def _construct_sampling_parameters(
|
|
442
441
|
|
443
442
|
|
444
443
|
def _compute_ages(
|
445
|
-
uniform_rv:
|
444
|
+
uniform_rv: np.ndarray | float,
|
446
445
|
start: float,
|
447
446
|
height: float,
|
448
447
|
slope: float,
|
449
448
|
normalization: float,
|
450
|
-
) ->
|
449
|
+
) -> np.ndarray | float:
|
451
450
|
"""Produces samples from the local age distribution.
|
452
451
|
|
453
452
|
Parameters
|
@@ -45,7 +45,7 @@ back the modified unmodeled csmr.
|
|
45
45
|
|
46
46
|
"""
|
47
47
|
|
48
|
-
from typing import Any
|
48
|
+
from typing import Any
|
49
49
|
|
50
50
|
import pandas as pd
|
51
51
|
from vivarium import Component
|
@@ -100,7 +100,7 @@ class Mortality(Component):
|
|
100
100
|
##############
|
101
101
|
|
102
102
|
@property
|
103
|
-
def configuration_defaults(self) ->
|
103
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
104
104
|
return {
|
105
105
|
"mortality": {
|
106
106
|
"data_sources": {
|
@@ -113,18 +113,18 @@ class Mortality(Component):
|
|
113
113
|
}
|
114
114
|
|
115
115
|
@property
|
116
|
-
def standard_lookup_tables(self) ->
|
116
|
+
def standard_lookup_tables(self) -> list[str]:
|
117
117
|
return [
|
118
118
|
"all_cause_mortality_rate",
|
119
119
|
"life_expectancy",
|
120
120
|
]
|
121
121
|
|
122
122
|
@property
|
123
|
-
def columns_created(self) ->
|
123
|
+
def columns_created(self) -> list[str]:
|
124
124
|
return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
|
125
125
|
|
126
126
|
@property
|
127
|
-
def columns_required(self) ->
|
127
|
+
def columns_required(self) -> list[str] | None:
|
128
128
|
return ["alive", "exit_time", "age", "sex"]
|
129
129
|
|
130
130
|
@property
|
@@ -184,7 +184,7 @@ class Mortality(Component):
|
|
184
184
|
requires_columns=required_columns,
|
185
185
|
)
|
186
186
|
|
187
|
-
def load_unmodeled_csmr(self, builder: Builder) ->
|
187
|
+
def load_unmodeled_csmr(self, builder: Builder) -> float | pd.DataFrame:
|
188
188
|
# todo validate that all data have the same columns
|
189
189
|
raw_csmr = 0.0
|
190
190
|
for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
|
@@ -8,8 +8,6 @@ in the simulation.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Union
|
12
|
-
|
13
11
|
import pandas as pd
|
14
12
|
from layered_config_tree import LayeredConfigTree
|
15
13
|
from loguru import logger
|
@@ -185,7 +183,7 @@ class DisabilityObserver(PublicHealthObserver):
|
|
185
183
|
# Aggregators #
|
186
184
|
###############
|
187
185
|
|
188
|
-
def disability_weight_aggregator(self, dw: pd.DataFrame) ->
|
186
|
+
def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
|
189
187
|
"""Aggregate disability weights for the time step.
|
190
188
|
|
191
189
|
Parameters
|
@@ -8,7 +8,7 @@ in the simulation.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from layered_config_tree import LayeredConfigTree
|
@@ -66,7 +66,7 @@ class DiseaseObserver(PublicHealthObserver):
|
|
66
66
|
##############
|
67
67
|
|
68
68
|
@property
|
69
|
-
def configuration_defaults(self) ->
|
69
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
70
70
|
"""A dictionary containing the defaults for any configurations managed by
|
71
71
|
this component.
|
72
72
|
"""
|
@@ -79,17 +79,17 @@ class DiseaseObserver(PublicHealthObserver):
|
|
79
79
|
}
|
80
80
|
|
81
81
|
@property
|
82
|
-
def columns_created(self) ->
|
82
|
+
def columns_created(self) -> list[str]:
|
83
83
|
"""Columns created by this observer."""
|
84
84
|
return [self.previous_state_column_name]
|
85
85
|
|
86
86
|
@property
|
87
|
-
def columns_required(self) ->
|
87
|
+
def columns_required(self) -> list[str]:
|
88
88
|
"""Columns required by this observer."""
|
89
89
|
return [self.disease]
|
90
90
|
|
91
91
|
@property
|
92
|
-
def initialization_requirements(self) ->
|
92
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
93
93
|
"""Requirements for observer initialization."""
|
94
94
|
return {
|
95
95
|
"requires_columns": [self.disease],
|
@@ -8,7 +8,7 @@ excess mortality in the simulation, including "other causes".
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from layered_config_tree import LayeredConfigTree
|
@@ -81,7 +81,7 @@ class MortalityObserver(PublicHealthObserver):
|
|
81
81
|
return [DiseaseState, RiskAttributableDisease]
|
82
82
|
|
83
83
|
@property
|
84
|
-
def configuration_defaults(self) ->
|
84
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
85
85
|
"""A dictionary containing the defaults for any configurations managed by
|
86
86
|
this component.
|
87
87
|
"""
|
@@ -90,7 +90,7 @@ class MortalityObserver(PublicHealthObserver):
|
|
90
90
|
return config_defaults
|
91
91
|
|
92
92
|
@property
|
93
|
-
def columns_required(self) ->
|
93
|
+
def columns_required(self) -> list[str]:
|
94
94
|
"""Columns required by this observer."""
|
95
95
|
return [
|
96
96
|
"alive",
|
@@ -8,7 +8,7 @@ public health models.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
11
|
+
from collections.abc import Callable
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from vivarium.framework.engine import Builder
|
@@ -32,12 +32,12 @@ class PublicHealthObserver(Observer):
|
|
32
32
|
name: str,
|
33
33
|
pop_filter: str,
|
34
34
|
when: str = "collect_metrics",
|
35
|
-
requires_columns:
|
36
|
-
requires_values:
|
37
|
-
additional_stratifications:
|
38
|
-
excluded_stratifications:
|
39
|
-
aggregator_sources:
|
40
|
-
aggregator: Callable[[pd.DataFrame],
|
35
|
+
requires_columns: list[str] = [],
|
36
|
+
requires_values: list[str] = [],
|
37
|
+
additional_stratifications: list[str] = [],
|
38
|
+
excluded_stratifications: list[str] = [],
|
39
|
+
aggregator_sources: list[str] | None = None,
|
40
|
+
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
|
41
41
|
) -> None:
|
42
42
|
"""Registers an adding observation to the results system.
|
43
43
|
|
@@ -7,7 +7,7 @@ This module contains tools for observing risk exposure during the simulation.
|
|
7
7
|
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Any
|
10
|
+
from typing import Any
|
11
11
|
|
12
12
|
import pandas as pd
|
13
13
|
from layered_config_tree import LayeredConfigTree
|
@@ -59,7 +59,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
|
|
59
59
|
##############
|
60
60
|
|
61
61
|
@property
|
62
|
-
def configuration_defaults(self) ->
|
62
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
63
63
|
"""A dictionary containing the defaults for any configurations managed by
|
64
64
|
this component.
|
65
65
|
"""
|
@@ -72,7 +72,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
|
|
72
72
|
}
|
73
73
|
|
74
74
|
@property
|
75
|
-
def columns_required(self) ->
|
75
|
+
def columns_required(self) -> list[str] | None:
|
76
76
|
"""The columns required by this observer."""
|
77
77
|
return ["alive"]
|
78
78
|
|
@@ -8,7 +8,7 @@ exposure.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from vivarium import Component
|
@@ -106,7 +106,7 @@ class Risk(Component):
|
|
106
106
|
return self.risk
|
107
107
|
|
108
108
|
@property
|
109
|
-
def configuration_defaults(self) ->
|
109
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
110
110
|
return {
|
111
111
|
self.name: {
|
112
112
|
"data_sources": {
|
@@ -122,14 +122,14 @@ class Risk(Component):
|
|
122
122
|
}
|
123
123
|
|
124
124
|
@property
|
125
|
-
def columns_created(self) ->
|
125
|
+
def columns_created(self) -> list[str]:
|
126
126
|
columns_to_create = [self.propensity_column_name]
|
127
127
|
if self.create_exposure_column:
|
128
128
|
columns_to_create.append(self.exposure_column_name)
|
129
129
|
return columns_to_create
|
130
130
|
|
131
131
|
@property
|
132
|
-
def initialization_requirements(self) ->
|
132
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
133
133
|
return {
|
134
134
|
"requires_columns": [],
|
135
135
|
"requires_values": [],
|
@@ -9,7 +9,7 @@ exposure distributions.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from abc import ABC, abstractmethod
|
12
|
-
from
|
12
|
+
from collections.abc import Callable
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
import pandas as pd
|
@@ -38,7 +38,7 @@ class RiskExposureDistribution(Component, ABC):
|
|
38
38
|
self,
|
39
39
|
risk: EntityString,
|
40
40
|
distribution_type: str,
|
41
|
-
exposure_data:
|
41
|
+
exposure_data: int | float | pd.DataFrame | None = None,
|
42
42
|
) -> None:
|
43
43
|
super().__init__()
|
44
44
|
self.risk = risk
|
@@ -51,14 +51,14 @@ class RiskExposureDistribution(Component, ABC):
|
|
51
51
|
# Setup methods #
|
52
52
|
#################
|
53
53
|
|
54
|
-
def get_configuration(self, builder: "Builder") ->
|
54
|
+
def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None:
|
55
55
|
return builder.configuration[self.risk]
|
56
56
|
|
57
57
|
@abstractmethod
|
58
58
|
def build_all_lookup_tables(self, builder: "Builder") -> None:
|
59
59
|
raise NotImplementedError
|
60
60
|
|
61
|
-
def get_exposure_data(self, builder: Builder) ->
|
61
|
+
def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
|
62
62
|
if self._exposure_data is not None:
|
63
63
|
return self._exposure_data
|
64
64
|
return self.get_data(builder, self.configuration["data_sources"]["exposure"])
|
@@ -92,11 +92,11 @@ class EnsembleDistribution(RiskExposureDistribution):
|
|
92
92
|
##############
|
93
93
|
|
94
94
|
@property
|
95
|
-
def columns_created(self) ->
|
95
|
+
def columns_created(self) -> list[str]:
|
96
96
|
return [self._propensity]
|
97
97
|
|
98
98
|
@property
|
99
|
-
def initialization_requirements(self) ->
|
99
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
100
100
|
return {
|
101
101
|
"requires_columns": [],
|
102
102
|
"requires_values": [],
|
@@ -265,7 +265,7 @@ class ContinuousDistribution(RiskExposureDistribution):
|
|
265
265
|
|
266
266
|
class PolytomousDistribution(RiskExposureDistribution):
|
267
267
|
@property
|
268
|
-
def categories(self) ->
|
268
|
+
def categories(self) -> list[str]:
|
269
269
|
# These need to be sorted so the cumulative sum is in the ocrrect order of categories
|
270
270
|
# and results are therefore reproducible and correct
|
271
271
|
return sorted(self.lookup_tables["exposure"].value_columns)
|
@@ -286,8 +286,8 @@ class PolytomousDistribution(RiskExposureDistribution):
|
|
286
286
|
)
|
287
287
|
|
288
288
|
def get_exposure_value_columns(
|
289
|
-
self, exposure_data:
|
290
|
-
) ->
|
289
|
+
self, exposure_data: int | float | pd.DataFrame
|
290
|
+
) -> list[str] | None:
|
291
291
|
if isinstance(exposure_data, pd.DataFrame):
|
292
292
|
return list(exposure_data["parameter"].unique())
|
293
293
|
return None
|
@@ -342,7 +342,7 @@ class DichotomousDistribution(RiskExposureDistribution):
|
|
342
342
|
)
|
343
343
|
self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
|
344
344
|
|
345
|
-
def get_exposure_data(self, builder: Builder) ->
|
345
|
+
def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
|
346
346
|
exposure_data = super().get_exposure_data(builder)
|
347
347
|
|
348
348
|
if isinstance(exposure_data, (int, float)):
|
@@ -373,8 +373,8 @@ class DichotomousDistribution(RiskExposureDistribution):
|
|
373
373
|
return exposure_data
|
374
374
|
|
375
375
|
def get_exposure_value_columns(
|
376
|
-
self, exposure_data:
|
377
|
-
) ->
|
376
|
+
self, exposure_data: int | float | pd.DataFrame
|
377
|
+
) -> list[str] | None:
|
378
378
|
if isinstance(exposure_data, pd.DataFrame):
|
379
379
|
return self.get_value_columns(exposure_data)
|
380
380
|
return None
|
@@ -470,9 +470,9 @@ def clip(q):
|
|
470
470
|
|
471
471
|
|
472
472
|
def get_risk_distribution_parameter(
|
473
|
-
value_columns_getter: Callable[[
|
474
|
-
data:
|
475
|
-
) ->
|
473
|
+
value_columns_getter: Callable[[pd.DataFrame], list[str]],
|
474
|
+
data: float | pd.DataFrame,
|
475
|
+
) -> float | pd.Series:
|
476
476
|
if isinstance(data, pd.DataFrame):
|
477
477
|
value_columns = value_columns_getter(data)
|
478
478
|
if len(value_columns) > 1:
|