pytmle 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pytmle-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,24 @@
1
+ Metadata-Version: 2.4
2
+ Name: pytmle
3
+ Version: 0.1.0
4
+ Summary: A Flexible Python Implementation of Targeted Estimation for Survival and Competing Risks Analysis
5
+ Requires-Python: >=3.9
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: matplotlib>=3.5.0
8
+ Requires-Dist: numpy>=1.22.3
9
+ Requires-Dist: pandas>=1.3.4
10
+ Requires-Dist: pycox
11
+ Requires-Dist: scikit-learn>=1.2.2
12
+ Requires-Dist: scikit-survival>=0.21.0
13
+ Requires-Dist: seaborn>=0.11.2
14
+ Requires-Dist: tqdm>=4.67.1
15
+ Provides-Extra: dev
16
+ Requires-Dist: ipykernel>=6.29.5; extra == "dev"
17
+ Requires-Dist: pytest>=8.3.5; extra == "dev"
18
+ Requires-Dist: torch>=2.6.0; extra == "dev"
19
+
20
+ # PyTMLE
21
+
22
+ A flexible Python implementation of the Targeted Maximum Likelihood Estimator (TMLE) for the cause-specific absolute risk of time-to-event outcomes measured in continuous time.
23
+
24
+ Additional information and documentation will be added in the next minor update.
pytmle-0.1.0/README.md ADDED
@@ -0,0 +1,5 @@
1
+ # PyTMLE
2
+
3
+ A flexible Python implementation of the Targeted Maximum Likelihood Estimator (TMLE) for the cause-specific absolute risk of time-to-event outcomes measured in continuous time.
4
+
5
+ Additional information and documentation will be added in the next minor update.
@@ -0,0 +1,35 @@
1
+ [project]
2
+ name = "pytmle"
3
+ version = "0.1.0"
4
+ description = "A Flexible Python Implementation of Targeted Estimation for Survival and Competing Risks Analysis"
5
+ readme = "README.md"
6
+ requires-python = ">=3.9"
7
+ dependencies = [
8
+ "matplotlib>=3.5.0",
9
+ "numpy>=1.22.3",
10
+ "pandas>=1.3.4",
11
+ "pycox",
12
+ "scikit-learn>=1.2.2",
13
+ "scikit-survival>=0.21.0",
14
+ "seaborn>=0.11.2",
15
+ "tqdm>=4.67.1",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ dev = [
20
+ "ipykernel>=6.29.5",
21
+ "pytest>=8.3.5",
22
+ "torch>=2.6.0",
23
+ ]
24
+
25
+ [tool.uv.sources]
26
+ pycox = { git = "https://github.com/pooya-mohammadi/pycox" }
27
+ torch = { index = "pytorch-cpu" }
28
+
29
+ [tool.setuptools]
30
+ packages = ["pytmle"]
31
+
32
+ [[tool.uv.index]]
33
+ name = "pytorch-cpu"
34
+ url = "https://download.pytorch.org/whl/cpu"
35
+ explicit = true
@@ -0,0 +1,2 @@
1
+ from .pytmle import PyTMLE
2
+ from .estimates import InitialEstimates
@@ -0,0 +1,187 @@
1
+ from typing import Dict, List, Optional
2
+ import warnings
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from concurrent.futures import ProcessPoolExecutor, as_completed
7
+
8
+ from pytmle.tmle_update import tmle_update
9
+ from pytmle.predict_ate import get_counterfactual_risks, ate_ratio, ate_diff
10
+ from pytmle.estimates import InitialEstimates
11
+
12
+
13
+ def standard_bootstrap(event_indicator):
14
+ return np.random.choice(len(event_indicator),
15
+ size=len(event_indicator),
16
+ replace=True)
17
+
18
+ def stratified_bootstrap(event_indicator):
19
+ """
20
+ Generate bootstrap samples stratified by event indicator.
21
+ """
22
+ sample_indices_all = []
23
+ for ev in np.unique(event_indicator):
24
+ indices = np.where(event_indicator == ev)[0]
25
+ sample_indices = np.random.choice(indices, size=len(indices), replace=True)
26
+ sample_indices_all.append(sample_indices)
27
+ return np.concatenate(sample_indices_all)
28
+
29
+ def single_boot(initial_estimates,
30
+ event_times,
31
+ event_indicator,
32
+ target_times,
33
+ target_events,
34
+ key_1,
35
+ key_0,
36
+ stratify_by_event,
37
+ **kwargs):
38
+ """
39
+ Perform a single bootstrap sample and call tmle_update.
40
+
41
+ As pointed out by Coyle & van der Laan (2018; https://link.springer.com/chapter/10.1007/978-3-319-65304-4_28)
42
+ and Tran et al. (2023; https://www.degruyter.com/document/doi/10.1515/jci-2021-0067/html?srsltid=AfmBOopT0k3YNof6ON7IWkEv49nuaK_bqgd_bCL8GSyYvmUNBDoGavDG),
43
+ only the second stage of TMLE should be bootstrapped, not the first stage
44
+ """
45
+ # Create a bootstrap sample of indices
46
+ if stratify_by_event:
47
+ sample_indices = stratified_bootstrap(event_indicator)
48
+ else:
49
+ sample_indices = standard_bootstrap(event_indicator)
50
+
51
+ # Resample initial estimates, event times and event indicator;
52
+ boot_initial_estimates = {}
53
+ for k in initial_estimates.keys():
54
+ boot_initial_estimates[k] = initial_estimates[k][sample_indices]
55
+ boot_event_times = event_times[sample_indices]
56
+ boot_event_indicator = event_indicator[sample_indices]
57
+ # Call tmle_update
58
+ updated_estimates, _, converged, _ = tmle_update(
59
+ initial_estimates=boot_initial_estimates,
60
+ event_times=boot_event_times,
61
+ event_indicator=boot_event_indicator,
62
+ target_times=target_times,
63
+ target_events=target_events,
64
+ verbose=0,
65
+ **kwargs,
66
+ )
67
+ if not converged:
68
+ # if tmle_update did not converge, return None
69
+ return
70
+ cf_risks = get_counterfactual_risks(updated_estimates,
71
+ key_1=key_1,
72
+ key_0=key_0)[["Event", "Time", "Group", "Pt Est"]]
73
+ cf_risks["type"] = "risks"
74
+ ate_ratios = ate_ratio(updated_estimates,
75
+ key_1=key_1,
76
+ key_0=key_0)[["Event", "Time", "Pt Est"]]
77
+ ate_ratios["type"] = "rr"
78
+ ate_ratios["Group"] = -1
79
+ ate_diffs = ate_diff(updated_estimates,
80
+ key_1=key_1,
81
+ key_0=key_0)[["Event", "Time", "Pt Est"]]
82
+ ate_diffs["type"] = "rd"
83
+ ate_diffs["Group"] = -1
84
+ result_df = pd.concat([cf_risks, ate_ratios, ate_diffs])
85
+ return result_df
86
+
87
+
88
+ def bootstrap_tmle_loop(
89
+ initial_estimates: Dict[int, InitialEstimates],
90
+ event_times: np.ndarray,
91
+ event_indicator: np.ndarray,
92
+ target_times: List[float],
93
+ target_events: List[int],
94
+ n_bootstrap: int = 100,
95
+ n_jobs: int = -1,
96
+ alpha: float = 0.05,
97
+ key_1: int = 1,
98
+ key_0: int = 0,
99
+ stratify_by_event: bool = False,
100
+ verbose: int = 2,
101
+ **kwargs,
102
+ ) -> Optional[pd.DataFrame]:
103
+ """
104
+ Perform parallel bootstrapping and call tmle_update on each sample.
105
+
106
+ Parameters
107
+ ----------
108
+ initial_estimates: Dict[int, InitialEstimates]
109
+ Initial estimates for each group.
110
+ event_times: np.ndarray
111
+ Array of event times.
112
+ event_indicator: np.ndarray
113
+ Array of event indicators.
114
+ target_times: List[float]
115
+ List of target times.
116
+ target_events: List[int]
117
+ List of target events.
118
+ n_bootstrap: int
119
+ Number of bootstrap samples.
120
+ n_jobs: int
121
+ Number of parallel jobs for bootstrapping.
122
+ alpha: float
123
+ Significance level for confidence intervals.
124
+ key_1: int
125
+ Key for group 1.
126
+ key_0: int
127
+ Key for group 0.
128
+ stratify_by_event: bool
129
+ Stratify bootstrapping by event indicator.
130
+ verbose: int
131
+ Verbosity level.
132
+ kwargs
133
+ Additional arguments to pass to tmle_update.
134
+
135
+ Returns
136
+ -------
137
+ Optional[pd.DataFrame]
138
+ DataFrame with bootstrapped confidence intervals.
139
+ """
140
+ with ProcessPoolExecutor(max_workers=n_jobs if n_jobs > 0 else None) as executor:
141
+ futures = [
142
+ executor.submit(
143
+ single_boot,
144
+ initial_estimates,
145
+ event_times,
146
+ event_indicator,
147
+ target_times,
148
+ target_events,
149
+ key_1,
150
+ key_0,
151
+ stratify_by_event,
152
+ **kwargs,
153
+ )
154
+ for _ in range(n_bootstrap)
155
+ ]
156
+ results = []
157
+ if verbose >= 2:
158
+ futures_iter = tqdm(
159
+ as_completed(futures), total=n_bootstrap, desc="Bootstrapping"
160
+ )
161
+ else:
162
+ futures_iter = as_completed(futures)
163
+ for f in futures_iter:
164
+ result = f.result()
165
+ if result is not None:
166
+ results.append(result)
167
+ if len(results) == 0:
168
+ if verbose >= 1:
169
+ warnings.warn(
170
+ "Not a single bootstrap samples converged. Bootstrapped CIs will not be available.",
171
+ RuntimeWarning,
172
+ )
173
+ return None
174
+ if verbose >= 2:
175
+ print(
176
+ f"TMLE converged for {len(results)} out of {n_bootstrap} bootstrap samples."
177
+ )
178
+ results_df = pd.concat(results)
179
+ summary_df = (
180
+ results_df.groupby(["type", "Event", "Time", "Group"])["Pt Est"]
181
+ .agg(
182
+ mean_bootstrap="mean",
183
+ CI_lower=lambda x: x.quantile(alpha / 2),
184
+ CI_upper=lambda x: x.quantile(1 - alpha / 2)
185
+ )
186
+ ).reset_index()
187
+ return summary_df
@@ -0,0 +1,254 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional, List, Union
5
+
6
+ from pytmle.g_computation import get_g_comp
7
+
8
+ @dataclass
9
+ class InitialEstimates:
10
+ # these fields must be filled on instatiation
11
+ times: np.ndarray
12
+ g_star_obs: np.ndarray
13
+ # these fields are optional and can be filled later
14
+ propensity_scores: Optional[np.ndarray] = field(default=None)
15
+ hazards: Optional[np.ndarray] = field(default=None)
16
+ event_free_survival_function: Optional[np.ndarray] = field(default=None)
17
+ censoring_survival_function: Optional[np.ndarray] = field(default=None)
18
+ _length: Optional[int] = field(default=None, init=False)
19
+ _run_checks: bool = field(default=True, init=False)
20
+
21
+ def __setattr__(self, name, value):
22
+ if value is not None and self._run_checks:
23
+ if name in ["propensity_scores",
24
+ "g_star_obs"]:
25
+ self._check_compatibility(value, check_width=False)
26
+ elif name in ["hazards",
27
+ "event_free_survival_function",
28
+ "censoring_survival_function"]:
29
+ self._check_compatibility(value, check_width=True)
30
+ super().__setattr__(name, value)
31
+
32
+ def _check_compatibility(self, new_element, check_width):
33
+ # check that all given estimates have the same length (first dimension size)
34
+ if self._length is None:
35
+ self._length = len(new_element)
36
+ elif self._length != len(new_element):
37
+ raise ValueError(
38
+ f"All initial estimates must have the same first dimension, got elements with sizes {self._length} and {len(new_element)}."
39
+ )
40
+ if check_width and ((len(new_element.shape) < 2) or (new_element.shape[1] != len(self.times))):
41
+ raise ValueError(
42
+ f"The second dimension of all initial estimates must be in line with the given times, got {len(self.times)} times and element of shape {new_element.shape}."
43
+ )
44
+
45
+ def __getitem__(self, key: Union[np.ndarray, List[int]]) -> "InitialEstimates":
46
+ """
47
+ Enable subsetting of an InitialEstimates object (needed for bootstrapping)
48
+
49
+ Args:
50
+ key (Union[np.ndarray, List[int]]): The indices of the subset.
51
+
52
+ Returns:
53
+ InitialEstimates: A new InitialEstimates object containing the subset.
54
+ """
55
+ return InitialEstimates(
56
+ times=self.times,
57
+ g_star_obs=self.g_star_obs[key],
58
+ propensity_scores=(
59
+ self.propensity_scores[key]
60
+ if self.propensity_scores is not None
61
+ else None
62
+ ),
63
+ hazards=self.hazards[key] if self.hazards is not None else None,
64
+ event_free_survival_function=(
65
+ self.event_free_survival_function[key]
66
+ if self.event_free_survival_function is not None
67
+ else None
68
+ ),
69
+ censoring_survival_function=(
70
+ self.censoring_survival_function[key]
71
+ if self.censoring_survival_function is not None
72
+ else None
73
+ ),
74
+ )
75
+
76
+ def __len__(self):
77
+ return self._length
78
+
79
+
80
+ @dataclass
81
+ class UpdatedEstimates(InitialEstimates):
82
+ # all have to be given
83
+ propensity_scores: np.ndarray # type: ignore
84
+ hazards: np.ndarray # type: ignore
85
+ event_free_survival_function: np.ndarray # type: ignore
86
+ censoring_survival_function: np.ndarray # type: ignore
87
+
88
+ # is set on initialization
89
+ nuisance_weight: Optional[np.ndarray] = field(default=None, init=False)
90
+
91
+ min_nuisance: Optional[float] = field(default=None)
92
+ target_events: Optional[List[int]] = field(default=None)
93
+ target_times: Optional[List[float]] = field(default=None)
94
+ g_comp_est: Optional[pd.DataFrame] = field(default=None)
95
+ ic: Optional[pd.DataFrame] = field(default=None)
96
+ summ_eic: Optional[pd.DataFrame] = field(default=None)
97
+
98
+ def __post_init__(self):
99
+ if self.min_nuisance is None:
100
+ self.min_nuisance = (
101
+ 5
102
+ / (len(self.propensity_scores) ** 0.5)
103
+ / (np.log(len(self.propensity_scores)))
104
+ )
105
+ if self.target_times is None:
106
+ # default if not target_times are given: only target the last time point
107
+ self.target_times = [self.times[-1]]
108
+ else:
109
+ self._update_for_target_times()
110
+ self._set_nuisance_weight()
111
+
112
+ def _set_nuisance_weight(self):
113
+ lagged_censoring_survival_function = np.column_stack(
114
+ [
115
+ np.ones((self.censoring_survival_function.shape[0], 1)),
116
+ self.censoring_survival_function[:, :-1],
117
+ ],
118
+ )
119
+ nuisance_denominator = (
120
+ self.propensity_scores[:, np.newaxis] * lagged_censoring_survival_function
121
+ )
122
+ # TODO: Add positivity check as in https://github.com/imbroglio-dc/concrete/blob/main/R/getInitialEstimate.R#L64?
123
+ self.nuisance_weight = 1 / np.maximum(nuisance_denominator, self.min_nuisance) # type: ignore
124
+ self._check_compatibility(self.nuisance_weight, check_width=True)
125
+
126
+ @classmethod
127
+ def from_initial_estimates(
128
+ cls,
129
+ initial_estimates: InitialEstimates,
130
+ target_events: Optional[List[int]] = None,
131
+ target_times: Optional[List[float]] = None,
132
+ min_nuisance: Optional[float] = None,
133
+ ) -> "UpdatedEstimates":
134
+ assert (initial_estimates.propensity_scores is not None and
135
+ initial_estimates.hazards is not None and
136
+ initial_estimates.event_free_survival_function is not None and
137
+ initial_estimates.censoring_survival_function is not None), "All initial estimates have to be provided prior to an instatiation of UpdatedEstimates."
138
+ return cls(
139
+ propensity_scores=initial_estimates.propensity_scores,
140
+ hazards=initial_estimates.hazards,
141
+ event_free_survival_function=initial_estimates.event_free_survival_function,
142
+ censoring_survival_function=initial_estimates.censoring_survival_function,
143
+ min_nuisance=min_nuisance,
144
+ target_events=target_events,
145
+ target_times=target_times,
146
+ g_star_obs=initial_estimates.g_star_obs,
147
+ times=initial_estimates.times,
148
+ )
149
+
150
+ def _update_for_target_times(self):
151
+ """
152
+ Updates the time-related attributes of the object to include target times (plus 0).
153
+ This method performs the following steps:
154
+ 1. Combines and sorts the existing times and target times.
155
+ 2. Finds the indices where the target times should be inserted.
156
+ 3. Updates the `hazards`, `event_free_survival_function`, and `censoring_survival_function`
157
+ attributes to account for the new target times by inserting appropriate values.
158
+ 4. Trims the `hazards`, `event_free_survival_function`, and `censoring_survival_function`
159
+ attributes to only include times up to the maximum target time.
160
+ 5. Updates the `times` attribute to include the target times up to the maximum target time.
161
+ Attributes:
162
+ times (np.ndarray): Array of existing times.
163
+ target_times (np.ndarray): Array of target times to be included.
164
+ hazards (np.ndarray): Array of hazard values.
165
+ event_free_survival_function (np.ndarray): Array of event-free survival function values.
166
+ censoring_survival_function (np.ndarray): Array of censoring survival function values.
167
+ """
168
+
169
+ # Combine and sort the times
170
+ all_times = np.sort(np.unique(np.concatenate((self.times, [0] + self.target_times)))) # type: ignore
171
+
172
+ if len(all_times) > len(self.times):
173
+
174
+ # Update hazards, event_free_survival_function, and censoring_survival_function
175
+ if 0 not in self.times:
176
+ self.times = np.insert(self.times, 0, 0)
177
+ self.hazards = np.insert(self.hazards, 0, 0, axis=1)
178
+ self.event_free_survival_function = np.insert(self.event_free_survival_function, 0, 1, axis=1)
179
+ self.censoring_survival_function = np.insert(self.censoring_survival_function, 0, 1, axis=1)
180
+
181
+ # Find the indices where the new times should be inserted
182
+ insert_times = [t for t in self.target_times if t not in self.times]
183
+ insert_indices = np.searchsorted(all_times, insert_times)
184
+
185
+ self.times = all_times
186
+
187
+ hazards_new = self.hazards
188
+ event_free_survival_function_new = self.event_free_survival_function
189
+ censoring_survival_function_new = self.censoring_survival_function
190
+ for idx in insert_indices:
191
+ hazards_new = np.insert(hazards_new, idx, 0, axis=1)
192
+ event_free_survival_function_new = np.insert(
193
+ event_free_survival_function_new,
194
+ idx,
195
+ event_free_survival_function_new[:, idx - 1],
196
+ axis=1,
197
+ )
198
+ censoring_survival_function_new = np.insert(
199
+ censoring_survival_function_new,
200
+ idx,
201
+ censoring_survival_function_new[:, idx - 1],
202
+ axis=1,
203
+ )
204
+ self.hazards = hazards_new
205
+ self.event_free_survival_function = event_free_survival_function_new
206
+ self.censoring_survival_function = censoring_survival_function_new
207
+
208
+ # Find the index of the maximum target time
209
+ max_target_time = max(self.target_times) # type: ignore
210
+ max_index = np.searchsorted(all_times, max_target_time)
211
+ # Keep only times up to the maximum index
212
+ self.times = all_times[: max_index + 1]
213
+ self.hazards = self.hazards[:, : max_index + 1, :]
214
+ self.event_free_survival_function = self.event_free_survival_function[
215
+ :, : max_index + 1
216
+ ]
217
+ self.censoring_survival_function = self.censoring_survival_function[
218
+ :, : max_index + 1
219
+ ]
220
+
221
+ def predict_mean_risks(self, g_comp: bool = False) -> pd.DataFrame:
222
+ """
223
+ Predict the mean risks for the target events and times.
224
+ Args:
225
+ g_comp (bool): Flag to return the G-computation estimate instead of the TMLE estimate.
226
+ Returns:
227
+ pd.DataFrame: DataFrame with columns 'Event', 'Time', 'Pt Est', and 'SE' containing the mean counterfactual risks.
228
+ """
229
+ if g_comp:
230
+ if self.g_comp_est is None:
231
+ raise ValueError(
232
+ "g_comp_est is not available."
233
+ )
234
+ # return g_comp_estimate from BEFORE the TMLE update loop (standard error not available)
235
+ pred_risk = self.g_comp_est
236
+ pred_risk["SE"] = np.nan
237
+ else:
238
+ # return g_comp_estimate from AFTER the TMLE update loop
239
+ if self.summ_eic is None or self.ic is None:
240
+ raise ValueError(
241
+ "ic or summ_eic is not available."
242
+ )
243
+ pred_risk = get_g_comp(
244
+ eval_times=self.times,
245
+ hazards=self.hazards,
246
+ total_surv=self.event_free_survival_function,
247
+ target_time=self.target_times, # type: ignore
248
+ target_events=self.target_events, # type: ignore
249
+ )
250
+ pred_risk = pred_risk.merge(self.summ_eic, on=["Event", "Time"])
251
+ pred_risk["SE"] = pred_risk["seEIC"] / len(self)**0.5
252
+ pred_risk = pred_risk[["Event", "Time", "Risk", "SE"]]
253
+ pred_risk.rename(columns={"Risk": "Pt Est"}, inplace=True)
254
+ return pred_risk