scikit-survival 0.25.0__cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.25.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-312-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +205 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-312-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +89 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-312-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-312-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +183 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-312-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-312-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +690 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +108 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-312-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +415 -0
@@ -0,0 +1,592 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numbers
14
+ import warnings
15
+
16
+ import numpy as np
17
+ from sklearn.base import BaseEstimator
18
+ from sklearn.exceptions import ConvergenceWarning
19
+ from sklearn.preprocessing import normalize as f_normalize
20
+ from sklearn.utils._param_validation import Interval, StrOptions
21
+ from sklearn.utils.validation import (
22
+ assert_all_finite,
23
+ check_is_fitted,
24
+ check_non_negative,
25
+ column_or_1d,
26
+ validate_data,
27
+ )
28
+
29
+ from ..base import SurvivalAnalysisMixin
30
+ from ..util import check_array_survival
31
+ from ._coxnet import call_fit_coxnet
32
+ from .coxph import BreslowEstimator
33
+
34
+ __all__ = ["CoxnetSurvivalAnalysis"]
35
+
36
+
37
+ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
38
+ r"""Cox's proportional hazard's model with elastic net penalty.
39
+
40
+ See the :ref:`User Guide </user_guide/coxnet.ipynb>` and [1]_ for further description.
41
+
42
+ Parameters
43
+ ----------
44
+ n_alphas : int, optional, default: 100
45
+ Number of alphas along the regularization path.
46
+
47
+ alphas : array-like or None, optional
48
+ List of alphas where to compute the models.
49
+ If ``None``, alphas are set automatically.
50
+
51
+ In this case, the ``alphas`` sequence is determined by :math:`\alpha_\max`
52
+ and ``alpha_min_ratio``. The latter determines the smallest alpha value
53
+ :math:`\alpha_\min` in the generated alphas sequence such that
54
+ ``alpha_min_ratio`` equals the ratio :math:`\frac{\alpha_\min}{\alpha_\max}`.
55
+ The generated ``alphas`` sequence contains ``n_alphas`` values linear
56
+ on the log scale from :math:`\alpha_\max` down to :math:`\alpha_\min`.
57
+ :math:`\alpha_\max` is not user-specified but is computed from the
58
+ input data.
59
+
60
+ alpha_min_ratio : float or { "auto" }, optional, default: "auto"
61
+ Determines the minimum alpha of the regularization path
62
+ if ``alphas`` is ``None``. The smallest value for alpha
63
+ is computed as the fraction of the maximum
64
+ alpha (i.e. the smallest value for which all
65
+ coefficients are zero), which is derived from the input data.
66
+
67
+ If set to "auto", the value will depend on the
68
+ sample size relative to the number of features:
69
+
70
+ - If ``n_samples > n_features``, the default value is 0.0001.
71
+ - If ``n_samples <= n_features``, the default value is 0.01.
72
+
73
+ l1_ratio : float, optional, default: 0.5
74
+ The ElasticNet mixing parameter, with ``0 < l1_ratio <= 1``.
75
+ For ``l1_ratio = 0`` the penalty is an L2 penalty.
76
+ For ``l1_ratio = 1`` it is an L1 penalty.
77
+ For ``0 < l1_ratio < 1``, the penalty is a combination of L1 and L2.
78
+
79
+ penalty_factor : array-like or None, optional
80
+ Separate penalty factors can be applied to each coefficient.
81
+ This is a number that multiplies alpha to allow differential
82
+ shrinkage. Can be 0 for some variables, which implies no shrinkage,
83
+ and that variable is always included in the model.
84
+ Default is 1 for all variables.
85
+
86
+ Note: the penalty factors are internally rescaled to sum to
87
+ `n_features`, and the alphas sequence will reflect this change.
88
+
89
+ normalize : bool, optional, default: False
90
+ If True, the features X will be normalized before optimization by
91
+ subtracting the mean and dividing by the l2-norm.
92
+ If you wish to standardize, please use
93
+ :class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
94
+ on an estimator with ``normalize=False``.
95
+
96
+ copy_X : boolean, optional, default: True
97
+ If ``True``, X will be copied; else, it may be overwritten.
98
+
99
+ tol : float, optional, default: 1e-7
100
+ The tolerance for the optimization: optimization continues
101
+ until all updates are smaller than ``tol``.
102
+
103
+ max_iter : int, optional, default: 100000
104
+ The maximum number of iterations taken for the solver to converge.
105
+
106
+ verbose : bool, optional, default: False
107
+ Whether to print additional information during optimization.
108
+
109
+ fit_baseline_model : bool, optional, default: False
110
+ Whether to estimate baseline survival function
111
+ and baseline cumulative hazard function for each alpha.
112
+ If enabled, :meth:`predict_cumulative_hazard_function` and
113
+ :meth:`predict_survival_function` can be used to obtain
114
+ predicted cumulative hazard function and survival function.
115
+
116
+ Attributes
117
+ ----------
118
+ alphas_ : ndarray, shape=(n_alphas,)
119
+ The actual sequence of alpha values used.
120
+
121
+ alpha_min_ratio_ : float
122
+ The inferred value of alpha_min_ratio.
123
+
124
+ penalty_factor_ : ndarray, shape=(n_features,)
125
+ The actual penalty factors used.
126
+
127
+ coef_ : ndarray, shape=(n_features, n_alphas)
128
+ Matrix of coefficients.
129
+
130
+ offset_ : ndarray, shape=(n_alphas,)
131
+ Bias term to account for non-centered features.
132
+
133
+ deviance_ratio_ : ndarray, shape=(n_alphas,)
134
+ The fraction of (null) deviance explained.
135
+ The deviance is defined as :math:`2 \cdot (\text{loglike_sat} - \text{loglike})`,
136
+ where `loglike_sat` is the log-likelihood for the saturated model
137
+ (a model with a free parameter per observation). Null deviance is defined as
138
+ :math:`2 \cdot (\text{loglike_sat} - \text{loglike(Null)})`;
139
+ The NULL model is the model with all zero coefficients.
140
+ Hence, ``deviance_ratio_`` is :math:`1 - \frac{\text{deviance}}{\text{null_deviance}}`.
141
+
142
+ n_features_in_ : int
143
+ Number of features seen during ``fit``.
144
+
145
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
146
+ Names of features seen during ``fit``. Defined only when `X`
147
+ has feature names that are all strings.
148
+
149
+ unique_times_ : ndarray, shape = (n_unique_times,)
150
+ Unique time points.
151
+
152
+ References
153
+ ----------
154
+ .. [1] Simon N, Friedman J, Hastie T, Tibshirani R.
155
+ Regularization paths for Cox’s proportional hazards model via coordinate descent.
156
+ Journal of statistical software. 2011 Mar;39(5):1.
157
+ """
158
+
159
+ _parameter_constraints: dict = {
160
+ "n_alphas": [Interval(numbers.Integral, 1, None, closed="left")],
161
+ "alphas": ["array-like", None],
162
+ "alpha_min_ratio": [Interval(numbers.Real, 0, None, closed="neither"), StrOptions({"auto"})],
163
+ "l1_ratio": [Interval(numbers.Real, 0.0, 1.0, closed="right")],
164
+ "penalty_factor": ["array-like", None],
165
+ "normalize": ["boolean"],
166
+ "copy_X": ["boolean"],
167
+ "tol": [Interval(numbers.Real, 0, None, closed="left")],
168
+ "max_iter": [Interval(numbers.Integral, 1, None, closed="left")],
169
+ "verbose": ["verbose"],
170
+ "fit_baseline_model": ["boolean"],
171
+ }
172
+
173
+ def __init__(
174
+ self,
175
+ *,
176
+ n_alphas=100,
177
+ alphas=None,
178
+ alpha_min_ratio="auto",
179
+ l1_ratio=0.5,
180
+ penalty_factor=None,
181
+ normalize=False,
182
+ copy_X=True,
183
+ tol=1e-7,
184
+ max_iter=100000,
185
+ verbose=False,
186
+ fit_baseline_model=False,
187
+ ):
188
+ self.n_alphas = n_alphas
189
+ self.alphas = alphas
190
+ self.alpha_min_ratio = alpha_min_ratio
191
+ self.l1_ratio = l1_ratio
192
+ self.penalty_factor = penalty_factor
193
+ self.normalize = normalize
194
+ self.copy_X = copy_X
195
+ self.tol = tol
196
+ self.max_iter = max_iter
197
+ self.verbose = verbose
198
+ self.fit_baseline_model = fit_baseline_model
199
+
200
+ self._baseline_models = None
201
+
202
+ def _pre_fit(self, X, y):
203
+ X = validate_data(self, X, ensure_min_samples=2, dtype=np.float64, copy=self.copy_X)
204
+ event, time = check_array_survival(X, y)
205
+ # center feature matrix
206
+ X_offset = np.average(X, axis=0)
207
+ X -= X_offset
208
+ if self.normalize:
209
+ X, X_scale = f_normalize(X, copy=False, axis=0, return_norm=True)
210
+ else:
211
+ X_scale = np.ones(X.shape[1], dtype=X.dtype)
212
+
213
+ # sort descending
214
+ o = np.argsort(-time, kind="mergesort")
215
+ X = np.asfortranarray(X[o, :])
216
+ event_num = event[o].astype(np.uint8)
217
+ time = time[o].astype(np.float64)
218
+ return X, event_num, time, X_offset, X_scale
219
+
220
+ def _check_penalty_factor(self, n_features):
221
+ if self.penalty_factor is None:
222
+ penalty_factor = np.ones(n_features, dtype=np.float64)
223
+ else:
224
+ pf = column_or_1d(self.penalty_factor, warn=True)
225
+ if pf.shape[0] != n_features:
226
+ raise ValueError(
227
+ f"penalty_factor must be array of length n_features ({n_features}), but got {pf.shape[0]}"
228
+ )
229
+ assert_all_finite(pf, input_name="penalty_factor")
230
+ check_non_negative(pf, "penalty_factor")
231
+ penalty_factor = pf * n_features / pf.sum()
232
+ return penalty_factor
233
+
234
+ def _check_alphas(self):
235
+ create_path = self.alphas is None
236
+ if create_path:
237
+ alphas = np.empty(int(self.n_alphas), dtype=np.float64)
238
+ else:
239
+ alphas = column_or_1d(self.alphas, warn=True)
240
+ assert_all_finite(alphas, input_name="alphas")
241
+ check_non_negative(alphas, "alphas")
242
+ return alphas, create_path
243
+
244
+ def _check_alpha_min_ratio(self, n_samples, n_features):
245
+ alpha_min_ratio = self.alpha_min_ratio
246
+ if isinstance(alpha_min_ratio, str) and self.alpha_min_ratio == "auto":
247
+ if n_samples > n_features:
248
+ alpha_min_ratio = 0.0001
249
+ else:
250
+ alpha_min_ratio = 0.01
251
+
252
+ return alpha_min_ratio
253
+
254
+ def _check_params(self, n_samples, n_features):
255
+ self._validate_params()
256
+
257
+ penalty_factor = self._check_penalty_factor(n_features)
258
+
259
+ alphas, create_path = self._check_alphas()
260
+
261
+ alpha_min_ratio = self._check_alpha_min_ratio(n_samples, n_features)
262
+
263
+ return create_path, alphas.astype(np.float64), penalty_factor.astype(np.float64), alpha_min_ratio
264
+
265
+ def fit(self, X, y):
266
+ """Fit estimator.
267
+
268
+ Parameters
269
+ ----------
270
+ X : array-like, shape = (n_samples, n_features)
271
+ Data matrix
272
+
273
+ y : structured array, shape = (n_samples,)
274
+ A structured array with two fields. The first field is a boolean
275
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
276
+ The second field is a float with the time of event or time of censoring.
277
+
278
+ Returns
279
+ -------
280
+ self
281
+ """
282
+ X, event_num, time, X_offset, X_scale = self._pre_fit(X, y)
283
+ create_path, alphas, penalty, alpha_min_ratio = self._check_params(*X.shape)
284
+
285
+ coef, alphas, deviance_ratio, n_iter = call_fit_coxnet(
286
+ X,
287
+ time,
288
+ event_num,
289
+ penalty,
290
+ alphas,
291
+ create_path,
292
+ alpha_min_ratio,
293
+ self.l1_ratio,
294
+ int(self.max_iter),
295
+ self.tol,
296
+ self.verbose,
297
+ )
298
+ assert np.isfinite(coef).all()
299
+
300
+ if np.all(np.absolute(coef) < np.finfo(float).eps):
301
+ warnings.warn("all coefficients are zero, consider decreasing alpha.", stacklevel=2)
302
+
303
+ if n_iter >= self.max_iter:
304
+ warnings.warn(
305
+ "Optimization terminated early, you might want"
306
+ f" to increase the number of iterations (max_iter={self.max_iter}).",
307
+ category=ConvergenceWarning,
308
+ stacklevel=2,
309
+ )
310
+
311
+ coef /= X_scale[:, np.newaxis]
312
+
313
+ if self.fit_baseline_model:
314
+ predictions = np.dot(X, coef)
315
+ self._baseline_models = tuple(
316
+ BreslowEstimator().fit(predictions[:, i], event_num, time) for i in range(coef.shape[1])
317
+ )
318
+ else:
319
+ self._baseline_models = None
320
+
321
+ self.alphas_ = alphas
322
+ self.alpha_min_ratio_ = alpha_min_ratio
323
+ self.penalty_factor_ = penalty
324
+ self.coef_ = coef
325
+ self.deviance_ratio_ = deviance_ratio
326
+ self.offset_ = np.dot(X_offset, coef)
327
+ return self
328
+
329
+ def _get_coef(self, alpha):
330
+ check_is_fitted(self, "coef_")
331
+
332
+ if alpha is None:
333
+ coef = self.coef_[:, -1], self.offset_[-1]
334
+ else:
335
+ coef = self._interpolate_coefficients(alpha)
336
+ return coef
337
+
338
+ def _interpolate_coefficients(self, alpha):
339
+ """Interpolate coefficients by calculating the weighted average of coefficient vectors corresponding to
340
+ neighbors of alpha in the list of alphas constructed during training."""
341
+ exact = False
342
+ coef_idx = None
343
+ for i, val in enumerate(self.alphas_):
344
+ if val > alpha:
345
+ coef_idx = i
346
+ elif alpha - val < np.finfo(float).eps:
347
+ coef_idx = i
348
+ exact = True
349
+ break
350
+
351
+ if coef_idx is None:
352
+ coef = self.coef_[:, 0]
353
+ offset = self.offset_[0]
354
+ elif exact or coef_idx == len(self.alphas_) - 1:
355
+ coef = self.coef_[:, coef_idx]
356
+ offset = self.offset_[coef_idx]
357
+ else:
358
+ # interpolate between coefficients
359
+ a1 = self.alphas_[coef_idx + 1]
360
+ a2 = self.alphas_[coef_idx]
361
+ frac = (alpha - a1) / (a2 - a1)
362
+ coef = frac * self.coef_[:, coef_idx] + (1.0 - frac) * self.coef_[:, coef_idx + 1]
363
+ offset = frac * self.offset_[coef_idx] + (1.0 - frac) * self.offset_[coef_idx + 1]
364
+
365
+ return coef, offset
366
+
367
+ def predict(self, X, alpha=None):
368
+ """Predict risk scores.
369
+
370
+ The risk score is the linear predictor of the model,
371
+ computed as the dot product of the input features `X` and the
372
+ estimated coefficients `coef_`. A higher score indicates a
373
+ higher risk of experiencing the event.
374
+
375
+ Parameters
376
+ ----------
377
+ X : array-like, shape = (n_samples, n_features)
378
+ Test data of which to calculate log-likelihood from
379
+
380
+ alpha : float, optional
381
+ Constant that multiplies the penalty terms. If the same alpha was used during training, exact
382
+ coefficients are used, otherwise coefficients are interpolated from the closest alpha values that
383
+ were used during training. If set to ``None``, the last alpha in the solution path is used.
384
+
385
+ Returns
386
+ -------
387
+ risk_score : array, shape = (n_samples,)
388
+ Predicted risk scores.
389
+ """
390
+ X = validate_data(self, X, reset=False)
391
+ coef, offset = self._get_coef(alpha)
392
+ return np.dot(X, coef) - offset
393
+
394
+ def _get_baseline_model(self, alpha):
395
+ check_is_fitted(self, "coef_")
396
+ if self._baseline_models is None:
397
+ raise ValueError("`fit` must be called with the fit_baseline_model option set to True.")
398
+
399
+ if alpha is None:
400
+ baseline_model = self._baseline_models[-1]
401
+ else:
402
+ is_close = np.isclose(alpha, self.alphas_)
403
+ if is_close.any():
404
+ idx = np.flatnonzero(is_close)[0]
405
+ baseline_model = self._baseline_models[idx]
406
+ else:
407
+ raise ValueError(f"alpha must be one value of alphas_: {self.alphas_}")
408
+
409
+ return baseline_model
410
+
411
+ def predict_cumulative_hazard_function(self, X, alpha=None, return_array=False):
412
+ r"""Predict cumulative hazard function.
413
+
414
+ Only available if :meth:`fit` has been called with `fit_baseline_model = True`.
415
+
416
+ The cumulative hazard function for an individual
417
+ with feature vector :math:`x_\alpha` is defined as
418
+
419
+ .. math::
420
+
421
+ H(t \mid x_\alpha) = \exp(x_\alpha^\top \beta) H_0(t) ,
422
+
423
+ where :math:`H_0(t)` is the baseline hazard function,
424
+ estimated by Breslow's estimator.
425
+
426
+ Parameters
427
+ ----------
428
+ X : array-like, shape = (n_samples, n_features)
429
+ Data matrix.
430
+
431
+ alpha : float, optional
432
+ Constant that multiplies the penalty terms. The same alpha as used during training
433
+ must be specified. If set to ``None``, the last alpha in the solution path is used.
434
+
435
+ return_array : bool, default: False
436
+ Whether to return a single array of cumulative hazard values
437
+ or a list of step functions.
438
+
439
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
440
+ objects is returned.
441
+
442
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
443
+ returned, where `n_unique_times` is the number of unique
444
+ event times in the training data. Each row represents the cumulative
445
+ hazard function of an individual evaluated at `unique_times_`.
446
+
447
+ Returns
448
+ -------
449
+ cum_hazard : ndarray
450
+ If `return_array` is `False`, an array of `n_samples`
451
+ :class:`sksurv.functions.StepFunction` instances is returned.
452
+
453
+ If `return_array` is `True`, a numeric array of shape
454
+ `(n_samples, n_unique_times_)` is returned.
455
+
456
+ Examples
457
+ --------
458
+ .. plot::
459
+
460
+ >>> import matplotlib.pyplot as plt
461
+ >>> from sksurv.datasets import load_breast_cancer
462
+ >>> from sksurv.preprocessing import OneHotEncoder
463
+ >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
464
+
465
+ Load and prepare the data.
466
+
467
+ >>> X, y = load_breast_cancer()
468
+ >>> X = OneHotEncoder().fit_transform(X)
469
+
470
+ Fit the model.
471
+
472
+ >>> estimator = CoxnetSurvivalAnalysis(
473
+ ... l1_ratio=0.99, fit_baseline_model=True
474
+ ... ).fit(X, y)
475
+
476
+ Estimate the cumulative hazard function for one sample and the five highest alpha.
477
+
478
+ >>> chf_funcs = {}
479
+ >>> for alpha in estimator.alphas_[:5]:
480
+ ... chf_funcs[alpha] = estimator.predict_cumulative_hazard_function(
481
+ ... X.iloc[:1], alpha=alpha)
482
+ ...
483
+
484
+ Plot the estimated cumulative hazard functions.
485
+
486
+ >>> for alpha, chf_alpha in chf_funcs.items():
487
+ ... for fn in chf_alpha:
488
+ ... plt.step(fn.x, fn(fn.x), where="post",
489
+ ... label=f"alpha = {alpha:.3f}")
490
+ ...
491
+ [...]
492
+ >>> plt.legend()
493
+ <matplotlib.legend.Legend object at 0x...>
494
+ >>> plt.show() # doctest: +SKIP
495
+ """
496
+ baseline_model = self._get_baseline_model(alpha)
497
+ return self._predict_cumulative_hazard_function(baseline_model, self.predict(X, alpha=alpha), return_array)
498
+
499
+ def predict_survival_function(self, X, alpha=None, return_array=False):
500
+ r"""Predict survival function.
501
+
502
+ Only available if :meth:`fit` has been called with `fit_baseline_model = True`.
503
+
504
+ The survival function for an individual
505
+ with feature vector :math:`x_\alpha` is defined as
506
+
507
+ .. math::
508
+
509
+ S(t \mid x_\alpha) = S_0(t)^{\exp(x_\alpha^\top \beta)} ,
510
+
511
+ where :math:`S_0(t)` is the baseline survival function,
512
+ estimated by Breslow's estimator.
513
+
514
+ Parameters
515
+ ----------
516
+ X : array-like, shape = (n_samples, n_features)
517
+ Data matrix.
518
+
519
+ alpha : float, optional
520
+ Constant that multiplies the penalty terms. The same alpha as used during training
521
+ must be specified. If set to ``None``, the last alpha in the solution path is used.
522
+
523
+ return_array : bool, default: False
524
+ Whether to return a single array of survival probabilities
525
+ or a list of step functions.
526
+
527
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
528
+ objects is returned.
529
+
530
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
531
+ returned, where `n_unique_times` is the number of unique
532
+ event times in the training data. Each row represents the survival
533
+ function of an individual evaluated at `unique_times_`.
534
+
535
+ Returns
536
+ -------
537
+ survival : ndarray
538
+ If `return_array` is `False`, an array of `n_samples`
539
+ :class:`sksurv.functions.StepFunction` instances is returned.
540
+
541
+ If `return_array` is `True`, a numeric array of shape
542
+ `(n_samples, n_unique_times_)` is returned.
543
+
544
+
545
+ Examples
546
+ --------
547
+ .. plot::
548
+
549
+ >>> import matplotlib.pyplot as plt
550
+ >>> from sksurv.datasets import load_breast_cancer
551
+ >>> from sksurv.preprocessing import OneHotEncoder
552
+ >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
553
+
554
+ Load and prepare the data.
555
+
556
+ >>> X, y = load_breast_cancer()
557
+ >>> X = OneHotEncoder().fit_transform(X)
558
+
559
+ Fit the model.
560
+
561
+ >>> estimator = CoxnetSurvivalAnalysis(
562
+ ... l1_ratio=0.99, fit_baseline_model=True
563
+ ... ).fit(X, y)
564
+
565
+ Estimate the survival function for one sample and the five highest alpha.
566
+
567
+ >>> surv_funcs = {}
568
+ >>> for alpha in estimator.alphas_[:5]:
569
+ ... surv_funcs[alpha] = estimator.predict_survival_function(
570
+ ... X.iloc[:1], alpha=alpha)
571
+ ...
572
+
573
+ Plot the estimated survival functions.
574
+
575
+ >>> for alpha, surv_alpha in surv_funcs.items():
576
+ ... for fn in surv_alpha:
577
+ ... plt.step(fn.x, fn(fn.x), where="post",
578
+ ... label=f"alpha = {alpha:.3f}")
579
+ ...
580
+ [...]
581
+ >>> plt.ylim(0, 1)
582
+ (0.0, 1.0)
583
+ >>> plt.legend()
584
+ <matplotlib.legend.Legend object at 0x...>
585
+ >>> plt.show() # doctest: +SKIP
586
+ """
587
+ baseline_model = self._get_baseline_model(alpha)
588
+ return self._predict_survival_function(baseline_model, self.predict(X, alpha=alpha), return_array)
589
+
590
+ @property
591
+ def unique_times_(self):
592
+ return self._get_baseline_model(None).unique_times_