pySEQTarget 0.13.2__tar.gz → 0.13.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/PKG-INFO +1 -1
  2. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQopts.py +8 -1
  3. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQoutput.py +1 -1
  4. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQuential.py +54 -2
  5. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_hazard.py +1 -1
  6. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_outcome_fit.py +49 -6
  7. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_subgroup_fit.py +10 -2
  8. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_survival_pred.py +64 -2
  9. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_data_checker.py +2 -2
  10. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_param_checker.py +7 -2
  11. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_binder.py +1 -1
  12. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_dynamic.py +1 -1
  13. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/PKG-INFO +1 -1
  14. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pyproject.toml +1 -1
  15. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_expansion.py +57 -0
  16. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_followup_options.py +12 -11
  17. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/LICENSE +0 -0
  18. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/README.md +0 -0
  19. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/__init__.py +0 -0
  20. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/__init__.py +0 -0
  21. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_risk_estimates.py +0 -0
  22. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/data/__init__.py +0 -0
  23. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/__init__.py +0 -0
  24. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_check_separation.py +0 -0
  25. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/__init__.py +0 -0
  26. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_diagnostics.py +0 -0
  27. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_mapper.py +0 -0
  28. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_selection.py +0 -0
  29. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/__init__.py +0 -0
  30. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_bootstrap.py +0 -0
  31. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_col_string.py +0 -0
  32. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_fix_categories.py +0 -0
  33. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_format_time.py +0 -0
  34. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_offloader.py +0 -0
  35. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_output_files.py +0 -0
  36. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_pad.py +0 -0
  37. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_predict_model.py +0 -0
  38. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_prepare_data.py +0 -0
  39. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/__init__.py +0 -0
  40. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_censoring.py +0 -0
  41. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_denominator.py +0 -0
  42. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_numerator.py +0 -0
  43. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_outcome.py +0 -0
  44. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/plot/__init__.py +0 -0
  45. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/plot/_survival_plot.py +0 -0
  46. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/__init__.py +0 -0
  47. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_bind.py +0 -0
  48. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_data.py +0 -0
  49. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_fit.py +0 -0
  50. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_offload.py +0 -0
  51. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_pred.py +0 -0
  52. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_stats.py +0 -0
  53. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/SOURCES.txt +0 -0
  54. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/dependency_links.txt +0 -0
  55. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/requires.txt +0 -0
  56. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/top_level.txt +0 -0
  57. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/setup.cfg +0 -0
  58. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_accessor.py +0 -0
  59. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_check_separation.py +0 -0
  60. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_coefficients.py +0 -0
  61. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_covariates.py +0 -0
  62. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_fix_categories.py +0 -0
  63. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_hazard.py +0 -0
  64. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_no_variation.py +0 -0
  65. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_offload.py +0 -0
  66. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_parallel.py +0 -0
  67. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_plot.py +0 -0
  68. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_reproducibility.py +0 -0
  69. {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_survival.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pySEQTarget
3
- Version: 0.13.2
3
+ Version: 0.13.4
4
4
  Summary: Sequentially Nested Target Trial Emulation
5
5
  Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernán <mhernan@hsph.harvard.edu>
6
6
  Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>, Tom Palmer <remlapmot@hotmail.com>
@@ -15,7 +15,7 @@ class SEQopts:
15
15
  :type bootstrap_sample: float
16
16
  :param bootstrap_CI: If bootstrapped, confidence interval level
17
17
  :type bootstrap_CI: float
18
- :param bootstrap_CI_method: If bootstrapped, confidence method generation method ['SE' or 'percentile']
18
+ :param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile']
19
19
  :type bootstrap_CI_method: str
20
20
  :param cense_colname: Column name for censoring effect (LTFU, etc.)
21
21
  :type cense_colname: str
@@ -45,6 +45,8 @@ class SEQopts:
45
45
  :type followup_include: bool
46
46
  :param followup_spline: Boolean to force followup values to be fit to cubic spline
47
47
  :type followup_spline: bool
48
+ :param followup_spline_df: Degrees of freedom for the followup cubic spline, default ``4``
49
+ :type followup_spline_df: int
48
50
  :param followup_max: Maximum allowed followup in analysis
49
51
  :type followup_max: int or None
50
52
  :param followup_min: Minimum allowed followup in analysis
@@ -107,6 +109,8 @@ class SEQopts:
107
109
  :type weight_p99: bool
108
110
  :param weight_preexpansion: Boolean to fit weights on preexpanded data
109
111
  :type weight_preexpansion: bool
112
+ :param verbose: Boolean to print dataset size summaries and bootstrap information
113
+ :type verbose: bool
110
114
  :param weighted: Boolean to weight analysis
111
115
  :type weighted: bool
112
116
  """
@@ -130,6 +134,7 @@ class SEQopts:
130
134
  followup_max: int = None
131
135
  followup_min: int = 0
132
136
  followup_spline: bool = False
137
+ followup_spline_df: int = 4
133
138
  hazard_estimate: bool = False
134
139
  indicator_baseline: str = "_bas"
135
140
  indicator_squared: str = "_sq"
@@ -160,6 +165,7 @@ class SEQopts:
160
165
  weight_lag_condition: bool = True
161
166
  weight_p99: bool = False
162
167
  weight_preexpansion: bool = True
168
+ verbose: bool = False
163
169
  weighted: bool = False
164
170
 
165
171
  def _validate_bools(self):
@@ -175,6 +181,7 @@ class SEQopts:
175
181
  "selection_first_trial",
176
182
  "selection_random",
177
183
  "trial_include",
184
+ "verbose",
178
185
  "weight_lag_condition",
179
186
  "weight_p99",
180
187
  "weight_preexpansion",
@@ -62,7 +62,7 @@ class SEQoutput:
62
62
 
63
63
  def plot(self) -> None:
64
64
  """
65
- Displays the kaplan-meier graph
65
+ Displays the Kaplan-Meier graph
66
66
  """
67
67
  if self.km_graph is None:
68
68
  raise ValueError(
@@ -114,6 +114,13 @@ class SEQuential:
114
114
  :class:`polars.DataFrame` and skips all subsequent analysis steps.
115
115
  """
116
116
  start = time.perf_counter()
117
+
118
+ if self.verbose:
119
+ n, m = self.data.shape
120
+ print(f"Full dataset: {n:,} observations, {m} variables")
121
+ n_elig = self.data.filter(pl.col(self.eligible_col) == 1).shape[0]
122
+ print(f"Eligible observations: {n_elig:,}")
123
+
117
124
  kept = [
118
125
  self.cense_colname,
119
126
  self.cense_eligible_colname,
@@ -162,14 +169,25 @@ class SEQuential:
162
169
  pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
163
170
  )
164
171
 
172
+ if self.verbose:
173
+ n, m = self.DT.shape
174
+ print(f"Expanded dataset: {n:,} observations, {m} variables")
175
+
165
176
  if self.method == "dose-response" or (
166
177
  self.method == "censoring" and not self.expand_only
167
178
  ):
168
179
  _dynamic(self)
169
180
  if self.selection_random:
170
181
  _random_selection(self)
182
+ if self.verbose:
183
+ n, m = self.DT.shape
184
+ print(f"Sampled expanded dataset: {n:,} observations, {m} variables")
171
185
  _diagnostics(self)
172
186
 
187
+ if self.verbose:
188
+ n, m = self.DT.shape
189
+ print(f"Final analysis dataset: {n:,} observations, {m} variables")
190
+
173
191
  end = time.perf_counter()
174
192
  self._expansion_time = _format_time(start, end)
175
193
 
@@ -200,6 +218,16 @@ class SEQuential:
200
218
  )
201
219
  NIDs = len(UIDs)
202
220
 
221
+ if self.verbose:
222
+ n_sample = round(self.bootstrap_sample * NIDs)
223
+ n_obs_sample = round(self.bootstrap_sample * len(self.DT))
224
+ print(
225
+ f"Bootstrapping with {self.bootstrap_sample * 100:.4g}% of "
226
+ f"{NIDs:,} subjects "
227
+ f"({n_sample:,} subjects, ~{n_obs_sample:,} observations per resample) "
228
+ f"{self.bootstrap_nboot} times"
229
+ )
230
+
203
231
  self._boot_samples = []
204
232
  for _ in range(self.bootstrap_nboot):
205
233
  sampled_IDs = self._rng.choice(
@@ -244,8 +272,23 @@ class SEQuential:
244
272
  _weight_bind(self, WDT)
245
273
  self.weight_stats = _weight_stats(self)
246
274
 
275
+ is_boot = boot_idx is not None
276
+ start = getattr(self, "_outcome_start_params", None) if is_boot else None
277
+
247
278
  if self.subgroup_colname is not None:
248
- return _subgroup_fit(self)
279
+ models_list = _subgroup_fit(self, start_params=start)
280
+ if not is_boot:
281
+ self._outcome_start_params = {
282
+ val: {
283
+ key: (m.params.values, list(m.model.exog_names))
284
+ for key, m in sg.items()
285
+ }
286
+ for val, sg in zip(self._unique_subgroups, models_list)
287
+ }
288
+ return models_list
289
+
290
+ start_outcome = (start or {}).get("outcome")
291
+ start_compevent = (start or {}).get("compevent")
249
292
 
250
293
  models = {
251
294
  "outcome": _outcome_fit(
@@ -255,6 +298,7 @@ class SEQuential:
255
298
  self.covariates,
256
299
  self.weighted,
257
300
  "weight",
301
+ start_params=start_outcome,
258
302
  )
259
303
  }
260
304
  if self.compevent_colname is not None:
@@ -265,7 +309,15 @@ class SEQuential:
265
309
  self.covariates,
266
310
  self.weighted,
267
311
  "weight",
312
+ start_params=start_compevent,
268
313
  )
314
+
315
+ if not is_boot:
316
+ self._outcome_start_params = {
317
+ k: (m.params.values, list(m.model.exog_names))
318
+ for k, m in models.items()
319
+ }
320
+
269
321
  if self.offload:
270
322
  offloaded_models = {}
271
323
  for key, model in models.items():
@@ -332,7 +384,7 @@ class SEQuential:
332
384
 
333
385
  def collect(self) -> SEQoutput:
334
386
  """
335
- Collects all results current created into ``SEQoutput`` class
387
+ Collects all results currently created into ``SEQoutput`` class
336
388
  """
337
389
  self._time_collected = datetime.datetime.now()
338
390
 
@@ -182,7 +182,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
182
182
  sim_data_pd = sim_data.to_pandas()
183
183
 
184
184
  try:
185
- # COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
185
+ # COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
186
186
  warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
187
187
  if ce_model is not None:
188
188
  cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
@@ -1,12 +1,30 @@
1
1
  import re
2
2
 
3
+ import numpy as np
3
4
  import polars as pl
4
5
  import statsmodels.api as sm
5
6
  import statsmodels.formula.api as smf
6
7
 
7
8
 
8
- def _apply_spline_formula(formula, indicator_squared):
9
- spline = "cr(followup, df=3)"
9
+ def _compute_spline_knots(followup_arr, df=3):
10
+ lower = float(np.min(followup_arr))
11
+ upper = float(np.max(followup_arr))
12
+ n_inner = df - 2
13
+ if n_inner == 0:
14
+ inner_knots = []
15
+ else:
16
+ # Replicate patsy's knot placement: percentiles of unique values in [lower, upper]
17
+ x = np.unique(followup_arr[(lower <= followup_arr) & (followup_arr <= upper)])
18
+ q = np.linspace(0, 100, n_inner + 2)[1:-1]
19
+ inner_knots = np.percentile(x, q.tolist()).tolist()
20
+ return inner_knots, lower, upper
21
+
22
+
23
+ def _apply_spline_formula(formula, indicator_squared, spline_knots):
24
+ inner_knots, lower, upper = spline_knots
25
+ spline = (
26
+ f"cr(followup, knots={inner_knots}, lower_bound={lower}, upper_bound={upper})"
27
+ )
10
28
 
11
29
  formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula)
12
30
  formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula)
@@ -18,8 +36,8 @@ def _apply_spline_formula(formula, indicator_squared):
18
36
  formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip()
19
37
 
20
38
  if formula:
21
- return f"{formula} + I({spline}**2)"
22
- return f"I({spline}**2)"
39
+ return f"{formula} + {spline}"
40
+ return spline
23
41
 
24
42
 
25
43
  def _cast_categories(self, df_pd):
@@ -50,6 +68,7 @@ def _outcome_fit(
50
68
  formula: str,
51
69
  weighted: bool = False,
52
70
  weight_col: str = "weight",
71
+ start_params=None,
53
72
  ):
54
73
  if weighted:
55
74
  df = df.with_columns(
@@ -64,7 +83,13 @@ def _outcome_fit(
64
83
  df_pd = _cast_categories(self, df.to_pandas())
65
84
 
66
85
  if self.followup_spline:
67
- formula = _apply_spline_formula(formula, self.indicator_squared)
86
+ if getattr(self, "_current_boot_idx", None) is None:
87
+ self._followup_spline_knots = _compute_spline_knots(
88
+ self.DT["followup"].to_numpy(), df=self.followup_spline_df
89
+ )
90
+ formula = _apply_spline_formula(
91
+ formula, self.indicator_squared, self._followup_spline_knots
92
+ )
68
93
 
69
94
  full_formula = f"{outcome} ~ {formula}"
70
95
 
@@ -78,5 +103,23 @@ def _outcome_fit(
78
103
  glm_kwargs["var_weights"] = df_pd[weight_col]
79
104
 
80
105
  model = smf.glm(**glm_kwargs)
81
- model_fit = model.fit()
106
+
107
+ # Drop warm-start coefs unless the design matrix columns match exactly
108
+ # by name — bootstrap resamples can shift categorical reference levels or
109
+ # column ordering, in which case the cached coefs are meaningless and
110
+ # IRLS can diverge into NaN/Inf and crash LAPACK.
111
+ if start_params is not None:
112
+ sp_values, sp_names = start_params
113
+ if list(model.exog_names) != list(sp_names):
114
+ start_params = None
115
+ else:
116
+ start_params = sp_values
117
+
118
+ try:
119
+ model_fit = model.fit(start_params=start_params)
120
+ except Exception:
121
+ if start_params is not None:
122
+ model_fit = model.fit()
123
+ else:
124
+ raise
82
125
  return model_fit
@@ -3,17 +3,24 @@ import polars as pl
3
3
  from ._outcome_fit import _outcome_fit
4
4
 
5
5
 
6
- def _subgroup_fit(self):
6
+ def _subgroup_fit(self, start_params=None):
7
7
  subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
8
8
  self._unique_subgroups = subgroups
9
9
 
10
10
  models_list = []
11
11
  for val in subgroups:
12
12
  subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)
13
+ sg_start = (start_params or {}).get(val, {}) or {}
13
14
 
14
15
  models = {
15
16
  "outcome": _outcome_fit(
16
- self, subDT, self.outcome_col, self.covariates, self.weighted, "weight"
17
+ self,
18
+ subDT,
19
+ self.outcome_col,
20
+ self.covariates,
21
+ self.weighted,
22
+ "weight",
23
+ start_params=sg_start.get("outcome"),
17
24
  )
18
25
  }
19
26
 
@@ -25,6 +32,7 @@ def _subgroup_fit(self):
25
32
  self.covariates,
26
33
  self.weighted,
27
34
  "weight",
35
+ start_params=sg_start.get("compevent"),
28
36
  )
29
37
  models_list.append(models)
30
38
  return models_list
@@ -1,5 +1,8 @@
1
+ import numpy as np
1
2
  import polars as pl
3
+ from patsy import PatsyError, dmatrix
2
4
 
5
+ from ..helpers._fix_categories import _fix_categories_for_predict
3
6
  from ..helpers._predict_model import _safe_predict
4
7
  from ._outcome_fit import _cast_categories
5
8
 
@@ -25,20 +28,79 @@ def _store_boot_risks(obj, treatment_val, TxDT, boot_cols, is_survival=False):
25
28
  )
26
29
 
27
30
 
31
+ def _build_design_matrix(design_info, data):
32
+ """
33
+ Build a design matrix from a cached design_info, applying the same
34
+ category-alignment fallback that _safe_predict uses on mismatch.
35
+ """
36
+ try:
37
+ return np.asarray(dmatrix(design_info, data))
38
+ except PatsyError as e:
39
+ if "mismatching levels" not in str(e):
40
+ raise
41
+
42
+ # Reuse the existing fix by wrapping design_info in a stub object
43
+ class _Stub:
44
+ class model:
45
+ class data:
46
+ pass
47
+
48
+ stub = _Stub()
49
+ stub.model.data.design_info = design_info
50
+ fixed = _fix_categories_for_predict(stub, data.copy())
51
+ return np.asarray(dmatrix(design_info, fixed))
52
+
53
+
54
+ def _cached_predict(model, X_cached, ref_column_names, data):
55
+ """
56
+ Predict using a pre-built design matrix when the model's design_info
57
+ column structure matches the reference, falling back to patsy via
58
+ _safe_predict on mismatch (e.g. a bootstrap resample that dropped a
59
+ categorical level).
60
+ """
61
+ dinfo = model.model.data.design_info
62
+ if list(dinfo.column_names) == ref_column_names:
63
+ probs = np.asarray(model.predict(X_cached, transform=False))
64
+ if not np.any(np.isnan(probs)):
65
+ return np.clip(probs, 0, 1)
66
+ return _safe_predict(model, data)
67
+
68
+
28
69
  def _get_outcome_predictions(self, TxDT, idx=None):
29
70
  data = _cast_categories(self, TxDT.to_pandas())
30
71
  predictions = {"outcome": []}
31
72
  if self.compevent_colname is not None:
32
73
  predictions["compevent"] = []
33
74
 
75
+ # Pre-build the design matrix once using the main fit's design_info.
76
+ # Each bootstrap model that shares the same column structure can then
77
+ # reuse it, skipping patsy entirely on the predict path.
78
+ main = self.outcome_model[0]
79
+ main_dict = main[idx] if idx is not None else main
80
+ main_outcome = self._offloader.load_model(main_dict["outcome"])
81
+ outcome_dinfo = main_outcome.model.data.design_info
82
+ X_outcome = _build_design_matrix(outcome_dinfo, data)
83
+ outcome_cols = list(outcome_dinfo.column_names)
84
+
85
+ X_compevent = compevent_cols = None
86
+ if self.compevent_colname is not None:
87
+ main_compevent = self._offloader.load_model(main_dict["compevent"])
88
+ compevent_dinfo = main_compevent.model.data.design_info
89
+ X_compevent = _build_design_matrix(compevent_dinfo, data)
90
+ compevent_cols = list(compevent_dinfo.column_names)
91
+
34
92
  for boot_model in self.outcome_model:
35
93
  model_dict = boot_model[idx] if idx is not None else boot_model
36
94
  outcome_model = self._offloader.load_model(model_dict["outcome"])
37
- predictions["outcome"].append(_safe_predict(outcome_model, data))
95
+ predictions["outcome"].append(
96
+ _cached_predict(outcome_model, X_outcome, outcome_cols, data)
97
+ )
38
98
 
39
99
  if self.compevent_colname is not None:
40
100
  compevent_model = self._offloader.load_model(model_dict["compevent"])
41
- predictions["compevent"].append(_safe_predict(compevent_model, data))
101
+ predictions["compevent"].append(
102
+ _cached_predict(compevent_model, X_compevent, compevent_cols, data)
103
+ )
42
104
 
43
105
  return predictions
44
106
 
@@ -31,8 +31,8 @@ def _data_checker(self):
31
31
  invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
32
32
  if len(invalid) > 0:
33
33
  raise ValueError(
34
- f"Data validation failed: {len(invalid)} ID(s) have mismatched "
35
- f"This suggests invalid times"
34
+ f"Data validation failed: {len(invalid)} ID(s) have mismatched row counts. "
35
+ f"This suggests invalid times. "
36
36
  f"Invalid IDs:\n{invalid}"
37
37
  )
38
38
 
@@ -1,3 +1,5 @@
1
+ import warnings
2
+
1
3
  from ..helpers import _pad
2
4
 
3
5
 
@@ -26,13 +28,13 @@ def _param_checker(self):
26
28
 
27
29
  if len(self.excused_colnames) == 0 and self.excused:
28
30
  self.excused = False
29
- raise Warning(
31
+ warnings.warn(
30
32
  "Excused column names not provided but excused is set to True. Automatically set excused to False"
31
33
  )
32
34
 
33
35
  if len(self.excused_colnames) > 0 and not self.excused:
34
36
  self.excused = True
35
- raise Warning(
37
+ warnings.warn(
36
38
  "Excused column names provided but excused is set to False. Automatically set excused to True"
37
39
  )
38
40
 
@@ -44,6 +46,9 @@ def _param_checker(self):
44
46
  "Only one of followup_class or followup_include can be set to True."
45
47
  )
46
48
 
49
+ if self.followup_spline_df < 2:
50
+ raise ValueError("followup_spline_df must be at least 2.")
51
+
47
52
  if (
48
53
  self.weighted
49
54
  and self.method == "ITT"
@@ -5,7 +5,7 @@ from ._mapper import _mapper
5
5
 
6
6
  def _binder(self, kept_cols):
7
7
  """
8
- Internal function to bind data to the map created by __mapper
8
+ Internal function to bind data to the map created by _mapper
9
9
  """
10
10
  excluded = {
11
11
  "dose",
@@ -3,7 +3,7 @@ import polars as pl
3
3
 
4
4
  def _dynamic(self):
5
5
  """
6
- Handles special cases for the data from the __mapper -> __binder pipeline
6
+ Handles special cases for the data from the _mapper -> _binder pipeline
7
7
  """
8
8
  if self.method == "dose-response":
9
9
  DT = self.DT.with_columns(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pySEQTarget
3
- Version: 0.13.2
3
+ Version: 0.13.4
4
4
  Summary: Sequentially Nested Target Trial Emulation
5
5
  Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernán <mhernan@hsph.harvard.edu>
6
6
  Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>, Tom Palmer <remlapmot@hotmail.com>
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "pySEQTarget"
7
- version = "0.13.2"
7
+ version = "0.13.4"
8
8
  description = "Sequentially Nested Target Trial Emulation"
9
9
  readme = "README.md"
10
10
  license = {text = "MIT"}
@@ -2,6 +2,7 @@ import polars as pl
2
2
  from polars.testing import assert_frame_equal
3
3
 
4
4
  from pySEQTarget import SEQopts, SEQuential
5
+ from pySEQTarget.data import load_data
5
6
 
6
7
 
7
8
  def _make_model(data):
@@ -126,3 +127,59 @@ def test_expand_only_returns_expanded_dataframe():
126
127
  model_full.expand()
127
128
 
128
129
  assert_frame_equal(result, model_full.DT)
130
+
131
+
132
+ def _make_verbose_model(verbose, **extra_opts):
133
+ data = load_data("SEQdata")
134
+ return SEQuential(
135
+ data,
136
+ id_col="ID",
137
+ time_col="time",
138
+ eligible_col="eligible",
139
+ treatment_col="tx_init",
140
+ outcome_col="outcome",
141
+ time_varying_cols=["N", "L", "P"],
142
+ fixed_cols=["sex"],
143
+ method="ITT",
144
+ parameters=SEQopts(verbose=verbose, **extra_opts),
145
+ )
146
+
147
+
148
+ def test_verbose_expand(capsys):
149
+ s = _make_verbose_model(verbose=True)
150
+ s.expand()
151
+ out = capsys.readouterr().out
152
+ assert "Full dataset:" in out
153
+ assert "Eligible observations:" in out
154
+ assert "Expanded dataset:" in out
155
+ assert "Final analysis dataset:" in out
156
+ assert "Sampled expanded dataset:" not in out
157
+ assert "observations" in out
158
+ assert "variables" in out
159
+
160
+
161
+ def test_verbose_expand_with_sampling(capsys):
162
+ s = _make_verbose_model(verbose=True, selection_random=True, selection_sample=0.5)
163
+ s.expand()
164
+ out = capsys.readouterr().out
165
+ assert "Sampled expanded dataset:" in out
166
+
167
+
168
+ def test_verbose_bootstrap(capsys):
169
+ s = _make_verbose_model(verbose=True, bootstrap_nboot=10)
170
+ s.expand()
171
+ capsys.readouterr()
172
+ s.bootstrap()
173
+ out = capsys.readouterr().out
174
+ assert "Bootstrapping" in out
175
+ assert "subjects" in out
176
+ assert "observations per resample" in out
177
+ assert "10 times" in out
178
+
179
+
180
+ def test_verbose_false_no_output(capsys):
181
+ s = _make_verbose_model(verbose=False, bootstrap_nboot=5)
182
+ s.expand()
183
+ s.bootstrap()
184
+ out = capsys.readouterr().out
185
+ assert out == ""
@@ -57,17 +57,18 @@ def test_followup_spline():
57
57
  s.fit()
58
58
  matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
59
59
  expected = [
60
- -6.264817962084417,
61
- 0.20125056343026881,
62
- 0.12568743032952776,
63
- 0.03823426390103046,
64
- 0.0006607691746414019,
65
- 0.003343365539743267,
66
- -0.01319460158923785,
67
- 0.19601796921732118,
68
- -0.5186462478511427,
69
- 0.37598656666756103,
70
- 1.6553848469346044,
60
+ -4.804282252748607,
61
+ 0.19115933860001255,
62
+ 0.12717121164606823,
63
+ 0.044310717515918724,
64
+ 0.0005814999431447507,
65
+ 0.0032948355025455216,
66
+ -0.013371824500839971,
67
+ 0.19972467861548412,
68
+ -2.027245615586753,
69
+ -1.395729861856384,
70
+ -0.9397731941281695,
71
+ -0.4415335811772879,
71
72
  ]
72
73
  assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]
73
74
 
File without changes
File without changes
File without changes