vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -1 +1 @@
1
- __version__ = "2.3.2"
1
+ __version__ = "3.0.0"
@@ -8,7 +8,8 @@ function is to provide coordination across a set of disease states and
8
8
  transitions at simulation initialization and during transitions.
9
9
 
10
10
  """
11
- from typing import Callable, Dict, List, Optional
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12
13
 
13
14
  import numpy as np
14
15
  import pandas as pd
@@ -27,10 +28,21 @@ class DiseaseModelError(VivariumError):
27
28
 
28
29
 
29
30
  class DiseaseModel(Machine):
31
+
30
32
  ##############
31
33
  # Properties #
32
34
  ##############
33
35
 
36
+ @property
37
+ def configuration_defaults(self) -> Dict[str, Any]:
38
+ return {
39
+ f"{self.name}": {
40
+ "data_sources": {
41
+ "cause_specific_mortality_rate": "self::load_cause_specific_mortality_rate"
42
+ },
43
+ },
44
+ }
45
+
34
46
  @property
35
47
  def columns_created(self) -> List[str]:
36
48
  return [self.state_column]
@@ -88,13 +100,6 @@ class DiseaseModel(Machine):
88
100
 
89
101
  self.configuration_age_start = builder.configuration.population.initialization_age_min
90
102
  self.configuration_age_end = builder.configuration.population.initialization_age_max
91
-
92
- cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder)
93
- self.cause_specific_mortality_rate = builder.lookup.build_table(
94
- cause_specific_mortality_rate,
95
- key_columns=["sex"],
96
- parameter_columns=["age", "year"],
97
- )
98
103
  builder.value.register_value_modifier(
99
104
  "cause_specific_mortality_rate",
100
105
  self.adjust_cause_specific_mortality_rate,
@@ -106,11 +111,13 @@ class DiseaseModel(Machine):
106
111
  # Setup methods #
107
112
  #################
108
113
 
109
- def load_cause_specific_mortality_rate_data(self, builder):
114
+ def load_cause_specific_mortality_rate(
115
+ self, builder: Builder
116
+ ) -> Union[float, pd.DataFrame]:
110
117
  if "cause_specific_mortality_rate" not in self._get_data_functions:
111
118
  only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
112
119
  if only_morbid:
113
- csmr_data = 0
120
+ csmr_data = 0.0
114
121
  else:
115
122
  csmr_data = builder.data.load(
116
123
  f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
@@ -130,8 +137,6 @@ class DiseaseModel(Machine):
130
137
 
131
138
  assert self.initial_state in {s.state_id for s in self.states}
132
139
 
133
- # FIXME: this is a hack to figure out whether or not we're at the simulation start based on the fact that the
134
- # fertility components create this user data
135
140
  if pop_data.user_data["sim_state"] == "setup": # simulation start
136
141
  if self.configuration_age_start != self.configuration_age_end != 0:
137
142
  state_names, weights_bins = self.get_state_weights(
@@ -186,7 +191,7 @@ class DiseaseModel(Machine):
186
191
  ##################################
187
192
 
188
193
  def adjust_cause_specific_mortality_rate(self, index, rate):
189
- return rate + self.cause_specific_mortality_rate(index)
194
+ return rate + self.lookup_tables["cause_specific_mortality_rate"](index)
190
195
 
191
196
  ####################
192
197
  # Helper functions #
@@ -198,18 +203,15 @@ class DiseaseModel(Machine):
198
203
  raise DiseaseModelError("Disease model must have exactly one SusceptibleState.")
199
204
  return susceptible_states[0].state_id
200
205
 
201
- def get_state_weights(self, pop_index, prevalence_type):
202
- states = [
203
- s
204
- for s in self.states
205
- if hasattr(s, f"{prevalence_type}")
206
- and getattr(s, f"{prevalence_type}") is not None
207
- ]
206
+ def get_state_weights(
207
+ self, pop_index: pd.Index, prevalence_type: str
208
+ ) -> Tuple[List[str], Union[np.ndarray, None]]:
209
+ states = [state for state in self.states if state.lookup_tables.get(prevalence_type)]
208
210
 
209
211
  if not states:
210
212
  return states, None
211
213
 
212
- weights = [getattr(s, f"{prevalence_type}")(pop_index) for s in states]
214
+ weights = [state.lookup_tables.get(prevalence_type)(pop_index) for state in states]
213
215
  for w in weights:
214
216
  w.reset_index(inplace=True, drop=True)
215
217
  weights += ((1 - np.sum(weights, axis=0)),)
@@ -7,6 +7,7 @@ This module contains a collection of frequently used parameterizations of
7
7
  disease models.
8
8
 
9
9
  """
10
+
10
11
  import pandas as pd
11
12
 
12
13
  from vivarium_public_health.disease.model import DiseaseModel
@@ -6,6 +6,7 @@
6
6
  This module contains frequently used, but non-standard disease models.
7
7
 
8
8
  """
9
+
9
10
  import re
10
11
  from collections import namedtuple
11
12
  from operator import gt, lt
@@ -13,12 +14,13 @@ from typing import Any, Dict, List, Optional
13
14
 
14
15
  import pandas as pd
15
16
  from vivarium import Component
17
+ from vivarium.framework.engine import Builder
16
18
  from vivarium.framework.event import Event
17
19
  from vivarium.framework.population import SimulantData
18
20
  from vivarium.framework.values import list_combiner, union_post_processor
19
21
 
20
22
  from vivarium_public_health.disease.transition import TransitionString
21
- from vivarium_public_health.utilities import EntityString, is_non_zero
23
+ from vivarium_public_health.utilities import EntityString, get_lookup_columns, is_non_zero
22
24
 
23
25
 
24
26
  class RiskAttributableDisease(Component):
@@ -85,25 +87,29 @@ class RiskAttributableDisease(Component):
85
87
  recoverable : True
86
88
  """
87
89
 
88
- CONFIGURATION_DEFAULTS = {
89
- "risk_attributable_disease": {
90
- "threshold": None,
91
- "mortality": True,
92
- "recoverable": True,
93
- }
94
- }
95
-
96
90
  ##############
97
91
  # Properties #
98
92
  ##############
99
93
 
100
94
  @property
101
95
  def name(self):
102
- return f"disease_model.{self.cause.name}"
96
+ return f"risk_attributable_disease.{self.cause.name}"
103
97
 
104
98
  @property
105
99
  def configuration_defaults(self) -> Dict[str, Any]:
106
- return {self.cause.name: self.CONFIGURATION_DEFAULTS["risk_attributable_disease"]}
100
+ return {
101
+ self.name: {
102
+ "data_sources": {
103
+ "raw_disability_weight": f"{self.cause}.disability_weight",
104
+ "cause_specific_mortality_rate": "self::load_cause_specific_mortality_rate_data",
105
+ "excess_mortality_rate": "self::load_excess_mortality_rate_data",
106
+ "population_attributable_fraction": 0,
107
+ },
108
+ "threshold": None,
109
+ "mortality": True,
110
+ "recoverable": True,
111
+ }
112
+ }
107
113
 
108
114
  @property
109
115
  def columns_created(self) -> List[str]:
@@ -142,6 +148,8 @@ class RiskAttributableDisease(Component):
142
148
  self.cause = EntityString(cause)
143
149
  self.risk = EntityString(risk)
144
150
  self.state_column = self.cause.name
151
+ self.cause_type = "risk_attributable_disease"
152
+ self.model = self.risk.name
145
153
  self.state_id = self.cause.name
146
154
  self.diseased_event_time_column = f"{self.cause.name}_event_time"
147
155
  self.susceptible_event_time_column = f"susceptible_to_{self.cause.name}_event_time"
@@ -157,51 +165,40 @@ class RiskAttributableDisease(Component):
157
165
 
158
166
  # noinspection PyAttributeOutsideInit
159
167
  def setup(self, builder):
160
- self.recoverable = builder.configuration[self.cause.name].recoverable
168
+ self.recoverable = builder.configuration[self.name].recoverable
161
169
  self.adjust_state_and_transitions()
162
170
  self.clock = builder.time.clock()
163
171
 
164
- disability_weight_data = builder.data.load(f"{self.cause}.disability_weight")
165
- self.has_disability = is_non_zero(disability_weight_data)
166
- self.base_disability_weight = builder.lookup.build_table(
167
- disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"]
168
- )
169
172
  self.disability_weight = builder.value.register_value_producer(
170
173
  f"{self.cause.name}.disability_weight",
171
174
  source=self.compute_disability_weight,
172
- requires_columns=["age", "sex", "alive", self.cause.name],
175
+ requires_columns=get_lookup_columns(
176
+ [self.lookup_tables["raw_disability_weight"]]
177
+ ),
173
178
  )
174
179
  builder.value.register_value_modifier(
175
- "disability_weight", modifier=self.disability_weight
176
- )
177
-
178
- cause_specific_mortality_rate = self.load_cause_specific_mortality_rate_data(builder)
179
- self.cause_specific_mortality_rate = builder.lookup.build_table(
180
- cause_specific_mortality_rate,
181
- key_columns=["sex"],
182
- parameter_columns=["age", "year"],
180
+ "all_causes.disability_weight", modifier=self.disability_weight
183
181
  )
184
182
  builder.value.register_value_modifier(
185
183
  "cause_specific_mortality_rate",
186
184
  self.adjust_cause_specific_mortality_rate,
187
- requires_columns=["age", "sex"],
185
+ requires_columns=get_lookup_columns(
186
+ [self.lookup_tables["cause_specific_mortality_rate"]]
187
+ ),
188
188
  )
189
+ self.has_excess_mortality = is_non_zero(self.lookup_tables["excess_mortality_rate"])
189
190
 
190
- excess_mortality_data = self.load_excess_mortality_rate_data(builder)
191
- self.has_excess_mortality = is_non_zero(excess_mortality_data)
192
- self.base_excess_mortality_rate = builder.lookup.build_table(
193
- excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"]
194
- )
195
191
  self.excess_mortality_rate = builder.value.register_value_producer(
196
192
  self.excess_mortality_rate_pipeline_name,
197
193
  source=self.compute_excess_mortality_rate,
198
- requires_columns=["age", "sex", "alive", self.cause.name],
194
+ requires_columns=get_lookup_columns(
195
+ [self.lookup_tables["excess_mortality_rate"]]
196
+ ),
199
197
  requires_values=[self.excess_mortality_rate_paf_pipeline_name],
200
198
  )
201
- paf = builder.lookup.build_table(0)
202
199
  self.joint_paf = builder.value.register_value_producer(
203
200
  self.excess_mortality_rate_paf_pipeline_name,
204
- source=lambda idx: [paf(idx)],
201
+ source=lambda idx: [self.lookup_tables["population_attributable_fraction"](idx)],
205
202
  preferred_combiner=list_combiner,
206
203
  preferred_post_processor=union_post_processor,
207
204
  )
@@ -213,7 +210,7 @@ class RiskAttributableDisease(Component):
213
210
 
214
211
  distribution = builder.data.load(f"{self.risk}.distribution")
215
212
  exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure")
216
- threshold = builder.configuration[self.cause.name].threshold
213
+ threshold = builder.configuration[self.name].threshold
217
214
 
218
215
  self.filter_by_exposure = self.get_exposure_filter(
219
216
  distribution, exposure_pipeline, threshold
@@ -230,7 +227,7 @@ class RiskAttributableDisease(Component):
230
227
  )
231
228
 
232
229
  def load_cause_specific_mortality_rate_data(self, builder):
233
- if builder.configuration[self.cause.name].mortality:
230
+ if builder.configuration[self.name].mortality:
234
231
  csmr_data = builder.data.load(
235
232
  f"cause.{self.cause.name}.cause_specific_mortality_rate"
236
233
  )
@@ -239,7 +236,7 @@ class RiskAttributableDisease(Component):
239
236
  return csmr_data
240
237
 
241
238
  def load_excess_mortality_rate_data(self, builder):
242
- if builder.configuration[self.cause.name].mortality:
239
+ if builder.configuration[self.name].mortality:
243
240
  emr_data = builder.data.load(f"cause.{self.cause.name}.excess_mortality_rate")
244
241
  else:
245
242
  emr_data = 0
@@ -331,13 +328,15 @@ class RiskAttributableDisease(Component):
331
328
  def compute_disability_weight(self, index):
332
329
  disability_weight = pd.Series(0.0, index=index)
333
330
  with_condition = self.with_condition(index)
334
- disability_weight.loc[with_condition] = self.base_disability_weight(with_condition)
331
+ disability_weight.loc[with_condition] = self.lookup_tables["raw_disability_weight"](
332
+ with_condition
333
+ )
335
334
  return disability_weight
336
335
 
337
336
  def compute_excess_mortality_rate(self, index):
338
337
  excess_mortality_rate = pd.Series(0.0, index=index)
339
338
  with_condition = self.with_condition(index)
340
- base_excess_mort = self.base_excess_mortality_rate(with_condition)
339
+ base_excess_mort = self.lookup_tables["excess_mortality_rate"](with_condition)
341
340
  joint_mediated_paf = self.joint_paf(with_condition)
342
341
  excess_mortality_rate.loc[with_condition] = base_excess_mort * (
343
342
  1 - joint_mediated_paf.values
@@ -345,7 +344,7 @@ class RiskAttributableDisease(Component):
345
344
  return excess_mortality_rate
346
345
 
347
346
  def adjust_cause_specific_mortality_rate(self, index, rate):
348
- return rate + self.cause_specific_mortality_rate(index)
347
+ return rate + self.lookup_tables["cause_specific_mortality_rate"](index)
349
348
 
350
349
  def adjust_mortality_rate(self, index, rates_df):
351
350
  """Modifies the baseline mortality rate for a simulant if they are in this state.
@@ -6,12 +6,13 @@ Disease States
6
6
  This module contains tools to manage standard disease states.
7
7
 
8
8
  """
9
+
9
10
  from typing import Any, Callable, Dict, List, Optional
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
13
14
  from vivarium.framework.engine import Builder
14
- from vivarium.framework.lookup import LookupTable, LookupTableData
15
+ from vivarium.framework.lookup import LookupTableData
15
16
  from vivarium.framework.population import PopulationView, SimulantData
16
17
  from vivarium.framework.randomness import RandomnessStream
17
18
  from vivarium.framework.state_machine import State, Transient, Transition, Trigger
@@ -22,7 +23,7 @@ from vivarium_public_health.disease.transition import (
22
23
  RateTransition,
23
24
  TransitionString,
24
25
  )
25
- from vivarium_public_health.utilities import is_non_zero
26
+ from vivarium_public_health.utilities import get_lookup_columns, is_non_zero
26
27
 
27
28
 
28
29
  class BaseDiseaseState(State):
@@ -273,6 +274,24 @@ class RecoveredState(NonDiseasedState):
273
274
  class DiseaseState(BaseDiseaseState):
274
275
  """State representing a disease in a state machine model."""
275
276
 
277
+ ##############
278
+ # Properties #
279
+ ##############
280
+
281
+ @property
282
+ def configuration_defaults(self) -> Dict[str, Any]:
283
+ return {
284
+ f"{self.name}": {
285
+ "data_sources": {
286
+ "prevalence": "self::load_prevalence",
287
+ "birth_prevalence": "self::load_birth_prevalence",
288
+ "dwell_time": "self::load_dwell_time",
289
+ "disability_weight": "self::load_disability_weight",
290
+ "excess_mortality_rate": "self::load_excess_mortality_rate",
291
+ },
292
+ },
293
+ }
294
+
276
295
  #####################
277
296
  # Lifecycle methods #
278
297
  #####################
@@ -331,31 +350,15 @@ class DiseaseState(BaseDiseaseState):
331
350
  super().setup(builder)
332
351
  self.clock = builder.time.clock()
333
352
 
334
- prevalence_data = self.load_prevalence_data(builder)
335
- self.prevalence = self.get_prevalence(builder, prevalence_data)
336
-
337
- birth_prevalence_data = self.load_birth_prevalence_data(builder)
338
- self.birth_prevalence = self.get_birth_prevalence(builder, birth_prevalence_data)
339
-
340
- dwell_time_data = self.load_dwell_time_data(builder)
341
- self.dwell_time = self.get_dwell_time_pipeline(builder, dwell_time_data)
342
-
343
- disability_weight_data = self.load_disability_weight_data(builder)
344
- self.has_disability = is_non_zero(disability_weight_data)
345
- self.base_disability_weight = self.get_base_disability_weight(
346
- builder, disability_weight_data
347
- )
348
-
353
+ self.dwell_time = self.get_dwell_time_pipeline(builder)
349
354
  self.disability_weight = self.get_disability_weight_pipeline(builder)
350
355
 
351
356
  builder.value.register_value_modifier(
352
- "disability_weight", modifier=self.disability_weight
357
+ "all_causes.disability_weight", modifier=self.disability_weight
353
358
  )
354
359
 
355
- excess_mortality_data = self.load_excess_mortality_rate_data(builder)
356
- self.has_excess_mortality = is_non_zero(excess_mortality_data)
357
- self.base_excess_mortality_rate = self.get_base_excess_mortality_rate(
358
- builder, excess_mortality_data
360
+ self.has_excess_mortality = is_non_zero(
361
+ self.lookup_tables["excess_mortality_rate"].data
359
362
  )
360
363
  self.excess_mortality_rate = self.get_excess_mortality_rate_pipeline(builder)
361
364
  self.joint_paf = self.get_joint_paf(builder)
@@ -372,62 +375,19 @@ class DiseaseState(BaseDiseaseState):
372
375
  # Setup methods #
373
376
  #################
374
377
 
375
- def load_prevalence_data(self, builder: Builder) -> LookupTableData:
378
+ def load_prevalence(self, builder: Builder) -> LookupTableData:
376
379
  if "prevalence" in self._get_data_functions:
377
380
  return self._get_data_functions["prevalence"](builder, self.state_id)
378
381
  else:
379
382
  return builder.data.load(f"{self.cause_type}.{self.state_id}.prevalence")
380
383
 
381
- def get_prevalence(
382
- self, builder: Builder, prevalence_data: LookupTableData
383
- ) -> LookupTable:
384
- """Builds a LookupTable for the prevalence of this state.
385
-
386
- Parameters
387
- ----------
388
- builder
389
- Interface to access simulation managers.
390
- prevalence_data
391
- The data to use to build the LookupTable.
392
-
393
- Returns
394
- -------
395
- LookupTable
396
- The LookupTable for the prevalence of this state.
397
- """
398
- return builder.lookup.build_table(
399
- prevalence_data, key_columns=["sex"], parameter_columns=["age", "year"]
400
- )
401
-
402
- def load_birth_prevalence_data(self, builder: Builder) -> LookupTableData:
384
+ def load_birth_prevalence(self, builder: Builder) -> LookupTableData:
403
385
  if "birth_prevalence" in self._get_data_functions:
404
386
  return self._get_data_functions["birth_prevalence"](builder, self.state_id)
405
387
  else:
406
388
  return 0
407
389
 
408
- def get_birth_prevalence(
409
- self, builder: Builder, birth_prevalence_data: LookupTableData
410
- ) -> LookupTable:
411
- """
412
- Builds a LookupTable for the birth prevalence of this state.
413
-
414
- Parameters
415
- ----------
416
- builder
417
- Interface to access simulation managers.
418
- birth_prevalence_data
419
- The data to use to build the LookupTable.
420
-
421
- Returns
422
- -------
423
- LookupTable
424
- The LookupTable for the birth prevalence of this state.
425
- """
426
- return builder.lookup.build_table(
427
- birth_prevalence_data, key_columns=["sex"], parameter_columns=["year"]
428
- )
429
-
430
- def load_dwell_time_data(self, builder: Builder) -> LookupTableData:
390
+ def load_dwell_time(self, builder: Builder) -> LookupTableData:
431
391
  if "dwell_time" in self._get_data_functions:
432
392
  dwell_time = self._get_data_functions["dwell_time"](builder, self.state_id)
433
393
  else:
@@ -442,18 +402,15 @@ class DiseaseState(BaseDiseaseState):
442
402
 
443
403
  return dwell_time
444
404
 
445
- def get_dwell_time_pipeline(
446
- self, builder: Builder, dwell_time_data: LookupTableData
447
- ) -> Pipeline:
405
+ def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
406
+ required_columns = get_lookup_columns([self.lookup_tables["dwell_time"]])
448
407
  return builder.value.register_value_producer(
449
408
  f"{self.state_id}.dwell_time",
450
- source=builder.lookup.build_table(
451
- dwell_time_data, key_columns=["sex"], parameter_columns=["age", "year"]
452
- ),
453
- requires_columns=["age", "sex"],
409
+ source=self.lookup_tables["dwell_time"],
410
+ requires_columns=required_columns,
454
411
  )
455
412
 
456
- def load_disability_weight_data(self, builder: Builder) -> LookupTableData:
413
+ def load_disability_weight(self, builder: Builder) -> LookupTableData:
457
414
  if "disability_weight" in self._get_data_functions:
458
415
  disability_weight = self._get_data_functions["disability_weight"](
459
416
  builder, self.state_id
@@ -468,36 +425,15 @@ class DiseaseState(BaseDiseaseState):
468
425
 
469
426
  return disability_weight
470
427
 
471
- def get_base_disability_weight(
472
- self, builder: Builder, disability_weight_data: LookupTableData
473
- ) -> LookupTable:
474
- """
475
- Builds a LookupTable for the base disability weight of this state.
476
-
477
- Parameters
478
- ----------
479
- builder
480
- Interface to access simulation managers.
481
- disability_weight_data
482
- The data to use to build the LookupTable.
483
-
484
- Returns
485
- -------
486
- LookupTable
487
- The LookupTable for the disability weight of this state.
488
- """
489
- return builder.lookup.build_table(
490
- disability_weight_data, key_columns=["sex"], parameter_columns=["age", "year"]
491
- )
492
-
493
428
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
429
+ lookup_columns = get_lookup_columns([self.lookup_tables["disability_weight"]])
494
430
  return builder.value.register_value_producer(
495
431
  f"{self.state_id}.disability_weight",
496
432
  source=self.compute_disability_weight,
497
- requires_columns=["age", "sex", "alive", self.model],
433
+ requires_columns=lookup_columns + ["alive", self.model],
498
434
  )
499
435
 
500
- def load_excess_mortality_rate_data(self, builder: Builder) -> LookupTableData:
436
+ def load_excess_mortality_rate(self, builder: Builder) -> LookupTableData:
501
437
  if "excess_mortality_rate" in self._get_data_functions:
502
438
  return self._get_data_functions["excess_mortality_rate"](builder, self.state_id)
503
439
  elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
@@ -507,33 +443,12 @@ class DiseaseState(BaseDiseaseState):
507
443
  f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
508
444
  )
509
445
 
510
- def get_base_excess_mortality_rate(
511
- self, builder: Builder, excess_mortality_data: LookupTableData
512
- ) -> LookupTable:
513
- """
514
- Builds a LookupTable for the base excess mortality rate of this state.
515
-
516
- Parameters
517
- ----------
518
- builder
519
- Interface to access simulation managers.
520
- excess_mortality_data
521
- The data to use to build the LookupTable.
522
-
523
- Returns
524
- -------
525
- LookupTable
526
- The LookupTable for the base excess mortality rate of this state.
527
- """
528
- return builder.lookup.build_table(
529
- excess_mortality_data, key_columns=["sex"], parameter_columns=["age", "year"]
530
- )
531
-
532
446
  def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
447
+ lookup_columns = get_lookup_columns([self.lookup_tables["excess_mortality_rate"]])
533
448
  return builder.value.register_rate_producer(
534
449
  self.excess_mortality_rate_pipeline_name,
535
450
  source=self.compute_excess_mortality_rate,
536
- requires_columns=["age", "sex", "alive", self.model],
451
+ requires_columns=lookup_columns + ["alive", self.model],
537
452
  requires_values=[self.excess_mortality_rate_paf_pipeline_name],
538
453
  )
539
454
 
@@ -618,13 +533,15 @@ class DiseaseState(BaseDiseaseState):
618
533
  """
619
534
  disability_weight = pd.Series(0.0, index=index)
620
535
  with_condition = self.with_condition(index)
621
- disability_weight.loc[with_condition] = self.base_disability_weight(with_condition)
536
+ disability_weight.loc[with_condition] = self.lookup_tables["disability_weight"](
537
+ with_condition
538
+ )
622
539
  return disability_weight
623
540
 
624
541
  def compute_excess_mortality_rate(self, index: pd.Index) -> pd.Series:
625
542
  excess_mortality_rate = pd.Series(0.0, index=index)
626
543
  with_condition = self.with_condition(index)
627
- base_excess_mort = self.base_excess_mortality_rate(with_condition)
544
+ base_excess_mort = self.lookup_tables["excess_mortality_rate"](with_condition)
628
545
  joint_mediated_paf = self.joint_paf(with_condition)
629
546
  excess_mortality_rate.loc[with_condition] = base_excess_mort * (
630
547
  1 - joint_mediated_paf.values