pytmle 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytmle-0.1.0/PKG-INFO +24 -0
- pytmle-0.1.0/README.md +5 -0
- pytmle-0.1.0/pyproject.toml +35 -0
- pytmle-0.1.0/pytmle/__init__.py +2 -0
- pytmle-0.1.0/pytmle/bootstrap.py +187 -0
- pytmle-0.1.0/pytmle/estimates.py +254 -0
- pytmle-0.1.0/pytmle/evalues_benchmark.py +382 -0
- pytmle-0.1.0/pytmle/g_computation.py +66 -0
- pytmle-0.1.0/pytmle/get_influence_curve.py +257 -0
- pytmle-0.1.0/pytmle/get_initial_estimates.py +542 -0
- pytmle-0.1.0/pytmle/initial_estimates_default_models.py +117 -0
- pytmle-0.1.0/pytmle/plotting.py +218 -0
- pytmle-0.1.0/pytmle/predict_ate.py +318 -0
- pytmle-0.1.0/pytmle/pycox_wrapper.py +307 -0
- pytmle-0.1.0/pytmle/pytmle.py +705 -0
- pytmle-0.1.0/pytmle/tmle_update.py +371 -0
- pytmle-0.1.0/pytmle.egg-info/PKG-INFO +24 -0
- pytmle-0.1.0/pytmle.egg-info/SOURCES.txt +25 -0
- pytmle-0.1.0/pytmle.egg-info/dependency_links.txt +1 -0
- pytmle-0.1.0/pytmle.egg-info/requires.txt +13 -0
- pytmle-0.1.0/pytmle.egg-info/top_level.txt +1 -0
- pytmle-0.1.0/setup.cfg +4 -0
- pytmle-0.1.0/tests/test_main_class.py +52 -0
- pytmle-0.1.0/tests/test_mock_initial_estimates.py +25 -0
- pytmle-0.1.0/tests/test_predict_methods.py +130 -0
- pytmle-0.1.0/tests/test_pycox_wrapper.py +102 -0
- pytmle-0.1.0/tests/test_tmle_update.py +48 -0
pytmle-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pytmle
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A Flexible Python Implementation of Targeted Estimation for Survival and Competing Risks Analysis
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: matplotlib>=3.5.0
|
|
8
|
+
Requires-Dist: numpy>=1.22.3
|
|
9
|
+
Requires-Dist: pandas>=1.3.4
|
|
10
|
+
Requires-Dist: pycox
|
|
11
|
+
Requires-Dist: scikit-learn>=1.2.2
|
|
12
|
+
Requires-Dist: scikit-survival>=0.21.0
|
|
13
|
+
Requires-Dist: seaborn>=0.11.2
|
|
14
|
+
Requires-Dist: tqdm>=4.67.1
|
|
15
|
+
Provides-Extra: dev
|
|
16
|
+
Requires-Dist: ipykernel>=6.29.5; extra == "dev"
|
|
17
|
+
Requires-Dist: pytest>=8.3.5; extra == "dev"
|
|
18
|
+
Requires-Dist: torch>=2.6.0; extra == "dev"
|
|
19
|
+
|
|
20
|
+
# PyTMLE
|
|
21
|
+
|
|
22
|
+
A flexible Python implementation of the Targeted Maximum Likelihood Estimator (TMLE) for the cause-specific absolute risk of time-to-event outcomes measured in continuous time.
|
|
23
|
+
|
|
24
|
+
Additional information and documentation will be added in the next minor update.
|
pytmle-0.1.0/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
# PyTMLE
|
|
2
|
+
|
|
3
|
+
A flexible Python implementation of the Targeted Maximum Likelihood Estimator (TMLE) for the cause-specific absolute risk of time-to-event outcomes measured in continuous time.
|
|
4
|
+
|
|
5
|
+
Additional information and documentation will be added in the next minor update.
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "pytmle"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "A Flexible Python Implementation of Targeted Estimation for Survival and Competing Risks Analysis"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"matplotlib>=3.5.0",
|
|
9
|
+
"numpy>=1.22.3",
|
|
10
|
+
"pandas>=1.3.4",
|
|
11
|
+
"pycox",
|
|
12
|
+
"scikit-learn>=1.2.2",
|
|
13
|
+
"scikit-survival>=0.21.0",
|
|
14
|
+
"seaborn>=0.11.2",
|
|
15
|
+
"tqdm>=4.67.1",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
dev = [
|
|
20
|
+
"ipykernel>=6.29.5",
|
|
21
|
+
"pytest>=8.3.5",
|
|
22
|
+
"torch>=2.6.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[tool.uv.sources]
|
|
26
|
+
pycox = { git = "https://github.com/pooya-mohammadi/pycox" }
|
|
27
|
+
torch = { index = "pytorch-cpu" }
|
|
28
|
+
|
|
29
|
+
[tool.setuptools]
|
|
30
|
+
packages = ["pytmle"]
|
|
31
|
+
|
|
32
|
+
[[tool.uv.index]]
|
|
33
|
+
name = "pytorch-cpu"
|
|
34
|
+
url = "https://download.pytorch.org/whl/cpu"
|
|
35
|
+
explicit = true
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
import warnings
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
7
|
+
|
|
8
|
+
from pytmle.tmle_update import tmle_update
|
|
9
|
+
from pytmle.predict_ate import get_counterfactual_risks, ate_ratio, ate_diff
|
|
10
|
+
from pytmle.estimates import InitialEstimates
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def standard_bootstrap(event_indicator):
|
|
14
|
+
return np.random.choice(len(event_indicator),
|
|
15
|
+
size=len(event_indicator),
|
|
16
|
+
replace=True)
|
|
17
|
+
|
|
18
|
+
def stratified_bootstrap(event_indicator):
|
|
19
|
+
"""
|
|
20
|
+
Generate bootstrap samples stratified by event indicator.
|
|
21
|
+
"""
|
|
22
|
+
sample_indices_all = []
|
|
23
|
+
for ev in np.unique(event_indicator):
|
|
24
|
+
indices = np.where(event_indicator == ev)[0]
|
|
25
|
+
sample_indices = np.random.choice(indices, size=len(indices), replace=True)
|
|
26
|
+
sample_indices_all.append(sample_indices)
|
|
27
|
+
return np.concatenate(sample_indices_all)
|
|
28
|
+
|
|
29
|
+
def single_boot(initial_estimates,
|
|
30
|
+
event_times,
|
|
31
|
+
event_indicator,
|
|
32
|
+
target_times,
|
|
33
|
+
target_events,
|
|
34
|
+
key_1,
|
|
35
|
+
key_0,
|
|
36
|
+
stratify_by_event,
|
|
37
|
+
**kwargs):
|
|
38
|
+
"""
|
|
39
|
+
Perform a single bootstrap sample and call tmle_update.
|
|
40
|
+
|
|
41
|
+
As pointed out by Coyle & van der Laan (2018; https://link.springer.com/chapter/10.1007/978-3-319-65304-4_28)
|
|
42
|
+
and Tran et al. (2023; https://www.degruyter.com/document/doi/10.1515/jci-2021-0067/html?srsltid=AfmBOopT0k3YNof6ON7IWkEv49nuaK_bqgd_bCL8GSyYvmUNBDoGavDG),
|
|
43
|
+
only the second stage of TMLE should be bootstrapped, not the first stage
|
|
44
|
+
"""
|
|
45
|
+
# Create a bootstrap sample of indices
|
|
46
|
+
if stratify_by_event:
|
|
47
|
+
sample_indices = stratified_bootstrap(event_indicator)
|
|
48
|
+
else:
|
|
49
|
+
sample_indices = standard_bootstrap(event_indicator)
|
|
50
|
+
|
|
51
|
+
# Resample initial estimates, event times and event indicator;
|
|
52
|
+
boot_initial_estimates = {}
|
|
53
|
+
for k in initial_estimates.keys():
|
|
54
|
+
boot_initial_estimates[k] = initial_estimates[k][sample_indices]
|
|
55
|
+
boot_event_times = event_times[sample_indices]
|
|
56
|
+
boot_event_indicator = event_indicator[sample_indices]
|
|
57
|
+
# Call tmle_update
|
|
58
|
+
updated_estimates, _, converged, _ = tmle_update(
|
|
59
|
+
initial_estimates=boot_initial_estimates,
|
|
60
|
+
event_times=boot_event_times,
|
|
61
|
+
event_indicator=boot_event_indicator,
|
|
62
|
+
target_times=target_times,
|
|
63
|
+
target_events=target_events,
|
|
64
|
+
verbose=0,
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
if not converged:
|
|
68
|
+
# if tmle_update did not converge, return None
|
|
69
|
+
return
|
|
70
|
+
cf_risks = get_counterfactual_risks(updated_estimates,
|
|
71
|
+
key_1=key_1,
|
|
72
|
+
key_0=key_0)[["Event", "Time", "Group", "Pt Est"]]
|
|
73
|
+
cf_risks["type"] = "risks"
|
|
74
|
+
ate_ratios = ate_ratio(updated_estimates,
|
|
75
|
+
key_1=key_1,
|
|
76
|
+
key_0=key_0)[["Event", "Time", "Pt Est"]]
|
|
77
|
+
ate_ratios["type"] = "rr"
|
|
78
|
+
ate_ratios["Group"] = -1
|
|
79
|
+
ate_diffs = ate_diff(updated_estimates,
|
|
80
|
+
key_1=key_1,
|
|
81
|
+
key_0=key_0)[["Event", "Time", "Pt Est"]]
|
|
82
|
+
ate_diffs["type"] = "rd"
|
|
83
|
+
ate_diffs["Group"] = -1
|
|
84
|
+
result_df = pd.concat([cf_risks, ate_ratios, ate_diffs])
|
|
85
|
+
return result_df
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def bootstrap_tmle_loop(
|
|
89
|
+
initial_estimates: Dict[int, InitialEstimates],
|
|
90
|
+
event_times: np.ndarray,
|
|
91
|
+
event_indicator: np.ndarray,
|
|
92
|
+
target_times: List[float],
|
|
93
|
+
target_events: List[int],
|
|
94
|
+
n_bootstrap: int = 100,
|
|
95
|
+
n_jobs: int = -1,
|
|
96
|
+
alpha: float = 0.05,
|
|
97
|
+
key_1: int = 1,
|
|
98
|
+
key_0: int = 0,
|
|
99
|
+
stratify_by_event: bool = False,
|
|
100
|
+
verbose: int = 2,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> Optional[pd.DataFrame]:
|
|
103
|
+
"""
|
|
104
|
+
Perform parallel bootstrapping and call tmle_update on each sample.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
initial_estimates: Dict[int, InitialEstimates]
|
|
109
|
+
Initial estimates for each group.
|
|
110
|
+
event_times: np.ndarray
|
|
111
|
+
Array of event times.
|
|
112
|
+
event_indicator: np.ndarray
|
|
113
|
+
Array of event indicators.
|
|
114
|
+
target_times: List[float]
|
|
115
|
+
List of target times.
|
|
116
|
+
target_events: List[int]
|
|
117
|
+
List of target events.
|
|
118
|
+
n_bootstrap: int
|
|
119
|
+
Number of bootstrap samples.
|
|
120
|
+
n_jobs: int
|
|
121
|
+
Number of parallel jobs for bootstrapping.
|
|
122
|
+
alpha: float
|
|
123
|
+
Significance level for confidence intervals.
|
|
124
|
+
key_1: int
|
|
125
|
+
Key for group 1.
|
|
126
|
+
key_0: int
|
|
127
|
+
Key for group 0.
|
|
128
|
+
stratify_by_event: bool
|
|
129
|
+
Stratify bootstrapping by event indicator.
|
|
130
|
+
verbose: int
|
|
131
|
+
Verbosity level.
|
|
132
|
+
kwargs
|
|
133
|
+
Additional arguments to pass to tmle_update.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
Optional[pd.DataFrame]
|
|
138
|
+
DataFrame with bootstrapped confidence intervals.
|
|
139
|
+
"""
|
|
140
|
+
with ProcessPoolExecutor(max_workers=n_jobs if n_jobs > 0 else None) as executor:
|
|
141
|
+
futures = [
|
|
142
|
+
executor.submit(
|
|
143
|
+
single_boot,
|
|
144
|
+
initial_estimates,
|
|
145
|
+
event_times,
|
|
146
|
+
event_indicator,
|
|
147
|
+
target_times,
|
|
148
|
+
target_events,
|
|
149
|
+
key_1,
|
|
150
|
+
key_0,
|
|
151
|
+
stratify_by_event,
|
|
152
|
+
**kwargs,
|
|
153
|
+
)
|
|
154
|
+
for _ in range(n_bootstrap)
|
|
155
|
+
]
|
|
156
|
+
results = []
|
|
157
|
+
if verbose >= 2:
|
|
158
|
+
futures_iter = tqdm(
|
|
159
|
+
as_completed(futures), total=n_bootstrap, desc="Bootstrapping"
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
futures_iter = as_completed(futures)
|
|
163
|
+
for f in futures_iter:
|
|
164
|
+
result = f.result()
|
|
165
|
+
if result is not None:
|
|
166
|
+
results.append(result)
|
|
167
|
+
if len(results) == 0:
|
|
168
|
+
if verbose >= 1:
|
|
169
|
+
warnings.warn(
|
|
170
|
+
"Not a single bootstrap samples converged. Bootstrapped CIs will not be available.",
|
|
171
|
+
RuntimeWarning,
|
|
172
|
+
)
|
|
173
|
+
return None
|
|
174
|
+
if verbose >= 2:
|
|
175
|
+
print(
|
|
176
|
+
f"TMLE converged for {len(results)} out of {n_bootstrap} bootstrap samples."
|
|
177
|
+
)
|
|
178
|
+
results_df = pd.concat(results)
|
|
179
|
+
summary_df = (
|
|
180
|
+
results_df.groupby(["type", "Event", "Time", "Group"])["Pt Est"]
|
|
181
|
+
.agg(
|
|
182
|
+
mean_bootstrap="mean",
|
|
183
|
+
CI_lower=lambda x: x.quantile(alpha / 2),
|
|
184
|
+
CI_upper=lambda x: x.quantile(1 - alpha / 2)
|
|
185
|
+
)
|
|
186
|
+
).reset_index()
|
|
187
|
+
return summary_df
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Optional, List, Union
|
|
5
|
+
|
|
6
|
+
from pytmle.g_computation import get_g_comp
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class InitialEstimates:
|
|
10
|
+
# these fields must be filled on instatiation
|
|
11
|
+
times: np.ndarray
|
|
12
|
+
g_star_obs: np.ndarray
|
|
13
|
+
# these fields are optional and can be filled later
|
|
14
|
+
propensity_scores: Optional[np.ndarray] = field(default=None)
|
|
15
|
+
hazards: Optional[np.ndarray] = field(default=None)
|
|
16
|
+
event_free_survival_function: Optional[np.ndarray] = field(default=None)
|
|
17
|
+
censoring_survival_function: Optional[np.ndarray] = field(default=None)
|
|
18
|
+
_length: Optional[int] = field(default=None, init=False)
|
|
19
|
+
_run_checks: bool = field(default=True, init=False)
|
|
20
|
+
|
|
21
|
+
def __setattr__(self, name, value):
|
|
22
|
+
if value is not None and self._run_checks:
|
|
23
|
+
if name in ["propensity_scores",
|
|
24
|
+
"g_star_obs"]:
|
|
25
|
+
self._check_compatibility(value, check_width=False)
|
|
26
|
+
elif name in ["hazards",
|
|
27
|
+
"event_free_survival_function",
|
|
28
|
+
"censoring_survival_function"]:
|
|
29
|
+
self._check_compatibility(value, check_width=True)
|
|
30
|
+
super().__setattr__(name, value)
|
|
31
|
+
|
|
32
|
+
def _check_compatibility(self, new_element, check_width):
|
|
33
|
+
# check that all given estimates have the same length (first dimension size)
|
|
34
|
+
if self._length is None:
|
|
35
|
+
self._length = len(new_element)
|
|
36
|
+
elif self._length != len(new_element):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"All initial estimates must have the same first dimension, got elements with sizes {self._length} and {len(new_element)}."
|
|
39
|
+
)
|
|
40
|
+
if check_width and ((len(new_element.shape) < 2) or (new_element.shape[1] != len(self.times))):
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"The second dimension of all initial estimates must be in line with the given times, got {len(self.times)} times and element of shape {new_element.shape}."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def __getitem__(self, key: Union[np.ndarray, List[int]]) -> "InitialEstimates":
|
|
46
|
+
"""
|
|
47
|
+
Enable subsetting of an InitialEstimates object (needed for bootstrapping)
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
key (Union[np.ndarray, List[int]]): The indices of the subset.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
InitialEstimates: A new InitialEstimates object containing the subset.
|
|
54
|
+
"""
|
|
55
|
+
return InitialEstimates(
|
|
56
|
+
times=self.times,
|
|
57
|
+
g_star_obs=self.g_star_obs[key],
|
|
58
|
+
propensity_scores=(
|
|
59
|
+
self.propensity_scores[key]
|
|
60
|
+
if self.propensity_scores is not None
|
|
61
|
+
else None
|
|
62
|
+
),
|
|
63
|
+
hazards=self.hazards[key] if self.hazards is not None else None,
|
|
64
|
+
event_free_survival_function=(
|
|
65
|
+
self.event_free_survival_function[key]
|
|
66
|
+
if self.event_free_survival_function is not None
|
|
67
|
+
else None
|
|
68
|
+
),
|
|
69
|
+
censoring_survival_function=(
|
|
70
|
+
self.censoring_survival_function[key]
|
|
71
|
+
if self.censoring_survival_function is not None
|
|
72
|
+
else None
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def __len__(self):
|
|
77
|
+
return self._length
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class UpdatedEstimates(InitialEstimates):
|
|
82
|
+
# all have to be given
|
|
83
|
+
propensity_scores: np.ndarray # type: ignore
|
|
84
|
+
hazards: np.ndarray # type: ignore
|
|
85
|
+
event_free_survival_function: np.ndarray # type: ignore
|
|
86
|
+
censoring_survival_function: np.ndarray # type: ignore
|
|
87
|
+
|
|
88
|
+
# is set on initialization
|
|
89
|
+
nuisance_weight: Optional[np.ndarray] = field(default=None, init=False)
|
|
90
|
+
|
|
91
|
+
min_nuisance: Optional[float] = field(default=None)
|
|
92
|
+
target_events: Optional[List[int]] = field(default=None)
|
|
93
|
+
target_times: Optional[List[float]] = field(default=None)
|
|
94
|
+
g_comp_est: Optional[pd.DataFrame] = field(default=None)
|
|
95
|
+
ic: Optional[pd.DataFrame] = field(default=None)
|
|
96
|
+
summ_eic: Optional[pd.DataFrame] = field(default=None)
|
|
97
|
+
|
|
98
|
+
def __post_init__(self):
|
|
99
|
+
if self.min_nuisance is None:
|
|
100
|
+
self.min_nuisance = (
|
|
101
|
+
5
|
|
102
|
+
/ (len(self.propensity_scores) ** 0.5)
|
|
103
|
+
/ (np.log(len(self.propensity_scores)))
|
|
104
|
+
)
|
|
105
|
+
if self.target_times is None:
|
|
106
|
+
# default if not target_times are given: only target the last time point
|
|
107
|
+
self.target_times = [self.times[-1]]
|
|
108
|
+
else:
|
|
109
|
+
self._update_for_target_times()
|
|
110
|
+
self._set_nuisance_weight()
|
|
111
|
+
|
|
112
|
+
def _set_nuisance_weight(self):
|
|
113
|
+
lagged_censoring_survival_function = np.column_stack(
|
|
114
|
+
[
|
|
115
|
+
np.ones((self.censoring_survival_function.shape[0], 1)),
|
|
116
|
+
self.censoring_survival_function[:, :-1],
|
|
117
|
+
],
|
|
118
|
+
)
|
|
119
|
+
nuisance_denominator = (
|
|
120
|
+
self.propensity_scores[:, np.newaxis] * lagged_censoring_survival_function
|
|
121
|
+
)
|
|
122
|
+
# TODO: Add positivity check as in https://github.com/imbroglio-dc/concrete/blob/main/R/getInitialEstimate.R#L64?
|
|
123
|
+
self.nuisance_weight = 1 / np.maximum(nuisance_denominator, self.min_nuisance) # type: ignore
|
|
124
|
+
self._check_compatibility(self.nuisance_weight, check_width=True)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_initial_estimates(
|
|
128
|
+
cls,
|
|
129
|
+
initial_estimates: InitialEstimates,
|
|
130
|
+
target_events: Optional[List[int]] = None,
|
|
131
|
+
target_times: Optional[List[float]] = None,
|
|
132
|
+
min_nuisance: Optional[float] = None,
|
|
133
|
+
) -> "UpdatedEstimates":
|
|
134
|
+
assert (initial_estimates.propensity_scores is not None and
|
|
135
|
+
initial_estimates.hazards is not None and
|
|
136
|
+
initial_estimates.event_free_survival_function is not None and
|
|
137
|
+
initial_estimates.censoring_survival_function is not None), "All initial estimates have to be provided prior to an instatiation of UpdatedEstimates."
|
|
138
|
+
return cls(
|
|
139
|
+
propensity_scores=initial_estimates.propensity_scores,
|
|
140
|
+
hazards=initial_estimates.hazards,
|
|
141
|
+
event_free_survival_function=initial_estimates.event_free_survival_function,
|
|
142
|
+
censoring_survival_function=initial_estimates.censoring_survival_function,
|
|
143
|
+
min_nuisance=min_nuisance,
|
|
144
|
+
target_events=target_events,
|
|
145
|
+
target_times=target_times,
|
|
146
|
+
g_star_obs=initial_estimates.g_star_obs,
|
|
147
|
+
times=initial_estimates.times,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _update_for_target_times(self):
|
|
151
|
+
"""
|
|
152
|
+
Updates the time-related attributes of the object to include target times (plus 0).
|
|
153
|
+
This method performs the following steps:
|
|
154
|
+
1. Combines and sorts the existing times and target times.
|
|
155
|
+
2. Finds the indices where the target times should be inserted.
|
|
156
|
+
3. Updates the `hazards`, `event_free_survival_function`, and `censoring_survival_function`
|
|
157
|
+
attributes to account for the new target times by inserting appropriate values.
|
|
158
|
+
4. Trims the `hazards`, `event_free_survival_function`, and `censoring_survival_function`
|
|
159
|
+
attributes to only include times up to the maximum target time.
|
|
160
|
+
5. Updates the `times` attribute to include the target times up to the maximum target time.
|
|
161
|
+
Attributes:
|
|
162
|
+
times (np.ndarray): Array of existing times.
|
|
163
|
+
target_times (np.ndarray): Array of target times to be included.
|
|
164
|
+
hazards (np.ndarray): Array of hazard values.
|
|
165
|
+
event_free_survival_function (np.ndarray): Array of event-free survival function values.
|
|
166
|
+
censoring_survival_function (np.ndarray): Array of censoring survival function values.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# Combine and sort the times
|
|
170
|
+
all_times = np.sort(np.unique(np.concatenate((self.times, [0] + self.target_times)))) # type: ignore
|
|
171
|
+
|
|
172
|
+
if len(all_times) > len(self.times):
|
|
173
|
+
|
|
174
|
+
# Update hazards, event_free_survival_function, and censoring_survival_function
|
|
175
|
+
if 0 not in self.times:
|
|
176
|
+
self.times = np.insert(self.times, 0, 0)
|
|
177
|
+
self.hazards = np.insert(self.hazards, 0, 0, axis=1)
|
|
178
|
+
self.event_free_survival_function = np.insert(self.event_free_survival_function, 0, 1, axis=1)
|
|
179
|
+
self.censoring_survival_function = np.insert(self.censoring_survival_function, 0, 1, axis=1)
|
|
180
|
+
|
|
181
|
+
# Find the indices where the new times should be inserted
|
|
182
|
+
insert_times = [t for t in self.target_times if t not in self.times]
|
|
183
|
+
insert_indices = np.searchsorted(all_times, insert_times)
|
|
184
|
+
|
|
185
|
+
self.times = all_times
|
|
186
|
+
|
|
187
|
+
hazards_new = self.hazards
|
|
188
|
+
event_free_survival_function_new = self.event_free_survival_function
|
|
189
|
+
censoring_survival_function_new = self.censoring_survival_function
|
|
190
|
+
for idx in insert_indices:
|
|
191
|
+
hazards_new = np.insert(hazards_new, idx, 0, axis=1)
|
|
192
|
+
event_free_survival_function_new = np.insert(
|
|
193
|
+
event_free_survival_function_new,
|
|
194
|
+
idx,
|
|
195
|
+
event_free_survival_function_new[:, idx - 1],
|
|
196
|
+
axis=1,
|
|
197
|
+
)
|
|
198
|
+
censoring_survival_function_new = np.insert(
|
|
199
|
+
censoring_survival_function_new,
|
|
200
|
+
idx,
|
|
201
|
+
censoring_survival_function_new[:, idx - 1],
|
|
202
|
+
axis=1,
|
|
203
|
+
)
|
|
204
|
+
self.hazards = hazards_new
|
|
205
|
+
self.event_free_survival_function = event_free_survival_function_new
|
|
206
|
+
self.censoring_survival_function = censoring_survival_function_new
|
|
207
|
+
|
|
208
|
+
# Find the index of the maximum target time
|
|
209
|
+
max_target_time = max(self.target_times) # type: ignore
|
|
210
|
+
max_index = np.searchsorted(all_times, max_target_time)
|
|
211
|
+
# Keep only times up to the maximum index
|
|
212
|
+
self.times = all_times[: max_index + 1]
|
|
213
|
+
self.hazards = self.hazards[:, : max_index + 1, :]
|
|
214
|
+
self.event_free_survival_function = self.event_free_survival_function[
|
|
215
|
+
:, : max_index + 1
|
|
216
|
+
]
|
|
217
|
+
self.censoring_survival_function = self.censoring_survival_function[
|
|
218
|
+
:, : max_index + 1
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
def predict_mean_risks(self, g_comp: bool = False) -> pd.DataFrame:
|
|
222
|
+
"""
|
|
223
|
+
Predict the mean risks for the target events and times.
|
|
224
|
+
Args:
|
|
225
|
+
g_comp (bool): Flag to return the G-computation estimate instead of the TMLE estimate.
|
|
226
|
+
Returns:
|
|
227
|
+
pd.DataFrame: DataFrame with columns 'Event', 'Time', 'Pt Est', and 'SE' containing the mean counterfactual risks.
|
|
228
|
+
"""
|
|
229
|
+
if g_comp:
|
|
230
|
+
if self.g_comp_est is None:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"g_comp_est is not available."
|
|
233
|
+
)
|
|
234
|
+
# return g_comp_estimate from BEFORE the TMLE update loop (standard error not available)
|
|
235
|
+
pred_risk = self.g_comp_est
|
|
236
|
+
pred_risk["SE"] = np.nan
|
|
237
|
+
else:
|
|
238
|
+
# return g_comp_estimate from AFTER the TMLE update loop
|
|
239
|
+
if self.summ_eic is None or self.ic is None:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"ic or summ_eic is not available."
|
|
242
|
+
)
|
|
243
|
+
pred_risk = get_g_comp(
|
|
244
|
+
eval_times=self.times,
|
|
245
|
+
hazards=self.hazards,
|
|
246
|
+
total_surv=self.event_free_survival_function,
|
|
247
|
+
target_time=self.target_times, # type: ignore
|
|
248
|
+
target_events=self.target_events, # type: ignore
|
|
249
|
+
)
|
|
250
|
+
pred_risk = pred_risk.merge(self.summ_eic, on=["Event", "Time"])
|
|
251
|
+
pred_risk["SE"] = pred_risk["seEIC"] / len(self)**0.5
|
|
252
|
+
pred_risk = pred_risk[["Event", "Time", "Risk", "SE"]]
|
|
253
|
+
pred_risk.rename(columns={"Risk": "Pt Est"}, inplace=True)
|
|
254
|
+
return pred_risk
|