vivarium-public-health 2.3.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/METADATA +13 -3
  37. vivarium_public_health-3.0.0.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.2.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.2.dist-info → vivarium_public_health-3.0.0.dist-info}/top_level.txt +0 -0
@@ -8,20 +8,25 @@ exposure models and disease models.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Callable, Dict
11
+ from importlib import import_module
12
+ from typing import Any, Callable, Dict, List, Tuple, Union
12
13
 
13
14
  import numpy as np
14
15
  import pandas as pd
16
+ import scipy
17
+ from layered_config_tree import ConfigurationError
15
18
  from vivarium import Component
16
19
  from vivarium.framework.engine import Builder
17
- from vivarium.framework.lookup import LookupTable
20
+ from vivarium.framework.event import Event
21
+ from vivarium.framework.population import SimulantData
18
22
 
23
+ from vivarium_public_health.risks import Risk
19
24
  from vivarium_public_health.risks.data_transformations import (
20
- get_distribution_type,
21
- get_population_attributable_fraction_data,
22
- get_relative_risk_data,
25
+ load_exposure_data,
26
+ pivot_categorical,
23
27
  )
24
- from vivarium_public_health.utilities import EntityString, TargetString
28
+ from vivarium_public_health.risks.distributions import MissingDataError
29
+ from vivarium_public_health.utilities import EntityString, TargetString, get_lookup_columns
25
30
 
26
31
 
27
32
  class RiskEffect(Component):
@@ -34,29 +39,24 @@ class RiskEffect(Component):
34
39
  .. code-block:: yaml
35
40
 
36
41
  configuration:
37
- effect_of_risk_on_affected_risk:
42
+ risk_effect.risk_name_on_affected_target:
38
43
  exposure_parameters: 2
39
44
  incidence_rate: 10
40
45
 
41
46
  """
42
47
 
43
- CONFIGURATION_DEFAULTS = {
44
- "effect_of_risk_on_target": {
45
- "measure": {
46
- "relative_risk": None,
47
- "mean": None,
48
- "se": None,
49
- "log_mean": None,
50
- "log_se": None,
51
- "tau_squared": None,
52
- }
53
- }
54
- }
55
-
56
- ##############
48
+ ###############
57
49
  # Properties #
58
50
  ##############
59
51
 
52
+ @property
53
+ def name(self) -> str:
54
+ return self.get_name(self.risk, self.target)
55
+
56
+ @staticmethod
57
+ def get_name(risk: EntityString, target: TargetString) -> str:
58
+ return f"risk_effect.{risk.name}_on_{target}"
59
+
60
60
  @property
61
61
  def configuration_defaults(self) -> Dict[str, Any]:
62
62
  """
@@ -64,13 +64,25 @@ class RiskEffect(Component):
64
64
  this component.
65
65
  """
66
66
  return {
67
- f"effect_of_{self.risk.name}_on_{self.target.name}": {
68
- self.target.measure: self.CONFIGURATION_DEFAULTS["effect_of_risk_on_target"][
69
- "measure"
70
- ]
67
+ self.name: {
68
+ "data_sources": {
69
+ "relative_risk": f"{self.risk}.relative_risk",
70
+ "population_attributable_fraction": f"{self.risk}.population_attributable_fraction",
71
+ },
72
+ "data_source_parameters": {
73
+ "relative_risk": {},
74
+ },
71
75
  }
72
76
  }
73
77
 
78
+ @property
79
+ def is_exposure_categorical(self) -> bool:
80
+ return self._exposure_distribution_type in [
81
+ "dichotomous",
82
+ "ordered_polytomous",
83
+ "unordered_polytomous",
84
+ ]
85
+
74
86
  #####################
75
87
  # Lifecycle methods #
76
88
  #####################
@@ -92,18 +104,15 @@ class RiskEffect(Component):
92
104
  self.risk = EntityString(risk)
93
105
  self.target = TargetString(target)
94
106
 
107
+ self._exposure_distribution_type = None
108
+
95
109
  self.exposure_pipeline_name = f"{self.risk.name}.exposure"
96
110
  self.target_pipeline_name = f"{self.target.name}.{self.target.measure}"
97
111
  self.target_paf_pipeline_name = f"{self.target_pipeline_name}.paf"
98
112
 
99
113
  # noinspection PyAttributeOutsideInit
100
114
  def setup(self, builder: Builder) -> None:
101
- self.exposure_distribution_type = self.get_distribution_type(builder)
102
115
  self.exposure = self.get_risk_exposure(builder)
103
- self.relative_risk = self.get_relative_risk_source(builder)
104
- self.population_attributable_fraction = (
105
- self.get_population_attributable_fraction_source(builder)
106
- )
107
116
 
108
117
  self.target_modifier = self.get_target_modifier(builder)
109
118
 
@@ -114,62 +123,148 @@ class RiskEffect(Component):
114
123
  # Setup methods #
115
124
  #################
116
125
 
117
- def get_distribution_type(self, builder: Builder) -> str:
118
- return get_distribution_type(builder, self.risk)
119
-
120
- def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
121
- return builder.value.get_value(self.exposure_pipeline_name)
126
+ def build_all_lookup_tables(self, builder: Builder) -> None:
127
+ self._exposure_distribution_type = self.get_distribution_type(builder)
122
128
 
123
- def get_relative_risk_source(self, builder: Builder) -> LookupTable:
124
- """
125
- Get the relative risk source for this risk effect model.
126
-
127
- Parameters
128
- ----------
129
- builder
130
- Interface to access simulation managers.
131
-
132
- Returns
133
- -------
134
- LookupTable
135
- A lookup table containing the relative risk data for this risk
136
- effect model.
137
- """
138
- relative_risk_data = get_relative_risk_data(builder, self.risk, self.target)
139
- return builder.lookup.build_table(
140
- relative_risk_data, key_columns=["sex"], parameter_columns=["age", "year"]
129
+ rr_data = self.get_relative_risk_data(builder)
130
+ rr_value_cols = None
131
+ if self.is_exposure_categorical:
132
+ rr_data, rr_value_cols = self.process_categorical_data(builder, rr_data)
133
+ self.lookup_tables["relative_risk"] = self.build_lookup_table(
134
+ builder, rr_data, rr_value_cols
141
135
  )
142
136
 
143
- def get_population_attributable_fraction_source(self, builder: Builder) -> LookupTable:
144
- """
145
- Get the population attributable fraction source for this risk effect model.
137
+ paf_data = self.get_filtered_data(
138
+ builder, self.configuration.data_sources.population_attributable_fraction
139
+ )
140
+ self.lookup_tables["population_attributable_fraction"] = self.build_lookup_table(
141
+ builder, paf_data
142
+ )
146
143
 
147
- Parameters
148
- ----------
149
- builder
150
- Interface to access simulation managers.
151
-
152
- Returns
153
- -------
154
- LookupTable
155
- A lookup table containing the population attributable fraction data
156
- for this risk effect model.
144
+ def get_distribution_type(self, builder: Builder) -> str:
145
+ """Get the distribution type for the risk from the configuration."""
146
+ risk_exposure_component = self._get_risk_exposure_class(builder)
147
+ if risk_exposure_component.distribution_type:
148
+ return risk_exposure_component.distribution_type
149
+ return risk_exposure_component.get_distribution_type(builder)
150
+
151
+ def get_relative_risk_data(
152
+ self,
153
+ builder: Builder,
154
+ configuration=None,
155
+ ) -> Union[str, float, pd.DataFrame]:
156
+ if configuration is None:
157
+ configuration = self.configuration
158
+
159
+ rr_source = configuration.data_sources.relative_risk
160
+ rr_dist_parameters = configuration.data_source_parameters.relative_risk.to_dict()
161
+
162
+ try:
163
+ distribution = getattr(import_module("scipy.stats"), rr_source)
164
+ rng = np.random.default_rng(builder.randomness.get_seed(self.name))
165
+ rr_data = distribution(**rr_dist_parameters).ppf(rng.random())
166
+ except AttributeError:
167
+ rr_data = self.get_filtered_data(builder, rr_source)
168
+ except TypeError:
169
+ raise ConfigurationError(
170
+ f"Parameters {rr_dist_parameters} are not valid for distribution {rr_source}."
171
+ )
172
+ return rr_data
173
+
174
+ def get_filtered_data(
175
+ self, builder: "Builder", data_source: Union[str, float, pd.DataFrame]
176
+ ) -> Union[float, pd.DataFrame]:
177
+ data = super().get_data(builder, data_source)
178
+
179
+ if isinstance(data, pd.DataFrame):
180
+ # filter data to only include the target entity and measure
181
+ correct_target_mask = True
182
+ columns_to_drop = []
183
+ if "affected_entity" in data.columns:
184
+ correct_target_mask &= data["affected_entity"] == self.target.name
185
+ columns_to_drop.append("affected_entity")
186
+ if "affected_measure" in data.columns:
187
+ correct_target_mask &= data["affected_measure"] == self.target.measure
188
+ columns_to_drop.append("affected_measure")
189
+ data = data[correct_target_mask].drop(columns=columns_to_drop)
190
+ return data
191
+
192
+ def process_categorical_data(
193
+ self, builder: Builder, rr_data: Union[str, float, pd.DataFrame]
194
+ ) -> Tuple[Union[str, float, pd.DataFrame], List[str]]:
195
+ if not isinstance(rr_data, pd.DataFrame):
196
+ cat1 = builder.data.load("population.demographic_dimensions")
197
+ cat1["parameter"] = "cat1"
198
+ cat1["value"] = rr_data
199
+ cat2 = cat1.copy()
200
+ cat2["parameter"] = "cat2"
201
+ cat2["value"] = 1
202
+ rr_data = pd.concat([cat1, cat2], ignore_index=True)
203
+
204
+ rr_value_cols = list(rr_data["parameter"].unique())
205
+ rr_data = pivot_categorical(builder, self.risk, rr_data, "parameter")
206
+ return rr_data, rr_value_cols
207
+
208
+ # todo currently this isn't being called. we need to properly set rrs if
209
+ # the exposure has been rebinned
210
+ def rebin_relative_risk_data(
211
+ self, builder, relative_risk_data: pd.DataFrame
212
+ ) -> pd.DataFrame:
213
+ """When the polytomous risk is rebinned, matching relative risk needs to be rebinned.
214
+ After rebinning, rr for both exposed and unexposed categories should be the weighted sum of relative risk
215
+ of the component categories where weights are relative proportions of exposure of those categories.
216
+ For example, if cat1, cat2, cat3 are exposed categories and cat4 is unexposed with exposure [0.1,0.2,0.3,0.4],
217
+ for the matching rr = [rr1, rr2, rr3, 1], rebinned rr for the rebinned cat1 should be:
218
+ (0.1 *rr1 + 0.2 * rr2 + 0.3* rr3) / (0.1+0.2+0.3)
157
219
  """
158
- paf_data = get_population_attributable_fraction_data(builder, self.risk, self.target)
159
- return builder.lookup.build_table(
160
- paf_data, key_columns=["sex"], parameter_columns=["age", "year"]
220
+ if not self.risk in builder.configuration.to_dict():
221
+ return relative_risk_data
222
+
223
+ rebin_exposed_categories = set(builder.configuration[self.risk]["rebinned_exposed"])
224
+
225
+ if rebin_exposed_categories:
226
+ # todo make sure this works
227
+ exposure_data = load_exposure_data(builder, self.risk)
228
+ relative_risk_data = self._rebin_relative_risk_data(
229
+ relative_risk_data, exposure_data, rebin_exposed_categories
230
+ )
231
+
232
+ return relative_risk_data
233
+
234
+ def _rebin_relative_risk_data(
235
+ self,
236
+ relative_risk_data: pd.DataFrame,
237
+ exposure_data: pd.DataFrame,
238
+ rebin_exposed_categories: set,
239
+ ) -> pd.DataFrame:
240
+ cols = list(exposure_data.columns.difference(["value"]))
241
+
242
+ relative_risk_data = relative_risk_data.merge(exposure_data, on=cols)
243
+ relative_risk_data["value_x"] = relative_risk_data.value_x.multiply(
244
+ relative_risk_data.value_y
161
245
  )
246
+ relative_risk_data.parameter = relative_risk_data["parameter"].map(
247
+ lambda p: "cat1" if p in rebin_exposed_categories else "cat2"
248
+ )
249
+ relative_risk_data = relative_risk_data.groupby(cols).sum().reset_index()
250
+ relative_risk_data["value"] = relative_risk_data.value_x.divide(
251
+ relative_risk_data.value_y
252
+ ).fillna(0)
253
+ return relative_risk_data.drop(columns=["value_x", "value_y"])
254
+
255
+ def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
256
+ return builder.value.get_value(self.exposure_pipeline_name)
162
257
 
163
258
  def get_target_modifier(
164
259
  self, builder: Builder
165
260
  ) -> Callable[[pd.Index, pd.Series], pd.Series]:
166
- if self.exposure_distribution_type in ["normal", "lognormal", "ensemble"]:
261
+ if not self.is_exposure_categorical:
167
262
  tmred = builder.data.load(f"{self.risk}.tmred")
168
263
  tmrel = 0.5 * (tmred["min"] + tmred["max"])
169
264
  scale = builder.data.load(f"{self.risk}.relative_risk_scalar")
170
265
 
171
266
  def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
172
- rr = self.relative_risk(index)
267
+ rr = self.lookup_tables["relative_risk"](index)
173
268
  exposure = self.exposure(index)
174
269
  relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1)
175
270
  return target * relative_risk
@@ -178,7 +273,7 @@ class RiskEffect(Component):
178
273
  index_columns = ["index", self.risk.name]
179
274
 
180
275
  def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
181
- rr = self.relative_risk(index)
276
+ rr = self.lookup_tables["relative_risk"](index)
182
277
  exposure = self.exposure(index).reset_index()
183
278
  exposure.columns = index_columns
184
279
  exposure = exposure.set_index(index_columns)
@@ -198,12 +293,218 @@ class RiskEffect(Component):
198
293
  self.target_pipeline_name,
199
294
  modifier=self.target_modifier,
200
295
  requires_values=[f"{self.risk.name}.exposure"],
201
- requires_columns=["age", "sex"],
202
296
  )
203
297
 
204
298
  def register_paf_modifier(self, builder: Builder) -> None:
299
+ required_columns = get_lookup_columns(
300
+ [self.lookup_tables["population_attributable_fraction"]]
301
+ )
205
302
  builder.value.register_value_modifier(
206
303
  self.target_paf_pipeline_name,
207
- modifier=self.population_attributable_fraction,
208
- requires_columns=["age", "sex"],
304
+ modifier=self.lookup_tables["population_attributable_fraction"],
305
+ requires_columns=required_columns,
306
+ )
307
+
308
+ ##################
309
+ # Helper methods #
310
+ ##################
311
+
312
+ def _get_risk_exposure_class(self, builder: Builder) -> Risk:
313
+ risk_exposure_component = builder.components.get_component(self.risk)
314
+ if not isinstance(risk_exposure_component, Risk):
315
+ raise ValueError(
316
+ f"Risk effect model {self.name} requires a Risk component named {self.risk}"
317
+ )
318
+ return risk_exposure_component
319
+
320
+
321
+ class NonLogLinearRiskEffect(RiskEffect):
322
+ """A component to model the impact of an exposure-parametrized risk factor on
323
+ the target rate of some affected entity. This component will
324
+
325
+ 1) read TMRED data from the artifact and define the TMREL
326
+ 2) calculate the relative risk at TMREL by linearly interpolating over
327
+ relative risk data defined in the configuration
328
+ 3) divide relative risk data from configuration by RR at TMREL
329
+ and clip to be greater than 1
330
+ 4) build a LookupTable which returns the exposure and RR of the left and right edges
331
+ of the RR bin containing a simulant's exposure
332
+ 5) use this LookupTable to modify the target pipeline by linearly interpolating
333
+ a simulant's RR value and multiplying it by the intended target rate
334
+ """
335
+
336
+ ##############
337
+ # Properties #
338
+ ##############
339
+
340
+ @property
341
+ def configuration_defaults(self) -> Dict[str, Any]:
342
+ """
343
+ A dictionary containing the defaults for any configurations managed by
344
+ this component.
345
+ """
346
+ return {
347
+ self.name: {
348
+ "data_sources": {
349
+ "relative_risk": f"{self.risk}.relative_risk",
350
+ "population_attributable_fraction": f"{self.risk}.population_attributable_fraction",
351
+ },
352
+ }
353
+ }
354
+
355
+ @property
356
+ def columns_required(self) -> list[str]:
357
+ return [f"{self.risk.name}_exposure"]
358
+
359
+ #################
360
+ # Setup methods #
361
+ #################
362
+
363
+ def build_all_lookup_tables(self, builder: Builder) -> None:
364
+ rr_data = self.get_relative_risk_data(builder)
365
+ self.validate_rr_data(rr_data)
366
+
367
+ def define_rr_intervals(df: pd.DataFrame) -> pd.DataFrame:
368
+ # create new row for right-most exposure bin (RR is same as max RR)
369
+ max_exposure_row = df.tail(1).copy()
370
+ max_exposure_row["parameter"] = np.inf
371
+ rr_data = pd.concat([df, max_exposure_row]).reset_index()
372
+
373
+ rr_data["left_exposure"] = [0] + rr_data["parameter"][:-1].tolist()
374
+ rr_data["left_rr"] = [rr_data["value"].min()] + rr_data["value"][:-1].tolist()
375
+ rr_data["right_exposure"] = rr_data["parameter"]
376
+ rr_data["right_rr"] = rr_data["value"]
377
+
378
+ return rr_data[
379
+ ["parameter", "left_exposure", "left_rr", "right_exposure", "right_rr"]
380
+ ]
381
+
382
+ # define exposure and rr interval columns
383
+ demographic_cols = [
384
+ col for col in rr_data.columns if col != "parameter" and col != "value"
385
+ ]
386
+ rr_data = (
387
+ rr_data.groupby(demographic_cols)
388
+ .apply(define_rr_intervals)
389
+ .reset_index(level=-1, drop=True)
390
+ .reset_index()
391
+ )
392
+ rr_data = rr_data.drop("parameter", axis=1)
393
+ rr_data[f"{self.risk.name}_exposure_start"] = rr_data["left_exposure"]
394
+ rr_data[f"{self.risk.name}_exposure_end"] = rr_data["right_exposure"]
395
+ # build lookup table
396
+ rr_value_cols = ["left_exposure", "left_rr", "right_exposure", "right_rr"]
397
+ self.lookup_tables["relative_risk"] = self.build_lookup_table(
398
+ builder, rr_data, rr_value_cols
399
+ )
400
+
401
+ paf_data = self.get_filtered_data(
402
+ builder, self.configuration.data_sources.population_attributable_fraction
403
+ )
404
+ self.lookup_tables["population_attributable_fraction"] = self.build_lookup_table(
405
+ builder, paf_data
406
+ )
407
+
408
+ def get_relative_risk_data(
409
+ self,
410
+ builder: Builder,
411
+ configuration=None,
412
+ ) -> Union[str, float, pd.DataFrame]:
413
+ if configuration is None:
414
+ configuration = self.configuration
415
+
416
+ # get TMREL
417
+ tmred = builder.data.load(f"{self.risk}.tmred")
418
+ if tmred["distribution"] == "uniform":
419
+ draw = builder.configuration.input_data.input_draw_number
420
+ rng = np.random.default_rng(builder.randomness.get_seed(self.name + str(draw)))
421
+ self.tmrel = rng.uniform(tmred["min"], tmred["max"])
422
+ elif tmred["distribution"] == "draws": # currently only for iron deficiency
423
+ raise MissingDataError(
424
+ f"This data has draw-level TMRELs. You will need to contact the research team that models {self.risk.name} to get this data."
425
+ )
426
+ else:
427
+ raise MissingDataError(f"No TMRED found in gbd_mapping for risk {self.risk.name}")
428
+
429
+ # calculate RR at TMREL
430
+ rr_source = configuration.data_sources.relative_risk
431
+ original_rrs = self.get_filtered_data(builder, rr_source)
432
+
433
+ self.validate_rr_data(original_rrs)
434
+
435
+ demographic_cols = [
436
+ col for col in original_rrs.columns if col != "parameter" and col != "value"
437
+ ]
438
+
439
+ def get_rr_at_tmrel(rr_data: pd.DataFrame) -> float:
440
+ interpolated_rr_function = scipy.interpolate.interp1d(
441
+ rr_data["parameter"],
442
+ rr_data["value"],
443
+ kind="linear",
444
+ bounds_error=False,
445
+ fill_value=(
446
+ rr_data["value"].min(),
447
+ rr_data["value"].max(),
448
+ ),
449
+ )
450
+ rr_at_tmrel = interpolated_rr_function(self.tmrel).item()
451
+ return rr_at_tmrel
452
+
453
+ rrs_at_tmrel = (
454
+ original_rrs.groupby(demographic_cols)
455
+ .apply(get_rr_at_tmrel)
456
+ .rename("rr_at_tmrel")
457
+ )
458
+ rr_data = original_rrs.merge(rrs_at_tmrel.reset_index())
459
+ rr_data["value"] = rr_data["value"] / rr_data["rr_at_tmrel"]
460
+ rr_data["value"] = np.clip(rr_data["value"], 1.0, np.inf)
461
+ rr_data = rr_data.drop("rr_at_tmrel", axis=1)
462
+
463
+ return rr_data
464
+
465
+ def get_target_modifier(
466
+ self, builder: Builder
467
+ ) -> Callable[[pd.Index, pd.Series], pd.Series]:
468
+ def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
469
+ rr_intervals = self.lookup_tables["relative_risk"](index)
470
+ exposure = self.population_view.get(index)[f"{self.risk.name}_exposure"]
471
+ x1, x2 = (
472
+ rr_intervals["left_exposure"].values,
473
+ rr_intervals["right_exposure"].values,
474
+ )
475
+ y1, y2 = rr_intervals["left_rr"].values, rr_intervals["right_rr"].values
476
+ m = (y2 - y1) / (x2 - x1)
477
+ b = y1 - m * x1
478
+ relative_risk = b + m * exposure
479
+ return target * relative_risk
480
+
481
+ return adjust_target
482
+
483
+ ##############
484
+ # Validators #
485
+ ##############
486
+
487
+ def validate_rr_data(self, rr_data: pd.DataFrame) -> None:
488
+ # check that rr_data has numeric parameter data
489
+ parameter_data_is_numeric = rr_data["parameter"].dtype.kind in "biufc"
490
+ if not parameter_data_is_numeric:
491
+ raise ValueError(
492
+ f"The parameter column in your {self.risk.name} relative risk data must contain numeric data. Its dtype is {rr_data['parameter'].dtype} instead."
493
+ )
494
+
495
+ # and that these RR values are monotonically increasing within each demographic group
496
+ # so that each simulant's exposure will assign them to either one bin or one RR value
497
+ demographic_cols = [
498
+ col for col in rr_data.columns if col != "parameter" and col != "value"
499
+ ]
500
+
501
+ def values_are_monotonically_increasing(df: pd.DataFrame) -> bool:
502
+ return np.all(df["parameter"].values[1:] >= df["parameter"].values[:-1])
503
+
504
+ group_is_increasing = rr_data.groupby(demographic_cols).apply(
505
+ values_are_monotonically_increasing
209
506
  )
507
+ if not group_is_increasing.all():
508
+ raise ValueError(
509
+ "The parameter column in your relative risk data must be monotonically increasing to be used in NonLogLinearRiskEffect."
510
+ )