vivarium-public-health 3.1.3__py3-none-any.whl → 3.1.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1 +1 @@
1
- __version__ = "3.1.3"
1
+ __version__ = "3.1.4"
@@ -0,0 +1,5 @@
1
+ from vivarium.exceptions import VivariumError
2
+
3
+
4
+ class DiseaseModelError(VivariumError):
5
+ pass
@@ -10,24 +10,20 @@ transitions at simulation initialization and during transitions.
10
10
  """
11
11
  import warnings
12
12
  from collections.abc import Callable, Iterable
13
+ from functools import partial
13
14
  from typing import Any
14
15
 
15
- import numpy as np
16
16
  import pandas as pd
17
- from vivarium.exceptions import VivariumError
18
17
  from vivarium.framework.engine import Builder
19
- from vivarium.framework.event import Event
20
18
  from vivarium.framework.population import SimulantData
21
19
  from vivarium.framework.state_machine import Machine
20
+ from vivarium.types import DataInput, LookupTableData
22
21
 
22
+ from vivarium_public_health.disease.exceptions import DiseaseModelError
23
23
  from vivarium_public_health.disease.state import BaseDiseaseState, SusceptibleState
24
24
  from vivarium_public_health.disease.transition import TransitionString
25
25
 
26
26
 
27
- class DiseaseModelError(VivariumError):
28
- pass
29
-
30
-
31
27
  class DiseaseModel(Machine):
32
28
 
33
29
  ##############
@@ -82,57 +78,77 @@ class DiseaseModel(Machine):
82
78
  cause_type: str = "cause",
83
79
  states: Iterable[BaseDiseaseState] = (),
84
80
  residual_state: BaseDiseaseState | None = None,
85
- **kwargs,
86
- ):
87
- super().__init__(cause, states=states, **kwargs)
81
+ cause_specific_mortality_rate: DataInput | None = None,
82
+ ) -> None:
83
+ super().__init__(cause, states=states)
88
84
  self.cause = cause
89
85
  self.cause_type = cause_type
86
+ self.residual_state = self._get_residual_state(initial_state, residual_state)
87
+ self._csmr_source = cause_specific_mortality_rate
88
+ self._get_data_functions = (
89
+ get_data_functions if get_data_functions is not None else {}
90
+ )
90
91
 
91
- if initial_state is not None:
92
+ if get_data_functions is not None:
92
93
  warnings.warn(
93
- "In the future, the 'initial_state' argument to DiseaseModel"
94
- " will be used to initialize all simulants into that state. To"
95
- " retain the current behavior of defining a residual state, use"
96
- " the 'residual_state' argument.",
94
+ "The argument 'get_data_functions' has been deprecated. Use"
95
+ " cause_specific_mortality_rate instead.",
97
96
  DeprecationWarning,
98
97
  stacklevel=2,
99
98
  )
100
99
 
101
- if residual_state:
100
+ if cause_specific_mortality_rate is not None:
102
101
  raise DiseaseModelError(
103
- "A DiseaseModel cannot be initialized with both"
104
- " 'initial_state and 'residual_state'."
102
+ "It is not allowed to pass cause_specific_mortality_rate"
103
+ " both as a stand-alone argument and as part of"
104
+ " get_data_functions."
105
105
  )
106
106
 
107
- self.residual_state = initial_state.state_id
108
- elif residual_state is not None:
109
- self.residual_state = residual_state.state_id
110
- else:
111
- self.residual_state = self._get_default_residual_state()
112
-
113
- self._get_data_functions = (
114
- get_data_functions if get_data_functions is not None else {}
115
- )
116
-
117
107
  def setup(self, builder: Builder) -> None:
118
108
  """Perform this component's setup."""
119
109
  super().setup(builder)
120
110
 
121
111
  self.configuration_age_start = builder.configuration.population.initialization_age_min
122
112
  self.configuration_age_end = builder.configuration.population.initialization_age_max
113
+
123
114
  builder.value.register_value_modifier(
124
115
  "cause_specific_mortality_rate",
125
116
  self.adjust_cause_specific_mortality_rate,
126
117
  requires_columns=["age", "sex"],
127
118
  )
128
- self.randomness = builder.randomness.get_stream(f"{self.state_column}_initial_states")
119
+
120
+ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
121
+ """Initialize the simulants in the population.
122
+
123
+ If all simulants are initialized at age 0, birth prevalence is used.
124
+ Otherwise, prevalence is used.
125
+
126
+ Parameters
127
+ ----------
128
+ pop_data
129
+ The population data object.
130
+ """
131
+ if pop_data.user_data.get("age_end", self.configuration_age_end) == 0:
132
+ initialization_table_name = "birth_prevalence"
133
+ else:
134
+ initialization_table_name = "prevalence"
135
+
136
+ for state in self.states:
137
+ state.lookup_tables["initialization_weights"] = state.lookup_tables[
138
+ initialization_table_name
139
+ ]
140
+
141
+ super().on_initialize_simulants(pop_data)
129
142
 
130
143
  #################
131
144
  # Setup methods #
132
145
  #################
133
146
 
134
147
  def load_cause_specific_mortality_rate(self, builder: Builder) -> float | pd.DataFrame:
135
- if "cause_specific_mortality_rate" not in self._get_data_functions:
148
+ if (
149
+ "cause_specific_mortality_rate" not in self._get_data_functions
150
+ and self._csmr_source is None
151
+ ):
136
152
  only_morbid = builder.data.load(f"cause.{self.cause}.restrictions")["yld_only"]
137
153
  if only_morbid:
138
154
  csmr_data = 0.0
@@ -140,64 +156,14 @@ class DiseaseModel(Machine):
140
156
  csmr_data = builder.data.load(
141
157
  f"{self.cause_type}.{self.cause}.cause_specific_mortality_rate"
142
158
  )
159
+ elif self._csmr_source is not None:
160
+ csmr_data = self.get_data(builder, self._csmr_source)
143
161
  else:
144
162
  csmr_data = self._get_data_functions["cause_specific_mortality_rate"](
145
163
  self.cause, builder
146
164
  )
147
165
  return csmr_data
148
166
 
149
- ########################
150
- # Event-driven methods #
151
- ########################
152
-
153
- def on_initialize_simulants(self, pop_data: SimulantData) -> None:
154
- population = self.population_view.subview(["age", "sex"]).get(pop_data.index)
155
-
156
- assert self.residual_state in {s.state_id for s in self.states}
157
-
158
- if pop_data.user_data["sim_state"] == "setup": # simulation start
159
- if self.configuration_age_start != self.configuration_age_end != 0:
160
- state_names, weights_bins = self.get_state_weights(
161
- pop_data.index, "prevalence"
162
- )
163
- else:
164
- raise NotImplementedError(
165
- "We do not currently support an age 0 cohort. "
166
- "configuration.population.initialization_age_min and "
167
- "configuration.population.initialization_age_max "
168
- "cannot both be 0."
169
- )
170
-
171
- else: # on time step
172
- if pop_data.user_data["age_start"] == pop_data.user_data["age_end"] == 0:
173
- state_names, weights_bins = self.get_state_weights(
174
- pop_data.index, "birth_prevalence"
175
- )
176
- else:
177
- state_names, weights_bins = self.get_state_weights(
178
- pop_data.index, "prevalence"
179
- )
180
-
181
- if state_names and not population.empty:
182
- # only do this if there are states in the model that supply prevalence data
183
- population["sex_id"] = population.sex.apply({"Male": 1, "Female": 2}.get)
184
-
185
- condition_column = self.assign_initial_status_to_simulants(
186
- population,
187
- state_names,
188
- weights_bins,
189
- self.randomness.get_draw(population.index),
190
- )
191
-
192
- condition_column = condition_column.rename(
193
- columns={"condition_state": self.state_column}
194
- )
195
- else:
196
- condition_column = pd.Series(
197
- self.residual_state, index=population.index, name=self.state_column
198
- )
199
- self.population_view.update(condition_column)
200
-
201
167
  ##################################
202
168
  # Pipeline sources and modifiers #
203
169
  ##################################
@@ -209,40 +175,78 @@ class DiseaseModel(Machine):
209
175
  # Helper functions #
210
176
  ####################
211
177
 
212
- def _get_default_residual_state(self):
213
- susceptible_states = [s for s in self.states if isinstance(s, SusceptibleState)]
214
- if len(susceptible_states) != 1:
215
- raise DiseaseModelError("Disease model must have exactly one SusceptibleState.")
216
- return susceptible_states[0].state_id
217
-
218
- def get_state_weights(
219
- self, pop_index: pd.Index, prevalence_type: str
220
- ) -> tuple[list[str], np.ndarray | None]:
221
- states = [state for state in self.states if state.lookup_tables.get(prevalence_type)]
222
-
223
- if not states:
224
- return states, None
178
+ def _get_residual_state(
179
+ self, initial_state: BaseDiseaseState, residual_state: BaseDiseaseState
180
+ ) -> BaseDiseaseState:
181
+ """Get the residual state for the DiseaseModel.
225
182
 
226
- weights = [state.lookup_tables.get(prevalence_type)(pop_index) for state in states]
227
- for w in weights:
228
- w.reset_index(inplace=True, drop=True)
229
- weights += ((1 - np.sum(weights, axis=0)),)
183
+ This will be the residual state if it is provided, otherwise it will be
184
+ the model's SusceptibleState. This method also calculates the residual
185
+ state's birth_prevalence and prevalence.
186
+ """
187
+ if initial_state is not None:
188
+ warnings.warn(
189
+ "In the future, the 'initial_state' argument to DiseaseModel"
190
+ " will be used to initialize all simulants into that state. To"
191
+ " retain the current behavior of defining a residual state, use"
192
+ " the 'residual_state' argument.",
193
+ DeprecationWarning,
194
+ stacklevel=2,
195
+ )
230
196
 
231
- weights = np.array(weights).T
232
- weights_bins = np.cumsum(weights, axis=1)
197
+ if residual_state:
198
+ raise DiseaseModelError(
199
+ "A DiseaseModel cannot be initialized with both"
200
+ " 'initial_state and 'residual_state'."
201
+ )
233
202
 
234
- state_names = [s.state_id for s in states] + [self.residual_state]
203
+ residual_state = initial_state
204
+ elif residual_state is None:
205
+ susceptible_states = [s for s in self.states if isinstance(s, SusceptibleState)]
206
+ if len(susceptible_states) != 1:
207
+ raise DiseaseModelError(
208
+ "Disease model must have exactly one SusceptibleState."
209
+ )
210
+ residual_state = susceptible_states[0]
235
211
 
236
- return state_names, weights_bins
212
+ if residual_state not in self.states:
213
+ raise DiseaseModelError(
214
+ f"Residual state '{self.residual_state}' must be one of the"
215
+ f" states: {self.states}."
216
+ )
237
217
 
238
- @staticmethod
239
- def assign_initial_status_to_simulants(
240
- simulants_df, state_names, weights_bins, propensities
241
- ):
242
- simulants = simulants_df[["age", "sex"]].copy()
218
+ residual_state.birth_prevalence = partial(
219
+ self._get_residual_state_probabilities, table_name="birth_prevalence"
220
+ )
221
+ residual_state.prevalence = partial(
222
+ self._get_residual_state_probabilities, table_name="prevalence"
223
+ )
243
224
 
244
- choice_index = (propensities.values[np.newaxis].T > weights_bins).sum(axis=1)
245
- initial_states = pd.Series(np.array(state_names)[choice_index], index=simulants.index)
225
+ return residual_state
226
+
227
+ def _get_residual_state_probabilities(
228
+ self, builder: Builder, table_name: str
229
+ ) -> LookupTableData:
230
+ """Calculate the probabilities of the residual state based on the other states."""
231
+ non_residual_states = [s for s in self.states if s != self.residual_state]
232
+ non_residual_probabilities = 0
233
+ for state in non_residual_states:
234
+ weights_source = builder.configuration[state.name].data_sources[table_name]
235
+ weights = state.get_data(builder, weights_source)
236
+ if isinstance(weights, pd.DataFrame):
237
+ weights = weights.set_index(
238
+ [c for c in weights.columns if c != "value"]
239
+ ).squeeze()
240
+ non_residual_probabilities += weights
241
+
242
+ residual_probabilities = 1 - non_residual_probabilities
243
+
244
+ if pd.Series(residual_probabilities < 0).any():
245
+ raise ValueError(
246
+ f"The {table_name} for the states in the DiseaseModel must sum"
247
+ " to less than 1."
248
+ )
249
+ if isinstance(residual_probabilities, pd.Series):
250
+ residual_probabilities = residual_probabilities.reset_index()
246
251
 
247
- simulants.loc[:, "condition_state"] = initial_states
248
- return simulants
252
+ return residual_probabilities
@@ -6,19 +6,20 @@ Disease States
6
6
  This module contains tools to manage standard disease states.
7
7
 
8
8
  """
9
-
9
+ import warnings
10
10
  from collections.abc import Callable
11
11
  from typing import Any
12
12
 
13
13
  import numpy as np
14
14
  import pandas as pd
15
15
  from vivarium.framework.engine import Builder
16
- from vivarium.framework.lookup import LookupTableData
17
16
  from vivarium.framework.population import PopulationView, SimulantData
18
17
  from vivarium.framework.randomness import RandomnessStream
19
18
  from vivarium.framework.state_machine import State, Transient, Transition, Trigger
20
19
  from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
20
+ from vivarium.types import DataInput, LookupTableData
21
21
 
22
+ from vivarium_public_health.disease.exceptions import DiseaseModelError
22
23
  from vivarium_public_health.disease.transition import (
23
24
  ProportionTransition,
24
25
  RateTransition,
@@ -28,10 +29,25 @@ from vivarium_public_health.utilities import get_lookup_columns, is_non_zero
28
29
 
29
30
 
30
31
  class BaseDiseaseState(State):
32
+
31
33
  ##############
32
34
  # Properties #
33
35
  ##############
34
36
 
37
+ @property
38
+ def configuration_defaults(self) -> dict[str, Any]:
39
+ configuration_defaults = super().configuration_defaults
40
+ additional_defaults = {
41
+ "prevalence": self.prevalence,
42
+ "birth_prevalence": self.birth_prevalence,
43
+ }
44
+ data_sources = {
45
+ **configuration_defaults[self.name]["data_sources"],
46
+ **additional_defaults,
47
+ }
48
+ configuration_defaults[self.name]["data_sources"] = data_sources
49
+ return configuration_defaults
50
+
35
51
  @property
36
52
  def columns_created(self):
37
53
  return [self.event_time_column, self.event_count_column]
@@ -59,13 +75,15 @@ class BaseDiseaseState(State):
59
75
  side_effect_function: Callable | None = None,
60
76
  cause_type: str = "cause",
61
77
  ):
62
- super().__init__(state_id, allow_self_transition) # becomes state_id
78
+ super().__init__(state_id, allow_self_transition)
63
79
  self.cause_type = cause_type
64
80
 
65
81
  self.side_effect_function = side_effect_function
66
82
 
67
83
  self.event_time_column = self.state_id + "_event_time"
68
84
  self.event_count_column = self.state_id + "_event_count"
85
+ self.prevalence = 0.0
86
+ self.birth_prevalence = 0.0
69
87
 
70
88
  ########################
71
89
  # Event-driven methods #
@@ -87,10 +105,7 @@ class BaseDiseaseState(State):
87
105
  def get_initialization_parameters(self) -> dict[str, Any]:
88
106
  """Exclude side effect function and cause type from name and __repr__."""
89
107
  initialization_parameters = super().get_initialization_parameters()
90
- for key in ["side_effect_function", "cause_type"]:
91
- if key in initialization_parameters.keys():
92
- del initialization_parameters[key]
93
- return initialization_parameters
108
+ return {"state_id": initialization_parameters["state_id"]}
94
109
 
95
110
  def get_initial_event_times(self, pop_data: SimulantData) -> pd.DataFrame:
96
111
  return pd.DataFrame(
@@ -132,6 +147,8 @@ class BaseDiseaseState(State):
132
147
  output: "BaseDiseaseState",
133
148
  get_data_functions: dict[str, Callable] = None,
134
149
  triggered=Trigger.NOT_TRIGGERED,
150
+ transition_rate: DataInput | None = None,
151
+ rate_type: str = "transition_rate",
135
152
  ) -> RateTransition:
136
153
  """Builds a RateTransition from this state to the given state.
137
154
 
@@ -139,18 +156,29 @@ class BaseDiseaseState(State):
139
156
  ----------
140
157
  output
141
158
  The end state after the transition.
142
-
143
159
  get_data_functions
144
160
  Map from transition type to the function to pull that transition's data.
145
161
  triggered
146
162
  The trigger for the transition
147
-
163
+ transition_rate
164
+ The transition rate source. Can be the data itself, a function to
165
+ retrieve the data, or the artifact key containing the data.
166
+ rate_type
167
+ The type of rate. Can be "incidence_rate", "transition_rate", or
168
+ "remission_rate".
148
169
 
149
170
  Returns
150
171
  -------
151
172
  The created transition object.
152
173
  """
153
- transition = RateTransition(self, output, get_data_functions, triggered)
174
+ transition = RateTransition(
175
+ input_state=self,
176
+ output_state=output,
177
+ get_data_functions=get_data_functions,
178
+ triggered=triggered,
179
+ transition_rate=transition_rate,
180
+ rate_type=rate_type,
181
+ )
154
182
  self.add_transition(transition)
155
183
  return transition
156
184
 
@@ -159,6 +187,7 @@ class BaseDiseaseState(State):
159
187
  output: "BaseDiseaseState",
160
188
  get_data_functions: dict[str, Callable] | None = None,
161
189
  triggered=Trigger.NOT_TRIGGERED,
190
+ proportion: DataInput | None = None,
162
191
  ) -> ProportionTransition:
163
192
  """Builds a ProportionTransition from this state to the given state.
164
193
 
@@ -170,15 +199,26 @@ class BaseDiseaseState(State):
170
199
  Map from transition type to the function to pull that transition's data.
171
200
  triggered
172
201
  The trigger for the transition.
202
+ proportion
203
+ The proportion source. Can be the data itself, a function to
204
+ retrieve the data, or the artifact key containing the data.
173
205
 
174
206
  Returns
175
207
  -------
176
208
  The created transition object.
177
209
  """
178
- if "proportion" not in get_data_functions:
210
+ if (
211
+ get_data_functions is None or "proportion" not in get_data_functions
212
+ ) and proportion is None:
179
213
  raise ValueError("You must supply a proportion function.")
180
214
 
181
- transition = ProportionTransition(self, output, get_data_functions, triggered)
215
+ transition = ProportionTransition(
216
+ input_state=self,
217
+ output_state=output,
218
+ get_data_functions=get_data_functions,
219
+ triggered=triggered,
220
+ proportion=proportion,
221
+ )
182
222
  self.add_transition(transition)
183
223
  return transition
184
224
 
@@ -220,17 +260,19 @@ class NonDiseasedState(BaseDiseaseState):
220
260
  self,
221
261
  output: BaseDiseaseState,
222
262
  get_data_functions: dict[str, Callable] = None,
223
- **kwargs,
263
+ triggered=Trigger.NOT_TRIGGERED,
264
+ transition_rate: DataInput | None = None,
265
+ **_kwargs,
224
266
  ) -> RateTransition:
225
- if get_data_functions is None:
226
- get_data_functions = {
227
- "incidence_rate": lambda builder, cause: builder.data.load(
228
- f"{self.cause_type}.{cause}.incidence_rate"
229
- )
230
- }
231
- elif "incidence_rate" not in get_data_functions:
232
- raise ValueError("You must supply an incidence rate function.")
233
- return super().add_rate_transition(output, get_data_functions, **kwargs)
267
+ if get_data_functions is None and transition_rate is None:
268
+ transition_rate = f"{self.cause_type}.{output.state_id}.incidence_rate"
269
+ return super().add_rate_transition(
270
+ output=output,
271
+ get_data_functions=get_data_functions,
272
+ triggered=triggered,
273
+ transition_rate=transition_rate,
274
+ rate_type="incidence_rate",
275
+ )
234
276
 
235
277
 
236
278
  class SusceptibleState(NonDiseasedState):
@@ -253,6 +295,13 @@ class SusceptibleState(NonDiseasedState):
253
295
  name_prefix="susceptible_to_",
254
296
  )
255
297
 
298
+ ##################
299
+ # Public methods #
300
+ ##################
301
+
302
+ def has_initialization_weights(self) -> bool:
303
+ return True
304
+
256
305
 
257
306
  class RecoveredState(NonDiseasedState):
258
307
  def __init__(
@@ -280,17 +329,20 @@ class DiseaseState(BaseDiseaseState):
280
329
 
281
330
  @property
282
331
  def configuration_defaults(self) -> dict[str, Any]:
283
- return {
284
- f"{self.name}": {
285
- "data_sources": {
286
- "prevalence": self.load_prevalence,
287
- "birth_prevalence": self.load_birth_prevalence,
288
- "dwell_time": self.load_dwell_time,
289
- "disability_weight": self.load_disability_weight,
290
- "excess_mortality_rate": self.load_excess_mortality_rate,
291
- },
292
- },
332
+ configuration_defaults = super().configuration_defaults
333
+ additional_defaults = {
334
+ "prevalence": self._prevalence_source,
335
+ "birth_prevalence": self._birth_prevalence_source,
336
+ "dwell_time": self._dwell_time_source,
337
+ "disability_weight": self._disability_weight_source,
338
+ "excess_mortality_rate": self._excess_modality_rate_source,
293
339
  }
340
+ data_sources = {
341
+ **configuration_defaults[self.name]["data_sources"],
342
+ **additional_defaults,
343
+ }
344
+ configuration_defaults[self.name]["data_sources"] = data_sources
345
+ return configuration_defaults
294
346
 
295
347
  #####################
296
348
  # Lifecycle methods #
@@ -304,6 +356,11 @@ class DiseaseState(BaseDiseaseState):
304
356
  cause_type: str = "cause",
305
357
  get_data_functions: dict[str, Callable] | None = None,
306
358
  cleanup_function: Callable | None = None,
359
+ prevalence: DataInput | None = None,
360
+ birth_prevalence: DataInput | None = None,
361
+ dwell_time: DataInput | None = None,
362
+ disability_weight: DataInput | None = None,
363
+ excess_mortality_rate: DataInput | None = None,
307
364
  ):
308
365
  """
309
366
  Parameters
@@ -322,7 +379,31 @@ class DiseaseState(BaseDiseaseState):
322
379
  various state attributes.
323
380
  cleanup_function
324
381
  The cleanup function.
382
+ prevalence
383
+ The prevalence source. This is used to initialize simulants. Can be
384
+ the data itself, a function to retrieve the data, or the artifact
385
+ key containing the data.
386
+ birth_prevalence
387
+ The birth prevalence source. This is used to initialize newborn
388
+ simulants. Can be the data itself, a function to retrieve the data,
389
+ or the artifact key containing the data.
390
+ dwell_time
391
+ The dwell time source. This is used to determine how long a simulant
392
+ must remain in the state before transitioning. Can be the data
393
+ itself, a function to retrieve the data, or the artifact key
394
+ containing the data.
395
+ disability_weight
396
+ The disability weight source. This is used to calculate the
397
+ disability weight for simulants in this state. Can be the data
398
+ itself, a function to retrieve the data, or the artifact key
399
+ containing the data.
400
+ excess_mortality_rate
401
+ The excess mortality rate source. This is used to calculate the
402
+ excess mortality rate for simulants in this state. Can be the data
403
+ itself, a function to retrieve the data, or the artifact key
404
+ containing the data.
325
405
  """
406
+
326
407
  super().__init__(
327
408
  state_id,
328
409
  allow_self_transition=allow_self_transition,
@@ -340,6 +421,34 @@ class DiseaseState(BaseDiseaseState):
340
421
  )
341
422
  self._cleanup_function = cleanup_function
342
423
 
424
+ if get_data_functions is not None:
425
+ warnings.warn(
426
+ "The argument 'get_data_functions' has been deprecated. Use"
427
+ " cause_specific_mortality_rate instead.",
428
+ DeprecationWarning,
429
+ stacklevel=2,
430
+ )
431
+
432
+ for data_type in self._get_data_functions:
433
+ try:
434
+ data_source = locals()[data_type]
435
+ except KeyError:
436
+ data_source = None
437
+
438
+ if locals()[data_type] is not None:
439
+ raise DiseaseModelError(
440
+ f"It is not allowed to pass '{data_type}' both as a"
441
+ " stand-alone argument and as part of get_data_functions."
442
+ )
443
+
444
+ self._prevalence_source = self.get_prevalence_source(prevalence)
445
+ self._birth_prevalence_source = self.get_birth_prevalence_source(birth_prevalence)
446
+ self._dwell_time_source = self.get_dwell_time_source(dwell_time)
447
+ self._disability_weight_source = self.get_disability_weight_source(disability_weight)
448
+ self._excess_modality_rate_source = self.get_excess_mortality_rate_source(
449
+ excess_mortality_rate
450
+ )
451
+
343
452
  # noinspection PyAttributeOutsideInit
344
453
  def setup(self, builder: Builder) -> None:
345
454
  """Performs this component's simulation setup.
@@ -377,32 +486,45 @@ class DiseaseState(BaseDiseaseState):
377
486
  # Setup methods #
378
487
  #################
379
488
 
380
- def load_prevalence(self, builder: Builder) -> LookupTableData:
489
+ def _get_data_functions_source(self, data_type: str) -> DataInput:
490
+ def data_source(builder: Builder) -> LookupTableData:
491
+ return self._get_data_functions[data_type](builder, self.state_id)
492
+
493
+ return data_source
494
+
495
+ def get_prevalence_source(self, prevalence: DataInput | None) -> DataInput:
381
496
  if "prevalence" in self._get_data_functions:
382
- return self._get_data_functions["prevalence"](builder, self.state_id)
497
+ return self._get_data_functions_source("prevalence")
498
+ elif prevalence is not None:
499
+ return prevalence
383
500
  else:
384
- return builder.data.load(f"{self.cause_type}.{self.state_id}.prevalence")
501
+ return f"{self.cause_type}.{self.state_id}.prevalence"
385
502
 
386
- def load_birth_prevalence(self, builder: Builder) -> LookupTableData:
503
+ def get_birth_prevalence_source(self, birth_prevalence: DataInput | None) -> DataInput:
387
504
  if "birth_prevalence" in self._get_data_functions:
388
- return self._get_data_functions["birth_prevalence"](builder, self.state_id)
505
+ return self._get_data_functions_source("birth_prevalence")
506
+ elif birth_prevalence is not None:
507
+ return birth_prevalence
389
508
  else:
390
- return 0
509
+ return 0.0
391
510
 
392
- def load_dwell_time(self, builder: Builder) -> LookupTableData:
511
+ def get_dwell_time_source(self, dwell_time: DataInput | None) -> DataInput:
393
512
  if "dwell_time" in self._get_data_functions:
394
- dwell_time = self._get_data_functions["dwell_time"](builder, self.state_id)
395
- else:
396
- dwell_time = 0
397
-
398
- if isinstance(dwell_time, pd.Timedelta):
399
- dwell_time = dwell_time.total_seconds() / (60 * 60 * 24)
400
- if (
401
- isinstance(dwell_time, pd.DataFrame) and np.any(dwell_time.value != 0)
402
- ) or dwell_time > 0:
403
- self.transition_set.allow_null_transition = True
404
-
405
- return dwell_time
513
+ dwell_time = self._get_data_functions_source("dwell_time")
514
+ elif dwell_time is None:
515
+ dwell_time = 0.0
516
+
517
+ def dwell_time_source(builder: Builder) -> LookupTableData:
518
+ dwell_time_ = self.get_data(builder, dwell_time)
519
+ if isinstance(dwell_time_, pd.Timedelta):
520
+ dwell_time_ = dwell_time_.total_seconds() / (60 * 60 * 24)
521
+ if (
522
+ isinstance(dwell_time_, pd.DataFrame) and np.any(dwell_time_.value != 0)
523
+ ) or dwell_time_ > 0:
524
+ self.transition_set.allow_null_transition = True
525
+ return dwell_time_
526
+
527
+ return dwell_time_source
406
528
 
407
529
  def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
408
530
  required_columns = get_lookup_columns([self.lookup_tables["dwell_time"]])
@@ -412,20 +534,22 @@ class DiseaseState(BaseDiseaseState):
412
534
  requires_columns=required_columns,
413
535
  )
414
536
 
415
- def load_disability_weight(self, builder: Builder) -> LookupTableData:
537
+ def get_disability_weight_source(self, disability_weight: DataInput | None) -> DataInput:
416
538
  if "disability_weight" in self._get_data_functions:
417
- disability_weight = self._get_data_functions["disability_weight"](
418
- builder, self.state_id
419
- )
539
+ disability_weight = self._get_data_functions_source("disability_weight")
540
+ elif disability_weight is not None:
541
+ disability_weight = disability_weight
420
542
  else:
421
- disability_weight = builder.data.load(
422
- f"{self.cause_type}.{self.state_id}.disability_weight"
423
- )
543
+ disability_weight = f"{self.cause_type}.{self.state_id}.disability_weight"
424
544
 
425
- if isinstance(disability_weight, pd.DataFrame) and len(disability_weight) == 1:
426
- disability_weight = disability_weight.value[0] # sequela only have single value
545
+ def disability_weight_source(builder: Builder) -> LookupTableData:
546
+ disability_weight_ = self.get_data(builder, disability_weight)
547
+ if isinstance(disability_weight_, pd.DataFrame) and len(disability_weight_) == 1:
548
+ # sequela only have single value
549
+ disability_weight_ = disability_weight_.value[0]
550
+ return disability_weight_
427
551
 
428
- return disability_weight
552
+ return disability_weight_source
429
553
 
430
554
  def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
431
555
  lookup_columns = get_lookup_columns([self.lookup_tables["disability_weight"]])
@@ -435,16 +559,25 @@ class DiseaseState(BaseDiseaseState):
435
559
  requires_columns=lookup_columns + ["alive", self.model],
436
560
  )
437
561
 
438
- def load_excess_mortality_rate(self, builder: Builder) -> LookupTableData:
562
+ def get_excess_mortality_rate_source(
563
+ self, excess_mortality_rate: DataInput | None
564
+ ) -> DataInput:
439
565
  if "excess_mortality_rate" in self._get_data_functions:
440
- return self._get_data_functions["excess_mortality_rate"](builder, self.state_id)
441
- elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
442
- return 0
443
- else:
566
+ excess_mortality_rate = self._get_data_functions_source("excess_mortality_rate")
567
+ elif excess_mortality_rate is None:
568
+ excess_mortality_rate = f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
569
+
570
+ def excess_mortality_rate_source(builder: Builder) -> LookupTableData:
571
+ if excess_mortality_rate is not None:
572
+ return self.get_data(builder, excess_mortality_rate)
573
+ elif builder.data.load(f"cause.{self.model}.restrictions")["yld_only"]:
574
+ return 0
444
575
  return builder.data.load(
445
576
  f"{self.cause_type}.{self.state_id}.excess_mortality_rate"
446
577
  )
447
578
 
579
+ return excess_mortality_rate_source
580
+
448
581
  def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
449
582
  lookup_columns = get_lookup_columns([self.lookup_tables["excess_mortality_rate"]])
450
583
  return builder.value.register_rate_producer(
@@ -470,24 +603,27 @@ class DiseaseState(BaseDiseaseState):
470
603
  # Public methods #
471
604
  ##################
472
605
 
606
+ def has_initialization_weights(self) -> bool:
607
+ return True
608
+
473
609
  def add_rate_transition(
474
610
  self,
475
611
  output: BaseDiseaseState,
476
612
  get_data_functions: dict[str, Callable] = None,
477
- **kwargs,
613
+ triggered=Trigger.NOT_TRIGGERED,
614
+ transition_rate: DataInput | None = None,
615
+ rate_type: str = "transition_rate",
478
616
  ) -> RateTransition:
479
- if get_data_functions is None:
480
- get_data_functions = {
481
- "remission_rate": lambda builder, cause: builder.data.load(
482
- f"{self.cause_type}.{cause}.remission_rate"
483
- )
484
- }
485
- elif (
486
- "remission_rate" not in get_data_functions
487
- and "transition_rate" not in get_data_functions
488
- ):
489
- raise ValueError("You must supply a transition rate or remission rate function.")
490
- return super().add_rate_transition(output, get_data_functions, **kwargs)
617
+ if get_data_functions is None and transition_rate is None:
618
+ transition_rate = f"{self.cause_type}.{self.state_id}.remission_rate"
619
+ rate_type = "remission_rate"
620
+ return super().add_rate_transition(
621
+ output=output,
622
+ get_data_functions=get_data_functions,
623
+ triggered=triggered,
624
+ transition_rate=transition_rate,
625
+ rate_type=rate_type,
626
+ )
491
627
 
492
628
  def add_dwell_time_transition(
493
629
  self,
@@ -6,7 +6,7 @@ Disease Transitions
6
6
  This module contains tools to model transitions between disease states.
7
7
 
8
8
  """
9
-
9
+ import warnings
10
10
  from collections.abc import Callable
11
11
  from typing import TYPE_CHECKING, Any
12
12
 
@@ -15,7 +15,9 @@ from vivarium.framework.engine import Builder
15
15
  from vivarium.framework.state_machine import Transition, Trigger
16
16
  from vivarium.framework.utilities import rate_to_probability
17
17
  from vivarium.framework.values import list_combiner, union_post_processor
18
+ from vivarium.types import DataInput
18
19
 
20
+ from vivarium_public_health.disease.exceptions import DiseaseModelError
19
21
  from vivarium_public_health.utilities import get_lookup_columns
20
22
 
21
23
  if TYPE_CHECKING:
@@ -44,7 +46,7 @@ class RateTransition(Transition):
44
46
  return {
45
47
  f"{self.name}": {
46
48
  "data_sources": {
47
- "transition_rate": self.load_transition_rate,
49
+ "transition_rate": self._rate_source,
48
50
  },
49
51
  },
50
52
  }
@@ -55,19 +57,37 @@ class RateTransition(Transition):
55
57
 
56
58
  @property
57
59
  def transition_rate_pipeline_name(self) -> str:
58
- if "incidence_rate" in self._get_data_functions:
59
- pipeline_name = f"{self.output_state.state_id}.incidence_rate"
60
- elif "remission_rate" in self._get_data_functions:
61
- pipeline_name = f"{self.input_state.state_id}.remission_rate"
62
- elif "transition_rate" in self._get_data_functions:
63
- pipeline_name = (
64
- f"{self.input_state.state_id}_to_{self.output_state.state_id}.transition_rate"
65
- )
60
+ if self._get_data_functions:
61
+ if "incidence_rate" in self._get_data_functions:
62
+ pipeline_name = f"{self.output_state.state_id}.incidence_rate"
63
+ elif "remission_rate" in self._get_data_functions:
64
+ pipeline_name = f"{self.input_state.state_id}.remission_rate"
65
+ elif "transition_rate" in self._get_data_functions:
66
+ pipeline_name = (
67
+ f"{self.input_state.state_id}_to_{self.output_state.state_id}"
68
+ ".transition_rate"
69
+ )
70
+ else:
71
+ raise DiseaseModelError(
72
+ "Cannot determine rate_transition pipeline name: "
73
+ "no valid data functions supplied."
74
+ )
66
75
  else:
67
- raise ValueError(
68
- "Cannot determine rate_transition pipeline name: "
69
- "no valid data functions supplied."
70
- )
76
+ if self.rate_type == "incidence_rate":
77
+ pipeline_name = f"{self.output_state.state_id}.incidence_rate"
78
+ elif self.rate_type == "remission_rate":
79
+ pipeline_name = f"{self.input_state.state_id}.remission_rate"
80
+ elif self.rate_type == "transition_rate":
81
+ pipeline_name = (
82
+ f"{self.input_state.state_id}_to_{self.output_state.state_id}"
83
+ ".transition_rate"
84
+ )
85
+ else:
86
+ raise DiseaseModelError(
87
+ "Cannot determine rate_transition pipeline name: invalid"
88
+ f" rate_type '{self.rate_type} supplied."
89
+ )
90
+
71
91
  return pipeline_name
72
92
 
73
93
  #####################
@@ -80,6 +100,8 @@ class RateTransition(Transition):
80
100
  output_state: "BaseDiseaseState",
81
101
  get_data_functions: dict[str, Callable] = None,
82
102
  triggered=Trigger.NOT_TRIGGERED,
103
+ transition_rate: DataInput | None = None,
104
+ rate_type: str = "transition_rate",
83
105
  ):
84
106
  super().__init__(
85
107
  input_state, output_state, probability_func=self._probability, triggered=triggered
@@ -87,6 +109,22 @@ class RateTransition(Transition):
87
109
  self._get_data_functions = (
88
110
  get_data_functions if get_data_functions is not None else {}
89
111
  )
112
+ self._rate_source = self._get_rate_source(transition_rate)
113
+ self.rate_type = rate_type
114
+
115
+ if get_data_functions is not None:
116
+ warnings.warn(
117
+ "The argument 'get_data_functions' has been deprecated. Use"
118
+ " 'transition_rate' instead.",
119
+ DeprecationWarning,
120
+ stacklevel=2,
121
+ )
122
+ if transition_rate is not None:
123
+ raise DiseaseModelError(
124
+ "It is not allowed to pass a transition rate"
125
+ " both as a stand-alone argument and as part of"
126
+ " get_data_functions."
127
+ )
90
128
 
91
129
  # noinspection PyAttributeOutsideInit
92
130
  def setup(self, builder: Builder) -> None:
@@ -109,17 +147,19 @@ class RateTransition(Transition):
109
147
  # Setup methods #
110
148
  #################
111
149
 
112
- def load_transition_rate(self, builder: Builder) -> float | pd.DataFrame:
113
- if "incidence_rate" in self._get_data_functions:
114
- rate_data = self._get_data_functions["incidence_rate"](
150
+ def _get_rate_source(self, transition_rate: DataInput | None) -> DataInput:
151
+ if transition_rate is not None:
152
+ rate_data = transition_rate
153
+ elif "incidence_rate" in self._get_data_functions:
154
+ rate_data = lambda builder: self._get_data_functions["incidence_rate"](
115
155
  builder, self.output_state.state_id
116
156
  )
117
157
  elif "remission_rate" in self._get_data_functions:
118
- rate_data = self._get_data_functions["remission_rate"](
158
+ rate_data = lambda builder: self._get_data_functions["remission_rate"](
119
159
  builder, self.input_state.state_id
120
160
  )
121
161
  elif "transition_rate" in self._get_data_functions:
122
- rate_data = self._get_data_functions["transition_rate"](
162
+ rate_data = lambda builder: self._get_data_functions["transition_rate"](
123
163
  builder, self.input_state.state_id, self.output_state.state_id
124
164
  )
125
165
  else:
@@ -172,21 +212,38 @@ class ProportionTransition(Transition):
172
212
  output_state: "BaseDiseaseState",
173
213
  get_data_functions: dict[str, Callable] = None,
174
214
  triggered=Trigger.NOT_TRIGGERED,
215
+ proportion: DataInput | None = None,
175
216
  ):
176
217
  super().__init__(
177
218
  input_state, output_state, probability_func=self._probability, triggered=triggered
178
219
  )
220
+ self._proportion_source = proportion
179
221
  self._get_data_functions = (
180
222
  get_data_functions if get_data_functions is not None else {}
181
223
  )
182
224
 
225
+ if get_data_functions is not None:
226
+ warnings.warn(
227
+ "The argument 'get_data_functions' has been deprecated. Use"
228
+ " 'proportion' instead.",
229
+ DeprecationWarning,
230
+ stacklevel=2,
231
+ )
232
+ if proportion is not None:
233
+ raise DiseaseModelError(
234
+ "It is not allowed to pass a proportion both as a"
235
+ " stand-alone argument and as part of get_data_functions."
236
+ )
237
+
183
238
  #################
184
239
  # Setup methods #
185
240
  #################
186
241
 
187
- def load_proportion(self, builder: Builder) -> float | pd.DataFrame:
242
+ def load_proportion(self, builder: Builder) -> DataInput:
243
+ if self._proportion_source is not None:
244
+ return self._proportion_source
188
245
  if "proportion" not in self._get_data_functions:
189
- raise ValueError("Must supply a proportion function")
246
+ raise DiseaseModelError("Must supply a proportion function")
190
247
  return self._get_data_functions["proportion"](builder, self.output_state.state_id)
191
248
 
192
249
  def _probability(self, index):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vivarium_public_health
3
- Version: 3.1.3
3
+ Version: 3.1.4
4
4
  Summary: Components for modelling diseases, risks, and interventions with ``vivarium``
5
5
  Home-page: https://github.com/ihmeuw/vivarium_public_health
6
6
  Author: The vivarium developers
@@ -26,7 +26,7 @@ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
26
26
  Classifier: Topic :: Scientific/Engineering :: Physics
27
27
  Classifier: Topic :: Software Development :: Libraries
28
28
  License-File: LICENSE.txt
29
- Requires-Dist: vivarium>=3.2.0
29
+ Requires-Dist: vivarium>=3.2.3
30
30
  Requires-Dist: layered_config_tree>=1.0.1
31
31
  Requires-Dist: loguru
32
32
  Requires-Dist: numpy<2.0.0
@@ -1,13 +1,14 @@
1
1
  vivarium_public_health/__about__.py,sha256=RgWycPypKZS80TpSX7o41cREnG8PfguNHDHLuLyl820,487
2
2
  vivarium_public_health/__init__.py,sha256=GDeeP-7OlCBwPuv_xQoB1wNmvCaFsqfTB7qnnYApm0w,1343
3
- vivarium_public_health/_version.py,sha256=xoSqNkNOCK0xzFXnsO80epHc1vbiRC8Nbn4Cy-2wBX4,22
3
+ vivarium_public_health/_version.py,sha256=2uixSkocUHf2KiY1oTfzz_5AQGmlrHnypVxGgr4mV9c,22
4
4
  vivarium_public_health/utilities.py,sha256=QNXQ6fhAr1HcV-GwKw7wQLz6QyuNxqNvMA-XujKjTgs,3035
5
5
  vivarium_public_health/disease/__init__.py,sha256=VUJHDLlE6ngo2qHNQUtZ8OWH5H_T7_ao-xsYKDkRmHw,443
6
- vivarium_public_health/disease/model.py,sha256=ZwhhQCc8jj_QeJZO2zLtp_yWzqRxvLjuzW7iDUmmBGA,8852
6
+ vivarium_public_health/disease/exceptions.py,sha256=vb30IIV82OiDf2cNZCs_E2rF6mdDDHbnZSND60no5CU,97
7
+ vivarium_public_health/disease/model.py,sha256=qA0mhbIL0rfPtJ1Csmvxvf757vUdrjcbpxYmL_dsniI,9175
7
8
  vivarium_public_health/disease/models.py,sha256=01UK7yB2zGPFzmlIpvhd-XnGe6vSCMDza3QTidgY7Nc,3479
8
9
  vivarium_public_health/disease/special_disease.py,sha256=kTVuE5rQjUK62ysComG8nB2f61aCKdca9trRB1zsDCQ,14537
9
- vivarium_public_health/disease/state.py,sha256=NwTnxB_i05magsJuEc1_Z_0xMMKsqSPXIpibUdxPrWY,22333
10
- vivarium_public_health/disease/transition.py,sha256=4g_F8L3Godb4yQjHRr42n95QXo6m0ldI6c8_vu6cTfo,6429
10
+ vivarium_public_health/disease/state.py,sha256=l1iBZl8IpNOegQjZIY2xq09SBhsn5jSoYD89CymJAUs,28309
11
+ vivarium_public_health/disease/transition.py,sha256=qRpHzq29S_I_KzyDncAahg27bnHNdrA5tsUWuJ4Zty4,8964
11
12
  vivarium_public_health/mslt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
13
  vivarium_public_health/mslt/delay.py,sha256=TrFp9nmv-PSu6xdQXfyeM4X8RrXTLUnQMxOuhdvD-Wk,22787
13
14
  vivarium_public_health/mslt/disease.py,sha256=DONSItBiOUk1qBE6Msw0vrV0XLW4BPzWVxwVK90GK1I,16638
@@ -42,8 +43,8 @@ vivarium_public_health/treatment/__init__.py,sha256=wONElu9aJbBYwpYIovYPYaN_GYfV
42
43
  vivarium_public_health/treatment/magic_wand.py,sha256=zGIhrNgB9q6JD7fHlvbDQb3H5e_N_QsROO4Y0kl_JQM,1955
43
44
  vivarium_public_health/treatment/scale_up.py,sha256=hVz0ELXDqlpcExI31rKdepxqcW_hy2hZSa6qCzv6udU,7020
44
45
  vivarium_public_health/treatment/therapeutic_inertia.py,sha256=8Z97s7GfcpfLu1U1ESJSqeEk4L__a3M0GbBV21MFg2s,2346
45
- vivarium_public_health-3.1.3.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
46
- vivarium_public_health-3.1.3.dist-info/METADATA,sha256=csuhuXUl84A5QlA7S10sjMvafQXonRvAlwm2XSj-ck4,4028
47
- vivarium_public_health-3.1.3.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
48
- vivarium_public_health-3.1.3.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
49
- vivarium_public_health-3.1.3.dist-info/RECORD,,
46
+ vivarium_public_health-3.1.4.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
47
+ vivarium_public_health-3.1.4.dist-info/METADATA,sha256=7h8TxTjuThUizK3wr0mzLB31wwS4Kqr9SVgJA9Gf7gA,4028
48
+ vivarium_public_health-3.1.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
49
+ vivarium_public_health-3.1.4.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
50
+ vivarium_public_health-3.1.4.dist-info/RECORD,,