vivarium-public-health 3.1.1__py3-none-any.whl → 3.1.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vivarium_public_health/_version.py +1 -1
- vivarium_public_health/disease/model.py +12 -20
- vivarium_public_health/disease/special_disease.py +5 -5
- vivarium_public_health/disease/state.py +18 -17
- vivarium_public_health/disease/transition.py +9 -8
- vivarium_public_health/mslt/delay.py +5 -5
- vivarium_public_health/mslt/disease.py +5 -5
- vivarium_public_health/mslt/intervention.py +8 -8
- vivarium_public_health/mslt/magic_wand_components.py +3 -3
- vivarium_public_health/mslt/observer.py +4 -6
- vivarium_public_health/mslt/population.py +4 -6
- vivarium_public_health/plugins/parser.py +18 -19
- vivarium_public_health/population/add_new_birth_cohorts.py +3 -5
- vivarium_public_health/population/base_population.py +8 -10
- vivarium_public_health/population/data_transformations.py +4 -5
- vivarium_public_health/population/mortality.py +6 -6
- vivarium_public_health/results/disability.py +1 -3
- vivarium_public_health/results/disease.py +5 -5
- vivarium_public_health/results/mortality.py +3 -3
- vivarium_public_health/results/observer.py +7 -7
- vivarium_public_health/results/risk.py +3 -3
- vivarium_public_health/risks/base_risk.py +4 -4
- vivarium_public_health/risks/distributions.py +15 -15
- vivarium_public_health/risks/effect.py +45 -33
- vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +20 -20
- vivarium_public_health/treatment/magic_wand.py +3 -3
- vivarium_public_health/treatment/scale_up.py +6 -5
- vivarium_public_health/utilities.py +4 -5
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/METADATA +32 -32
- vivarium_public_health-3.1.3.dist-info/RECORD +49 -0
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/WHEEL +1 -1
- vivarium_public_health-3.1.1.dist-info/RECORD +0 -49
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/LICENSE.txt +0 -0
- {vivarium_public_health-3.1.1.dist-info → vivarium_public_health-3.1.3.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,9 @@ Health package.
|
|
10
10
|
|
11
11
|
"""
|
12
12
|
import warnings
|
13
|
+
from collections.abc import Callable
|
13
14
|
from importlib import import_module
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import pandas as pd
|
17
18
|
from layered_config_tree import LayeredConfigTree
|
@@ -37,7 +38,7 @@ from vivarium_public_health.utilities import TargetString
|
|
37
38
|
class CausesParsingErrors(ParsingError):
|
38
39
|
"""Error raised when there are any errors parsing a cause model configuration."""
|
39
40
|
|
40
|
-
def __init__(self, messages:
|
41
|
+
def __init__(self, messages: list[str]):
|
41
42
|
super().__init__("\n - " + "\n - ".join(messages))
|
42
43
|
|
43
44
|
|
@@ -81,7 +82,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
81
82
|
This value is used if the transition configuration does not explicity specify it.
|
82
83
|
"""
|
83
84
|
|
84
|
-
def parse_component_config(self, component_config: LayeredConfigTree) ->
|
85
|
+
def parse_component_config(self, component_config: LayeredConfigTree) -> list[Component]:
|
85
86
|
"""Parses the component configuration and returns a list of components.
|
86
87
|
|
87
88
|
In particular, this method looks for an `external_configuration` key
|
@@ -216,7 +217,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
216
217
|
|
217
218
|
def _get_cause_model_components(
|
218
219
|
self, causes_config: LayeredConfigTree
|
219
|
-
) ->
|
220
|
+
) -> list[Component]:
|
220
221
|
"""Parses the cause model configuration and returns the `DiseaseModel` components.
|
221
222
|
|
222
223
|
Parameters
|
@@ -236,7 +237,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
236
237
|
data_sources_config = cause_config.data_sources
|
237
238
|
data_sources = self._get_data_sources(data_sources_config)
|
238
239
|
|
239
|
-
states:
|
240
|
+
states: dict[str, BaseDiseaseState] = {
|
240
241
|
state_name: self._get_state(state_name, state_config, cause_name)
|
241
242
|
for state_name, state_config in cause_config.states.items()
|
242
243
|
}
|
@@ -351,7 +352,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
351
352
|
|
352
353
|
def _get_data_sources(
|
353
354
|
self, config: LayeredConfigTree
|
354
|
-
) ->
|
355
|
+
) -> dict[str, Callable[[Builder, Any], Any]]:
|
355
356
|
"""Parses a data sources configuration and returns the data sources.
|
356
357
|
|
357
358
|
Parameters
|
@@ -366,9 +367,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
366
367
|
return {name: self._get_data_source(name, config[name]) for name in config.keys()}
|
367
368
|
|
368
369
|
@staticmethod
|
369
|
-
def _get_data_source(
|
370
|
-
name: str, source: Union[str, float]
|
371
|
-
) -> Callable[[Builder, Any], Any]:
|
370
|
+
def _get_data_source(name: str, source: str | float) -> Callable[[Builder, Any], Any]:
|
372
371
|
"""Parses a data source and returns a callable that can be used to retrieve the data.
|
373
372
|
|
374
373
|
Parameters
|
@@ -500,7 +499,7 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
500
499
|
if error_messages:
|
501
500
|
raise CausesParsingErrors(error_messages)
|
502
501
|
|
503
|
-
def _validate_cause(self, cause_name: str, cause_config:
|
502
|
+
def _validate_cause(self, cause_name: str, cause_config: dict[str, Any]) -> list[str]:
|
504
503
|
"""Validates a cause configuration and returns a list of error messages.
|
505
504
|
|
506
505
|
Parameters
|
@@ -591,8 +590,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
591
590
|
return error_messages
|
592
591
|
|
593
592
|
def _validate_state(
|
594
|
-
self, cause_name: str, state_name: str, state_config:
|
595
|
-
) ->
|
593
|
+
self, cause_name: str, state_name: str, state_config: dict[str, Any]
|
594
|
+
) -> list[str]:
|
596
595
|
"""Validates a state configuration and returns a list of error messages.
|
597
596
|
|
598
597
|
Parameters
|
@@ -683,9 +682,9 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
683
682
|
self,
|
684
683
|
cause_name: str,
|
685
684
|
transition_name: str,
|
686
|
-
transition_config:
|
687
|
-
states_config:
|
688
|
-
) ->
|
685
|
+
transition_config: dict[str, Any],
|
686
|
+
states_config: dict[str, Any],
|
687
|
+
) -> list[str]:
|
689
688
|
"""Validates a transition configuration and returns a list of error messages.
|
690
689
|
|
691
690
|
Parameters
|
@@ -783,8 +782,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
783
782
|
|
784
783
|
@staticmethod
|
785
784
|
def _validate_imported_type(
|
786
|
-
import_path: str, cause_name: str, entity_type: str, entity_name:
|
787
|
-
) ->
|
785
|
+
import_path: str, cause_name: str, entity_type: str, entity_name: str | None = None
|
786
|
+
) -> list[str]:
|
788
787
|
"""Validates an imported type and returns a list of error messages.
|
789
788
|
|
790
789
|
Parameters
|
@@ -825,8 +824,8 @@ class CausesConfigurationParser(ComponentConfigurationParser):
|
|
825
824
|
return error_messages
|
826
825
|
|
827
826
|
def _validate_data_sources(
|
828
|
-
self, config:
|
829
|
-
) ->
|
827
|
+
self, config: dict[str, Any], cause_name: str, config_type: str, config_name: str
|
828
|
+
) -> list[str]:
|
830
829
|
"""Validates the data sources in a configuration and returns any error messages.
|
831
830
|
|
832
831
|
Parameters
|
@@ -7,8 +7,6 @@ This module contains several different models of fertility.
|
|
7
7
|
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Dict, List, Optional
|
11
|
-
|
12
10
|
import numpy as np
|
13
11
|
import pandas as pd
|
14
12
|
from vivarium import Component
|
@@ -158,15 +156,15 @@ class FertilityAgeSpecificRates(Component):
|
|
158
156
|
##############
|
159
157
|
|
160
158
|
@property
|
161
|
-
def columns_created(self) ->
|
159
|
+
def columns_created(self) -> list[str]:
|
162
160
|
return ["last_birth_time", "parent_id"]
|
163
161
|
|
164
162
|
@property
|
165
|
-
def columns_required(self) ->
|
163
|
+
def columns_required(self) -> list[str] | None:
|
166
164
|
return ["sex"]
|
167
165
|
|
168
166
|
@property
|
169
|
-
def initialization_requirements(self) ->
|
167
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
170
168
|
return {
|
171
169
|
"requires_columns": ["sex"],
|
172
170
|
"requires_values": [],
|
@@ -8,9 +8,7 @@ characteristics to simulants.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
12
|
-
|
13
|
-
from typing import Callable, Dict, Iterable, List
|
11
|
+
from collections.abc import Callable, Iterable
|
14
12
|
|
15
13
|
import numpy as np
|
16
14
|
import pandas as pd
|
@@ -48,7 +46,7 @@ class BasePopulation(Component):
|
|
48
46
|
##############
|
49
47
|
|
50
48
|
@property
|
51
|
-
def columns_created(self) ->
|
49
|
+
def columns_created(self) -> list[str]:
|
52
50
|
return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]
|
53
51
|
|
54
52
|
@property
|
@@ -91,7 +89,7 @@ class BasePopulation(Component):
|
|
91
89
|
#################
|
92
90
|
|
93
91
|
@staticmethod
|
94
|
-
def get_randomness_streams(builder: Builder) ->
|
92
|
+
def get_randomness_streams(builder: Builder) -> dict[str, RandomnessStream]:
|
95
93
|
return {
|
96
94
|
"general_purpose": builder.randomness.get_stream("population_generation"),
|
97
95
|
"bin_selection": builder.randomness.get_stream(
|
@@ -255,7 +253,7 @@ class AgeOutSimulants(Component):
|
|
255
253
|
##############
|
256
254
|
|
257
255
|
@property
|
258
|
-
def columns_required(self) ->
|
256
|
+
def columns_required(self) -> list[str]:
|
259
257
|
"""A list of the columns this component requires that it did not create."""
|
260
258
|
return self._columns_required
|
261
259
|
|
@@ -290,9 +288,9 @@ def generate_population(
|
|
290
288
|
simulant_ids: pd.Index,
|
291
289
|
creation_time: pd.Timestamp,
|
292
290
|
step_size: pd.Timedelta,
|
293
|
-
age_params:
|
291
|
+
age_params: dict[str, float],
|
294
292
|
demographic_proportions: pd.DataFrame,
|
295
|
-
randomness_streams:
|
293
|
+
randomness_streams: dict[str, RandomnessStream],
|
296
294
|
register_simulants: Callable[[pd.DataFrame], None],
|
297
295
|
key_columns: Iterable[str] = ("entrance_time", "age"),
|
298
296
|
) -> pd.DataFrame:
|
@@ -378,7 +376,7 @@ def _assign_demography_with_initial_age(
|
|
378
376
|
pop_data: pd.DataFrame,
|
379
377
|
initial_age: float,
|
380
378
|
step_size: pd.Timedelta,
|
381
|
-
randomness_streams:
|
379
|
+
randomness_streams: dict[str, RandomnessStream],
|
382
380
|
register_simulants: Callable[[pd.DataFrame], None],
|
383
381
|
) -> pd.DataFrame:
|
384
382
|
"""Assigns age, sex, and location information to the provided simulants given a fixed age.
|
@@ -441,7 +439,7 @@ def _assign_demography_with_age_bounds(
|
|
441
439
|
pop_data: pd.DataFrame,
|
442
440
|
age_start: float,
|
443
441
|
age_end: float,
|
444
|
-
randomness_streams:
|
442
|
+
randomness_streams: dict[str, RandomnessStream],
|
445
443
|
register_simulants: Callable[[pd.DataFrame], None],
|
446
444
|
key_columns: Iterable[str] = ("entrance_time", "age"),
|
447
445
|
) -> pd.DataFrame:
|
@@ -9,7 +9,6 @@ it into different distributions for sampling.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from collections import namedtuple
|
12
|
-
from typing import Tuple, Union
|
13
12
|
|
14
13
|
import numpy as np
|
15
14
|
import pandas as pd
|
@@ -302,7 +301,7 @@ def smooth_ages(
|
|
302
301
|
def _get_bins_and_proportions(
|
303
302
|
pop_data: pd.DataFrame,
|
304
303
|
age: AgeValues,
|
305
|
-
) ->
|
304
|
+
) -> tuple[EndpointValues, AgeValues]:
|
306
305
|
"""Finds and returns the bin edges and the population proportions in
|
307
306
|
the current and neighboring bins.
|
308
307
|
|
@@ -376,7 +375,7 @@ def _get_bins_and_proportions(
|
|
376
375
|
|
377
376
|
def _construct_sampling_parameters(
|
378
377
|
age: AgeValues, endpoint: EndpointValues, proportion: AgeValues
|
379
|
-
) ->
|
378
|
+
) -> tuple[EndpointValues, EndpointValues, float, float]:
|
380
379
|
"""Calculates some sampling distribution parameters from known values.
|
381
380
|
|
382
381
|
Parameters
|
@@ -442,12 +441,12 @@ def _construct_sampling_parameters(
|
|
442
441
|
|
443
442
|
|
444
443
|
def _compute_ages(
|
445
|
-
uniform_rv:
|
444
|
+
uniform_rv: np.ndarray | float,
|
446
445
|
start: float,
|
447
446
|
height: float,
|
448
447
|
slope: float,
|
449
448
|
normalization: float,
|
450
|
-
) ->
|
449
|
+
) -> np.ndarray | float:
|
451
450
|
"""Produces samples from the local age distribution.
|
452
451
|
|
453
452
|
Parameters
|
@@ -45,7 +45,7 @@ back the modified unmodeled csmr.
|
|
45
45
|
|
46
46
|
"""
|
47
47
|
|
48
|
-
from typing import Any
|
48
|
+
from typing import Any
|
49
49
|
|
50
50
|
import pandas as pd
|
51
51
|
from vivarium import Component
|
@@ -100,7 +100,7 @@ class Mortality(Component):
|
|
100
100
|
##############
|
101
101
|
|
102
102
|
@property
|
103
|
-
def configuration_defaults(self) ->
|
103
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
104
104
|
return {
|
105
105
|
"mortality": {
|
106
106
|
"data_sources": {
|
@@ -113,18 +113,18 @@ class Mortality(Component):
|
|
113
113
|
}
|
114
114
|
|
115
115
|
@property
|
116
|
-
def standard_lookup_tables(self) ->
|
116
|
+
def standard_lookup_tables(self) -> list[str]:
|
117
117
|
return [
|
118
118
|
"all_cause_mortality_rate",
|
119
119
|
"life_expectancy",
|
120
120
|
]
|
121
121
|
|
122
122
|
@property
|
123
|
-
def columns_created(self) ->
|
123
|
+
def columns_created(self) -> list[str]:
|
124
124
|
return [self.cause_of_death_column_name, self.years_of_life_lost_column_name]
|
125
125
|
|
126
126
|
@property
|
127
|
-
def columns_required(self) ->
|
127
|
+
def columns_required(self) -> list[str] | None:
|
128
128
|
return ["alive", "exit_time", "age", "sex"]
|
129
129
|
|
130
130
|
@property
|
@@ -184,7 +184,7 @@ class Mortality(Component):
|
|
184
184
|
requires_columns=required_columns,
|
185
185
|
)
|
186
186
|
|
187
|
-
def load_unmodeled_csmr(self, builder: Builder) ->
|
187
|
+
def load_unmodeled_csmr(self, builder: Builder) -> float | pd.DataFrame:
|
188
188
|
# todo validate that all data have the same columns
|
189
189
|
raw_csmr = 0.0
|
190
190
|
for idx, cause in enumerate(builder.configuration[self.name].unmodeled_causes):
|
@@ -8,8 +8,6 @@ in the simulation.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Union
|
12
|
-
|
13
11
|
import pandas as pd
|
14
12
|
from layered_config_tree import LayeredConfigTree
|
15
13
|
from loguru import logger
|
@@ -185,7 +183,7 @@ class DisabilityObserver(PublicHealthObserver):
|
|
185
183
|
# Aggregators #
|
186
184
|
###############
|
187
185
|
|
188
|
-
def disability_weight_aggregator(self, dw: pd.DataFrame) ->
|
186
|
+
def disability_weight_aggregator(self, dw: pd.DataFrame) -> float | pd.Series:
|
189
187
|
"""Aggregate disability weights for the time step.
|
190
188
|
|
191
189
|
Parameters
|
@@ -8,7 +8,7 @@ in the simulation.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from layered_config_tree import LayeredConfigTree
|
@@ -66,7 +66,7 @@ class DiseaseObserver(PublicHealthObserver):
|
|
66
66
|
##############
|
67
67
|
|
68
68
|
@property
|
69
|
-
def configuration_defaults(self) ->
|
69
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
70
70
|
"""A dictionary containing the defaults for any configurations managed by
|
71
71
|
this component.
|
72
72
|
"""
|
@@ -79,17 +79,17 @@ class DiseaseObserver(PublicHealthObserver):
|
|
79
79
|
}
|
80
80
|
|
81
81
|
@property
|
82
|
-
def columns_created(self) ->
|
82
|
+
def columns_created(self) -> list[str]:
|
83
83
|
"""Columns created by this observer."""
|
84
84
|
return [self.previous_state_column_name]
|
85
85
|
|
86
86
|
@property
|
87
|
-
def columns_required(self) ->
|
87
|
+
def columns_required(self) -> list[str]:
|
88
88
|
"""Columns required by this observer."""
|
89
89
|
return [self.disease]
|
90
90
|
|
91
91
|
@property
|
92
|
-
def initialization_requirements(self) ->
|
92
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
93
93
|
"""Requirements for observer initialization."""
|
94
94
|
return {
|
95
95
|
"requires_columns": [self.disease],
|
@@ -8,7 +8,7 @@ excess mortality in the simulation, including "other causes".
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from layered_config_tree import LayeredConfigTree
|
@@ -81,7 +81,7 @@ class MortalityObserver(PublicHealthObserver):
|
|
81
81
|
return [DiseaseState, RiskAttributableDisease]
|
82
82
|
|
83
83
|
@property
|
84
|
-
def configuration_defaults(self) ->
|
84
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
85
85
|
"""A dictionary containing the defaults for any configurations managed by
|
86
86
|
this component.
|
87
87
|
"""
|
@@ -90,7 +90,7 @@ class MortalityObserver(PublicHealthObserver):
|
|
90
90
|
return config_defaults
|
91
91
|
|
92
92
|
@property
|
93
|
-
def columns_required(self) ->
|
93
|
+
def columns_required(self) -> list[str]:
|
94
94
|
"""Columns required by this observer."""
|
95
95
|
return [
|
96
96
|
"alive",
|
@@ -8,7 +8,7 @@ public health models.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from
|
11
|
+
from collections.abc import Callable
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from vivarium.framework.engine import Builder
|
@@ -32,12 +32,12 @@ class PublicHealthObserver(Observer):
|
|
32
32
|
name: str,
|
33
33
|
pop_filter: str,
|
34
34
|
when: str = "collect_metrics",
|
35
|
-
requires_columns:
|
36
|
-
requires_values:
|
37
|
-
additional_stratifications:
|
38
|
-
excluded_stratifications:
|
39
|
-
aggregator_sources:
|
40
|
-
aggregator: Callable[[pd.DataFrame],
|
35
|
+
requires_columns: list[str] = [],
|
36
|
+
requires_values: list[str] = [],
|
37
|
+
additional_stratifications: list[str] = [],
|
38
|
+
excluded_stratifications: list[str] = [],
|
39
|
+
aggregator_sources: list[str] | None = None,
|
40
|
+
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
|
41
41
|
) -> None:
|
42
42
|
"""Registers an adding observation to the results system.
|
43
43
|
|
@@ -7,7 +7,7 @@ This module contains tools for observing risk exposure during the simulation.
|
|
7
7
|
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import Any
|
10
|
+
from typing import Any
|
11
11
|
|
12
12
|
import pandas as pd
|
13
13
|
from layered_config_tree import LayeredConfigTree
|
@@ -59,7 +59,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
|
|
59
59
|
##############
|
60
60
|
|
61
61
|
@property
|
62
|
-
def configuration_defaults(self) ->
|
62
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
63
63
|
"""A dictionary containing the defaults for any configurations managed by
|
64
64
|
this component.
|
65
65
|
"""
|
@@ -72,7 +72,7 @@ class CategoricalRiskObserver(PublicHealthObserver):
|
|
72
72
|
}
|
73
73
|
|
74
74
|
@property
|
75
|
-
def columns_required(self) ->
|
75
|
+
def columns_required(self) -> list[str] | None:
|
76
76
|
"""The columns required by this observer."""
|
77
77
|
return ["alive"]
|
78
78
|
|
@@ -8,7 +8,7 @@ exposure.
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
|
11
|
-
from typing import Any
|
11
|
+
from typing import Any
|
12
12
|
|
13
13
|
import pandas as pd
|
14
14
|
from vivarium import Component
|
@@ -106,7 +106,7 @@ class Risk(Component):
|
|
106
106
|
return self.risk
|
107
107
|
|
108
108
|
@property
|
109
|
-
def configuration_defaults(self) ->
|
109
|
+
def configuration_defaults(self) -> dict[str, Any]:
|
110
110
|
return {
|
111
111
|
self.name: {
|
112
112
|
"data_sources": {
|
@@ -122,14 +122,14 @@ class Risk(Component):
|
|
122
122
|
}
|
123
123
|
|
124
124
|
@property
|
125
|
-
def columns_created(self) ->
|
125
|
+
def columns_created(self) -> list[str]:
|
126
126
|
columns_to_create = [self.propensity_column_name]
|
127
127
|
if self.create_exposure_column:
|
128
128
|
columns_to_create.append(self.exposure_column_name)
|
129
129
|
return columns_to_create
|
130
130
|
|
131
131
|
@property
|
132
|
-
def initialization_requirements(self) ->
|
132
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
133
133
|
return {
|
134
134
|
"requires_columns": [],
|
135
135
|
"requires_values": [],
|
@@ -9,7 +9,7 @@ exposure distributions.
|
|
9
9
|
"""
|
10
10
|
|
11
11
|
from abc import ABC, abstractmethod
|
12
|
-
from
|
12
|
+
from collections.abc import Callable
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
import pandas as pd
|
@@ -38,7 +38,7 @@ class RiskExposureDistribution(Component, ABC):
|
|
38
38
|
self,
|
39
39
|
risk: EntityString,
|
40
40
|
distribution_type: str,
|
41
|
-
exposure_data:
|
41
|
+
exposure_data: int | float | pd.DataFrame | None = None,
|
42
42
|
) -> None:
|
43
43
|
super().__init__()
|
44
44
|
self.risk = risk
|
@@ -51,14 +51,14 @@ class RiskExposureDistribution(Component, ABC):
|
|
51
51
|
# Setup methods #
|
52
52
|
#################
|
53
53
|
|
54
|
-
def get_configuration(self, builder: "Builder") ->
|
54
|
+
def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None:
|
55
55
|
return builder.configuration[self.risk]
|
56
56
|
|
57
57
|
@abstractmethod
|
58
58
|
def build_all_lookup_tables(self, builder: "Builder") -> None:
|
59
59
|
raise NotImplementedError
|
60
60
|
|
61
|
-
def get_exposure_data(self, builder: Builder) ->
|
61
|
+
def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
|
62
62
|
if self._exposure_data is not None:
|
63
63
|
return self._exposure_data
|
64
64
|
return self.get_data(builder, self.configuration["data_sources"]["exposure"])
|
@@ -92,11 +92,11 @@ class EnsembleDistribution(RiskExposureDistribution):
|
|
92
92
|
##############
|
93
93
|
|
94
94
|
@property
|
95
|
-
def columns_created(self) ->
|
95
|
+
def columns_created(self) -> list[str]:
|
96
96
|
return [self._propensity]
|
97
97
|
|
98
98
|
@property
|
99
|
-
def initialization_requirements(self) ->
|
99
|
+
def initialization_requirements(self) -> dict[str, list[str]]:
|
100
100
|
return {
|
101
101
|
"requires_columns": [],
|
102
102
|
"requires_values": [],
|
@@ -265,7 +265,7 @@ class ContinuousDistribution(RiskExposureDistribution):
|
|
265
265
|
|
266
266
|
class PolytomousDistribution(RiskExposureDistribution):
|
267
267
|
@property
|
268
|
-
def categories(self) ->
|
268
|
+
def categories(self) -> list[str]:
|
269
269
|
# These need to be sorted so the cumulative sum is in the ocrrect order of categories
|
270
270
|
# and results are therefore reproducible and correct
|
271
271
|
return sorted(self.lookup_tables["exposure"].value_columns)
|
@@ -286,8 +286,8 @@ class PolytomousDistribution(RiskExposureDistribution):
|
|
286
286
|
)
|
287
287
|
|
288
288
|
def get_exposure_value_columns(
|
289
|
-
self, exposure_data:
|
290
|
-
) ->
|
289
|
+
self, exposure_data: int | float | pd.DataFrame
|
290
|
+
) -> list[str] | None:
|
291
291
|
if isinstance(exposure_data, pd.DataFrame):
|
292
292
|
return list(exposure_data["parameter"].unique())
|
293
293
|
return None
|
@@ -342,7 +342,7 @@ class DichotomousDistribution(RiskExposureDistribution):
|
|
342
342
|
)
|
343
343
|
self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
|
344
344
|
|
345
|
-
def get_exposure_data(self, builder: Builder) ->
|
345
|
+
def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
|
346
346
|
exposure_data = super().get_exposure_data(builder)
|
347
347
|
|
348
348
|
if isinstance(exposure_data, (int, float)):
|
@@ -373,8 +373,8 @@ class DichotomousDistribution(RiskExposureDistribution):
|
|
373
373
|
return exposure_data
|
374
374
|
|
375
375
|
def get_exposure_value_columns(
|
376
|
-
self, exposure_data:
|
377
|
-
) ->
|
376
|
+
self, exposure_data: int | float | pd.DataFrame
|
377
|
+
) -> list[str] | None:
|
378
378
|
if isinstance(exposure_data, pd.DataFrame):
|
379
379
|
return self.get_value_columns(exposure_data)
|
380
380
|
return None
|
@@ -470,9 +470,9 @@ def clip(q):
|
|
470
470
|
|
471
471
|
|
472
472
|
def get_risk_distribution_parameter(
|
473
|
-
value_columns_getter: Callable[[
|
474
|
-
data:
|
475
|
-
) ->
|
473
|
+
value_columns_getter: Callable[[pd.DataFrame], list[str]],
|
474
|
+
data: float | pd.DataFrame,
|
475
|
+
) -> float | pd.Series:
|
476
476
|
if isinstance(data, pd.DataFrame):
|
477
477
|
value_columns = value_columns_getter(data)
|
478
478
|
if len(value_columns) > 1:
|