vivarium-public-health 2.3.3__py3-none-any.whl → 3.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. vivarium_public_health/_version.py +1 -1
  2. vivarium_public_health/disease/model.py +23 -21
  3. vivarium_public_health/disease/models.py +1 -0
  4. vivarium_public_health/disease/special_disease.py +40 -41
  5. vivarium_public_health/disease/state.py +42 -125
  6. vivarium_public_health/disease/transition.py +70 -27
  7. vivarium_public_health/mslt/delay.py +1 -0
  8. vivarium_public_health/mslt/disease.py +1 -0
  9. vivarium_public_health/mslt/intervention.py +1 -0
  10. vivarium_public_health/mslt/magic_wand_components.py +1 -0
  11. vivarium_public_health/mslt/observer.py +1 -0
  12. vivarium_public_health/mslt/population.py +1 -0
  13. vivarium_public_health/plugins/parser.py +61 -31
  14. vivarium_public_health/population/add_new_birth_cohorts.py +2 -3
  15. vivarium_public_health/population/base_population.py +2 -1
  16. vivarium_public_health/population/mortality.py +83 -80
  17. vivarium_public_health/{metrics → results}/__init__.py +2 -0
  18. vivarium_public_health/results/columns.py +22 -0
  19. vivarium_public_health/results/disability.py +187 -0
  20. vivarium_public_health/results/disease.py +222 -0
  21. vivarium_public_health/results/mortality.py +186 -0
  22. vivarium_public_health/results/observer.py +78 -0
  23. vivarium_public_health/results/risk.py +138 -0
  24. vivarium_public_health/results/simple_cause.py +18 -0
  25. vivarium_public_health/{metrics → results}/stratification.py +10 -8
  26. vivarium_public_health/risks/__init__.py +1 -2
  27. vivarium_public_health/risks/base_risk.py +134 -29
  28. vivarium_public_health/risks/data_transformations.py +65 -326
  29. vivarium_public_health/risks/distributions.py +315 -145
  30. vivarium_public_health/risks/effect.py +376 -75
  31. vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py +61 -89
  32. vivarium_public_health/treatment/magic_wand.py +1 -0
  33. vivarium_public_health/treatment/scale_up.py +1 -0
  34. vivarium_public_health/treatment/therapeutic_inertia.py +1 -0
  35. vivarium_public_health/utilities.py +17 -2
  36. {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/METADATA +12 -2
  37. vivarium_public_health-3.0.1.dist-info/RECORD +49 -0
  38. {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/WHEEL +1 -1
  39. vivarium_public_health/metrics/disability.py +0 -118
  40. vivarium_public_health/metrics/disease.py +0 -136
  41. vivarium_public_health/metrics/mortality.py +0 -144
  42. vivarium_public_health/metrics/risk.py +0 -110
  43. vivarium_public_health/testing/__init__.py +0 -0
  44. vivarium_public_health/testing/mock_artifact.py +0 -145
  45. vivarium_public_health/testing/utils.py +0 -71
  46. vivarium_public_health-2.3.3.dist-info/RECORD +0 -49
  47. {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/LICENSE.txt +0 -0
  48. {vivarium_public_health-2.3.3.dist-info → vivarium_public_health-3.0.1.dist-info}/top_level.txt +0 -0
@@ -8,20 +8,25 @@ exposure models and disease models.
8
8
 
9
9
  """
10
10
 
11
- from typing import Any, Callable, Dict
11
+ from importlib import import_module
12
+ from typing import Any, Callable, Dict, List, Tuple, Union
12
13
 
13
14
  import numpy as np
14
15
  import pandas as pd
16
+ import scipy
17
+ from layered_config_tree import ConfigurationError
15
18
  from vivarium import Component
16
19
  from vivarium.framework.engine import Builder
17
- from vivarium.framework.lookup import LookupTable
20
+ from vivarium.framework.event import Event
21
+ from vivarium.framework.population import SimulantData
18
22
 
23
+ from vivarium_public_health.risks import Risk
19
24
  from vivarium_public_health.risks.data_transformations import (
20
- get_distribution_type,
21
- get_population_attributable_fraction_data,
22
- get_relative_risk_data,
25
+ load_exposure_data,
26
+ pivot_categorical,
23
27
  )
24
- from vivarium_public_health.utilities import EntityString, TargetString
28
+ from vivarium_public_health.risks.distributions import MissingDataError
29
+ from vivarium_public_health.utilities import EntityString, TargetString, get_lookup_columns
25
30
 
26
31
 
27
32
  class RiskEffect(Component):
@@ -34,29 +39,24 @@ class RiskEffect(Component):
34
39
  .. code-block:: yaml
35
40
 
36
41
  configuration:
37
- effect_of_risk_on_affected_risk:
42
+ risk_effect.risk_name_on_affected_target:
38
43
  exposure_parameters: 2
39
44
  incidence_rate: 10
40
45
 
41
46
  """
42
47
 
43
- CONFIGURATION_DEFAULTS = {
44
- "effect_of_risk_on_target": {
45
- "measure": {
46
- "relative_risk": None,
47
- "mean": None,
48
- "se": None,
49
- "log_mean": None,
50
- "log_se": None,
51
- "tau_squared": None,
52
- }
53
- }
54
- }
55
-
56
- ##############
48
+ ###############
57
49
  # Properties #
58
50
  ##############
59
51
 
52
+ @property
53
+ def name(self) -> str:
54
+ return self.get_name(self.risk, self.target)
55
+
56
+ @staticmethod
57
+ def get_name(risk: EntityString, target: TargetString) -> str:
58
+ return f"risk_effect.{risk.name}_on_{target}"
59
+
60
60
  @property
61
61
  def configuration_defaults(self) -> Dict[str, Any]:
62
62
  """
@@ -64,13 +64,25 @@ class RiskEffect(Component):
64
64
  this component.
65
65
  """
66
66
  return {
67
- f"effect_of_{self.risk.name}_on_{self.target.name}": {
68
- self.target.measure: self.CONFIGURATION_DEFAULTS["effect_of_risk_on_target"][
69
- "measure"
70
- ]
67
+ self.name: {
68
+ "data_sources": {
69
+ "relative_risk": f"{self.risk}.relative_risk",
70
+ "population_attributable_fraction": f"{self.risk}.population_attributable_fraction",
71
+ },
72
+ "data_source_parameters": {
73
+ "relative_risk": {},
74
+ },
71
75
  }
72
76
  }
73
77
 
78
+ @property
79
+ def is_exposure_categorical(self) -> bool:
80
+ return self._exposure_distribution_type in [
81
+ "dichotomous",
82
+ "ordered_polytomous",
83
+ "unordered_polytomous",
84
+ ]
85
+
74
86
  #####################
75
87
  # Lifecycle methods #
76
88
  #####################
@@ -92,18 +104,15 @@ class RiskEffect(Component):
92
104
  self.risk = EntityString(risk)
93
105
  self.target = TargetString(target)
94
106
 
107
+ self._exposure_distribution_type = None
108
+
95
109
  self.exposure_pipeline_name = f"{self.risk.name}.exposure"
96
110
  self.target_pipeline_name = f"{self.target.name}.{self.target.measure}"
97
111
  self.target_paf_pipeline_name = f"{self.target_pipeline_name}.paf"
98
112
 
99
113
  # noinspection PyAttributeOutsideInit
100
114
  def setup(self, builder: Builder) -> None:
101
- self.exposure_distribution_type = self.get_distribution_type(builder)
102
115
  self.exposure = self.get_risk_exposure(builder)
103
- self.relative_risk = self.get_relative_risk_source(builder)
104
- self.population_attributable_fraction = (
105
- self.get_population_attributable_fraction_source(builder)
106
- )
107
116
 
108
117
  self.target_modifier = self.get_target_modifier(builder)
109
118
 
@@ -114,62 +123,148 @@ class RiskEffect(Component):
114
123
  # Setup methods #
115
124
  #################
116
125
 
117
- def get_distribution_type(self, builder: Builder) -> str:
118
- return get_distribution_type(builder, self.risk)
119
-
120
- def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
121
- return builder.value.get_value(self.exposure_pipeline_name)
126
+ def build_all_lookup_tables(self, builder: Builder) -> None:
127
+ self._exposure_distribution_type = self.get_distribution_type(builder)
122
128
 
123
- def get_relative_risk_source(self, builder: Builder) -> LookupTable:
124
- """
125
- Get the relative risk source for this risk effect model.
126
-
127
- Parameters
128
- ----------
129
- builder
130
- Interface to access simulation managers.
131
-
132
- Returns
133
- -------
134
- LookupTable
135
- A lookup table containing the relative risk data for this risk
136
- effect model.
137
- """
138
- relative_risk_data = get_relative_risk_data(builder, self.risk, self.target)
139
- return builder.lookup.build_table(
140
- relative_risk_data, key_columns=["sex"], parameter_columns=["age", "year"]
129
+ rr_data = self.get_relative_risk_data(builder)
130
+ rr_value_cols = None
131
+ if self.is_exposure_categorical:
132
+ rr_data, rr_value_cols = self.process_categorical_data(builder, rr_data)
133
+ self.lookup_tables["relative_risk"] = self.build_lookup_table(
134
+ builder, rr_data, rr_value_cols
141
135
  )
142
136
 
143
- def get_population_attributable_fraction_source(self, builder: Builder) -> LookupTable:
144
- """
145
- Get the population attributable fraction source for this risk effect model.
137
+ paf_data = self.get_filtered_data(
138
+ builder, self.configuration.data_sources.population_attributable_fraction
139
+ )
140
+ self.lookup_tables["population_attributable_fraction"] = self.build_lookup_table(
141
+ builder, paf_data
142
+ )
146
143
 
147
- Parameters
148
- ----------
149
- builder
150
- Interface to access simulation managers.
151
-
152
- Returns
153
- -------
154
- LookupTable
155
- A lookup table containing the population attributable fraction data
156
- for this risk effect model.
144
+ def get_distribution_type(self, builder: Builder) -> str:
145
+ """Get the distribution type for the risk from the configuration."""
146
+ risk_exposure_component = self._get_risk_exposure_class(builder)
147
+ if risk_exposure_component.distribution_type:
148
+ return risk_exposure_component.distribution_type
149
+ return risk_exposure_component.get_distribution_type(builder)
150
+
151
+ def get_relative_risk_data(
152
+ self,
153
+ builder: Builder,
154
+ configuration=None,
155
+ ) -> Union[str, float, pd.DataFrame]:
156
+ if configuration is None:
157
+ configuration = self.configuration
158
+
159
+ rr_source = configuration.data_sources.relative_risk
160
+ rr_dist_parameters = configuration.data_source_parameters.relative_risk.to_dict()
161
+
162
+ try:
163
+ distribution = getattr(import_module("scipy.stats"), rr_source)
164
+ rng = np.random.default_rng(builder.randomness.get_seed(self.name))
165
+ rr_data = distribution(**rr_dist_parameters).ppf(rng.random())
166
+ except AttributeError:
167
+ rr_data = self.get_filtered_data(builder, rr_source)
168
+ except TypeError:
169
+ raise ConfigurationError(
170
+ f"Parameters {rr_dist_parameters} are not valid for distribution {rr_source}."
171
+ )
172
+ return rr_data
173
+
174
+ def get_filtered_data(
175
+ self, builder: "Builder", data_source: Union[str, float, pd.DataFrame]
176
+ ) -> Union[float, pd.DataFrame]:
177
+ data = super().get_data(builder, data_source)
178
+
179
+ if isinstance(data, pd.DataFrame):
180
+ # filter data to only include the target entity and measure
181
+ correct_target_mask = True
182
+ columns_to_drop = []
183
+ if "affected_entity" in data.columns:
184
+ correct_target_mask &= data["affected_entity"] == self.target.name
185
+ columns_to_drop.append("affected_entity")
186
+ if "affected_measure" in data.columns:
187
+ correct_target_mask &= data["affected_measure"] == self.target.measure
188
+ columns_to_drop.append("affected_measure")
189
+ data = data[correct_target_mask].drop(columns=columns_to_drop)
190
+ return data
191
+
192
+ def process_categorical_data(
193
+ self, builder: Builder, rr_data: Union[str, float, pd.DataFrame]
194
+ ) -> Tuple[Union[str, float, pd.DataFrame], List[str]]:
195
+ if not isinstance(rr_data, pd.DataFrame):
196
+ cat1 = builder.data.load("population.demographic_dimensions")
197
+ cat1["parameter"] = "cat1"
198
+ cat1["value"] = rr_data
199
+ cat2 = cat1.copy()
200
+ cat2["parameter"] = "cat2"
201
+ cat2["value"] = 1
202
+ rr_data = pd.concat([cat1, cat2], ignore_index=True)
203
+
204
+ rr_value_cols = list(rr_data["parameter"].unique())
205
+ rr_data = pivot_categorical(builder, self.risk, rr_data, "parameter")
206
+ return rr_data, rr_value_cols
207
+
208
+ # todo currently this isn't being called. we need to properly set rrs if
209
+ # the exposure has been rebinned
210
+ def rebin_relative_risk_data(
211
+ self, builder, relative_risk_data: pd.DataFrame
212
+ ) -> pd.DataFrame:
213
+ """When the polytomous risk is rebinned, matching relative risk needs to be rebinned.
214
+ After rebinning, rr for both exposed and unexposed categories should be the weighted sum of relative risk
215
+ of the component categories where weights are relative proportions of exposure of those categories.
216
+ For example, if cat1, cat2, cat3 are exposed categories and cat4 is unexposed with exposure [0.1,0.2,0.3,0.4],
217
+ for the matching rr = [rr1, rr2, rr3, 1], rebinned rr for the rebinned cat1 should be:
218
+ (0.1 *rr1 + 0.2 * rr2 + 0.3* rr3) / (0.1+0.2+0.3)
157
219
  """
158
- paf_data = get_population_attributable_fraction_data(builder, self.risk, self.target)
159
- return builder.lookup.build_table(
160
- paf_data, key_columns=["sex"], parameter_columns=["age", "year"]
220
+ if not self.risk in builder.configuration.to_dict():
221
+ return relative_risk_data
222
+
223
+ rebin_exposed_categories = set(builder.configuration[self.risk]["rebinned_exposed"])
224
+
225
+ if rebin_exposed_categories:
226
+ # todo make sure this works
227
+ exposure_data = load_exposure_data(builder, self.risk)
228
+ relative_risk_data = self._rebin_relative_risk_data(
229
+ relative_risk_data, exposure_data, rebin_exposed_categories
230
+ )
231
+
232
+ return relative_risk_data
233
+
234
+ def _rebin_relative_risk_data(
235
+ self,
236
+ relative_risk_data: pd.DataFrame,
237
+ exposure_data: pd.DataFrame,
238
+ rebin_exposed_categories: set,
239
+ ) -> pd.DataFrame:
240
+ cols = list(exposure_data.columns.difference(["value"]))
241
+
242
+ relative_risk_data = relative_risk_data.merge(exposure_data, on=cols)
243
+ relative_risk_data["value_x"] = relative_risk_data.value_x.multiply(
244
+ relative_risk_data.value_y
161
245
  )
246
+ relative_risk_data.parameter = relative_risk_data["parameter"].map(
247
+ lambda p: "cat1" if p in rebin_exposed_categories else "cat2"
248
+ )
249
+ relative_risk_data = relative_risk_data.groupby(cols).sum().reset_index()
250
+ relative_risk_data["value"] = relative_risk_data.value_x.divide(
251
+ relative_risk_data.value_y
252
+ ).fillna(0)
253
+ return relative_risk_data.drop(columns=["value_x", "value_y"])
254
+
255
+ def get_risk_exposure(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
256
+ return builder.value.get_value(self.exposure_pipeline_name)
162
257
 
163
258
  def get_target_modifier(
164
259
  self, builder: Builder
165
260
  ) -> Callable[[pd.Index, pd.Series], pd.Series]:
166
- if self.exposure_distribution_type in ["normal", "lognormal", "ensemble"]:
261
+ if not self.is_exposure_categorical:
167
262
  tmred = builder.data.load(f"{self.risk}.tmred")
168
263
  tmrel = 0.5 * (tmred["min"] + tmred["max"])
169
264
  scale = builder.data.load(f"{self.risk}.relative_risk_scalar")
170
265
 
171
266
  def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
172
- rr = self.relative_risk(index)
267
+ rr = self.lookup_tables["relative_risk"](index)
173
268
  exposure = self.exposure(index)
174
269
  relative_risk = np.maximum(rr.values ** ((exposure - tmrel) / scale), 1)
175
270
  return target * relative_risk
@@ -178,7 +273,7 @@ class RiskEffect(Component):
178
273
  index_columns = ["index", self.risk.name]
179
274
 
180
275
  def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
181
- rr = self.relative_risk(index)
276
+ rr = self.lookup_tables["relative_risk"](index)
182
277
  exposure = self.exposure(index).reset_index()
183
278
  exposure.columns = index_columns
184
279
  exposure = exposure.set_index(index_columns)
@@ -198,12 +293,218 @@ class RiskEffect(Component):
198
293
  self.target_pipeline_name,
199
294
  modifier=self.target_modifier,
200
295
  requires_values=[f"{self.risk.name}.exposure"],
201
- requires_columns=["age", "sex"],
202
296
  )
203
297
 
204
298
  def register_paf_modifier(self, builder: Builder) -> None:
299
+ required_columns = get_lookup_columns(
300
+ [self.lookup_tables["population_attributable_fraction"]]
301
+ )
205
302
  builder.value.register_value_modifier(
206
303
  self.target_paf_pipeline_name,
207
- modifier=self.population_attributable_fraction,
208
- requires_columns=["age", "sex"],
304
+ modifier=self.lookup_tables["population_attributable_fraction"],
305
+ requires_columns=required_columns,
306
+ )
307
+
308
+ ##################
309
+ # Helper methods #
310
+ ##################
311
+
312
+ def _get_risk_exposure_class(self, builder: Builder) -> Risk:
313
+ risk_exposure_component = builder.components.get_component(self.risk)
314
+ if not isinstance(risk_exposure_component, Risk):
315
+ raise ValueError(
316
+ f"Risk effect model {self.name} requires a Risk component named {self.risk}"
317
+ )
318
+ return risk_exposure_component
319
+
320
+
321
+ class NonLogLinearRiskEffect(RiskEffect):
322
+ """A component to model the impact of an exposure-parametrized risk factor on
323
+ the target rate of some affected entity. This component will
324
+
325
+ 1) read TMRED data from the artifact and define the TMREL
326
+ 2) calculate the relative risk at TMREL by linearly interpolating over
327
+ relative risk data defined in the configuration
328
+ 3) divide relative risk data from configuration by RR at TMREL
329
+ and clip to be greater than 1
330
+ 4) build a LookupTable which returns the exposure and RR of the left and right edges
331
+ of the RR bin containing a simulant's exposure
332
+ 5) use this LookupTable to modify the target pipeline by linearly interpolating
333
+ a simulant's RR value and multiplying it by the intended target rate
334
+ """
335
+
336
+ ##############
337
+ # Properties #
338
+ ##############
339
+
340
+ @property
341
+ def configuration_defaults(self) -> Dict[str, Any]:
342
+ """
343
+ A dictionary containing the defaults for any configurations managed by
344
+ this component.
345
+ """
346
+ return {
347
+ self.name: {
348
+ "data_sources": {
349
+ "relative_risk": f"{self.risk}.relative_risk",
350
+ "population_attributable_fraction": f"{self.risk}.population_attributable_fraction",
351
+ },
352
+ }
353
+ }
354
+
355
+ @property
356
+ def columns_required(self) -> list[str]:
357
+ return [f"{self.risk.name}_exposure"]
358
+
359
+ #################
360
+ # Setup methods #
361
+ #################
362
+
363
+ def build_all_lookup_tables(self, builder: Builder) -> None:
364
+ rr_data = self.get_relative_risk_data(builder)
365
+ self.validate_rr_data(rr_data)
366
+
367
+ def define_rr_intervals(df: pd.DataFrame) -> pd.DataFrame:
368
+ # create new row for right-most exposure bin (RR is same as max RR)
369
+ max_exposure_row = df.tail(1).copy()
370
+ max_exposure_row["parameter"] = np.inf
371
+ rr_data = pd.concat([df, max_exposure_row]).reset_index()
372
+
373
+ rr_data["left_exposure"] = [0] + rr_data["parameter"][:-1].tolist()
374
+ rr_data["left_rr"] = [rr_data["value"].min()] + rr_data["value"][:-1].tolist()
375
+ rr_data["right_exposure"] = rr_data["parameter"]
376
+ rr_data["right_rr"] = rr_data["value"]
377
+
378
+ return rr_data[
379
+ ["parameter", "left_exposure", "left_rr", "right_exposure", "right_rr"]
380
+ ]
381
+
382
+ # define exposure and rr interval columns
383
+ demographic_cols = [
384
+ col for col in rr_data.columns if col != "parameter" and col != "value"
385
+ ]
386
+ rr_data = (
387
+ rr_data.groupby(demographic_cols)
388
+ .apply(define_rr_intervals)
389
+ .reset_index(level=-1, drop=True)
390
+ .reset_index()
391
+ )
392
+ rr_data = rr_data.drop("parameter", axis=1)
393
+ rr_data[f"{self.risk.name}_exposure_start"] = rr_data["left_exposure"]
394
+ rr_data[f"{self.risk.name}_exposure_end"] = rr_data["right_exposure"]
395
+ # build lookup table
396
+ rr_value_cols = ["left_exposure", "left_rr", "right_exposure", "right_rr"]
397
+ self.lookup_tables["relative_risk"] = self.build_lookup_table(
398
+ builder, rr_data, rr_value_cols
399
+ )
400
+
401
+ paf_data = self.get_filtered_data(
402
+ builder, self.configuration.data_sources.population_attributable_fraction
403
+ )
404
+ self.lookup_tables["population_attributable_fraction"] = self.build_lookup_table(
405
+ builder, paf_data
406
+ )
407
+
408
+ def get_relative_risk_data(
409
+ self,
410
+ builder: Builder,
411
+ configuration=None,
412
+ ) -> Union[str, float, pd.DataFrame]:
413
+ if configuration is None:
414
+ configuration = self.configuration
415
+
416
+ # get TMREL
417
+ tmred = builder.data.load(f"{self.risk}.tmred")
418
+ if tmred["distribution"] == "uniform":
419
+ draw = builder.configuration.input_data.input_draw_number
420
+ rng = np.random.default_rng(builder.randomness.get_seed(self.name + str(draw)))
421
+ self.tmrel = rng.uniform(tmred["min"], tmred["max"])
422
+ elif tmred["distribution"] == "draws": # currently only for iron deficiency
423
+ raise MissingDataError(
424
+ f"This data has draw-level TMRELs. You will need to contact the research team that models {self.risk.name} to get this data."
425
+ )
426
+ else:
427
+ raise MissingDataError(f"No TMRED found in gbd_mapping for risk {self.risk.name}")
428
+
429
+ # calculate RR at TMREL
430
+ rr_source = configuration.data_sources.relative_risk
431
+ original_rrs = self.get_filtered_data(builder, rr_source)
432
+
433
+ self.validate_rr_data(original_rrs)
434
+
435
+ demographic_cols = [
436
+ col for col in original_rrs.columns if col != "parameter" and col != "value"
437
+ ]
438
+
439
+ def get_rr_at_tmrel(rr_data: pd.DataFrame) -> float:
440
+ interpolated_rr_function = scipy.interpolate.interp1d(
441
+ rr_data["parameter"],
442
+ rr_data["value"],
443
+ kind="linear",
444
+ bounds_error=False,
445
+ fill_value=(
446
+ rr_data["value"].min(),
447
+ rr_data["value"].max(),
448
+ ),
449
+ )
450
+ rr_at_tmrel = interpolated_rr_function(self.tmrel).item()
451
+ return rr_at_tmrel
452
+
453
+ rrs_at_tmrel = (
454
+ original_rrs.groupby(demographic_cols)
455
+ .apply(get_rr_at_tmrel)
456
+ .rename("rr_at_tmrel")
457
+ )
458
+ rr_data = original_rrs.merge(rrs_at_tmrel.reset_index())
459
+ rr_data["value"] = rr_data["value"] / rr_data["rr_at_tmrel"]
460
+ rr_data["value"] = np.clip(rr_data["value"], 1.0, np.inf)
461
+ rr_data = rr_data.drop("rr_at_tmrel", axis=1)
462
+
463
+ return rr_data
464
+
465
+ def get_target_modifier(
466
+ self, builder: Builder
467
+ ) -> Callable[[pd.Index, pd.Series], pd.Series]:
468
+ def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
469
+ rr_intervals = self.lookup_tables["relative_risk"](index)
470
+ exposure = self.population_view.get(index)[f"{self.risk.name}_exposure"]
471
+ x1, x2 = (
472
+ rr_intervals["left_exposure"].values,
473
+ rr_intervals["right_exposure"].values,
474
+ )
475
+ y1, y2 = rr_intervals["left_rr"].values, rr_intervals["right_rr"].values
476
+ m = (y2 - y1) / (x2 - x1)
477
+ b = y1 - m * x1
478
+ relative_risk = b + m * exposure
479
+ return target * relative_risk
480
+
481
+ return adjust_target
482
+
483
+ ##############
484
+ # Validators #
485
+ ##############
486
+
487
+ def validate_rr_data(self, rr_data: pd.DataFrame) -> None:
488
+ # check that rr_data has numeric parameter data
489
+ parameter_data_is_numeric = rr_data["parameter"].dtype.kind in "biufc"
490
+ if not parameter_data_is_numeric:
491
+ raise ValueError(
492
+ f"The parameter column in your {self.risk.name} relative risk data must contain numeric data. Its dtype is {rr_data['parameter'].dtype} instead."
493
+ )
494
+
495
+ # and that these RR values are monotonically increasing within each demographic group
496
+ # so that each simulant's exposure will assign them to either one bin or one RR value
497
+ demographic_cols = [
498
+ col for col in rr_data.columns if col != "parameter" and col != "value"
499
+ ]
500
+
501
+ def values_are_monotonically_increasing(df: pd.DataFrame) -> bool:
502
+ return np.all(df["parameter"].values[1:] >= df["parameter"].values[:-1])
503
+
504
+ group_is_increasing = rr_data.groupby(demographic_cols).apply(
505
+ values_are_monotonically_increasing
209
506
  )
507
+ if not group_is_increasing.all():
508
+ raise ValueError(
509
+ "The parameter column in your relative risk data must be monotonically increasing to be used in NonLogLinearRiskEffect."
510
+ )