survivalpredict 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. survivalpredict/__init__.py +0 -0
  2. survivalpredict/_allen_additive.py +68 -0
  3. survivalpredict/_base_hazard.py +24 -0
  4. survivalpredict/_cox_net_ph.py +254 -0
  5. survivalpredict/_cox_ph_elastic_net.py +338 -0
  6. survivalpredict/_cox_ph_estimation.py +461 -0
  7. survivalpredict/_cox_ph_estimation_left_censorship.py +593 -0
  8. survivalpredict/_data_validation.py +64 -0
  9. survivalpredict/_discrete_time_ph_estimation.py +338 -0
  10. survivalpredict/_estimator_utils.py +48 -0
  11. survivalpredict/_multi_task_logistic_regression.py +138 -0
  12. survivalpredict/_neighbors.py +75 -0
  13. survivalpredict/_nonparametric.py +93 -0
  14. survivalpredict/_stratification.py +301 -0
  15. survivalpredict/datasets/__init__.py +93 -0
  16. survivalpredict/datasets/iranian_churn_X.txt +3150 -0
  17. survivalpredict/datasets/iranian_churn_col_names.txt +12 -0
  18. survivalpredict/datasets/iranian_churn_events.txt +3150 -0
  19. survivalpredict/datasets/iranian_churn_times.txt +3150 -0
  20. survivalpredict/datasets/kickstarter_X.txt +18093 -0
  21. survivalpredict/datasets/kickstarter_col_names.txt +54 -0
  22. survivalpredict/datasets/kickstarter_events.txt +18093 -0
  23. survivalpredict/datasets/kickstarter_times.txt +18093 -0
  24. survivalpredict/estimators.py +2123 -0
  25. survivalpredict/metrics.py +539 -0
  26. survivalpredict/model_selection.py +525 -0
  27. survivalpredict/pipeline.py +398 -0
  28. survivalpredict/strata_preprocessing.py +908 -0
  29. survivalpredict/survivalpredict.egg-info/PKG-INFO +210 -0
  30. survivalpredict/survivalpredict.egg-info/SOURCES.txt +23 -0
  31. survivalpredict/survivalpredict.egg-info/dependency_links.txt +1 -0
  32. survivalpredict/survivalpredict.egg-info/requires.txt +14 -0
  33. survivalpredict/survivalpredict.egg-info/top_level.txt +2 -0
  34. survivalpredict/tests/__init__.py +0 -0
  35. survivalpredict/tests/test_estimators.py +49 -0
  36. survivalpredict/tests/test_left_censorship.py +58 -0
  37. survivalpredict/tests/test_model_selection.py +36 -0
  38. survivalpredict/tests/test_parametric_ph_stratification.py +63 -0
  39. survivalpredict/tests/test_pipeline.py +152 -0
  40. survivalpredict/tests/test_strata_preprocessing.py +252 -0
  41. survivalpredict/validation.py +518 -0
  42. survivalpredict-0.0.1.dist-info/METADATA +253 -0
  43. survivalpredict-0.0.1.dist-info/RECORD +45 -0
  44. survivalpredict-0.0.1.dist-info/WHEEL +4 -0
  45. survivalpredict-0.0.1.dist-info/licenses/LICENSE +201 -0
File without changes
@@ -0,0 +1,68 @@
1
+ import numba as nb
2
+ import numpy as np
3
+
4
+ _estimate_allen_additive_hazard_time_weights_sig = nb.types.Tuple(
5
+ (
6
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
7
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
8
+ )
9
+ )(
10
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
11
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
12
+ nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
13
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
14
+ nb.types.float64,
15
+ )
16
+
17
+
18
+ @nb.njit(_estimate_allen_additive_hazard_time_weights_sig, cache=True)
19
+ def _estimate_allen_additive_hazard_time_weights(X, times, events, times_start, alpha):
20
+ hazard_weights_times = np.unique(times[events])
21
+ hazard_weights = np.empty((hazard_weights_times.shape[0], X.shape[1]))
22
+
23
+ for i, t in enumerate(hazard_weights_times):
24
+ survived_at_time = times >= t
25
+ exits = times == t
26
+ deaths_mask = np.logical_and(exits, events)
27
+ not_right_censored = np.logical_and(survived_at_time, deaths_mask)
28
+ not_left_censored = times_start < t
29
+ not_censored = np.logical_or(not_right_censored, not_left_censored)
30
+
31
+ if not_censored.any() and not_censored.sum() > 1:
32
+ death_as_target = deaths_mask[not_censored].astype(np.float64)
33
+
34
+ X_mask = X[not_censored, :]
35
+
36
+ a = np.dot(X_mask.T, X_mask) + alpha
37
+ b = np.dot(X_mask.T, death_as_target)
38
+
39
+ try:
40
+ w = np.linalg.solve(a, b)
41
+ except:
42
+ w = np.zeros(X.shape[1])
43
+
44
+ hazard_weights[i, :] = w
45
+
46
+ return hazard_weights, hazard_weights_times
47
+
48
+
49
+ _generate_hazards_at_times_from_allen_additive_hazard_weights_sig = nb.types.Array(
50
+ nb.types.float64, 2, "C", False, aligned=True
51
+ )(
52
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
53
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
54
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
55
+ nb.types.int64,
56
+ )
57
+
58
+
59
+ @nb.njit(_generate_hazards_at_times_from_allen_additive_hazard_weights_sig, cache=True)
60
+ def _generate_hazards_at_times_from_allen_additive_hazard_weights(
61
+ X, hazard_weights, hazard_weights_times, max_time
62
+ ):
63
+ hazards = np.zeros((X.shape[0], max_time))
64
+
65
+ for t, w in zip(hazard_weights_times, hazard_weights):
66
+ hazards[:, t] = np.dot(X, w)
67
+
68
+ return hazards
@@ -0,0 +1,24 @@
1
+ import numpy as np
2
+
3
+
4
+ def _get_breslow_base_hazard(
5
+ risk,
6
+ times,
7
+ events,
8
+ max_time,
9
+ ):
10
+ unique_times = np.arange(1, max_time + 1)
11
+ rows_at_risk_at_time = times[:, np.newaxis] > unique_times
12
+
13
+ failure_per_unique_time = np.bincount(
14
+ times.astype(np.int64), events, minlength=max_time + 1
15
+ )[1:]
16
+
17
+ risk_per_time = np.dot(risk, rows_at_risk_at_time)
18
+ base_hazard = np.divide(
19
+ failure_per_unique_time,
20
+ risk_per_time,
21
+ out=np.zeros(max_time),
22
+ where=risk_per_time != 0,
23
+ )
24
+ return base_hazard
@@ -0,0 +1,254 @@
1
+ from functools import reduce
2
+ from itertools import pairwise
3
+ from typing import Literal, Optional
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ import optax
9
+ from optax import GradientTransformationExtraArgs
10
+
11
+
12
+ def get_gradient_updater(
13
+ gradient_updater: Literal[
14
+ "adadelta",
15
+ "adagrad",
16
+ "adam",
17
+ "adamax",
18
+ "rmsprop",
19
+ ] = "adam",
20
+ learning_rate: float = 0.01,
21
+ beta1: float = 0.9,
22
+ beta2: float = 0.999,
23
+ epsilon: float = 0.0000001,
24
+ rho: float = 0.95,
25
+ decay: float = 0.9,
26
+ ) -> GradientTransformationExtraArgs:
27
+ if gradient_updater == "adadelta":
28
+ return optax.adadelta(learning_rate=learning_rate, rho=rho, esp=epsilon)
29
+
30
+ elif gradient_updater == "adagrad":
31
+ return optax.adagrad(learning_rate=learning_rate, eps=epsilon)
32
+
33
+ elif gradient_updater == "adam":
34
+ return optax.adam(learning_rate=learning_rate, b1=beta1, b2=beta2, eps=epsilon)
35
+
36
+ elif gradient_updater == "adamax":
37
+ return optax.adamax(
38
+ learning_rate=learning_rate, b1=beta1, b2=beta1, eps=epsilon
39
+ )
40
+
41
+ else: # gradient_updater == 'rmsprop':
42
+ return optax.rmsprop(learning_rate=learning_rate, decay=decay, eps=epsilon)
43
+
44
+
45
+ def relu_jax(x):
46
+ return jnp.where(x < 0, 0, x)
47
+
48
+
49
+ def _reverse_cumsum_jax(a):
50
+ return jnp.flip(jnp.cumsum(jnp.flip(a)))
51
+
52
+
53
+ def get_init_weights(
54
+ input_n_cols: int,
55
+ hidden_layers: list[int],
56
+ init_dis: Literal["uniform", "normal"] = "uniform",
57
+ ):
58
+ weight_matrix_shapes = list(pairwise([input_n_cols] + hidden_layers))
59
+ if init_dis == "uniform":
60
+ initializer = jax.nn.initializers.he_uniform()
61
+ else:
62
+ initializer = jax.nn.initializers.he_normal()
63
+
64
+ jax_key = jax.random.key(np.random.randint(-10000, 10000))
65
+
66
+ weight_matrix_shapes.append((hidden_layers[-1], 1))
67
+ weights = [initializer(jax_key, shape=ws) for ws in weight_matrix_shapes]
68
+ weights[-1] = weights[-1].flatten()
69
+ return weights
70
+
71
+
72
+ def _get_cox_net_ph_loss(
73
+ weights,
74
+ X_strata: list[np.ndarray[tuple[int, int], np.dtype[np.floating]]],
75
+ n_strata: int,
76
+ events_strata: list[np.ndarray[tuple[int], np.dtype[np.bool_]]],
77
+ time_end_return_inverse_strata: list[np.ndarray[tuple[int], np.dtype[np.integer]]],
78
+ n_unique_times_strata: list[int],
79
+ alpha=0.0,
80
+ l1_ratio=0.5,
81
+ time_start_return_inverse_strata: (
82
+ list[np.ndarray[tuple[int], np.dtype[np.integer]]] | None
83
+ ) = None,
84
+ ):
85
+
86
+ uses_left_censorship = time_start_return_inverse_strata is not None
87
+
88
+ partial_log_likelihood_per_strata = []
89
+
90
+ abs_weights_sum = reduce(lambda a, b: a + b, [jnp.sum(jnp.abs(w)) for w in weights])
91
+ square_weights_sum = reduce(
92
+ lambda a, b: a + b, [jnp.sum(jnp.square(w)) for w in weights]
93
+ )
94
+
95
+ for s_i in range(n_strata):
96
+ X = X_strata[s_i]
97
+ time_end_return_inverse = time_end_return_inverse_strata[s_i]
98
+ n_unique_times = n_unique_times_strata[s_i]
99
+ events = events_strata[s_i]
100
+ if uses_left_censorship:
101
+ time_start_return_inverse = time_start_return_inverse_strata[s_i]
102
+
103
+ matrixs_for_reduce_dot = [X] + weights[:-1]
104
+
105
+ second_to_last_layer = reduce(
106
+ lambda a, b: relu_jax(
107
+ jnp.dot(
108
+ a,
109
+ b,
110
+ )
111
+ ),
112
+ matrixs_for_reduce_dot,
113
+ )
114
+
115
+ o = jnp.dot(second_to_last_layer, weights[-1])
116
+ o_exp = jnp.exp(o)
117
+
118
+ if uses_left_censorship:
119
+ risk_removed_at_time = jnp.bincount(
120
+ time_end_return_inverse, weights=o_exp, minlength=n_unique_times
121
+ )
122
+ risk_added_at_time = jnp.bincount(
123
+ time_start_return_inverse, weights=o_exp, minlength=n_unique_times
124
+ )
125
+ risk_at_time = jnp.cumsum(risk_added_at_time - risk_removed_at_time)
126
+ risk_set = risk_at_time[time_end_return_inverse - 1]
127
+ else:
128
+ risk_set = _reverse_cumsum_jax(
129
+ jnp.bincount(
130
+ time_end_return_inverse, weights=o_exp, minlength=n_unique_times
131
+ )
132
+ )[time_end_return_inverse]
133
+
134
+ loss = -jnp.sum(events * (o - jnp.log(risk_set)))
135
+
136
+ partial_log_likelihood_per_strata.append(loss)
137
+
138
+ l1 = alpha * l1_ratio * abs_weights_sum
139
+ l2 = 0.5 * alpha * (1.0 - l1_ratio) * square_weights_sum
140
+
141
+ return reduce(lambda a, b: a + b, partial_log_likelihood_per_strata + [l1, l2])
142
+
143
+
144
+ def train_cox_net_ph(
145
+ X_strata: list[np.ndarray[tuple[int, int], np.dtype[np.floating]]],
146
+ n_strata: int,
147
+ events_strata: list[np.ndarray[tuple[int], np.dtype[np.bool_]]],
148
+ time_end_return_inverse_strata: list[np.ndarray[tuple[int], np.dtype[np.integer]]],
149
+ n_unique_times_strata: list[int],
150
+ hidden_layers: list[int],
151
+ weights: Optional[list[np.ndarray]] = None,
152
+ alpha: float = 0.0,
153
+ l1_ratio: float = 0.5,
154
+ init_dis: Literal["uniform", "normal"] = "uniform",
155
+ track_loss=True,
156
+ max_iter=100,
157
+ gradient_updater: Literal[
158
+ "adadelta",
159
+ "adagrad",
160
+ "adam",
161
+ "adamax",
162
+ "rmsprop",
163
+ ] = "adam",
164
+ learning_rate: float = 0.01,
165
+ beta1: float = 0.9,
166
+ beta2: float = 0.999,
167
+ epsilon: float = 0.0000001,
168
+ rho: float = 0.95,
169
+ decay: float = 0.9,
170
+ time_start_return_inverse_strata: (
171
+ list[np.ndarray[tuple[int], np.dtype[np.integer]]] | None
172
+ ) = None,
173
+ ) -> tuple[list[np.ndarray], float, list[float]]:
174
+ if weights is None:
175
+ weights = get_init_weights(X_strata[0].shape[1], hidden_layers, init_dis)
176
+
177
+ grad_updater = get_gradient_updater(
178
+ gradient_updater, learning_rate, beta1, beta2, epsilon, rho, decay
179
+ )
180
+
181
+ opt_state = grad_updater.init(weights)
182
+ get_cox_net_ph_grad = jax.grad(_get_cox_net_ph_loss)
183
+
184
+ losses_per_steps = []
185
+ loss = None
186
+
187
+ for i in range(max_iter):
188
+ jacobian = get_cox_net_ph_grad(
189
+ weights,
190
+ X_strata,
191
+ n_strata,
192
+ events_strata,
193
+ time_end_return_inverse_strata,
194
+ n_unique_times_strata,
195
+ alpha,
196
+ l1_ratio,
197
+ time_start_return_inverse_strata,
198
+ )
199
+ updates, opt_state = grad_updater.update(jacobian, opt_state, weights)
200
+ weights = optax.apply_updates(weights, updates)
201
+
202
+ if track_loss:
203
+ loss = _get_cox_net_ph_loss(
204
+ weights,
205
+ X_strata,
206
+ n_strata,
207
+ events_strata,
208
+ time_end_return_inverse_strata,
209
+ n_unique_times_strata,
210
+ alpha,
211
+ l1_ratio,
212
+ time_start_return_inverse_strata,
213
+ ).item()
214
+
215
+ losses_per_steps.append(loss)
216
+
217
+ if loss is None:
218
+ loss = _get_cox_net_ph_loss(
219
+ weights,
220
+ X_strata,
221
+ n_strata,
222
+ events_strata,
223
+ time_end_return_inverse_strata,
224
+ n_unique_times_strata,
225
+ alpha,
226
+ l1_ratio,
227
+ time_start_return_inverse_strata,
228
+ ).item()
229
+
230
+ weights_np = [np.array(w) for w in weights]
231
+
232
+ return weights_np, loss, losses_per_steps
233
+
234
+
235
+ def relu_np(x):
236
+ return np.where(x < 0, 0, x)
237
+
238
+
239
+ def get_relative_risk_from_cox_net_ph_weights(
240
+ X: np.ndarray[tuple[int, int], np.dtype[np.floating]], weights: list[np.ndarray]
241
+ ) -> np.ndarray[tuple[int], np.dtype[np.floating]]:
242
+ matrixs_for_reduce_dot_np = [X] + weights[:-1]
243
+
244
+ second_to_last_layer = reduce(
245
+ lambda a, b: relu_np(
246
+ np.dot(
247
+ a,
248
+ b,
249
+ )
250
+ ),
251
+ matrixs_for_reduce_dot_np,
252
+ )
253
+
254
+ return np.exp(np.dot(second_to_last_layer, weights[-1]))
@@ -0,0 +1,338 @@
1
+ import numba as nb
2
+ import numpy as np
3
+
4
+ from ._stratification import _unique_with_return_inverse
5
+
6
+ get_breslow_neg_log_likelihood_with_elasticnet_penalty_signature = nb.types.float64(
7
+ nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
8
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
9
+ nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
10
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
11
+ nb.types.int64,
12
+ nb.types.float64,
13
+ nb.types.float64,
14
+ nb.types.bool_,
15
+ )
16
+
17
+
18
+ @nb.njit(get_breslow_neg_log_likelihood_with_elasticnet_penalty_signature, cache=True)
19
+ def get_breslow_neg_log_likelihood_with_elasticnet_penalty(
20
+ weights,
21
+ X,
22
+ event,
23
+ time_return_inverse,
24
+ n_unique_times,
25
+ alpha,
26
+ l1_ratio,
27
+ scaled,
28
+ ):
29
+ n = float(X.shape[0])
30
+
31
+ p = np.dot(X, weights)
32
+ p_exp = np.exp(p)
33
+
34
+ risk_set_at_time = np.flip(
35
+ np.cumsum(
36
+ np.flip(
37
+ np.bincount(
38
+ time_return_inverse, weights=p_exp, minlength=n_unique_times
39
+ )
40
+ )
41
+ )
42
+ )
43
+
44
+ risk_set = risk_set_at_time[time_return_inverse]
45
+
46
+ breslow_neg_log_likelihood = -np.sum(event * (p - np.log(risk_set)))
47
+
48
+ l1 = alpha * l1_ratio * np.abs(weights).sum()
49
+ l2 = 0.5 * alpha * (1.0 - l1_ratio) * np.square(weights).sum()
50
+ elasticnet_loss = l1 + l2
51
+
52
+ if scaled:
53
+ return (2 / n * breslow_neg_log_likelihood) + elasticnet_loss
54
+ else:
55
+ return breslow_neg_log_likelihood + elasticnet_loss
56
+
57
+
58
+ @nb.njit
59
+ def soft_threasholding_operator(z, t):
60
+ return np.fmax((np.abs(z) - t), 0) * np.sign(z)
61
+
62
+
63
+ train_cox_elastic_net_signature = nb.types.Tuple(
64
+ (
65
+ nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
66
+ nb.types.float64,
67
+ )
68
+ )(
69
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
70
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
71
+ nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
72
+ nb.types.float64,
73
+ nb.types.float64,
74
+ nb.types.float64,
75
+ nb.types.int64,
76
+ )
77
+
78
+
79
+ @nb.njit(train_cox_elastic_net_signature, cache=True)
80
+ def train_cox_elastic_net_regularization_paths(
81
+ X: np.ndarray,
82
+ times: np.ndarray,
83
+ events: np.ndarray,
84
+ alpha: float,
85
+ l1_ratio: float,
86
+ tol: float,
87
+ n_iter: int,
88
+ ) -> tuple[np.ndarray, float]:
89
+ """A direct implementation of,
90
+ Simon N, Friedman J, Hastie T, Tibshirani R. Regularization Paths for Cox's Proportional Hazards Model via Coordinate Descent. J Stat Softw. 2011 Mar;39(5):1-13. doi: 10.18637/jss.v039.i05. PMID: 27065756; PMCID: PMC4824408.
91
+ """
92
+ unique_times, time_return_inverse = _unique_with_return_inverse(times)
93
+ n_unique_times = len(unique_times)
94
+
95
+ n = X.shape[0]
96
+
97
+ col_indexes = np.arange(X.shape[1])
98
+
99
+ weights = np.zeros(X.shape[1])
100
+
101
+ last_loss = np.inf
102
+
103
+ for _ in range(n_iter):
104
+
105
+ new_weights = weights.copy()
106
+ n_hat = np.dot(X, new_weights)
107
+ n_hat_exp = np.exp(n_hat)
108
+ n_hat_exp_risk_set = np.flip(
109
+ np.cumsum(
110
+ np.flip(
111
+ np.bincount(
112
+ time_return_inverse, n_hat_exp, minlength=n_unique_times
113
+ )
114
+ )
115
+ )
116
+ )[time_return_inverse]
117
+ wn = np.sum(
118
+ events
119
+ * (
120
+ (n_hat_exp * n_hat_exp_risk_set - n_hat_exp**2)
121
+ / (n_hat_exp_risk_set**2)
122
+ )
123
+ )
124
+ zn = n_hat + 1 / wn * (
125
+ events - np.sum(events * (n_hat_exp / n_hat_exp_risk_set))
126
+ )
127
+ for j in range(X.shape[1]):
128
+ not_j = col_indexes[col_indexes != j]
129
+ not_j_n = np.dot(X[:, not_j], new_weights[not_j])
130
+ left = 1 / n * np.sum(wn * X[:, j] * (zn - not_j_n))
131
+ right = alpha * l1_ratio
132
+ top = soft_threasholding_operator(left, right)
133
+ bottom = 1 / n * np.sum(wn * (X[:, j]) ** 2 + alpha * (1 - l1_ratio))
134
+ new_weights[j] = top / bottom
135
+
136
+ loss = get_breslow_neg_log_likelihood_with_elasticnet_penalty(
137
+ weights,
138
+ X,
139
+ events,
140
+ time_return_inverse,
141
+ n_unique_times,
142
+ alpha,
143
+ l1_ratio,
144
+ True,
145
+ )
146
+
147
+ if (loss - last_loss) > tol:
148
+ break
149
+ last_loss = loss
150
+ weights = new_weights
151
+
152
+ final_loss = get_breslow_neg_log_likelihood_with_elasticnet_penalty(
153
+ weights, X, events, time_return_inverse, n_unique_times, alpha, l1_ratio, False
154
+ )
155
+
156
+ return weights, final_loss
157
+
158
+
159
+ get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship_signature = nb.types.float64(
160
+ nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
161
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
162
+ nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
163
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
164
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
165
+ nb.types.int64,
166
+ nb.types.float64,
167
+ nb.types.float64,
168
+ nb.types.bool_,
169
+ )
170
+
171
+
172
+ @nb.njit(
173
+ get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship_signature,
174
+ cache=True,
175
+ )
176
+ def get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
177
+ weights,
178
+ X,
179
+ event,
180
+ time_end_return_inverse,
181
+ time_start_return_inverse,
182
+ n_unique_times,
183
+ alpha,
184
+ l1_ratio,
185
+ scaled,
186
+ ):
187
+ n = float(X.shape[0])
188
+
189
+ p = np.dot(X, weights)
190
+ p_exp = np.exp(p)
191
+
192
+ risk_removed_at_time = np.bincount(
193
+ time_end_return_inverse, weights=p_exp, minlength=n_unique_times
194
+ )
195
+ risk_added_at_time = np.bincount(
196
+ time_start_return_inverse, weights=p_exp, minlength=n_unique_times
197
+ )
198
+ risk_at_time = np.cumsum(risk_added_at_time - risk_removed_at_time)
199
+ risk_set = risk_at_time[time_end_return_inverse - 1]
200
+
201
+ breslow_neg_log_likelihood = -np.sum(event * (p - np.log(risk_set)))
202
+
203
+ l1 = alpha * l1_ratio * np.abs(weights).sum()
204
+ l2 = 0.5 * alpha * (1.0 - l1_ratio) * np.square(weights).sum()
205
+ elasticnet_loss = l1 + l2
206
+
207
+ if scaled:
208
+ return (2 / n * breslow_neg_log_likelihood) + elasticnet_loss
209
+ else:
210
+ return breslow_neg_log_likelihood + elasticnet_loss
211
+
212
+
213
+ train_cox_elastic_net_with_left_censorship_signature = nb.types.Tuple(
214
+ (
215
+ nb.types.Array(nb.types.float64, 1, "C", False, aligned=True),
216
+ nb.types.float64,
217
+ )
218
+ )(
219
+ nb.types.Array(nb.types.float64, 2, "C", False, aligned=True),
220
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
221
+ nb.types.Array(nb.types.int64, 1, "C", False, aligned=True),
222
+ nb.types.Array(nb.types.bool_, 1, "C", False, aligned=True),
223
+ nb.types.float64,
224
+ nb.types.float64,
225
+ nb.types.float64,
226
+ nb.types.int64,
227
+ )
228
+
229
+
230
+ @nb.njit(train_cox_elastic_net_with_left_censorship_signature, cache=True)
231
+ def train_cox_elastic_net_with_left_censorship(
232
+ X: np.ndarray,
233
+ times: np.ndarray,
234
+ times_start: np.ndarray,
235
+ events: np.ndarray,
236
+ alpha: float,
237
+ l1_ratio: float,
238
+ tol: float,
239
+ n_iter: int,
240
+ ) -> tuple[np.ndarray, float]:
241
+ """A direct implementation of,
242
+ Simon N, Friedman J, Hastie T, Tibshirani R. Regularization Paths for Cox's Proportional Hazards Model via Coordinate Descent. J Stat Softw. 2011 Mar;39(5):1-13. doi: 10.18637/jss.v039.i05. PMID: 27065756; PMCID: PMC4824408.
243
+ """
244
+ all_times = np.concatenate((times_start, times))
245
+ unique_times, unique_times_return_inverse = _unique_with_return_inverse(all_times)
246
+
247
+ time_end_return_inverse = unique_times_return_inverse[len(times) :]
248
+ time_start_return_inverse = unique_times_return_inverse[: len(times_start)]
249
+
250
+ n_unique_times = len(unique_times)
251
+
252
+ n = X.shape[0]
253
+
254
+ col_indexes = np.arange(X.shape[1])
255
+
256
+ weights = np.zeros(X.shape[1])
257
+
258
+ last_loss = np.inf
259
+
260
+ for _ in range(n_iter):
261
+
262
+ new_weights = weights.copy()
263
+ n_hat = np.dot(X, new_weights)
264
+
265
+ n_hat_exp = np.exp(n_hat)
266
+
267
+ # n_hat_exp_risk_set = np.flip(
268
+ # np.cumsum(
269
+ # np.flip(
270
+ # np.bincount(
271
+ # time_return_inverse, n_hat_exp, minlength=n_unique_times
272
+ # )
273
+ # )
274
+ # )
275
+ # )[time_return_inverse]
276
+
277
+ risk_removed_at_time = np.bincount(
278
+ time_end_return_inverse, weights=n_hat_exp, minlength=n_unique_times
279
+ )
280
+ risk_added_at_time = np.bincount(
281
+ time_start_return_inverse, weights=n_hat_exp, minlength=n_unique_times
282
+ )
283
+ risk_at_time = np.cumsum(risk_added_at_time - risk_removed_at_time)
284
+ n_hat_exp_risk_set = risk_at_time[time_end_return_inverse - 1]
285
+
286
+ wn = np.sum(
287
+ events
288
+ * (
289
+ (n_hat_exp * n_hat_exp_risk_set - n_hat_exp**2)
290
+ / (n_hat_exp_risk_set**2)
291
+ )
292
+ )
293
+ zn = n_hat + 1 / wn * (
294
+ events - np.sum(events * (n_hat_exp / n_hat_exp_risk_set))
295
+ )
296
+ for j in range(X.shape[1]):
297
+ not_j = col_indexes[col_indexes != j]
298
+ not_j_n = np.dot(X[:, not_j], new_weights[not_j])
299
+ left = 1 / n * np.sum(wn * X[:, j] * (zn - not_j_n))
300
+ right = alpha * l1_ratio
301
+ top = soft_threasholding_operator(left, right)
302
+ bottom = 1 / n * np.sum(wn * (X[:, j]) ** 2 + alpha * (1 - l1_ratio))
303
+ new_weights[j] = top / bottom
304
+
305
+ loss = (
306
+ get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
307
+ weights,
308
+ X,
309
+ events,
310
+ time_end_return_inverse,
311
+ time_start_return_inverse,
312
+ n_unique_times,
313
+ alpha,
314
+ l1_ratio,
315
+ True,
316
+ )
317
+ )
318
+
319
+ if (loss - last_loss) > tol:
320
+ break
321
+ last_loss = loss
322
+ weights = new_weights
323
+
324
+ final_loss = (
325
+ get_breslow_neg_log_likelihood_with_elasticnet_penalty_with_left_censorship(
326
+ weights,
327
+ X,
328
+ events,
329
+ time_end_return_inverse,
330
+ time_start_return_inverse,
331
+ n_unique_times,
332
+ alpha,
333
+ l1_ratio,
334
+ False,
335
+ )
336
+ )
337
+
338
+ return weights, final_loss