pySEQTarget 0.13.2__tar.gz → 0.13.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/PKG-INFO +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQopts.py +8 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQoutput.py +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/SEQuential.py +54 -2
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_hazard.py +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_outcome_fit.py +49 -6
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_subgroup_fit.py +10 -2
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_survival_pred.py +64 -2
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_data_checker.py +2 -2
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_param_checker.py +7 -2
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_binder.py +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_dynamic.py +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/PKG-INFO +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pyproject.toml +1 -1
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_expansion.py +57 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_followup_options.py +12 -11
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/LICENSE +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/README.md +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/analysis/_risk_estimates.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/data/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/error/_check_separation.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_diagnostics.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_mapper.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/expansion/_selection.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_bootstrap.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_col_string.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_fix_categories.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_format_time.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_offloader.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_output_files.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_pad.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_predict_model.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/helpers/_prepare_data.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_censoring.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_denominator.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_numerator.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/initialization/_outcome.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/plot/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/plot/_survival_plot.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/__init__.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_bind.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_data.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_fit.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_offload.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_pred.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget/weighting/_weight_stats.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/SOURCES.txt +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/dependency_links.txt +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/requires.txt +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/pySEQTarget.egg-info/top_level.txt +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/setup.cfg +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_accessor.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_check_separation.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_coefficients.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_covariates.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_fix_categories.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_hazard.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_no_variation.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_offload.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_parallel.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_plot.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_reproducibility.py +0 -0
- {pyseqtarget-0.13.2 → pyseqtarget-0.13.4}/tests/test_survival.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pySEQTarget
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.4
|
|
4
4
|
Summary: Sequentially Nested Target Trial Emulation
|
|
5
5
|
Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernán <mhernan@hsph.harvard.edu>
|
|
6
6
|
Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>, Tom Palmer <remlapmot@hotmail.com>
|
|
@@ -15,7 +15,7 @@ class SEQopts:
|
|
|
15
15
|
:type bootstrap_sample: float
|
|
16
16
|
:param bootstrap_CI: If bootstrapped, confidence interval level
|
|
17
17
|
:type bootstrap_CI: float
|
|
18
|
-
:param bootstrap_CI_method: If bootstrapped, confidence
|
|
18
|
+
:param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile']
|
|
19
19
|
:type bootstrap_CI_method: str
|
|
20
20
|
:param cense_colname: Column name for censoring effect (LTFU, etc.)
|
|
21
21
|
:type cense_colname: str
|
|
@@ -45,6 +45,8 @@ class SEQopts:
|
|
|
45
45
|
:type followup_include: bool
|
|
46
46
|
:param followup_spline: Boolean to force followup values to be fit to cubic spline
|
|
47
47
|
:type followup_spline: bool
|
|
48
|
+
:param followup_spline_df: Degrees of freedom for the followup cubic spline, default ``4``
|
|
49
|
+
:type followup_spline_df: int
|
|
48
50
|
:param followup_max: Maximum allowed followup in analysis
|
|
49
51
|
:type followup_max: int or None
|
|
50
52
|
:param followup_min: Minimum allowed followup in analysis
|
|
@@ -107,6 +109,8 @@ class SEQopts:
|
|
|
107
109
|
:type weight_p99: bool
|
|
108
110
|
:param weight_preexpansion: Boolean to fit weights on preexpanded data
|
|
109
111
|
:type weight_preexpansion: bool
|
|
112
|
+
:param verbose: Boolean to print dataset size summaries and bootstrap information
|
|
113
|
+
:type verbose: bool
|
|
110
114
|
:param weighted: Boolean to weight analysis
|
|
111
115
|
:type weighted: bool
|
|
112
116
|
"""
|
|
@@ -130,6 +134,7 @@ class SEQopts:
|
|
|
130
134
|
followup_max: int = None
|
|
131
135
|
followup_min: int = 0
|
|
132
136
|
followup_spline: bool = False
|
|
137
|
+
followup_spline_df: int = 4
|
|
133
138
|
hazard_estimate: bool = False
|
|
134
139
|
indicator_baseline: str = "_bas"
|
|
135
140
|
indicator_squared: str = "_sq"
|
|
@@ -160,6 +165,7 @@ class SEQopts:
|
|
|
160
165
|
weight_lag_condition: bool = True
|
|
161
166
|
weight_p99: bool = False
|
|
162
167
|
weight_preexpansion: bool = True
|
|
168
|
+
verbose: bool = False
|
|
163
169
|
weighted: bool = False
|
|
164
170
|
|
|
165
171
|
def _validate_bools(self):
|
|
@@ -175,6 +181,7 @@ class SEQopts:
|
|
|
175
181
|
"selection_first_trial",
|
|
176
182
|
"selection_random",
|
|
177
183
|
"trial_include",
|
|
184
|
+
"verbose",
|
|
178
185
|
"weight_lag_condition",
|
|
179
186
|
"weight_p99",
|
|
180
187
|
"weight_preexpansion",
|
|
@@ -114,6 +114,13 @@ class SEQuential:
|
|
|
114
114
|
:class:`polars.DataFrame` and skips all subsequent analysis steps.
|
|
115
115
|
"""
|
|
116
116
|
start = time.perf_counter()
|
|
117
|
+
|
|
118
|
+
if self.verbose:
|
|
119
|
+
n, m = self.data.shape
|
|
120
|
+
print(f"Full dataset: {n:,} observations, {m} variables")
|
|
121
|
+
n_elig = self.data.filter(pl.col(self.eligible_col) == 1).shape[0]
|
|
122
|
+
print(f"Eligible observations: {n_elig:,}")
|
|
123
|
+
|
|
117
124
|
kept = [
|
|
118
125
|
self.cense_colname,
|
|
119
126
|
self.cense_eligible_colname,
|
|
@@ -162,14 +169,25 @@ class SEQuential:
|
|
|
162
169
|
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
|
|
163
170
|
)
|
|
164
171
|
|
|
172
|
+
if self.verbose:
|
|
173
|
+
n, m = self.DT.shape
|
|
174
|
+
print(f"Expanded dataset: {n:,} observations, {m} variables")
|
|
175
|
+
|
|
165
176
|
if self.method == "dose-response" or (
|
|
166
177
|
self.method == "censoring" and not self.expand_only
|
|
167
178
|
):
|
|
168
179
|
_dynamic(self)
|
|
169
180
|
if self.selection_random:
|
|
170
181
|
_random_selection(self)
|
|
182
|
+
if self.verbose:
|
|
183
|
+
n, m = self.DT.shape
|
|
184
|
+
print(f"Sampled expanded dataset: {n:,} observations, {m} variables")
|
|
171
185
|
_diagnostics(self)
|
|
172
186
|
|
|
187
|
+
if self.verbose:
|
|
188
|
+
n, m = self.DT.shape
|
|
189
|
+
print(f"Final analysis dataset: {n:,} observations, {m} variables")
|
|
190
|
+
|
|
173
191
|
end = time.perf_counter()
|
|
174
192
|
self._expansion_time = _format_time(start, end)
|
|
175
193
|
|
|
@@ -200,6 +218,16 @@ class SEQuential:
|
|
|
200
218
|
)
|
|
201
219
|
NIDs = len(UIDs)
|
|
202
220
|
|
|
221
|
+
if self.verbose:
|
|
222
|
+
n_sample = round(self.bootstrap_sample * NIDs)
|
|
223
|
+
n_obs_sample = round(self.bootstrap_sample * len(self.DT))
|
|
224
|
+
print(
|
|
225
|
+
f"Bootstrapping with {self.bootstrap_sample * 100:.4g}% of "
|
|
226
|
+
f"{NIDs:,} subjects "
|
|
227
|
+
f"({n_sample:,} subjects, ~{n_obs_sample:,} observations per resample) "
|
|
228
|
+
f"{self.bootstrap_nboot} times"
|
|
229
|
+
)
|
|
230
|
+
|
|
203
231
|
self._boot_samples = []
|
|
204
232
|
for _ in range(self.bootstrap_nboot):
|
|
205
233
|
sampled_IDs = self._rng.choice(
|
|
@@ -244,8 +272,23 @@ class SEQuential:
|
|
|
244
272
|
_weight_bind(self, WDT)
|
|
245
273
|
self.weight_stats = _weight_stats(self)
|
|
246
274
|
|
|
275
|
+
is_boot = boot_idx is not None
|
|
276
|
+
start = getattr(self, "_outcome_start_params", None) if is_boot else None
|
|
277
|
+
|
|
247
278
|
if self.subgroup_colname is not None:
|
|
248
|
-
|
|
279
|
+
models_list = _subgroup_fit(self, start_params=start)
|
|
280
|
+
if not is_boot:
|
|
281
|
+
self._outcome_start_params = {
|
|
282
|
+
val: {
|
|
283
|
+
key: (m.params.values, list(m.model.exog_names))
|
|
284
|
+
for key, m in sg.items()
|
|
285
|
+
}
|
|
286
|
+
for val, sg in zip(self._unique_subgroups, models_list)
|
|
287
|
+
}
|
|
288
|
+
return models_list
|
|
289
|
+
|
|
290
|
+
start_outcome = (start or {}).get("outcome")
|
|
291
|
+
start_compevent = (start or {}).get("compevent")
|
|
249
292
|
|
|
250
293
|
models = {
|
|
251
294
|
"outcome": _outcome_fit(
|
|
@@ -255,6 +298,7 @@ class SEQuential:
|
|
|
255
298
|
self.covariates,
|
|
256
299
|
self.weighted,
|
|
257
300
|
"weight",
|
|
301
|
+
start_params=start_outcome,
|
|
258
302
|
)
|
|
259
303
|
}
|
|
260
304
|
if self.compevent_colname is not None:
|
|
@@ -265,7 +309,15 @@ class SEQuential:
|
|
|
265
309
|
self.covariates,
|
|
266
310
|
self.weighted,
|
|
267
311
|
"weight",
|
|
312
|
+
start_params=start_compevent,
|
|
268
313
|
)
|
|
314
|
+
|
|
315
|
+
if not is_boot:
|
|
316
|
+
self._outcome_start_params = {
|
|
317
|
+
k: (m.params.values, list(m.model.exog_names))
|
|
318
|
+
for k, m in models.items()
|
|
319
|
+
}
|
|
320
|
+
|
|
269
321
|
if self.offload:
|
|
270
322
|
offloaded_models = {}
|
|
271
323
|
for key, model in models.items():
|
|
@@ -332,7 +384,7 @@ class SEQuential:
|
|
|
332
384
|
|
|
333
385
|
def collect(self) -> SEQoutput:
|
|
334
386
|
"""
|
|
335
|
-
Collects all results
|
|
387
|
+
Collects all results currently created into ``SEQoutput`` class
|
|
336
388
|
"""
|
|
337
389
|
self._time_collected = datetime.datetime.now()
|
|
338
390
|
|
|
@@ -182,7 +182,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
|
182
182
|
sim_data_pd = sim_data.to_pandas()
|
|
183
183
|
|
|
184
184
|
try:
|
|
185
|
-
#
|
|
185
|
+
# COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
|
|
186
186
|
warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
|
|
187
187
|
if ce_model is not None:
|
|
188
188
|
cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
|
|
@@ -1,12 +1,30 @@
|
|
|
1
1
|
import re
|
|
2
2
|
|
|
3
|
+
import numpy as np
|
|
3
4
|
import polars as pl
|
|
4
5
|
import statsmodels.api as sm
|
|
5
6
|
import statsmodels.formula.api as smf
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def
|
|
9
|
-
|
|
9
|
+
def _compute_spline_knots(followup_arr, df=3):
|
|
10
|
+
lower = float(np.min(followup_arr))
|
|
11
|
+
upper = float(np.max(followup_arr))
|
|
12
|
+
n_inner = df - 2
|
|
13
|
+
if n_inner == 0:
|
|
14
|
+
inner_knots = []
|
|
15
|
+
else:
|
|
16
|
+
# Replicate patsy's knot placement: percentiles of unique values in [lower, upper]
|
|
17
|
+
x = np.unique(followup_arr[(lower <= followup_arr) & (followup_arr <= upper)])
|
|
18
|
+
q = np.linspace(0, 100, n_inner + 2)[1:-1]
|
|
19
|
+
inner_knots = np.percentile(x, q.tolist()).tolist()
|
|
20
|
+
return inner_knots, lower, upper
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _apply_spline_formula(formula, indicator_squared, spline_knots):
|
|
24
|
+
inner_knots, lower, upper = spline_knots
|
|
25
|
+
spline = (
|
|
26
|
+
f"cr(followup, knots={inner_knots}, lower_bound={lower}, upper_bound={upper})"
|
|
27
|
+
)
|
|
10
28
|
|
|
11
29
|
formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula)
|
|
12
30
|
formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula)
|
|
@@ -18,8 +36,8 @@ def _apply_spline_formula(formula, indicator_squared):
|
|
|
18
36
|
formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip()
|
|
19
37
|
|
|
20
38
|
if formula:
|
|
21
|
-
return f"{formula} +
|
|
22
|
-
return
|
|
39
|
+
return f"{formula} + {spline}"
|
|
40
|
+
return spline
|
|
23
41
|
|
|
24
42
|
|
|
25
43
|
def _cast_categories(self, df_pd):
|
|
@@ -50,6 +68,7 @@ def _outcome_fit(
|
|
|
50
68
|
formula: str,
|
|
51
69
|
weighted: bool = False,
|
|
52
70
|
weight_col: str = "weight",
|
|
71
|
+
start_params=None,
|
|
53
72
|
):
|
|
54
73
|
if weighted:
|
|
55
74
|
df = df.with_columns(
|
|
@@ -64,7 +83,13 @@ def _outcome_fit(
|
|
|
64
83
|
df_pd = _cast_categories(self, df.to_pandas())
|
|
65
84
|
|
|
66
85
|
if self.followup_spline:
|
|
67
|
-
|
|
86
|
+
if getattr(self, "_current_boot_idx", None) is None:
|
|
87
|
+
self._followup_spline_knots = _compute_spline_knots(
|
|
88
|
+
self.DT["followup"].to_numpy(), df=self.followup_spline_df
|
|
89
|
+
)
|
|
90
|
+
formula = _apply_spline_formula(
|
|
91
|
+
formula, self.indicator_squared, self._followup_spline_knots
|
|
92
|
+
)
|
|
68
93
|
|
|
69
94
|
full_formula = f"{outcome} ~ {formula}"
|
|
70
95
|
|
|
@@ -78,5 +103,23 @@ def _outcome_fit(
|
|
|
78
103
|
glm_kwargs["var_weights"] = df_pd[weight_col]
|
|
79
104
|
|
|
80
105
|
model = smf.glm(**glm_kwargs)
|
|
81
|
-
|
|
106
|
+
|
|
107
|
+
# Drop warm-start coefs unless the design matrix columns match exactly
|
|
108
|
+
# by name — bootstrap resamples can shift categorical reference levels or
|
|
109
|
+
# column ordering, in which case the cached coefs are meaningless and
|
|
110
|
+
# IRLS can diverge into NaN/Inf and crash LAPACK.
|
|
111
|
+
if start_params is not None:
|
|
112
|
+
sp_values, sp_names = start_params
|
|
113
|
+
if list(model.exog_names) != list(sp_names):
|
|
114
|
+
start_params = None
|
|
115
|
+
else:
|
|
116
|
+
start_params = sp_values
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
model_fit = model.fit(start_params=start_params)
|
|
120
|
+
except Exception:
|
|
121
|
+
if start_params is not None:
|
|
122
|
+
model_fit = model.fit()
|
|
123
|
+
else:
|
|
124
|
+
raise
|
|
82
125
|
return model_fit
|
|
@@ -3,17 +3,24 @@ import polars as pl
|
|
|
3
3
|
from ._outcome_fit import _outcome_fit
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def _subgroup_fit(self):
|
|
6
|
+
def _subgroup_fit(self, start_params=None):
|
|
7
7
|
subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
|
|
8
8
|
self._unique_subgroups = subgroups
|
|
9
9
|
|
|
10
10
|
models_list = []
|
|
11
11
|
for val in subgroups:
|
|
12
12
|
subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)
|
|
13
|
+
sg_start = (start_params or {}).get(val, {}) or {}
|
|
13
14
|
|
|
14
15
|
models = {
|
|
15
16
|
"outcome": _outcome_fit(
|
|
16
|
-
self,
|
|
17
|
+
self,
|
|
18
|
+
subDT,
|
|
19
|
+
self.outcome_col,
|
|
20
|
+
self.covariates,
|
|
21
|
+
self.weighted,
|
|
22
|
+
"weight",
|
|
23
|
+
start_params=sg_start.get("outcome"),
|
|
17
24
|
)
|
|
18
25
|
}
|
|
19
26
|
|
|
@@ -25,6 +32,7 @@ def _subgroup_fit(self):
|
|
|
25
32
|
self.covariates,
|
|
26
33
|
self.weighted,
|
|
27
34
|
"weight",
|
|
35
|
+
start_params=sg_start.get("compevent"),
|
|
28
36
|
)
|
|
29
37
|
models_list.append(models)
|
|
30
38
|
return models_list
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import numpy as np
|
|
1
2
|
import polars as pl
|
|
3
|
+
from patsy import PatsyError, dmatrix
|
|
2
4
|
|
|
5
|
+
from ..helpers._fix_categories import _fix_categories_for_predict
|
|
3
6
|
from ..helpers._predict_model import _safe_predict
|
|
4
7
|
from ._outcome_fit import _cast_categories
|
|
5
8
|
|
|
@@ -25,20 +28,79 @@ def _store_boot_risks(obj, treatment_val, TxDT, boot_cols, is_survival=False):
|
|
|
25
28
|
)
|
|
26
29
|
|
|
27
30
|
|
|
31
|
+
def _build_design_matrix(design_info, data):
|
|
32
|
+
"""
|
|
33
|
+
Build a design matrix from a cached design_info, applying the same
|
|
34
|
+
category-alignment fallback that _safe_predict uses on mismatch.
|
|
35
|
+
"""
|
|
36
|
+
try:
|
|
37
|
+
return np.asarray(dmatrix(design_info, data))
|
|
38
|
+
except PatsyError as e:
|
|
39
|
+
if "mismatching levels" not in str(e):
|
|
40
|
+
raise
|
|
41
|
+
|
|
42
|
+
# Reuse the existing fix by wrapping design_info in a stub object
|
|
43
|
+
class _Stub:
|
|
44
|
+
class model:
|
|
45
|
+
class data:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
stub = _Stub()
|
|
49
|
+
stub.model.data.design_info = design_info
|
|
50
|
+
fixed = _fix_categories_for_predict(stub, data.copy())
|
|
51
|
+
return np.asarray(dmatrix(design_info, fixed))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _cached_predict(model, X_cached, ref_column_names, data):
|
|
55
|
+
"""
|
|
56
|
+
Predict using a pre-built design matrix when the model's design_info
|
|
57
|
+
column structure matches the reference, falling back to patsy via
|
|
58
|
+
_safe_predict on mismatch (e.g. a bootstrap resample that dropped a
|
|
59
|
+
categorical level).
|
|
60
|
+
"""
|
|
61
|
+
dinfo = model.model.data.design_info
|
|
62
|
+
if list(dinfo.column_names) == ref_column_names:
|
|
63
|
+
probs = np.asarray(model.predict(X_cached, transform=False))
|
|
64
|
+
if not np.any(np.isnan(probs)):
|
|
65
|
+
return np.clip(probs, 0, 1)
|
|
66
|
+
return _safe_predict(model, data)
|
|
67
|
+
|
|
68
|
+
|
|
28
69
|
def _get_outcome_predictions(self, TxDT, idx=None):
|
|
29
70
|
data = _cast_categories(self, TxDT.to_pandas())
|
|
30
71
|
predictions = {"outcome": []}
|
|
31
72
|
if self.compevent_colname is not None:
|
|
32
73
|
predictions["compevent"] = []
|
|
33
74
|
|
|
75
|
+
# Pre-build the design matrix once using the main fit's design_info.
|
|
76
|
+
# Each bootstrap model that shares the same column structure can then
|
|
77
|
+
# reuse it, skipping patsy entirely on the predict path.
|
|
78
|
+
main = self.outcome_model[0]
|
|
79
|
+
main_dict = main[idx] if idx is not None else main
|
|
80
|
+
main_outcome = self._offloader.load_model(main_dict["outcome"])
|
|
81
|
+
outcome_dinfo = main_outcome.model.data.design_info
|
|
82
|
+
X_outcome = _build_design_matrix(outcome_dinfo, data)
|
|
83
|
+
outcome_cols = list(outcome_dinfo.column_names)
|
|
84
|
+
|
|
85
|
+
X_compevent = compevent_cols = None
|
|
86
|
+
if self.compevent_colname is not None:
|
|
87
|
+
main_compevent = self._offloader.load_model(main_dict["compevent"])
|
|
88
|
+
compevent_dinfo = main_compevent.model.data.design_info
|
|
89
|
+
X_compevent = _build_design_matrix(compevent_dinfo, data)
|
|
90
|
+
compevent_cols = list(compevent_dinfo.column_names)
|
|
91
|
+
|
|
34
92
|
for boot_model in self.outcome_model:
|
|
35
93
|
model_dict = boot_model[idx] if idx is not None else boot_model
|
|
36
94
|
outcome_model = self._offloader.load_model(model_dict["outcome"])
|
|
37
|
-
predictions["outcome"].append(
|
|
95
|
+
predictions["outcome"].append(
|
|
96
|
+
_cached_predict(outcome_model, X_outcome, outcome_cols, data)
|
|
97
|
+
)
|
|
38
98
|
|
|
39
99
|
if self.compevent_colname is not None:
|
|
40
100
|
compevent_model = self._offloader.load_model(model_dict["compevent"])
|
|
41
|
-
predictions["compevent"].append(
|
|
101
|
+
predictions["compevent"].append(
|
|
102
|
+
_cached_predict(compevent_model, X_compevent, compevent_cols, data)
|
|
103
|
+
)
|
|
42
104
|
|
|
43
105
|
return predictions
|
|
44
106
|
|
|
@@ -31,8 +31,8 @@ def _data_checker(self):
|
|
|
31
31
|
invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
|
|
32
32
|
if len(invalid) > 0:
|
|
33
33
|
raise ValueError(
|
|
34
|
-
f"Data validation failed: {len(invalid)} ID(s) have mismatched "
|
|
35
|
-
f"This suggests invalid times"
|
|
34
|
+
f"Data validation failed: {len(invalid)} ID(s) have mismatched row counts. "
|
|
35
|
+
f"This suggests invalid times. "
|
|
36
36
|
f"Invalid IDs:\n{invalid}"
|
|
37
37
|
)
|
|
38
38
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
1
3
|
from ..helpers import _pad
|
|
2
4
|
|
|
3
5
|
|
|
@@ -26,13 +28,13 @@ def _param_checker(self):
|
|
|
26
28
|
|
|
27
29
|
if len(self.excused_colnames) == 0 and self.excused:
|
|
28
30
|
self.excused = False
|
|
29
|
-
|
|
31
|
+
warnings.warn(
|
|
30
32
|
"Excused column names not provided but excused is set to True. Automatically set excused to False"
|
|
31
33
|
)
|
|
32
34
|
|
|
33
35
|
if len(self.excused_colnames) > 0 and not self.excused:
|
|
34
36
|
self.excused = True
|
|
35
|
-
|
|
37
|
+
warnings.warn(
|
|
36
38
|
"Excused column names provided but excused is set to False. Automatically set excused to True"
|
|
37
39
|
)
|
|
38
40
|
|
|
@@ -44,6 +46,9 @@ def _param_checker(self):
|
|
|
44
46
|
"Only one of followup_class or followup_include can be set to True."
|
|
45
47
|
)
|
|
46
48
|
|
|
49
|
+
if self.followup_spline_df < 2:
|
|
50
|
+
raise ValueError("followup_spline_df must be at least 2.")
|
|
51
|
+
|
|
47
52
|
if (
|
|
48
53
|
self.weighted
|
|
49
54
|
and self.method == "ITT"
|
|
@@ -3,7 +3,7 @@ import polars as pl
|
|
|
3
3
|
|
|
4
4
|
def _dynamic(self):
|
|
5
5
|
"""
|
|
6
|
-
Handles special cases for the data from the
|
|
6
|
+
Handles special cases for the data from the _mapper -> _binder pipeline
|
|
7
7
|
"""
|
|
8
8
|
if self.method == "dose-response":
|
|
9
9
|
DT = self.DT.with_columns(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pySEQTarget
|
|
3
|
-
Version: 0.13.
|
|
3
|
+
Version: 0.13.4
|
|
4
4
|
Summary: Sequentially Nested Target Trial Emulation
|
|
5
5
|
Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernán <mhernan@hsph.harvard.edu>
|
|
6
6
|
Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>, Tom Palmer <remlapmot@hotmail.com>
|
|
@@ -2,6 +2,7 @@ import polars as pl
|
|
|
2
2
|
from polars.testing import assert_frame_equal
|
|
3
3
|
|
|
4
4
|
from pySEQTarget import SEQopts, SEQuential
|
|
5
|
+
from pySEQTarget.data import load_data
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def _make_model(data):
|
|
@@ -126,3 +127,59 @@ def test_expand_only_returns_expanded_dataframe():
|
|
|
126
127
|
model_full.expand()
|
|
127
128
|
|
|
128
129
|
assert_frame_equal(result, model_full.DT)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _make_verbose_model(verbose, **extra_opts):
|
|
133
|
+
data = load_data("SEQdata")
|
|
134
|
+
return SEQuential(
|
|
135
|
+
data,
|
|
136
|
+
id_col="ID",
|
|
137
|
+
time_col="time",
|
|
138
|
+
eligible_col="eligible",
|
|
139
|
+
treatment_col="tx_init",
|
|
140
|
+
outcome_col="outcome",
|
|
141
|
+
time_varying_cols=["N", "L", "P"],
|
|
142
|
+
fixed_cols=["sex"],
|
|
143
|
+
method="ITT",
|
|
144
|
+
parameters=SEQopts(verbose=verbose, **extra_opts),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_verbose_expand(capsys):
|
|
149
|
+
s = _make_verbose_model(verbose=True)
|
|
150
|
+
s.expand()
|
|
151
|
+
out = capsys.readouterr().out
|
|
152
|
+
assert "Full dataset:" in out
|
|
153
|
+
assert "Eligible observations:" in out
|
|
154
|
+
assert "Expanded dataset:" in out
|
|
155
|
+
assert "Final analysis dataset:" in out
|
|
156
|
+
assert "Sampled expanded dataset:" not in out
|
|
157
|
+
assert "observations" in out
|
|
158
|
+
assert "variables" in out
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_verbose_expand_with_sampling(capsys):
|
|
162
|
+
s = _make_verbose_model(verbose=True, selection_random=True, selection_sample=0.5)
|
|
163
|
+
s.expand()
|
|
164
|
+
out = capsys.readouterr().out
|
|
165
|
+
assert "Sampled expanded dataset:" in out
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def test_verbose_bootstrap(capsys):
|
|
169
|
+
s = _make_verbose_model(verbose=True, bootstrap_nboot=10)
|
|
170
|
+
s.expand()
|
|
171
|
+
capsys.readouterr()
|
|
172
|
+
s.bootstrap()
|
|
173
|
+
out = capsys.readouterr().out
|
|
174
|
+
assert "Bootstrapping" in out
|
|
175
|
+
assert "subjects" in out
|
|
176
|
+
assert "observations per resample" in out
|
|
177
|
+
assert "10 times" in out
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_verbose_false_no_output(capsys):
|
|
181
|
+
s = _make_verbose_model(verbose=False, bootstrap_nboot=5)
|
|
182
|
+
s.expand()
|
|
183
|
+
s.bootstrap()
|
|
184
|
+
out = capsys.readouterr().out
|
|
185
|
+
assert out == ""
|
|
@@ -57,17 +57,18 @@ def test_followup_spline():
|
|
|
57
57
|
s.fit()
|
|
58
58
|
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
|
|
59
59
|
expected = [
|
|
60
|
-
-
|
|
61
|
-
0.
|
|
62
|
-
0.
|
|
63
|
-
0.
|
|
64
|
-
0.
|
|
65
|
-
0.
|
|
66
|
-
-0.
|
|
67
|
-
0.
|
|
68
|
-
-
|
|
69
|
-
|
|
70
|
-
|
|
60
|
+
-4.804282252748607,
|
|
61
|
+
0.19115933860001255,
|
|
62
|
+
0.12717121164606823,
|
|
63
|
+
0.044310717515918724,
|
|
64
|
+
0.0005814999431447507,
|
|
65
|
+
0.0032948355025455216,
|
|
66
|
+
-0.013371824500839971,
|
|
67
|
+
0.19972467861548412,
|
|
68
|
+
-2.027245615586753,
|
|
69
|
+
-1.395729861856384,
|
|
70
|
+
-0.9397731941281695,
|
|
71
|
+
-0.4415335811772879,
|
|
71
72
|
]
|
|
72
73
|
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]
|
|
73
74
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|