vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -7,52 +7,86 @@ This module contains tools for modeling several different risk
7
7
  exposure distributions.
8
8
 
9
9
  """
10
- from typing import Dict, List
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import Callable, Dict, List, Optional, Union
11
13
 
12
14
  import numpy as np
13
15
  import pandas as pd
14
- from risk_distributions import EnsembleDistribution, LogNormal, Normal
16
+ import risk_distributions as rd
17
+ from layered_config_tree import LayeredConfigTree
15
18
  from vivarium import Component
16
19
  from vivarium.framework.engine import Builder
17
20
  from vivarium.framework.population import SimulantData
18
21
  from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
19
22
 
20
- from vivarium_public_health.risks.data_transformations import get_distribution_data
21
- from vivarium_public_health.utilities import EntityString
23
+ from vivarium_public_health.risks.data_transformations import pivot_categorical
24
+ from vivarium_public_health.utilities import EntityString, get_lookup_columns
22
25
 
23
26
 
24
27
  class MissingDataError(Exception):
25
28
  pass
26
29
 
27
30
 
28
- # FIXME: This is a hack. It's wrapping up an adaptor pattern in another
29
- # adaptor pattern, which is gross, but would require some more difficult
30
- # refactoring which is thoroughly out of scope right now. -J.C. 8/25/19
31
- class SimulationDistribution(Component):
32
- """Wrapper around a variety of distribution implementations."""
31
+ class RiskExposureDistribution(Component, ABC):
33
32
 
34
33
  #####################
35
34
  # Lifecycle methods #
36
35
  #####################
37
36
 
38
- def __init__(self, risk: str):
37
+ def __init__(
38
+ self,
39
+ risk: EntityString,
40
+ distribution_type: str,
41
+ exposure_data: Optional[Union[int, float, pd.DataFrame]] = None,
42
+ ) -> None:
39
43
  super().__init__()
40
- self.risk = EntityString(risk)
44
+ self.risk = risk
45
+ self.distribution_type = distribution_type
46
+ self._exposure_data = exposure_data
47
+
48
+ self.parameters_pipeline_name = f"{self.risk}.exposure_parameters"
41
49
 
50
+ #################
51
+ # Setup methods #
52
+ #################
53
+
54
+ def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]:
55
+ return builder.configuration[self.risk]
56
+
57
+ @abstractmethod
58
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
59
+ raise NotImplementedError
60
+
61
+ def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
62
+ if self._exposure_data is not None:
63
+ return self._exposure_data
64
+ return self.get_data(builder, self.configuration["data_sources"]["exposure"])
65
+
66
+ # noinspection PyAttributeOutsideInit
42
67
  def setup(self, builder: Builder) -> None:
43
- distribution_data = get_distribution_data(builder, self.risk)
44
- self.implementation = get_distribution(self.risk, **distribution_data)
45
- self.implementation.setup_component(builder)
68
+ self.exposure_parameters = self.get_exposure_parameter_pipeline(builder)
69
+ if self.exposure_parameters.name != self.parameters_pipeline_name:
70
+ raise ValueError(
71
+ "Expected exposure parameters pipeline to be named "
72
+ f"{self.parameters_pipeline_name}, "
73
+ f"but found {self.exposure_parameters.name}."
74
+ )
75
+
76
+ @abstractmethod
77
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
78
+ raise NotImplementedError
46
79
 
47
80
  ##################
48
81
  # Public methods #
49
82
  ##################
50
83
 
51
- def ppf(self, q):
52
- return self.implementation.ppf(q)
84
+ @abstractmethod
85
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
86
+ raise NotImplementedError
53
87
 
54
88
 
55
- class EnsembleSimulation(Component):
89
+ class EnsembleDistribution(RiskExposureDistribution):
56
90
  ##############
57
91
  # Properties #
58
92
  ##############
@@ -73,38 +107,71 @@ class EnsembleSimulation(Component):
73
107
  # Lifecycle methods #
74
108
  #####################
75
109
 
76
- def __init__(self, risk, weights, mean, sd):
77
- super().__init__()
78
- self.risk = EntityString(risk)
79
- self._weights, self._parameters = self.get_parameters(weights, mean, sd)
110
+ def __init__(self, risk: EntityString, distribution_type: str = "ensemble") -> None:
111
+ super().__init__(risk, distribution_type)
80
112
  self._propensity = f"ensemble_propensity_{self.risk}"
81
113
 
82
- def setup(self, builder: Builder) -> None:
83
- self.weights = builder.lookup.build_table(
84
- self._weights, key_columns=["sex"], parameter_columns=["age", "year"]
114
+ #################
115
+ # Setup methods #
116
+ #################
117
+
118
+ def build_all_lookup_tables(self, builder: Builder) -> None:
119
+ exposure_data = self.get_exposure_data(builder)
120
+ standard_deviation = self.get_data(
121
+ builder,
122
+ self.configuration["data_sources"]["exposure_standard_deviation"],
123
+ )
124
+ weights_source = self.configuration["data_sources"]["ensemble_distribution_weights"]
125
+ raw_weights = self.get_data(builder, weights_source)
126
+
127
+ glnorm_mask = raw_weights["parameter"] == "glnorm"
128
+ if np.any(raw_weights.loc[glnorm_mask, self.get_value_columns(weights_source)]):
129
+ raise NotImplementedError("glnorm distribution is not supported")
130
+ raw_weights = raw_weights[~glnorm_mask]
131
+
132
+ distributions = list(raw_weights["parameter"].unique())
133
+
134
+ raw_weights = pivot_categorical(
135
+ builder, self.risk, raw_weights, pivot_column="parameter", reset_index=False
85
136
  )
137
+
138
+ weights, parameters = rd.EnsembleDistribution.get_parameters(
139
+ raw_weights,
140
+ mean=get_risk_distribution_parameter(self.get_value_columns, exposure_data),
141
+ sd=get_risk_distribution_parameter(self.get_value_columns, standard_deviation),
142
+ )
143
+
144
+ distribution_weights_table = self.build_lookup_table(
145
+ builder, weights.reset_index(), distributions
146
+ )
147
+ self.lookup_tables["ensemble_distribution_weights"] = distribution_weights_table
148
+ key_columns = distribution_weights_table.key_columns
149
+ parameter_columns = distribution_weights_table.parameter_columns
150
+
86
151
  self.parameters = {
87
- k: builder.lookup.build_table(
88
- v, key_columns=["sex"], parameter_columns=["age", "year"]
152
+ parameter: builder.lookup.build_table(
153
+ data.reset_index(),
154
+ key_columns=key_columns,
155
+ parameter_columns=parameter_columns,
89
156
  )
90
- for k, v in self._parameters.items()
157
+ for parameter, data in parameters.items()
91
158
  }
92
159
 
160
+ def setup(self, builder: Builder) -> None:
161
+ super().setup(builder)
93
162
  self.randomness = builder.randomness.get_stream(self._propensity)
94
163
 
95
- ##########################
96
- # Initialization methods #
97
- ##########################
98
-
99
- def get_parameters(self, weights, mean, sd):
100
- index_cols = ["sex", "age_start", "age_end", "year_start", "year_end"]
101
- weights = weights.set_index(index_cols)
102
- mean = mean.set_index(index_cols)["value"]
103
- sd = sd.set_index(index_cols)["value"]
104
- weights, parameters = EnsembleDistribution.get_parameters(weights, mean=mean, sd=sd)
105
- return weights.reset_index(), {
106
- name: p.reset_index() for name, p in parameters.items()
107
- }
164
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
165
+ # This pipeline is not needed for ensemble distributions, so just
166
+ # register a dummy pipeline
167
+ def raise_not_implemented():
168
+ raise NotImplementedError(
169
+ "EnsembleDistribution does not use exposure parameters."
170
+ )
171
+
172
+ return builder.value.register_value_producer(
173
+ self.parameters_pipeline_name, lambda *_: raise_not_implemented()
174
+ )
108
175
 
109
176
  ########################
110
177
  # Event-driven methods #
@@ -120,149 +187,256 @@ class EnsembleSimulation(Component):
120
187
  # Public methods #
121
188
  ##################
122
189
 
123
- def ppf(self, q):
124
- if not q.empty:
125
- q = clip(q)
126
- weights = self.weights(q.index)
190
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
191
+ if not quantiles.empty:
192
+ quantiles = clip(quantiles)
193
+ weights = self.lookup_tables["ensemble_distribution_weights"](quantiles.index)
127
194
  parameters = {
128
- name: parameter(q.index) for name, parameter in self.parameters.items()
195
+ name: param(quantiles.index) for name, param in self.parameters.items()
129
196
  }
130
- ensemble_propensity = self.population_view.get(q.index).iloc[:, 0]
131
- x = EnsembleDistribution(weights, parameters).ppf(q, ensemble_propensity)
197
+ ensemble_propensity = self.population_view.get(quantiles.index).iloc[:, 0]
198
+ x = rd.EnsembleDistribution(weights, parameters).ppf(
199
+ quantiles, ensemble_propensity
200
+ )
132
201
  x[x.isnull()] = 0
133
202
  else:
134
203
  x = pd.Series([])
135
204
  return x
136
205
 
137
206
 
138
- class ContinuousDistribution(Component):
207
+ class ContinuousDistribution(RiskExposureDistribution):
139
208
  #####################
140
209
  # Lifecycle methods #
141
210
  #####################
142
211
 
143
- def __init__(self, risk, mean, sd, distribution=None):
144
- super().__init__()
145
- self.risk = EntityString(risk)
146
- self._distribution = distribution
147
- self._parameters = self.get_parameters(mean, sd)
212
+ def __init__(self, risk: EntityString, distribution_type: str) -> None:
213
+ super().__init__(risk, distribution_type)
214
+ self.standard_deviation = None
215
+ try:
216
+ self._distribution = {
217
+ "normal": rd.Normal,
218
+ "lognormal": rd.LogNormal,
219
+ }[distribution_type]
220
+ except KeyError:
221
+ raise NotImplementedError(
222
+ f"Distribution type {distribution_type} is not supported for "
223
+ f"risk {risk.name}."
224
+ )
148
225
 
149
- def setup(self, builder: Builder) -> None:
150
- self.parameters = builder.lookup.build_table(
151
- self._parameters, key_columns=["sex"], parameter_columns=["age", "year"]
226
+ #################
227
+ # Setup methods #
228
+ #################
229
+
230
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
231
+ exposure_data = self.get_exposure_data(builder)
232
+ standard_deviation = self.get_data(
233
+ builder, self.configuration["data_sources"]["exposure_standard_deviation"]
234
+ )
235
+ parameters = self._distribution.get_parameters(
236
+ mean=get_risk_distribution_parameter(self.get_value_columns, exposure_data),
237
+ sd=get_risk_distribution_parameter(self.get_value_columns, standard_deviation),
152
238
  )
153
239
 
154
- ##########################
155
- # Initialization methods #
156
- ##########################
240
+ self.lookup_tables["parameters"] = self.build_lookup_table(
241
+ builder, parameters.reset_index(), list(parameters.columns)
242
+ )
157
243
 
158
- def get_parameters(self, mean, sd):
159
- index = ["sex", "age_start", "age_end", "year_start", "year_end"]
160
- mean = mean.set_index(index)["value"]
161
- sd = sd.set_index(index)["value"]
162
- return self._distribution.get_parameters(mean=mean, sd=sd).reset_index()
244
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
245
+ return builder.value.register_value_producer(
246
+ self.parameters_pipeline_name,
247
+ source=self.lookup_tables["parameters"],
248
+ requires_columns=get_lookup_columns([self.lookup_tables["parameters"]]),
249
+ )
163
250
 
164
251
  ##################
165
252
  # Public methods #
166
253
  ##################
167
254
 
168
- def ppf(self, q):
169
- if not q.empty:
170
- q = clip(q)
171
- x = self._distribution(parameters=self.parameters(q.index)).ppf(q)
255
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
256
+ if not quantiles.empty:
257
+ quantiles = clip(quantiles)
258
+ parameters = self.exposure_parameters(quantiles.index)
259
+ x = self._distribution(parameters=parameters).ppf(quantiles)
172
260
  x[x.isnull()] = 0
173
261
  else:
174
262
  x = pd.Series([])
175
263
  return x
176
264
 
177
265
 
178
- class PolytomousDistribution(Component):
179
- #####################
180
- # Lifecycle methods #
181
- #####################
182
-
183
- def __init__(self, risk: str, exposure_data: pd.DataFrame):
184
- super().__init__()
185
- self.risk = EntityString(risk)
186
- self._exposure_data = exposure_data
187
- self.exposure_parameters_pipeline_name = f"{self.risk}.exposure_parameters"
188
-
189
- # noinspection PyAttributeOutsideInit
190
- def setup(self, builder: Builder) -> None:
191
- self.categories = self.get_categories()
192
- self.exposure = self.get_exposure_parameters(builder)
266
+ class PolytomousDistribution(RiskExposureDistribution):
267
+ @property
268
+ def categories(self) -> List[str]:
269
+ # These need to be sorted so the cumulative sum is in the ocrrect order of categories
270
+ # and results are therefore reproducible and correct
271
+ return sorted(self.lookup_tables["exposure"].value_columns)
193
272
 
194
273
  #################
195
274
  # Setup methods #
196
275
  #################
197
276
 
198
- def get_categories(self) -> List[str]:
199
- return sorted(
200
- [column for column in self._exposure_data if "cat" in column],
201
- key=lambda column: int(column[3:]),
277
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
278
+ exposure_data = self.get_exposure_data(builder)
279
+ exposure_value_columns = self.get_exposure_value_columns(exposure_data)
280
+
281
+ if isinstance(exposure_data, pd.DataFrame):
282
+ exposure_data = pivot_categorical(builder, self.risk, exposure_data, "parameter")
283
+
284
+ self.lookup_tables["exposure"] = self.build_lookup_table(
285
+ builder, exposure_data, exposure_value_columns
202
286
  )
203
287
 
204
- def get_exposure_parameters(self, builder: Builder) -> Pipeline:
288
+ def get_exposure_value_columns(
289
+ self, exposure_data: Union[int, float, pd.DataFrame]
290
+ ) -> Optional[List[str]]:
291
+ if isinstance(exposure_data, pd.DataFrame):
292
+ return list(exposure_data["parameter"].unique())
293
+ return None
294
+
295
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
205
296
  return builder.value.register_value_producer(
206
- self.exposure_parameters_pipeline_name,
207
- source=builder.lookup.build_table(
208
- self._exposure_data,
209
- key_columns=["sex"],
210
- parameter_columns=["age", "year"],
211
- ),
297
+ self.parameters_pipeline_name,
298
+ source=self.lookup_tables["exposure"],
299
+ requires_columns=get_lookup_columns([self.lookup_tables["exposure"]]),
212
300
  )
213
301
 
214
302
  ##################
215
303
  # Public methods #
216
304
  ##################
217
305
 
218
- def ppf(self, x: pd.Series) -> pd.Series:
219
- exposure = self.exposure(x.index)
306
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
307
+ exposure = self.exposure_parameters(quantiles.index)
220
308
  sorted_exposures = exposure[self.categories]
221
309
  if not np.allclose(1, np.sum(sorted_exposures, axis=1)):
222
310
  raise MissingDataError("All exposure data returned as 0.")
223
311
  exposure_sum = sorted_exposures.cumsum(axis="columns")
224
312
  category_index = pd.concat(
225
- [exposure_sum[c] < x for c in exposure_sum.columns], axis=1
313
+ [exposure_sum[c] < quantiles for c in exposure_sum.columns], axis=1
226
314
  ).sum(axis=1)
227
315
  return pd.Series(
228
316
  np.array(self.categories)[category_index],
229
317
  name=self.risk + ".exposure",
230
- index=x.index,
318
+ index=quantiles.index,
231
319
  )
232
320
 
233
321
 
234
- class DichotomousDistribution(Component):
235
- #####################
236
- # Lifecycle methods #
237
- #####################
322
+ class DichotomousDistribution(RiskExposureDistribution):
238
323
 
239
- def __init__(self, risk: str, exposure_data: pd.DataFrame):
240
- super().__init__()
241
- self.risk = risk
242
- self._exposure_data = exposure_data.drop(columns="cat2")
324
+ #################
325
+ # Setup methods #
326
+ #################
243
327
 
244
- # noinspection PyAttributeOutsideInit
245
- def setup(self, builder: Builder) -> None:
246
- self._base_exposure = builder.lookup.build_table(
247
- self._exposure_data, key_columns=["sex"], parameter_columns=["age", "year"]
328
+ def build_all_lookup_tables(self, builder: "Builder") -> None:
329
+ exposure_data = self.get_exposure_data(builder)
330
+ exposure_value_columns = self.get_exposure_value_columns(exposure_data)
331
+
332
+ if isinstance(exposure_data, pd.DataFrame):
333
+ any_negatives = (exposure_data[exposure_value_columns] < 0).any().any()
334
+ any_over_one = (exposure_data[exposure_value_columns] > 1).any().any()
335
+ if any_negatives or any_over_one:
336
+ raise ValueError(f"All exposures must be in the range [0, 1] for {self.risk}")
337
+ elif exposure_data < 0 or exposure_data > 1:
338
+ raise ValueError(f"Exposure must be in the range [0, 1] for {self.risk}")
339
+
340
+ self.lookup_tables["exposure"] = self.build_lookup_table(
341
+ builder, exposure_data, exposure_value_columns
248
342
  )
249
- self.exposure_proportion = builder.value.register_value_producer(
250
- f"{self.risk}.exposure_parameters", source=self.exposure
343
+ self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
344
+
345
+ def get_exposure_data(self, builder: Builder) -> Union[int, float, pd.DataFrame]:
346
+ exposure_data = super().get_exposure_data(builder)
347
+
348
+ if isinstance(exposure_data, (int, float)):
349
+ return exposure_data
350
+
351
+ # rebin exposure categories
352
+ self.validate_rebin_source(builder, exposure_data)
353
+ rebin_exposed_categories = set(self.configuration["rebinned_exposed"])
354
+ if rebin_exposed_categories:
355
+ exposure_data = self._rebin_exposure_data(exposure_data, rebin_exposed_categories)
356
+
357
+ exposure_data = exposure_data[exposure_data["parameter"] == "cat1"]
358
+ return exposure_data.drop(columns="parameter")
359
+
360
+ @staticmethod
361
+ def _rebin_exposure_data(
362
+ exposure_data: pd.DataFrame, rebin_exposed_categories: set
363
+ ) -> pd.DataFrame:
364
+ exposure_data = exposure_data[
365
+ exposure_data["parameter"].isin(rebin_exposed_categories)
366
+ ]
367
+ exposure_data["parameter"] = "cat1"
368
+ exposure_data = (
369
+ exposure_data.groupby(list(exposure_data.columns.difference(["value"])))
370
+ .sum()
371
+ .reset_index()
251
372
  )
252
- base_paf = builder.lookup.build_table(0)
373
+ return exposure_data
374
+
375
+ def get_exposure_value_columns(
376
+ self, exposure_data: Union[int, float, pd.DataFrame]
377
+ ) -> Optional[List[str]]:
378
+ if isinstance(exposure_data, pd.DataFrame):
379
+ return self.get_value_columns(exposure_data)
380
+ return None
381
+
382
+ # noinspection PyAttributeOutsideInit
383
+ def setup(self, builder: Builder) -> None:
384
+ super().setup(builder)
253
385
  self.joint_paf = builder.value.register_value_producer(
254
386
  f"{self.risk}.exposure_parameters.paf",
255
- source=lambda index: [base_paf(index)],
387
+ source=lambda index: [self.lookup_tables["paf"](index)],
256
388
  preferred_combiner=list_combiner,
257
389
  preferred_post_processor=union_post_processor,
258
390
  )
259
391
 
392
+ def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
393
+ return builder.value.register_value_producer(
394
+ f"{self.risk}.exposure_parameters",
395
+ source=self.exposure_parameter_source,
396
+ requires_columns=get_lookup_columns([self.lookup_tables["exposure"]]),
397
+ )
398
+
399
+ ##############
400
+ # Validators #
401
+ ##############
402
+
403
+ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
404
+ if not isinstance(data, pd.DataFrame):
405
+ return
406
+
407
+ rebin_exposed_categories = set(builder.configuration[self.risk]["rebinned_exposed"])
408
+
409
+ if (
410
+ rebin_exposed_categories
411
+ and builder.configuration[self.risk]["category_thresholds"]
412
+ ):
413
+ raise ValueError(
414
+ f"Rebinning and category thresholds are mutually exclusive. "
415
+ f"You provided both for {self.risk.name}."
416
+ )
417
+
418
+ invalid_cats = rebin_exposed_categories.difference(set(data.parameter))
419
+ if invalid_cats:
420
+ raise ValueError(
421
+ f"The following provided categories for the rebinned exposed "
422
+ f"category of {self.risk.name} are not found in the exposure data: "
423
+ f"{invalid_cats}."
424
+ )
425
+
426
+ if rebin_exposed_categories == set(data.parameter):
427
+ raise ValueError(
428
+ f"The provided categories for the rebinned exposed category of "
429
+ f"{self.risk.name} comprise all categories for the exposure data. "
430
+ f"At least one category must be left out of the provided categories "
431
+ f"to be rebinned into the unexposed category."
432
+ )
433
+
260
434
  ##################################
261
435
  # Pipeline sources and modifiers #
262
436
  ##################################
263
437
 
264
- def exposure(self, index: pd.Index) -> pd.Series:
265
- base_exposure = self._base_exposure(index).values
438
+ def exposure_parameter_source(self, index: pd.Index) -> pd.Series:
439
+ base_exposure = self.lookup_tables["exposure"](index).values
266
440
  joint_paf = self.joint_paf(index).values
267
441
  return pd.Series(base_exposure * (1 - joint_paf), index=index, name="values")
268
442
 
@@ -270,42 +444,17 @@ class DichotomousDistribution(Component):
270
444
  # Public methods #
271
445
  ##################
272
446
 
273
- def ppf(self, x: pd.Series) -> pd.Series:
274
- exposed = x < self.exposure_proportion(x.index)
447
+ def ppf(self, quantiles: pd.Series) -> pd.Series:
448
+ exposed = quantiles < self.exposure_parameters(quantiles.index)
275
449
  return pd.Series(
276
450
  exposed.replace({True: "cat1", False: "cat2"}),
277
451
  name=self.risk + ".exposure",
278
- index=x.index,
279
- )
280
-
281
-
282
- def get_distribution(risk, distribution_type, exposure, exposure_standard_deviation, weights):
283
- if distribution_type == "dichotomous":
284
- distribution = DichotomousDistribution(risk, exposure)
285
- elif "polytomous" in distribution_type:
286
- distribution = PolytomousDistribution(risk, exposure)
287
- elif distribution_type == "normal":
288
- distribution = ContinuousDistribution(
289
- risk, mean=exposure, sd=exposure_standard_deviation, distribution=Normal
290
- )
291
- elif distribution_type == "lognormal":
292
- distribution = ContinuousDistribution(
293
- risk, mean=exposure, sd=exposure_standard_deviation, distribution=LogNormal
294
- )
295
- elif distribution_type == "ensemble":
296
- distribution = EnsembleSimulation(
297
- risk,
298
- weights,
299
- mean=exposure,
300
- sd=exposure_standard_deviation,
452
+ index=quantiles.index,
301
453
  )
302
- else:
303
- raise NotImplementedError(f"Unhandled distribution type {distribution_type}")
304
- return distribution
305
454
 
306
455
 
307
456
  def clip(q):
308
- """Adjust the percentile boundary casses.
457
+ """Adjust the percentile boundary cases.
309
458
 
310
459
  The risk distributions package uses the 99.9th and 0.001st percentiles
311
460
  of a log-normal distribution as the bounds of the distribution support.
@@ -319,3 +468,24 @@ def clip(q):
319
468
  q[q > Q_UPPER_BOUND] = Q_UPPER_BOUND
320
469
  q[q < Q_LOWER_BOUND] = Q_LOWER_BOUND
321
470
  return q
471
+
472
+
473
+ def get_risk_distribution_parameter(
474
+ value_columns_getter: Callable[[Union[pd.DataFrame]], List[str]],
475
+ data: Union[float, pd.DataFrame],
476
+ ) -> Union[float, pd.Series]:
477
+ if isinstance(data, pd.DataFrame):
478
+ value_columns = value_columns_getter(data)
479
+ if len(value_columns) > 1:
480
+ raise ValueError(
481
+ "Expected a single value column for risk data, but found "
482
+ f"{len(value_columns)}: {value_columns}."
483
+ )
484
+ # don't return parameter col in continuous and ensemble distribution
485
+ # means to match standard deviation index
486
+ if "parameter" in data.columns and set(data["parameter"]) == {"continuous"}:
487
+ data = data.drop("parameter", axis=1)
488
+ index = [col for col in data.columns if col not in value_columns]
489
+ data = data.set_index(index)[value_columns].squeeze(axis=1)
490
+
491
+ return data