pySEQTarget 0.10.1__tar.gz → 0.12.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/PKG-INFO +6 -1
  2. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQopts.py +17 -2
  3. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQuential.py +17 -4
  4. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_hazard.py +9 -4
  5. pyseqtarget-0.12.0/pySEQTarget/analysis/_risk_estimates.py +138 -0
  6. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_survival_pred.py +13 -12
  7. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_mapper.py +4 -11
  8. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/__init__.py +1 -0
  9. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_bootstrap.py +39 -4
  10. pyseqtarget-0.12.0/pySEQTarget/helpers/_fix_categories.py +21 -0
  11. pyseqtarget-0.12.0/pySEQTarget/helpers/_offloader.py +82 -0
  12. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_output_files.py +2 -4
  13. pyseqtarget-0.12.0/pySEQTarget/helpers/_predict_model.py +57 -0
  14. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/__init__.py +1 -0
  15. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_bind.py +3 -0
  16. pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_fit.py +137 -0
  17. pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_offload.py +19 -0
  18. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_pred.py +74 -52
  19. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/PKG-INFO +6 -1
  20. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/SOURCES.txt +4 -0
  21. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/requires.txt +6 -0
  22. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pyproject.toml +10 -8
  23. pyseqtarget-0.12.0/tests/test_offload.py +41 -0
  24. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_survival.py +13 -0
  25. pyseqtarget-0.10.1/pySEQTarget/analysis/_risk_estimates.py +0 -136
  26. pyseqtarget-0.10.1/pySEQTarget/helpers/_predict_model.py +0 -9
  27. pyseqtarget-0.10.1/pySEQTarget/weighting/_weight_fit.py +0 -99
  28. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/LICENSE +0 -0
  29. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/README.md +0 -0
  30. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQoutput.py +0 -0
  31. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/__init__.py +0 -0
  32. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/__init__.py +0 -0
  33. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_outcome_fit.py +0 -0
  34. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_subgroup_fit.py +0 -0
  35. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/data/__init__.py +0 -0
  36. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/__init__.py +0 -0
  37. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/_data_checker.py +0 -0
  38. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/_param_checker.py +0 -0
  39. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/__init__.py +0 -0
  40. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_binder.py +0 -0
  41. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_diagnostics.py +0 -0
  42. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_dynamic.py +0 -0
  43. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_selection.py +0 -0
  44. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_col_string.py +0 -0
  45. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_format_time.py +0 -0
  46. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_pad.py +0 -0
  47. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_prepare_data.py +0 -0
  48. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/__init__.py +0 -0
  49. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_censoring.py +0 -0
  50. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_denominator.py +0 -0
  51. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_numerator.py +0 -0
  52. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_outcome.py +0 -0
  53. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/plot/__init__.py +0 -0
  54. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/plot/_survival_plot.py +0 -0
  55. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_data.py +0 -0
  56. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_stats.py +0 -0
  57. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/dependency_links.txt +0 -0
  58. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/top_level.txt +0 -0
  59. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/setup.cfg +0 -0
  60. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_accessor.py +0 -0
  61. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_coefficients.py +0 -0
  62. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_covariates.py +0 -0
  63. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_followup_options.py +0 -0
  64. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_hazard.py +0 -0
  65. {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_parallel.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pySEQTarget
3
- Version: 0.10.1
3
+ Version: 0.12.0
4
4
  Summary: Sequentially Nested Target Trial Emulation
5
5
  Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
6
6
  Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
@@ -33,6 +33,11 @@ Requires-Dist: statsmodels
33
33
  Requires-Dist: matplotlib
34
34
  Requires-Dist: pyarrow
35
35
  Requires-Dist: lifelines
36
+ Requires-Dist: joblib
37
+ Provides-Extra: output
38
+ Requires-Dist: markdown; extra == "output"
39
+ Requires-Dist: weasyprint; extra == "output"
40
+ Requires-Dist: tabulate; extra == "output"
36
41
  Dynamic: license-file
37
42
 
38
43
  # pySEQTarget - Sequentially Nested Target Trial Emulation
@@ -1,4 +1,5 @@
1
1
  import multiprocessing
2
+ import os
2
3
  from dataclasses import dataclass, field
3
4
  from typing import List, Literal, Optional
4
5
 
@@ -18,7 +19,7 @@ class SEQopts:
18
19
  :type bootstrap_CI_method: str
19
20
  :param cense_colname: Column name for censoring effect (LTFU, etc.)
20
21
  :type cense_colname: str
21
- :param cense_denominator: Override to specify denominator patsy formula for censoring models
22
+ :param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
22
23
  :type cense_denominator: Optional[str] or None
23
24
  :param cense_numerator: Override to specify numerator patsy formula for censoring models
24
25
  :type cense_numerator: Optional[str] or None
@@ -54,8 +55,12 @@ class SEQopts:
54
55
  :type km_curves: bool
55
56
  :param ncores: Number of cores to use if running in parallel
56
57
  :type ncores: int
57
- :param numerator: Override to specify the outcome patsy formula for numerator models
58
+ :param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
58
59
  :type numerator: str
60
+ :param offload: Boolean to offload intermediate model data to disk
61
+ :type offload: bool
62
+ :param offload_dir: Directory to offload intermediate model data
63
+ :type offload_dir: str
59
64
  :param parallel: Boolean to run model fitting in parallel
60
65
  :type parallel: bool
61
66
  :param plot_colors: List of colors for KM plots, if applicable
@@ -80,8 +85,12 @@ class SEQopts:
80
85
  :type treatment_level: List[int]
81
86
  :param trial_include: Boolean to force trial values into model covariates
82
87
  :type trial_include: bool
88
+ :param visit_colname: Column name specifying visit number
89
+ :type visit_colname: str
83
90
  :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
84
91
  :type weight_eligible_colnames: List[str]
92
+ :param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
93
+ :type weight_fit_method: str
85
94
  :param weight_min: Minimum weight
86
95
  :type weight_min: float
87
96
  :param weight_max: Maximum weight
@@ -120,6 +129,8 @@ class SEQopts:
120
129
  km_curves: bool = False
121
130
  ncores: int = multiprocessing.cpu_count()
122
131
  numerator: Optional[str] = None
132
+ offload: bool = False
133
+ offload_dir: str = "_seq_models"
123
134
  parallel: bool = False
124
135
  plot_colors: List[str] = field(
125
136
  default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
@@ -136,6 +147,7 @@ class SEQopts:
136
147
  trial_include: bool = True
137
148
  visit_colname: str = None
138
149
  weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
150
+ weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
139
151
  weight_min: float = 0.0
140
152
  weight_max: float = None
141
153
  weight_lag_condition: bool = True
@@ -195,3 +207,6 @@ class SEQopts:
195
207
  attr = getattr(self, i)
196
208
  if attr is not None and not isinstance(attr, list):
197
209
  setattr(self, i, "".join(attr.split()))
210
+
211
+ if self.offload:
212
+ os.makedirs(self.offload_dir, exist_ok=True)
@@ -12,15 +12,15 @@ from .analysis import (_calculate_hazard, _calculate_survival, _clamp,
12
12
  _subgroup_fit)
13
13
  from .error import _data_checker, _param_checker
14
14
  from .expansion import _binder, _diagnostics, _dynamic, _random_selection
15
- from .helpers import _col_string, _format_time, bootstrap_loop
15
+ from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
16
16
  from .initialization import (_cense_denominator, _cense_numerator,
17
17
  _denominator, _numerator, _outcome)
18
18
  from .plot import _survival_plot
19
19
  from .SEQopts import SEQopts
20
20
  from .SEQoutput import SEQoutput
21
21
  from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
22
- _fit_visit, _weight_bind, _weight_predict,
23
- _weight_setup, _weight_stats)
22
+ _fit_visit, _offload_weights, _weight_bind,
23
+ _weight_predict, _weight_setup, _weight_stats)
24
24
 
25
25
 
26
26
  class SEQuential:
@@ -84,6 +84,8 @@ class SEQuential:
84
84
  np.random.RandomState(self.seed) if self.seed is not None else np.random
85
85
  )
86
86
 
87
+ self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)
88
+
87
89
  if self.covariates is None:
88
90
  self.covariates = _outcome(self)
89
91
 
@@ -201,6 +203,9 @@ class SEQuential:
201
203
  raise ValueError(
202
204
  "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
203
205
  )
206
+ boot_idx = None
207
+ if hasattr(self, "_current_boot_idx"):
208
+ boot_idx = self._current_boot_idx
204
209
 
205
210
  if self.weighted:
206
211
  WDT = _weight_setup(self)
@@ -217,6 +222,9 @@ class SEQuential:
217
222
  _fit_numerator(self, WDT)
218
223
  _fit_denominator(self, WDT)
219
224
 
225
+ if self.offload:
226
+ _offload_weights(self, boot_idx)
227
+
220
228
  WDT = pl.from_pandas(WDT)
221
229
  WDT = _weight_predict(self, WDT)
222
230
  _weight_bind(self, WDT)
@@ -244,6 +252,11 @@ class SEQuential:
244
252
  self.weighted,
245
253
  "weight",
246
254
  )
255
+ if self.offload:
256
+ offloaded_models = {}
257
+ for key, model in models.items():
258
+ offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
259
+ return offloaded_models
247
260
  return models
248
261
 
249
262
  def survival(self, **kwargs) -> None:
@@ -342,7 +355,7 @@ class SEQuential:
342
355
  }
343
356
 
344
357
  if self.compevent_colname is not None:
345
- compevent_models = [model["compevent"] for model in self.outcome_models]
358
+ compevent_models = [model["compevent"] for model in self.outcome_model]
346
359
  else:
347
360
  compevent_models = None
348
361
 
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import polars as pl
5
5
  from lifelines import CoxPHFitter
6
6
 
7
+ from ..helpers._predict_model import _safe_predict
8
+
7
9
 
8
10
  def _calculate_hazard(self):
9
11
  if self.subgroup_colname is None:
@@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
93
95
  else:
94
96
  model_dict = self.outcome_model[boot_idx]
95
97
 
96
- outcome_model = model_dict["outcome"]
97
- ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
98
+ outcome_model = self._offloader.load_model(model_dict["outcome"])
99
+ ce_model = None
100
+ if self.compevent_colname and "compevent" in model_dict:
101
+ ce_model = self._offloader.load_model(model_dict["compevent"])
98
102
 
99
103
  all_treatments = []
100
104
  for val in self.treatment_level:
@@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
103
107
  )
104
108
 
105
109
  tmp_pd = tmp.to_pandas()
106
- outcome_prob = outcome_model.predict(tmp_pd)
110
+ outcome_prob = _safe_predict(outcome_model, tmp_pd)
107
111
  outcome_sim = rng.binomial(1, outcome_prob)
108
112
 
109
113
  tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
110
114
 
111
115
  if ce_model is not None:
112
- ce_prob = ce_model.predict(tmp_pd)
116
+ ce_tmp_pd = tmp.to_pandas()
117
+ ce_prob = _safe_predict(ce_model, ce_tmp_pd)
113
118
  ce_sim = rng.binomial(1, ce_prob)
114
119
  tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
115
120
 
@@ -0,0 +1,138 @@
1
+ import polars as pl
2
+ from scipy import stats
3
+
4
+
5
+ def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
6
+ """
7
+ Compute Risk Difference and Risk Ratio from a comparison dataframe.
8
+ Consolidates the repeated calculation logic.
9
+ """
10
+ if group_cols is None:
11
+ group_cols = []
12
+
13
+ if has_bootstrap:
14
+ rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
15
+ rd_comp = comp.with_columns(
16
+ [
17
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
18
+ (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
19
+ (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
20
+ ]
21
+ )
22
+ rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
23
+ col_order = group_cols + [
24
+ "A_x",
25
+ "A_y",
26
+ "Risk Difference",
27
+ "RD 95% LCI",
28
+ "RD 95% UCI",
29
+ ]
30
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
31
+
32
+ rr_log_se = (
33
+ (pl.col("se_x") / pl.col("risk_x")).pow(2)
34
+ + (pl.col("se_y") / pl.col("risk_y")).pow(2)
35
+ ).sqrt()
36
+ rr_comp = comp.with_columns(
37
+ [
38
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
39
+ ((pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()).alias(
40
+ "RR 95% LCI"
41
+ ),
42
+ ((pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()).alias(
43
+ "RR 95% UCI"
44
+ ),
45
+ ]
46
+ )
47
+ rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
48
+ col_order = group_cols + [
49
+ "A_x",
50
+ "A_y",
51
+ "Risk Ratio",
52
+ "RR 95% LCI",
53
+ "RR 95% UCI",
54
+ ]
55
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
56
+ else:
57
+ rd_comp = comp.with_columns(
58
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
59
+ )
60
+ rd_comp = rd_comp.drop(["risk_x", "risk_y"])
61
+ col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
62
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
63
+
64
+ rr_comp = comp.with_columns(
65
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
66
+ )
67
+ rr_comp = rr_comp.drop(["risk_x", "risk_y"])
68
+ col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
69
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
70
+
71
+ return rd_comp, rr_comp
72
+
73
+
74
+ def _risk_estimates(self):
75
+ last_followup = self.km_data["followup"].max()
76
+ risk = self.km_data.filter(
77
+ (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
78
+ )
79
+
80
+ group_cols = [self.subgroup_colname] if self.subgroup_colname else []
81
+ has_bootstrap = self.bootstrap_nboot > 0
82
+
83
+ if has_bootstrap:
84
+ alpha = 1 - self.bootstrap_CI
85
+ z = stats.norm.ppf(1 - alpha / 2)
86
+ else:
87
+ z = None
88
+
89
+ # Pre-extract data for each treatment level once (avoid repeated filtering)
90
+ risk_by_level = {}
91
+ for tx in self.treatment_level:
92
+ level_data = risk.filter(pl.col(self.treatment_col) == tx)
93
+ risk_by_level[tx] = {
94
+ "pred": level_data.select(group_cols + ["pred"]),
95
+ }
96
+ if has_bootstrap:
97
+ risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])
98
+
99
+ rd_comparisons = []
100
+ rr_comparisons = []
101
+
102
+ for tx_x in self.treatment_level:
103
+ for tx_y in self.treatment_level:
104
+ if tx_x == tx_y:
105
+ continue
106
+
107
+ # Use pre-extracted data instead of filtering again
108
+ risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
109
+ risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})
110
+
111
+ if group_cols:
112
+ comp = risk_x.join(risk_y, on=group_cols, how="left")
113
+ else:
114
+ comp = risk_x.join(risk_y, how="cross")
115
+
116
+ comp = comp.with_columns(
117
+ [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
118
+ )
119
+
120
+ if has_bootstrap:
121
+ se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
122
+ se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
123
+
124
+ if group_cols:
125
+ comp = comp.join(se_x, on=group_cols, how="left")
126
+ comp = comp.join(se_y, on=group_cols, how="left")
127
+ else:
128
+ comp = comp.join(se_x, how="cross")
129
+ comp = comp.join(se_y, how="cross")
130
+
131
+ rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
132
+ rd_comparisons.append(rd_comp)
133
+ rr_comparisons.append(rr_comp)
134
+
135
+ risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
136
+ risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
137
+
138
+ return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
@@ -1,5 +1,7 @@
1
1
  import polars as pl
2
2
 
3
+ from ..helpers._predict_model import _safe_predict
4
+
3
5
 
4
6
  def _get_outcome_predictions(self, TxDT, idx=None):
5
7
  data = TxDT.to_pandas()
@@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):
9
11
 
10
12
  for boot_model in self.outcome_model:
11
13
  model_dict = boot_model[idx] if idx is not None else boot_model
12
- predictions["outcome"].append(model_dict["outcome"].predict(data))
14
+ outcome_model = self._offloader.load_model(model_dict["outcome"])
15
+ predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
16
+
13
17
  if self.compevent_colname is not None:
14
- predictions["compevent"].append(model_dict["compevent"].predict(data))
18
+ compevent_model = self._offloader.load_model(model_dict["compevent"])
19
+ predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
15
20
 
16
21
  return predictions
17
22
 
@@ -41,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
41
46
  lci = a / 2
42
47
  uci = 1 - lci
43
48
 
49
+ # Pre-compute the followup range once (starts at 1, not 0)
50
+ followup_range = list(range(1, self.followup_max + 1))
51
+
44
52
  SDT = (
45
53
  data.with_columns(
46
- [
47
- (
48
- pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
49
- ).alias("TID")
50
- ]
54
+ [pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
51
55
  )
52
56
  .group_by("TID")
53
57
  .first()
54
58
  .drop(["followup", f"followup{self.indicator_squared}"])
55
- .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
59
+ .with_columns([pl.lit(followup_range).alias("followup")])
56
60
  .explode("followup")
57
61
  .with_columns(
58
- [
59
- (pl.col("followup") + 1).alias("followup"),
60
- (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
61
- ]
62
+ [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
62
63
  )
63
64
  ).sort([self.id_col, "trial", "followup"])
64
65
 
@@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
13
13
  .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
14
14
  .with_columns(
15
15
  [
16
- pl.struct(
17
- [
18
- pl.col(time_col),
19
- pl.col(time_col).max().over(id_col).alias("max_time"),
20
- ]
21
- )
22
- .map_elements(
23
- lambda x: list(range(x[time_col], x["max_time"] + 1)),
24
- return_dtype=pl.List(pl.Int64),
25
- )
26
- .alias("period")
16
+ pl.int_ranges(
17
+ pl.col(time_col),
18
+ pl.col(time_col).max().over(id_col) + 1,
19
+ ).alias("period")
27
20
  ]
28
21
  )
29
22
  .explode("period")
@@ -1,6 +1,7 @@
1
1
  from ._bootstrap import bootstrap_loop as bootstrap_loop
2
2
  from ._col_string import _col_string as _col_string
3
3
  from ._format_time import _format_time as _format_time
4
+ from ._offloader import Offloader as Offloader
4
5
  from ._output_files import _build_md as _build_md
5
6
  from ._output_files import _build_pdf as _build_pdf
6
7
  from ._pad import _pad as _pad
@@ -35,11 +35,28 @@ def _prepare_boot_data(self, data, boot_id):
35
35
 
36
36
 
37
37
  def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
38
- obj = copy.deepcopy(obj)
38
+ # Shallow copy the object and only deep copy mutable state that changes per-bootstrap
39
+ obj = copy.copy(obj)
40
+ # Deep copy only the mutable attributes that get modified during fitting
41
+ obj.outcome_model = []
42
+ obj.numerator_model = (
43
+ copy.copy(obj.numerator_model)
44
+ if hasattr(obj, "numerator_model") and obj.numerator_model
45
+ else []
46
+ )
47
+ obj.denominator_model = (
48
+ copy.copy(obj.denominator_model)
49
+ if hasattr(obj, "denominator_model") and obj.denominator_model
50
+ else []
51
+ )
52
+
39
53
  obj._rng = (
40
54
  np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
41
55
  )
56
+ original_DT = obj._offloader.load_dataframe(original_DT)
42
57
  obj.DT = _prepare_boot_data(obj, original_DT, i)
58
+ del original_DT
59
+ obj._current_boot_idx = i + 1
43
60
 
44
61
  # Disable bootstrapping to prevent recursion
45
62
  obj.bootstrap_nboot = 0
@@ -60,6 +77,7 @@ def bootstrap_loop(method):
60
77
  results = []
61
78
  original_DT = self.DT
62
79
 
80
+ self._current_boot_idx = None
63
81
  full = method(self, *args, **kwargs)
64
82
  results.append(full)
65
83
 
@@ -71,9 +89,12 @@ def bootstrap_loop(method):
71
89
  seed = getattr(self, "seed", None)
72
90
  method_name = method.__name__
73
91
 
92
+ original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
93
+
74
94
  if getattr(self, "parallel", False):
75
95
  original_rng = getattr(self, "_rng", None)
76
96
  self._rng = None
97
+ self.DT = None
77
98
 
78
99
  with ProcessPoolExecutor(max_workers=ncores) as executor:
79
100
  futures = [
@@ -81,7 +102,7 @@ def bootstrap_loop(method):
81
102
  _bootstrap_worker,
82
103
  self,
83
104
  method_name,
84
- original_DT,
105
+ original_DT_ref,
85
106
  i,
86
107
  seed,
87
108
  args,
@@ -95,13 +116,27 @@ def bootstrap_loop(method):
95
116
  results.append(j.result())
96
117
 
97
118
  self._rng = original_rng
119
+ self.DT = self._offloader.load_dataframe(original_DT_ref)
98
120
  else:
121
+ # Keep original data in memory if offloading is disabled to avoid unnecessary I/O
122
+ if self._offloader.enabled:
123
+ original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
124
+ del original_DT
125
+ else:
126
+ original_DT_ref = original_DT
127
+
99
128
  for i in tqdm(range(nboot), desc="Bootstrapping..."):
100
- self.DT = _prepare_boot_data(self, original_DT, i)
129
+ self._current_boot_idx = i + 1
130
+ tmp = self._offloader.load_dataframe(original_DT_ref)
131
+ self.DT = _prepare_boot_data(self, tmp, i)
132
+ if self._offloader.enabled:
133
+ del tmp
134
+ self.bootstrap_nboot = 0
101
135
  boot_fit = method(self, *args, **kwargs)
102
136
  results.append(boot_fit)
103
137
 
104
- self.DT = original_DT
138
+ self.bootstrap_nboot = nboot
139
+ self.DT = self._offloader.load_dataframe(original_DT_ref)
105
140
 
106
141
  end = time.perf_counter()
107
142
  self._model_time = _format_time(start, end)
@@ -0,0 +1,21 @@
1
+ def _fix_categories_for_predict(model, newdata):
2
+ """
3
+ Fix categorical column ordering in newdata to match what the model expects.
4
+ """
5
+ if (
6
+ hasattr(model, "model")
7
+ and hasattr(model.model, "data")
8
+ and hasattr(model.model.data, "design_info")
9
+ ):
10
+ design_info = model.model.data.design_info
11
+ for factor, factor_info in design_info.factor_infos.items():
12
+ if factor_info.type == "categorical":
13
+ col_name = factor.name()
14
+ if col_name in newdata.columns:
15
+ expected_categories = list(factor_info.categories)
16
+ newdata[col_name] = newdata[col_name].astype(str)
17
+ newdata[col_name] = newdata[col_name].astype("category")
18
+ newdata[col_name] = newdata[col_name].cat.set_categories(
19
+ expected_categories
20
+ )
21
+ return newdata
@@ -0,0 +1,82 @@
1
+ from functools import lru_cache
2
+ from pathlib import Path
3
+ from typing import Any, Optional, Union
4
+
5
+ import joblib
6
+ import polars as pl
7
+
8
+
9
+ class Offloader:
10
+ """Manages disk-based storage for models and intermediate data"""
11
+
12
+ def __init__(self, enabled: bool, dir: str, compression: int = 3):
13
+ self.enabled = enabled
14
+ self.dir = Path(dir)
15
+ self.compression = compression
16
+ # Create a cached loader bound to this instance
17
+ self._init_cache()
18
+
19
+ def _init_cache(self):
20
+ """Initialize the LRU cache for model loading."""
21
+ self._cached_load = lru_cache(maxsize=32)(self._load_from_disk)
22
+
23
+ def __getstate__(self):
24
+ """Prepare state for pickling - exclude the unpicklable cache."""
25
+ state = self.__dict__.copy()
26
+ # Remove the cache wrapper which can't be pickled
27
+ del state["_cached_load"]
28
+ return state
29
+
30
+ def __setstate__(self, state):
31
+ """Restore state after unpickling - recreate the cache."""
32
+ self.__dict__.update(state)
33
+ # Recreate the cache after unpickling
34
+ self._init_cache()
35
+
36
+ def save_model(
37
+ self, model: Any, name: str, boot_idx: Optional[int] = None
38
+ ) -> Union[Any, str]:
39
+ """Save a fitted model to disk and return a reference"""
40
+ if not self.enabled:
41
+ return model
42
+
43
+ filename = (
44
+ f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl"
45
+ )
46
+ filepath = self.dir / filename
47
+
48
+ joblib.dump(model, filepath, compress=self.compression)
49
+
50
+ return str(filepath)
51
+
52
+ def _load_from_disk(self, filepath: str) -> Any:
53
+ """Internal method to load a model from disk (cached)."""
54
+ return joblib.load(filepath)
55
+
56
+ def load_model(self, ref: Union[Any, str]) -> Any:
57
+ """Load a model, using cache for repeated loads of the same file."""
58
+ if not self.enabled or not isinstance(ref, str):
59
+ return ref
60
+
61
+ return self._cached_load(ref)
62
+
63
+ def clear_cache(self) -> None:
64
+ """Clear the model loading cache. Call between bootstrap iterations if needed."""
65
+ self._cached_load.cache_clear()
66
+
67
+ def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
68
+ if not self.enabled:
69
+ return df
70
+
71
+ filename = f"{name}.parquet"
72
+ filepath = self.dir / filename
73
+
74
+ df.write_parquet(filepath, compression="zstd")
75
+
76
+ return str(filepath)
77
+
78
+ def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame:
79
+ if not self.enabled or not isinstance(ref, str):
80
+ return ref
81
+
82
+ return pl.read_parquet(ref)
@@ -121,8 +121,7 @@ def _build_pdf(md_content: str, filename: str, img_path: str = None) -> None:
121
121
  f'src="{img_name}"', f'src="file://{img_path}"'
122
122
  )
123
123
 
124
- css = CSS(
125
- string="""
124
+ css = CSS(string="""
126
125
  body {
127
126
  font-family: Arial, sans-serif;
128
127
  font-size: 11pt;
@@ -153,8 +152,7 @@ def _build_pdf(md_content: str, filename: str, img_path: str = None) -> None:
153
152
  }
154
153
  code { font-family: 'Courier New', monospace; }
155
154
  img { max-width: 100%; height: auto; }
156
- """
157
- )
155
+ """)
158
156
 
159
157
  full_html = f"""
160
158
  <!DOCTYPE html>