torch-survival 0.1.0a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.3
2
+ Name: torch-survival
3
+ Version: 0.1.0a2
4
+ Summary: Production-ready deep survival analysis with one line of code
5
+ Requires-Dist: matplotlib>=3.10.8
6
+ Requires-Dist: optuna>=4.5.0
7
+ Requires-Dist: optuna-integration[sklearn]>=4.6.0
8
+ Requires-Dist: pycox>=0.3.0
9
+ Requires-Dist: rich>=14.2.0
10
+ Requires-Dist: scikit-survival>=0.24.1
11
+ Requires-Dist: torch>=2.8.0
12
+ Requires-Dist: typing-extensions>=4.15.0
13
+ Requires-Dist: furo>=2025.7.19 ; extra == 'docs'
14
+ Requires-Dist: sphinx>=8.2.3 ; extra == 'docs'
15
+ Requires-Dist: h5py>=3.14.0 ; extra == 'test'
16
+ Requires-Dist: torchsurv>=0.1.5 ; extra == 'test'
17
+ Requires-Python: >=3.12
18
+ Provides-Extra: docs
19
+ Provides-Extra: test
20
+ Description-Content-Type: text/markdown
21
+
File without changes
@@ -0,0 +1,40 @@
1
+ [project]
2
+ name = "torch-survival"
3
+ version = "0.1.0-alpha.2"
4
+ description = "Production-ready deep survival analysis with one line of code"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "matplotlib>=3.10.8",
9
+ "optuna>=4.5.0",
10
+ "optuna-integration[sklearn]>=4.6.0",
11
+ "pycox>=0.3.0",
12
+ "rich>=14.2.0",
13
+ "scikit-survival>=0.24.1",
14
+ "torch>=2.8.0",
15
+ "typing-extensions>=4.15.0",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ docs = [
20
+ "furo>=2025.7.19",
21
+ "sphinx>=8.2.3",
22
+ ]
23
+ test = [
24
+ "h5py>=3.14.0",
25
+ "torchsurv>=0.1.5",
26
+ ]
27
+
28
+ [[tool.uv.index]]
29
+ name = "pytorch-cu128"
30
+ url = "https://download.pytorch.org/whl/cu128"
31
+ explicit = true
32
+
33
+ [tool.uv.sources]
34
+ torch = [
35
+ { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
36
+ ]
37
+
38
+ [build-system]
39
+ requires = ["uv_build>=0.8.11,<0.9.0"]
40
+ build-backend = "uv_build"
File without changes
@@ -0,0 +1,30 @@
1
+ from typing_extensions import TypedDict
2
+
3
+
4
+ class LayerConfig(TypedDict, closed=True):
5
+ # Maximum number of hidden layers
6
+ max_layers: int
7
+ # Maximum number of nodes per hidden layer
8
+ max_nodes_per_layer: int
9
+
10
+
11
+ class NetworkConfig(TypedDict, closed=True):
12
+ #: Layers of the neural network
13
+ layers: LayerConfig | list[int]
14
+ #: Activation function used in hidden layers
15
+ activation: list[str] | str
16
+ #: Probability of dropout for dropout layers
17
+ dropout: tuple[float, float] | float
18
+
19
+
20
+ class OptimizerConfig(TypedDict, closed=True):
21
+ #: Optimizer used to train neural network
22
+ name: list[str] | str
23
+ #: Learning rate used to train neural network
24
+ lr: tuple[float, float] | float
25
+ #: Learning rate momentum used to train neural network
26
+ momentum: tuple[float, float] | float
27
+ #: Scheduler applied to learning rate
28
+ scheduler: list[str] | str
29
+ #: Decay used by learning rate scheduler
30
+ decay: tuple[float, float] | float
@@ -0,0 +1,219 @@
1
+ import torch
2
+
3
+
4
+ def cox_neg_log_likelihood(risk: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
5
+ sort: bool = True) -> torch.Tensor:
6
+ """ Cox negative log partial likelihood.
7
+ From DeepSurv: https://doi.org/10.1186/s12874-018-0482-1
8
+
9
+ Parameters
10
+ ----------
11
+ risk: torch.Tensor, of shape (..., n_samples)
12
+ Risks estimated by model
13
+
14
+ event: torch.Tensor, of shape (..., n_samples)
15
+ Event indicator denoting whether time is of observed event or dropout
16
+
17
+ time: torch.Tensor, of shape (..., n_samples)
18
+ Time of either observed event or dropout
19
+
20
+ sort: bool
21
+ Whether to sort by time, otherwise assumes pre-sorted ground truth times and events
22
+
23
+ """
24
+ if sort:
25
+ # Sort risk and events by time
26
+ sort_idx = torch.argsort(time, descending=True)
27
+ risk = torch.gather(risk, dim=-1, index=sort_idx)
28
+ event = torch.gather(event, dim=-1, index=sort_idx)
29
+ # Due to the sorting before, log_risk[i] = log(sum(e^risk[j=0:i]) with time[j] >= time[i]
30
+ log_risk = torch.logcumsumexp(risk, dim=-1)
31
+ likelihood = (risk - log_risk) * event.float()
32
+ return - likelihood.sum(dim=-1) / event.sum(dim=-1)
33
+
34
+
35
+ def weibull_neg_log_likelihood(params: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
36
+ activations: bool = False) -> torch.Tensor:
37
+ """ Weibull negative log likelihood.
38
+ From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
39
+
40
+ Parameters
41
+ ----------
42
+ params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
43
+ Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
44
+
45
+ event: torch.Tensor, of shape (..., n_samples)
46
+ Event indicator denoting whether time is of observed event or dropout
47
+
48
+ time: torch.Tensor, of shape (..., n_samples)
49
+ Time of either observed event or dropout
50
+
51
+ activations: bool
52
+ Whether to apply ELU activations
53
+ """
54
+ # Extract mixture alpha, shape, and scale from params
55
+ if params.shape[-1] == 2 or params.shape[-1] == 3:
56
+ n_dists = 1
57
+ weight, shape, scale = None, params[..., -2], params[..., -1]
58
+ elif params.shape[-1] % 3 == 0:
59
+ n_dists = params.shape[-1] // 3
60
+ weight, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
61
+ time = time.unsqueeze(-1) # for proper broadcasting with mixture params
62
+ else:
63
+ raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
64
+
65
+ # Apply ELU activations if requested
66
+ if activations:
67
+ shape = torch.nn.functional.elu(shape) + 2
68
+ scale = torch.nn.functional.elu(scale) + 1 + 1e-5
69
+
70
+ # Guard against time=0, where log(time) = -inf, since a * event_true = nan which is later disregarded in nansum
71
+ event_true = event & (time.squeeze(-1) != 0)
72
+ event_false = ~event
73
+
74
+ a = torch.log(shape) - torch.log(scale) + (shape - 1) * (torch.log(time) - torch.log(scale))
75
+ b = -torch.pow(time / scale, shape)
76
+ if n_dists == 1:
77
+ # Simplified equation for single distribution
78
+ return -(torch.nansum(a * event_true.float(), dim=-1) + torch.sum(b, dim=-1)) / event.shape[-1]
79
+ else:
80
+ # Extended equation for multiple distributions
81
+ b += weight
82
+ return -(torch.nansum(torch.logsumexp(a + b, dim=-1) * event_true.float(), dim=-1)
83
+ + torch.sum(torch.logsumexp(b, dim=-1) * event_false.float(), dim=-1)
84
+ - torch.sum(torch.logsumexp(weight, dim=-1), dim=-1)) / event.shape[-1]
85
+
86
+
87
+ def weibull_neg_log_likelihood_original(params: torch.Tensor, event: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
88
+ """ Weibull negative log likelihood.
89
+ From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
90
+
91
+ Parameters
92
+ ----------
93
+ params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
94
+ Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
95
+
96
+ event: torch.Tensor, of shape (..., n_samples)
97
+ Event indicator denoting whether time is of observed event or dropout
98
+
99
+ time: torch.Tensor, of shape (..., n_samples)
100
+ Time of either observed event or dropout
101
+ """
102
+ # Extract mixture alpha, shape, and scale from params
103
+ if params.shape[-1] % 3 == 0:
104
+ n_dists = params.shape[-1] // 3
105
+ alphas, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
106
+ alphas, shape, scale = alphas.t(), shape.t(), scale.t()
107
+ else:
108
+ raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
109
+
110
+ # https://github.com/AchrafB2015/pydpwte/blob/133aeb1004adf6bc0fd1e6985b42fc5986a77e02/pydpwte/utils/loss.py#L37-L44
111
+ t_over_eta = torch.div(time, scale)
112
+ h1 = torch.exp(-torch.pow(t_over_eta, shape))
113
+ h1_bis = torch.pow(t_over_eta, shape - 1)
114
+ params_aux = torch.div(torch.mul(alphas, shape), scale)
115
+ return -torch.mean(event * torch.log(torch.sum(torch.mul(torch.mul(params_aux, h1_bis), h1), 0))
116
+ + (~event) * torch.log(torch.sum(alphas * h1, 0)))
117
+
118
+
119
+ def weibull_survival_time(params: torch.Tensor, activations: bool = False):
120
+ """ Survival time from mean of mixture of Weibull distributions.
121
+ From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
122
+
123
+ Parameters
124
+ ----------
125
+ params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
126
+ Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
127
+
128
+ activations: bool
129
+ Whether to apply ELU activations
130
+ """
131
+ # Extract mixture alpha, shape, and scale from params
132
+ if params.shape[-1] == 2 or params.shape[-1] == 3:
133
+ n_dists = 1
134
+ weight, shape, scale = 1, params[..., [-2]], params[..., [-1]]
135
+ elif params.shape[-1] % 3 == 0:
136
+ n_dists = params.shape[-1] // 3
137
+ weight, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
138
+ else:
139
+ raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
140
+
141
+ # Apply softmax if needed
142
+ if activations and n_dists > 1:
143
+ weight = torch.softmax(weight, dim=-1)
144
+
145
+ # Apply ELU activations if requested
146
+ if activations:
147
+ shape = torch.nn.functional.elu(shape) + 2
148
+ scale = torch.nn.functional.elu(scale) + 1 + 1e-5
149
+
150
+ # Return weighted mean as survival time estimate
151
+ return torch.sum(weight * scale * torch.exp(torch.lgamma(1 + 1 / shape)), dim=-1)
152
+
153
+
154
+ def mse_with_pairwise_rank(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
155
+ """ Extended mean squared error and pairwise ranking loss.
156
+ From RankDeepSurv: https://doi.org/10.1016/j.artmed.2019.06.001
157
+
158
+ """
159
+ # First part of the loss function is a simple mean squared error
160
+ error = time - estimate
161
+ loss1 = torch.mean(torch.square(error) * (event | (estimate <= time)).float(), dim=-1)
162
+
163
+ # Here, we add extra dimensions to enable pairwise comparisons of (i,j)
164
+ event_i, event_j = event.unsqueeze(-2), event.unsqueeze(-1)
165
+ time_i, time_j = time.unsqueeze(-2), time.unsqueeze(-1)
166
+ # The compatibility matrix specifies which pairs (i,j) can be compared accounting for censoring
167
+ comp = event_i & (event_j | (time_i <= time_j)) # matrix C in report
168
+
169
+ # Second part of the loss function encourages correct ranking among compatible pairs based on relative distance
170
+ # To save computations, we can reuse the errors needed for loss1 as
171
+ # (time_j - time_i) - (estimate_j - estimate_i) = (time_j - estimate_j) - (time_i - estimate_i)
172
+ error_i, error_j = error.unsqueeze(-2), error.unsqueeze(-1)
173
+ diff = torch.clamp(error_j - error_i, min=0) # matrix D in report, fused with condition
174
+ loss2 = torch.sum(diff * comp, dim=(-2, -1)) / event.shape[-1]
175
+
176
+ return loss1 + loss2
177
+
178
+
179
+ def discrete_with_pairwise_rank(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
180
+ alpha: float, sigma: float) -> torch.Tensor:
181
+ """ Discrete negative log likelihood and pairwise ranking loss.
182
+ From DeepHit: https://doi.org/10.1609/aaai.v32i1.11842
183
+
184
+ Parameters
185
+ ----------
186
+ estimate: torch.Tensor, of shape (..., n_samples, n_times)
187
+ Discrete time probabilities estimated by model
188
+
189
+ event: torch.Tensor, of shape (..., n_samples)
190
+ Event indicator denoting whether time is of observed event or dropout
191
+
192
+ time: torch.Tensor, of shape (..., n_samples)
193
+ Time of either observed event or dropout
194
+
195
+ alpha: float
196
+ Weight for the ranking loss component
197
+
198
+ sigma: float
199
+ Bandwidth of the radial basis function in the ranking loss component
200
+ """
201
+ cum_incidence = torch.cumsum(estimate, dim=-1)
202
+ estimate_t = torch.gather(estimate, dim=-1, index=time.unsqueeze(-1)).squeeze(-1)
203
+ cum_incidence_t = torch.gather(cum_incidence, dim=-1, index=time.unsqueeze(-1)).squeeze(-1)
204
+ eps = torch.exp(
205
+ torch.tensor(-100, device=estimate.device)) # -100 is also used in PyTorch's binary cross entropy as a cut-off
206
+ loss1 = (- torch.sum(event * torch.log(torch.clamp(estimate_t, min=eps)), dim=-1)
207
+ - torch.sum(~event * torch.log(torch.clamp(1 - cum_incidence_t, min=eps)), dim=-1))
208
+
209
+ # Here, we add extra dimensions to enable pairwise comparisons of (i,j)
210
+ event_i, event_j = event.unsqueeze(-2), event.unsqueeze(-1)
211
+ time_i, time_j = time.unsqueeze(-2), time.unsqueeze(-1)
212
+ # The acceptability matrix specifies which (i, j) can be compared
213
+ acc = event_i & (time_i < time_j)
214
+
215
+ cum_incidence_ti = cum_incidence_t.unsqueeze(-2)
216
+ cum_incidence_tij = cum_incidence[:, time]
217
+ loss2 = torch.sum(torch.exp(-(cum_incidence_ti - cum_incidence_tij) / sigma) * acc, dim=(-2, -1))
218
+
219
+ return loss1 + alpha * loss2
@@ -0,0 +1,57 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+
6
+ def concordance_index(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
7
+ mode: Literal['risk', 'time'] = 'risk') -> torch.Tensor:
8
+ r""" Compute Harrell's concordance index or c-index.
9
+
10
+ This essentially computes the ratio of correctly ordered pairs while accounting for censoring. Given estimates
11
+ :math:`\eta`, event indicators :math:`\delta`, and times :math:`t` it is described by
12
+
13
+ .. math::
14
+
15
+ C = \frac{\sum_{i,j} \mathbb{1}_{\eta_i < \eta_j} \mathbb{1}_{t_i > t_j} \delta_j}
16
+ {\sum_{i,j} \mathbb{1}_{t_i > t_j} \delta_j}.
17
+
18
+ If your model directly predicts survival time, rather than risk, you need to negate the estimates as this function
19
+ assumes an inverse relationship between estimates and times, i.e., if one increases the other decreases.
20
+
21
+ Note
22
+ ----
23
+ While Harrell's concordance index is easy to interpret, it is known to be biased in the presence of higher amounts
24
+ of censoring [1]_. An alternative is Uno's concordance index, as implemented in
25
+ `sksurv.metrics.concordance_index_ipcw`.
26
+
27
+ .. [1] H. Uno, T. Cai, M. J. Pencina, R. B. D’Agostino, and L. J. Wei, “On the C‐statistics for evaluating overall
28
+ adequacy of risk prediction procedures with censored survival data,” Statistics in Medicine, vol. 30, no. 10, pp.
29
+ 1105–1117, Jan. 2011, doi: 10.1002/sim.4154. Available: http://dx.doi.org/10.1002/sim.4154
30
+
31
+
32
+ Parameters
33
+ ----------
34
+ estimate: torch.Tensor, of shape (..., n_samples)
35
+ Risks or times estimated by model
36
+
37
+ event: torch.Tensor, of shape (..., n_samples)
38
+ Event indicator denoting whether time is of observed event or dropout
39
+
40
+ time: torch.Tensor, of shape (..., n_samples)
41
+ Time of either observed event or dropout
42
+
43
+ mode: Literal['risk', 'time']
44
+ Whether the passed estimates are risks or survival times
45
+
46
+ Returns
47
+ -------
48
+ torch.Tensor, of shape (...,)
49
+ Concordance index score
50
+ """
51
+ if mode == 'risk':
52
+ estimate_comp = estimate.unsqueeze(-1) < estimate.unsqueeze(-2)
53
+ else:
54
+ estimate_comp = estimate.unsqueeze(-1) > estimate.unsqueeze(-2)
55
+ time_comp = time.unsqueeze(-1) > time.unsqueeze(-2)
56
+ event = event.unsqueeze(-2)
57
+ return torch.sum(estimate_comp & time_comp & event, (-2, -1)) / torch.sum(time_comp & event, (-2, -1))
@@ -0,0 +1,3 @@
1
+ # alphabetical imports to make models more easily accessible
2
+ from .deepsurv import DeepSurv, DeepSurvSearchSpace
3
+ from .deephit import DeepHit
@@ -0,0 +1,192 @@
1
+ import functools
2
+ import warnings
3
+ from copy import deepcopy
4
+ from typing import TypedDict
5
+
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ import torchtuples as tt
10
+ from optuna.samplers import TPESampler
11
+ from pycox.evaluation import EvalSurv
12
+ from pycox.models import DeepHitSingle
13
+ from pycox.preprocessing.label_transforms import LabTransDiscreteTime
14
+ from sklearn.base import BaseEstimator
15
+ from sklearn.model_selection import StratifiedKFold, train_test_split
16
+ from sklearn.utils.validation import check_is_fitted, validate_data
17
+ from sksurv.base import SurvivalAnalysisMixin
18
+ from sksurv.util import check_array_survival
19
+ from torch import nn
20
+
21
+ from torch_survival.progress import OptunaProgressCallback
22
+ from torch_survival.utils import merge_configs
23
+
24
+
25
+ class DeepHitNetwork(nn.Module):
26
+ def __init__(self, n_inputs, n_times):
27
+ super().__init__()
28
+ self.shared_network = nn.Sequential(
29
+ nn.Linear(n_inputs, 3 * n_inputs),
30
+ nn.ReLU(),
31
+ nn.Dropout(p=0.6),
32
+ )
33
+ self.cause_network = nn.Sequential(
34
+ # input consists of shared network output + features
35
+ nn.Linear(4 * n_inputs, 5 * n_inputs),
36
+ nn.ReLU(),
37
+ nn.Dropout(p=0.6),
38
+ nn.Linear(5 * n_inputs, 3 * n_inputs),
39
+ nn.ReLU(),
40
+ nn.Dropout(p=0.6),
41
+ nn.Linear(3 * n_inputs, n_times),
42
+ )
43
+
44
+ def forward(self, x):
45
+ x = torch.concat((self.shared_network(x), x), dim=-1)
46
+ x = self.cause_network(x)
47
+ return x
48
+
49
+
50
+ class DeepHitSearchSpace(TypedDict):
51
+ #: Weight for the ranking loss component
52
+ alpha: tuple[float, float] | float
53
+ #: Bandwidth of the radial basis function in the ranking loss component
54
+ sigma: tuple[float, float] | float
55
+
56
+
57
+ class DeepHit(SurvivalAnalysisMixin, BaseEstimator):
58
+ r""" Implements the DeepHit model presented by Lee et al. [1]_.
59
+
60
+ Uses a deep neural network trained with the negative log likelihood of the survival time probability distribution,
61
+ as estimated over discrete time intervals, combined with a ranking loss.
62
+
63
+ .. [1] C. Lee, W. Zame, J. Yoon, and M. Van der Schaar, “DeepHit: A Deep Learning Approach to Survival Analysis
64
+ With Competing Risks,” AAAI, vol. 32, no. 1, Apr. 2018, doi: 10.1609/aaai.v32i1.11842. Available:
65
+ http://dx.doi.org/10.1609/aaai.v32i1.11842
66
+ """
67
+
68
+ #: Default hyperparameter search space
69
+ default_search_space: DeepHitSearchSpace = {
70
+ 'n_times': (10, 500),
71
+ 'alpha': (0, 1),
72
+ 'sigma': (1e-3, 1e1),
73
+ }
74
+
75
+ def __init__(self, search_space: DeepHitSearchSpace | None = None, n_epochs=100, random_state=None, device=None):
76
+ self.search_space = deepcopy(self.default_search_space)
77
+ if search_space:
78
+ self.search_space = merge_configs(self.search_space, search_space)
79
+ self.n_epochs = n_epochs
80
+ self.random_state = random_state
81
+ self.device = device
82
+ if self.device is None:
83
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
84
+
85
+ def _optimize(self, trial: optuna.Trial, X, y_event, y_time):
86
+ # Do 5-fold cross validation
87
+ scores = []
88
+ for train_idx, test_idx in StratifiedKFold(n_splits=5).split(X, y_event):
89
+ model, _ = self._train(trial, X[train_idx], y_event[train_idx], y_time[train_idx])
90
+ surv = model.predict_surv_df(X[test_idx])
91
+ c_index = EvalSurv(surv, y_time[test_idx], y_event[test_idx], censor_surv='km').concordance_td('antolini')
92
+ scores.append(c_index)
93
+ return - sum(scores) / len(scores)
94
+
95
+ def _train(self, trial: optuna.Trial, X, y_event, y_time):
96
+ n_inputs, n_outputs = X.shape[-1], 1
97
+
98
+ # Discretize event times
99
+ n_times = self.search_space['n_times']
100
+ if not isinstance(n_times, int):
101
+ n_times = trial.suggest_int('n_times', *n_times)
102
+ y_trans = LabTransDiscreteTime(n_times)
103
+ y = y_trans.fit_transform(y_time, y_event)
104
+
105
+ # Split into training and validation for early stopping
106
+ X_train, X_val, y_time_train, y_time_val, y_event_train, y_event_val = \
107
+ train_test_split(X, *y, test_size=0.2, stratify=y_event, random_state=self.random_state)
108
+ y_train = (y_time_train, y_event_train)
109
+ y_val = (y_time_val, y_event_val)
110
+
111
+ # Sample loss parameters
112
+ alpha = self.search_space['alpha']
113
+ if not isinstance(alpha, float):
114
+ alpha = trial.suggest_float('alpha', *alpha)
115
+ sigma = self.search_space['sigma']
116
+ if not isinstance(sigma, float):
117
+ sigma = trial.suggest_float('sigma', *sigma, log=True)
118
+
119
+ # Train and return model
120
+ net = DeepHitNetwork(n_inputs, y_trans.out_features)
121
+ model = DeepHitSingle(net, tt.optim.Adam(lr=1e-4), duration_index=y_trans.cuts,
122
+ alpha=alpha, sigma=sigma, device=self.device)
123
+ callbacks = [tt.callbacks.EarlyStopping(patience=10)]
124
+ model.fit(X_train, y_train, batch_size=50, epochs=100, callbacks=callbacks,
125
+ val_data=(X_val, y_val), verbose=False)
126
+ return model, y_trans.cuts
127
+
128
+ def fit(self, X, y):
129
+ """ Fit the model to the given survival data.
130
+
131
+ Parameters
132
+ ----------
133
+ X: array-like, shape = (n_samples, n_features)
134
+ Data matrix.
135
+ y: structured array, shape = (n_samples,)
136
+ A structured array with two fields. The first field is a boolean where ``True`` indicates an event and
137
+ ``False`` indicates right-censoring. The second field is a float with the time of event or time of
138
+ censoring.
139
+
140
+ Returns
141
+ -------
142
+ self
143
+ """
144
+ # Validate and extract data
145
+ X, y = validate_data(self, X, y)
146
+ X = X.astype(np.float32)
147
+ y_event, y_time = check_array_survival(X, y)
148
+
149
+ # Seed PyTorch random number generator
150
+ if self.random_state:
151
+ torch.manual_seed(self.random_state)
152
+
153
+ # Optimize hyperparameters
154
+ optuna.logging.disable_default_handler()
155
+ warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
156
+ with OptunaProgressCallback(model_name='DeepHit', n_trials=50) as callback:
157
+ study = optuna.create_study(sampler=TPESampler(seed=self.random_state) if self.random_state else None)
158
+ objective = functools.partial(self._optimize, X=X, y_event=y_event, y_time=y_time)
159
+ study.optimize(objective, n_trials=50, callbacks=[callback])
160
+
161
+ # Train model
162
+ self.optuna_params_ = study.best_params
163
+ self.model_, self.disc_times_ = self._train(study.best_trial, X, y_event, y_time)
164
+
165
+ return self
166
+
167
+ @torch.no_grad()
168
+ def predict(self, X):
169
+ """ Predict survival times.
170
+
171
+ The survival time is estimated based on the mean of the mixture of Weibull distributions predicted by the
172
+ neural network.
173
+
174
+ Parameters
175
+ ----------
176
+ X: array-like, shape = (n_samples, n_features)
177
+ Data matrix.
178
+
179
+ Returns
180
+ -------
181
+ survival_time: array, shape = (n_samples,)
182
+ Predicted survival times.
183
+ """
184
+ check_is_fitted(self)
185
+ X = validate_data(self, X)
186
+ X = X.astype(np.float32)
187
+ pmf = self.model_.predict_pmf(X)
188
+ times = pmf @ self.disc_times_
189
+ return times
190
+
191
+ def get_optuna_params(self):
192
+ return self.optuna_params_
@@ -0,0 +1,180 @@
1
+ import functools
2
+ import warnings
3
+ from copy import deepcopy
4
+ from typing import TypedDict
5
+
6
+ import optuna
7
+ import torch
8
+ from optuna.samplers import TPESampler
9
+ from sklearn.base import BaseEstimator
10
+ from sklearn.model_selection import StratifiedKFold
11
+ from sklearn.utils.validation import check_is_fitted, validate_data
12
+ from sksurv.base import SurvivalAnalysisMixin
13
+ from sksurv.util import check_array_survival
14
+
15
+ from torch_survival.config import NetworkConfig, OptimizerConfig
16
+ from torch_survival.losses import cox_neg_log_likelihood
17
+ from torch_survival.metrics import concordance_index
18
+ from torch_survival.progress import OptunaProgressCallback
19
+ from torch_survival.sample import sample_network, sample_optimizer
20
+ from torch_survival.utils import merge_configs
21
+
22
+
23
+ class DeepSurvSearchSpace(TypedDict):
24
+ #: Neural network configuration
25
+ network: NetworkConfig
26
+ #: Optimizer configuration
27
+ optimizer: OptimizerConfig
28
+
29
+
30
+ class DeepSurv(SurvivalAnalysisMixin, BaseEstimator):
31
+ r""" Implements the DeepSurv model presented by Katzman et al. [1]_.
32
+
33
+ Uses a deep neural network trained with the Cox negative log partial likelihood to estimate the risk of each
34
+ individual. The network's configuration is tuned using the Sobol solver. This implementation tries to stay faithful
35
+ to the original paper, with the following deviations:
36
+
37
+ * Optuna's default TPE sampler is used in favor of the Sobol sampler with 5-fold internal cross-validation instead
38
+ of 3-fold internal cross-validation.
39
+ * The hyperparameter search space is not detailed in the original paper, and the reference implementation is
40
+ underspecified. We thus define our own shared search space in `DeepSurvSearchSpace`.
41
+ * Our implementation does not support or tune :math:`\ell_2` regularization. We found this to be detrimental to
42
+ performance and were unable to fully replicate the described weight regularization.
43
+
44
+ .. [1] J. L. Katzman, U. Shaham, A. Cloninger, J. Bates, T. Jiang, and Y. Kluger, “DeepSurv: personalized treatment
45
+ recommender system using a Cox proportional hazards deep neural network,” BMC Med Res Methodol, vol. 18, no. 1,
46
+ Feb. 2018, doi: 10.1186/s12874-018-0482-1. Available: http://dx.doi.org/10.1186/s12874-018-0482-1
47
+ """
48
+
49
+ #: Default hyperparameter search space
50
+ default_search_space: DeepSurvSearchSpace = {
51
+ 'network': {
52
+ 'layers': {
53
+ 'max_layers': 4,
54
+ 'max_nodes_per_layer': 50,
55
+ },
56
+ 'activation': ['relu', 'selu'],
57
+ 'dropout': (0.0, 0.5),
58
+ },
59
+ 'optimizer': {
60
+ 'name': ['sgd', 'adam'],
61
+ 'lr': (1e-7, 1e-3),
62
+ 'scheduler': 'inverse_time',
63
+ 'decay': (0.0, 0.001),
64
+ 'momentum': (0.8, 0.95),
65
+ },
66
+ }
67
+
68
+ def __init__(self, search_space: DeepSurvSearchSpace | None = None, n_epochs=500, random_state=None, device=None):
69
+ self.search_space = deepcopy(self.default_search_space)
70
+ if search_space:
71
+ self.search_space = merge_configs(self.search_space, search_space)
72
+ self.n_epochs = n_epochs
73
+ self.random_state = random_state
74
+ self.device = device
75
+ if self.device is None:
76
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
77
+
78
+ def _optimize(self, trial: optuna.Trial, X, y_event, y_time):
79
+ # Do 5-fold cross validation
80
+ scores = []
81
+ for train_idx, test_idx in StratifiedKFold(n_splits=5).split(X, y_event.cpu()):
82
+ model = self._train(trial, X[train_idx], y_event[train_idx], y_time[train_idx])
83
+ model.eval()
84
+ risks = model(X[test_idx]).squeeze(dim=-1)
85
+ c_index = concordance_index(risks, y_event[test_idx], y_time[test_idx])
86
+ scores.append(c_index)
87
+ return - sum(scores) / len(scores)
88
+
89
+ def _train(self, trial: optuna.Trial, X, y_event, y_time):
90
+ # Set up model, optimizer, and scheduler
91
+ n_inputs, n_outputs = X.shape[-1], 1
92
+ model = sample_network(trial, self.search_space['network'], n_inputs, n_outputs)
93
+ optimizer, scheduler = sample_optimizer(trial, self.search_space['optimizer'], model)
94
+
95
+ # Pre-sort dataset based on time
96
+ sort_idx = torch.argsort(y_time, descending=True)
97
+ X = X[sort_idx]
98
+ y_event = y_event[sort_idx]
99
+ y_time = y_time[sort_idx]
100
+
101
+ # Train and return model
102
+ model.to(self.device)
103
+ for i in range(self.n_epochs):
104
+ optimizer.zero_grad()
105
+ risk = model(X).squeeze(-1)
106
+ loss = cox_neg_log_likelihood(risk, y_event, y_time, sort=False)
107
+ loss.backward()
108
+ optimizer.step()
109
+ if scheduler:
110
+ scheduler.step()
111
+ return model
112
+
113
+ def fit(self, X, y):
114
+ """ Fit the model to the given survival data.
115
+
116
+ Parameters
117
+ ----------
118
+ X: array-like, shape = (n_samples, n_features)
119
+ Data matrix.
120
+ y: structured array, shape = (n_samples,)
121
+ A structured array with two fields. The first field is a boolean where ``True`` indicates an event and
122
+ ``False`` indicates right-censoring. The second field is a float with the time of event or time of
123
+ censoring.
124
+
125
+ Returns
126
+ -------
127
+ self
128
+ """
129
+ # Validate and extract data
130
+ X, y = validate_data(self, X, y)
131
+ y_event, y_time = check_array_survival(X, y)
132
+
133
+ # Convert data to tensors
134
+ X = torch.as_tensor(X, dtype=torch.float32, device=self.device)
135
+ y_event = torch.as_tensor(y_event, dtype=torch.bool, device=self.device)
136
+ y_time = torch.as_tensor(y_time.copy(), dtype=torch.float32, device=self.device)
137
+
138
+ # Seed PyTorch random number generator
139
+ if self.random_state:
140
+ torch.manual_seed(self.random_state)
141
+
142
+ # Optimize hyperparameters
143
+ optuna.logging.disable_default_handler()
144
+ warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
145
+ with OptunaProgressCallback(model_name='DeepSurv', n_trials=50) as callback:
146
+ study = optuna.create_study(sampler=TPESampler(seed=self.random_state) if self.random_state else None)
147
+ objective = functools.partial(self._optimize, X=X, y_event=y_event, y_time=y_time)
148
+ study.optimize(objective, n_trials=50, callbacks=[callback])
149
+
150
+ # Train final model with best hyperparameters
151
+ self.optuna_params_ = study.best_params
152
+ self.model_ = self._train(study.best_trial, X, y_event, y_time)
153
+ self.model_.eval()
154
+
155
+ return self
156
+
157
+ @torch.no_grad()
158
+ def predict(self, X):
159
+ """ Predict risk scores.
160
+
161
+ The risk score is predicted directly by a neural network. A higher score indicates a higher risk of experiencing
162
+ the event.
163
+
164
+ Parameters
165
+ ----------
166
+ X: array-like, shape = (n_samples, n_features)
167
+ Data matrix.
168
+
169
+ Returns
170
+ -------
171
+ risk_score: array, shape = (n_samples,)
172
+ Predicted risk scores.
173
+ """
174
+ check_is_fitted(self)
175
+ X = validate_data(self, X)
176
+ X = torch.as_tensor(X, dtype=torch.float32, device=self.device)
177
+ return self.model_(X).detach().squeeze(dim=-1).cpu().numpy()
178
+
179
+ def get_optuna_params(self):
180
+ return self.optuna_params_
@@ -0,0 +1,29 @@
1
+ import optuna
2
+ from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, TimeRemainingColumn
3
+
4
+
5
+ class OptunaProgressCallback:
6
+ def __init__(self, model_name, n_trials, verbose=1):
7
+ self.verbose = verbose
8
+ if self.verbose > 0:
9
+ # At verbosity = 1 we show a progress bar for trials
10
+ self.progress = Progress(
11
+ TextColumn('{task.description}'),
12
+ BarColumn(),
13
+ TextColumn('{task.fields[score]}'),
14
+ MofNCompleteColumn(),
15
+ TimeRemainingColumn(),
16
+ )
17
+ self.tune_task = self.progress.add_task('[green]Optimizing ' + model_name, total=n_trials, score='-.--')
18
+
19
+ def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
20
+ if self.verbose > 0:
21
+ score = '{:.2f}'.format(abs(study.best_value))
22
+ self.progress.update(self.tune_task, advance=1, score=score)
23
+
24
+ def __enter__(self):
25
+ self.progress.start()
26
+ return self
27
+
28
+ def __exit__(self, exc_type, exc_value, traceback):
29
+ self.progress.stop()
File without changes
@@ -0,0 +1,161 @@
1
+ from typing import Mapping, Any
2
+
3
+ import optuna
4
+ import torch.optim as optim
5
+ import torch.optim.lr_scheduler as sched
6
+ from torch import nn
7
+
8
+ from torch_survival.config import NetworkConfig, OptimizerConfig
9
+ from torch_survival.utils import make_activation
10
+
11
+
12
+ class SimpleNeuralNetwork(nn.Module):
13
+ def __init__(self, n_inputs, n_outputs, layers, activation, dropout):
14
+ super().__init__()
15
+ hidden = []
16
+ n_nodes = n_inputs
17
+ for nodes in layers:
18
+ hidden.append(nn.Linear(n_nodes, nodes))
19
+ hidden.append(make_activation(activation))
20
+ hidden.append(nn.Dropout(p=dropout))
21
+ n_nodes = nodes
22
+ self.hidden = nn.Sequential(*hidden)
23
+ self.output = nn.Linear(n_nodes, n_outputs)
24
+
25
+ def forward(self, x):
26
+ x = self.hidden(x)
27
+ return self.output(x)
28
+
29
+
30
+ def sample_network(trial: optuna.Trial, config: NetworkConfig, n_inputs: int, n_outputs: int):
31
+ """
32
+ Sample a simple neural network using the provided configuration.
33
+
34
+ Parameters
35
+ ----------
36
+ trial: optuna.Trial
37
+ Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
38
+ config: config.NetworkConfig
39
+ Network configuration specifying architecture and search constraints.
40
+ n_inputs: int
41
+ Number of inputs of the neural network. Typically depends on the data.
42
+ n_outputs: int
43
+ Number of outputs of the neural network. Typically depends on the modeling technique.
44
+
45
+ Returns
46
+ -------
47
+ nn.Module
48
+ PyTorch model consisting of linear layers.
49
+ """
50
+ layers = config['layers']
51
+ if not isinstance(layers, list):
52
+ n_layers = trial.suggest_int('layers', 0, layers['max_layers'])
53
+ layers = [trial.suggest_int('nodes_' + str(i + 1), 1, layers['max_nodes_per_layer']) for i in range(n_layers)]
54
+ activation = config['activation']
55
+ if not isinstance(activation, str):
56
+ activation = trial.suggest_categorical('activation', activation)
57
+ dropout = config['dropout']
58
+ if not isinstance(dropout, float):
59
+ dropout = trial.suggest_float('dropout', low=dropout[0], high=dropout[1])
60
+ return SimpleNeuralNetwork(n_inputs, n_outputs, layers, activation, dropout)
61
+
62
+
63
+ def sample_optimizer(trial: optuna.Trial, config: OptimizerConfig, model: nn.Module):
64
+ """
65
+ Sample optimizer and scheduler for training using the provided configuration.
66
+
67
+ Parameters
68
+ ----------
69
+ trial: optuna.Trial
70
+ Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
71
+ config: config.OptimizerConfig
72
+ Optimizer configuration specifying how the neural network should be trained.
73
+ model: nn.Module
74
+ Neural network whose parameters should be optimized.
75
+
76
+ Returns
77
+ -------
78
+ torch.optim.Optimizer
79
+ PyTorch optimizer that can be used for training.
80
+ torch.optim.lr_scheduler.LRScheduler or None
81
+ PyTorch scheduler that can be used to update learning rates.
82
+ """
83
+ # Sample optimizer parameters
84
+ lr = config['lr']
85
+ if not isinstance(lr, float):
86
+ lr = trial.suggest_float('lr', *lr, log=True)
87
+ momentum = config['momentum']
88
+ if not isinstance(momentum, float):
89
+ momentum = trial.suggest_float('momentum', *momentum)
90
+ # Initialize optimizer
91
+ optimizer = None
92
+ optimizer_name = config['name']
93
+ if not isinstance(optimizer_name, str):
94
+ optimizer_name = trial.suggest_categorical('optimizer', optimizer_name)
95
+ if optimizer_name == 'sgd':
96
+ optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
97
+ if optimizer_name == 'adam':
98
+ optimizer = optim.Adam(model.parameters(), lr=lr, betas=(momentum, 0.999))
99
+ if optimizer is None:
100
+ raise ValueError('Optimizer with name `{}` is not supported'.format(config['name']))
101
+ # Sample optimizer parameters
102
+ decay = config['decay']
103
+ if not isinstance(decay, float):
104
+ decay = trial.suggest_float('decay', *decay)
105
+ # Initialize scheduler
106
+ scheduler = None
107
+ scheduler_name = config['scheduler']
108
+ if not isinstance(scheduler_name, str):
109
+ scheduler_name = trial.suggest_categorical('scheduler', scheduler_name)
110
+ if scheduler_name == 'inverse_time':
111
+ scheduler = sched.LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + epoch * decay))
112
+ # Return both
113
+ return optimizer, scheduler
114
+
115
+
116
+ def sample_int(trial: optuna.Trial, config: Mapping[str, Any], key: str) -> int:
117
+ """
118
+ Helper to sample an integer value from a search space configuration.
119
+
120
+ Parameters
121
+ ----------
122
+ trial: optuna.Trial
123
+ Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
124
+ config: Mapping[str, Any]
125
+ Search space configuration dictionary.
126
+ key: str
127
+ Name of configuration parameter that should be sampled.
128
+
129
+ Returns
130
+ -------
131
+ int:
132
+ The sampled value.
133
+ """
134
+ value = config[key]
135
+ if not isinstance(value, int):
136
+ value = trial.suggest_int(key, *value)
137
+ return value
138
+
139
+
140
+ def sample_float(trial: optuna.Trial, config: Mapping[str, Any], key: str) -> float:
141
+ """
142
+ Helper to sample a floating point value from a search space configuration.
143
+
144
+ Parameters
145
+ ----------
146
+ trial: optuna.Trial
147
+ Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
148
+ config: Mapping[str, Any]
149
+ Search space configuration dictionary.
150
+ key: str
151
+ Name of configuration parameter that should be sampled.
152
+
153
+ Returns
154
+ -------
155
+ float:
156
+ The sampled value.
157
+ """
158
+ value = config[key]
159
+ if not isinstance(value, float):
160
+ value = trial.suggest_float(key, *value)
161
+ return value
@@ -0,0 +1,87 @@
1
+ import inspect
2
+ from typing import TypeVar
3
+
4
+ import torch.nn as nn
5
+
6
+ _TypedDict = TypeVar('_TypedDict')
7
+
8
+
9
+ def make_activation(query: str):
10
+ """
11
+ Resolves the name of an activation function to its PyTorch layer. This method uses inspection and should thus be
12
+ able to resolve any current PyTorch activation function, e.g., `relu` gets mapped to torch.nn.ReLU, and so forth.
13
+
14
+ Parameters
15
+ ----------
16
+ query: str
17
+ Name of activation function that should be created
18
+
19
+ Returns
20
+ -------
21
+ torch.nn.Module
22
+ PyTorch layer corresponding to query, initialized with default values
23
+ """
24
+ query = query.lower()
25
+ matches = [obj for name, obj in inspect.getmembers(nn) if query == name.lower()]
26
+ if len(matches) == 0:
27
+ raise ValueError('Found no candidate for `{}` activation'.format(query))
28
+ if len(matches) > 1:
29
+ raise ValueError('Found multiple candidates for `{}` activation'.format(query))
30
+ return matches[0]()
31
+
32
+
33
+ def merge_configs(default_config: _TypedDict, user_config: _TypedDict) -> _TypedDict:
34
+ """
35
+ Recursively merges default configuration with user configuration.
36
+
37
+ Parameters
38
+ ----------
39
+ default_config: subclass of TypedDict
40
+ Default model configuration
41
+ user_config: subclass of TypedDict
42
+ User-provided model configuration overriding defaults
43
+
44
+ Returns
45
+ -------
46
+ subclass of TypedDict
47
+ Merged model configuration
48
+ """
49
+ for k, v in user_config.items():
50
+ if k in default_config:
51
+ if isinstance(default_config[k], dict) and isinstance(v, dict):
52
+ default_config[k] = merge_configs(default_config[k], v)
53
+ else:
54
+ default_config[k] = v
55
+ return default_config
56
+
57
+ # def sample_params(search_space, trial: optuna.Trial):
58
+ # """
59
+ # Samples hyperparameter search space, allowing for fixed
60
+ #
61
+ # Parameters
62
+ # ----------
63
+ # search_space: Mapping[str, ParamConfig]
64
+ # Dictionary of parameter configurations.
65
+ # trial: optuna.Trial
66
+ # Active trial for sampling parameters.
67
+ #
68
+ # Returns
69
+ # -------
70
+ # params: dict[str, int | float | str]
71
+ # Sampled parameters.
72
+ # """
73
+ # params = {}
74
+ # for key, value in search_space.items():
75
+ # if type(value) in [int, float, str]:
76
+ # # The value is already fixed and will be passed on as-is
77
+ # params[key] = value
78
+ # elif type(value) is list:
79
+ # # The value is a list of possible categories
80
+ # params[key] = trial.suggest_categorical(key, value)
81
+ # elif type(value) is tuple and type(value[0]) is int:
82
+ # # The value are lower and upper bounds for an integer
83
+ # params[key] = trial.suggest_int(key, *value[:2], log=value[2] if len(value) > 2 else False)
84
+ # elif type(value) is tuple and type(value[0]) is float:
85
+ # # The value are lower and upper bounds for an integer
86
+ # params[key] = trial.suggest_float(key, *value[:2], log=value[2] if len(value) > 2 else False)
87
+ # return params