vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
1
- __version__ = "2.3.2"
1
+ __version__ = "3.0.0"
@@ -8,7 +8,8 @@ function is to provide coordination across a set of disease states and
8
8
  transitions at simulation initialization and during transitions.
9
9
 
10
10
  """
11
- from typing import Callable, Dict, List, Optional
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
13
 
13
14
  import numpy as np
14
15
  import pandas as pd
@@ -27,10 +28,21 @@ class DiseaseModelError(VivariumError):
27
28
 
28
29
 
29
30
  class DiseaseModel(Machine):
31
+
30
32
  ##############
31
33
  # Properties #
32
34
  ##############
33
35
 
36
+ @property
37
+ def configuration_defaults(self) -> Dict[str, Any]:
38
+ return {
39
+ f"{self.name}": {
40
+ "data_sources": {
41
+ "cause_specific_mortality_rate": "self::load_cause_specific_mortality_rate"
42
+ },
43
+ },
44
+ }
45
+
34
46
  @property
35
47
  def columns_created(self) -> List[str]:
36
48
  return [self.state_column]
@@ -88,13 +100,6 @@ class DiseaseModel(Machine):
88
100
 
89
101
  self.configuration_age_start = builder.configuration.population.initialization_age_min
90
102
  self.configuration_age_end = builder.configuration.population.initialization_age_max
91
-
92
- cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder)
93
- self.cause_specific_mortality_rate = builder.lookup.build_table(
94
- cause_specific_mortality_rate,
95
- key_columns=["sex"],
96
- parameter_columns=["age", "year"],
97
- )
98
103
  builder.value.register_value_modifier(
99
104
  "cause_specific_mortality_rate",
100
105
  self.adjust_cause_specific_mortality_rate,
@@ -106,11 +111,13 @@ class DiseaseModel(Machine):
106
111
  # Setup methods #
107
112
  #################
108
113
 
109
- def load_cause_specific_mortality_rate_data(self, builder):
114
+ def load_cause_specific_mortality_rate(
115
+ self, builder: Builder
116
+ ) -> Union[float, pd.DataFrame]:
110
117
  if "cause_specific_mortality_rate" not in self._get_data_functions:
111
118
  only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
112
119
  if only_morbid:
113
- csmr_data = 0
120
+ csmr_data = 0.0
114
121
  else:
115
122
  csmr_data = builder.data.load(
116
123
  f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
@@ -130,8 +137,6 @@ class DiseaseModel(Machine):
130
137
 
131
138
  assert self.initial_state in {s.state_id for s in self.states}
132
139
 
133
- # FIXME: this is a hack to figure out whether or not we're at the simulation start based on the fact that the
134
- # fertility components create this user data
135
140
  if pop_data.user_data["sim_state"] == "setup": # simulation start
136
141
  if self.configuration_age_start != self.configuration_age_end != 0:
137
142
  state_names, weights_bins = self.get_state_weights(
@@ -186,7 +191,7 @@ class DiseaseModel(Machine):
186
191
  ##################################
187
192
 
188
193
  def adjust_cause_specific_mortality_rate(self, index, rate):
189
- return rate + self.cause_specific_mortality_rate(index)
194
+ return rate + self.lookup_tables["cause_specific_mortality_rate"](index)
190
195
 
191
196
  ####################
192
197
  # Helper functions #
@@ -198,18 +203,15 @@ class DiseaseModel(Machine):
198
203
  raise DiseaseModelError("Disease model must have exactly one SusceptibleState.")
199
204
  return susceptible_states[0].state_id
200
205
 
201
- def get_state_weights(self, pop_index, prevalence_type):
202
- states = [
203
- s
204
- for s in self.states
205
- if hasattr(s, f"{prevalence_type}")
206
- and getattr(s, f"{prevalence_type}") is not None
207
- ]
206
+ def get_state_weights(
207
+ self, pop_index: pd.Index, prevalence_type: str
208
+ ) -> Tuple[List[str], Union[np.ndarray, None]]:
209
+ states = [state for state in self.states if state.lookup_tables.get(prevalence_type)]
208
210
 
209
211
  if not states:
210
212
  return states, None
211
213
 
212
- weights = [getattr(s, f"{prevalence_type}")(pop_index) for s in states]
214
+ weights = [state.lookup_tables.get(prevalence_type)(pop_index) for state in states]
213
215
  for w in weights:
214
216
  w.reset_index(inplace=True, drop=True)
215
217
  weights += ((1 - np.sum(weights, axis=0)),)
@@ -7,6 +7,7 @@ This module contains a collection of frequently used parameterizations of
7
7
  disease models.
8
8
 
9
9
  """
10
+
10
11
  import pandas as pd
11
12
 
12
13
  from vivarium_public_health.disease.model import DiseaseModel
@@ -6,6 +6,7 @@
6
6
  This module contains frequently used, but non-standard disease models.
7
7
 
8
8
  """
9
+
9
10
  import re
10
11
  from collections import namedtuple
11
12
  from operator import gt, lt
@@ -13,12 +14,13 @@ from typing import Any, Dict, List, Optional
13
14
 
14
15
  import pandas as pd
15
16
  from vivarium import Component
17
+ from vivarium.framework.engine import Builder
16
18
  from vivarium.framework.event import Event
17
19
  from vivarium.framework.population import SimulantData
18
20
  from vivarium.framework.values import list_combiner, union_post_processor
19
21
 
20
22
  from vivarium_public_health.disease.transition import TransitionString
21
- from vivarium_public_health.utilities import EntityString, is_non_zero
23
+ from vivarium_public_health.utilities import EntityString, get_lookup_columns, is_non_zero
22
24
 
23
25
 
24
26
  class RiskAttributableDisease(Component):
@@ -85,25 +87,29 @@ class RiskAttributableDisease(Component):
85
87
  recoverable : True
86
88
  """
87
89
 
88
- CONFIGURATION_DEFAULTS = {
89
- "risk_attributable_disease": {
90
- "threshold": None,
91
- "mortality": True,
92
- "recoverable": True,
93
- }
94
- }
95
-
96
90
  ##############
97
91
  # Properties #
98
92
  ##############
99
93
 
100
94
  @property
101
95
  def name(self):
102
- return f"disease_model.{self.cause.name}"
96
+ return f"risk_attributable_disease.{self.cause.name}"
103
97
 
104
98
  @property
105
99
  def configuration_defaults(self) -> Dict[str, Any]:
106
- return {self.cause.name: self.CONFIGURATION_DEFAULTS["risk_attributable_disease"]}
100
+ return {
101
+ self.name: {
102
+ "data_sources": {
103
+ "raw_disability_weight": f"{self.cause}.disability_weight",
104
+ "cause_specific_mortality_rate": "self::load_cause_specific_mortality_rate_data",
105
+ "excess_mortality_rate": "self::load_excess_mortality_rate_data",
106
+ "population_attributable_fraction": 0,
107
+ },
108
+ "threshold": None,
109
+ "mortality": True,
110
+ "recoverable": True,
111
+ }
112
+ }
107
113
 
108
114
  @property
109
115
  def columns_created(self) -> List[str]:
@@ -142,6 +148,8 @@ class RiskAttributableDisease(Component):
142
148
  self.cause = EntityString(cause)
143
149
  self.risk = EntityString(risk)
144
150
  self.state_column = self.cause.name
151
+ self.cause_type = "risk_attributable_disease"
152
+ self.model = self.risk.name
145
153
  self.state_id = self.cause.name
146
154
  self.diseased_event_time_column = f"{self.cause.name}_event_time"
147
155
  self.susceptible_event_time_column = f"susceptible_to_{self.cause.name}_event_time"
@@ -157,51 +165,40 @@ class RiskAttributableDisease(Component):
157
165
 
158
166
  # noinspection PyAttributeOutsideInit
159
167
  def setup(self, builder):
160
- self.recoverable = builder.configuration[self.cause.name].recoverable
168
+ self.recoverable = builder.configuration[self.name].recoverable
161
169
  self.adjust_state_and_transitions()
162
170
  self.clock = builder.time.clock()
163
171
 
164
- disability_weight_data = builder.data.load(f"{self.cause}.disability_weight")
165
- self.has_disability = is_non_zero(disability_weight_data)
166
- self.base_disability_weight = builder.lookup.build_table(
167
- disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"]
168
- )
169
172
  self.disability_weight = builder.value.register_value_producer(
170
173
  f"{self.cause.name}.disability_weight",
171
174
  source=self.compute_disability_weight,
172
- requires_columns=["age", "sex", "alive", self.cause.name],
175
+ requires_columns=get_lookup_columns(
176
+ [self.lookup_tables["raw_disability_weight"]]
177
+ ),
173
178
  )
174
179
  builder.value.register_value_modifier(
175
- "disability_weight", modifier=self.disability_weight
176
- )
177
-
178
- cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder)
179
- self.cause_specific_mortality_rate = builder.lookup.build_table(
180
- cause_specific_mortality_rate,
181
- key_columns=["sex"],
182
- parameter_columns=["age", "year"],
180
+ "all_causes.disability_weight", modifier=self.disability_weight
183
181
  )
184
182
  builder.value.register_value_modifier(
185
183
  "cause_specific_mortality_rate",
186
184
  self.adjust_cause_specific_mortality_rate,
187
- requires_columns=["age", "sex"],
185
+ requires_columns=get_lookup_columns(
186
+ [self.lookup_tables["cause_specific_mortality_rate"]]
187
+ ),
188
188
  )
189
+ self.has_excess_mortality = is_non_zero(self.lookup_tables["excess_mortality_rate"])
189
190
 
190
- excess_mortality_data = self.load_excess_mortality_rate_data(builder)
191
- self.has_excess_mortality = is_non_zero(excess_mortality_data)
192
- self.base_excess_mortality_rate = builder.lookup.build_table(
193
- excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"]
194
- )
195
191
  self.excess_mortality_rate = builder.value.register_value_producer(
196
192
  self.excess_mortality_rate_pipeline_name,
197
193
  source=self.compute_excess_mortality_rate,
198
- requires_columns=["age", "sex", "alive", self.cause.name],
194
+ requires_columns=get_lookup_columns(
195
+ [self.lookup_tables["excess_mortality_rate"]]
196
+ ),
199
197
  requires_values=[self.excess_mortality_rate_paf_pipeline_name],
200
198
  )
201
- paf = builder.lookup.build_table(0)
202
199
  self.joint_paf = builder.value.register_value_producer(
203
200
  self.excess_mortality_rate_paf_pipeline_name,
204
- source=lambda idx: [paf(idx)],
201
+ source=lambda idx: [self.lookup_tables["population_attributable_fraction"](idx)],
205
202
  preferred_combiner=list_combiner,
206
203
  preferred_post_processor=union_post_processor,
207
204
  )
@@ -213,7 +210,7 @@ class RiskAttributableDisease(Component):
213
210
 
214
211
  distribution = builder.data.load(f"{self.risk}.distribution")
215
212
  exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure")
216
- threshold = builder.configuration[self.cause.name].threshold
213
+ threshold = builder.configuration[self.name].threshold
217
214
 
218
215
  self.filter_by_exposure = self.get_exposure_filter(
219
216
  distribution, exposure_pipeline, threshold
@@ -230,7 +227,7 @@ class RiskAttributableDisease(Component):
230
227
  )
231
228
 
232
229
  def load_cause_specific_mortality_rate_data(self, builder):
233
- if builder.configuration[self.cause.name].mortality:
230
+ if builder.configuration[self.name].mortality:
234
231
  csmr_data = builder.data.load(
235
232
  f"cause.{self.cause.name}.cause_specific_mortality_rate"
236
233
  )
@@ -239,7 +236,7 @@ class RiskAttributableDisease(Component):
239
236
  return csmr_data
240
237
 
241
238
  def load_excess_mortality_rate_data(self, builder):
242
- if builder.configuration[self.cause.name].mortality:
239
+ if builder.configuration[self.name].mortality:
243
240
  emr_data = builder.data.load(f"cause.{self.cause.name}.excess_mortality_rate")
244
241
  else:
245
242
  emr_data = 0
@@ -331,13 +328,15 @@ class RiskAttributableDisease(Component):
331
328
  def compute_disability_weight(self, index):
332
329
  disability_weight = pd.Series(0.0, index=index)
333
330
  with_condition = self.with_condition(index)
334
- disability_weight.loc[with_condition] = self.base_disability_weight(with_condition)
331
+ disability_weight.loc[with_condition] = self.lookup_tables["raw_disability_weight"](
332
+ with_condition
333
+ )
335
334
  return disability_weight
336
335
 
337
336
  def compute_excess_mortality_rate(self, index):
338
337
  excess_mortality_rate = pd.Series(0.0, index=index)
339
338
  with_condition = self.with_condition(index)
340
- base_excess_mort = self.base_excess_mortality_rate(with_condition)
339
+ base_excess_mort = self.lookup_tables["excess_mortality_rate"](with_condition)
341
340
  joint_mediated_paf = self.joint_paf(with_condition)
342
341
  excess_mortality_rate.loc[with_condition] = base_excess_mort * (
343
342
  1 - joint_mediated_paf.values
@@ -345,7 +344,7 @@ class RiskAttributableDisease(Component):
345
344
  return excess_mortality_rate
346
345
 
347
346
  def adjust_cause_specific_mortality_rate(self, index, rate):
348
- return rate + self.cause_specific_mortality_rate(index)
347
+ return rate + self.lookup_tables["cause_specific_mortality_rate"](index)
349
348
 
350
349
  def adjust_mortality_rate(self, index, rates_df):
351
350
  """Modifies the baseline mortality rate for a simulant if they are in this state.
@@ -6,12 +6,13 @@ Disease States
6
6
  This module contains tools to manage standard disease states.
7
7
 
8
8
  """
9
+
9
10
  from typing import Any, Callable, Dict, List, Optional
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
13
14
  from vivarium.framework.engine import Builder
14
- from vivarium.framework.lookup import LookupTable, LookupTableData
15
+ from vivarium.framework.lookup import LookupTableData
15
16
  from vivarium.framework.population import PopulationView, SimulantData
16
17
  from vivarium.framework.randomness import RandomnessStream
17
18
  from vivarium.framework.state_machine import State, Transient, Transition, Trigger
@@ -22,7 +23,7 @@ from vivarium_public_health.disease.transition import (
22
23
  RateTransition,
23
24
  TransitionString,
24
25
  )
25
- from vivarium_public_health.utilities import is_non_zero
26
+ from vivarium_public_health.utilities import get_lookup_columns, is_non_zero
26
27
 
27
28
 
28
29
  class BaseDiseaseState(State):
@@ -273,6 +274,24 @@ class RecoveredState(NonDiseasedState):
273
274
  class DiseaseState(BaseDiseaseState):
274
275
  """State representing a disease in a state machine model."""
275
276
 
277
+ ##############
278
+ # Properties #
279
+ ##############
280
+
281
+ @property
282
+ def configuration_defaults(self) -> Dict[str, Any]:
283
+ return {
284
+ f"{self.name}": {
285
+ "data_sources": {
286
+ "prevalence": "self::load_prevalence",
287
+ "birth_prevalence": "self::load_birth_prevalence",
288
+ "dwell_time": "self::load_dwell_time",
289
+ "disability_weight": "self::load_disability_weight",
290
+ "excess_mortality_rate": "self::load_excess_mortality_rate",
291
+ },
292
+ },
293
+ }
294
+
276
295
  #####################
277
296
  # Lifecycle methods #
278
297
  #####################
@@ -331,31 +350,15 @@ class DiseaseState(BaseDiseaseState):
331
350
  super().setup(builder)
332
351
  self.clock = builder.time.clock()
333
352
 
334
- prevalence_data = self.load_prevalence_data(builder)
335
- self.prevalence = self.get_prevalence(builder, prevalence_data)
336
-
337
- birth_prevalence_data = self.load_birth_prevalence_data(builder)
338
- self.birth_prevalence = self.get_birth_prevalence(builder, birth_prevalence_data)
339
-
340
- dwell_time_data = self.load_dwell_time_data(builder)
341
- self.dwell_time = self.get_dwell_time_pipeline(builder, dwell_time_data)
342
-
343
- disability_weight_data = self.load_disability_weight_data(builder)
344
- self.has_disability = is_non_zero(disability_weight_data)
345
- self.base_disability_weight = self.get_base_disability_weight(
346
- builder, disability_weight_data
347
- )
348
-
353
+ self.dwell_time = self.get_dwell_time_pipeline(builder)
349
354
  self.disability_weight = self.get_disability_weight_pipeline(builder)
350
355
 
351
356
  builder.value.register_value_modifier(
352
- "disability_weight", modifier=self.disability_weight
357
+ "all_causes.disability_weight", modifier=self.disability_weight
353
358
  )
354
359
 
355
- excess_mortality_data = self.load_excess_mortality_rate_data(builder)
356
- self.has_excess_mortality = is_non_zero(excess_mortality_data)
357
- self.base_excess_mortality_rate = self.get_base_excess_mortality_rate(
358
- builder, excess_mortality_data
360
+ self.has_excess_mortality = is_non_zero(
361
+ self.lookup_tables["excess_mortality_rate"].data
359
362
  )
360
363
  self.excess_mortality_rate = self.get_excess_mortality_rate_pipeline(builder)
361
364
  self.joint_paf = self.get_joint_paf(builder)
@@ -372,62 +375,19 @@ class DiseaseState(BaseDiseaseState):
372
375
  # Setup methods #
373
376
  #################
374
377
 
375
- def load_prevalence_data(self, builder: Builder) -> LookupTableData:
378
+ def load_prevalence(self, builder: Builder) -> LookupTableData:
376
379
  if "prevalence" in self._get_data_functions:
377
380
  return self._get_data_functions["prevalence"](builder, self.state_id)
378
381
  else:
379
382
  return builder.data.load(f"{self.cause_type}.{self.state_id}.prevalence")
380
383
 
381
- def get_prevalence(
382
- self, builder: Builder, prevalence_data: LookupTableData
383
- ) -> LookupTable:
384
- """Builds a LookupTable for the prevalence of this state.
385
-
386
- Parameters
387
- ----------
388
- builder
389
- Interface to access simulation managers.
390
- prevalence_data
391
- The data to use to build the LookupTable.
392
-
393
- Returns
394
- -------
395
- LookupTable
396
- The LookupTable for the prevalence of this state.
397
- """
398
- return builder.lookup.build_table(
399
- prevalence_data, key_columns=["sex"], parameter_columns=["age", "year"]
400
- )
401
-
402
- def load_birth_prevalence_data(self, builder: Builder) -> LookupTableData:
384
+ def load_birth_prevalence(self, builder: Builder) -> LookupTableData:
403
385
  if "birth_prevalence" in self._get_data_functions:
404
386
  return self._get_data_functions["birth_prevalence"](builder, self.state_id)
405
387
  else:
406
388
  return 0
407
389
 
408
- def get_birth_prevalence(
409
- self, builder: Builder, birth_prevalence_data: LookupTableData
410
- ) -> LookupTable:
411
- """
412
- Builds a LookupTable for the birth prevalence of this state.
413
-
414
- Parameters
415
- ----------
416
- builder
417
- Interface to access simulation managers.
418
- birth_prevalence_data
419
- The data to use to build the LookupTable.
420
-
421
- Returns
422
- -------
423
- LookupTable
424
- The LookupTable for the birth prevalence of this state.
425
- """
426
- return builder.lookup.build_table(
427
- birth_prevalence_data, key_columns=["sex"], parameter_columns=["year"]
428
- )
429
-
430
- def load_dwell_time_data(self, builder: Builder) -> LookupTableData:
390
+ def load_dwell_time(self, builder: Builder) -> LookupTableData:
431
391
  if "dwell_time" in self._get_data_functions:
432
392
  dwell_time = self._get_data_functions["dwell_time"](builder, self.state_id)
433
393
  else:
@@ -442,18 +402,15 @@ class DiseaseState(BaseDiseaseState):
442
402
 
443
403
  return dwell_time
444
404
 
445
- def get_dwell_time_pipeline(
446
- self, builder: Builder, dwell_time_data: LookupTableData
447
- ) -> Pipeline:
405
+ def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
406
+ required_columns = get_lookup_columns([self.lookup_tables["dwell_time"]])
448
407
  return builder.value.register_value_producer(
449
408
  f"{self.state_id}.dwell_time",
450
- source=builder.lookup.build_table(
451
- dwell_time_data, key_columns=["sex"], parameter_columns=["age", "year"]
452
- ),
453
- requires_columns=["age", "sex"],
409
+ source=self.lookup_tables["dwell_time"],
410
+ requires_columns=required_columns,
454
411
  )
455
412
 
456
- def load_disability_weight_data(self, builder: Builder) -> LookupTableData:
413
+ def load_disability_weight(self, builder: Builder) -> LookupTableData:
457
414
  if "disability_weight" in self._get_data_functions:
458
415
  disability_weight = self._get_data_functions["disability_weight"](
459
416
  builder, self.state_id
@@ -468,36 +425,15 @@ class DiseaseState(BaseDiseaseState):
468
425
 
469
426
  return disability_weight
470
427
 
471
- def get_base_disability_weight(
472
- self, builder: Builder, disability_weight_data: LookupTableData
473
- ) -> LookupTable:
474
- """
475
- Builds a LookupTable for the base disability weight of this state.
476
-
477
- Parameters
478
- ----------
479
- builder
480
- Interface to access simulation managers.
481
- disability_weight_data
482
- The data to use to build the LookupTable.
483
-
484
- Returns
485
- -------
486
- LookupTable
487
- The LookupTable for the disability weight of this state.
488
- """
489
- return builder.lookup.build_table(
490
- disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"]
491
- )
492
-
493
428
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
429
+ lookup_columns = get_lookup_columns([self.lookup_tables["disability_weight"]])
494
430
  return builder.value.register_value_producer(
495
431
  f"{self.state_id}.disability_weight",
496
432
  source=self.compute_disability_weight,
497
- requires_columns=["age", "sex", "alive", self.model],
433
+ requires_columns=lookup_columns + ["alive", self.model],
498
434
  )
499
435
 
500
- def load_excess_mortality_rate_data(self, builder: Builder) -> LookupTableData:
436
+ def load_excess_mortality_rate(self, builder: Builder) -> LookupTableData:
501
437
  if "excess_mortality_rate" in self._get_data_functions:
502
438
  return self._get_data_functions["excess_mortality_rate"](builder, self.state_id)
503
439
  elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
@@ -507,33 +443,12 @@ class DiseaseState(BaseDiseaseState):
507
443
  f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
508
444
  )
509
445
 
510
- def get_base_excess_mortality_rate(
511
- self, builder: Builder, excess_mortality_data: LookupTableData
512
- ) -> LookupTable:
513
- """
514
- Builds a LookupTable for the base excess mortality rate of this state.
515
-
516
- Parameters
517
- ----------
518
- builder
519
- Interface to access simulation managers.
520
- excess_mortality_data
521
- The data to use to build the LookupTable.
522
-
523
- Returns
524
- -------
525
- LookupTable
526
- The LookupTable for the base excess mortality rate of this state.
527
- """
528
- return builder.lookup.build_table(
529
- excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"]
530
- )
531
-
532
446
  def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
447
+ lookup_columns = get_lookup_columns([self.lookup_tables["excess_mortality_rate"]])
533
448
  return builder.value.register_rate_producer(
534
449
  self.excess_mortality_rate_pipeline_name,
535
450
  source=self.compute_excess_mortality_rate,
536
- requires_columns=["age", "sex", "alive", self.model],
451
+ requires_columns=lookup_columns + ["alive", self.model],
537
452
  requires_values=[self.excess_mortality_rate_paf_pipeline_name],
538
453
  )
539
454
 
@@ -618,13 +533,15 @@ class DiseaseState(BaseDiseaseState):
618
533
  """
619
534
  disability_weight = pd.Series(0.0, index=index)
620
535
  with_condition = self.with_condition(index)
621
- disability_weight.loc[with_condition] = self.base_disability_weight(with_condition)
536
+ disability_weight.loc[with_condition] = self.lookup_tables["disability_weight"](
537
+ with_condition
538
+ )
622
539
  return disability_weight
623
540
 
624
541
  def compute_excess_mortality_rate(self, index: pd.Index) -> pd.Series:
625
542
  excess_mortality_rate = pd.Series(0.0, index=index)
626
543
  with_condition = self.with_condition(index)
627
- base_excess_mort = self.base_excess_mortality_rate(with_condition)
544
+ base_excess_mort = self.lookup_tables["excess_mortality_rate"](with_condition)
628
545
  joint_mediated_paf = self.joint_paf(with_condition)
629
546
  excess_mortality_rate.loc[with_condition] = base_excess_mort * (
630
547
  1 - joint_mediated_paf.values