vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -7,52 +7,86 @@ This module contains tools for modeling several different risk
7
7
  exposure distributions.
8
8
 
9
9
  """
10
- from typing import Dict, List
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import Callable, Dict, List, Optional, Union
11
13
 
12
14
  import numpy as np
13
15
  import pandas as pd
14
- from risk_distributions import EnsembleDistribution, LogNormal, Normal
16
+ import risk_distributions as rd
17
+ from layered_config_tree import LayeredConfigTree
15
18
  from vivarium import Component
16
19
  from vivarium.framework.engine import Builder
17
20
  from vivarium.framework.population import SimulantData
18
21
  from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
19
22
 
20
- from vivarium_public_health.risks.data_transformations import get_distribution_data
21
- from vivarium_public_health.utilities import EntityString
23
+ from vivarium_public_health.risks.data_transformations import pivot_categorical
24
+ from vivarium_public_health.utilities import EntityString, get_lookup_columns
22
25
 
23
26
 
24
27
  class MissingDataError(Exception):
25
28
  pass
26
29
 
27
30
 
28
- # FIXME: This is a hack. It's wrapping up an adaptor pattern in another
29
- # adaptor pattern, which is gross, but would require some more difficult
30
- # refactoring which is thoroughly out of scope right now. -J.C. 8/25/19
31
- class SimulationDistribution(Component):
32
- """Wrapper around a variety of distribution implementations."""
31
+ class RiskExposureDistribution(Component, ABC):
33
32
 
34
33
  #####################
35
34
  # Lifecycle methods #
36
35
  #####################
37
36
 
38
- def __init__(self, risk: str):
37
+ def __init__(
38
+ self,
39
+ risk: EntityString,
40
+ distribution_type: str,
41
+ exposure_data: Optional[Union[int, float, pd.DataFrame]] = None,
42
+ ) -> None:
39
43
  super().__init__()
40
- self.risk = EntityString(risk)
44
+ self.risk = risk
45
+ self.distribution_type = distribution_type
46
+ self._exposure_data = exposure_data
47
+
48
+ self.parameters_pipeline_name = f"{self.risk}.exposure_parameters"
41
49
 
50
+ #################
51
+ # Setup methods #
52
+ #################
53
+
54
+ def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]:
55
+ return builder.configuration[self.risk]
56
+
57
+ @abstractmethod
58
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
59
+ raise NotImplementedError
60
+
61
+ def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
62
+ if self._exposure_data is not None:
63
+ return self._exposure_data
64
+ return self.get_data(builder, self.configuration["data_sources"]["exposure"])
65
+
66
+ # noinspection PyAttributeOutsideInit
42
67
  def setup(self, builder: Builder) -> None:
43
- distribution_data = get_distribution_data(builder, self.risk)
44
- self.implementation = get_distribution(self.risk, **distribution_data)
45
- self.implementation.setup_component(builder)
68
+ self.exposure_parameters = self.get_exposure_parameter_pipeline(builder)
69
+ if self.exposure_parameters.name != self.parameters_pipeline_name:
70
+ raise ValueError(
71
+ "Expected exposure parameters pipeline to be named "
72
+ f"{self.parameters_pipeline_name}, "
73
+ f"but found {self.exposure_parameters.name}."
74
+ )
75
+
76
+ @abstractmethod
77
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
78
+ raise NotImplementedError
46
79
 
47
80
  ##################
48
81
  # Public methods #
49
82
  ##################
50
83
 
51
- def ppf(self, q):
52
- return self.implementation.ppf(q)
84
+ @abstractmethod
85
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
86
+ raise NotImplementedError
53
87
 
54
88
 
55
- class EnsembleSimulation(Component):
89
+ class EnsembleDistribution(RiskExposureDistribution):
56
90
  ##############
57
91
  # Properties #
58
92
  ##############
@@ -73,38 +107,71 @@ class EnsembleSimulation(Component):
73
107
  # Lifecycle methods #
74
108
  #####################
75
109
 
76
- def __init__(self, risk, weights, mean, sd):
77
- super().__init__()
78
- self.risk = EntityString(risk)
79
- self._weights, self._parameters = self.get_parameters(weights, mean, sd)
110
+ def __init__(self, risk: EntityString, distribution_type: str = "ensemble") -> None:
111
+ super().__init__(risk, distribution_type)
80
112
  self._propensity = f"ensemble_propensity_{self.risk}"
81
113
 
82
- def setup(self, builder: Builder) -> None:
83
- self.weights = builder.lookup.build_table(
84
- self._weights, key_columns=["sex"], parameter_columns=["age", "year"]
114
+ #################
115
+ # Setup methods #
116
+ #################
117
+
118
+ def build_all_lookup_tables(self, builder: Builder) -> None:
119
+ exposure_data = self.get_exposure_data(builder)
120
+ standard_deviation = self.get_data(
121
+ builder,
122
+ self.configuration["data_sources"]["exposure_standard_deviation"],
123
+ )
124
+ weights_source = self.configuration["data_sources"]["ensemble_distribution_weights"]
125
+ raw_weights = self.get_data(builder, weights_source)
126
+
127
+ glnorm_mask = raw_weights["parameter"] == "glnorm"
128
+ if np.any(raw_weights.loc[glnorm_mask, self.get_value_columns(weights_source)]):
129
+ raise NotImplementedError("glnorm distribution is not supported")
130
+ raw_weights = raw_weights[~glnorm_mask]
131
+
132
+ distributions = list(raw_weights["parameter"].unique())
133
+
134
+ raw_weights = pivot_categorical(
135
+ builder, self.risk, raw_weights, pivot_column="parameter", reset_index=False
85
136
  )
137
+
138
+ weights, parameters = rd.EnsembleDistribution.get_parameters(
139
+ raw_weights,
140
+ mean=get_risk_distribution_parameter(self.get_value_columns, exposure_data),
141
+ sd=get_risk_distribution_parameter(self.get_value_columns, standard_deviation),
142
+ )
143
+
144
+ distribution_weights_table = self.build_lookup_table(
145
+ builder, weights.reset_index(), distributions
146
+ )
147
+ self.lookup_tables["ensemble_distribution_weights"] = distribution_weights_table
148
+ key_columns = distribution_weights_table.key_columns
149
+ parameter_columns = distribution_weights_table.parameter_columns
150
+
86
151
  self.parameters = {
87
- k: builder.lookup.build_table(
88
- v, key_columns=["sex"], parameter_columns=["age", "year"]
152
+ parameter: builder.lookup.build_table(
153
+ data.reset_index(),
154
+ key_columns=key_columns,
155
+ parameter_columns=parameter_columns,
89
156
  )
90
- for k, v in self._parameters.items()
157
+ for parameter, data in parameters.items()
91
158
  }
92
159
 
160
+ def setup(self, builder: Builder) -> None:
161
+ super().setup(builder)
93
162
  self.randomness = builder.randomness.get_stream(self._propensity)
94
163
 
95
- ##########################
96
- # Initialization methods #
97
- ##########################
98
-
99
- def get_parameters(self, weights, mean, sd):
100
- index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
101
- weights = weights.set_index(index_cols)
102
- mean = mean.set_index(index_cols)["value"]
103
- sd = sd.set_index(index_cols)["value"]
104
- weights, parameters = EnsembleDistribution.get_parameters(weights, mean=mean, sd=sd)
105
- return weights.reset_index(), {
106
- name: p.reset_index() for name, p in parameters.items()
107
- }
164
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
165
+ # This pipeline is not needed for ensemble distributions, so just
166
+ # register a dummy pipeline
167
+ def raise_not_implemented():
168
+ raise NotImplementedError(
169
+ "EnsembleDistribution does not use exposure parameters."
170
+ )
171
+
172
+ return builder.value.register_value_producer(
173
+ self.parameters_pipeline_name, lambda *_: raise_not_implemented()
174
+ )
108
175
 
109
176
  ########################
110
177
  # Event-driven methods #
@@ -120,149 +187,256 @@ class EnsembleSimulation(Component):
120
187
  # Public methods #
121
188
  ##################
122
189
 
123
- def ppf(self, q):
124
- if not q.empty:
125
- q = clip(q)
126
- weights = self.weights(q.index)
190
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
191
+ if not quantiles.empty:
192
+ quantiles = clip(quantiles)
193
+ weights = self.lookup_tables["ensemble_distribution_weights"](quantiles.index)
127
194
  parameters = {
128
- name: parameter(q.index) for name, parameter in self.parameters.items()
195
+ name: param(quantiles.index) for name, param in self.parameters.items()
129
196
  }
130
- ensemble_propensity = self.population_view.get(q.index).iloc[:, 0]
131
- x = EnsembleDistribution(weights, parameters).ppf(q, ensemble_propensity)
197
+ ensemble_propensity = self.population_view.get(quantiles.index).iloc[:, 0]
198
+ x = rd.EnsembleDistribution(weights, parameters).ppf(
199
+ quantiles, ensemble_propensity
200
+ )
132
201
  x[x.isnull()] = 0
133
202
  else:
134
203
  x = pd.Series([])
135
204
  return x
136
205
 
137
206
 
138
- class ContinuousDistribution(Component):
207
+ class ContinuousDistribution(RiskExposureDistribution):
139
208
  #####################
140
209
  # Lifecycle methods #
141
210
  #####################
142
211
 
143
- def __init__(self, risk, mean, sd, distribution=None):
144
- super().__init__()
145
- self.risk = EntityString(risk)
146
- self._distribution = distribution
147
- self._parameters = self.get_parameters(mean, sd)
212
+ def __init__(self, risk: EntityString, distribution_type: str) -> None:
213
+ super().__init__(risk, distribution_type)
214
+ self.standard_deviation = None
215
+ try:
216
+ self._distribution = {
217
+ "normal": rd.Normal,
218
+ "lognormal": rd.LogNormal,
219
+ }[distribution_type]
220
+ except KeyError:
221
+ raise NotImplementedError(
222
+ f"Distribution type {distribution_type} is not supported for "
223
+ f"risk {risk.name}."
224
+ )
148
225
 
149
- def setup(self, builder: Builder) -> None:
150
- self.parameters = builder.lookup.build_table(
151
- self._parameters, key_columns=["sex"], parameter_columns=["age", "year"]
226
+ #################
227
+ # Setup methods #
228
+ #################
229
+
230
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
231
+ exposure_data = self.get_exposure_data(builder)
232
+ standard_deviation = self.get_data(
233
+ builder, self.configuration["data_sources"]["exposure_standard_deviation"]
234
+ )
235
+ parameters = self._distribution.get_parameters(
236
+ mean=get_risk_distribution_parameter(self.get_value_columns, exposure_data),
237
+ sd=get_risk_distribution_parameter(self.get_value_columns, standard_deviation),
152
238
  )
153
239
 
154
- ##########################
155
- # Initialization methods #
156
- ##########################
240
+ self.lookup_tables["parameters"] = self.build_lookup_table(
241
+ builder, parameters.reset_index(), list(parameters.columns)
242
+ )
157
243
 
158
- def get_parameters(self, mean, sd):
159
- index = ["sex", "age_start", "age_end", "year_start", "year_end"]
160
- mean = mean.set_index(index)["value"]
161
- sd = sd.set_index(index)["value"]
162
- return self._distribution.get_parameters(mean=mean, sd=sd).reset_index()
244
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
245
+ return builder.value.register_value_producer(
246
+ self.parameters_pipeline_name,
247
+ source=self.lookup_tables["parameters"],
248
+ requires_columns=get_lookup_columns([self.lookup_tables["parameters"]]),
249
+ )
163
250
 
164
251
  ##################
165
252
  # Public methods #
166
253
  ##################
167
254
 
168
- def ppf(self, q):
169
- if not q.empty:
170
- q = clip(q)
171
- x = self._distribution(parameters=self.parameters(q.index)).ppf(q)
255
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
256
+ if not quantiles.empty:
257
+ quantiles = clip(quantiles)
258
+ parameters = self.exposure_parameters(quantiles.index)
259
+ x = self._distribution(parameters=parameters).ppf(quantiles)
172
260
  x[x.isnull()] = 0
173
261
  else:
174
262
  x = pd.Series([])
175
263
  return x
176
264
 
177
265
 
178
- class PolytomousDistribution(Component):
179
- #####################
180
- # Lifecycle methods #
181
- #####################
182
-
183
- def __init__(self, risk: str, exposure_data: pd.DataFrame):
184
- super().__init__()
185
- self.risk = EntityString(risk)
186
- self._exposure_data = exposure_data
187
- self.exposure_parameters_pipeline_name = f"{self.risk}.exposure_parameters"
188
-
189
- # noinspection PyAttributeOutsideInit
190
- def setup(self, builder: Builder) -> None:
191
- self.categories = self.get_categories()
192
- self.exposure = self.get_exposure_parameters(builder)
266
+ class PolytomousDistribution(RiskExposureDistribution):
267
+ @property
268
+ def categories(self) -> List[str]:
269
+ # These need to be sorted so the cumulative sum is in the ocrrect order of categories
270
+ # and results are therefore reproducible and correct
271
+ return sorted(self.lookup_tables["exposure"].value_columns)
193
272
 
194
273
  #################
195
274
  # Setup methods #
196
275
  #################
197
276
 
198
- def get_categories(self) -> List[str]:
199
- return sorted(
200
- [column for column in self._exposure_data if "cat" in column],
201
- key=lambda column: int(column[3:]),
277
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
278
+ exposure_data = self.get_exposure_data(builder)
279
+ exposure_value_columns = self.get_exposure_value_columns(exposure_data)
280
+
281
+ if isinstance(exposure_data, pd.DataFrame):
282
+ exposure_data = pivot_categorical(builder, self.risk, exposure_data, "parameter")
283
+
284
+ self.lookup_tables["exposure"] = self.build_lookup_table(
285
+ builder, exposure_data, exposure_value_columns
202
286
  )
203
287
 
204
- def get_exposure_parameters(self, builder: Builder) -> Pipeline:
288
+ def get_exposure_value_columns(
289
+ self, exposure_data: Union[int, float, pd.DataFrame]
290
+ ) -> Optional[List[str]]:
291
+ if isinstance(exposure_data, pd.DataFrame):
292
+ return list(exposure_data["parameter"].unique())
293
+ return None
294
+
295
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
205
296
  return builder.value.register_value_producer(
206
- self.exposure_parameters_pipeline_name,
207
- source=builder.lookup.build_table(
208
- self._exposure_data,
209
- key_columns=["sex"],
210
- parameter_columns=["age", "year"],
211
- ),
297
+ self.parameters_pipeline_name,
298
+ source=self.lookup_tables["exposure"],
299
+ requires_columns=get_lookup_columns([self.lookup_tables["exposure"]]),
212
300
  )
213
301
 
214
302
  ##################
215
303
  # Public methods #
216
304
  ##################
217
305
 
218
- def ppf(self, x: pd.Series) -> pd.Series:
219
- exposure = self.exposure(x.index)
306
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
307
+ exposure = self.exposure_parameters(quantiles.index)
220
308
  sorted_exposures = exposure[self.categories]
221
309
  if not np.allclose(1, np.sum(sorted_exposures, axis=1)):
222
310
  raise MissingDataError("All exposure data returned as 0.")
223
311
  exposure_sum = sorted_exposures.cumsum(axis="columns")
224
312
  category_index = pd.concat(
225
- [exposure_sum[c] < x for c in exposure_sum.columns], axis=1
313
+ [exposure_sum[c] < quantiles for c in exposure_sum.columns], axis=1
226
314
  ).sum(axis=1)
227
315
  return pd.Series(
228
316
  np.array(self.categories)[category_index],
229
317
  name=self.risk + ".exposure",
230
- index=x.index,
318
+ index=quantiles.index,
231
319
  )
232
320
 
233
321
 
234
- class DichotomousDistribution(Component):
235
- #####################
236
- # Lifecycle methods #
237
- #####################
322
+ class DichotomousDistribution(RiskExposureDistribution):
238
323
 
239
- def __init__(self, risk: str, exposure_data: pd.DataFrame):
240
- super().__init__()
241
- self.risk = risk
242
- self._exposure_data = exposure_data.drop(columns="cat2")
324
+ #################
325
+ # Setup methods #
326
+ #################
243
327
 
244
- # noinspection PyAttributeOutsideInit
245
- def setup(self, builder: Builder) -> None:
246
- self._base_exposure = builder.lookup.build_table(
247
- self._exposure_data, key_columns=["sex"], parameter_columns=["age", "year"]
328
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
329
+ exposure_data = self.get_exposure_data(builder)
330
+ exposure_value_columns = self.get_exposure_value_columns(exposure_data)
331
+
332
+ if isinstance(exposure_data, pd.DataFrame):
333
+ any_negatives = (exposure_data[exposure_value_columns] < 0).any().any()
334
+ any_over_one = (exposure_data[exposure_value_columns] > 1).any().any()
335
+ if any_negatives or any_over_one:
336
+ raise ValueError(f"All exposures must be in the range [0, 1] for {self.risk}")
337
+ elif exposure_data < 0 or exposure_data > 1:
338
+ raise ValueError(f"Exposure must be in the range [0, 1] for {self.risk}")
339
+
340
+ self.lookup_tables["exposure"] = self.build_lookup_table(
341
+ builder, exposure_data, exposure_value_columns
248
342
  )
249
- self.exposure_proportion = builder.value.register_value_producer(
250
- f"{self.risk}.exposure_parameters", source=self.exposure
343
+ self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
344
+
345
+ def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
346
+ exposure_data = super().get_exposure_data(builder)
347
+
348
+ if isinstance(exposure_data, (int, float)):
349
+ return exposure_data
350
+
351
+ # rebin exposure categories
352
+ self.validate_rebin_source(builder, exposure_data)
353
+ rebin_exposed_categories = set(self.configuration["rebinned_exposed"])
354
+ if rebin_exposed_categories:
355
+ exposure_data = self._rebin_exposure_data(exposure_data, rebin_exposed_categories)
356
+
357
+ exposure_data = exposure_data[exposure_data["parameter"] == "cat1"]
358
+ return exposure_data.drop(columns="parameter")
359
+
360
+ @staticmethod
361
+ def _rebin_exposure_data(
362
+ exposure_data: pd.DataFrame, rebin_exposed_categories: set
363
+ ) -> pd.DataFrame:
364
+ exposure_data = exposure_data[
365
+ exposure_data["parameter"].isin(rebin_exposed_categories)
366
+ ]
367
+ exposure_data["parameter"] = "cat1"
368
+ exposure_data = (
369
+ exposure_data.groupby(list(exposure_data.columns.difference(["value"])))
370
+ .sum()
371
+ .reset_index()
251
372
  )
252
- base_paf = builder.lookup.build_table(0)
373
+ return exposure_data
374
+
375
+ def get_exposure_value_columns(
376
+ self, exposure_data: Union[int, float, pd.DataFrame]
377
+ ) -> Optional[List[str]]:
378
+ if isinstance(exposure_data, pd.DataFrame):
379
+ return self.get_value_columns(exposure_data)
380
+ return None
381
+
382
+ # noinspection PyAttributeOutsideInit
383
+ def setup(self, builder: Builder) -> None:
384
+ super().setup(builder)
253
385
  self.joint_paf = builder.value.register_value_producer(
254
386
  f"{self.risk}.exposure_parameters.paf",
255
- source=lambda index: [base_paf(index)],
387
+ source=lambda index: [self.lookup_tables["paf"](index)],
256
388
  preferred_combiner=list_combiner,
257
389
  preferred_post_processor=union_post_processor,
258
390
  )
259
391
 
392
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
393
+ return builder.value.register_value_producer(
394
+ f"{self.risk}.exposure_parameters",
395
+ source=self.exposure_parameter_source,
396
+ requires_columns=get_lookup_columns([self.lookup_tables["exposure"]]),
397
+ )
398
+
399
+ ##############
400
+ # Validators #
401
+ ##############
402
+
403
+ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
404
+ if not isinstance(data, pd.DataFrame):
405
+ return
406
+
407
+ rebin_exposed_categories = set(builder.configuration[self.risk]["rebinned_exposed"])
408
+
409
+ if (
410
+ rebin_exposed_categories
411
+ and builder.configuration[self.risk]["category_thresholds"]
412
+ ):
413
+ raise ValueError(
414
+ f"Rebinning and category thresholds are mutually exclusive. "
415
+ f"You provided both for {self.risk.name}."
416
+ )
417
+
418
+ invalid_cats = rebin_exposed_categories.difference(set(data.parameter))
419
+ if invalid_cats:
420
+ raise ValueError(
421
+ f"The following provided categories for the rebinned exposed "
422
+ f"category of {self.risk.name} are not found in the exposure data: "
423
+ f"{invalid_cats}."
424
+ )
425
+
426
+ if rebin_exposed_categories == set(data.parameter):
427
+ raise ValueError(
428
+ f"The provided categories for the rebinned exposed category of "
429
+ f"{self.risk.name} comprise all categories for the exposure data. "
430
+ f"At least one category must be left out of the provided categories "
431
+ f"to be rebinned into the unexposed category."
432
+ )
433
+
260
434
  ##################################
261
435
  # Pipeline sources and modifiers #
262
436
  ##################################
263
437
 
264
- def exposure(self, index: pd.Index) -> pd.Series:
265
- base_exposure = self._base_exposure(index).values
438
+ def exposure_parameter_source(self, index: pd.Index) -> pd.Series:
439
+ base_exposure = self.lookup_tables["exposure"](index).values
266
440
  joint_paf = self.joint_paf(index).values
267
441
  return pd.Series(base_exposure * (1 - joint_paf), index=index, name="values")
268
442
 
@@ -270,42 +444,17 @@ class DichotomousDistribution(Component):
270
444
  # Public methods #
271
445
  ##################
272
446
 
273
- def ppf(self, x: pd.Series) -> pd.Series:
274
- exposed = x < self.exposure_proportion(x.index)
447
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
448
+ exposed = quantiles < self.exposure_parameters(quantiles.index)
275
449
  return pd.Series(
276
450
  exposed.replace({True: "cat1", False: "cat2"}),
277
451
  name=self.risk + ".exposure",
278
- index=x.index,
279
- )
280
-
281
-
282
- def get_distribution(risk, distribution_type, exposure, exposure_standard_deviation, weights):
283
- if distribution_type == "dichotomous":
284
- distribution = DichotomousDistribution(risk, exposure)
285
- elif "polytomous" in distribution_type:
286
- distribution = PolytomousDistribution(risk, exposure)
287
- elif distribution_type == "normal":
288
- distribution = ContinuousDistribution(
289
- risk, mean=exposure, sd=exposure_standard_deviation, distribution=Normal
290
- )
291
- elif distribution_type == "lognormal":
292
- distribution = ContinuousDistribution(
293
- risk, mean=exposure, sd=exposure_standard_deviation, distribution=LogNormal
294
- )
295
- elif distribution_type == "ensemble":
296
- distribution = EnsembleSimulation(
297
- risk,
298
- weights,
299
- mean=exposure,
300
- sd=exposure_standard_deviation,
452
+ index=quantiles.index,
301
453
  )
302
- else:
303
- raise NotImplementedError(f"Unhandled distribution type {distribution_type}")
304
- return distribution
305
454
 
306
455
 
307
456
  def clip(q):
308
- """Adjust the percentile boundary casses.
457
+ """Adjust the percentile boundary cases.
309
458
 
310
459
  The risk distributions package uses the 99.9th and 0.001st percentiles
311
460
  of a log-normal distribution as the bounds of the distribution support.
@@ -319,3 +468,24 @@ def clip(q):
319
468
  q[q > Q_UPPER_BOUND] = Q_UPPER_BOUND
320
469
  q[q < Q_LOWER_BOUND] = Q_LOWER_BOUND
321
470
  return q
471
+
472
+
473
+ def get_risk_distribution_parameter(
474
+ value_columns_getter: Callable[[Union[pd.DataFrame]], List[str]],
475
+ data: Union[float, pd.DataFrame],
476
+ ) -> Union[float, pd.Series]:
477
+ if isinstance(data, pd.DataFrame):
478
+ value_columns = value_columns_getter(data)
479
+ if len(value_columns) > 1:
480
+ raise ValueError(
481
+ "Expected a single value column for risk data, but found "
482
+ f"{len(value_columns)}: {value_columns}."
483
+ )
484
+ # don't return parameter col in continuous and ensemble distribution
485
+ # means to match standard deviation index
486
+ if "parameter" in data.columns and set(data["parameter"]) == {"continuous"}:
487
+ data = data.drop("parameter", axis=1)
488
+ index = [col for col in data.columns if col not in value_columns]
489
+ data = data.set_index(index)[value_columns].squeeze(axis=1)
490
+
491
+ return data