survival 1.1.36__cp314-cp314-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1308 @@
1
+ # ruff: noqa: N803, N806, UP037
2
+ from __future__ import annotations
3
+
4
+ from collections.abc import Iterator
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ from survival import _survival as _surv
10
+
11
+ if TYPE_CHECKING:
12
+ from numpy.typing import ArrayLike, NDArray
13
+
14
+ try:
15
+ from sklearn.base import BaseEstimator, RegressorMixin
16
+ from sklearn.utils.validation import check_array, check_is_fitted
17
+
18
+ _HAS_SKLEARN = True
19
+ except ImportError:
20
+ _HAS_SKLEARN = False
21
+
22
+ class BaseEstimator:
23
+ def get_params(self, deep: bool = True) -> dict:
24
+ return {
25
+ k: getattr(self, k)
26
+ for k in self.__init__.__code__.co_varnames[1 : self.__init__.__code__.co_argcount]
27
+ }
28
+
29
+ def set_params(self, **params) -> "BaseEstimator":
30
+ for key, value in params.items():
31
+ setattr(self, key, value)
32
+ return self
33
+
34
+ class RegressorMixin:
35
+ pass
36
+
37
+ def check_array(X, **kwargs):
38
+ return np.asarray(X)
39
+
40
+ def check_is_fitted(estimator, attributes=None):
41
+ if not hasattr(estimator, "is_fitted_") or not estimator.is_fitted_:
42
+ raise ValueError(f"{type(estimator).__name__} is not fitted yet.")
43
+
44
+
45
+ def _validate_survival_data(
46
+ X: ArrayLike, y: ArrayLike
47
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32]]:
48
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
49
+ y = np.asarray(y)
50
+
51
+ if y.ndim == 1:
52
+ raise ValueError("y must be a 2D array with columns [time, status]")
53
+ if y.shape[1] != 2:
54
+ raise ValueError("y must have exactly 2 columns: [time, status]")
55
+
56
+ time = y[:, 0].astype(np.float64)
57
+ status = y[:, 1].astype(np.int32)
58
+
59
+ if X.shape[0] != len(time):
60
+ raise ValueError(f"X has {X.shape[0]} samples, but y has {len(time)} samples")
61
+
62
+ return X, time, status
63
+
64
+
65
+ def _compute_concordance_index(
66
+ time: NDArray[np.float64],
67
+ status: NDArray[np.int32],
68
+ risk_scores: NDArray[np.float64],
69
+ ) -> float:
70
+ """Compute concordance index (C-index) for survival predictions."""
71
+ n = len(time)
72
+ concordant = 0.0
73
+ comparable = 0.0
74
+
75
+ for i in range(n):
76
+ if status[i] == 0:
77
+ continue
78
+ for j in range(n):
79
+ if i == j:
80
+ continue
81
+ if time[i] < time[j]:
82
+ comparable += 1.0
83
+ if risk_scores[i] > risk_scores[j]:
84
+ concordant += 1.0
85
+ elif risk_scores[i] == risk_scores[j]:
86
+ concordant += 0.5
87
+
88
+ return concordant / comparable if comparable > 0 else 0.5
89
+
90
+
91
+ class SurvivalScoreMixin:
92
+ """Mixin providing concordance index scoring for survival models."""
93
+
94
+ def score(self, X: ArrayLike, y: ArrayLike) -> float:
95
+ """Return the concordance index on the given test data.
96
+
97
+ Parameters
98
+ ----------
99
+ X : array-like of shape (n_samples, n_features)
100
+ Test samples.
101
+ y : array-like of shape (n_samples, 2)
102
+ True target values.
103
+
104
+ Returns
105
+ -------
106
+ score : float
107
+ Concordance index (C-index), between 0 and 1.
108
+ """
109
+ check_is_fitted(self)
110
+ X, time, status = _validate_survival_data(X, y)
111
+ risk_scores = self.predict(X)
112
+ return _compute_concordance_index(time, status, risk_scores)
113
+
114
+
115
+ class CoxPHEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
116
+ """Scikit-learn compatible Cox Proportional Hazards model.
117
+
118
+ Parameters
119
+ ----------
120
+ n_iters : int, default=20
121
+ Maximum number of iterations for the Newton-Raphson optimization.
122
+
123
+ Attributes
124
+ ----------
125
+ model_ : CoxPHModel
126
+ The underlying fitted Cox model.
127
+ coef_ : ndarray of shape (n_features,)
128
+ Estimated coefficients.
129
+ n_features_in_ : int
130
+ Number of features seen during fit.
131
+
132
+ Examples
133
+ --------
134
+ >>> from survival.sklearn_compat import CoxPHEstimator
135
+ >>> import numpy as np
136
+ >>> X = np.random.randn(100, 3)
137
+ >>> y = np.column_stack([np.random.exponential(10, 100), np.random.binomial(1, 0.7, 100)])
138
+ >>> model = CoxPHEstimator()
139
+ >>> model.fit(X, y)
140
+ >>> risk_scores = model.predict(X)
141
+ """
142
+
143
+ def __init__(self, n_iters: int = 20):
144
+ self.n_iters = n_iters
145
+
146
+ def fit(self, X: ArrayLike, y: ArrayLike) -> "CoxPHEstimator":
147
+ """Fit the Cox PH model.
148
+
149
+ Parameters
150
+ ----------
151
+ X : array-like of shape (n_samples, n_features)
152
+ Training data.
153
+ y : array-like of shape (n_samples, 2)
154
+ Target values where y[:, 0] is survival time and y[:, 1] is event status.
155
+
156
+ Returns
157
+ -------
158
+ self : CoxPHEstimator
159
+ Fitted estimator.
160
+ """
161
+ X, time, status = _validate_survival_data(X, y)
162
+ self.n_features_in_ = X.shape[1]
163
+
164
+ covariates = X.tolist()
165
+ self.model_ = _surv.CoxPHModel.new_with_data(covariates, time.tolist(), status.tolist())
166
+ self.model_.fit(self.n_iters)
167
+
168
+ self.coef_ = np.array(self.model_.coefficients)
169
+ self.is_fitted_ = True
170
+ return self
171
+
172
+ def predict(self, X: ArrayLike) -> NDArray[np.float64]:
173
+ """Predict risk scores for samples.
174
+
175
+ Parameters
176
+ ----------
177
+ X : array-like of shape (n_samples, n_features)
178
+ Samples to predict.
179
+
180
+ Returns
181
+ -------
182
+ risk_scores : ndarray of shape (n_samples,)
183
+ Predicted risk scores (higher = higher risk).
184
+ """
185
+ check_is_fitted(self)
186
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
187
+
188
+ if X.shape[1] != self.n_features_in_:
189
+ raise ValueError(
190
+ f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
191
+ )
192
+
193
+ return np.array(self.model_.predict(X.tolist()))
194
+
195
+ def predict_survival_function(
196
+ self, X: ArrayLike, times: ArrayLike | None = None
197
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
198
+ """Predict survival function for samples.
199
+
200
+ Parameters
201
+ ----------
202
+ X : array-like of shape (n_samples, n_features)
203
+ Samples to predict.
204
+ times : array-like of shape (n_times,), optional
205
+ Time points at which to evaluate the survival function.
206
+
207
+ Returns
208
+ -------
209
+ times : ndarray of shape (n_times,)
210
+ Time points.
211
+ survival : ndarray of shape (n_samples, n_times)
212
+ Survival probabilities.
213
+ """
214
+ check_is_fitted(self)
215
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
216
+
217
+ times_list = times.tolist() if times is not None else None
218
+ t, surv = self.model_.survival_curve(X.tolist(), times_list)
219
+ return np.array(t), np.array(surv)
220
+
221
+ def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
222
+ """Predict median survival time for samples.
223
+
224
+ Parameters
225
+ ----------
226
+ X : array-like of shape (n_samples, n_features)
227
+ Samples to predict.
228
+
229
+ Returns
230
+ -------
231
+ median_times : ndarray of shape (n_samples,)
232
+ Predicted median survival times (NaN if survival never drops below 0.5).
233
+ """
234
+ check_is_fitted(self)
235
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
236
+
237
+ result = self.model_.predicted_survival_time(X.tolist(), 0.5)
238
+ return np.array([t if t is not None else np.nan for t in result])
239
+
240
+
241
+ class GradientBoostSurvivalEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
242
+ """Scikit-learn compatible Gradient Boosting Survival model.
243
+
244
+ Parameters
245
+ ----------
246
+ n_estimators : int, default=100
247
+ Number of boosting iterations.
248
+ learning_rate : float, default=0.1
249
+ Learning rate shrinks the contribution of each tree.
250
+ max_depth : int, default=3
251
+ Maximum depth of the individual regression trees.
252
+ min_samples_split : int, default=10
253
+ Minimum number of samples required to split an internal node.
254
+ min_samples_leaf : int, default=5
255
+ Minimum number of samples required at each leaf node.
256
+ subsample : float, default=1.0
257
+ Fraction of samples used for fitting individual trees.
258
+ max_features : int or None, default=None
259
+ Number of features to consider for splits.
260
+ seed : int or None, default=None
261
+ Random seed for reproducibility.
262
+
263
+ Attributes
264
+ ----------
265
+ model_ : GradientBoostSurvival
266
+ The underlying fitted model.
267
+ feature_importances_ : ndarray of shape (n_features,)
268
+ Feature importances.
269
+ n_features_in_ : int
270
+ Number of features seen during fit.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ n_estimators: int = 100,
276
+ learning_rate: float = 0.1,
277
+ max_depth: int = 3,
278
+ min_samples_split: int = 10,
279
+ min_samples_leaf: int = 5,
280
+ subsample: float = 1.0,
281
+ max_features: int | None = None,
282
+ seed: int | None = None,
283
+ ):
284
+ self.n_estimators = n_estimators
285
+ self.learning_rate = learning_rate
286
+ self.max_depth = max_depth
287
+ self.min_samples_split = min_samples_split
288
+ self.min_samples_leaf = min_samples_leaf
289
+ self.subsample = subsample
290
+ self.max_features = max_features
291
+ self.seed = seed
292
+
293
+ def fit(self, X: ArrayLike, y: ArrayLike) -> "GradientBoostSurvivalEstimator":
294
+ """Fit the gradient boosting survival model.
295
+
296
+ Parameters
297
+ ----------
298
+ X : array-like of shape (n_samples, n_features)
299
+ Training data.
300
+ y : array-like of shape (n_samples, 2)
301
+ Target values where y[:, 0] is survival time and y[:, 1] is event status.
302
+
303
+ Returns
304
+ -------
305
+ self : GradientBoostSurvivalEstimator
306
+ Fitted estimator.
307
+ """
308
+ X, time, status = _validate_survival_data(X, y)
309
+ self.n_features_in_ = X.shape[1]
310
+ n_obs = X.shape[0]
311
+
312
+ config = _surv.GradientBoostSurvivalConfig(
313
+ n_estimators=self.n_estimators,
314
+ learning_rate=self.learning_rate,
315
+ max_depth=self.max_depth,
316
+ min_samples_split=self.min_samples_split,
317
+ min_samples_leaf=self.min_samples_leaf,
318
+ subsample=self.subsample,
319
+ max_features=self.max_features,
320
+ seed=self.seed,
321
+ )
322
+
323
+ x_flat = X.flatten().tolist()
324
+ self.model_ = _surv.GradientBoostSurvival.fit(
325
+ x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
326
+ )
327
+
328
+ self.feature_importances_ = np.array(self.model_.feature_importance)
329
+ self.is_fitted_ = True
330
+ return self
331
+
332
+ def predict(self, X: ArrayLike) -> NDArray[np.float64]:
333
+ """Predict risk scores for samples.
334
+
335
+ Parameters
336
+ ----------
337
+ X : array-like of shape (n_samples, n_features)
338
+ Samples to predict.
339
+
340
+ Returns
341
+ -------
342
+ risk_scores : ndarray of shape (n_samples,)
343
+ Predicted risk scores (higher = higher risk).
344
+ """
345
+ check_is_fitted(self)
346
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
347
+
348
+ if X.shape[1] != self.n_features_in_:
349
+ raise ValueError(
350
+ f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
351
+ )
352
+
353
+ x_flat = X.flatten().tolist()
354
+ return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
355
+
356
+ def predict_survival_function(
357
+ self, X: ArrayLike
358
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
359
+ """Predict survival function for samples.
360
+
361
+ Parameters
362
+ ----------
363
+ X : array-like of shape (n_samples, n_features)
364
+ Samples to predict.
365
+
366
+ Returns
367
+ -------
368
+ times : ndarray of shape (n_times,)
369
+ Time points.
370
+ survival : ndarray of shape (n_samples, n_times)
371
+ Survival probabilities.
372
+ """
373
+ check_is_fitted(self)
374
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
375
+
376
+ x_flat = X.flatten().tolist()
377
+ survival = self.model_.predict_survival(x_flat, X.shape[0])
378
+ return np.array(self.model_.unique_times), np.array(survival)
379
+
380
+ def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
381
+ """Predict median survival time for samples.
382
+
383
+ Parameters
384
+ ----------
385
+ X : array-like of shape (n_samples, n_features)
386
+ Samples to predict.
387
+
388
+ Returns
389
+ -------
390
+ median_times : ndarray of shape (n_samples,)
391
+ Predicted median survival times (NaN if survival never drops below 0.5).
392
+ """
393
+ check_is_fitted(self)
394
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
395
+
396
+ x_flat = X.flatten().tolist()
397
+ result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
398
+ return np.array([t if t is not None else np.nan for t in result])
399
+
400
+
401
+ class SurvivalForestEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
402
+ """Scikit-learn compatible Random Survival Forest model.
403
+
404
+ Parameters
405
+ ----------
406
+ n_trees : int, default=500
407
+ Number of trees in the forest.
408
+ max_depth : int or None, default=None
409
+ Maximum depth of trees.
410
+ min_node_size : int, default=15
411
+ Minimum number of samples at each leaf node.
412
+ mtry : int or None, default=None
413
+ Number of features to consider at each split (default: sqrt(n_features)).
414
+ sample_fraction : float, default=0.632
415
+ Fraction of samples used for each tree.
416
+ seed : int or None, default=None
417
+ Random seed for reproducibility.
418
+ oob_error : bool, default=True
419
+ Whether to compute out-of-bag error.
420
+
421
+ Attributes
422
+ ----------
423
+ model_ : SurvivalForest
424
+ The underlying fitted model.
425
+ variable_importance_ : ndarray of shape (n_features,)
426
+ Variable importances.
427
+ oob_error_ : float or None
428
+ Out-of-bag error (if computed).
429
+ n_features_in_ : int
430
+ Number of features seen during fit.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ n_trees: int = 500,
436
+ max_depth: int | None = None,
437
+ min_node_size: int = 15,
438
+ mtry: int | None = None,
439
+ sample_fraction: float = 0.632,
440
+ seed: int | None = None,
441
+ oob_error: bool = True,
442
+ ):
443
+ self.n_trees = n_trees
444
+ self.max_depth = max_depth
445
+ self.min_node_size = min_node_size
446
+ self.mtry = mtry
447
+ self.sample_fraction = sample_fraction
448
+ self.seed = seed
449
+ self.oob_error = oob_error
450
+
451
+ def fit(self, X: ArrayLike, y: ArrayLike) -> "SurvivalForestEstimator":
452
+ """Fit the random survival forest model.
453
+
454
+ Parameters
455
+ ----------
456
+ X : array-like of shape (n_samples, n_features)
457
+ Training data.
458
+ y : array-like of shape (n_samples, 2)
459
+ Target values where y[:, 0] is survival time and y[:, 1] is event status.
460
+
461
+ Returns
462
+ -------
463
+ self : SurvivalForestEstimator
464
+ Fitted estimator.
465
+ """
466
+ X, time, status = _validate_survival_data(X, y)
467
+ self.n_features_in_ = X.shape[1]
468
+ n_obs = X.shape[0]
469
+
470
+ config = _surv.SurvivalForestConfig(
471
+ n_trees=self.n_trees,
472
+ max_depth=self.max_depth,
473
+ min_node_size=self.min_node_size,
474
+ mtry=self.mtry,
475
+ sample_fraction=self.sample_fraction,
476
+ seed=self.seed,
477
+ oob_error=self.oob_error,
478
+ )
479
+
480
+ x_flat = X.flatten().tolist()
481
+ self.model_ = _surv.SurvivalForest.fit(
482
+ x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
483
+ )
484
+
485
+ self.variable_importance_ = np.array(self.model_.variable_importance)
486
+ self.oob_error_ = self.model_.oob_error
487
+ self.is_fitted_ = True
488
+ return self
489
+
490
+ def predict(self, X: ArrayLike) -> NDArray[np.float64]:
491
+ """Predict risk scores for samples.
492
+
493
+ Parameters
494
+ ----------
495
+ X : array-like of shape (n_samples, n_features)
496
+ Samples to predict.
497
+
498
+ Returns
499
+ -------
500
+ risk_scores : ndarray of shape (n_samples,)
501
+ Predicted risk scores (cumulative hazard at last time point).
502
+ """
503
+ check_is_fitted(self)
504
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
505
+
506
+ if X.shape[1] != self.n_features_in_:
507
+ raise ValueError(
508
+ f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
509
+ )
510
+
511
+ x_flat = X.flatten().tolist()
512
+ return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
513
+
514
+ def predict_survival_function(
515
+ self, X: ArrayLike
516
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
517
+ """Predict survival function for samples.
518
+
519
+ Parameters
520
+ ----------
521
+ X : array-like of shape (n_samples, n_features)
522
+ Samples to predict.
523
+
524
+ Returns
525
+ -------
526
+ times : ndarray of shape (n_times,)
527
+ Time points.
528
+ survival : ndarray of shape (n_samples, n_times)
529
+ Survival probabilities.
530
+ """
531
+ check_is_fitted(self)
532
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
533
+
534
+ x_flat = X.flatten().tolist()
535
+ survival = self.model_.predict_survival(x_flat, X.shape[0])
536
+ return np.array(self.model_.unique_times), np.array(survival)
537
+
538
+ def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
539
+ """Predict median survival time for samples.
540
+
541
+ Parameters
542
+ ----------
543
+ X : array-like of shape (n_samples, n_features)
544
+ Samples to predict.
545
+
546
+ Returns
547
+ -------
548
+ median_times : ndarray of shape (n_samples,)
549
+ Predicted median survival times (NaN if survival never drops below 0.5).
550
+ """
551
+ check_is_fitted(self)
552
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
553
+
554
+ x_flat = X.flatten().tolist()
555
+ result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
556
+ return np.array([t if t is not None else np.nan for t in result])
557
+
558
+
559
+ class AFTEstimator(BaseEstimator, RegressorMixin):
560
+ """Scikit-learn compatible Accelerated Failure Time (AFT) model.
561
+
562
+ AFT models assume that covariates act multiplicatively on the survival time,
563
+ i.e., log(T) = X @ beta + sigma * epsilon, where epsilon follows a specified
564
+ error distribution.
565
+
566
+ Parameters
567
+ ----------
568
+ distribution : str, default="weibull"
569
+ Error distribution. One of:
570
+ - "weibull": Weibull distribution (extreme value errors)
571
+ - "lognormal": Log-normal distribution (Gaussian errors)
572
+ - "loglogistic": Log-logistic distribution (logistic errors)
573
+ - "exponential": Exponential distribution (special case of Weibull)
574
+ - "gaussian": Gaussian distribution (for linear models)
575
+ - "logistic": Logistic distribution (for linear models)
576
+ max_iter : int, default=100
577
+ Maximum number of iterations for optimization.
578
+ tol : float, default=1e-9
579
+ Convergence tolerance.
580
+
581
+ Attributes
582
+ ----------
583
+ model_ : SurvivalFit
584
+ The underlying fitted AFT model.
585
+ coef_ : ndarray of shape (n_features,)
586
+ Estimated coefficients (acceleration factors in log scale).
587
+ scale_ : float
588
+ Estimated scale parameter (sigma).
589
+ n_features_in_ : int
590
+ Number of features seen during fit.
591
+
592
+ Examples
593
+ --------
594
+ >>> from survival.sklearn_compat import AFTEstimator
595
+ >>> import numpy as np
596
+ >>> X = np.random.randn(100, 3)
597
+ >>> y = np.column_stack([np.random.exponential(10, 100), np.random.binomial(1, 0.7, 100)])
598
+ >>> model = AFTEstimator(distribution="weibull")
599
+ >>> model.fit(X, y)
600
+ >>> predicted_times = model.predict(X)
601
+
602
+ Notes
603
+ -----
604
+ The AFT model interprets coefficients as acceleration factors:
605
+ - Positive coefficients increase expected survival time
606
+ - Negative coefficients decrease expected survival time
607
+ - exp(coef) gives the multiplicative effect on survival time
608
+ """
609
+
610
+ def __init__(
611
+ self,
612
+ distribution: str = "weibull",
613
+ max_iter: int = 200,
614
+ tol: float = 1e-9,
615
+ ):
616
+ self.distribution = distribution
617
+ self.max_iter = max_iter
618
+ self.tol = tol
619
+
620
+ def fit(self, X: ArrayLike, y: ArrayLike) -> "AFTEstimator":
621
+ """Fit the AFT model using maximum likelihood estimation.
622
+
623
+ Parameters
624
+ ----------
625
+ X : array-like of shape (n_samples, n_features)
626
+ Training data.
627
+ y : array-like of shape (n_samples, 2)
628
+ Target values where y[:, 0] is survival time and y[:, 1] is event status.
629
+
630
+ Returns
631
+ -------
632
+ self : AFTEstimator
633
+ Fitted estimator.
634
+ """
635
+ X, time, status = _validate_survival_data(X, y)
636
+ self.n_features_in_ = X.shape[1]
637
+ n = len(time)
638
+
639
+ events = status == 1
640
+ n_events = events.sum()
641
+
642
+ if n_events < X.shape[1] + 1:
643
+ raise ValueError(
644
+ f"Not enough events ({n_events}) to fit model with {X.shape[1]} features"
645
+ )
646
+
647
+ X_with_intercept = np.column_stack([np.ones(n), X])
648
+
649
+ self.model_ = _surv.survreg(
650
+ time=time.tolist(),
651
+ status=status.tolist(),
652
+ covariates=X_with_intercept.tolist(),
653
+ distribution=self.distribution,
654
+ max_iter=self.max_iter,
655
+ eps=self.tol,
656
+ )
657
+
658
+ coefs = np.array(self.model_.coefficients)
659
+ self.intercept_ = coefs[0]
660
+ self.coef_ = coefs[1:-1]
661
+ self.scale_ = np.exp(coefs[-1])
662
+ self.converged_ = self.model_.convergence_flag == 0
663
+
664
+ self.is_fitted_ = True
665
+ return self
666
+
667
+ def predict(self, X: ArrayLike) -> NDArray[np.float64]:
668
+ """Predict expected survival time for samples.
669
+
670
+ Parameters
671
+ ----------
672
+ X : array-like of shape (n_samples, n_features)
673
+ Samples to predict.
674
+
675
+ Returns
676
+ -------
677
+ survival_times : ndarray of shape (n_samples,)
678
+ Predicted survival times (median by default).
679
+ """
680
+ check_is_fitted(self)
681
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
682
+
683
+ if X.shape[1] != self.n_features_in_:
684
+ raise ValueError(
685
+ f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
686
+ )
687
+
688
+ linear_pred = self.intercept_ + X @ self.coef_
689
+ return np.exp(linear_pred)
690
+
691
+ def predict_median(self, X: ArrayLike) -> NDArray[np.float64]:
692
+ """Predict median survival time for samples.
693
+
694
+ Parameters
695
+ ----------
696
+ X : array-like of shape (n_samples, n_features)
697
+ Samples to predict.
698
+
699
+ Returns
700
+ -------
701
+ median_times : ndarray of shape (n_samples,)
702
+ Predicted median survival times.
703
+ """
704
+ check_is_fitted(self)
705
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
706
+
707
+ linear_pred = self.intercept_ + X @ self.coef_
708
+
709
+ if self.distribution in ("weibull", "exponential", "extreme_value"):
710
+ median_z = np.log(np.log(2))
711
+ elif self.distribution in ("lognormal", "gaussian", "loglogistic", "logistic"):
712
+ median_z = 0.0
713
+ else:
714
+ median_z = 0.0
715
+
716
+ return np.exp(linear_pred + self.scale_ * median_z)
717
+
718
+ def predict_quantile(self, X: ArrayLike, q: float = 0.5) -> NDArray[np.float64]:
719
+ """Predict survival time quantile for samples.
720
+
721
+ Parameters
722
+ ----------
723
+ X : array-like of shape (n_samples, n_features)
724
+ Samples to predict.
725
+ q : float, default=0.5
726
+ Quantile to predict (0 < q < 1). Default is median (0.5).
727
+
728
+ Returns
729
+ -------
730
+ quantile_times : ndarray of shape (n_samples,)
731
+ Predicted survival times at the given quantile.
732
+ """
733
+ check_is_fitted(self)
734
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
735
+
736
+ if not 0 < q < 1:
737
+ raise ValueError("q must be between 0 and 1")
738
+
739
+ linear_pred = self.intercept_ + X @ self.coef_
740
+
741
+ if self.distribution in ("weibull", "exponential", "extreme_value"):
742
+ z_q = np.log(-np.log(1 - q))
743
+ elif self.distribution in ("lognormal", "gaussian"):
744
+ from scipy.stats import norm
745
+
746
+ z_q = norm.ppf(q)
747
+ elif self.distribution in ("loglogistic", "logistic"):
748
+ z_q = np.log(q / (1 - q))
749
+ else:
750
+ z_q = 0.0
751
+
752
+ return np.exp(linear_pred + self.scale_ * z_q)
753
+
754
+ def score(self, X: ArrayLike, y: ArrayLike) -> float:
755
+ """Return the concordance index on the given test data.
756
+
757
+ Parameters
758
+ ----------
759
+ X : array-like of shape (n_samples, n_features)
760
+ Test samples.
761
+ y : array-like of shape (n_samples, 2)
762
+ True target values.
763
+
764
+ Returns
765
+ -------
766
+ score : float
767
+ Concordance index (C-index), between 0 and 1.
768
+ """
769
+ check_is_fitted(self)
770
+ X, time, status = _validate_survival_data(X, y)
771
+ predicted_times = self.predict(X)
772
+ return _compute_concordance_index(time, status, -predicted_times)
773
+
774
+ @property
775
+ def acceleration_factors(self) -> NDArray[np.float64]:
776
+ """Return acceleration factors (exp of coefficients).
777
+
778
+ Returns
779
+ -------
780
+ af : ndarray of shape (n_features,)
781
+ Acceleration factors. Values > 1 increase survival time,
782
+ values < 1 decrease survival time.
783
+ """
784
+ check_is_fitted(self)
785
+ return np.exp(self.coef_)
786
+
787
+
788
+ def iter_chunks(X: ArrayLike, batch_size: int = 1000) -> Iterator[tuple[int, NDArray[np.float64]]]:
789
+ """Iterate over an array in chunks.
790
+
791
+ Parameters
792
+ ----------
793
+ X : array-like of shape (n_samples, n_features)
794
+ Input array.
795
+ batch_size : int, default=1000
796
+ Number of samples per chunk.
797
+
798
+ Yields
799
+ ------
800
+ start_idx : int
801
+ Starting index of the chunk.
802
+ chunk : ndarray
803
+ Chunk of the input array.
804
+
805
+ Examples
806
+ --------
807
+ >>> import numpy as np
808
+ >>> X = np.random.randn(10000, 5)
809
+ >>> for start_idx, chunk in iter_chunks(X, batch_size=1000):
810
+ ... print(f"Processing samples {start_idx} to {start_idx + len(chunk)}")
811
+ """
812
+ X = np.asarray(X)
813
+ n_samples = X.shape[0]
814
+ for start_idx in range(0, n_samples, batch_size):
815
+ end_idx = min(start_idx + batch_size, n_samples)
816
+ yield start_idx, X[start_idx:end_idx]
817
+
818
+
819
+ class StreamingMixin:
820
+ """Mixin class providing streaming/batched prediction methods."""
821
+
822
+ def predict_batched(
823
+ self, X: ArrayLike, batch_size: int = 1000
824
+ ) -> Iterator[NDArray[np.float64]]:
825
+ """Predict risk scores in batches to handle large datasets.
826
+
827
+ Parameters
828
+ ----------
829
+ X : array-like of shape (n_samples, n_features)
830
+ Samples to predict.
831
+ batch_size : int, default=1000
832
+ Number of samples per batch.
833
+
834
+ Yields
835
+ ------
836
+ risk_scores : ndarray of shape (batch_size,) or smaller for last batch
837
+ Predicted risk scores for each batch.
838
+
839
+ Examples
840
+ --------
841
+ >>> model = GradientBoostSurvivalEstimator()
842
+ >>> model.fit(X_train, y_train)
843
+ >>> all_predictions = []
844
+ >>> for batch_preds in model.predict_batched(X_large, batch_size=5000):
845
+ ... all_predictions.append(batch_preds)
846
+ >>> predictions = np.concatenate(all_predictions)
847
+ """
848
+ for _, chunk in iter_chunks(X, batch_size):
849
+ yield self.predict(chunk)
850
+
851
+ def predict_survival_batched(
852
+ self, X: ArrayLike, batch_size: int = 1000
853
+ ) -> Iterator[tuple[NDArray[np.float64], NDArray[np.float64]]]:
854
+ """Predict survival functions in batches.
855
+
856
+ Parameters
857
+ ----------
858
+ X : array-like of shape (n_samples, n_features)
859
+ Samples to predict.
860
+ batch_size : int, default=1000
861
+ Number of samples per batch.
862
+
863
+ Yields
864
+ ------
865
+ times : ndarray of shape (n_times,)
866
+ Time points (same for all batches).
867
+ survival : ndarray of shape (batch_size, n_times)
868
+ Survival probabilities for each batch.
869
+ """
870
+ for _, chunk in iter_chunks(X, batch_size):
871
+ yield self.predict_survival_function(chunk)
872
+
873
+ def predict_to_array(
874
+ self, X: ArrayLike, batch_size: int = 1000, out: NDArray | None = None
875
+ ) -> NDArray[np.float64]:
876
+ """Predict risk scores with optional pre-allocated output array.
877
+
878
+ This method is memory-efficient for large datasets as it can write
879
+ directly to a pre-allocated array or memory-mapped file.
880
+
881
+ Parameters
882
+ ----------
883
+ X : array-like of shape (n_samples, n_features)
884
+ Samples to predict.
885
+ batch_size : int, default=1000
886
+ Number of samples per batch.
887
+ out : ndarray of shape (n_samples,), optional
888
+ Pre-allocated output array. If None, a new array is created.
889
+
890
+ Returns
891
+ -------
892
+ risk_scores : ndarray of shape (n_samples,)
893
+ Predicted risk scores.
894
+
895
+ Examples
896
+ --------
897
+ >>> # Using with memory-mapped array for very large datasets
898
+ >>> import numpy as np
899
+ >>> out = np.memmap('predictions.dat', dtype='float64', mode='w+', shape=(1000000,))
900
+ >>> model.predict_to_array(X_large, batch_size=10000, out=out)
901
+ >>> out.flush() # Write to disk
902
+ """
903
+ X = np.asarray(X)
904
+ n_samples = X.shape[0]
905
+
906
+ if out is None:
907
+ out = np.empty(n_samples, dtype=np.float64)
908
+ elif out.shape[0] != n_samples:
909
+ raise ValueError(f"out has shape {out.shape}, expected ({n_samples},)")
910
+
911
+ for start_idx, chunk in iter_chunks(X, batch_size):
912
+ end_idx = start_idx + chunk.shape[0]
913
+ out[start_idx:end_idx] = self.predict(chunk)
914
+
915
+ return out
916
+
917
+
918
+ class StreamingCoxPHEstimator(CoxPHEstimator, StreamingMixin):
919
+ """Cox PH Estimator with streaming/batched prediction support.
920
+
921
+ This class extends CoxPHEstimator with methods for processing large
922
+ datasets that don't fit in memory.
923
+
924
+ See CoxPHEstimator for full documentation.
925
+ """
926
+
927
+ pass
928
+
929
+
930
+ class StreamingGradientBoostSurvivalEstimator(GradientBoostSurvivalEstimator, StreamingMixin):
931
+ """Gradient Boosting Survival Estimator with streaming support.
932
+
933
+ This class extends GradientBoostSurvivalEstimator with methods for
934
+ processing large datasets that don't fit in memory.
935
+
936
+ See GradientBoostSurvivalEstimator for full documentation.
937
+ """
938
+
939
+ pass
940
+
941
+
942
+ class StreamingSurvivalForestEstimator(SurvivalForestEstimator, StreamingMixin):
943
+ """Survival Forest Estimator with streaming support.
944
+
945
+ This class extends SurvivalForestEstimator with methods for processing
946
+ large datasets that don't fit in memory.
947
+
948
+ See SurvivalForestEstimator for full documentation.
949
+ """
950
+
951
+ pass
952
+
953
+
954
+ class StreamingAFTEstimator(AFTEstimator, StreamingMixin):
955
+ """AFT Estimator with streaming/batched prediction support.
956
+
957
+ This class extends AFTEstimator with methods for processing large
958
+ datasets that don't fit in memory.
959
+
960
+ See AFTEstimator for full documentation.
961
+ """
962
+
963
+ pass
964
+
965
+
966
+ class DeepSurvEstimator(BaseEstimator, RegressorMixin):
967
+ """Scikit-learn compatible DeepSurv model.
968
+
969
+ DeepSurv is a deep feedforward neural network for survival analysis
970
+ using Cox partial likelihood loss.
971
+
972
+ Parameters
973
+ ----------
974
+ hidden_layers : list of int, default=[64, 32]
975
+ Number of neurons in each hidden layer.
976
+ activation : str, default="selu"
977
+ Activation function. One of "relu", "selu", or "tanh".
978
+ dropout_rate : float, default=0.2
979
+ Dropout rate applied after each hidden layer.
980
+ learning_rate : float, default=0.001
981
+ Learning rate for the Adam optimizer.
982
+ batch_size : int, default=256
983
+ Mini-batch size for training.
984
+ n_epochs : int, default=100
985
+ Number of training epochs.
986
+ l2_reg : float, default=0.0001
987
+ L2 regularization (weight decay) coefficient.
988
+ seed : int or None, default=None
989
+ Random seed for reproducibility.
990
+ early_stopping_patience : int or None, default=10
991
+ Number of epochs without improvement before early stopping.
992
+ Set to None to disable early stopping.
993
+ validation_fraction : float, default=0.1
994
+ Fraction of training data to use for validation.
995
+
996
+ Attributes
997
+ ----------
998
+ model_ : DeepSurv
999
+ The underlying fitted model.
1000
+ n_features_in_ : int
1001
+ Number of features seen during fit.
1002
+ """
1003
+
1004
+ def __init__(
1005
+ self,
1006
+ hidden_layers: list[int] | None = None,
1007
+ activation: str = "selu",
1008
+ dropout_rate: float = 0.2,
1009
+ learning_rate: float = 0.001,
1010
+ batch_size: int = 256,
1011
+ n_epochs: int = 100,
1012
+ l2_reg: float = 0.0001,
1013
+ seed: int | None = None,
1014
+ early_stopping_patience: int | None = 10,
1015
+ validation_fraction: float = 0.1,
1016
+ ):
1017
+ self.hidden_layers = hidden_layers if hidden_layers is not None else [64, 32]
1018
+ self.activation = activation
1019
+ self.dropout_rate = dropout_rate
1020
+ self.learning_rate = learning_rate
1021
+ self.batch_size = batch_size
1022
+ self.n_epochs = n_epochs
1023
+ self.l2_reg = l2_reg
1024
+ self.seed = seed
1025
+ self.early_stopping_patience = early_stopping_patience
1026
+ self.validation_fraction = validation_fraction
1027
+
1028
+ def fit(self, X: ArrayLike, y: ArrayLike) -> "DeepSurvEstimator":
1029
+ """Fit the DeepSurv model.
1030
+
1031
+ Parameters
1032
+ ----------
1033
+ X : array-like of shape (n_samples, n_features)
1034
+ Training data.
1035
+ y : array-like of shape (n_samples, 2)
1036
+ Target values where y[:, 0] is survival time and y[:, 1] is event status.
1037
+
1038
+ Returns
1039
+ -------
1040
+ self : DeepSurvEstimator
1041
+ Fitted estimator.
1042
+ """
1043
+ X, time, status = _validate_survival_data(X, y)
1044
+ self.n_features_in_ = X.shape[1]
1045
+ n_obs = X.shape[0]
1046
+
1047
+ activation = _surv.Activation(self.activation)
1048
+ config = _surv.DeepSurvConfig(
1049
+ hidden_layers=self.hidden_layers,
1050
+ activation=activation,
1051
+ dropout_rate=self.dropout_rate,
1052
+ learning_rate=self.learning_rate,
1053
+ batch_size=self.batch_size,
1054
+ n_epochs=self.n_epochs,
1055
+ l2_reg=self.l2_reg,
1056
+ seed=self.seed,
1057
+ early_stopping_patience=self.early_stopping_patience,
1058
+ validation_fraction=self.validation_fraction,
1059
+ )
1060
+
1061
+ x_flat = X.flatten().tolist()
1062
+ self.model_ = _surv.DeepSurv.fit(
1063
+ x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
1064
+ )
1065
+
1066
+ self.is_fitted_ = True
1067
+ return self
1068
+
1069
+ def predict(self, X: ArrayLike) -> NDArray[np.float64]:
1070
+ """Predict risk scores for samples.
1071
+
1072
+ Parameters
1073
+ ----------
1074
+ X : array-like of shape (n_samples, n_features)
1075
+ Samples to predict.
1076
+
1077
+ Returns
1078
+ -------
1079
+ risk_scores : ndarray of shape (n_samples,)
1080
+ Predicted risk scores (higher = higher risk).
1081
+ """
1082
+ check_is_fitted(self)
1083
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
1084
+
1085
+ if X.shape[1] != self.n_features_in_:
1086
+ raise ValueError(
1087
+ f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
1088
+ )
1089
+
1090
+ x_flat = X.flatten().tolist()
1091
+ return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
1092
+
1093
+ def predict_survival_function(
1094
+ self, X: ArrayLike
1095
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
1096
+ """Predict survival function for samples.
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ X : array-like of shape (n_samples, n_features)
1101
+ Samples to predict.
1102
+
1103
+ Returns
1104
+ -------
1105
+ times : ndarray of shape (n_times,)
1106
+ Time points.
1107
+ survival : ndarray of shape (n_samples, n_times)
1108
+ Survival probabilities.
1109
+ """
1110
+ check_is_fitted(self)
1111
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
1112
+
1113
+ x_flat = X.flatten().tolist()
1114
+ survival = self.model_.predict_survival(x_flat, X.shape[0])
1115
+ return np.array(self.model_.unique_times), np.array(survival)
1116
+
1117
+ def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
1118
+ """Predict median survival time for samples.
1119
+
1120
+ Parameters
1121
+ ----------
1122
+ X : array-like of shape (n_samples, n_features)
1123
+ Samples to predict.
1124
+
1125
+ Returns
1126
+ -------
1127
+ median_times : ndarray of shape (n_samples,)
1128
+ Predicted median survival times (NaN if survival never drops below 0.5).
1129
+ """
1130
+ check_is_fitted(self)
1131
+ X = check_array(X, dtype=np.float64, ensure_2d=True)
1132
+
1133
+ x_flat = X.flatten().tolist()
1134
+ result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
1135
+ return np.array([t if t is not None else np.nan for t in result])
1136
+
1137
+ def score(self, X: ArrayLike, y: ArrayLike) -> float:
1138
+ """Return the concordance index on the given test data.
1139
+
1140
+ Parameters
1141
+ ----------
1142
+ X : array-like of shape (n_samples, n_features)
1143
+ Test samples.
1144
+ y : array-like of shape (n_samples, 2)
1145
+ True target values.
1146
+
1147
+ Returns
1148
+ -------
1149
+ score : float
1150
+ Concordance index (C-index), between 0 and 1.
1151
+ """
1152
+ check_is_fitted(self)
1153
+ X, time, status = _validate_survival_data(X, y)
1154
+ risk_scores = self.predict(X)
1155
+ return _compute_concordance_index(time, status, risk_scores)
1156
+
1157
+ @property
1158
+ def train_loss(self) -> NDArray[np.float64]:
1159
+ """Training loss history."""
1160
+ check_is_fitted(self)
1161
+ return np.array(self.model_.train_loss)
1162
+
1163
+ @property
1164
+ def val_loss(self) -> NDArray[np.float64]:
1165
+ """Validation loss history."""
1166
+ check_is_fitted(self)
1167
+ return np.array(self.model_.val_loss)
1168
+
1169
+
1170
+ class StreamingDeepSurvEstimator(DeepSurvEstimator, StreamingMixin):
1171
+ """DeepSurv Estimator with streaming/batched prediction support.
1172
+
1173
+ This class extends DeepSurvEstimator with methods for processing large
1174
+ datasets that don't fit in memory.
1175
+
1176
+ See DeepSurvEstimator for full documentation.
1177
+ """
1178
+
1179
+ pass
1180
+
1181
+
1182
+ def predict_large_dataset(
1183
+ estimator,
1184
+ X: ArrayLike,
1185
+ batch_size: int = 1000,
1186
+ output_file: str | None = None,
1187
+ verbose: bool = False,
1188
+ ) -> NDArray[np.float64]:
1189
+ """Predict on a large dataset using batched processing.
1190
+
1191
+ This is a utility function for making predictions on datasets that may
1192
+ not fit in memory. It processes the data in batches and optionally
1193
+ writes results to a memory-mapped file.
1194
+
1195
+ Parameters
1196
+ ----------
1197
+ estimator : fitted estimator
1198
+ A fitted survival estimator with a predict method.
1199
+ X : array-like of shape (n_samples, n_features)
1200
+ Samples to predict. Can be a numpy array or memory-mapped array.
1201
+ batch_size : int, default=1000
1202
+ Number of samples to process at once.
1203
+ output_file : str, optional
1204
+ Path to output file for memory-mapped results. If provided, results
1205
+ are written to this file and can exceed available RAM.
1206
+ verbose : bool, default=False
1207
+ If True, print progress information.
1208
+
1209
+ Returns
1210
+ -------
1211
+ predictions : ndarray of shape (n_samples,)
1212
+ Predicted risk scores. If output_file is provided, this is a
1213
+ memory-mapped array.
1214
+
1215
+ Examples
1216
+ --------
1217
+ >>> # Process a very large dataset
1218
+ >>> predictions = predict_large_dataset(
1219
+ ... model, X_huge, batch_size=10000,
1220
+ ... output_file='predictions.mmap', verbose=True
1221
+ ... )
1222
+ """
1223
+ X = np.asarray(X)
1224
+ n_samples = X.shape[0]
1225
+
1226
+ if output_file is not None:
1227
+ predictions = np.memmap(output_file, dtype=np.float64, mode="w+", shape=(n_samples,))
1228
+ else:
1229
+ predictions = np.empty(n_samples, dtype=np.float64)
1230
+
1231
+ n_batches = (n_samples + batch_size - 1) // batch_size
1232
+
1233
+ for batch_idx, (start_idx, chunk) in enumerate(iter_chunks(X, batch_size)):
1234
+ end_idx = start_idx + chunk.shape[0]
1235
+ predictions[start_idx:end_idx] = estimator.predict(chunk)
1236
+
1237
+ if verbose:
1238
+ print(f"Processed batch {batch_idx + 1}/{n_batches} (samples {start_idx}-{end_idx})")
1239
+
1240
+ if output_file is not None:
1241
+ predictions.flush()
1242
+
1243
+ return predictions
1244
+
1245
+
1246
+ def survival_curves_to_disk(
1247
+ estimator,
1248
+ X: ArrayLike,
1249
+ output_file: str,
1250
+ batch_size: int = 100,
1251
+ verbose: bool = False,
1252
+ ) -> tuple[NDArray[np.float64], np.memmap]:
1253
+ """Compute survival curves and write to disk for large datasets.
1254
+
1255
+ This function computes survival curves in batches and stores them in
1256
+ a memory-mapped file, allowing processing of datasets larger than RAM.
1257
+
1258
+ Parameters
1259
+ ----------
1260
+ estimator : fitted estimator
1261
+ A fitted survival estimator with predict_survival_function method.
1262
+ X : array-like of shape (n_samples, n_features)
1263
+ Samples to predict.
1264
+ output_file : str
1265
+ Path to output file for memory-mapped survival curves.
1266
+ batch_size : int, default=100
1267
+ Number of samples to process at once. Smaller values use less
1268
+ memory but are slower.
1269
+ verbose : bool, default=False
1270
+ If True, print progress information.
1271
+
1272
+ Returns
1273
+ -------
1274
+ times : ndarray of shape (n_times,)
1275
+ Time points for the survival curves.
1276
+ survival : memmap of shape (n_samples, n_times)
1277
+ Memory-mapped array of survival probabilities.
1278
+
1279
+ Examples
1280
+ --------
1281
+ >>> times, survival_curves = survival_curves_to_disk(
1282
+ ... model, X_huge, 'survival_curves.mmap',
1283
+ ... batch_size=100, verbose=True
1284
+ ... )
1285
+ >>> # Access individual survival curves without loading all into memory
1286
+ >>> curve_0 = survival_curves[0] # Loads only first curve
1287
+ """
1288
+ X = np.asarray(X)
1289
+ n_samples = X.shape[0]
1290
+
1291
+ first_times, first_surv = estimator.predict_survival_function(X[:1])
1292
+ n_times = len(first_times)
1293
+ times = first_times
1294
+
1295
+ survival = np.memmap(output_file, dtype=np.float64, mode="w+", shape=(n_samples, n_times))
1296
+
1297
+ n_batches = (n_samples + batch_size - 1) // batch_size
1298
+
1299
+ for batch_idx, (start_idx, chunk) in enumerate(iter_chunks(X, batch_size)):
1300
+ end_idx = start_idx + chunk.shape[0]
1301
+ _, batch_surv = estimator.predict_survival_function(chunk)
1302
+ survival[start_idx:end_idx] = batch_surv
1303
+
1304
+ if verbose:
1305
+ print(f"Processed batch {batch_idx + 1}/{n_batches} (samples {start_idx}-{end_idx})")
1306
+
1307
+ survival.flush()
1308
+ return times, survival