vivarium-profiling 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vivarium/profiling/__init__.py +5 -0
- vivarium/profiling/_version.py +1 -0
- vivarium/profiling/components/__init__.py +0 -0
- vivarium/profiling/components/risks/effect.py +45 -0
- vivarium/profiling/constants/__init__.py +0 -0
- vivarium/profiling/constants/data_keys.py +240 -0
- vivarium/profiling/constants/data_values.py +34 -0
- vivarium/profiling/constants/metadata.py +42 -0
- vivarium/profiling/constants/models.py +11 -0
- vivarium/profiling/constants/paths.py +10 -0
- vivarium/profiling/constants/scenarios.py +29 -0
- vivarium/profiling/data/__init__.py +0 -0
- vivarium/profiling/data/builder.py +128 -0
- vivarium/profiling/data/loader.py +257 -0
- vivarium/profiling/model_specifications/branches/scenarios.yaml +7 -0
- vivarium/profiling/model_specifications/model_spec_scaling.yaml +81 -0
- vivarium/profiling/plugins/artifact.py +24 -0
- vivarium/profiling/plugins/parser.py +419 -0
- vivarium/profiling/templates/__init__.py +6 -0
- vivarium/profiling/templates/analysis_template.ipynb +207 -0
- vivarium/profiling/tools/__init__.py +4 -0
- vivarium/profiling/tools/app_logging.py +89 -0
- vivarium/profiling/tools/cli.py +387 -0
- vivarium/profiling/tools/extraction.py +438 -0
- vivarium/profiling/tools/make_artifacts.py +243 -0
- vivarium/profiling/tools/notebook_generator.py +56 -0
- vivarium/profiling/tools/plotting.py +448 -0
- vivarium/profiling/tools/run_benchmark.py +223 -0
- vivarium/profiling/tools/run_profile.py +65 -0
- vivarium/profiling/tools/summarize.py +190 -0
- vivarium/profiling/utilities.py +151 -0
- vivarium_profiling-0.4.0.dist-info/METADATA +233 -0
- vivarium_profiling-0.4.0.dist-info/RECORD +37 -0
- vivarium_profiling-0.4.0.dist-info/WHEEL +5 -0
- vivarium_profiling-0.4.0.dist-info/entry_points.txt +5 -0
- vivarium_profiling-0.4.0.dist-info/licenses/LICENSE +29 -0
- vivarium_profiling-0.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.0"\n
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from vivarium.framework.engine import Builder
|
|
5
|
+
from vivarium_public_health.risks.effect import (
|
|
6
|
+
NonLogLinearRiskEffect as NonLogLinearRiskEffect_,
|
|
7
|
+
)
|
|
8
|
+
from vivarium_public_health.risks.effect import RiskEffect as RiskEffect_
|
|
9
|
+
|
|
10
|
+
"""
|
|
11
|
+
Basic Risk Effect Wrapper for use in the MultiComponentParser
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RiskEffect(RiskEffect_):
|
|
16
|
+
def get_filtered_data(
|
|
17
|
+
self, builder: Builder, data_source: str | float | pd.DataFrame
|
|
18
|
+
) -> float | pd.DataFrame:
|
|
19
|
+
data = super().get_data(builder, data_source)
|
|
20
|
+
|
|
21
|
+
if isinstance(data, pd.DataFrame):
|
|
22
|
+
# filter data to only include the target entity and measure
|
|
23
|
+
correct_target_mask = True
|
|
24
|
+
columns_to_drop = []
|
|
25
|
+
if "affected_entity" in data.columns:
|
|
26
|
+
# THIS IS THE ONLY CHANGE! We need to filter to the non-suffixed name
|
|
27
|
+
correct_target_mask &= data["affected_entity"] == re.sub(
|
|
28
|
+
r"(_\d+)$", "", self.target.name
|
|
29
|
+
)
|
|
30
|
+
columns_to_drop.append("affected_entity")
|
|
31
|
+
if "affected_measure" in data.columns:
|
|
32
|
+
correct_target_mask &= data["affected_measure"] == self.target.measure
|
|
33
|
+
columns_to_drop.append("affected_measure")
|
|
34
|
+
data = data[correct_target_mask].drop(columns=columns_to_drop)
|
|
35
|
+
return data
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class NonLogLinearRiskEffect(NonLogLinearRiskEffect_, RiskEffect):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
from vivarium_public_health.risks.effect import (
|
|
43
|
+
NonLogLinearRiskEffect as NonLogLinearRiskEffect_,
|
|
44
|
+
)
|
|
45
|
+
from vivarium_public_health.risks.effect import RiskEffect as RiskEffect_
|
|
File without changes
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
from vivarium_public_health.utilities import TargetString
|
|
4
|
+
|
|
5
|
+
#############
|
|
6
|
+
# Data Keys #
|
|
7
|
+
#############
|
|
8
|
+
|
|
9
|
+
METADATA_LOCATIONS = "metadata.locations"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class __Population(NamedTuple):
|
|
13
|
+
LOCATION: str = "population.location"
|
|
14
|
+
STRUCTURE: str = "population.structure"
|
|
15
|
+
AGE_BINS: str = "population.age_bins"
|
|
16
|
+
DEMOGRAPHY: str = "population.demographic_dimensions"
|
|
17
|
+
TMRLE: str = "population.theoretical_minimum_risk_life_expectancy"
|
|
18
|
+
ACMR: str = "cause.all_causes.cause_specific_mortality_rate"
|
|
19
|
+
LIVE_BIRTH_RATE: str = "covariate.live_births_by_sex.estimate"
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self):
|
|
23
|
+
return "population"
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def log_name(self):
|
|
27
|
+
return "population"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
POPULATION = __Population()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class __LRI(NamedTuple):
|
|
34
|
+
|
|
35
|
+
# Keys that will be loaded into the artifact. must have a colon type declaration
|
|
36
|
+
PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections.prevalence")
|
|
37
|
+
INCIDENCE_RATE: TargetString = TargetString(
|
|
38
|
+
"cause.lower_respiratory_infections.incidence_rate"
|
|
39
|
+
)
|
|
40
|
+
REMISSION_RATE: TargetString = TargetString(
|
|
41
|
+
"cause.lower_respiratory_infections.remission_rate"
|
|
42
|
+
)
|
|
43
|
+
DISABILITY_WEIGHT: TargetString = TargetString(
|
|
44
|
+
"cause.lower_respiratory_infections.disability_weight"
|
|
45
|
+
)
|
|
46
|
+
EMR: TargetString = TargetString(
|
|
47
|
+
"cause.lower_respiratory_infections.excess_mortality_rate"
|
|
48
|
+
)
|
|
49
|
+
CSMR: TargetString = TargetString(
|
|
50
|
+
"cause.lower_respiratory_infections.cause_specific_mortality_rate"
|
|
51
|
+
)
|
|
52
|
+
RESTRICTIONS: TargetString = TargetString(
|
|
53
|
+
"cause.lower_respiratory_infections.restrictions"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def name(self):
|
|
58
|
+
return "lower_respiratory_infections"
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def log_name(self):
|
|
62
|
+
return self.name.replace("_", " ")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
LRI = __LRI()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class __LRI2(NamedTuple):
|
|
69
|
+
|
|
70
|
+
# Keys that will be loaded into the artifact. must have a colon type declaration
|
|
71
|
+
PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_2.prevalence")
|
|
72
|
+
INCIDENCE_RATE: TargetString = TargetString(
|
|
73
|
+
"cause.lower_respiratory_infections_2.incidence_rate"
|
|
74
|
+
)
|
|
75
|
+
REMISSION_RATE: TargetString = TargetString(
|
|
76
|
+
"cause.lower_respiratory_infections_2.remission_rate"
|
|
77
|
+
)
|
|
78
|
+
DISABILITY_WEIGHT: TargetString = TargetString(
|
|
79
|
+
"cause.lower_respiratory_infections_2.disability_weight"
|
|
80
|
+
)
|
|
81
|
+
EMR: TargetString = TargetString(
|
|
82
|
+
"cause.lower_respiratory_infections_2.excess_mortality_rate"
|
|
83
|
+
)
|
|
84
|
+
CSMR: TargetString = TargetString(
|
|
85
|
+
"cause.lower_respiratory_infections_2.cause_specific_mortality_rate"
|
|
86
|
+
)
|
|
87
|
+
RESTRICTIONS: TargetString = TargetString(
|
|
88
|
+
"cause.lower_respiratory_infections_2.restrictions"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def name(self):
|
|
93
|
+
return "lower_respiratory_infections_2"
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def log_name(self):
|
|
97
|
+
return self.name.replace("_", " ")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
LRI2 = __LRI2()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class __LRI3(NamedTuple):
|
|
104
|
+
|
|
105
|
+
# Keys that will be loaded into the artifact. must have a colon type declaration
|
|
106
|
+
PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_3.prevalence")
|
|
107
|
+
INCIDENCE_RATE: TargetString = TargetString(
|
|
108
|
+
"cause.lower_respiratory_infections_3.incidence_rate"
|
|
109
|
+
)
|
|
110
|
+
REMISSION_RATE: TargetString = TargetString(
|
|
111
|
+
"cause.lower_respiratory_infections_3.remission_rate"
|
|
112
|
+
)
|
|
113
|
+
DISABILITY_WEIGHT: TargetString = TargetString(
|
|
114
|
+
"cause.lower_respiratory_infections_3.disability_weight"
|
|
115
|
+
)
|
|
116
|
+
EMR: TargetString = TargetString(
|
|
117
|
+
"cause.lower_respiratory_infections_3.excess_mortality_rate"
|
|
118
|
+
)
|
|
119
|
+
CSMR: TargetString = TargetString(
|
|
120
|
+
"cause.lower_respiratory_infections_3.cause_specific_mortality_rate"
|
|
121
|
+
)
|
|
122
|
+
RESTRICTIONS: TargetString = TargetString(
|
|
123
|
+
"cause.lower_respiratory_infections_3.restrictions"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def name(self):
|
|
128
|
+
return "lower_respiratory_infections_3"
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def log_name(self):
|
|
132
|
+
return self.name.replace("_", " ")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
LRI3 = __LRI3()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class __LRI4(NamedTuple):
|
|
139
|
+
|
|
140
|
+
# Keys that will be loaded into the artifact. must have a colon type declaration
|
|
141
|
+
PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_4.prevalence")
|
|
142
|
+
INCIDENCE_RATE: TargetString = TargetString(
|
|
143
|
+
"cause.lower_respiratory_infections_4.incidence_rate"
|
|
144
|
+
)
|
|
145
|
+
REMISSION_RATE: TargetString = TargetString(
|
|
146
|
+
"cause.lower_respiratory_infections_4.remission_rate"
|
|
147
|
+
)
|
|
148
|
+
DISABILITY_WEIGHT: TargetString = TargetString(
|
|
149
|
+
"cause.lower_respiratory_infections_4.disability_weight"
|
|
150
|
+
)
|
|
151
|
+
EMR: TargetString = TargetString(
|
|
152
|
+
"cause.lower_respiratory_infections_4.excess_mortality_rate"
|
|
153
|
+
)
|
|
154
|
+
CSMR: TargetString = TargetString(
|
|
155
|
+
"cause.lower_respiratory_infections_4.cause_specific_mortality_rate"
|
|
156
|
+
)
|
|
157
|
+
RESTRICTIONS: TargetString = TargetString(
|
|
158
|
+
"cause.lower_respiratory_infections_4.restrictions"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def name(self):
|
|
163
|
+
return "lower_respiratory_infections_4"
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def log_name(self):
|
|
167
|
+
return self.name.replace("_", " ")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
LRI4 = __LRI4()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class __SBP(NamedTuple):
|
|
174
|
+
DISTRIBUTION: TargetString = TargetString(
|
|
175
|
+
"risk_factor.high_systolic_blood_pressure.distribution"
|
|
176
|
+
)
|
|
177
|
+
EXPOSURE_MEAN: TargetString = TargetString(
|
|
178
|
+
"risk_factor.high_systolic_blood_pressure.exposure"
|
|
179
|
+
)
|
|
180
|
+
EXPOSURE_SD: TargetString = TargetString(
|
|
181
|
+
"risk_factor.high_systolic_blood_pressure.exposure_standard_deviation"
|
|
182
|
+
)
|
|
183
|
+
EXPOSURE_WEIGHTS: TargetString = TargetString(
|
|
184
|
+
"risk_factor.high_systolic_blood_pressure.exposure_distribution_weights"
|
|
185
|
+
)
|
|
186
|
+
RELATIVE_RISK: TargetString = TargetString(
|
|
187
|
+
"risk_factor.high_systolic_blood_pressure.relative_risk"
|
|
188
|
+
)
|
|
189
|
+
PAF: TargetString = TargetString(
|
|
190
|
+
"risk_factor.high_systolic_blood_pressure.population_attributable_fraction"
|
|
191
|
+
)
|
|
192
|
+
TMRED: TargetString = TargetString("risk_factor.high_systolic_blood_pressure.tmred")
|
|
193
|
+
RELATIVE_RISK_SCALAR: TargetString = TargetString(
|
|
194
|
+
"risk_factor.high_systolic_blood_pressure.relative_risk_scalar"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def name(self):
|
|
199
|
+
return "high_systolic_blood_pressure"
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def log_name(self):
|
|
203
|
+
return self.name.replace("_", " ")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
SBP = __SBP()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class __Water(NamedTuple):
|
|
210
|
+
DISTRIBUTION: TargetString = TargetString("risk_factor.unsafe_water_source.distribution")
|
|
211
|
+
EXPOSURE: TargetString = TargetString("risk_factor.unsafe_water_source.exposure")
|
|
212
|
+
CATEGORIES: TargetString = TargetString("risk_factor.unsafe_water_source.categories")
|
|
213
|
+
RELATIVE_RISK: TargetString = TargetString(
|
|
214
|
+
"risk_factor.unsafe_water_source.relative_risk"
|
|
215
|
+
)
|
|
216
|
+
PAF: TargetString = TargetString(
|
|
217
|
+
"risk_factor.unsafe_water_source.population_attributable_fraction"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def name(self):
|
|
222
|
+
return "unsafe_water_source"
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def log_name(self):
|
|
226
|
+
return self.name.replace("_", " ")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
WATER = __Water()
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
MAKE_ARTIFACT_KEY_GROUPS = [
|
|
233
|
+
POPULATION,
|
|
234
|
+
LRI,
|
|
235
|
+
LRI2,
|
|
236
|
+
LRI3,
|
|
237
|
+
LRI4,
|
|
238
|
+
SBP,
|
|
239
|
+
WATER,
|
|
240
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
############################
|
|
4
|
+
# Disease Model Parameters #
|
|
5
|
+
############################
|
|
6
|
+
|
|
7
|
+
REMISSION_RATE = 0.1
|
|
8
|
+
MEAN_SOJOURN_TIME = 10
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
##############################
|
|
12
|
+
# Screening Model Parameters #
|
|
13
|
+
##############################
|
|
14
|
+
|
|
15
|
+
PROBABILITY_ATTENDING_SCREENING_KEY = "probability_attending_screening"
|
|
16
|
+
PROBABILITY_ATTENDING_SCREENING_START_MEAN = 0.25
|
|
17
|
+
PROBABILITY_ATTENDING_SCREENING_START_STDDEV = 0.0025
|
|
18
|
+
PROBABILITY_ATTENDING_SCREENING_END_MEAN = 0.5
|
|
19
|
+
PROBABILITY_ATTENDING_SCREENING_END_STDDEV = 0.005
|
|
20
|
+
|
|
21
|
+
FIRST_SCREENING_AGE = 21
|
|
22
|
+
MID_SCREENING_AGE = 30
|
|
23
|
+
LAST_SCREENING_AGE = 65
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
###################################
|
|
27
|
+
# Scale-up Intervention Constants #
|
|
28
|
+
###################################
|
|
29
|
+
SCALE_UP_START_DT = datetime(2021, 1, 1)
|
|
30
|
+
SCALE_UP_END_DT = datetime(2030, 1, 1)
|
|
31
|
+
SCREENING_SCALE_UP_GOAL_COVERAGE = 0.50
|
|
32
|
+
SCREENING_SCALE_UP_DIFFERENCE = (
|
|
33
|
+
SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN
|
|
34
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
####################
|
|
6
|
+
# Project metadata #
|
|
7
|
+
####################
|
|
8
|
+
|
|
9
|
+
# Underscore form is intentional: this string is used to build paths on the
|
|
10
|
+
# shared cluster filesystem (e.g. /mnt/team/.../{PROJECT_NAME}/artifacts/).
|
|
11
|
+
# Changing it would break access to existing artifacts.
|
|
12
|
+
PROJECT_NAME = "vivarium_profiling"
|
|
13
|
+
CLUSTER_PROJECT = "proj_simscience_prod"
|
|
14
|
+
|
|
15
|
+
CLUSTER_QUEUE = "all.q"
|
|
16
|
+
MAKE_ARTIFACT_MEM = 10 # GB
|
|
17
|
+
MAKE_ARTIFACT_CPU = 1
|
|
18
|
+
MAKE_ARTIFACT_RUNTIME = "3:00:00"
|
|
19
|
+
MAKE_ARTIFACT_SLEEP = 10
|
|
20
|
+
|
|
21
|
+
LOCATIONS = [
|
|
22
|
+
"Pakistan",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
ARTIFACT_INDEX_COLUMNS = [
|
|
26
|
+
"sex",
|
|
27
|
+
"age_start",
|
|
28
|
+
"age_end",
|
|
29
|
+
"year_start",
|
|
30
|
+
"year_end",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
DRAW_COUNT = 1000
|
|
34
|
+
ARTIFACT_COLUMNS = pd.Index([f"draw_{i}" for i in range(DRAW_COUNT)])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class __Scenarios(NamedTuple):
|
|
38
|
+
baseline: str = "baseline"
|
|
39
|
+
# TODO - add scenarios here
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
SCENARIOS = __Scenarios()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from vivarium.profiling.constants import data_keys
|
|
2
|
+
|
|
3
|
+
###########################
|
|
4
|
+
# Disease Model variables #
|
|
5
|
+
###########################
|
|
6
|
+
|
|
7
|
+
# TODO input details of model states
|
|
8
|
+
SOME_MODEL_NAME = data_keys.SOME_DISEASE.name
|
|
9
|
+
SUSCEPTIBLE_STATE_NAME = f"susceptible_to_{SOME_MODEL_NAME}"
|
|
10
|
+
FIRST_STATE_NAME = "first_state"
|
|
11
|
+
SECOND_STATE_NAME = "second_state"
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import vivarium.profiling
|
|
4
|
+
from vivarium.profiling.constants import metadata
|
|
5
|
+
|
|
6
|
+
BASE_DIR = Path(vivarium.profiling.__file__).resolve().parent
|
|
7
|
+
|
|
8
|
+
ARTIFACT_ROOT = Path(
|
|
9
|
+
f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/"
|
|
10
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
#############
|
|
4
|
+
# Scenarios #
|
|
5
|
+
#############
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InterventionScenario:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
name: str,
|
|
12
|
+
# todo add additional interventions
|
|
13
|
+
# has_treatment_one: bool = False,
|
|
14
|
+
# has_treatment_two: bool = False,
|
|
15
|
+
):
|
|
16
|
+
self.name = name
|
|
17
|
+
# self.has_treatment_one = has_treatment_one
|
|
18
|
+
# self.has_treatment_two = has_treatment_two
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class __InterventionScenarios(NamedTuple):
|
|
22
|
+
BASELINE: InterventionScenario = InterventionScenario("baseline")
|
|
23
|
+
# todo add additional intervention scenarios
|
|
24
|
+
|
|
25
|
+
def __get_item__(self, item):
|
|
26
|
+
return self._asdict()[item]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
INTERVENTION_SCENARIOS = __InterventionScenarios()
|
|
File without changes
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Modularized functions for building project data artifacts.
|
|
2
|
+
|
|
3
|
+
This module is an abstraction around the load portion of our artifact building ETL pipeline.
|
|
4
|
+
The intent is to be declarative so it's easy to see what is put into the artifact and how.
|
|
5
|
+
Some degree of verbosity/boilerplate is fine in the interest of transparency.
|
|
6
|
+
|
|
7
|
+
.. admonition::
|
|
8
|
+
|
|
9
|
+
Logging in this module should be done at the ``debug`` level.
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from vivarium.framework.artifact import Artifact, EntityKey
|
|
17
|
+
|
|
18
|
+
from vivarium.profiling.constants import data_keys
|
|
19
|
+
from vivarium.profiling.data import loader
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def open_artifact(output_path: Path, location: str) -> Artifact:
|
|
23
|
+
"""Creates or opens an artifact at the output path.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
output_path
|
|
28
|
+
Fully resolved path to the artifact file.
|
|
29
|
+
location
|
|
30
|
+
Proper GBD location name represented by the artifact.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
A new artifact.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
if not output_path.exists():
|
|
38
|
+
logger.debug(f"Creating artifact at {str(output_path)}.")
|
|
39
|
+
else:
|
|
40
|
+
logger.debug(f"Opening artifact at {str(output_path)} for appending.")
|
|
41
|
+
|
|
42
|
+
artifact = Artifact(output_path)
|
|
43
|
+
|
|
44
|
+
key = data_keys.METADATA_LOCATIONS
|
|
45
|
+
if key not in artifact:
|
|
46
|
+
artifact.write(key, [location])
|
|
47
|
+
|
|
48
|
+
return artifact
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_and_write_data(
|
|
52
|
+
artifact: Artifact, key: str, location: str, years: str | None, replace: bool
|
|
53
|
+
):
|
|
54
|
+
"""Loads data and writes it to the artifact if not already present.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
artifact
|
|
59
|
+
The artifact to write to.
|
|
60
|
+
key
|
|
61
|
+
The entity key associated with the data to write.
|
|
62
|
+
location
|
|
63
|
+
The location associated with the data to load and the artifact to
|
|
64
|
+
write to.
|
|
65
|
+
replace
|
|
66
|
+
Flag which determines whether to overwrite existing data
|
|
67
|
+
|
|
68
|
+
"""
|
|
69
|
+
if key in artifact and not replace:
|
|
70
|
+
logger.debug(f"Data for {key} already in artifact. Skipping...")
|
|
71
|
+
else:
|
|
72
|
+
logger.debug(f"Loading data for {key} for location {location}.")
|
|
73
|
+
# years is either a string we want to convert to an int, 'all', or None
|
|
74
|
+
years = int(years) if years and years != "all" else years
|
|
75
|
+
data = loader.get_data(key, location, years)
|
|
76
|
+
if key not in artifact:
|
|
77
|
+
logger.debug(f"Writing data for {key} to artifact.")
|
|
78
|
+
artifact.write(key, data)
|
|
79
|
+
else: # key is in artifact, but should be replaced
|
|
80
|
+
logger.debug(f"Replacing data for {key} in artifact.")
|
|
81
|
+
artifact.replace(key, data)
|
|
82
|
+
return artifact.load(key)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def write_data(artifact: Artifact, key: str, data: pd.DataFrame):
|
|
86
|
+
"""Writes data to the artifact if not already present.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
artifact
|
|
91
|
+
The artifact to write to.
|
|
92
|
+
key
|
|
93
|
+
The entity key associated with the data to write.
|
|
94
|
+
data
|
|
95
|
+
The data to write.
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
if key in artifact:
|
|
99
|
+
logger.debug(f"Data for {key} already in artifact. Skipping...")
|
|
100
|
+
else:
|
|
101
|
+
logger.debug(f"Writing data for {key} to artifact.")
|
|
102
|
+
artifact.write(key, data)
|
|
103
|
+
return artifact.load(key)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# TODO - writing and reading by draw is necessary if you are using
|
|
107
|
+
# LBWSG data. Find the read function in utilities.py
|
|
108
|
+
def write_data_by_draw(artifact: Artifact, key: str, data: pd.DataFrame):
|
|
109
|
+
"""Writes data to the artifact on a per-draw basis. This is useful
|
|
110
|
+
for large datasets like Low Birthweight Short Gestation (LBWSG).
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
artifact
|
|
115
|
+
The artifact to write to.
|
|
116
|
+
key
|
|
117
|
+
The entity key associated with the data to write.
|
|
118
|
+
data
|
|
119
|
+
The data to write.
|
|
120
|
+
|
|
121
|
+
"""
|
|
122
|
+
with pd.HDFStore(artifact.path, complevel=9, mode="a") as store:
|
|
123
|
+
key = EntityKey(key)
|
|
124
|
+
artifact._keys.append(key)
|
|
125
|
+
store.put(f"{key.path}/index", data.index.to_frame(index=False))
|
|
126
|
+
data = data.reset_index(drop=True)
|
|
127
|
+
for c in data.columns:
|
|
128
|
+
store.put(f"{key.path}/{c}", data[c])
|