vivarium-public-health 2.1.4__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- __version__ = "2.1.4"
1
+ __version__ = "2.2.0"
@@ -0,0 +1 @@
1
+ from vivarium_public_health.plugins.parser import CausesConfigurationParser
@@ -0,0 +1,864 @@
1
+ """
2
+ ===============================
3
+ Component Configuration Parsers
4
+ ===============================
5
+
6
+ Component Configuration Parsers in this module are specialized implementations of
7
+ :class:`ComponentConfigurationParser <vivarium.framework.components.parser.ComponentConfigurationParser>`
8
+ that can parse configurations of components specific to the Vivarium Public
9
+ Health package.
10
+ """
11
+ from importlib import import_module
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+ import pandas as pd
15
+ from pkg_resources import resource_filename
16
+ from vivarium import Component, ConfigTree
17
+ from vivarium.framework.components import ComponentConfigurationParser
18
+ from vivarium.framework.components.parser import ParsingError
19
+ from vivarium.framework.engine import Builder
20
+ from vivarium.framework.state_machine import Trigger
21
+ from vivarium.framework.utilities import import_by_path
22
+
23
+ from vivarium_public_health.disease import (
24
+ BaseDiseaseState,
25
+ DiseaseModel,
26
+ DiseaseState,
27
+ RecoveredState,
28
+ SusceptibleState,
29
+ TransientDiseaseState,
30
+ )
31
+ from vivarium_public_health.utilities import TargetString
32
+
33
+
34
+ class CausesParsingErrors(ParsingError):
35
+ """
36
+ Error raised when there are any errors parsing a cause model configuration.
37
+ """
38
+
39
+ def __init__(self, messages: List[str]):
40
+ super().__init__("\n - " + "\n - ".join(messages))
41
+
42
+
43
+ class CausesConfigurationParser(ComponentConfigurationParser):
44
+ """
45
+ Component configuration parser that acts the same as the standard vivarium
46
+ `ComponentConfigurationParser` but adds the additional ability to parse a
47
+ configuration to create `DiseaseModel` components. These DiseaseModel
48
+ configurations can either be specified directly in the configuration in a
49
+ `causes` key or in external configuration files that are specified in the
50
+ `external_configuration` key.
51
+ """
52
+
53
+ DEFAULT_MODEL_CONFIG = {
54
+ "model_type": f"{DiseaseModel.__module__}.{DiseaseModel.__name__}",
55
+ "initial_state": None,
56
+ }
57
+ """
58
+ If a cause model configuration does not specify a model type or initial
59
+ state, these default values will be used. The default model type is
60
+ `DiseaseModel` and the
61
+ default initial state is `None`. If the initial state is not specified,
62
+ the cause model must have a state named 'susceptible'.
63
+ """
64
+
65
+ DEFAULT_STATE_CONFIG = {
66
+ "cause_type": "cause",
67
+ "transient": False,
68
+ "allow_self_transition": True,
69
+ "side_effect": None,
70
+ "cleanup_function": None,
71
+ "state_type": None,
72
+ }
73
+ """
74
+ If a state configuration does not specify cause_type, transient,
75
+ allow_self_transition, side_effect, cleanup_function, or state_type,
76
+ these default values will be used. The default cause type is 'cause', the
77
+ default transient value is False, and the default allow_self_transition
78
+ value is True.
79
+ """
80
+
81
+ DEFAULT_TRANSITION_CONFIG = {"triggered": "NOT_TRIGGERED"}
82
+ """
83
+ If a transition configuration does not specify a triggered value, this
84
+ default value will be used. The default triggered value is 'NOT_TRIGGERED'.
85
+ """
86
+
87
+ def parse_component_config(self, component_config: ConfigTree) -> List[Component]:
88
+ """
89
+ Parses the component configuration and returns a list of components.
90
+
91
+ In particular, this method looks for an `external_configuration` key
92
+ and/or a `causes` key.
93
+
94
+ The `external_configuration` key should have names of packages that
95
+ contain cause model configuration files. Within that key should be a list
96
+ of paths to cause model configuration files relative to the package.
97
+
98
+ .. code-block:: yaml
99
+
100
+ external_configuration:
101
+ some_package:
102
+ - some/path/cause_model_1.yaml
103
+ - some/path/cause_model_2.yaml
104
+
105
+ The `causes` key should contain configuration information for cause
106
+ models.
107
+
108
+ .. code-block:: yaml
109
+
110
+ causes:
111
+ cause_1:
112
+ model_type: vivarium_public_health.disease.DiseaseModel
113
+ initial_state: susceptible
114
+ states:
115
+ susceptible:
116
+ cause_type: cause
117
+ data_sources: {}
118
+ infected:
119
+ cause_type: cause
120
+ transient: false
121
+ allow_self_transition: true
122
+ data_sources: {}
123
+ transitions:
124
+ transition_1:
125
+ source: susceptible
126
+ sink: infected
127
+ transition_type: rate
128
+ data_sources: {}
129
+
130
+ # todo add information about the data_sources configuration
131
+
132
+ Note that this method modifies the simulation's component configuration
133
+ by adding the contents of external configuration files to the
134
+ `model_override` layer and adding default cause model configuration
135
+ values for all cause models to the `component_config` layer.
136
+
137
+ Parameters
138
+ ----------
139
+ component_config
140
+ A ConfigTree defining the components to initialize.
141
+
142
+ Returns
143
+ -------
144
+ List
145
+ A list of initialized components.
146
+
147
+ Raises
148
+ ------
149
+ CausesParsingErrors
150
+ If the cause model configuration is invalid
151
+ """
152
+ components = []
153
+
154
+ if "external_configuration" in component_config:
155
+ self._validate_external_configuration(component_config["external_configuration"])
156
+ for package, config_files in component_config["external_configuration"].items():
157
+ for config_file in config_files.get_value():
158
+ source = f"{package}::{config_file}"
159
+ config_file = resource_filename(package, config_file)
160
+
161
+ external_config = ConfigTree(config_file)
162
+ component_config.update(
163
+ external_config, layer="model_override", source=source
164
+ )
165
+
166
+ if "causes" in component_config:
167
+ causes_config = component_config["causes"]
168
+ self._validate_causes_config(causes_config)
169
+ self._add_default_config_layer(causes_config)
170
+ components += self._get_cause_model_components(causes_config)
171
+
172
+ # Parse standard components (i.e. not cause models)
173
+ standard_component_config = component_config.to_dict()
174
+ standard_component_config.pop("external_configuration", None)
175
+ standard_component_config.pop("causes", None)
176
+ standard_components = (
177
+ self.process_level(standard_component_config, [])
178
+ if standard_component_config
179
+ else []
180
+ )
181
+
182
+ return components + standard_components
183
+
184
+ #########################
185
+ # Configuration methods #
186
+ #########################
187
+
188
+ def _add_default_config_layer(self, causes_config: ConfigTree) -> None:
189
+ """
190
+ Adds a default layer to the provided configuration that specifies
191
+ default values for the cause model configuration.
192
+
193
+ Parameters
194
+ ----------
195
+ causes_config
196
+ A ConfigTree defining the cause model configurations
197
+
198
+ Returns
199
+ -------
200
+ None
201
+ """
202
+ default_config = {}
203
+ for cause_name, cause_config in causes_config.items():
204
+ default_states_config = {}
205
+ default_transitions_config = {}
206
+ default_config[cause_name] = {
207
+ **self.DEFAULT_MODEL_CONFIG,
208
+ "states": default_states_config,
209
+ "transitions": default_transitions_config,
210
+ }
211
+
212
+ for state_name, state_config in cause_config.states.items():
213
+ default_states_config[state_name] = self.DEFAULT_STATE_CONFIG
214
+
215
+ for transition_name, transition_config in cause_config.transitions.items():
216
+ default_transitions_config[transition_name] = self.DEFAULT_TRANSITION_CONFIG
217
+
218
+ causes_config.update(
219
+ default_config, layer="component_configs", source="causes_configuration_parser"
220
+ )
221
+
222
+ ################################
223
+ # Cause model creation methods #
224
+ ################################
225
+
226
+ def _get_cause_model_components(self, causes_config: ConfigTree) -> List[Component]:
227
+ """
228
+ Parses the cause model configuration and returns a list of
229
+ `DiseaseModel` components.
230
+
231
+ Parameters
232
+ ----------
233
+ causes_config
234
+ A ConfigTree defining the cause model components to initialize
235
+
236
+ Returns
237
+ -------
238
+ List[Component]
239
+ A list of initialized `DiseaseModel` components
240
+ """
241
+ cause_models = []
242
+
243
+ for cause_name, cause_config in causes_config.items():
244
+ states: Dict[str, BaseDiseaseState] = {
245
+ state_name: self._get_state(state_name, state_config, cause_name)
246
+ for state_name, state_config in cause_config.states.items()
247
+ }
248
+
249
+ for transition_config in cause_config.transitions.values():
250
+ self._add_transition(
251
+ states[transition_config.source],
252
+ states[transition_config.sink],
253
+ transition_config,
254
+ )
255
+
256
+ model_type = import_by_path(cause_config.model_type)
257
+ initial_state = states.get(cause_config.initial_state, None)
258
+ model = model_type(
259
+ cause_name, initial_state=initial_state, states=list(states.values())
260
+ )
261
+ cause_models.append(model)
262
+
263
+ return cause_models
264
+
265
+ def _get_state(
266
+ self, state_name: str, state_config: ConfigTree, cause_name: str
267
+ ) -> BaseDiseaseState:
268
+ """
269
+ Parses a state configuration and returns an initialized `BaseDiseaseState`
270
+ object.
271
+
272
+ Parameters
273
+ ----------
274
+ state_name
275
+ The name of the state to initialize
276
+ state_config
277
+ A ConfigTree defining the state to initialize
278
+ cause_name
279
+ The name of the cause to which the state belongs
280
+
281
+ Returns
282
+ -------
283
+ BaseDiseaseState
284
+ An initialized `BaseDiseaseState` object
285
+ """
286
+ state_id = cause_name if state_name in ["susceptible", "recovered"] else state_name
287
+ state_kwargs = {
288
+ "cause_type": state_config.cause_type,
289
+ "allow_self_transition": state_config.allow_self_transition,
290
+ }
291
+ if state_config.side_effect:
292
+ # todo handle side effects properly
293
+ state_kwargs["side_effect"] = lambda *x: x
294
+ if state_config.cleanup_function:
295
+ # todo handle cleanup functions properly
296
+ state_kwargs["cleanup_function"] = lambda *x: x
297
+ if "data_sources" in state_config:
298
+ data_sources_config = state_config.data_sources
299
+ state_kwargs["get_data_functions"] = {
300
+ name: self._get_data_source(name, data_sources_config[name])
301
+ for name in data_sources_config.keys()
302
+ }
303
+
304
+ if state_config.state_type is not None:
305
+ state_type = import_by_path(state_config.state_type)
306
+ elif state_config.transient:
307
+ state_type = TransientDiseaseState
308
+ elif state_name == "susceptible":
309
+ state_type = SusceptibleState
310
+ elif state_name == "recovered":
311
+ state_type = RecoveredState
312
+ else:
313
+ state_type = DiseaseState
314
+
315
+ state = state_type(state_id, **state_kwargs)
316
+ return state
317
+
318
+ def _add_transition(
319
+ self,
320
+ source_state: BaseDiseaseState,
321
+ sink_state: BaseDiseaseState,
322
+ transition_config: ConfigTree,
323
+ ) -> None:
324
+ """
325
+ Adds a transition between two states.
326
+
327
+ Parameters
328
+ ----------
329
+ source_state
330
+ The state the transition starts from
331
+ sink_state
332
+ The state the transition ends at
333
+ transition_config
334
+ A `ConfigTree` defining the transition to add
335
+
336
+ Returns
337
+ -------
338
+ None
339
+ """
340
+ triggered = Trigger[transition_config.triggered]
341
+ if "data_sources" in transition_config:
342
+ data_sources_config = transition_config.data_sources
343
+ data_sources = {
344
+ name: self._get_data_source(name, data_sources_config[name])
345
+ for name in data_sources_config.keys()
346
+ }
347
+ else:
348
+ data_sources = None
349
+
350
+ if transition_config["transition_type"] == "rate":
351
+ source_state.add_rate_transition(
352
+ sink_state, get_data_functions=data_sources, triggered=triggered
353
+ )
354
+ elif transition_config["transition_type"] == "proportion":
355
+ source_state.add_proportion_transition(
356
+ sink_state, get_data_functions=data_sources, triggered=triggered
357
+ )
358
+ elif transition_config["transition_type"] == "dwell_time":
359
+ source_state.add_dwell_time_transition(sink_state, triggered=triggered)
360
+ else:
361
+ raise ValueError(
362
+ f"Invalid transition data type '{transition_config.type}'"
363
+ f" provided for transition '{transition_config}'."
364
+ )
365
+
366
+ @staticmethod
367
+ def _get_data_source(
368
+ name: str, source: Union[str, float]
369
+ ) -> Callable[[Builder, Any], Any]:
370
+ """
371
+ Parses a data source and returns a callable that can be used to retrieve
372
+ the data.
373
+
374
+ Parameters
375
+ ----------
376
+ name
377
+ The name of the data getter
378
+ source
379
+ The data source to parse
380
+
381
+ Returns
382
+ -------
383
+ Callable[[Builder, Any], Any]
384
+ A callable that can be used to retrieve the data
385
+ """
386
+ if isinstance(source, float):
387
+ return lambda builder, *_: source
388
+
389
+ try:
390
+ timedelta = pd.Timedelta(source)
391
+ return lambda builder, *_: timedelta
392
+ except ValueError:
393
+ pass
394
+
395
+ if "::" in source:
396
+ module, method = source.split("::")
397
+ return getattr(import_module(module), method)
398
+
399
+ try:
400
+ target_string = TargetString(source)
401
+ return lambda builder, *_: builder.data.load(target_string)
402
+ except ValueError:
403
+ pass
404
+
405
+ raise ValueError(f"Invalid data source '{source}' for '{name}'.")
406
+
407
+ ######################
408
+ # Validation methods #
409
+ ######################
410
+
411
+ _CAUSE_KEYS = {"model_type", "initial_state", "states", "transitions"}
412
+ _STATE_KEYS = {
413
+ "state_type",
414
+ "cause_type",
415
+ "transient",
416
+ "allow_self_transition",
417
+ "side_effect",
418
+ "data_sources",
419
+ "cleanup_function",
420
+ }
421
+
422
+ _DATA_SOURCE_KEYS = {
423
+ "state": {
424
+ "prevalence",
425
+ "birth_prevalence",
426
+ "dwell_time",
427
+ "disability_weight",
428
+ "excess_mortality_rate",
429
+ },
430
+ "rate_transition": {
431
+ "incidence_rate",
432
+ "transition_rate",
433
+ "remission_rate",
434
+ },
435
+ "proportion_transition": {"proportion"},
436
+ }
437
+ _TRANSITION_KEYS = {"source", "sink", "transition_type", "triggered", "data_sources"}
438
+ _TRANSITION_TYPE_KEYS = {"rate", "proportion", "dwell_time"}
439
+
440
+ @staticmethod
441
+ def _validate_external_configuration(external_configuration: ConfigTree) -> None:
442
+ """
443
+ Validates the external configuration.
444
+
445
+ Parameters
446
+ ----------
447
+ external_configuration
448
+ A ConfigTree defining the external configuration
449
+
450
+ Returns
451
+ -------
452
+ None
453
+
454
+ Raises
455
+ ------
456
+ CausesParsingErrors
457
+ If the external configuration is invalid
458
+ """
459
+ external_configuration = external_configuration.to_dict()
460
+ error_messages = []
461
+ for package, config_files in external_configuration.items():
462
+ if not isinstance(package, str):
463
+ error_messages.append(
464
+ f"Key '{package}' must be a string definition of a package."
465
+ )
466
+ if not isinstance(config_files, list):
467
+ error_messages.append(
468
+ f"External configuration for package '{package}' must be a list."
469
+ )
470
+ else:
471
+ for config_file in config_files:
472
+ if not isinstance(config_file, str) or not config_file.endswith(".yaml"):
473
+ error_messages.append(
474
+ f"External configuration for package '{package}' must "
475
+ "be a list of paths to yaml files."
476
+ )
477
+ if error_messages:
478
+ raise CausesParsingErrors(error_messages)
479
+
480
+ def _validate_causes_config(self, causes_config: ConfigTree) -> None:
481
+ """
482
+ Validates the cause model configuration.
483
+
484
+ Parameters
485
+ ----------
486
+ causes_config
487
+ A ConfigTree defining the cause model configurations
488
+
489
+ Returns
490
+ -------
491
+ None
492
+
493
+ Raises
494
+ ------
495
+ CausesParsingErrors
496
+ If the cause model configuration is invalid
497
+ """
498
+ causes_config = causes_config.to_dict()
499
+ error_messages = []
500
+ for cause_name, cause_config in causes_config.items():
501
+ error_messages += self._validate_cause(cause_name, cause_config)
502
+
503
+ if error_messages:
504
+ raise CausesParsingErrors(error_messages)
505
+
506
+ def _validate_cause(self, cause_name: str, cause_config: Dict[str, Any]) -> List[str]:
507
+ """
508
+ Validates a cause configuration and returns a list of error messages.
509
+
510
+ Parameters
511
+ ----------
512
+ cause_name
513
+ The name of the cause to validate
514
+ cause_config
515
+ A ConfigTree defining the cause to validate
516
+
517
+ Returns
518
+ -------
519
+ List[str]
520
+ A list of error messages
521
+ """
522
+ error_messages = []
523
+ if not isinstance(cause_config, dict):
524
+ error_messages.append(
525
+ f"Cause configuration for cause '{cause_name}' must be a dictionary."
526
+ )
527
+ return error_messages
528
+
529
+ if not set(cause_config.keys()).issubset(self._CAUSE_KEYS):
530
+ error_messages.append(
531
+ f"Cause configuration for cause '{cause_name}' may only"
532
+ " contain the following keys: "
533
+ f"{self._CAUSE_KEYS}."
534
+ )
535
+
536
+ if "model_type" in cause_config:
537
+ error_messages += self._validate_imported_type(
538
+ cause_config["model_type"], cause_name, "model"
539
+ )
540
+
541
+ states_config = cause_config.get("states", {})
542
+ if not states_config:
543
+ error_messages.append(
544
+ f"Cause configuration for cause '{cause_name}' must define "
545
+ "at least one state."
546
+ )
547
+
548
+ if not isinstance(states_config, dict):
549
+ error_messages.append(
550
+ f"States configuration for cause '{cause_name}' must be a dictionary."
551
+ )
552
+ else:
553
+ initial_state = cause_config.get("initial_state", None)
554
+ if initial_state is not None and initial_state not in states_config:
555
+ error_messages.append(
556
+ f"Initial state '{cause_config['initial_state']}' for cause "
557
+ f"'{cause_name}' must be present in the states for cause "
558
+ f"'{cause_name}."
559
+ )
560
+ for state_name, state_config in states_config.items():
561
+ error_messages += self._validate_state(cause_name, state_name, state_config)
562
+
563
+ transitions_config = cause_config.get("transitions", {})
564
+ if not isinstance(transitions_config, dict):
565
+ error_messages.append(
566
+ f"Transitions configuration for cause '{cause_name}' must be "
567
+ "a dictionary if it is present."
568
+ )
569
+ else:
570
+ for transition_name, transition_config in transitions_config.items():
571
+ error_messages += self._validate_transition(
572
+ cause_name, transition_name, transition_config, states_config
573
+ )
574
+
575
+ return error_messages
576
+
577
+ def _validate_state(
578
+ self, cause_name: str, state_name: str, state_config: Dict[str, Any]
579
+ ) -> List[str]:
580
+ """
581
+ Validates a state configuration and returns a list of error messages.
582
+
583
+ Parameters
584
+ ----------
585
+ cause_name
586
+ The name of the cause to which the state belongs
587
+ state_name
588
+ The name of the state to validate
589
+ state_config
590
+ A ConfigTree defining the state to validate
591
+
592
+ Returns
593
+ -------
594
+ List[str]
595
+ A list of error messages
596
+ """
597
+ error_messages = []
598
+
599
+ if not isinstance(state_config, dict):
600
+ error_messages.append(
601
+ f"State configuration for in cause '{cause_name}' and "
602
+ f"state '{state_name}' must be a dictionary."
603
+ )
604
+ return error_messages
605
+
606
+ allowable_keys = set(self._STATE_KEYS)
607
+ if state_name in ["susceptible", "recovered"]:
608
+ allowable_keys.remove("data_sources")
609
+
610
+ if not set(state_config.keys()).issubset(allowable_keys):
611
+ error_messages.append(
612
+ f"State configuration for in cause '{cause_name}' and "
613
+ f"state '{state_name}' may only contain the following "
614
+ f"keys: {allowable_keys}."
615
+ )
616
+
617
+ state_type = state_config.get("state_type", "")
618
+ error_messages += self._validate_imported_type(
619
+ state_type, cause_name, "state", state_name
620
+ )
621
+
622
+ if not isinstance(state_config.get("cause_type", ""), str):
623
+ error_messages.append(
624
+ f"Cause type for state '{state_name}' in cause '{cause_name}' "
625
+ f"must be a string. Provided {state_config['cause_type']}."
626
+ )
627
+ is_transient = state_config.get("transient", False)
628
+ if not isinstance(is_transient, bool):
629
+ error_messages.append(
630
+ f"Transient flag for state '{state_name}' in cause '{cause_name}' "
631
+ f"must be a boolean. Provided {state_config['transient']}."
632
+ )
633
+
634
+ if state_name in ["susceptible", "recovered"] and state_type:
635
+ error_messages.append(
636
+ f"The name '{state_name}' in cause '{cause_name}' concretely "
637
+ f"specifies the state type, so state_type is not an allowed "
638
+ "configuration."
639
+ )
640
+
641
+ if state_name in ["susceptible", "recovered"] and is_transient:
642
+ error_messages.append(
643
+ f"The name '{state_name}' in cause '{cause_name}' concretely "
644
+ f"specifies the state type, so transient is not an allowed "
645
+ "configuration."
646
+ )
647
+
648
+ if is_transient and state_type:
649
+ error_messages.append(
650
+ f"Specifying transient as True for state '{state_name}' in cause "
651
+ f"'{cause_name}' concretely specifies the state type, so "
652
+ "state_type is not an allowed configuration."
653
+ )
654
+
655
+ if not isinstance(state_config.get("allow_self_transition", True), bool):
656
+ error_messages.append(
657
+ f"Allow self transition flag for state '{state_name}' in cause "
658
+ f"'{cause_name}' must be a boolean. Provided "
659
+ f"'{state_config['allow_self_transition']}'."
660
+ )
661
+
662
+ error_messages += self._validate_data_sources(
663
+ state_config, cause_name, "state", state_name
664
+ )
665
+
666
+ return error_messages
667
+
668
+ def _validate_transition(
669
+ self,
670
+ cause_name: str,
671
+ transition_name: str,
672
+ transition_config: Dict[str, Any],
673
+ states_config: Dict[str, Any],
674
+ ) -> List[str]:
675
+ """
676
+ Validates a transition configuration and returns a list of error messages.
677
+
678
+ Parameters
679
+ ----------
680
+ cause_name
681
+ The name of the cause to which the transition belongs
682
+ transition_name
683
+ The name of the transition to validate
684
+ transition_config
685
+ A ConfigTree defining the transition to validate
686
+ states_config
687
+ A ConfigTree defining the states for the cause
688
+
689
+ Returns
690
+ -------
691
+ List[str]
692
+ A list of error messages
693
+ """
694
+ error_messages = []
695
+
696
+ if not isinstance(transition_config, dict):
697
+ error_messages.append(
698
+ f"Transition configuration for in cause '{cause_name}' and "
699
+ f"transition '{transition_name}' must be a dictionary."
700
+ )
701
+ return error_messages
702
+
703
+ if not set(transition_config.keys()).issubset(
704
+ CausesConfigurationParser._TRANSITION_KEYS
705
+ ):
706
+ error_messages.append(
707
+ f"Transition configuration for in cause '{cause_name}' and "
708
+ f"transition '{transition_name}' may only contain the "
709
+ f"following keys: {self._TRANSITION_KEYS}."
710
+ )
711
+ source = transition_config.get("source", None)
712
+ sink = transition_config.get("sink", None)
713
+ if sink is None or source is None:
714
+ error_messages.append(
715
+ f"Transition configuration for in cause '{cause_name}' and "
716
+ f"transition '{transition_name}' must contain both a source "
717
+ f"and a sink."
718
+ )
719
+
720
+ if source is not None and source not in states_config:
721
+ error_messages.append(
722
+ f"Transition configuration for in cause '{cause_name}' and "
723
+ f"transition '{transition_name}' must contain a source that "
724
+ f"is present in the states."
725
+ )
726
+
727
+ if sink is not None and sink not in states_config:
728
+ error_messages.append(
729
+ f"Transition configuration for in cause '{cause_name}' and "
730
+ f"transition '{transition_name}' must contain a sink that "
731
+ f"is present in the states."
732
+ )
733
+
734
+ if (
735
+ "triggered" in transition_config
736
+ and transition_config["triggered"] not in Trigger.__members__
737
+ ):
738
+ error_messages.append(
739
+ f"Transition configuration for in cause '{cause_name}' and "
740
+ f"transition '{transition_name}' may only have one of the following "
741
+ f"values: {Trigger.__members__}."
742
+ )
743
+
744
+ if "transition_type" not in transition_config:
745
+ error_messages.append(
746
+ f"Transition configuration for in cause '{cause_name}' and "
747
+ f"transition '{transition_name}' must contain a transition type."
748
+ )
749
+ else:
750
+ transition_type = transition_config["transition_type"]
751
+ if transition_type not in self._TRANSITION_TYPE_KEYS:
752
+ error_messages.append(
753
+ f"Transition configuration for in cause '{cause_name}' and "
754
+ f"transition '{transition_name}' may only contain the "
755
+ f"following values: {self._TRANSITION_TYPE_KEYS}."
756
+ )
757
+ if transition_type == "dwell_time" and "data_sources" in transition_config:
758
+ error_messages.append(
759
+ f"Transition configuration for in cause '{cause_name}' and "
760
+ f"transition '{transition_name}' is a dwell-time transition and "
761
+ f"may not have data sources as dwell-time is configured on the state."
762
+ )
763
+ elif transition_type in self._TRANSITION_TYPE_KEYS.difference({"dwell_time"}):
764
+ error_messages += self._validate_data_sources(
765
+ transition_config,
766
+ cause_name,
767
+ f"{transition_type}_transition",
768
+ transition_name,
769
+ )
770
+ return error_messages
771
+
772
+ @staticmethod
773
+ def _validate_imported_type(
774
+ import_path: str, cause_name: str, entity_type: str, entity_name: Optional[str] = None
775
+ ) -> List[str]:
776
+ """
777
+ Validates an imported type and returns a list of error messages.
778
+
779
+ Parameters
780
+ ----------
781
+ import_path
782
+ The import path to validate
783
+ cause_name
784
+ The name of the cause to which the imported type belongs
785
+ entity_type
786
+ The type of the entity to which the imported type belongs
787
+ entity_name
788
+ The name of the entity to which the imported type belongs, if it is
789
+ not a cause
790
+
791
+ Returns
792
+ -------
793
+ List[str]
794
+ A list of error messages
795
+ """
796
+ expected_type = {"model": DiseaseModel, "state": BaseDiseaseState}[entity_type]
797
+
798
+ error_messages = []
799
+ if not import_path:
800
+ return error_messages
801
+
802
+ try:
803
+ imported_type = import_by_path(import_path)
804
+ if not (
805
+ isinstance(imported_type, type) and issubclass(imported_type, expected_type)
806
+ ):
807
+ raise TypeError
808
+ except (ModuleNotFoundError, AttributeError, TypeError, ValueError):
809
+ error_messages.append(
810
+ f"If '{entity_type}_type' is provided for cause '{cause_name}' "
811
+ f"{f'and {entity_type} {entity_name} ' if entity_name else ''}it "
812
+ f"must be the fully qualified import path to a {expected_type} "
813
+ f"implementation. Provided'{import_path}'."
814
+ )
815
+ return error_messages
816
+
817
+ def _validate_data_sources(
818
+ self, config: Dict[str, Any], cause_name: str, config_type: str, config_name: str
819
+ ) -> List[str]:
820
+ """
821
+ Validates the data sources in a configuration and returns a list of
822
+ error messages.
823
+
824
+ Parameters
825
+ ----------
826
+ config
827
+ A ConfigTree defining the configuration to validate
828
+ cause_name
829
+ The name of the cause to which the configuration belongs
830
+ config_type
831
+ The type of the configuration to validate
832
+ config_name
833
+ The name of the configuration being validated
834
+
835
+ Returns
836
+ -------
837
+ List[str]
838
+ A list of error messages
839
+ """
840
+ error_messages = []
841
+ data_sources_config = config.get("data_sources", {})
842
+ if not isinstance(data_sources_config, dict):
843
+ error_messages.append(
844
+ f"Data sources configuration for {config_type} '{config}' in "
845
+ f"cause '{cause_name}' must be a dictionary if it is present."
846
+ )
847
+ return error_messages
848
+
849
+ if not set(data_sources_config.keys()).issubset(self._DATA_SOURCE_KEYS[config_type]):
850
+ error_messages.append(
851
+ f"Data sources configuration for {config_type} '{config_name}' "
852
+ f"in cause '{cause_name}' may only contain the following keys: "
853
+ f"{self._DATA_SOURCE_KEYS[config_type]}."
854
+ )
855
+
856
+ for config_name, source in data_sources_config.items():
857
+ try:
858
+ self._get_data_source(config_name, source)
859
+ except ValueError:
860
+ error_messages.append(
861
+ f"Configuration for {config_type} '{config_name}' in cause "
862
+ f"'{cause_name}' has an invalid data source at '{source}'."
863
+ )
864
+ return error_messages
@@ -7,6 +7,7 @@ This module contains tools for handling raw demographic data and transforming
7
7
  it into different distributions for sampling.
8
8
 
9
9
  """
10
+
10
11
  from collections import namedtuple
11
12
  from typing import Tuple, Union
12
13
 
@@ -54,23 +55,23 @@ def assign_demographic_proportions(
54
55
 
55
56
  population_data["P(sex, location, age| year)"] = (
56
57
  population_data.groupby("year_start", as_index=False)
57
- .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum())
58
- .reset_index(level=0)
59
- .value.fillna(0.0)
58
+ .apply(lambda sub_pop: sub_pop[["value"]] / sub_pop["value"].sum())
59
+ .reset_index(level=0)["value"]
60
+ .fillna(0.0)
60
61
  )
61
62
 
62
63
  population_data["P(sex, location | age, year)"] = (
63
64
  population_data.groupby(["age", "year_start"], as_index=False)
64
- .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum())
65
- .reset_index(level=0)
66
- .value.fillna(0.0)
65
+ .apply(lambda sub_pop: sub_pop[["value"]] / sub_pop["value"].sum())
66
+ .reset_index(level=0)["value"]
67
+ .fillna(0.0)
67
68
  )
68
69
 
69
70
  population_data["P(age | year, sex, location)"] = (
70
71
  population_data.groupby(["year_start", "sex", "location"], as_index=False)
71
- .apply(lambda sub_pop: sub_pop.value / sub_pop.value.sum())
72
- .reset_index(level=0)
73
- .value.fillna(0.0)
72
+ .apply(lambda sub_pop: sub_pop[["value"]] / sub_pop["value"].sum())
73
+ .reset_index(level=0)["value"]
74
+ .fillna(0.0)
74
75
  )
75
76
 
76
77
  return population_data.sort_values(_SORT_ORDER).reset_index(drop=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vivarium_public_health
3
- Version: 2.1.4
3
+ Version: 2.2.0
4
4
  Summary: Components for modelling diseases, risks, and interventions with ``vivarium``
5
5
  Home-page: https://github.com/ihmeuw/vivarium_public_health
6
6
  Author: The vivarium developers
@@ -27,6 +27,7 @@ Classifier: Topic :: Scientific/Engineering :: Physics
27
27
  Classifier: Topic :: Software Development :: Libraries
28
28
  License-File: LICENSE.txt
29
29
  Requires-Dist: vivarium >=2.0.0
30
+ Requires-Dist: loguru
30
31
  Requires-Dist: numpy
31
32
  Requires-Dist: pandas
32
33
  Requires-Dist: scipy
@@ -41,6 +42,7 @@ Requires-Dist: matplotlib ; extra == 'dev'
41
42
  Requires-Dist: pytest ; extra == 'dev'
42
43
  Requires-Dist: pytest-mock ; extra == 'dev'
43
44
  Requires-Dist: hypothesis ; extra == 'dev'
45
+ Requires-Dist: pyyaml ; extra == 'dev'
44
46
  Provides-Extra: docs
45
47
  Requires-Dist: sphinx <7.0 ; extra == 'docs'
46
48
  Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
@@ -51,6 +53,7 @@ Provides-Extra: test
51
53
  Requires-Dist: pytest ; extra == 'test'
52
54
  Requires-Dist: pytest-mock ; extra == 'test'
53
55
  Requires-Dist: hypothesis ; extra == 'test'
56
+ Requires-Dist: pyyaml ; extra == 'test'
54
57
 
55
58
  Vivarium Public Health
56
59
  ======================
@@ -1,6 +1,6 @@
1
1
  vivarium_public_health/__about__.py,sha256=RgWycPypKZS80TpSX7o41cREnG8PfguNHDHLuLyl820,487
2
2
  vivarium_public_health/__init__.py,sha256=tomMOl3PI7O8GdxDWGBiBjT0Bwd31GpyQTYTzwIv108,361
3
- vivarium_public_health/_version.py,sha256=1yR20YsjyDpnFQgDIQmHfutaSsaW0F7mDqjloRVRIG8,22
3
+ vivarium_public_health/_version.py,sha256=DKk-1b-rZsJFxFi1JoJ7TmEvIEQ0rf-C9HAZWwvjuM0,22
4
4
  vivarium_public_health/utilities.py,sha256=_X9sZQ7flsi2sVWQ9zrf8GJw8QwZsPZm3NUjx1gu7bM,2555
5
5
  vivarium_public_health/disease/__init__.py,sha256=RuuiRcvAJfX9WQGt_WZZjxN7Cu3E5rMTmuaRS-UaFPM,419
6
6
  vivarium_public_health/disease/model.py,sha256=7xFGo6JqPjQjm2VUZ3u3ThXWSzmitTOCZ8N9PTTL6MU,8253
@@ -21,10 +21,12 @@ vivarium_public_health/mslt/intervention.py,sha256=q6-ZHYFuV9o7WKlHMpiFAsEH2fYzw
21
21
  vivarium_public_health/mslt/magic_wand_components.py,sha256=7sy_7fa0R5pk5XPdt9-AfK005JV3em4Tvu4L4xeg4g0,3980
22
22
  vivarium_public_health/mslt/observer.py,sha256=Z0aWLrlSjxLuYqzXarNunJ2Xqei4R9nX03dnE6uvr4o,13940
23
23
  vivarium_public_health/mslt/population.py,sha256=R5Z7aBf63LDbasPVMMI0HTh41FIL_OAhYk0qQWf_-lU,7285
24
+ vivarium_public_health/plugins/__init__.py,sha256=oBW_zfgG_LbwfgTDjUe0btfy9FaDvAbtXho1zQFnz0Y,76
25
+ vivarium_public_health/plugins/parser.py,sha256=uhBw5t-Lmb8YDN2GvVG93l50ZuCIsg4VocSA5T_wz3w,31789
24
26
  vivarium_public_health/population/__init__.py,sha256=17rtbcNVK5LtCCxAex7P7Q_vYpwbeTepyf3nazS90Yc,225
25
27
  vivarium_public_health/population/add_new_birth_cohorts.py,sha256=qNsZjvaJ7Et8_Kw7JNyRshHHRA3pEQMM4TSqCp48Gr4,9092
26
28
  vivarium_public_health/population/base_population.py,sha256=U9FibBoPuYWvUqPFCUIVwjBxQVuhTb58cUy-2EJSzio,15345
27
- vivarium_public_health/population/data_transformations.py,sha256=M1qZlDyWGQsVR33LtrKW2ts1OIXOT2TYoQEtjzXQ_1k,21804
29
+ vivarium_public_health/population/data_transformations.py,sha256=3MNDrfaZMMGM3crNzucGfYWZOTGHoavIcu8wxq1xIU8,21838
28
30
  vivarium_public_health/population/mortality.py,sha256=w1Oxb958LjUkNwxJ0vdA3TZndpeNiaH3d7RukLas_oQ,10085
29
31
  vivarium_public_health/risks/__init__.py,sha256=XvX12RgD0iF5PBoc2StsOhxJmK1FP-RaAYrjIT9MfDs,232
30
32
  vivarium_public_health/risks/base_risk.py,sha256=6D7YlxQOdQm-Kw5_vjpQmFqU7spF-lTy14WEEefRQlA,6494
@@ -40,8 +42,8 @@ vivarium_public_health/treatment/__init__.py,sha256=wONElu9aJbBYwpYIovYPYaN_GYfV
40
42
  vivarium_public_health/treatment/magic_wand.py,sha256=iPKFN3VjfiMy_XvN94UqM-FUrGuI0ULwmOdAGdOepYQ,1979
41
43
  vivarium_public_health/treatment/scale_up.py,sha256=7QKBgAII4dwkds9gdbQ5d6oDaD02iwcQCVcYRN-B4Mg,7573
42
44
  vivarium_public_health/treatment/therapeutic_inertia.py,sha256=VwZ7t90zzfGoBusduIvcE4lDe5zTvzmHiUNB3u2I52Y,2339
43
- vivarium_public_health-2.1.4.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
44
- vivarium_public_health-2.1.4.dist-info/METADATA,sha256=4sHNnYPO59LcNUgAFKdzBn2AQU1aq_cppS9RiZJ0GQw,3430
45
- vivarium_public_health-2.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
46
- vivarium_public_health-2.1.4.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
47
- vivarium_public_health-2.1.4.dist-info/RECORD,,
45
+ vivarium_public_health-2.2.0.dist-info/LICENSE.txt,sha256=mN4bNLUQNcN9njYRc_3jCZkfPySVpmM6MRps104FxA4,1548
46
+ vivarium_public_health-2.2.0.dist-info/METADATA,sha256=AyFXCJ4aiPYreEVtAkAFpOODBMkUCIs0pXiTp5W-dy8,3531
47
+ vivarium_public_health-2.2.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
48
+ vivarium_public_health-2.2.0.dist-info/top_level.txt,sha256=VVInlpzCFD0UNNhjOq_j-a29odzjwUwYFTGfvqbi4dY,23
49
+ vivarium_public_health-2.2.0.dist-info/RECORD,,