survivalpredict 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- survivalpredict/__init__.py +0 -0
- survivalpredict/_allen_additive.py +68 -0
- survivalpredict/_base_hazard.py +24 -0
- survivalpredict/_cox_net_ph.py +254 -0
- survivalpredict/_cox_ph_elastic_net.py +338 -0
- survivalpredict/_cox_ph_estimation.py +461 -0
- survivalpredict/_cox_ph_estimation_left_censorship.py +593 -0
- survivalpredict/_data_validation.py +64 -0
- survivalpredict/_discrete_time_ph_estimation.py +338 -0
- survivalpredict/_estimator_utils.py +48 -0
- survivalpredict/_multi_task_logistic_regression.py +138 -0
- survivalpredict/_neighbors.py +75 -0
- survivalpredict/_nonparametric.py +93 -0
- survivalpredict/_stratification.py +301 -0
- survivalpredict/datasets/__init__.py +93 -0
- survivalpredict/datasets/iranian_churn_X.txt +3150 -0
- survivalpredict/datasets/iranian_churn_col_names.txt +12 -0
- survivalpredict/datasets/iranian_churn_events.txt +3150 -0
- survivalpredict/datasets/iranian_churn_times.txt +3150 -0
- survivalpredict/datasets/kickstarter_X.txt +18093 -0
- survivalpredict/datasets/kickstarter_col_names.txt +54 -0
- survivalpredict/datasets/kickstarter_events.txt +18093 -0
- survivalpredict/datasets/kickstarter_times.txt +18093 -0
- survivalpredict/estimators.py +2123 -0
- survivalpredict/metrics.py +539 -0
- survivalpredict/model_selection.py +525 -0
- survivalpredict/pipeline.py +398 -0
- survivalpredict/strata_preprocessing.py +908 -0
- survivalpredict/survivalpredict.egg-info/PKG-INFO +210 -0
- survivalpredict/survivalpredict.egg-info/SOURCES.txt +23 -0
- survivalpredict/survivalpredict.egg-info/dependency_links.txt +1 -0
- survivalpredict/survivalpredict.egg-info/requires.txt +14 -0
- survivalpredict/survivalpredict.egg-info/top_level.txt +2 -0
- survivalpredict/tests/__init__.py +0 -0
- survivalpredict/tests/test_estimators.py +49 -0
- survivalpredict/tests/test_left_censorship.py +58 -0
- survivalpredict/tests/test_model_selection.py +36 -0
- survivalpredict/tests/test_parametric_ph_stratification.py +63 -0
- survivalpredict/tests/test_pipeline.py +152 -0
- survivalpredict/tests/test_strata_preprocessing.py +252 -0
- survivalpredict/validation.py +518 -0
- survivalpredict-0.0.1.dist-info/METADATA +253 -0
- survivalpredict-0.0.1.dist-info/RECORD +45 -0
- survivalpredict-0.0.1.dist-info/WHEEL +4 -0
- survivalpredict-0.0.1.dist-info/licenses/LICENSE +201 -0
|
File without changes
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import numba as nb
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
_estimate_allen_additive_hazard_time_weights_sig = nb.types.Tuple(
|
|
5
|
+
(
|
|
6
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
7
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
8
|
+
)
|
|
9
|
+
)(
|
|
10
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
11
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
12
|
+
nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
|
|
13
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
14
|
+
nb.types.float64,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@nb.njit(_estimate_allen_additive_hazard_time_weights_sig, cache=True)
|
|
19
|
+
def _estimate_allen_additive_hazard_time_weights(X, times, events, times_start, alpha):
|
|
20
|
+
hazard_weights_times = np.unique(times[events])
|
|
21
|
+
hazard_weights = np.empty((hazard_weights_times.shape[0], X.shape[1]))
|
|
22
|
+
|
|
23
|
+
for i, t in enumerate(hazard_weights_times):
|
|
24
|
+
survived_at_time = times >= t
|
|
25
|
+
exits = times == t
|
|
26
|
+
deaths_mask = np.logical_and(exits, events)
|
|
27
|
+
not_right_censored = np.logical_and(survived_at_time, deaths_mask)
|
|
28
|
+
not_left_censored = times_start < t
|
|
29
|
+
not_censored = np.logical_or(not_right_censored, not_left_censored)
|
|
30
|
+
|
|
31
|
+
if not_censored.any() and not_censored.sum() > 1:
|
|
32
|
+
death_as_target = deaths_mask[not_censored].astype(np.float64)
|
|
33
|
+
|
|
34
|
+
X_mask = X[not_censored, :]
|
|
35
|
+
|
|
36
|
+
a = np.dot(X_mask.T, X_mask) + alpha
|
|
37
|
+
b = np.dot(X_mask.T, death_as_target)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
w = np.linalg.solve(a, b)
|
|
41
|
+
except:
|
|
42
|
+
w = np.zeros(X.shape[1])
|
|
43
|
+
|
|
44
|
+
hazard_weights[i, :] = w
|
|
45
|
+
|
|
46
|
+
return hazard_weights, hazard_weights_times
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
_generate_hazards_at_times_from_allen_additive_hazard_weights_sig = nb.types.Array(
|
|
50
|
+
nb.types.float64, 2, "C", False, aligned=True
|
|
51
|
+
)(
|
|
52
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
53
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
54
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
55
|
+
nb.types.int64,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@nb.njit(_generate_hazards_at_times_from_allen_additive_hazard_weights_sig, cache=True)
|
|
60
|
+
def _generate_hazards_at_times_from_allen_additive_hazard_weights(
|
|
61
|
+
X, hazard_weights, hazard_weights_times, max_time
|
|
62
|
+
):
|
|
63
|
+
hazards = np.zeros((X.shape[0], max_time))
|
|
64
|
+
|
|
65
|
+
for t, w in zip(hazard_weights_times, hazard_weights):
|
|
66
|
+
hazards[:, t] = np.dot(X, w)
|
|
67
|
+
|
|
68
|
+
return hazards
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _get_breslow_base_hazard(
|
|
5
|
+
risk,
|
|
6
|
+
times,
|
|
7
|
+
events,
|
|
8
|
+
max_time,
|
|
9
|
+
):
|
|
10
|
+
unique_times = np.arange(1, max_time + 1)
|
|
11
|
+
rows_at_risk_at_time = times[:, np.newaxis] > unique_times
|
|
12
|
+
|
|
13
|
+
failure_per_unique_time = np.bincount(
|
|
14
|
+
times.astype(np.int64), events, minlength=max_time + 1
|
|
15
|
+
)[1:]
|
|
16
|
+
|
|
17
|
+
risk_per_time = np.dot(risk, rows_at_risk_at_time)
|
|
18
|
+
base_hazard = np.divide(
|
|
19
|
+
failure_per_unique_time,
|
|
20
|
+
risk_per_time,
|
|
21
|
+
out=np.zeros(max_time),
|
|
22
|
+
where=risk_per_time != 0,
|
|
23
|
+
)
|
|
24
|
+
return base_hazard
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from functools import reduce
|
|
2
|
+
from itertools import pairwise
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
import optax
|
|
9
|
+
from optax import GradientTransformationExtraArgs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_gradient_updater(
|
|
13
|
+
gradient_updater: Literal[
|
|
14
|
+
"adadelta",
|
|
15
|
+
"adagrad",
|
|
16
|
+
"adam",
|
|
17
|
+
"adamax",
|
|
18
|
+
"rmsprop",
|
|
19
|
+
] = "adam",
|
|
20
|
+
learning_rate: float = 0.01,
|
|
21
|
+
beta1: float = 0.9,
|
|
22
|
+
beta2: float = 0.999,
|
|
23
|
+
epsilon: float = 0.0000001,
|
|
24
|
+
rho: float = 0.95,
|
|
25
|
+
decay: float = 0.9,
|
|
26
|
+
) -> GradientTransformationExtraArgs:
|
|
27
|
+
if gradient_updater == "adadelta":
|
|
28
|
+
return optax.adadelta(learning_rate=learning_rate, rho=rho, esp=epsilon)
|
|
29
|
+
|
|
30
|
+
elif gradient_updater == "adagrad":
|
|
31
|
+
return optax.adagrad(learning_rate=learning_rate, eps=epsilon)
|
|
32
|
+
|
|
33
|
+
elif gradient_updater == "adam":
|
|
34
|
+
return optax.adam(learning_rate=learning_rate, b1=beta1, b2=beta2, eps=epsilon)
|
|
35
|
+
|
|
36
|
+
elif gradient_updater == "adamax":
|
|
37
|
+
return optax.adamax(
|
|
38
|
+
learning_rate=learning_rate, b1=beta1, b2=beta1, eps=epsilon
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
else: # gradient_updater == 'rmsprop':
|
|
42
|
+
return optax.rmsprop(learning_rate=learning_rate, decay=decay, eps=epsilon)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def relu_jax(x):
|
|
46
|
+
return jnp.where(x < 0, 0, x)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _reverse_cumsum_jax(a):
|
|
50
|
+
return jnp.flip(jnp.cumsum(jnp.flip(a)))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_init_weights(
|
|
54
|
+
input_n_cols: int,
|
|
55
|
+
hidden_layers: list[int],
|
|
56
|
+
init_dis: Literal["uniform", "normal"] = "uniform",
|
|
57
|
+
):
|
|
58
|
+
weight_matrix_shapes = list(pairwise([input_n_cols] + hidden_layers))
|
|
59
|
+
if init_dis == "uniform":
|
|
60
|
+
initializer = jax.nn.initializers.he_uniform()
|
|
61
|
+
else:
|
|
62
|
+
initializer = jax.nn.initializers.he_normal()
|
|
63
|
+
|
|
64
|
+
jax_key = jax.random.key(np.random.randint(-10000, 10000))
|
|
65
|
+
|
|
66
|
+
weight_matrix_shapes.append((hidden_layers[-1], 1))
|
|
67
|
+
weights = [initializer(jax_key, shape=ws) for ws in weight_matrix_shapes]
|
|
68
|
+
weights[-1] = weights[-1].flatten()
|
|
69
|
+
return weights
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_cox_net_ph_loss(
|
|
73
|
+
weights,
|
|
74
|
+
X_strata: list[np.ndarray[tuple[int, int], np.dtype[np.floating]]],
|
|
75
|
+
n_strata: int,
|
|
76
|
+
events_strata: list[np.ndarray[tuple[int], np.dtype[np.bool_]]],
|
|
77
|
+
time_end_return_inverse_strata: list[np.ndarray[tuple[int], np.dtype[np.integer]]],
|
|
78
|
+
n_unique_times_strata: list[int],
|
|
79
|
+
alpha=0.0,
|
|
80
|
+
l1_ratio=0.5,
|
|
81
|
+
time_start_return_inverse_strata: (
|
|
82
|
+
list[np.ndarray[tuple[int], np.dtype[np.integer]]] | None
|
|
83
|
+
) = None,
|
|
84
|
+
):
|
|
85
|
+
|
|
86
|
+
uses_left_censorship = time_start_return_inverse_strata is not None
|
|
87
|
+
|
|
88
|
+
partial_log_likelihood_per_strata = []
|
|
89
|
+
|
|
90
|
+
abs_weights_sum = reduce(lambda a, b: a + b, [jnp.sum(jnp.abs(w)) for w in weights])
|
|
91
|
+
square_weights_sum = reduce(
|
|
92
|
+
lambda a, b: a + b, [jnp.sum(jnp.square(w)) for w in weights]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
for s_i in range(n_strata):
|
|
96
|
+
X = X_strata[s_i]
|
|
97
|
+
time_end_return_inverse = time_end_return_inverse_strata[s_i]
|
|
98
|
+
n_unique_times = n_unique_times_strata[s_i]
|
|
99
|
+
events = events_strata[s_i]
|
|
100
|
+
if uses_left_censorship:
|
|
101
|
+
time_start_return_inverse = time_start_return_inverse_strata[s_i]
|
|
102
|
+
|
|
103
|
+
matrixs_for_reduce_dot = [X] + weights[:-1]
|
|
104
|
+
|
|
105
|
+
second_to_last_layer = reduce(
|
|
106
|
+
lambda a, b: relu_jax(
|
|
107
|
+
jnp.dot(
|
|
108
|
+
a,
|
|
109
|
+
b,
|
|
110
|
+
)
|
|
111
|
+
),
|
|
112
|
+
matrixs_for_reduce_dot,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
o = jnp.dot(second_to_last_layer, weights[-1])
|
|
116
|
+
o_exp = jnp.exp(o)
|
|
117
|
+
|
|
118
|
+
if uses_left_censorship:
|
|
119
|
+
risk_removed_at_time = jnp.bincount(
|
|
120
|
+
time_end_return_inverse, weights=o_exp, minlength=n_unique_times
|
|
121
|
+
)
|
|
122
|
+
risk_added_at_time = jnp.bincount(
|
|
123
|
+
time_start_return_inverse, weights=o_exp, minlength=n_unique_times
|
|
124
|
+
)
|
|
125
|
+
risk_at_time = jnp.cumsum(risk_added_at_time - risk_removed_at_time)
|
|
126
|
+
risk_set = risk_at_time[time_end_return_inverse - 1]
|
|
127
|
+
else:
|
|
128
|
+
risk_set = _reverse_cumsum_jax(
|
|
129
|
+
jnp.bincount(
|
|
130
|
+
time_end_return_inverse, weights=o_exp, minlength=n_unique_times
|
|
131
|
+
)
|
|
132
|
+
)[time_end_return_inverse]
|
|
133
|
+
|
|
134
|
+
loss = -jnp.sum(events * (o - jnp.log(risk_set)))
|
|
135
|
+
|
|
136
|
+
partial_log_likelihood_per_strata.append(loss)
|
|
137
|
+
|
|
138
|
+
l1 = alpha * l1_ratio * abs_weights_sum
|
|
139
|
+
l2 = 0.5 * alpha * (1.0 - l1_ratio) * square_weights_sum
|
|
140
|
+
|
|
141
|
+
return reduce(lambda a, b: a + b, partial_log_likelihood_per_strata + [l1, l2])
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def train_cox_net_ph(
|
|
145
|
+
X_strata: list[np.ndarray[tuple[int, int], np.dtype[np.floating]]],
|
|
146
|
+
n_strata: int,
|
|
147
|
+
events_strata: list[np.ndarray[tuple[int], np.dtype[np.bool_]]],
|
|
148
|
+
time_end_return_inverse_strata: list[np.ndarray[tuple[int], np.dtype[np.integer]]],
|
|
149
|
+
n_unique_times_strata: list[int],
|
|
150
|
+
hidden_layers: list[int],
|
|
151
|
+
weights: Optional[list[np.ndarray]] = None,
|
|
152
|
+
alpha: float = 0.0,
|
|
153
|
+
l1_ratio: float = 0.5,
|
|
154
|
+
init_dis: Literal["uniform", "normal"] = "uniform",
|
|
155
|
+
track_loss=True,
|
|
156
|
+
max_iter=100,
|
|
157
|
+
gradient_updater: Literal[
|
|
158
|
+
"adadelta",
|
|
159
|
+
"adagrad",
|
|
160
|
+
"adam",
|
|
161
|
+
"adamax",
|
|
162
|
+
"rmsprop",
|
|
163
|
+
] = "adam",
|
|
164
|
+
learning_rate: float = 0.01,
|
|
165
|
+
beta1: float = 0.9,
|
|
166
|
+
beta2: float = 0.999,
|
|
167
|
+
epsilon: float = 0.0000001,
|
|
168
|
+
rho: float = 0.95,
|
|
169
|
+
decay: float = 0.9,
|
|
170
|
+
time_start_return_inverse_strata: (
|
|
171
|
+
list[np.ndarray[tuple[int], np.dtype[np.integer]]] | None
|
|
172
|
+
) = None,
|
|
173
|
+
) -> tuple[list[np.ndarray], float, list[float]]:
|
|
174
|
+
if weights is None:
|
|
175
|
+
weights = get_init_weights(X_strata[0].shape[1], hidden_layers, init_dis)
|
|
176
|
+
|
|
177
|
+
grad_updater = get_gradient_updater(
|
|
178
|
+
gradient_updater, learning_rate, beta1, beta2, epsilon, rho, decay
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
opt_state = grad_updater.init(weights)
|
|
182
|
+
get_cox_net_ph_grad = jax.grad(_get_cox_net_ph_loss)
|
|
183
|
+
|
|
184
|
+
losses_per_steps = []
|
|
185
|
+
loss = None
|
|
186
|
+
|
|
187
|
+
for i in range(max_iter):
|
|
188
|
+
jacobian = get_cox_net_ph_grad(
|
|
189
|
+
weights,
|
|
190
|
+
X_strata,
|
|
191
|
+
n_strata,
|
|
192
|
+
events_strata,
|
|
193
|
+
time_end_return_inverse_strata,
|
|
194
|
+
n_unique_times_strata,
|
|
195
|
+
alpha,
|
|
196
|
+
l1_ratio,
|
|
197
|
+
time_start_return_inverse_strata,
|
|
198
|
+
)
|
|
199
|
+
updates, opt_state = grad_updater.update(jacobian, opt_state, weights)
|
|
200
|
+
weights = optax.apply_updates(weights, updates)
|
|
201
|
+
|
|
202
|
+
if track_loss:
|
|
203
|
+
loss = _get_cox_net_ph_loss(
|
|
204
|
+
weights,
|
|
205
|
+
X_strata,
|
|
206
|
+
n_strata,
|
|
207
|
+
events_strata,
|
|
208
|
+
time_end_return_inverse_strata,
|
|
209
|
+
n_unique_times_strata,
|
|
210
|
+
alpha,
|
|
211
|
+
l1_ratio,
|
|
212
|
+
time_start_return_inverse_strata,
|
|
213
|
+
).item()
|
|
214
|
+
|
|
215
|
+
losses_per_steps.append(loss)
|
|
216
|
+
|
|
217
|
+
if loss is None:
|
|
218
|
+
loss = _get_cox_net_ph_loss(
|
|
219
|
+
weights,
|
|
220
|
+
X_strata,
|
|
221
|
+
n_strata,
|
|
222
|
+
events_strata,
|
|
223
|
+
time_end_return_inverse_strata,
|
|
224
|
+
n_unique_times_strata,
|
|
225
|
+
alpha,
|
|
226
|
+
l1_ratio,
|
|
227
|
+
time_start_return_inverse_strata,
|
|
228
|
+
).item()
|
|
229
|
+
|
|
230
|
+
weights_np = [np.array(w) for w in weights]
|
|
231
|
+
|
|
232
|
+
return weights_np, loss, losses_per_steps
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def relu_np(x):
|
|
236
|
+
return np.where(x < 0, 0, x)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def get_relative_risk_from_cox_net_ph_weights(
|
|
240
|
+
X: np.ndarray[tuple[int, int], np.dtype[np.floating]], weights: list[np.ndarray]
|
|
241
|
+
) -> np.ndarray[tuple[int], np.dtype[np.floating]]:
|
|
242
|
+
matrixs_for_reduce_dot_np = [X] + weights[:-1]
|
|
243
|
+
|
|
244
|
+
second_to_last_layer = reduce(
|
|
245
|
+
lambda a, b: relu_np(
|
|
246
|
+
np.dot(
|
|
247
|
+
a,
|
|
248
|
+
b,
|
|
249
|
+
)
|
|
250
|
+
),
|
|
251
|
+
matrixs_for_reduce_dot_np,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return np.exp(np.dot(second_to_last_layer, weights[-1]))
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import numba as nb
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from ._stratification import _unique_with_return_inverse
|
|
5
|
+
|
|
6
|
+
get_breslow_neg_log_likelihood_with_elasticnet_penalty_signature = nb.types.float64(
|
|
7
|
+
nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
|
|
8
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
9
|
+
nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
|
|
10
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
11
|
+
nb.types.int64,
|
|
12
|
+
nb.types.float64,
|
|
13
|
+
nb.types.float64,
|
|
14
|
+
nb.types.bool_,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@nb.njit(get_breslow_neg_log_likelihood_with_elasticnet_penalty_signature, cache=True)
|
|
19
|
+
def get_breslow_neg_log_likelihood_with_elasticnet_penalty(
|
|
20
|
+
weights,
|
|
21
|
+
X,
|
|
22
|
+
event,
|
|
23
|
+
time_return_inverse,
|
|
24
|
+
n_unique_times,
|
|
25
|
+
alpha,
|
|
26
|
+
l1_ratio,
|
|
27
|
+
scaled,
|
|
28
|
+
):
|
|
29
|
+
n = float(X.shape[0])
|
|
30
|
+
|
|
31
|
+
p = np.dot(X, weights)
|
|
32
|
+
p_exp = np.exp(p)
|
|
33
|
+
|
|
34
|
+
risk_set_at_time = np.flip(
|
|
35
|
+
np.cumsum(
|
|
36
|
+
np.flip(
|
|
37
|
+
np.bincount(
|
|
38
|
+
time_return_inverse, weights=p_exp, minlength=n_unique_times
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
risk_set = risk_set_at_time[time_return_inverse]
|
|
45
|
+
|
|
46
|
+
breslow_neg_log_likelihood = -np.sum(event * (p - np.log(risk_set)))
|
|
47
|
+
|
|
48
|
+
l1 = alpha * l1_ratio * np.abs(weights).sum()
|
|
49
|
+
l2 = 0.5 * alpha * (1.0 - l1_ratio) * np.square(weights).sum()
|
|
50
|
+
elasticnet_loss = l1 + l2
|
|
51
|
+
|
|
52
|
+
if scaled:
|
|
53
|
+
return (2 / n * breslow_neg_log_likelihood) + elasticnet_loss
|
|
54
|
+
else:
|
|
55
|
+
return breslow_neg_log_likelihood + elasticnet_loss
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@nb.njit
|
|
59
|
+
def soft_threasholding_operator(z, t):
|
|
60
|
+
return np.fmax((np.abs(z) - t), 0) * np.sign(z)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
train_cox_elastic_net_signature = nb.types.Tuple(
|
|
64
|
+
(
|
|
65
|
+
nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
|
|
66
|
+
nb.types.float64,
|
|
67
|
+
)
|
|
68
|
+
)(
|
|
69
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
70
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
71
|
+
nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
|
|
72
|
+
nb.types.float64,
|
|
73
|
+
nb.types.float64,
|
|
74
|
+
nb.types.float64,
|
|
75
|
+
nb.types.int64,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@nb.njit(train_cox_elastic_net_signature, cache=True)
|
|
80
|
+
def train_cox_elastic_net_regularization_paths(
|
|
81
|
+
X: np.ndarray,
|
|
82
|
+
times: np.ndarray,
|
|
83
|
+
events: np.ndarray,
|
|
84
|
+
alpha: float,
|
|
85
|
+
l1_ratio: float,
|
|
86
|
+
tol: float,
|
|
87
|
+
n_iter: int,
|
|
88
|
+
) -> tuple[np.ndarray, float]:
|
|
89
|
+
"""A direct implementation of,
|
|
90
|
+
Simon N, Friedman J, Hastie T, Tibshirani R. Regularization Paths for Cox's Proportional Hazards Model via Coordinate Descent. J Stat Softw. 2011 Mar;39(5):1-13. doi: 10.18637/jss.v039.i05. PMID: 27065756; PMCID: PMC4824408.
|
|
91
|
+
"""
|
|
92
|
+
unique_times, time_return_inverse = _unique_with_return_inverse(times)
|
|
93
|
+
n_unique_times = len(unique_times)
|
|
94
|
+
|
|
95
|
+
n = X.shape[0]
|
|
96
|
+
|
|
97
|
+
col_indexes = np.arange(X.shape[1])
|
|
98
|
+
|
|
99
|
+
weights = np.zeros(X.shape[1])
|
|
100
|
+
|
|
101
|
+
last_loss = np.inf
|
|
102
|
+
|
|
103
|
+
for _ in range(n_iter):
|
|
104
|
+
|
|
105
|
+
new_weights = weights.copy()
|
|
106
|
+
n_hat = np.dot(X, new_weights)
|
|
107
|
+
n_hat_exp = np.exp(n_hat)
|
|
108
|
+
n_hat_exp_risk_set = np.flip(
|
|
109
|
+
np.cumsum(
|
|
110
|
+
np.flip(
|
|
111
|
+
np.bincount(
|
|
112
|
+
time_return_inverse, n_hat_exp, minlength=n_unique_times
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
)
|
|
116
|
+
)[time_return_inverse]
|
|
117
|
+
wn = np.sum(
|
|
118
|
+
events
|
|
119
|
+
* (
|
|
120
|
+
(n_hat_exp * n_hat_exp_risk_set - n_hat_exp**2)
|
|
121
|
+
/ (n_hat_exp_risk_set**2)
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
zn = n_hat + 1 / wn * (
|
|
125
|
+
events - np.sum(events * (n_hat_exp / n_hat_exp_risk_set))
|
|
126
|
+
)
|
|
127
|
+
for j in range(X.shape[1]):
|
|
128
|
+
not_j = col_indexes[col_indexes != j]
|
|
129
|
+
not_j_n = np.dot(X[:, not_j], new_weights[not_j])
|
|
130
|
+
left = 1 / n * np.sum(wn * X[:, j] * (zn - not_j_n))
|
|
131
|
+
right = alpha * l1_ratio
|
|
132
|
+
top = soft_threasholding_operator(left, right)
|
|
133
|
+
bottom = 1 / n * np.sum(wn * (X[:, j]) ** 2 + alpha * (1 - l1_ratio))
|
|
134
|
+
new_weights[j] = top / bottom
|
|
135
|
+
|
|
136
|
+
loss = get_breslow_neg_log_likelihood_with_elasticnet_penalty(
|
|
137
|
+
weights,
|
|
138
|
+
X,
|
|
139
|
+
events,
|
|
140
|
+
time_return_inverse,
|
|
141
|
+
n_unique_times,
|
|
142
|
+
alpha,
|
|
143
|
+
l1_ratio,
|
|
144
|
+
True,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if (loss - last_loss) > tol:
|
|
148
|
+
break
|
|
149
|
+
last_loss = loss
|
|
150
|
+
weights = new_weights
|
|
151
|
+
|
|
152
|
+
final_loss = get_breslow_neg_log_likelihood_with_elasticnet_penalty(
|
|
153
|
+
weights, X, events, time_return_inverse, n_unique_times, alpha, l1_ratio, False
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return weights, final_loss
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship_signature = nb.types.float64(
|
|
160
|
+
nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
|
|
161
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
162
|
+
nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
|
|
163
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
164
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
165
|
+
nb.types.int64,
|
|
166
|
+
nb.types.float64,
|
|
167
|
+
nb.types.float64,
|
|
168
|
+
nb.types.bool_,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@nb.njit(
|
|
173
|
+
get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship_signature,
|
|
174
|
+
cache=True,
|
|
175
|
+
)
|
|
176
|
+
def get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
|
|
177
|
+
weights,
|
|
178
|
+
X,
|
|
179
|
+
event,
|
|
180
|
+
time_end_return_inverse,
|
|
181
|
+
time_start_return_inverse,
|
|
182
|
+
n_unique_times,
|
|
183
|
+
alpha,
|
|
184
|
+
l1_ratio,
|
|
185
|
+
scaled,
|
|
186
|
+
):
|
|
187
|
+
n = float(X.shape[0])
|
|
188
|
+
|
|
189
|
+
p = np.dot(X, weights)
|
|
190
|
+
p_exp = np.exp(p)
|
|
191
|
+
|
|
192
|
+
risk_removed_at_time = np.bincount(
|
|
193
|
+
time_end_return_inverse, weights=p_exp, minlength=n_unique_times
|
|
194
|
+
)
|
|
195
|
+
risk_added_at_time = np.bincount(
|
|
196
|
+
time_start_return_inverse, weights=p_exp, minlength=n_unique_times
|
|
197
|
+
)
|
|
198
|
+
risk_at_time = np.cumsum(risk_added_at_time - risk_removed_at_time)
|
|
199
|
+
risk_set = risk_at_time[time_end_return_inverse - 1]
|
|
200
|
+
|
|
201
|
+
breslow_neg_log_likelihood = -np.sum(event * (p - np.log(risk_set)))
|
|
202
|
+
|
|
203
|
+
l1 = alpha * l1_ratio * np.abs(weights).sum()
|
|
204
|
+
l2 = 0.5 * alpha * (1.0 - l1_ratio) * np.square(weights).sum()
|
|
205
|
+
elasticnet_loss = l1 + l2
|
|
206
|
+
|
|
207
|
+
if scaled:
|
|
208
|
+
return (2 / n * breslow_neg_log_likelihood) + elasticnet_loss
|
|
209
|
+
else:
|
|
210
|
+
return breslow_neg_log_likelihood + elasticnet_loss
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
train_cox_elastic_net_with_left_censorship_signature = nb.types.Tuple(
|
|
214
|
+
(
|
|
215
|
+
nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
|
|
216
|
+
nb.types.float64,
|
|
217
|
+
)
|
|
218
|
+
)(
|
|
219
|
+
nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
|
|
220
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
221
|
+
nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
|
|
222
|
+
nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
|
|
223
|
+
nb.types.float64,
|
|
224
|
+
nb.types.float64,
|
|
225
|
+
nb.types.float64,
|
|
226
|
+
nb.types.int64,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@nb.njit(train_cox_elastic_net_with_left_censorship_signature, cache=True)
|
|
231
|
+
def train_cox_elastic_net_with_left_censorship(
|
|
232
|
+
X: np.ndarray,
|
|
233
|
+
times: np.ndarray,
|
|
234
|
+
times_start: np.ndarray,
|
|
235
|
+
events: np.ndarray,
|
|
236
|
+
alpha: float,
|
|
237
|
+
l1_ratio: float,
|
|
238
|
+
tol: float,
|
|
239
|
+
n_iter: int,
|
|
240
|
+
) -> tuple[np.ndarray, float]:
|
|
241
|
+
"""A direct implementation of,
|
|
242
|
+
Simon N, Friedman J, Hastie T, Tibshirani R. Regularization Paths for Cox's Proportional Hazards Model via Coordinate Descent. J Stat Softw. 2011 Mar;39(5):1-13. doi: 10.18637/jss.v039.i05. PMID: 27065756; PMCID: PMC4824408.
|
|
243
|
+
"""
|
|
244
|
+
all_times = np.concatenate((times_start, times))
|
|
245
|
+
unique_times, unique_times_return_inverse = _unique_with_return_inverse(all_times)
|
|
246
|
+
|
|
247
|
+
time_end_return_inverse = unique_times_return_inverse[len(times) :]
|
|
248
|
+
time_start_return_inverse = unique_times_return_inverse[: len(times_start)]
|
|
249
|
+
|
|
250
|
+
n_unique_times = len(unique_times)
|
|
251
|
+
|
|
252
|
+
n = X.shape[0]
|
|
253
|
+
|
|
254
|
+
col_indexes = np.arange(X.shape[1])
|
|
255
|
+
|
|
256
|
+
weights = np.zeros(X.shape[1])
|
|
257
|
+
|
|
258
|
+
last_loss = np.inf
|
|
259
|
+
|
|
260
|
+
for _ in range(n_iter):
|
|
261
|
+
|
|
262
|
+
new_weights = weights.copy()
|
|
263
|
+
n_hat = np.dot(X, new_weights)
|
|
264
|
+
|
|
265
|
+
n_hat_exp = np.exp(n_hat)
|
|
266
|
+
|
|
267
|
+
# n_hat_exp_risk_set = np.flip(
|
|
268
|
+
# np.cumsum(
|
|
269
|
+
# np.flip(
|
|
270
|
+
# np.bincount(
|
|
271
|
+
# time_return_inverse, n_hat_exp, minlength=n_unique_times
|
|
272
|
+
# )
|
|
273
|
+
# )
|
|
274
|
+
# )
|
|
275
|
+
# )[time_return_inverse]
|
|
276
|
+
|
|
277
|
+
risk_removed_at_time = np.bincount(
|
|
278
|
+
time_end_return_inverse, weights=n_hat_exp, minlength=n_unique_times
|
|
279
|
+
)
|
|
280
|
+
risk_added_at_time = np.bincount(
|
|
281
|
+
time_start_return_inverse, weights=n_hat_exp, minlength=n_unique_times
|
|
282
|
+
)
|
|
283
|
+
risk_at_time = np.cumsum(risk_added_at_time - risk_removed_at_time)
|
|
284
|
+
n_hat_exp_risk_set = risk_at_time[time_end_return_inverse - 1]
|
|
285
|
+
|
|
286
|
+
wn = np.sum(
|
|
287
|
+
events
|
|
288
|
+
* (
|
|
289
|
+
(n_hat_exp * n_hat_exp_risk_set - n_hat_exp**2)
|
|
290
|
+
/ (n_hat_exp_risk_set**2)
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
zn = n_hat + 1 / wn * (
|
|
294
|
+
events - np.sum(events * (n_hat_exp / n_hat_exp_risk_set))
|
|
295
|
+
)
|
|
296
|
+
for j in range(X.shape[1]):
|
|
297
|
+
not_j = col_indexes[col_indexes != j]
|
|
298
|
+
not_j_n = np.dot(X[:, not_j], new_weights[not_j])
|
|
299
|
+
left = 1 / n * np.sum(wn * X[:, j] * (zn - not_j_n))
|
|
300
|
+
right = alpha * l1_ratio
|
|
301
|
+
top = soft_threasholding_operator(left, right)
|
|
302
|
+
bottom = 1 / n * np.sum(wn * (X[:, j]) ** 2 + alpha * (1 - l1_ratio))
|
|
303
|
+
new_weights[j] = top / bottom
|
|
304
|
+
|
|
305
|
+
loss = (
|
|
306
|
+
get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
|
|
307
|
+
weights,
|
|
308
|
+
X,
|
|
309
|
+
events,
|
|
310
|
+
time_end_return_inverse,
|
|
311
|
+
time_start_return_inverse,
|
|
312
|
+
n_unique_times,
|
|
313
|
+
alpha,
|
|
314
|
+
l1_ratio,
|
|
315
|
+
True,
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if (loss - last_loss) > tol:
|
|
320
|
+
break
|
|
321
|
+
last_loss = loss
|
|
322
|
+
weights = new_weights
|
|
323
|
+
|
|
324
|
+
final_loss = (
|
|
325
|
+
get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
|
|
326
|
+
weights,
|
|
327
|
+
X,
|
|
328
|
+
events,
|
|
329
|
+
time_end_return_inverse,
|
|
330
|
+
time_start_return_inverse,
|
|
331
|
+
n_unique_times,
|
|
332
|
+
alpha,
|
|
333
|
+
l1_ratio,
|
|
334
|
+
False,
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
return weights, final_loss
|