scikit-survival 0.24.1__cp311-cp311-win_amd64.whl → 0.25.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/WHEEL +1 -1
  4. sksurv/__init__.py +51 -6
  5. sksurv/base.py +12 -2
  6. sksurv/bintrees/_binarytrees.cp311-win_amd64.pyd +0 -0
  7. sksurv/column.py +33 -29
  8. sksurv/compare.py +22 -22
  9. sksurv/datasets/base.py +45 -20
  10. sksurv/docstrings.py +99 -0
  11. sksurv/ensemble/_coxph_loss.cp311-win_amd64.pyd +0 -0
  12. sksurv/ensemble/boosting.py +116 -168
  13. sksurv/ensemble/forest.py +94 -151
  14. sksurv/functions.py +29 -29
  15. sksurv/io/arffread.py +34 -3
  16. sksurv/io/arffwrite.py +38 -2
  17. sksurv/kernels/_clinical_kernel.cp311-win_amd64.pyd +0 -0
  18. sksurv/kernels/clinical.py +33 -13
  19. sksurv/linear_model/_coxnet.cp311-win_amd64.pyd +0 -0
  20. sksurv/linear_model/aft.py +14 -11
  21. sksurv/linear_model/coxnet.py +138 -89
  22. sksurv/linear_model/coxph.py +102 -83
  23. sksurv/meta/ensemble_selection.py +91 -9
  24. sksurv/meta/stacking.py +47 -26
  25. sksurv/metrics.py +257 -224
  26. sksurv/nonparametric.py +150 -81
  27. sksurv/preprocessing.py +55 -27
  28. sksurv/svm/_minlip.cp311-win_amd64.pyd +0 -0
  29. sksurv/svm/_prsvm.cp311-win_amd64.pyd +0 -0
  30. sksurv/svm/minlip.py +160 -79
  31. sksurv/svm/naive_survival_svm.py +63 -34
  32. sksurv/svm/survival_svm.py +103 -103
  33. sksurv/tree/_criterion.cp311-win_amd64.pyd +0 -0
  34. sksurv/tree/tree.py +170 -84
  35. sksurv/util.py +80 -26
  36. scikit_survival-0.24.1.dist-info/METADATA +0 -889
  37. scikit_survival-0.24.1.dist-info/RECORD +0 -57
  38. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/licenses/COPYING +0 -0
  39. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/top_level.txt +0 -0
@@ -62,10 +62,11 @@ def _ordinal_as_numeric(x, ordinal_columns):
62
62
 
63
63
 
64
64
  def clinical_kernel(x, y=None):
65
- """Computes clinical kernel
65
+ """Computes clinical kernel.
66
66
 
67
67
  The clinical kernel distinguishes between continuous
68
- ordinal,and nominal variables.
68
+ ordinal, and nominal variables.
69
+ Kernel values are normalized to lie within [0, 1].
69
70
 
70
71
  See [1]_ for further description.
71
72
 
@@ -80,13 +81,30 @@ def clinical_kernel(x, y=None):
80
81
  Returns
81
82
  -------
82
83
  kernel : array, shape = (n_samples_x, n_samples_y)
83
- Kernel matrix. Values are normalized to lie within [0, 1].
84
+ Kernel matrix.
84
85
 
85
86
  References
86
87
  ----------
87
88
  .. [1] Daemen, A., De Moor, B.,
88
89
  "Development of a kernel function for clinical data".
89
90
  Annual International Conference of the IEEE Engineering in Medicine and Biology Society, 5913-7, 2009
91
+
92
+ Examples
93
+ --------
94
+ >>> import pandas as pd
95
+ >>> from sksurv.kernels import clinical_kernel
96
+ >>>
97
+ >>> data = pd.DataFrame({
98
+ ... 'feature_num': [1.0, 2.0, 3.0],
99
+ ... 'feature_ord': pd.Categorical(['low', 'medium', 'high'], ordered=True),
100
+ ... 'feature_nom': pd.Categorical(['A', 'B', 'A'])
101
+ ... })
102
+ >>>
103
+ >>> kernel_matrix = clinical_kernel(data)
104
+ >>> print(kernel_matrix)
105
+ [[1. 0.33333333 0.5 ]
106
+ [0.33333333 1. 0.16666667]
107
+ [0.5 0.16666667 1. ]]
90
108
  """
91
109
  if y is not None:
92
110
  if x.shape[1] != y.shape[1]:
@@ -114,7 +132,7 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
114
132
  """Transform data using a clinical Kernel
115
133
 
116
134
  The clinical kernel distinguishes between continuous
117
- ordinal,and nominal variables.
135
+ ordinal, and nominal variables.
118
136
 
119
137
  See [1]_ for further description.
120
138
 
@@ -131,7 +149,7 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
131
149
  n_features_in_ : int
132
150
  Number of features seen during ``fit``.
133
151
 
134
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
152
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
135
153
  Names of features seen during ``fit``. Defined only when `X`
136
154
  has feature names that are all strings.
137
155
 
@@ -214,10 +232,12 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
214
232
  Data to estimate parameters from.
215
233
 
216
234
  y : None
217
- Argument is ignored (included for compatibility reasons).
235
+ Ignored. This parameter exists only for compatibility with
236
+ :class:`sklearn.pipeline.Pipeline`.
218
237
 
219
238
  kwargs : dict
220
- Argument is ignored (included for compatibility reasons).
239
+ Ignored. This parameter exists only for compatibility with
240
+ :class:`sklearn.pipeline.Pipeline`.
221
241
 
222
242
  Returns
223
243
  -------
@@ -281,10 +301,10 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
281
301
 
282
302
  Parameters
283
303
  ----------
284
- x : array-like, shape = (n_samples_x, n_features)
304
+ x : pandas.DataFrame, shape = (n_samples_x, n_features)
285
305
  Training data
286
306
 
287
- y : array-like, shape = (n_samples_y, n_features)
307
+ y : pandas.DataFrame, shape = (n_samples_y, n_features)
288
308
  Testing data
289
309
 
290
310
  Returns
@@ -295,18 +315,18 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
295
315
  return self.fit(X).transform(Y).T
296
316
 
297
317
  def pairwise_kernel(self, X, Y):
298
- """Function to use with :func:`sklearn.metrics.pairwise.pairwise_kernels`
318
+ """Function to use with :func:`sklearn.metrics.pairwise.pairwise_kernels`.
299
319
 
300
320
  Parameters
301
321
  ----------
302
- X : array, shape = (n_features,)
322
+ X : ndarray, shape = (n_features,)
303
323
 
304
- Y : array, shape = (n_features,)
324
+ Y : ndarray, shape = (n_features,)
305
325
 
306
326
  Returns
307
327
  -------
308
328
  similarity : float
309
- Similarities are normalized to be within [0, 1]
329
+ Similarities are normalized to be within [0, 1].
310
330
  """
311
331
  check_is_fitted(self, "X_fit_")
312
332
  if X.shape[0] != Y.shape[0]:
@@ -19,15 +19,15 @@ from ..util import check_array_survival
19
19
 
20
20
 
21
21
  class IPCRidge(Ridge, SurvivalAnalysisMixin):
22
- """Accelerated failure time model with inverse probability of censoring weights.
22
+ r"""Accelerated failure time model with inverse probability of censoring weights.
23
23
 
24
24
  This model assumes a regression model of the form
25
25
 
26
26
  .. math::
27
27
 
28
- \\log y = \\beta_0 + \\mathbf{X} \\beta + \\epsilon
28
+ \log y = \beta_0 + \mathbf{X} \beta + \epsilon
29
29
 
30
- L2-shrinkage is applied to the coefficients :math:`\\beta` and
30
+ L2-shrinkage is applied to the coefficients :math:`\beta` and
31
31
  each sample is weighted by the inverse probability of censoring
32
32
  to account for right censoring (under the assumption that
33
33
  censoring is independent of the features, i.e., random censoring).
@@ -57,7 +57,7 @@ class IPCRidge(Ridge, SurvivalAnalysisMixin):
57
57
  by scipy.sparse.linalg. For 'sag' solver, the default value is 1000.
58
58
  For 'lbfgs' solver, the default value is 15000.
59
59
 
60
- tol : float, default: 1e-4
60
+ tol : float, default: 1e-3
61
61
  Precision of the solution. Note that `tol` has no effect for solvers 'svd' and
62
62
  'cholesky'.
63
63
 
@@ -111,18 +111,18 @@ class IPCRidge(Ridge, SurvivalAnalysisMixin):
111
111
  coef_ : ndarray, shape = (n_features,)
112
112
  Weight vector.
113
113
 
114
- intercept_ : float or ndarray of shape (n_targets,)
114
+ intercept_ : float or ndarray, shape = (n_targets,)
115
115
  Independent term in decision function. Set to 0.0 if
116
116
  ``fit_intercept = False``.
117
117
 
118
- n_iter_ : None or ndarray of shape (n_targets,)
118
+ n_iter_ : None or ndarray, shape = (n_targets,)
119
119
  Actual number of iterations for each target. Available only for
120
120
  sag and lsqr solvers. Other solvers will return None.
121
121
 
122
122
  n_features_in_ : int
123
123
  Number of features seen during ``fit``.
124
124
 
125
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
125
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
126
126
  Names of features seen during ``fit``. Defined only when `X`
127
127
  has feature names that are all strings.
128
128
 
@@ -171,9 +171,9 @@ class IPCRidge(Ridge, SurvivalAnalysisMixin):
171
171
  Data matrix.
172
172
 
173
173
  y : structured array, shape = (n_samples,)
174
- A structured array containing the binary event indicator
175
- as first field, and time of event or time of censoring as
176
- second field.
174
+ A structured array with two fields. The first field is a boolean
175
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
176
+ The second field is a float with the time of event or time of censoring.
177
177
 
178
178
  Returns
179
179
  -------
@@ -196,10 +196,13 @@ class IPCRidge(Ridge, SurvivalAnalysisMixin):
196
196
 
197
197
  Returns
198
198
  -------
199
- C : array, shape = (n_samples,)
199
+ y_pred : array, shape = (n_samples,)
200
200
  Returns predicted values on original scale (NOT log scale).
201
201
  """
202
202
  return np.exp(super().predict(X))
203
203
 
204
204
  def score(self, X, y, sample_weight=None):
205
205
  return SurvivalAnalysisMixin.score(self, X, y)
206
+
207
+
208
+ IPCRidge.score.__doc__ = SurvivalAnalysisMixin.score.__doc__
@@ -35,7 +35,7 @@ __all__ = ["CoxnetSurvivalAnalysis"]
35
35
 
36
36
 
37
37
  class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
38
- """Cox's proportional hazard's model with elastic net penalty.
38
+ r"""Cox's proportional hazard's model with elastic net penalty.
39
39
 
40
40
  See the :ref:`User Guide </user_guide/coxnet.ipynb>` and [1]_ for further description.
41
41
 
@@ -46,19 +46,29 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
46
46
 
47
47
  alphas : array-like or None, optional
48
48
  List of alphas where to compute the models.
49
- If ``None`` alphas are set automatically.
49
+ If ``None``, alphas are set automatically.
50
+
51
+ In this case, the ``alphas`` sequence is determined by :math:`\alpha_\max`
52
+ and ``alpha_min_ratio``. The latter determines the smallest alpha value
53
+ :math:`\alpha_\min` in the generated alphas sequence such that
54
+ ``alpha_min_ratio`` equals the ratio :math:`\frac{\alpha_\min}{\alpha_\max}`.
55
+ The generated ``alphas`` sequence contains ``n_alphas`` values linear
56
+ on the log scale from :math:`\alpha_\max` down to :math:`\alpha_\min`.
57
+ :math:`\alpha_\max` is not user-specified but is computed from the
58
+ input data.
50
59
 
51
60
  alpha_min_ratio : float or { "auto" }, optional, default: "auto"
52
- Determines minimum alpha of the regularization path
61
+ Determines the minimum alpha of the regularization path
53
62
  if ``alphas`` is ``None``. The smallest value for alpha
54
- is computed as the fraction of the data derived maximum
63
+ is computed as the fraction of the maximum
55
64
  alpha (i.e. the smallest value for which all
56
- coefficients are zero).
65
+ coefficients are zero), which is derived from the input data.
57
66
 
58
67
  If set to "auto", the value will depend on the
59
- sample size relative to the number of features.
60
- If ``n_samples > n_features``, the default value is 0.0001
61
- If ``n_samples <= n_features``, 0.01 is the default value.
68
+ sample size relative to the number of features:
69
+
70
+ - If ``n_samples > n_features``, the default value is 0.0001.
71
+ - If ``n_samples <= n_features``, the default value is 0.01.
62
72
 
63
73
  l1_ratio : float, optional, default: 0.5
64
74
  The ElasticNet mixing parameter, with ``0 < l1_ratio <= 1``.
@@ -76,7 +86,7 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
76
86
  Note: the penalty factors are internally rescaled to sum to
77
87
  `n_features`, and the alphas sequence will reflect this change.
78
88
 
79
- normalize : boolean, optional, default: False
89
+ normalize : bool, optional, default: False
80
90
  If True, the features X will be normalized before optimization by
81
91
  subtracting the mean and dividing by the l2-norm.
82
92
  If you wish to standardize, please use
@@ -91,7 +101,7 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
91
101
  until all updates are smaller than ``tol``.
92
102
 
93
103
  max_iter : int, optional, default: 100000
94
- The maximum number of iterations.
104
+ The maximum number of iterations taken for the solver to converge.
95
105
 
96
106
  verbose : bool, optional, default: False
97
107
  Whether to print additional information during optimization.
@@ -122,15 +132,21 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
122
132
 
123
133
  deviance_ratio_ : ndarray, shape=(n_alphas,)
124
134
  The fraction of (null) deviance explained.
135
+ The deviance is defined as :math:`2 \cdot (\text{loglike_sat} - \text{loglike})`,
136
+ where `loglike_sat` is the log-likelihood for the saturated model
137
+ (a model with a free parameter per observation). Null deviance is defined as
138
+ :math:`2 \cdot (\text{loglike_sat} - \text{loglike(Null)})`;
139
+ The NULL model is the model with all zero coefficients.
140
+ Hence, ``deviance_ratio_`` is :math:`1 - \frac{\text{deviance}}{\text{null_deviance}}`.
125
141
 
126
142
  n_features_in_ : int
127
143
  Number of features seen during ``fit``.
128
144
 
129
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
145
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
130
146
  Names of features seen during ``fit``. Defined only when `X`
131
147
  has feature names that are all strings.
132
148
 
133
- unique_times_ : array of shape = (n_unique_times,)
149
+ unique_times_ : ndarray, shape = (n_unique_times,)
134
150
  Unique time points.
135
151
 
136
152
  References
@@ -255,9 +271,9 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
255
271
  Data matrix
256
272
 
257
273
  y : structured array, shape = (n_samples,)
258
- A structured array containing the binary event indicator
259
- as first field, and time of event or time of censoring as
260
- second field.
274
+ A structured array with two fields. The first field is a boolean
275
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
276
+ The second field is a float with the time of event or time of censoring.
261
277
 
262
278
  Returns
263
279
  -------
@@ -349,7 +365,12 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
349
365
  return coef, offset
350
366
 
351
367
  def predict(self, X, alpha=None):
352
- """The linear predictor of the model.
368
+ """Predict risk scores.
369
+
370
+ The risk score is the linear predictor of the model,
371
+ computed as the dot product of the input features `X` and the
372
+ estimated coefficients `coef_`. A higher score indicates a
373
+ higher risk of experiencing the event.
353
374
 
354
375
  Parameters
355
376
  ----------
@@ -363,8 +384,8 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
363
384
 
364
385
  Returns
365
386
  -------
366
- T : array, shape = (n_samples,)
367
- The predicted decision function
387
+ risk_score : array, shape = (n_samples,)
388
+ Predicted risk scores.
368
389
  """
369
390
  X = validate_data(self, X, reset=False)
370
391
  coef, offset = self._get_coef(alpha)
@@ -388,16 +409,16 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
388
409
  return baseline_model
389
410
 
390
411
  def predict_cumulative_hazard_function(self, X, alpha=None, return_array=False):
391
- """Predict cumulative hazard function.
412
+ r"""Predict cumulative hazard function.
392
413
 
393
414
  Only available if :meth:`fit` has been called with `fit_baseline_model = True`.
394
415
 
395
416
  The cumulative hazard function for an individual
396
- with feature vector :math:`x_\\alpha` is defined as
417
+ with feature vector :math:`x_\alpha` is defined as
397
418
 
398
419
  .. math::
399
420
 
400
- H(t \\mid x_\\alpha) = \\exp(x_\\alpha^\\top \\beta) H_0(t) ,
421
+ H(t \mid x_\alpha) = \exp(x_\alpha^\top \beta) H_0(t) ,
401
422
 
402
423
  where :math:`H_0(t)` is the baseline hazard function,
403
424
  estimated by Breslow's estimator.
@@ -411,68 +432,81 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
411
432
  Constant that multiplies the penalty terms. The same alpha as used during training
412
433
  must be specified. If set to ``None``, the last alpha in the solution path is used.
413
434
 
414
- return_array : boolean, default: False
415
- If set, return an array with the cumulative hazard rate
416
- for each `self.unique_times_`, otherwise an array of
417
- :class:`sksurv.functions.StepFunction`.
435
+ return_array : bool, default: False
436
+ Whether to return a single array of cumulative hazard values
437
+ or a list of step functions.
438
+
439
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
440
+ objects is returned.
441
+
442
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
443
+ returned, where `n_unique_times` is the number of unique
444
+ event times in the training data. Each row represents the cumulative
445
+ hazard function of an individual evaluated at `unique_times_`.
418
446
 
419
447
  Returns
420
448
  -------
421
449
  cum_hazard : ndarray
422
- If `return_array` is set, an array with the cumulative hazard rate
423
- for each `self.unique_times_`, otherwise an array of length `n_samples`
424
- of :class:`sksurv.functions.StepFunction` instances will be returned.
450
+ If `return_array` is `False`, an array of `n_samples`
451
+ :class:`sksurv.functions.StepFunction` instances is returned.
452
+
453
+ If `return_array` is `True`, a numeric array of shape
454
+ `(n_samples, n_unique_times_)` is returned.
425
455
 
426
456
  Examples
427
457
  --------
428
- >>> import matplotlib.pyplot as plt
429
- >>> from sksurv.datasets import load_breast_cancer
430
- >>> from sksurv.preprocessing import OneHotEncoder
431
- >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
458
+ .. plot::
432
459
 
433
- Load and prepare the data.
460
+ >>> import matplotlib.pyplot as plt
461
+ >>> from sksurv.datasets import load_breast_cancer
462
+ >>> from sksurv.preprocessing import OneHotEncoder
463
+ >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
434
464
 
435
- >>> X, y = load_breast_cancer()
436
- >>> X = OneHotEncoder().fit_transform(X)
465
+ Load and prepare the data.
437
466
 
438
- Fit the model.
467
+ >>> X, y = load_breast_cancer()
468
+ >>> X = OneHotEncoder().fit_transform(X)
439
469
 
440
- >>> estimator = CoxnetSurvivalAnalysis(l1_ratio=0.99, fit_baseline_model=True)
441
- >>> estimator.fit(X, y)
470
+ Fit the model.
442
471
 
443
- Estimate the cumulative hazard function for one sample and the five highest alpha.
472
+ >>> estimator = CoxnetSurvivalAnalysis(
473
+ ... l1_ratio=0.99, fit_baseline_model=True
474
+ ... ).fit(X, y)
444
475
 
445
- >>> chf_funcs = {}
446
- >>> for alpha in estimator.alphas_[:5]:
447
- ... chf_funcs[alpha] = estimator.predict_cumulative_hazard_function(
448
- ... X.iloc[:1], alpha=alpha)
449
- ...
476
+ Estimate the cumulative hazard function for one sample and the five highest alpha.
450
477
 
451
- Plot the estimated cumulative hazard functions.
478
+ >>> chf_funcs = {}
479
+ >>> for alpha in estimator.alphas_[:5]:
480
+ ... chf_funcs[alpha] = estimator.predict_cumulative_hazard_function(
481
+ ... X.iloc[:1], alpha=alpha)
482
+ ...
452
483
 
453
- >>> for alpha, chf_alpha in chf_funcs.items():
454
- ... for fn in chf_alpha:
455
- ... plt.step(fn.x, fn(fn.x), where="post",
456
- ... label=f"alpha = {alpha:.3f}")
457
- ...
458
- >>> plt.ylim(0, 1)
459
- >>> plt.legend()
460
- >>> plt.show()
484
+ Plot the estimated cumulative hazard functions.
485
+
486
+ >>> for alpha, chf_alpha in chf_funcs.items():
487
+ ... for fn in chf_alpha:
488
+ ... plt.step(fn.x, fn(fn.x), where="post",
489
+ ... label=f"alpha = {alpha:.3f}")
490
+ ...
491
+ [...]
492
+ >>> plt.legend()
493
+ <matplotlib.legend.Legend object at 0x...>
494
+ >>> plt.show() # doctest: +SKIP
461
495
  """
462
496
  baseline_model = self._get_baseline_model(alpha)
463
497
  return self._predict_cumulative_hazard_function(baseline_model, self.predict(X, alpha=alpha), return_array)
464
498
 
465
499
  def predict_survival_function(self, X, alpha=None, return_array=False):
466
- """Predict survival function.
500
+ r"""Predict survival function.
467
501
 
468
502
  Only available if :meth:`fit` has been called with `fit_baseline_model = True`.
469
503
 
470
504
  The survival function for an individual
471
- with feature vector :math:`x_\\alpha` is defined as
505
+ with feature vector :math:`x_\alpha` is defined as
472
506
 
473
507
  .. math::
474
508
 
475
- S(t \\mid x_\\alpha) = S_0(t)^{\\exp(x_\\alpha^\\top \\beta)} ,
509
+ S(t \mid x_\alpha) = S_0(t)^{\exp(x_\alpha^\top \beta)} ,
476
510
 
477
511
  where :math:`S_0(t)` is the baseline survival function,
478
512
  estimated by Breslow's estimator.
@@ -486,54 +520,69 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
486
520
  Constant that multiplies the penalty terms. The same alpha as used during training
487
521
  must be specified. If set to ``None``, the last alpha in the solution path is used.
488
522
 
489
- return_array : boolean, default: False
490
- If set, return an array with the probability
491
- of survival for each `self.unique_times_`,
492
- otherwise an array of :class:`sksurv.functions.StepFunction`.
523
+ return_array : bool, default: False
524
+ Whether to return a single array of survival probabilities
525
+ or a list of step functions.
526
+
527
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
528
+ objects is returned.
529
+
530
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
531
+ returned, where `n_unique_times` is the number of unique
532
+ event times in the training data. Each row represents the survival
533
+ function of an individual evaluated at `unique_times_`.
493
534
 
494
535
  Returns
495
536
  -------
496
537
  survival : ndarray
497
- If `return_array` is set, an array with the probability of
498
- survival for each `self.unique_times_`, otherwise an array of
499
- length `n_samples` of :class:`sksurv.functions.StepFunction`
500
- instances will be returned.
538
+ If `return_array` is `False`, an array of `n_samples`
539
+ :class:`sksurv.functions.StepFunction` instances is returned.
540
+
541
+ If `return_array` is `True`, a numeric array of shape
542
+ `(n_samples, n_unique_times_)` is returned.
543
+
501
544
 
502
545
  Examples
503
546
  --------
504
- >>> import matplotlib.pyplot as plt
505
- >>> from sksurv.datasets import load_breast_cancer
506
- >>> from sksurv.preprocessing import OneHotEncoder
507
- >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
547
+ .. plot::
548
+
549
+ >>> import matplotlib.pyplot as plt
550
+ >>> from sksurv.datasets import load_breast_cancer
551
+ >>> from sksurv.preprocessing import OneHotEncoder
552
+ >>> from sksurv.linear_model import CoxnetSurvivalAnalysis
508
553
 
509
- Load and prepare the data.
554
+ Load and prepare the data.
510
555
 
511
- >>> X, y = load_breast_cancer()
512
- >>> X = OneHotEncoder().fit_transform(X)
556
+ >>> X, y = load_breast_cancer()
557
+ >>> X = OneHotEncoder().fit_transform(X)
513
558
 
514
- Fit the model.
559
+ Fit the model.
515
560
 
516
- >>> estimator = CoxnetSurvivalAnalysis(l1_ratio=0.99, fit_baseline_model=True)
517
- >>> estimator.fit(X, y)
561
+ >>> estimator = CoxnetSurvivalAnalysis(
562
+ ... l1_ratio=0.99, fit_baseline_model=True
563
+ ... ).fit(X, y)
518
564
 
519
- Estimate the survival function for one sample and the five highest alpha.
565
+ Estimate the survival function for one sample and the five highest alpha.
520
566
 
521
- >>> surv_funcs = {}
522
- >>> for alpha in estimator.alphas_[:5]:
523
- ... surv_funcs[alpha] = estimator.predict_survival_function(
524
- ... X.iloc[:1], alpha=alpha)
525
- ...
567
+ >>> surv_funcs = {}
568
+ >>> for alpha in estimator.alphas_[:5]:
569
+ ... surv_funcs[alpha] = estimator.predict_survival_function(
570
+ ... X.iloc[:1], alpha=alpha)
571
+ ...
526
572
 
527
- Plot the estimated survival functions.
573
+ Plot the estimated survival functions.
528
574
 
529
- >>> for alpha, surv_alpha in surv_funcs.items():
530
- ... for fn in surv_alpha:
531
- ... plt.step(fn.x, fn(fn.x), where="post",
532
- ... label=f"alpha = {alpha:.3f}")
533
- ...
534
- >>> plt.ylim(0, 1)
535
- >>> plt.legend()
536
- >>> plt.show()
575
+ >>> for alpha, surv_alpha in surv_funcs.items():
576
+ ... for fn in surv_alpha:
577
+ ... plt.step(fn.x, fn(fn.x), where="post",
578
+ ... label=f"alpha = {alpha:.3f}")
579
+ ...
580
+ [...]
581
+ >>> plt.ylim(0, 1)
582
+ (0.0, 1.0)
583
+ >>> plt.legend()
584
+ <matplotlib.legend.Legend object at 0x...>
585
+ >>> plt.show() # doctest: +SKIP
537
586
  """
538
587
  baseline_model = self._get_baseline_model(alpha)
539
588
  return self._predict_survival_function(baseline_model, self.predict(X, alpha=alpha), return_array)