vivarium-profiling 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vivarium/profiling/__init__.py +5 -0
  2. vivarium/profiling/_version.py +1 -0
  3. vivarium/profiling/components/__init__.py +0 -0
  4. vivarium/profiling/components/risks/effect.py +45 -0
  5. vivarium/profiling/constants/__init__.py +0 -0
  6. vivarium/profiling/constants/data_keys.py +240 -0
  7. vivarium/profiling/constants/data_values.py +34 -0
  8. vivarium/profiling/constants/metadata.py +42 -0
  9. vivarium/profiling/constants/models.py +11 -0
  10. vivarium/profiling/constants/paths.py +10 -0
  11. vivarium/profiling/constants/scenarios.py +29 -0
  12. vivarium/profiling/data/__init__.py +0 -0
  13. vivarium/profiling/data/builder.py +128 -0
  14. vivarium/profiling/data/loader.py +257 -0
  15. vivarium/profiling/model_specifications/branches/scenarios.yaml +7 -0
  16. vivarium/profiling/model_specifications/model_spec_scaling.yaml +81 -0
  17. vivarium/profiling/plugins/artifact.py +24 -0
  18. vivarium/profiling/plugins/parser.py +419 -0
  19. vivarium/profiling/templates/__init__.py +6 -0
  20. vivarium/profiling/templates/analysis_template.ipynb +207 -0
  21. vivarium/profiling/tools/__init__.py +4 -0
  22. vivarium/profiling/tools/app_logging.py +89 -0
  23. vivarium/profiling/tools/cli.py +387 -0
  24. vivarium/profiling/tools/extraction.py +438 -0
  25. vivarium/profiling/tools/make_artifacts.py +243 -0
  26. vivarium/profiling/tools/notebook_generator.py +56 -0
  27. vivarium/profiling/tools/plotting.py +448 -0
  28. vivarium/profiling/tools/run_benchmark.py +223 -0
  29. vivarium/profiling/tools/run_profile.py +65 -0
  30. vivarium/profiling/tools/summarize.py +190 -0
  31. vivarium/profiling/utilities.py +151 -0
  32. vivarium_profiling-0.4.0.dist-info/METADATA +233 -0
  33. vivarium_profiling-0.4.0.dist-info/RECORD +37 -0
  34. vivarium_profiling-0.4.0.dist-info/WHEEL +5 -0
  35. vivarium_profiling-0.4.0.dist-info/entry_points.txt +5 -0
  36. vivarium_profiling-0.4.0.dist-info/licenses/LICENSE +29 -0
  37. vivarium_profiling-0.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,5 @@
1
+ """vivarium.profiling
2
+
3
+ Profiling and benchmarking tools for Vivarium simulations.
4
+
5
+ """
@@ -0,0 +1 @@
1
+ __version__ = "0.4.0"\n
File without changes
@@ -0,0 +1,45 @@
1
+ import re
2
+
3
+ import pandas as pd
4
+ from vivarium.framework.engine import Builder
5
+ from vivarium_public_health.risks.effect import (
6
+ NonLogLinearRiskEffect as NonLogLinearRiskEffect_,
7
+ )
8
+ from vivarium_public_health.risks.effect import RiskEffect as RiskEffect_
9
+
10
+ """
11
+ Basic Risk Effect Wrapper for use in the MultiComponentParser
12
+ """
13
+
14
+
15
+ class RiskEffect(RiskEffect_):
16
+ def get_filtered_data(
17
+ self, builder: Builder, data_source: str | float | pd.DataFrame
18
+ ) -> float | pd.DataFrame:
19
+ data = super().get_data(builder, data_source)
20
+
21
+ if isinstance(data, pd.DataFrame):
22
+ # filter data to only include the target entity and measure
23
+ correct_target_mask = True
24
+ columns_to_drop = []
25
+ if "affected_entity" in data.columns:
26
+ # THIS IS THE ONLY CHANGE! We need to filter to the non-suffixed name
27
+ correct_target_mask &= data["affected_entity"] == re.sub(
28
+ r"(_\d+)$", "", self.target.name
29
+ )
30
+ columns_to_drop.append("affected_entity")
31
+ if "affected_measure" in data.columns:
32
+ correct_target_mask &= data["affected_measure"] == self.target.measure
33
+ columns_to_drop.append("affected_measure")
34
+ data = data[correct_target_mask].drop(columns=columns_to_drop)
35
+ return data
36
+
37
+
38
+ class NonLogLinearRiskEffect(NonLogLinearRiskEffect_, RiskEffect):
39
+ pass
40
+
41
+
42
+ from vivarium_public_health.risks.effect import (
43
+ NonLogLinearRiskEffect as NonLogLinearRiskEffect_,
44
+ )
45
+ from vivarium_public_health.risks.effect import RiskEffect as RiskEffect_
File without changes
@@ -0,0 +1,240 @@
1
+ from typing import NamedTuple
2
+
3
+ from vivarium_public_health.utilities import TargetString
4
+
5
+ #############
6
+ # Data Keys #
7
+ #############
8
+
9
+ METADATA_LOCATIONS = "metadata.locations"
10
+
11
+
12
+ class __Population(NamedTuple):
13
+ LOCATION: str = "population.location"
14
+ STRUCTURE: str = "population.structure"
15
+ AGE_BINS: str = "population.age_bins"
16
+ DEMOGRAPHY: str = "population.demographic_dimensions"
17
+ TMRLE: str = "population.theoretical_minimum_risk_life_expectancy"
18
+ ACMR: str = "cause.all_causes.cause_specific_mortality_rate"
19
+ LIVE_BIRTH_RATE: str = "covariate.live_births_by_sex.estimate"
20
+
21
+ @property
22
+ def name(self):
23
+ return "population"
24
+
25
+ @property
26
+ def log_name(self):
27
+ return "population"
28
+
29
+
30
+ POPULATION = __Population()
31
+
32
+
33
+ class __LRI(NamedTuple):
34
+
35
+ # Keys that will be loaded into the artifact. must have a colon type declaration
36
+ PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections.prevalence")
37
+ INCIDENCE_RATE: TargetString = TargetString(
38
+ "cause.lower_respiratory_infections.incidence_rate"
39
+ )
40
+ REMISSION_RATE: TargetString = TargetString(
41
+ "cause.lower_respiratory_infections.remission_rate"
42
+ )
43
+ DISABILITY_WEIGHT: TargetString = TargetString(
44
+ "cause.lower_respiratory_infections.disability_weight"
45
+ )
46
+ EMR: TargetString = TargetString(
47
+ "cause.lower_respiratory_infections.excess_mortality_rate"
48
+ )
49
+ CSMR: TargetString = TargetString(
50
+ "cause.lower_respiratory_infections.cause_specific_mortality_rate"
51
+ )
52
+ RESTRICTIONS: TargetString = TargetString(
53
+ "cause.lower_respiratory_infections.restrictions"
54
+ )
55
+
56
+ @property
57
+ def name(self):
58
+ return "lower_respiratory_infections"
59
+
60
+ @property
61
+ def log_name(self):
62
+ return self.name.replace("_", " ")
63
+
64
+
65
+ LRI = __LRI()
66
+
67
+
68
+ class __LRI2(NamedTuple):
69
+
70
+ # Keys that will be loaded into the artifact. must have a colon type declaration
71
+ PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_2.prevalence")
72
+ INCIDENCE_RATE: TargetString = TargetString(
73
+ "cause.lower_respiratory_infections_2.incidence_rate"
74
+ )
75
+ REMISSION_RATE: TargetString = TargetString(
76
+ "cause.lower_respiratory_infections_2.remission_rate"
77
+ )
78
+ DISABILITY_WEIGHT: TargetString = TargetString(
79
+ "cause.lower_respiratory_infections_2.disability_weight"
80
+ )
81
+ EMR: TargetString = TargetString(
82
+ "cause.lower_respiratory_infections_2.excess_mortality_rate"
83
+ )
84
+ CSMR: TargetString = TargetString(
85
+ "cause.lower_respiratory_infections_2.cause_specific_mortality_rate"
86
+ )
87
+ RESTRICTIONS: TargetString = TargetString(
88
+ "cause.lower_respiratory_infections_2.restrictions"
89
+ )
90
+
91
+ @property
92
+ def name(self):
93
+ return "lower_respiratory_infections_2"
94
+
95
+ @property
96
+ def log_name(self):
97
+ return self.name.replace("_", " ")
98
+
99
+
100
+ LRI2 = __LRI2()
101
+
102
+
103
+ class __LRI3(NamedTuple):
104
+
105
+ # Keys that will be loaded into the artifact. must have a colon type declaration
106
+ PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_3.prevalence")
107
+ INCIDENCE_RATE: TargetString = TargetString(
108
+ "cause.lower_respiratory_infections_3.incidence_rate"
109
+ )
110
+ REMISSION_RATE: TargetString = TargetString(
111
+ "cause.lower_respiratory_infections_3.remission_rate"
112
+ )
113
+ DISABILITY_WEIGHT: TargetString = TargetString(
114
+ "cause.lower_respiratory_infections_3.disability_weight"
115
+ )
116
+ EMR: TargetString = TargetString(
117
+ "cause.lower_respiratory_infections_3.excess_mortality_rate"
118
+ )
119
+ CSMR: TargetString = TargetString(
120
+ "cause.lower_respiratory_infections_3.cause_specific_mortality_rate"
121
+ )
122
+ RESTRICTIONS: TargetString = TargetString(
123
+ "cause.lower_respiratory_infections_3.restrictions"
124
+ )
125
+
126
+ @property
127
+ def name(self):
128
+ return "lower_respiratory_infections_3"
129
+
130
+ @property
131
+ def log_name(self):
132
+ return self.name.replace("_", " ")
133
+
134
+
135
+ LRI3 = __LRI3()
136
+
137
+
138
+ class __LRI4(NamedTuple):
139
+
140
+ # Keys that will be loaded into the artifact. must have a colon type declaration
141
+ PREVALENCE: TargetString = TargetString("cause.lower_respiratory_infections_4.prevalence")
142
+ INCIDENCE_RATE: TargetString = TargetString(
143
+ "cause.lower_respiratory_infections_4.incidence_rate"
144
+ )
145
+ REMISSION_RATE: TargetString = TargetString(
146
+ "cause.lower_respiratory_infections_4.remission_rate"
147
+ )
148
+ DISABILITY_WEIGHT: TargetString = TargetString(
149
+ "cause.lower_respiratory_infections_4.disability_weight"
150
+ )
151
+ EMR: TargetString = TargetString(
152
+ "cause.lower_respiratory_infections_4.excess_mortality_rate"
153
+ )
154
+ CSMR: TargetString = TargetString(
155
+ "cause.lower_respiratory_infections_4.cause_specific_mortality_rate"
156
+ )
157
+ RESTRICTIONS: TargetString = TargetString(
158
+ "cause.lower_respiratory_infections_4.restrictions"
159
+ )
160
+
161
+ @property
162
+ def name(self):
163
+ return "lower_respiratory_infections_4"
164
+
165
+ @property
166
+ def log_name(self):
167
+ return self.name.replace("_", " ")
168
+
169
+
170
+ LRI4 = __LRI4()
171
+
172
+
173
+ class __SBP(NamedTuple):
174
+ DISTRIBUTION: TargetString = TargetString(
175
+ "risk_factor.high_systolic_blood_pressure.distribution"
176
+ )
177
+ EXPOSURE_MEAN: TargetString = TargetString(
178
+ "risk_factor.high_systolic_blood_pressure.exposure"
179
+ )
180
+ EXPOSURE_SD: TargetString = TargetString(
181
+ "risk_factor.high_systolic_blood_pressure.exposure_standard_deviation"
182
+ )
183
+ EXPOSURE_WEIGHTS: TargetString = TargetString(
184
+ "risk_factor.high_systolic_blood_pressure.exposure_distribution_weights"
185
+ )
186
+ RELATIVE_RISK: TargetString = TargetString(
187
+ "risk_factor.high_systolic_blood_pressure.relative_risk"
188
+ )
189
+ PAF: TargetString = TargetString(
190
+ "risk_factor.high_systolic_blood_pressure.population_attributable_fraction"
191
+ )
192
+ TMRED: TargetString = TargetString("risk_factor.high_systolic_blood_pressure.tmred")
193
+ RELATIVE_RISK_SCALAR: TargetString = TargetString(
194
+ "risk_factor.high_systolic_blood_pressure.relative_risk_scalar"
195
+ )
196
+
197
+ @property
198
+ def name(self):
199
+ return "high_systolic_blood_pressure"
200
+
201
+ @property
202
+ def log_name(self):
203
+ return self.name.replace("_", " ")
204
+
205
+
206
+ SBP = __SBP()
207
+
208
+
209
+ class __Water(NamedTuple):
210
+ DISTRIBUTION: TargetString = TargetString("risk_factor.unsafe_water_source.distribution")
211
+ EXPOSURE: TargetString = TargetString("risk_factor.unsafe_water_source.exposure")
212
+ CATEGORIES: TargetString = TargetString("risk_factor.unsafe_water_source.categories")
213
+ RELATIVE_RISK: TargetString = TargetString(
214
+ "risk_factor.unsafe_water_source.relative_risk"
215
+ )
216
+ PAF: TargetString = TargetString(
217
+ "risk_factor.unsafe_water_source.population_attributable_fraction"
218
+ )
219
+
220
+ @property
221
+ def name(self):
222
+ return "unsafe_water_source"
223
+
224
+ @property
225
+ def log_name(self):
226
+ return self.name.replace("_", " ")
227
+
228
+
229
+ WATER = __Water()
230
+
231
+
232
+ MAKE_ARTIFACT_KEY_GROUPS = [
233
+ POPULATION,
234
+ LRI,
235
+ LRI2,
236
+ LRI3,
237
+ LRI4,
238
+ SBP,
239
+ WATER,
240
+ ]
@@ -0,0 +1,34 @@
1
+ from datetime import datetime
2
+
3
+ ############################
4
+ # Disease Model Parameters #
5
+ ############################
6
+
7
+ REMISSION_RATE = 0.1
8
+ MEAN_SOJOURN_TIME = 10
9
+
10
+
11
+ ##############################
12
+ # Screening Model Parameters #
13
+ ##############################
14
+
15
+ PROBABILITY_ATTENDING_SCREENING_KEY = "probability_attending_screening"
16
+ PROBABILITY_ATTENDING_SCREENING_START_MEAN = 0.25
17
+ PROBABILITY_ATTENDING_SCREENING_START_STDDEV = 0.0025
18
+ PROBABILITY_ATTENDING_SCREENING_END_MEAN = 0.5
19
+ PROBABILITY_ATTENDING_SCREENING_END_STDDEV = 0.005
20
+
21
+ FIRST_SCREENING_AGE = 21
22
+ MID_SCREENING_AGE = 30
23
+ LAST_SCREENING_AGE = 65
24
+
25
+
26
+ ###################################
27
+ # Scale-up Intervention Constants #
28
+ ###################################
29
+ SCALE_UP_START_DT = datetime(2021, 1, 1)
30
+ SCALE_UP_END_DT = datetime(2030, 1, 1)
31
+ SCREENING_SCALE_UP_GOAL_COVERAGE = 0.50
32
+ SCREENING_SCALE_UP_DIFFERENCE = (
33
+ SCREENING_SCALE_UP_GOAL_COVERAGE - PROBABILITY_ATTENDING_SCREENING_START_MEAN
34
+ )
@@ -0,0 +1,42 @@
1
+ from typing import NamedTuple
2
+
3
+ import pandas as pd
4
+
5
+ ####################
6
+ # Project metadata #
7
+ ####################
8
+
9
+ # Underscore form is intentional: this string is used to build paths on the
10
+ # shared cluster filesystem (e.g. /mnt/team/.../{PROJECT_NAME}/artifacts/).
11
+ # Changing it would break access to existing artifacts.
12
+ PROJECT_NAME = "vivarium_profiling"
13
+ CLUSTER_PROJECT = "proj_simscience_prod"
14
+
15
+ CLUSTER_QUEUE = "all.q"
16
+ MAKE_ARTIFACT_MEM = 10 # GB
17
+ MAKE_ARTIFACT_CPU = 1
18
+ MAKE_ARTIFACT_RUNTIME = "3:00:00"
19
+ MAKE_ARTIFACT_SLEEP = 10
20
+
21
+ LOCATIONS = [
22
+ "Pakistan",
23
+ ]
24
+
25
+ ARTIFACT_INDEX_COLUMNS = [
26
+ "sex",
27
+ "age_start",
28
+ "age_end",
29
+ "year_start",
30
+ "year_end",
31
+ ]
32
+
33
+ DRAW_COUNT = 1000
34
+ ARTIFACT_COLUMNS = pd.Index([f"draw_{i}" for i in range(DRAW_COUNT)])
35
+
36
+
37
+ class __Scenarios(NamedTuple):
38
+ baseline: str = "baseline"
39
+ # TODO - add scenarios here
40
+
41
+
42
+ SCENARIOS = __Scenarios()
@@ -0,0 +1,11 @@
1
+ from vivarium.profiling.constants import data_keys
2
+
3
+ ###########################
4
+ # Disease Model variables #
5
+ ###########################
6
+
7
+ # TODO input details of model states
8
+ SOME_MODEL_NAME = data_keys.SOME_DISEASE.name
9
+ SUSCEPTIBLE_STATE_NAME = f"susceptible_to_{SOME_MODEL_NAME}"
10
+ FIRST_STATE_NAME = "first_state"
11
+ SECOND_STATE_NAME = "second_state"
@@ -0,0 +1,10 @@
1
+ from pathlib import Path
2
+
3
+ import vivarium.profiling
4
+ from vivarium.profiling.constants import metadata
5
+
6
+ BASE_DIR = Path(vivarium.profiling.__file__).resolve().parent
7
+
8
+ ARTIFACT_ROOT = Path(
9
+ f"/mnt/team/simulation_science/pub/models/{metadata.PROJECT_NAME}/artifacts/"
10
+ )
@@ -0,0 +1,29 @@
1
+ from typing import NamedTuple
2
+
3
+ #############
4
+ # Scenarios #
5
+ #############
6
+
7
+
8
+ class InterventionScenario:
9
+ def __init__(
10
+ self,
11
+ name: str,
12
+ # todo add additional interventions
13
+ # has_treatment_one: bool = False,
14
+ # has_treatment_two: bool = False,
15
+ ):
16
+ self.name = name
17
+ # self.has_treatment_one = has_treatment_one
18
+ # self.has_treatment_two = has_treatment_two
19
+
20
+
21
+ class __InterventionScenarios(NamedTuple):
22
+ BASELINE: InterventionScenario = InterventionScenario("baseline")
23
+ # todo add additional intervention scenarios
24
+
25
+ def __get_item__(self, item):
26
+ return self._asdict()[item]
27
+
28
+
29
+ INTERVENTION_SCENARIOS = __InterventionScenarios()
File without changes
@@ -0,0 +1,128 @@
1
+ """Modularized functions for building project data artifacts.
2
+
3
+ This module is an abstraction around the load portion of our artifact building ETL pipeline.
4
+ The intent is to be declarative so it's easy to see what is put into the artifact and how.
5
+ Some degree of verbosity/boilerplate is fine in the interest of transparency.
6
+
7
+ .. admonition::
8
+
9
+ Logging in this module should be done at the ``debug`` level.
10
+
11
+ """
12
+ from pathlib import Path
13
+
14
+ import pandas as pd
15
+ from loguru import logger
16
+ from vivarium.framework.artifact import Artifact, EntityKey
17
+
18
+ from vivarium.profiling.constants import data_keys
19
+ from vivarium.profiling.data import loader
20
+
21
+
22
+ def open_artifact(output_path: Path, location: str) -> Artifact:
23
+ """Creates or opens an artifact at the output path.
24
+
25
+ Parameters
26
+ ----------
27
+ output_path
28
+ Fully resolved path to the artifact file.
29
+ location
30
+ Proper GBD location name represented by the artifact.
31
+
32
+ Returns
33
+ -------
34
+ A new artifact.
35
+
36
+ """
37
+ if not output_path.exists():
38
+ logger.debug(f"Creating artifact at {str(output_path)}.")
39
+ else:
40
+ logger.debug(f"Opening artifact at {str(output_path)} for appending.")
41
+
42
+ artifact = Artifact(output_path)
43
+
44
+ key = data_keys.METADATA_LOCATIONS
45
+ if key not in artifact:
46
+ artifact.write(key, [location])
47
+
48
+ return artifact
49
+
50
+
51
+ def load_and_write_data(
52
+ artifact: Artifact, key: str, location: str, years: str | None, replace: bool
53
+ ):
54
+ """Loads data and writes it to the artifact if not already present.
55
+
56
+ Parameters
57
+ ----------
58
+ artifact
59
+ The artifact to write to.
60
+ key
61
+ The entity key associated with the data to write.
62
+ location
63
+ The location associated with the data to load and the artifact to
64
+ write to.
65
+ replace
66
+ Flag which determines whether to overwrite existing data
67
+
68
+ """
69
+ if key in artifact and not replace:
70
+ logger.debug(f"Data for {key} already in artifact. Skipping...")
71
+ else:
72
+ logger.debug(f"Loading data for {key} for location {location}.")
73
+ # years is either a string we want to convert to an int, 'all', or None
74
+ years = int(years) if years and years != "all" else years
75
+ data = loader.get_data(key, location, years)
76
+ if key not in artifact:
77
+ logger.debug(f"Writing data for {key} to artifact.")
78
+ artifact.write(key, data)
79
+ else: # key is in artifact, but should be replaced
80
+ logger.debug(f"Replacing data for {key} in artifact.")
81
+ artifact.replace(key, data)
82
+ return artifact.load(key)
83
+
84
+
85
+ def write_data(artifact: Artifact, key: str, data: pd.DataFrame):
86
+ """Writes data to the artifact if not already present.
87
+
88
+ Parameters
89
+ ----------
90
+ artifact
91
+ The artifact to write to.
92
+ key
93
+ The entity key associated with the data to write.
94
+ data
95
+ The data to write.
96
+
97
+ """
98
+ if key in artifact:
99
+ logger.debug(f"Data for {key} already in artifact. Skipping...")
100
+ else:
101
+ logger.debug(f"Writing data for {key} to artifact.")
102
+ artifact.write(key, data)
103
+ return artifact.load(key)
104
+
105
+
106
+ # TODO - writing and reading by draw is necessary if you are using
107
+ # LBWSG data. Find the read function in utilities.py
108
+ def write_data_by_draw(artifact: Artifact, key: str, data: pd.DataFrame):
109
+ """Writes data to the artifact on a per-draw basis. This is useful
110
+ for large datasets like Low Birthweight Short Gestation (LBWSG).
111
+
112
+ Parameters
113
+ ----------
114
+ artifact
115
+ The artifact to write to.
116
+ key
117
+ The entity key associated with the data to write.
118
+ data
119
+ The data to write.
120
+
121
+ """
122
+ with pd.HDFStore(artifact.path, complevel=9, mode="a") as store:
123
+ key = EntityKey(key)
124
+ artifact._keys.append(key)
125
+ store.put(f"{key.path}/index", data.index.to_frame(index=False))
126
+ data = data.reset_index(drop=True)
127
+ for c in data.columns:
128
+ store.put(f"{key.path}/{c}", data[c])