torch-survival 0.1.0a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_survival-0.1.0a2/PKG-INFO +21 -0
- torch_survival-0.1.0a2/README.md +0 -0
- torch_survival-0.1.0a2/pyproject.toml +40 -0
- torch_survival-0.1.0a2/src/torch_survival/__init__.py +0 -0
- torch_survival-0.1.0a2/src/torch_survival/config.py +30 -0
- torch_survival-0.1.0a2/src/torch_survival/losses.py +219 -0
- torch_survival-0.1.0a2/src/torch_survival/metrics.py +57 -0
- torch_survival-0.1.0a2/src/torch_survival/models/__init__.py +3 -0
- torch_survival-0.1.0a2/src/torch_survival/models/deephit.py +192 -0
- torch_survival-0.1.0a2/src/torch_survival/models/deepsurv.py +180 -0
- torch_survival-0.1.0a2/src/torch_survival/progress.py +29 -0
- torch_survival-0.1.0a2/src/torch_survival/py.typed +0 -0
- torch_survival-0.1.0a2/src/torch_survival/sample.py +161 -0
- torch_survival-0.1.0a2/src/torch_survival/utils.py +87 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: torch-survival
|
|
3
|
+
Version: 0.1.0a2
|
|
4
|
+
Summary: Production-ready deep survival analysis with one line of code
|
|
5
|
+
Requires-Dist: matplotlib>=3.10.8
|
|
6
|
+
Requires-Dist: optuna>=4.5.0
|
|
7
|
+
Requires-Dist: optuna-integration[sklearn]>=4.6.0
|
|
8
|
+
Requires-Dist: pycox>=0.3.0
|
|
9
|
+
Requires-Dist: rich>=14.2.0
|
|
10
|
+
Requires-Dist: scikit-survival>=0.24.1
|
|
11
|
+
Requires-Dist: torch>=2.8.0
|
|
12
|
+
Requires-Dist: typing-extensions>=4.15.0
|
|
13
|
+
Requires-Dist: furo>=2025.7.19 ; extra == 'docs'
|
|
14
|
+
Requires-Dist: sphinx>=8.2.3 ; extra == 'docs'
|
|
15
|
+
Requires-Dist: h5py>=3.14.0 ; extra == 'test'
|
|
16
|
+
Requires-Dist: torchsurv>=0.1.5 ; extra == 'test'
|
|
17
|
+
Requires-Python: >=3.12
|
|
18
|
+
Provides-Extra: docs
|
|
19
|
+
Provides-Extra: test
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "torch-survival"
|
|
3
|
+
version = "0.1.0-alpha.2"
|
|
4
|
+
description = "Production-ready deep survival analysis with one line of code"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"matplotlib>=3.10.8",
|
|
9
|
+
"optuna>=4.5.0",
|
|
10
|
+
"optuna-integration[sklearn]>=4.6.0",
|
|
11
|
+
"pycox>=0.3.0",
|
|
12
|
+
"rich>=14.2.0",
|
|
13
|
+
"scikit-survival>=0.24.1",
|
|
14
|
+
"torch>=2.8.0",
|
|
15
|
+
"typing-extensions>=4.15.0",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
docs = [
|
|
20
|
+
"furo>=2025.7.19",
|
|
21
|
+
"sphinx>=8.2.3",
|
|
22
|
+
]
|
|
23
|
+
test = [
|
|
24
|
+
"h5py>=3.14.0",
|
|
25
|
+
"torchsurv>=0.1.5",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[[tool.uv.index]]
|
|
29
|
+
name = "pytorch-cu128"
|
|
30
|
+
url = "https://download.pytorch.org/whl/cu128"
|
|
31
|
+
explicit = true
|
|
32
|
+
|
|
33
|
+
[tool.uv.sources]
|
|
34
|
+
torch = [
|
|
35
|
+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[build-system]
|
|
39
|
+
requires = ["uv_build>=0.8.11,<0.9.0"]
|
|
40
|
+
build-backend = "uv_build"
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing_extensions import TypedDict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LayerConfig(TypedDict, closed=True):
|
|
5
|
+
# Maximum number of hidden layers
|
|
6
|
+
max_layers: int
|
|
7
|
+
# Maximum number of nodes per hidden layer
|
|
8
|
+
max_nodes_per_layer: int
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NetworkConfig(TypedDict, closed=True):
|
|
12
|
+
#: Layers of the neural network
|
|
13
|
+
layers: LayerConfig | list[int]
|
|
14
|
+
#: Activation function used in hidden layers
|
|
15
|
+
activation: list[str] | str
|
|
16
|
+
#: Probability of dropout for dropout layers
|
|
17
|
+
dropout: tuple[float, float] | float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OptimizerConfig(TypedDict, closed=True):
|
|
21
|
+
#: Optimizer used to train neural network
|
|
22
|
+
name: list[str] | str
|
|
23
|
+
#: Learning rate used to train neural network
|
|
24
|
+
lr: tuple[float, float] | float
|
|
25
|
+
#: Learning rate momentum used to train neural network
|
|
26
|
+
momentum: tuple[float, float] | float
|
|
27
|
+
#: Scheduler applied to learning rate
|
|
28
|
+
scheduler: list[str] | str
|
|
29
|
+
#: Decay used by learning rate scheduler
|
|
30
|
+
decay: tuple[float, float] | float
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def cox_neg_log_likelihood(risk: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
|
|
5
|
+
sort: bool = True) -> torch.Tensor:
|
|
6
|
+
""" Cox negative log partial likelihood.
|
|
7
|
+
From DeepSurv: https://doi.org/10.1186/s12874-018-0482-1
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
risk: torch.Tensor, of shape (..., n_samples)
|
|
12
|
+
Risks estimated by model
|
|
13
|
+
|
|
14
|
+
event: torch.Tensor, of shape (..., n_samples)
|
|
15
|
+
Event indicator denoting whether time is of observed event or dropout
|
|
16
|
+
|
|
17
|
+
time: torch.Tensor, of shape (..., n_samples)
|
|
18
|
+
Time of either observed event or dropout
|
|
19
|
+
|
|
20
|
+
sort: bool
|
|
21
|
+
Whether to sort by time, otherwise assumes pre-sorted ground truth times and events
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
if sort:
|
|
25
|
+
# Sort risk and events by time
|
|
26
|
+
sort_idx = torch.argsort(time, descending=True)
|
|
27
|
+
risk = torch.gather(risk, dim=-1, index=sort_idx)
|
|
28
|
+
event = torch.gather(event, dim=-1, index=sort_idx)
|
|
29
|
+
# Due to the sorting before, log_risk[i] = log(sum(e^risk[j=0:i]) with time[j] >= time[i]
|
|
30
|
+
log_risk = torch.logcumsumexp(risk, dim=-1)
|
|
31
|
+
likelihood = (risk - log_risk) * event.float()
|
|
32
|
+
return - likelihood.sum(dim=-1) / event.sum(dim=-1)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def weibull_neg_log_likelihood(params: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
|
|
36
|
+
activations: bool = False) -> torch.Tensor:
|
|
37
|
+
""" Weibull negative log likelihood.
|
|
38
|
+
From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
|
|
43
|
+
Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
|
|
44
|
+
|
|
45
|
+
event: torch.Tensor, of shape (..., n_samples)
|
|
46
|
+
Event indicator denoting whether time is of observed event or dropout
|
|
47
|
+
|
|
48
|
+
time: torch.Tensor, of shape (..., n_samples)
|
|
49
|
+
Time of either observed event or dropout
|
|
50
|
+
|
|
51
|
+
activations: bool
|
|
52
|
+
Whether to apply ELU activations
|
|
53
|
+
"""
|
|
54
|
+
# Extract mixture alpha, shape, and scale from params
|
|
55
|
+
if params.shape[-1] == 2 or params.shape[-1] == 3:
|
|
56
|
+
n_dists = 1
|
|
57
|
+
weight, shape, scale = None, params[..., -2], params[..., -1]
|
|
58
|
+
elif params.shape[-1] % 3 == 0:
|
|
59
|
+
n_dists = params.shape[-1] // 3
|
|
60
|
+
weight, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
|
|
61
|
+
time = time.unsqueeze(-1) # for proper broadcasting with mixture params
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
|
|
64
|
+
|
|
65
|
+
# Apply ELU activations if requested
|
|
66
|
+
if activations:
|
|
67
|
+
shape = torch.nn.functional.elu(shape) + 2
|
|
68
|
+
scale = torch.nn.functional.elu(scale) + 1 + 1e-5
|
|
69
|
+
|
|
70
|
+
# Guard against time=0, where log(time) = -inf, since a * event_true = nan which is later disregarded in nansum
|
|
71
|
+
event_true = event & (time.squeeze(-1) != 0)
|
|
72
|
+
event_false = ~event
|
|
73
|
+
|
|
74
|
+
a = torch.log(shape) - torch.log(scale) + (shape - 1) * (torch.log(time) - torch.log(scale))
|
|
75
|
+
b = -torch.pow(time / scale, shape)
|
|
76
|
+
if n_dists == 1:
|
|
77
|
+
# Simplified equation for single distribution
|
|
78
|
+
return -(torch.nansum(a * event_true.float(), dim=-1) + torch.sum(b, dim=-1)) / event.shape[-1]
|
|
79
|
+
else:
|
|
80
|
+
# Extended equation for multiple distributions
|
|
81
|
+
b += weight
|
|
82
|
+
return -(torch.nansum(torch.logsumexp(a + b, dim=-1) * event_true.float(), dim=-1)
|
|
83
|
+
+ torch.sum(torch.logsumexp(b, dim=-1) * event_false.float(), dim=-1)
|
|
84
|
+
- torch.sum(torch.logsumexp(weight, dim=-1), dim=-1)) / event.shape[-1]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def weibull_neg_log_likelihood_original(params: torch.Tensor, event: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
|
|
88
|
+
""" Weibull negative log likelihood.
|
|
89
|
+
From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
|
|
94
|
+
Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
|
|
95
|
+
|
|
96
|
+
event: torch.Tensor, of shape (..., n_samples)
|
|
97
|
+
Event indicator denoting whether time is of observed event or dropout
|
|
98
|
+
|
|
99
|
+
time: torch.Tensor, of shape (..., n_samples)
|
|
100
|
+
Time of either observed event or dropout
|
|
101
|
+
"""
|
|
102
|
+
# Extract mixture alpha, shape, and scale from params
|
|
103
|
+
if params.shape[-1] % 3 == 0:
|
|
104
|
+
n_dists = params.shape[-1] // 3
|
|
105
|
+
alphas, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
|
|
106
|
+
alphas, shape, scale = alphas.t(), shape.t(), scale.t()
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
|
|
109
|
+
|
|
110
|
+
# https://github.com/AchrafB2015/pydpwte/blob/133aeb1004adf6bc0fd1e6985b42fc5986a77e02/pydpwte/utils/loss.py#L37-L44
|
|
111
|
+
t_over_eta = torch.div(time, scale)
|
|
112
|
+
h1 = torch.exp(-torch.pow(t_over_eta, shape))
|
|
113
|
+
h1_bis = torch.pow(t_over_eta, shape - 1)
|
|
114
|
+
params_aux = torch.div(torch.mul(alphas, shape), scale)
|
|
115
|
+
return -torch.mean(event * torch.log(torch.sum(torch.mul(torch.mul(params_aux, h1_bis), h1), 0))
|
|
116
|
+
+ (~event) * torch.log(torch.sum(alphas * h1, 0)))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def weibull_survival_time(params: torch.Tensor, activations: bool = False):
|
|
120
|
+
""" Survival time from mean of mixture of Weibull distributions.
|
|
121
|
+
From DeepWeiSurv: https://doi.org/10.1007/978-3-030-47426-3_53
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
params: torch.Tensor, of shape (..., n_samples, 2 | p * 3)
|
|
126
|
+
Parameters of the mixture of p Weibull distributions, namely weights, shapes, and scales
|
|
127
|
+
|
|
128
|
+
activations: bool
|
|
129
|
+
Whether to apply ELU activations
|
|
130
|
+
"""
|
|
131
|
+
# Extract mixture alpha, shape, and scale from params
|
|
132
|
+
if params.shape[-1] == 2 or params.shape[-1] == 3:
|
|
133
|
+
n_dists = 1
|
|
134
|
+
weight, shape, scale = 1, params[..., [-2]], params[..., [-1]]
|
|
135
|
+
elif params.shape[-1] % 3 == 0:
|
|
136
|
+
n_dists = params.shape[-1] // 3
|
|
137
|
+
weight, shape, scale = params[..., 0:n_dists], params[..., n_dists:2 * n_dists], params[..., 2 * n_dists:]
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError('Unexpected number of Weibull parameters: ' + str(params.shape[-1]))
|
|
140
|
+
|
|
141
|
+
# Apply softmax if needed
|
|
142
|
+
if activations and n_dists > 1:
|
|
143
|
+
weight = torch.softmax(weight, dim=-1)
|
|
144
|
+
|
|
145
|
+
# Apply ELU activations if requested
|
|
146
|
+
if activations:
|
|
147
|
+
shape = torch.nn.functional.elu(shape) + 2
|
|
148
|
+
scale = torch.nn.functional.elu(scale) + 1 + 1e-5
|
|
149
|
+
|
|
150
|
+
# Return weighted mean as survival time estimate
|
|
151
|
+
return torch.sum(weight * scale * torch.exp(torch.lgamma(1 + 1 / shape)), dim=-1)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def mse_with_pairwise_rank(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
|
|
155
|
+
""" Extended mean squared error and pairwise ranking loss.
|
|
156
|
+
From RankDeepSurv: https://doi.org/10.1016/j.artmed.2019.06.001
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
# First part of the loss function is a simple mean squared error
|
|
160
|
+
error = time - estimate
|
|
161
|
+
loss1 = torch.mean(torch.square(error) * (event | (estimate <= time)).float(), dim=-1)
|
|
162
|
+
|
|
163
|
+
# Here, we add extra dimensions to enable pairwise comparisons of (i,j)
|
|
164
|
+
event_i, event_j = event.unsqueeze(-2), event.unsqueeze(-1)
|
|
165
|
+
time_i, time_j = time.unsqueeze(-2), time.unsqueeze(-1)
|
|
166
|
+
# The compatibility matrix specifies which pairs (i,j) can be compared accounting for censoring
|
|
167
|
+
comp = event_i & (event_j | (time_i <= time_j)) # matrix C in report
|
|
168
|
+
|
|
169
|
+
# Second part of the loss function encourages correct ranking among compatible pairs based on relative distance
|
|
170
|
+
# To save computations, we can reuse the errors needed for loss1 as
|
|
171
|
+
# (time_j - time_i) - (estimate_j - estimate_i) = (time_j - estimate_j) - (time_i - estimate_i)
|
|
172
|
+
error_i, error_j = error.unsqueeze(-2), error.unsqueeze(-1)
|
|
173
|
+
diff = torch.clamp(error_j - error_i, min=0) # matrix D in report, fused with condition
|
|
174
|
+
loss2 = torch.sum(diff * comp, dim=(-2, -1)) / event.shape[-1]
|
|
175
|
+
|
|
176
|
+
return loss1 + loss2
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def discrete_with_pairwise_rank(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
|
|
180
|
+
alpha: float, sigma: float) -> torch.Tensor:
|
|
181
|
+
""" Discrete negative log likelihood and pairwise ranking loss.
|
|
182
|
+
From DeepHit: https://doi.org/10.1609/aaai.v32i1.11842
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
estimate: torch.Tensor, of shape (..., n_samples, n_times)
|
|
187
|
+
Discrete time probabilities estimated by model
|
|
188
|
+
|
|
189
|
+
event: torch.Tensor, of shape (..., n_samples)
|
|
190
|
+
Event indicator denoting whether time is of observed event or dropout
|
|
191
|
+
|
|
192
|
+
time: torch.Tensor, of shape (..., n_samples)
|
|
193
|
+
Time of either observed event or dropout
|
|
194
|
+
|
|
195
|
+
alpha: float
|
|
196
|
+
Weight for the ranking loss component
|
|
197
|
+
|
|
198
|
+
sigma: float
|
|
199
|
+
Bandwidth of the radial basis function in the ranking loss component
|
|
200
|
+
"""
|
|
201
|
+
cum_incidence = torch.cumsum(estimate, dim=-1)
|
|
202
|
+
estimate_t = torch.gather(estimate, dim=-1, index=time.unsqueeze(-1)).squeeze(-1)
|
|
203
|
+
cum_incidence_t = torch.gather(cum_incidence, dim=-1, index=time.unsqueeze(-1)).squeeze(-1)
|
|
204
|
+
eps = torch.exp(
|
|
205
|
+
torch.tensor(-100, device=estimate.device)) # -100 is also used in PyTorch's binary cross entropy as a cut-off
|
|
206
|
+
loss1 = (- torch.sum(event * torch.log(torch.clamp(estimate_t, min=eps)), dim=-1)
|
|
207
|
+
- torch.sum(~event * torch.log(torch.clamp(1 - cum_incidence_t, min=eps)), dim=-1))
|
|
208
|
+
|
|
209
|
+
# Here, we add extra dimensions to enable pairwise comparisons of (i,j)
|
|
210
|
+
event_i, event_j = event.unsqueeze(-2), event.unsqueeze(-1)
|
|
211
|
+
time_i, time_j = time.unsqueeze(-2), time.unsqueeze(-1)
|
|
212
|
+
# The acceptability matrix specifies which (i, j) can be compared
|
|
213
|
+
acc = event_i & (time_i < time_j)
|
|
214
|
+
|
|
215
|
+
cum_incidence_ti = cum_incidence_t.unsqueeze(-2)
|
|
216
|
+
cum_incidence_tij = cum_incidence[:, time]
|
|
217
|
+
loss2 = torch.sum(torch.exp(-(cum_incidence_ti - cum_incidence_tij) / sigma) * acc, dim=(-2, -1))
|
|
218
|
+
|
|
219
|
+
return loss1 + alpha * loss2
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def concordance_index(estimate: torch.Tensor, event: torch.Tensor, time: torch.Tensor,
|
|
7
|
+
mode: Literal['risk', 'time'] = 'risk') -> torch.Tensor:
|
|
8
|
+
r""" Compute Harrell's concordance index or c-index.
|
|
9
|
+
|
|
10
|
+
This essentially computes the ratio of correctly ordered pairs while accounting for censoring. Given estimates
|
|
11
|
+
:math:`\eta`, event indicators :math:`\delta`, and times :math:`t` it is described by
|
|
12
|
+
|
|
13
|
+
.. math::
|
|
14
|
+
|
|
15
|
+
C = \frac{\sum_{i,j} \mathbb{1}_{\eta_i < \eta_j} \mathbb{1}_{t_i > t_j} \delta_j}
|
|
16
|
+
{\sum_{i,j} \mathbb{1}_{t_i > t_j} \delta_j}.
|
|
17
|
+
|
|
18
|
+
If your model directly predicts survival time, rather than risk, you need to negate the estimates as this function
|
|
19
|
+
assumes an inverse relationship between estimates and times, i.e., if one increases the other decreases.
|
|
20
|
+
|
|
21
|
+
Note
|
|
22
|
+
----
|
|
23
|
+
While Harrell's concordance index is easy to interpret, it is known to be biased in the presence of higher amounts
|
|
24
|
+
of censoring [1]_. An alternative is Uno's concordance index, as implemented in
|
|
25
|
+
`sksurv.metrics.concordance_index_ipcw`.
|
|
26
|
+
|
|
27
|
+
.. [1] H. Uno, T. Cai, M. J. Pencina, R. B. D’Agostino, and L. J. Wei, “On the C‐statistics for evaluating overall
|
|
28
|
+
adequacy of risk prediction procedures with censored survival data,” Statistics in Medicine, vol. 30, no. 10, pp.
|
|
29
|
+
1105–1117, Jan. 2011, doi: 10.1002/sim.4154. Available: http://dx.doi.org/10.1002/sim.4154
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
estimate: torch.Tensor, of shape (..., n_samples)
|
|
35
|
+
Risks or times estimated by model
|
|
36
|
+
|
|
37
|
+
event: torch.Tensor, of shape (..., n_samples)
|
|
38
|
+
Event indicator denoting whether time is of observed event or dropout
|
|
39
|
+
|
|
40
|
+
time: torch.Tensor, of shape (..., n_samples)
|
|
41
|
+
Time of either observed event or dropout
|
|
42
|
+
|
|
43
|
+
mode: Literal['risk', 'time']
|
|
44
|
+
Whether the passed estimates are risks or survival times
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
torch.Tensor, of shape (...,)
|
|
49
|
+
Concordance index score
|
|
50
|
+
"""
|
|
51
|
+
if mode == 'risk':
|
|
52
|
+
estimate_comp = estimate.unsqueeze(-1) < estimate.unsqueeze(-2)
|
|
53
|
+
else:
|
|
54
|
+
estimate_comp = estimate.unsqueeze(-1) > estimate.unsqueeze(-2)
|
|
55
|
+
time_comp = time.unsqueeze(-1) > time.unsqueeze(-2)
|
|
56
|
+
event = event.unsqueeze(-2)
|
|
57
|
+
return torch.sum(estimate_comp & time_comp & event, (-2, -1)) / torch.sum(time_comp & event, (-2, -1))
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import warnings
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import optuna
|
|
8
|
+
import torch
|
|
9
|
+
import torchtuples as tt
|
|
10
|
+
from optuna.samplers import TPESampler
|
|
11
|
+
from pycox.evaluation import EvalSurv
|
|
12
|
+
from pycox.models import DeepHitSingle
|
|
13
|
+
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
|
|
14
|
+
from sklearn.base import BaseEstimator
|
|
15
|
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
|
16
|
+
from sklearn.utils.validation import check_is_fitted, validate_data
|
|
17
|
+
from sksurv.base import SurvivalAnalysisMixin
|
|
18
|
+
from sksurv.util import check_array_survival
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
from torch_survival.progress import OptunaProgressCallback
|
|
22
|
+
from torch_survival.utils import merge_configs
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DeepHitNetwork(nn.Module):
|
|
26
|
+
def __init__(self, n_inputs, n_times):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.shared_network = nn.Sequential(
|
|
29
|
+
nn.Linear(n_inputs, 3 * n_inputs),
|
|
30
|
+
nn.ReLU(),
|
|
31
|
+
nn.Dropout(p=0.6),
|
|
32
|
+
)
|
|
33
|
+
self.cause_network = nn.Sequential(
|
|
34
|
+
# input consists of shared network output + features
|
|
35
|
+
nn.Linear(4 * n_inputs, 5 * n_inputs),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.Dropout(p=0.6),
|
|
38
|
+
nn.Linear(5 * n_inputs, 3 * n_inputs),
|
|
39
|
+
nn.ReLU(),
|
|
40
|
+
nn.Dropout(p=0.6),
|
|
41
|
+
nn.Linear(3 * n_inputs, n_times),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
x = torch.concat((self.shared_network(x), x), dim=-1)
|
|
46
|
+
x = self.cause_network(x)
|
|
47
|
+
return x
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class DeepHitSearchSpace(TypedDict):
|
|
51
|
+
#: Weight for the ranking loss component
|
|
52
|
+
alpha: tuple[float, float] | float
|
|
53
|
+
#: Bandwidth of the radial basis function in the ranking loss component
|
|
54
|
+
sigma: tuple[float, float] | float
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DeepHit(SurvivalAnalysisMixin, BaseEstimator):
|
|
58
|
+
r""" Implements the DeepHit model presented by Lee et al. [1]_.
|
|
59
|
+
|
|
60
|
+
Uses a deep neural network trained with the negative log likelihood of the survival time probability distribution,
|
|
61
|
+
as estimated over discrete time intervals, combined with a ranking loss.
|
|
62
|
+
|
|
63
|
+
.. [1] C. Lee, W. Zame, J. Yoon, and M. Van der Schaar, “DeepHit: A Deep Learning Approach to Survival Analysis
|
|
64
|
+
With Competing Risks,” AAAI, vol. 32, no. 1, Apr. 2018, doi: 10.1609/aaai.v32i1.11842. Available:
|
|
65
|
+
http://dx.doi.org/10.1609/aaai.v32i1.11842
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
#: Default hyperparameter search space
|
|
69
|
+
default_search_space: DeepHitSearchSpace = {
|
|
70
|
+
'n_times': (10, 500),
|
|
71
|
+
'alpha': (0, 1),
|
|
72
|
+
'sigma': (1e-3, 1e1),
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def __init__(self, search_space: DeepHitSearchSpace | None = None, n_epochs=100, random_state=None, device=None):
|
|
76
|
+
self.search_space = deepcopy(self.default_search_space)
|
|
77
|
+
if search_space:
|
|
78
|
+
self.search_space = merge_configs(self.search_space, search_space)
|
|
79
|
+
self.n_epochs = n_epochs
|
|
80
|
+
self.random_state = random_state
|
|
81
|
+
self.device = device
|
|
82
|
+
if self.device is None:
|
|
83
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
84
|
+
|
|
85
|
+
def _optimize(self, trial: optuna.Trial, X, y_event, y_time):
|
|
86
|
+
# Do 5-fold cross validation
|
|
87
|
+
scores = []
|
|
88
|
+
for train_idx, test_idx in StratifiedKFold(n_splits=5).split(X, y_event):
|
|
89
|
+
model, _ = self._train(trial, X[train_idx], y_event[train_idx], y_time[train_idx])
|
|
90
|
+
surv = model.predict_surv_df(X[test_idx])
|
|
91
|
+
c_index = EvalSurv(surv, y_time[test_idx], y_event[test_idx], censor_surv='km').concordance_td('antolini')
|
|
92
|
+
scores.append(c_index)
|
|
93
|
+
return - sum(scores) / len(scores)
|
|
94
|
+
|
|
95
|
+
def _train(self, trial: optuna.Trial, X, y_event, y_time):
|
|
96
|
+
n_inputs, n_outputs = X.shape[-1], 1
|
|
97
|
+
|
|
98
|
+
# Discretize event times
|
|
99
|
+
n_times = self.search_space['n_times']
|
|
100
|
+
if not isinstance(n_times, int):
|
|
101
|
+
n_times = trial.suggest_int('n_times', *n_times)
|
|
102
|
+
y_trans = LabTransDiscreteTime(n_times)
|
|
103
|
+
y = y_trans.fit_transform(y_time, y_event)
|
|
104
|
+
|
|
105
|
+
# Split into training and validation for early stopping
|
|
106
|
+
X_train, X_val, y_time_train, y_time_val, y_event_train, y_event_val = \
|
|
107
|
+
train_test_split(X, *y, test_size=0.2, stratify=y_event, random_state=self.random_state)
|
|
108
|
+
y_train = (y_time_train, y_event_train)
|
|
109
|
+
y_val = (y_time_val, y_event_val)
|
|
110
|
+
|
|
111
|
+
# Sample loss parameters
|
|
112
|
+
alpha = self.search_space['alpha']
|
|
113
|
+
if not isinstance(alpha, float):
|
|
114
|
+
alpha = trial.suggest_float('alpha', *alpha)
|
|
115
|
+
sigma = self.search_space['sigma']
|
|
116
|
+
if not isinstance(sigma, float):
|
|
117
|
+
sigma = trial.suggest_float('sigma', *sigma, log=True)
|
|
118
|
+
|
|
119
|
+
# Train and return model
|
|
120
|
+
net = DeepHitNetwork(n_inputs, y_trans.out_features)
|
|
121
|
+
model = DeepHitSingle(net, tt.optim.Adam(lr=1e-4), duration_index=y_trans.cuts,
|
|
122
|
+
alpha=alpha, sigma=sigma, device=self.device)
|
|
123
|
+
callbacks = [tt.callbacks.EarlyStopping(patience=10)]
|
|
124
|
+
model.fit(X_train, y_train, batch_size=50, epochs=100, callbacks=callbacks,
|
|
125
|
+
val_data=(X_val, y_val), verbose=False)
|
|
126
|
+
return model, y_trans.cuts
|
|
127
|
+
|
|
128
|
+
def fit(self, X, y):
|
|
129
|
+
""" Fit the model to the given survival data.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
X: array-like, shape = (n_samples, n_features)
|
|
134
|
+
Data matrix.
|
|
135
|
+
y: structured array, shape = (n_samples,)
|
|
136
|
+
A structured array with two fields. The first field is a boolean where ``True`` indicates an event and
|
|
137
|
+
``False`` indicates right-censoring. The second field is a float with the time of event or time of
|
|
138
|
+
censoring.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
self
|
|
143
|
+
"""
|
|
144
|
+
# Validate and extract data
|
|
145
|
+
X, y = validate_data(self, X, y)
|
|
146
|
+
X = X.astype(np.float32)
|
|
147
|
+
y_event, y_time = check_array_survival(X, y)
|
|
148
|
+
|
|
149
|
+
# Seed PyTorch random number generator
|
|
150
|
+
if self.random_state:
|
|
151
|
+
torch.manual_seed(self.random_state)
|
|
152
|
+
|
|
153
|
+
# Optimize hyperparameters
|
|
154
|
+
optuna.logging.disable_default_handler()
|
|
155
|
+
warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
|
|
156
|
+
with OptunaProgressCallback(model_name='DeepHit', n_trials=50) as callback:
|
|
157
|
+
study = optuna.create_study(sampler=TPESampler(seed=self.random_state) if self.random_state else None)
|
|
158
|
+
objective = functools.partial(self._optimize, X=X, y_event=y_event, y_time=y_time)
|
|
159
|
+
study.optimize(objective, n_trials=50, callbacks=[callback])
|
|
160
|
+
|
|
161
|
+
# Train model
|
|
162
|
+
self.optuna_params_ = study.best_params
|
|
163
|
+
self.model_, self.disc_times_ = self._train(study.best_trial, X, y_event, y_time)
|
|
164
|
+
|
|
165
|
+
return self
|
|
166
|
+
|
|
167
|
+
@torch.no_grad()
|
|
168
|
+
def predict(self, X):
|
|
169
|
+
""" Predict survival times.
|
|
170
|
+
|
|
171
|
+
The survival time is estimated based on the mean of the mixture of Weibull distributions predicted by the
|
|
172
|
+
neural network.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
X: array-like, shape = (n_samples, n_features)
|
|
177
|
+
Data matrix.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
survival_time: array, shape = (n_samples,)
|
|
182
|
+
Predicted survival times.
|
|
183
|
+
"""
|
|
184
|
+
check_is_fitted(self)
|
|
185
|
+
X = validate_data(self, X)
|
|
186
|
+
X = X.astype(np.float32)
|
|
187
|
+
pmf = self.model_.predict_pmf(X)
|
|
188
|
+
times = pmf @ self.disc_times_
|
|
189
|
+
return times
|
|
190
|
+
|
|
191
|
+
def get_optuna_params(self):
|
|
192
|
+
return self.optuna_params_
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import warnings
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
6
|
+
import optuna
|
|
7
|
+
import torch
|
|
8
|
+
from optuna.samplers import TPESampler
|
|
9
|
+
from sklearn.base import BaseEstimator
|
|
10
|
+
from sklearn.model_selection import StratifiedKFold
|
|
11
|
+
from sklearn.utils.validation import check_is_fitted, validate_data
|
|
12
|
+
from sksurv.base import SurvivalAnalysisMixin
|
|
13
|
+
from sksurv.util import check_array_survival
|
|
14
|
+
|
|
15
|
+
from torch_survival.config import NetworkConfig, OptimizerConfig
|
|
16
|
+
from torch_survival.losses import cox_neg_log_likelihood
|
|
17
|
+
from torch_survival.metrics import concordance_index
|
|
18
|
+
from torch_survival.progress import OptunaProgressCallback
|
|
19
|
+
from torch_survival.sample import sample_network, sample_optimizer
|
|
20
|
+
from torch_survival.utils import merge_configs
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DeepSurvSearchSpace(TypedDict):
|
|
24
|
+
#: Neural network configuration
|
|
25
|
+
network: NetworkConfig
|
|
26
|
+
#: Optimizer configuration
|
|
27
|
+
optimizer: OptimizerConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DeepSurv(SurvivalAnalysisMixin, BaseEstimator):
|
|
31
|
+
r""" Implements the DeepSurv model presented by Katzman et al. [1]_.
|
|
32
|
+
|
|
33
|
+
Uses a deep neural network trained with the Cox negative log partial likelihood to estimate the risk of each
|
|
34
|
+
individual. The network's configuration is tuned using the Sobol solver. This implementation tries to stay faithful
|
|
35
|
+
to the original paper, with the following deviations:
|
|
36
|
+
|
|
37
|
+
* Optuna's default TPE sampler is used in favor of the Sobol sampler with 5-fold internal cross-validation instead
|
|
38
|
+
of 3-fold internal cross-validation.
|
|
39
|
+
* The hyperparameter search space is not detailed in the original paper, and the reference implementation is
|
|
40
|
+
underspecified. We thus define our own shared search space in `DeepSurvSearchSpace`.
|
|
41
|
+
* Our implementation does not support or tune :math:`\ell_2` regularization. We found this to be detrimental to
|
|
42
|
+
performance and were unable to fully replicate the described weight regularization.
|
|
43
|
+
|
|
44
|
+
.. [1] J. L. Katzman, U. Shaham, A. Cloninger, J. Bates, T. Jiang, and Y. Kluger, “DeepSurv: personalized treatment
|
|
45
|
+
recommender system using a Cox proportional hazards deep neural network,” BMC Med Res Methodol, vol. 18, no. 1,
|
|
46
|
+
Feb. 2018, doi: 10.1186/s12874-018-0482-1. Available: http://dx.doi.org/10.1186/s12874-018-0482-1
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
#: Default hyperparameter search space
|
|
50
|
+
default_search_space: DeepSurvSearchSpace = {
|
|
51
|
+
'network': {
|
|
52
|
+
'layers': {
|
|
53
|
+
'max_layers': 4,
|
|
54
|
+
'max_nodes_per_layer': 50,
|
|
55
|
+
},
|
|
56
|
+
'activation': ['relu', 'selu'],
|
|
57
|
+
'dropout': (0.0, 0.5),
|
|
58
|
+
},
|
|
59
|
+
'optimizer': {
|
|
60
|
+
'name': ['sgd', 'adam'],
|
|
61
|
+
'lr': (1e-7, 1e-3),
|
|
62
|
+
'scheduler': 'inverse_time',
|
|
63
|
+
'decay': (0.0, 0.001),
|
|
64
|
+
'momentum': (0.8, 0.95),
|
|
65
|
+
},
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def __init__(self, search_space: DeepSurvSearchSpace | None = None, n_epochs=500, random_state=None, device=None):
|
|
69
|
+
self.search_space = deepcopy(self.default_search_space)
|
|
70
|
+
if search_space:
|
|
71
|
+
self.search_space = merge_configs(self.search_space, search_space)
|
|
72
|
+
self.n_epochs = n_epochs
|
|
73
|
+
self.random_state = random_state
|
|
74
|
+
self.device = device
|
|
75
|
+
if self.device is None:
|
|
76
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
77
|
+
|
|
78
|
+
def _optimize(self, trial: optuna.Trial, X, y_event, y_time):
|
|
79
|
+
# Do 5-fold cross validation
|
|
80
|
+
scores = []
|
|
81
|
+
for train_idx, test_idx in StratifiedKFold(n_splits=5).split(X, y_event.cpu()):
|
|
82
|
+
model = self._train(trial, X[train_idx], y_event[train_idx], y_time[train_idx])
|
|
83
|
+
model.eval()
|
|
84
|
+
risks = model(X[test_idx]).squeeze(dim=-1)
|
|
85
|
+
c_index = concordance_index(risks, y_event[test_idx], y_time[test_idx])
|
|
86
|
+
scores.append(c_index)
|
|
87
|
+
return - sum(scores) / len(scores)
|
|
88
|
+
|
|
89
|
+
def _train(self, trial: optuna.Trial, X, y_event, y_time):
|
|
90
|
+
# Set up model, optimizer, and scheduler
|
|
91
|
+
n_inputs, n_outputs = X.shape[-1], 1
|
|
92
|
+
model = sample_network(trial, self.search_space['network'], n_inputs, n_outputs)
|
|
93
|
+
optimizer, scheduler = sample_optimizer(trial, self.search_space['optimizer'], model)
|
|
94
|
+
|
|
95
|
+
# Pre-sort dataset based on time
|
|
96
|
+
sort_idx = torch.argsort(y_time, descending=True)
|
|
97
|
+
X = X[sort_idx]
|
|
98
|
+
y_event = y_event[sort_idx]
|
|
99
|
+
y_time = y_time[sort_idx]
|
|
100
|
+
|
|
101
|
+
# Train and return model
|
|
102
|
+
model.to(self.device)
|
|
103
|
+
for i in range(self.n_epochs):
|
|
104
|
+
optimizer.zero_grad()
|
|
105
|
+
risk = model(X).squeeze(-1)
|
|
106
|
+
loss = cox_neg_log_likelihood(risk, y_event, y_time, sort=False)
|
|
107
|
+
loss.backward()
|
|
108
|
+
optimizer.step()
|
|
109
|
+
if scheduler:
|
|
110
|
+
scheduler.step()
|
|
111
|
+
return model
|
|
112
|
+
|
|
113
|
+
def fit(self, X, y):
|
|
114
|
+
""" Fit the model to the given survival data.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
X: array-like, shape = (n_samples, n_features)
|
|
119
|
+
Data matrix.
|
|
120
|
+
y: structured array, shape = (n_samples,)
|
|
121
|
+
A structured array with two fields. The first field is a boolean where ``True`` indicates an event and
|
|
122
|
+
``False`` indicates right-censoring. The second field is a float with the time of event or time of
|
|
123
|
+
censoring.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
self
|
|
128
|
+
"""
|
|
129
|
+
# Validate and extract data
|
|
130
|
+
X, y = validate_data(self, X, y)
|
|
131
|
+
y_event, y_time = check_array_survival(X, y)
|
|
132
|
+
|
|
133
|
+
# Convert data to tensors
|
|
134
|
+
X = torch.as_tensor(X, dtype=torch.float32, device=self.device)
|
|
135
|
+
y_event = torch.as_tensor(y_event, dtype=torch.bool, device=self.device)
|
|
136
|
+
y_time = torch.as_tensor(y_time.copy(), dtype=torch.float32, device=self.device)
|
|
137
|
+
|
|
138
|
+
# Seed PyTorch random number generator
|
|
139
|
+
if self.random_state:
|
|
140
|
+
torch.manual_seed(self.random_state)
|
|
141
|
+
|
|
142
|
+
# Optimize hyperparameters
|
|
143
|
+
optuna.logging.disable_default_handler()
|
|
144
|
+
warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)
|
|
145
|
+
with OptunaProgressCallback(model_name='DeepSurv', n_trials=50) as callback:
|
|
146
|
+
study = optuna.create_study(sampler=TPESampler(seed=self.random_state) if self.random_state else None)
|
|
147
|
+
objective = functools.partial(self._optimize, X=X, y_event=y_event, y_time=y_time)
|
|
148
|
+
study.optimize(objective, n_trials=50, callbacks=[callback])
|
|
149
|
+
|
|
150
|
+
# Train final model with best hyperparameters
|
|
151
|
+
self.optuna_params_ = study.best_params
|
|
152
|
+
self.model_ = self._train(study.best_trial, X, y_event, y_time)
|
|
153
|
+
self.model_.eval()
|
|
154
|
+
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
@torch.no_grad()
|
|
158
|
+
def predict(self, X):
|
|
159
|
+
""" Predict risk scores.
|
|
160
|
+
|
|
161
|
+
The risk score is predicted directly by a neural network. A higher score indicates a higher risk of experiencing
|
|
162
|
+
the event.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
X: array-like, shape = (n_samples, n_features)
|
|
167
|
+
Data matrix.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
risk_score: array, shape = (n_samples,)
|
|
172
|
+
Predicted risk scores.
|
|
173
|
+
"""
|
|
174
|
+
check_is_fitted(self)
|
|
175
|
+
X = validate_data(self, X)
|
|
176
|
+
X = torch.as_tensor(X, dtype=torch.float32, device=self.device)
|
|
177
|
+
return self.model_(X).detach().squeeze(dim=-1).cpu().numpy()
|
|
178
|
+
|
|
179
|
+
def get_optuna_params(self):
|
|
180
|
+
return self.optuna_params_
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import optuna
|
|
2
|
+
from rich.progress import Progress, TextColumn, BarColumn, MofNCompleteColumn, TimeRemainingColumn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class OptunaProgressCallback:
|
|
6
|
+
def __init__(self, model_name, n_trials, verbose=1):
|
|
7
|
+
self.verbose = verbose
|
|
8
|
+
if self.verbose > 0:
|
|
9
|
+
# At verbosity = 1 we show a progress bar for trials
|
|
10
|
+
self.progress = Progress(
|
|
11
|
+
TextColumn('{task.description}'),
|
|
12
|
+
BarColumn(),
|
|
13
|
+
TextColumn('{task.fields[score]}'),
|
|
14
|
+
MofNCompleteColumn(),
|
|
15
|
+
TimeRemainingColumn(),
|
|
16
|
+
)
|
|
17
|
+
self.tune_task = self.progress.add_task('[green]Optimizing ' + model_name, total=n_trials, score='-.--')
|
|
18
|
+
|
|
19
|
+
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
|
|
20
|
+
if self.verbose > 0:
|
|
21
|
+
score = '{:.2f}'.format(abs(study.best_value))
|
|
22
|
+
self.progress.update(self.tune_task, advance=1, score=score)
|
|
23
|
+
|
|
24
|
+
def __enter__(self):
|
|
25
|
+
self.progress.start()
|
|
26
|
+
return self
|
|
27
|
+
|
|
28
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
29
|
+
self.progress.stop()
|
|
File without changes
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from typing import Mapping, Any
|
|
2
|
+
|
|
3
|
+
import optuna
|
|
4
|
+
import torch.optim as optim
|
|
5
|
+
import torch.optim.lr_scheduler as sched
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from torch_survival.config import NetworkConfig, OptimizerConfig
|
|
9
|
+
from torch_survival.utils import make_activation
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SimpleNeuralNetwork(nn.Module):
|
|
13
|
+
def __init__(self, n_inputs, n_outputs, layers, activation, dropout):
|
|
14
|
+
super().__init__()
|
|
15
|
+
hidden = []
|
|
16
|
+
n_nodes = n_inputs
|
|
17
|
+
for nodes in layers:
|
|
18
|
+
hidden.append(nn.Linear(n_nodes, nodes))
|
|
19
|
+
hidden.append(make_activation(activation))
|
|
20
|
+
hidden.append(nn.Dropout(p=dropout))
|
|
21
|
+
n_nodes = nodes
|
|
22
|
+
self.hidden = nn.Sequential(*hidden)
|
|
23
|
+
self.output = nn.Linear(n_nodes, n_outputs)
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
x = self.hidden(x)
|
|
27
|
+
return self.output(x)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def sample_network(trial: optuna.Trial, config: NetworkConfig, n_inputs: int, n_outputs: int):
|
|
31
|
+
"""
|
|
32
|
+
Sample a simple neural network using the provided configuration.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
trial: optuna.Trial
|
|
37
|
+
Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
|
|
38
|
+
config: config.NetworkConfig
|
|
39
|
+
Network configuration specifying architecture and search constraints.
|
|
40
|
+
n_inputs: int
|
|
41
|
+
Number of inputs of the neural network. Typically depends on the data.
|
|
42
|
+
n_outputs: int
|
|
43
|
+
Number of outputs of the neural network. Typically depends on the modeling technique.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
nn.Module
|
|
48
|
+
PyTorch model consisting of linear layers.
|
|
49
|
+
"""
|
|
50
|
+
layers = config['layers']
|
|
51
|
+
if not isinstance(layers, list):
|
|
52
|
+
n_layers = trial.suggest_int('layers', 0, layers['max_layers'])
|
|
53
|
+
layers = [trial.suggest_int('nodes_' + str(i + 1), 1, layers['max_nodes_per_layer']) for i in range(n_layers)]
|
|
54
|
+
activation = config['activation']
|
|
55
|
+
if not isinstance(activation, str):
|
|
56
|
+
activation = trial.suggest_categorical('activation', activation)
|
|
57
|
+
dropout = config['dropout']
|
|
58
|
+
if not isinstance(dropout, float):
|
|
59
|
+
dropout = trial.suggest_float('dropout', low=dropout[0], high=dropout[1])
|
|
60
|
+
return SimpleNeuralNetwork(n_inputs, n_outputs, layers, activation, dropout)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def sample_optimizer(trial: optuna.Trial, config: OptimizerConfig, model: nn.Module):
|
|
64
|
+
"""
|
|
65
|
+
Sample optimizer and scheduler for training using the provided configuration.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
trial: optuna.Trial
|
|
70
|
+
Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
|
|
71
|
+
config: config.OptimizerConfig
|
|
72
|
+
Optimizer configuration specifying how the neural network should be trained.
|
|
73
|
+
model: nn.Module
|
|
74
|
+
Neural network whose parameters should be optimized.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
torch.optim.Optimizer
|
|
79
|
+
PyTorch optimizer that can be used for training.
|
|
80
|
+
torch.optim.lr_scheduler.LRScheduler or None
|
|
81
|
+
PyTorch scheduler that can be used to update learning rates.
|
|
82
|
+
"""
|
|
83
|
+
# Sample optimizer parameters
|
|
84
|
+
lr = config['lr']
|
|
85
|
+
if not isinstance(lr, float):
|
|
86
|
+
lr = trial.suggest_float('lr', *lr, log=True)
|
|
87
|
+
momentum = config['momentum']
|
|
88
|
+
if not isinstance(momentum, float):
|
|
89
|
+
momentum = trial.suggest_float('momentum', *momentum)
|
|
90
|
+
# Initialize optimizer
|
|
91
|
+
optimizer = None
|
|
92
|
+
optimizer_name = config['name']
|
|
93
|
+
if not isinstance(optimizer_name, str):
|
|
94
|
+
optimizer_name = trial.suggest_categorical('optimizer', optimizer_name)
|
|
95
|
+
if optimizer_name == 'sgd':
|
|
96
|
+
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
|
|
97
|
+
if optimizer_name == 'adam':
|
|
98
|
+
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(momentum, 0.999))
|
|
99
|
+
if optimizer is None:
|
|
100
|
+
raise ValueError('Optimizer with name `{}` is not supported'.format(config['name']))
|
|
101
|
+
# Sample optimizer parameters
|
|
102
|
+
decay = config['decay']
|
|
103
|
+
if not isinstance(decay, float):
|
|
104
|
+
decay = trial.suggest_float('decay', *decay)
|
|
105
|
+
# Initialize scheduler
|
|
106
|
+
scheduler = None
|
|
107
|
+
scheduler_name = config['scheduler']
|
|
108
|
+
if not isinstance(scheduler_name, str):
|
|
109
|
+
scheduler_name = trial.suggest_categorical('scheduler', scheduler_name)
|
|
110
|
+
if scheduler_name == 'inverse_time':
|
|
111
|
+
scheduler = sched.LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + epoch * decay))
|
|
112
|
+
# Return both
|
|
113
|
+
return optimizer, scheduler
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def sample_int(trial: optuna.Trial, config: Mapping[str, Any], key: str) -> int:
|
|
117
|
+
"""
|
|
118
|
+
Helper to sample an integer value from a search space configuration.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
trial: optuna.Trial
|
|
123
|
+
Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
|
|
124
|
+
config: Mapping[str, Any]
|
|
125
|
+
Search space configuration dictionary.
|
|
126
|
+
key: str
|
|
127
|
+
Name of configuration parameter that should be sampled.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
int:
|
|
132
|
+
The sampled value.
|
|
133
|
+
"""
|
|
134
|
+
value = config[key]
|
|
135
|
+
if not isinstance(value, int):
|
|
136
|
+
value = trial.suggest_int(key, *value)
|
|
137
|
+
return value
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def sample_float(trial: optuna.Trial, config: Mapping[str, Any], key: str) -> float:
|
|
141
|
+
"""
|
|
142
|
+
Helper to sample a floating point value from a search space configuration.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
trial: optuna.Trial
|
|
147
|
+
Active or fixed trial used to sample hyperparameters. Final configuration may be obtained by `study.best_trial`.
|
|
148
|
+
config: Mapping[str, Any]
|
|
149
|
+
Search space configuration dictionary.
|
|
150
|
+
key: str
|
|
151
|
+
Name of configuration parameter that should be sampled.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
float:
|
|
156
|
+
The sampled value.
|
|
157
|
+
"""
|
|
158
|
+
value = config[key]
|
|
159
|
+
if not isinstance(value, float):
|
|
160
|
+
value = trial.suggest_float(key, *value)
|
|
161
|
+
return value
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import TypeVar
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
_TypedDict = TypeVar('_TypedDict')
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def make_activation(query: str):
|
|
10
|
+
"""
|
|
11
|
+
Resolves the name of an activation function to its PyTorch layer. This method uses inspection and should thus be
|
|
12
|
+
able to resolve any current PyTorch activation function, e.g., `relu` gets mapped to torch.nn.ReLU, and so forth.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
query: str
|
|
17
|
+
Name of activation function that should be created
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
torch.nn.Module
|
|
22
|
+
PyTorch layer corresponding to query, initialized with default values
|
|
23
|
+
"""
|
|
24
|
+
query = query.lower()
|
|
25
|
+
matches = [obj for name, obj in inspect.getmembers(nn) if query == name.lower()]
|
|
26
|
+
if len(matches) == 0:
|
|
27
|
+
raise ValueError('Found no candidate for `{}` activation'.format(query))
|
|
28
|
+
if len(matches) > 1:
|
|
29
|
+
raise ValueError('Found multiple candidates for `{}` activation'.format(query))
|
|
30
|
+
return matches[0]()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def merge_configs(default_config: _TypedDict, user_config: _TypedDict) -> _TypedDict:
|
|
34
|
+
"""
|
|
35
|
+
Recursively merges default configuration with user configuration.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
default_config: subclass of TypedDict
|
|
40
|
+
Default model configuration
|
|
41
|
+
user_config: subclass of TypedDict
|
|
42
|
+
User-provided model configuration overriding defaults
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
subclass of TypedDict
|
|
47
|
+
Merged model configuration
|
|
48
|
+
"""
|
|
49
|
+
for k, v in user_config.items():
|
|
50
|
+
if k in default_config:
|
|
51
|
+
if isinstance(default_config[k], dict) and isinstance(v, dict):
|
|
52
|
+
default_config[k] = merge_configs(default_config[k], v)
|
|
53
|
+
else:
|
|
54
|
+
default_config[k] = v
|
|
55
|
+
return default_config
|
|
56
|
+
|
|
57
|
+
# def sample_params(search_space, trial: optuna.Trial):
|
|
58
|
+
# """
|
|
59
|
+
# Samples hyperparameter search space, allowing for fixed
|
|
60
|
+
#
|
|
61
|
+
# Parameters
|
|
62
|
+
# ----------
|
|
63
|
+
# search_space: Mapping[str, ParamConfig]
|
|
64
|
+
# Dictionary of parameter configurations.
|
|
65
|
+
# trial: optuna.Trial
|
|
66
|
+
# Active trial for sampling parameters.
|
|
67
|
+
#
|
|
68
|
+
# Returns
|
|
69
|
+
# -------
|
|
70
|
+
# params: dict[str, int | float | str]
|
|
71
|
+
# Sampled parameters.
|
|
72
|
+
# """
|
|
73
|
+
# params = {}
|
|
74
|
+
# for key, value in search_space.items():
|
|
75
|
+
# if type(value) in [int, float, str]:
|
|
76
|
+
# # The value is already fixed and will be passed on as-is
|
|
77
|
+
# params[key] = value
|
|
78
|
+
# elif type(value) is list:
|
|
79
|
+
# # The value is a list of possible categories
|
|
80
|
+
# params[key] = trial.suggest_categorical(key, value)
|
|
81
|
+
# elif type(value) is tuple and type(value[0]) is int:
|
|
82
|
+
# # The value are lower and upper bounds for an integer
|
|
83
|
+
# params[key] = trial.suggest_int(key, *value[:2], log=value[2] if len(value) > 2 else False)
|
|
84
|
+
# elif type(value) is tuple and type(value[0]) is float:
|
|
85
|
+
# # The value are lower and upper bounds for an integer
|
|
86
|
+
# params[key] = trial.suggest_float(key, *value[:2], log=value[2] if len(value) > 2 else False)
|
|
87
|
+
# return params
|