pySEQTarget 0.10.0__tar.gz → 0.12.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/PKG-INFO +11 -5
  2. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/README.md +5 -4
  3. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQopts.py +17 -2
  4. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQuential.py +23 -10
  5. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/__init__.py +1 -0
  6. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_hazard.py +9 -4
  7. pyseqtarget-0.12.0/pySEQTarget/analysis/_risk_estimates.py +138 -0
  8. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_survival_pred.py +21 -12
  9. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/error/__init__.py +1 -1
  10. pyseqtarget-0.10.0/pySEQTarget/error/_datachecker.py → pyseqtarget-0.12.0/pySEQTarget/error/_data_checker.py +1 -1
  11. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_mapper.py +4 -11
  12. pyseqtarget-0.12.0/pySEQTarget/expansion/_selection.py +44 -0
  13. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/__init__.py +1 -0
  14. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_bootstrap.py +41 -5
  15. pyseqtarget-0.12.0/pySEQTarget/helpers/_fix_categories.py +21 -0
  16. pyseqtarget-0.12.0/pySEQTarget/helpers/_offloader.py +82 -0
  17. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_output_files.py +2 -4
  18. pyseqtarget-0.12.0/pySEQTarget/helpers/_predict_model.py +57 -0
  19. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/__init__.py +1 -0
  20. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_bind.py +3 -0
  21. pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_fit.py +137 -0
  22. pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_offload.py +19 -0
  23. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_pred.py +74 -52
  24. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/PKG-INFO +11 -5
  25. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/SOURCES.txt +5 -1
  26. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/requires.txt +6 -0
  27. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pyproject.toml +10 -8
  28. pyseqtarget-0.12.0/tests/test_offload.py +41 -0
  29. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_survival.py +13 -0
  30. pyseqtarget-0.10.0/pySEQTarget/analysis/_risk_estimates.py +0 -136
  31. pyseqtarget-0.10.0/pySEQTarget/expansion/_selection.py +0 -31
  32. pyseqtarget-0.10.0/pySEQTarget/helpers/_predict_model.py +0 -9
  33. pyseqtarget-0.10.0/pySEQTarget/weighting/_weight_fit.py +0 -99
  34. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/LICENSE +0 -0
  35. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQoutput.py +0 -0
  36. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/__init__.py +0 -0
  37. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_outcome_fit.py +0 -0
  38. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_subgroup_fit.py +0 -0
  39. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/data/__init__.py +0 -0
  40. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/error/_param_checker.py +0 -0
  41. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/__init__.py +0 -0
  42. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_binder.py +0 -0
  43. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_diagnostics.py +0 -0
  44. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_dynamic.py +0 -0
  45. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_col_string.py +0 -0
  46. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_format_time.py +0 -0
  47. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_pad.py +0 -0
  48. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_prepare_data.py +0 -0
  49. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/__init__.py +0 -0
  50. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_censoring.py +0 -0
  51. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_denominator.py +0 -0
  52. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_numerator.py +0 -0
  53. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_outcome.py +0 -0
  54. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/plot/__init__.py +0 -0
  55. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/plot/_survival_plot.py +0 -0
  56. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_data.py +0 -0
  57. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_stats.py +0 -0
  58. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/dependency_links.txt +0 -0
  59. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/top_level.txt +0 -0
  60. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/setup.cfg +0 -0
  61. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_accessor.py +0 -0
  62. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_coefficients.py +0 -0
  63. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_covariates.py +0 -0
  64. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_followup_options.py +0 -0
  65. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_hazard.py +0 -0
  66. {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_parallel.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pySEQTarget
3
- Version: 0.10.0
3
+ Version: 0.12.0
4
4
  Summary: Sequentially Nested Target Trial Emulation
5
5
  Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
6
6
  Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
@@ -33,6 +33,11 @@ Requires-Dist: statsmodels
33
33
  Requires-Dist: matplotlib
34
34
  Requires-Dist: pyarrow
35
35
  Requires-Dist: lifelines
36
+ Requires-Dist: joblib
37
+ Provides-Extra: output
38
+ Requires-Dist: markdown; extra == "output"
39
+ Requires-Dist: weasyprint; extra == "output"
40
+ Requires-Dist: tabulate; extra == "output"
36
41
  Dynamic: license-file
37
42
 
38
43
  # pySEQTarget - Sequentially Nested Target Trial Emulation
@@ -68,8 +73,9 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe
68
73
  ```python
69
74
  import polars as pl
70
75
  from pySEQTarget import SEQuential, SEQopts
76
+ from pySEQTarget.data import load_data
71
77
 
72
- data = pl.from_pandas(SEQdata)
78
+ data = load_data("SEQdata")
73
79
  options = SEQopts(km_curves = True)
74
80
 
75
81
  # Initiate the class
@@ -77,17 +83,18 @@ model = SEQuential(data,
77
83
  id_col = "ID",
78
84
  time_col = "time",
79
85
  eligible_col = "eligible",
86
+ treatment_col = "tx_init",
87
+ outcome_col = "outcome",
80
88
  time_varying_cols = ["N", "L", "P"],
81
89
  fixed_cols = ["sex"],
82
90
  method = "ITT",
83
- options = options)
91
+ parameters = options)
84
92
  model.expand() # Construct the nested structure
85
93
  model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
86
94
  model.fit() # Fit the model
87
95
  model.survival() # Create survival curves
88
96
  model.plot() # Create and show a plot of the survival curves
89
97
  model.collect() # Collection of important information
90
-
91
98
  ```
92
99
 
93
100
  ## Assumptions
@@ -95,4 +102,3 @@ There are several key assumptions in this package -
95
102
  1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
96
103
  1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
97
104
  2. `eligible_col` and elements of `excused_colnames` are once 1, only 1 (with respect to `time_col`) flag variables.
98
-
@@ -31,8 +31,9 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe
31
31
  ```python
32
32
  import polars as pl
33
33
  from pySEQTarget import SEQuential, SEQopts
34
+ from pySEQTarget.data import load_data
34
35
 
35
- data = pl.from_pandas(SEQdata)
36
+ data = load_data("SEQdata")
36
37
  options = SEQopts(km_curves = True)
37
38
 
38
39
  # Initiate the class
@@ -40,17 +41,18 @@ model = SEQuential(data,
40
41
  id_col = "ID",
41
42
  time_col = "time",
42
43
  eligible_col = "eligible",
44
+ treatment_col = "tx_init",
45
+ outcome_col = "outcome",
43
46
  time_varying_cols = ["N", "L", "P"],
44
47
  fixed_cols = ["sex"],
45
48
  method = "ITT",
46
- options = options)
49
+ parameters = options)
47
50
  model.expand() # Construct the nested structure
48
51
  model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
49
52
  model.fit() # Fit the model
50
53
  model.survival() # Create survival curves
51
54
  model.plot() # Create and show a plot of the survival curves
52
55
  model.collect() # Collection of important information
53
-
54
56
  ```
55
57
 
56
58
  ## Assumptions
@@ -58,4 +60,3 @@ There are several key assumptions in this package -
58
60
  1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
59
61
  1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
60
62
  2. `eligible_col` and elements of `excused_colnames` are once 1, only 1 (with respect to `time_col`) flag variables.
61
-
@@ -1,4 +1,5 @@
1
1
  import multiprocessing
2
+ import os
2
3
  from dataclasses import dataclass, field
3
4
  from typing import List, Literal, Optional
4
5
 
@@ -18,7 +19,7 @@ class SEQopts:
18
19
  :type bootstrap_CI_method: str
19
20
  :param cense_colname: Column name for censoring effect (LTFU, etc.)
20
21
  :type cense_colname: str
21
- :param cense_denominator: Override to specify denominator patsy formula for censoring models
22
+ :param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
22
23
  :type cense_denominator: Optional[str] or None
23
24
  :param cense_numerator: Override to specify numerator patsy formula for censoring models
24
25
  :type cense_numerator: Optional[str] or None
@@ -54,8 +55,12 @@ class SEQopts:
54
55
  :type km_curves: bool
55
56
  :param ncores: Number of cores to use if running in parallel
56
57
  :type ncores: int
57
- :param numerator: Override to specify the outcome patsy formula for numerator models
58
+ :param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
58
59
  :type numerator: str
60
+ :param offload: Boolean to offload intermediate model data to disk
61
+ :type offload: bool
62
+ :param offload_dir: Directory to offload intermediate model data
63
+ :type offload_dir: str
59
64
  :param parallel: Boolean to run model fitting in parallel
60
65
  :type parallel: bool
61
66
  :param plot_colors: List of colors for KM plots, if applicable
@@ -80,8 +85,12 @@ class SEQopts:
80
85
  :type treatment_level: List[int]
81
86
  :param trial_include: Boolean to force trial values into model covariates
82
87
  :type trial_include: bool
88
+ :param visit_colname: Column name specifying visit number
89
+ :type visit_colname: str
83
90
  :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
84
91
  :type weight_eligible_colnames: List[str]
92
+ :param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
93
+ :type weight_fit_method: str
85
94
  :param weight_min: Minimum weight
86
95
  :type weight_min: float
87
96
  :param weight_max: Maximum weight
@@ -120,6 +129,8 @@ class SEQopts:
120
129
  km_curves: bool = False
121
130
  ncores: int = multiprocessing.cpu_count()
122
131
  numerator: Optional[str] = None
132
+ offload: bool = False
133
+ offload_dir: str = "_seq_models"
123
134
  parallel: bool = False
124
135
  plot_colors: List[str] = field(
125
136
  default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
@@ -136,6 +147,7 @@ class SEQopts:
136
147
  trial_include: bool = True
137
148
  visit_colname: str = None
138
149
  weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
150
+ weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
139
151
  weight_min: float = 0.0
140
152
  weight_max: float = None
141
153
  weight_lag_condition: bool = True
@@ -195,3 +207,6 @@ class SEQopts:
195
207
  attr = getattr(self, i)
196
208
  if attr is not None and not isinstance(attr, list):
197
209
  setattr(self, i, "".join(attr.split()))
210
+
211
+ if self.offload:
212
+ os.makedirs(self.offload_dir, exist_ok=True)
@@ -7,19 +7,20 @@ from typing import List, Literal, Optional
7
7
  import numpy as np
8
8
  import polars as pl
9
9
 
10
- from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
11
- _pred_risk, _risk_estimates, _subgroup_fit)
12
- from .error import _datachecker, _param_checker
10
+ from .analysis import (_calculate_hazard, _calculate_survival, _clamp,
11
+ _outcome_fit, _pred_risk, _risk_estimates,
12
+ _subgroup_fit)
13
+ from .error import _data_checker, _param_checker
13
14
  from .expansion import _binder, _diagnostics, _dynamic, _random_selection
14
- from .helpers import _col_string, _format_time, bootstrap_loop
15
+ from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
15
16
  from .initialization import (_cense_denominator, _cense_numerator,
16
17
  _denominator, _numerator, _outcome)
17
18
  from .plot import _survival_plot
18
19
  from .SEQopts import SEQopts
19
20
  from .SEQoutput import SEQoutput
20
21
  from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
21
- _fit_visit, _weight_bind, _weight_predict,
22
- _weight_setup, _weight_stats)
22
+ _fit_visit, _offload_weights, _weight_bind,
23
+ _weight_predict, _weight_setup, _weight_stats)
23
24
 
24
25
 
25
26
  class SEQuential:
@@ -83,6 +84,8 @@ class SEQuential:
83
84
  np.random.RandomState(self.seed) if self.seed is not None else np.random
84
85
  )
85
86
 
87
+ self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)
88
+
86
89
  if self.covariates is None:
87
90
  self.covariates = _outcome(self)
88
91
 
@@ -101,7 +104,7 @@ class SEQuential:
101
104
  self.cense_denominator = _cense_denominator(self)
102
105
 
103
106
  _param_checker(self)
104
- _datachecker(self)
107
+ _data_checker(self)
105
108
 
106
109
  def expand(self) -> None:
107
110
  """
@@ -190,7 +193,6 @@ class SEQuential:
190
193
  )
191
194
  id_counts = Counter(sampled_IDs)
192
195
  self._boot_samples.append(id_counts)
193
- return self
194
196
 
195
197
  @bootstrap_loop
196
198
  def fit(self) -> None:
@@ -201,6 +203,9 @@ class SEQuential:
201
203
  raise ValueError(
202
204
  "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
203
205
  )
206
+ boot_idx = None
207
+ if hasattr(self, "_current_boot_idx"):
208
+ boot_idx = self._current_boot_idx
204
209
 
205
210
  if self.weighted:
206
211
  WDT = _weight_setup(self)
@@ -217,6 +222,9 @@ class SEQuential:
217
222
  _fit_numerator(self, WDT)
218
223
  _fit_denominator(self, WDT)
219
224
 
225
+ if self.offload:
226
+ _offload_weights(self, boot_idx)
227
+
220
228
  WDT = pl.from_pandas(WDT)
221
229
  WDT = _weight_predict(self, WDT)
222
230
  _weight_bind(self, WDT)
@@ -244,6 +252,11 @@ class SEQuential:
244
252
  self.weighted,
245
253
  "weight",
246
254
  )
255
+ if self.offload:
256
+ offloaded_models = {}
257
+ for key, model in models.items():
258
+ offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
259
+ return offloaded_models
247
260
  return models
248
261
 
249
262
  def survival(self, **kwargs) -> None:
@@ -266,7 +279,7 @@ class SEQuential:
266
279
 
267
280
  risk_data = _pred_risk(self)
268
281
  surv_data = _calculate_survival(self, risk_data)
269
- self.km_data = pl.concat([risk_data, surv_data])
282
+ self.km_data = _clamp(pl.concat([risk_data, surv_data]))
270
283
  self.risk_estimates = _risk_estimates(self)
271
284
 
272
285
  end = time.perf_counter()
@@ -342,7 +355,7 @@ class SEQuential:
342
355
  }
343
356
 
344
357
  if self.compevent_colname is not None:
345
- compevent_models = [model["compevent"] for model in self.outcome_models]
358
+ compevent_models = [model["compevent"] for model in self.outcome_model]
346
359
  else:
347
360
  compevent_models = None
348
361
 
@@ -3,6 +3,7 @@ from ._outcome_fit import _outcome_fit as _outcome_fit
3
3
  from ._risk_estimates import _risk_estimates as _risk_estimates
4
4
  from ._subgroup_fit import _subgroup_fit as _subgroup_fit
5
5
  from ._survival_pred import _calculate_survival as _calculate_survival
6
+ from ._survival_pred import _clamp as _clamp
6
7
  from ._survival_pred import \
7
8
  _get_outcome_predictions as _get_outcome_predictions
8
9
  from ._survival_pred import _pred_risk as _pred_risk
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import polars as pl
5
5
  from lifelines import CoxPHFitter
6
6
 
7
+ from ..helpers._predict_model import _safe_predict
8
+
7
9
 
8
10
  def _calculate_hazard(self):
9
11
  if self.subgroup_colname is None:
@@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
93
95
  else:
94
96
  model_dict = self.outcome_model[boot_idx]
95
97
 
96
- outcome_model = model_dict["outcome"]
97
- ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
98
+ outcome_model = self._offloader.load_model(model_dict["outcome"])
99
+ ce_model = None
100
+ if self.compevent_colname and "compevent" in model_dict:
101
+ ce_model = self._offloader.load_model(model_dict["compevent"])
98
102
 
99
103
  all_treatments = []
100
104
  for val in self.treatment_level:
@@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
103
107
  )
104
108
 
105
109
  tmp_pd = tmp.to_pandas()
106
- outcome_prob = outcome_model.predict(tmp_pd)
110
+ outcome_prob = _safe_predict(outcome_model, tmp_pd)
107
111
  outcome_sim = rng.binomial(1, outcome_prob)
108
112
 
109
113
  tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
110
114
 
111
115
  if ce_model is not None:
112
- ce_prob = ce_model.predict(tmp_pd)
116
+ ce_tmp_pd = tmp.to_pandas()
117
+ ce_prob = _safe_predict(ce_model, ce_tmp_pd)
113
118
  ce_sim = rng.binomial(1, ce_prob)
114
119
  tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
115
120
 
@@ -0,0 +1,138 @@
1
+ import polars as pl
2
+ from scipy import stats
3
+
4
+
5
+ def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
6
+ """
7
+ Compute Risk Difference and Risk Ratio from a comparison dataframe.
8
+ Consolidates the repeated calculation logic.
9
+ """
10
+ if group_cols is None:
11
+ group_cols = []
12
+
13
+ if has_bootstrap:
14
+ rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
15
+ rd_comp = comp.with_columns(
16
+ [
17
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
18
+ (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
19
+ (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
20
+ ]
21
+ )
22
+ rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
23
+ col_order = group_cols + [
24
+ "A_x",
25
+ "A_y",
26
+ "Risk Difference",
27
+ "RD 95% LCI",
28
+ "RD 95% UCI",
29
+ ]
30
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
31
+
32
+ rr_log_se = (
33
+ (pl.col("se_x") / pl.col("risk_x")).pow(2)
34
+ + (pl.col("se_y") / pl.col("risk_y")).pow(2)
35
+ ).sqrt()
36
+ rr_comp = comp.with_columns(
37
+ [
38
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
39
+ ((pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()).alias(
40
+ "RR 95% LCI"
41
+ ),
42
+ ((pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()).alias(
43
+ "RR 95% UCI"
44
+ ),
45
+ ]
46
+ )
47
+ rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
48
+ col_order = group_cols + [
49
+ "A_x",
50
+ "A_y",
51
+ "Risk Ratio",
52
+ "RR 95% LCI",
53
+ "RR 95% UCI",
54
+ ]
55
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
56
+ else:
57
+ rd_comp = comp.with_columns(
58
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
59
+ )
60
+ rd_comp = rd_comp.drop(["risk_x", "risk_y"])
61
+ col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
62
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
63
+
64
+ rr_comp = comp.with_columns(
65
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
66
+ )
67
+ rr_comp = rr_comp.drop(["risk_x", "risk_y"])
68
+ col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
69
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
70
+
71
+ return rd_comp, rr_comp
72
+
73
+
74
+ def _risk_estimates(self):
75
+ last_followup = self.km_data["followup"].max()
76
+ risk = self.km_data.filter(
77
+ (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
78
+ )
79
+
80
+ group_cols = [self.subgroup_colname] if self.subgroup_colname else []
81
+ has_bootstrap = self.bootstrap_nboot > 0
82
+
83
+ if has_bootstrap:
84
+ alpha = 1 - self.bootstrap_CI
85
+ z = stats.norm.ppf(1 - alpha / 2)
86
+ else:
87
+ z = None
88
+
89
+ # Pre-extract data for each treatment level once (avoid repeated filtering)
90
+ risk_by_level = {}
91
+ for tx in self.treatment_level:
92
+ level_data = risk.filter(pl.col(self.treatment_col) == tx)
93
+ risk_by_level[tx] = {
94
+ "pred": level_data.select(group_cols + ["pred"]),
95
+ }
96
+ if has_bootstrap:
97
+ risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])
98
+
99
+ rd_comparisons = []
100
+ rr_comparisons = []
101
+
102
+ for tx_x in self.treatment_level:
103
+ for tx_y in self.treatment_level:
104
+ if tx_x == tx_y:
105
+ continue
106
+
107
+ # Use pre-extracted data instead of filtering again
108
+ risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
109
+ risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})
110
+
111
+ if group_cols:
112
+ comp = risk_x.join(risk_y, on=group_cols, how="left")
113
+ else:
114
+ comp = risk_x.join(risk_y, how="cross")
115
+
116
+ comp = comp.with_columns(
117
+ [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
118
+ )
119
+
120
+ if has_bootstrap:
121
+ se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
122
+ se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
123
+
124
+ if group_cols:
125
+ comp = comp.join(se_x, on=group_cols, how="left")
126
+ comp = comp.join(se_y, on=group_cols, how="left")
127
+ else:
128
+ comp = comp.join(se_x, how="cross")
129
+ comp = comp.join(se_y, how="cross")
130
+
131
+ rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
132
+ rd_comparisons.append(rd_comp)
133
+ rr_comparisons.append(rr_comp)
134
+
135
+ risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
136
+ risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
137
+
138
+ return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
@@ -1,5 +1,7 @@
1
1
  import polars as pl
2
2
 
3
+ from ..helpers._predict_model import _safe_predict
4
+
3
5
 
4
6
  def _get_outcome_predictions(self, TxDT, idx=None):
5
7
  data = TxDT.to_pandas()
@@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):
9
11
 
10
12
  for boot_model in self.outcome_model:
11
13
  model_dict = boot_model[idx] if idx is not None else boot_model
12
- predictions["outcome"].append(model_dict["outcome"].predict(data))
14
+ outcome_model = self._offloader.load_model(model_dict["outcome"])
15
+ predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
16
+
13
17
  if self.compevent_colname is not None:
14
- predictions["compevent"].append(model_dict["compevent"].predict(data))
18
+ compevent_model = self._offloader.load_model(model_dict["compevent"])
19
+ predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
15
20
 
16
21
  return predictions
17
22
 
@@ -41,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
41
46
  lci = a / 2
42
47
  uci = 1 - lci
43
48
 
49
+ # Pre-compute the followup range once (starts at 1, not 0)
50
+ followup_range = list(range(1, self.followup_max + 1))
51
+
44
52
  SDT = (
45
53
  data.with_columns(
46
- [
47
- (
48
- pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
49
- ).alias("TID")
50
- ]
54
+ [pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
51
55
  )
52
56
  .group_by("TID")
53
57
  .first()
54
58
  .drop(["followup", f"followup{self.indicator_squared}"])
55
- .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
59
+ .with_columns([pl.lit(followup_range).alias("followup")])
56
60
  .explode("followup")
57
61
  .with_columns(
58
- [
59
- (pl.col("followup") + 1).alias("followup"),
60
- (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
61
- ]
62
+ [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
62
63
  )
63
64
  ).sort([self.id_col, "trial", "followup"])
64
65
 
@@ -370,3 +371,11 @@ def _calculate_survival(self, risk_data):
370
371
  [(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")]
371
372
  )
372
373
  return surv
374
+
375
+
376
+ def _clamp(data):
377
+ """Clamp prediction and CI columns to [0, 1] bounds."""
378
+ cols = ["pred", "LCI", "UCI"]
379
+ exists = [c for c in cols if c in data.columns]
380
+
381
+ return data.with_columns([pl.col(col).clip(0.0, 1.0) for col in exists])
@@ -1,2 +1,2 @@
1
- from ._datachecker import _datachecker as _datachecker
1
+ from ._data_checker import _data_checker as _data_checker
2
2
  from ._param_checker import _param_checker as _param_checker
@@ -1,7 +1,7 @@
1
1
  import polars as pl
2
2
 
3
3
 
4
- def _datachecker(self):
4
+ def _data_checker(self):
5
5
  check = self.data.group_by(self.id_col).agg(
6
6
  [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")]
7
7
  )
@@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
13
13
  .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
14
14
  .with_columns(
15
15
  [
16
- pl.struct(
17
- [
18
- pl.col(time_col),
19
- pl.col(time_col).max().over(id_col).alias("max_time"),
20
- ]
21
- )
22
- .map_elements(
23
- lambda x: list(range(x[time_col], x["max_time"] + 1)),
24
- return_dtype=pl.List(pl.Int64),
25
- )
26
- .alias("period")
16
+ pl.int_ranges(
17
+ pl.col(time_col),
18
+ pl.col(time_col).max().over(id_col) + 1,
19
+ ).alias("period")
27
20
  ]
28
21
  )
29
22
  .explode("period")
@@ -0,0 +1,44 @@
1
+ import polars as pl
2
+
3
+
4
+ def _random_selection(self):
5
+ """
6
+ Handles the case where random selection is applied for data from
7
+ the __mapper -> __binder -> optionally __dynamic pipeline
8
+ """
9
+ UIDs = (
10
+ self.DT.select(
11
+ [self.id_col, "trial", f"{self.treatment_col}{self.indicator_baseline}"]
12
+ )
13
+ .with_columns(
14
+ (
15
+ pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("trial").cast(pl.Utf8)
16
+ ).alias("trialID")
17
+ )
18
+ .filter(
19
+ pl.col(f"{self.treatment_col}{self.indicator_baseline}")
20
+ == self.treatment_level[0]
21
+ )
22
+ .unique("trialID")
23
+ .get_column("trialID")
24
+ .to_list()
25
+ )
26
+
27
+ NIDs = len(UIDs)
28
+ sample = self._rng.choice(
29
+ UIDs, size=int(self.selection_sample * NIDs), replace=False
30
+ )
31
+
32
+ self.DT = (
33
+ self.DT.with_columns(
34
+ (
35
+ pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("trial").cast(pl.Utf8)
36
+ ).alias("trialID")
37
+ )
38
+ .filter(
39
+ pl.col("trialID").is_in(sample)
40
+ | pl.col(f"{self.treatment_col}{self.indicator_baseline}")
41
+ != self.treatment_level[0]
42
+ )
43
+ .drop("trialID")
44
+ )
@@ -1,6 +1,7 @@
1
1
  from ._bootstrap import bootstrap_loop as bootstrap_loop
2
2
  from ._col_string import _col_string as _col_string
3
3
  from ._format_time import _format_time as _format_time
4
+ from ._offloader import Offloader as Offloader
4
5
  from ._output_files import _build_md as _build_md
5
6
  from ._output_files import _build_pdf as _build_pdf
6
7
  from ._pad import _pad as _pad