konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,114 +1,385 @@
|
|
|
1
|
-
from konfai.network import network, blocks
|
|
2
|
-
from typing import Union
|
|
3
1
|
import torch
|
|
2
|
+
|
|
4
3
|
from konfai.data.patching import ModelPatch
|
|
4
|
+
from konfai.network import blocks, network
|
|
5
|
+
|
|
5
6
|
|
|
6
7
|
class NestedUNet(network.Network):
|
|
7
8
|
|
|
8
9
|
class NestedUNetBlock(network.ModuleArgsDict):
|
|
9
10
|
|
|
10
|
-
def __init__(
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
channels: list[int],
|
|
14
|
+
nb_conv_per_stage: int,
|
|
15
|
+
block_config: blocks.BlockConfig,
|
|
16
|
+
downsample_mode: blocks.DownsampleMode,
|
|
17
|
+
upsample_mode: blocks.UpsampleMode,
|
|
18
|
+
attention: bool,
|
|
19
|
+
block: type,
|
|
20
|
+
dim: int,
|
|
21
|
+
i: int = 0,
|
|
22
|
+
) -> None:
|
|
11
23
|
super().__init__()
|
|
12
24
|
if i > 0:
|
|
13
|
-
self.add_module(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
25
|
+
self.add_module(
|
|
26
|
+
downsample_mode.name,
|
|
27
|
+
blocks.downsample(
|
|
28
|
+
in_channels=channels[0],
|
|
29
|
+
out_channels=channels[1],
|
|
30
|
+
downsample_mode=downsample_mode,
|
|
31
|
+
dim=dim,
|
|
32
|
+
),
|
|
33
|
+
)
|
|
22
34
|
|
|
35
|
+
self.add_module(
|
|
36
|
+
f"X_{i}_{0}",
|
|
37
|
+
block(
|
|
38
|
+
in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
|
|
39
|
+
out_channels=channels[1],
|
|
40
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
41
|
+
dim=dim,
|
|
42
|
+
),
|
|
43
|
+
out_branch=[f"X_{i}_{0}"],
|
|
44
|
+
)
|
|
45
|
+
if len(channels) > 2:
|
|
46
|
+
self.add_module(
|
|
47
|
+
f"UNetBlock_{i + 1}",
|
|
48
|
+
NestedUNet.NestedUNetBlock(
|
|
49
|
+
channels[1:],
|
|
50
|
+
nb_conv_per_stage,
|
|
51
|
+
block_config,
|
|
52
|
+
downsample_mode,
|
|
53
|
+
upsample_mode,
|
|
54
|
+
attention,
|
|
55
|
+
block,
|
|
56
|
+
dim,
|
|
57
|
+
i + 1,
|
|
58
|
+
),
|
|
59
|
+
in_branch=[f"X_{i}_{0}"],
|
|
60
|
+
out_branch=[f"X_{i + 1}_{j}" for j in range(len(channels) - 2)],
|
|
61
|
+
)
|
|
62
|
+
for j in range(len(channels) - 2):
|
|
63
|
+
self.add_module(
|
|
64
|
+
f"X_{i}_{j + 1}_{upsample_mode.name}",
|
|
65
|
+
blocks.upsample(
|
|
66
|
+
in_channels=channels[2],
|
|
67
|
+
out_channels=channels[1],
|
|
68
|
+
upsample_mode=upsample_mode,
|
|
69
|
+
dim=dim,
|
|
70
|
+
),
|
|
71
|
+
in_branch=[f"X_{i + 1}_{j}"],
|
|
72
|
+
out_branch=[f"X_{i + 1}_{j}"],
|
|
73
|
+
)
|
|
74
|
+
self.add_module(
|
|
75
|
+
f"SkipConnection_{i}_{j + 1}",
|
|
76
|
+
blocks.Concat(),
|
|
77
|
+
in_branch=[f"X_{i + 1}_{j}"] + [f"X_{i}_{r}" for r in range(j + 1)],
|
|
78
|
+
out_branch=[f"X_{i}_{j + 1}"],
|
|
79
|
+
)
|
|
80
|
+
self.add_module(
|
|
81
|
+
f"X_{i}_{j + 1}",
|
|
82
|
+
block(
|
|
83
|
+
in_channels=(
|
|
84
|
+
(channels[1] * (j + 1) + channels[2])
|
|
85
|
+
if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
|
|
86
|
+
else channels[1] * (j + 2)
|
|
87
|
+
),
|
|
88
|
+
out_channels=channels[1],
|
|
89
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
90
|
+
dim=dim,
|
|
91
|
+
),
|
|
92
|
+
in_branch=[f"X_{i}_{j + 1}"],
|
|
93
|
+
out_branch=[f"X_{i}_{j + 1}"],
|
|
94
|
+
)
|
|
23
95
|
|
|
24
96
|
class NestedUNetHead(network.ModuleArgsDict):
|
|
25
97
|
|
|
26
98
|
def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
|
|
27
99
|
super().__init__()
|
|
28
|
-
self.add_module(
|
|
100
|
+
self.add_module(
|
|
101
|
+
"Conv",
|
|
102
|
+
blocks.get_torch_module("Conv", dim)(
|
|
103
|
+
in_channels=in_channels,
|
|
104
|
+
out_channels=nb_class,
|
|
105
|
+
kernel_size=1,
|
|
106
|
+
stride=1,
|
|
107
|
+
padding=0,
|
|
108
|
+
),
|
|
109
|
+
)
|
|
29
110
|
if activation == "Softmax":
|
|
30
111
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
31
112
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
32
113
|
elif activation == "Tanh":
|
|
33
114
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
34
|
-
|
|
35
|
-
def __init__(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
119
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
120
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
121
|
+
},
|
|
122
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
123
|
+
patch: ModelPatch | None = None,
|
|
124
|
+
dim: int = 3,
|
|
125
|
+
channels: list[int] = [1, 64, 128, 256, 512, 1024],
|
|
126
|
+
nb_class: int = 2,
|
|
127
|
+
block_config: blocks.BlockConfig = blocks.BlockConfig(),
|
|
128
|
+
nb_conv_per_stage: int = 2,
|
|
129
|
+
downsample_mode: str = "MAXPOOL",
|
|
130
|
+
upsample_mode: str = "CONV_TRANSPOSE",
|
|
131
|
+
attention: bool = False,
|
|
132
|
+
block_type: str = "Conv",
|
|
133
|
+
activation: str = "Softmax",
|
|
134
|
+
) -> None:
|
|
135
|
+
super().__init__(
|
|
136
|
+
in_channels=channels[0],
|
|
137
|
+
optimizer=optimizer,
|
|
138
|
+
schedulers=schedulers,
|
|
139
|
+
outputs_criterions=outputs_criterions,
|
|
140
|
+
patch=patch,
|
|
141
|
+
dim=dim,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.add_module(
|
|
145
|
+
"UNetBlock_0",
|
|
146
|
+
NestedUNet.NestedUNetBlock(
|
|
147
|
+
channels,
|
|
148
|
+
nb_conv_per_stage,
|
|
149
|
+
block_config,
|
|
150
|
+
downsample_mode=blocks.DownsampleMode[downsample_mode],
|
|
151
|
+
upsample_mode=blocks.UpsampleMode[upsample_mode],
|
|
152
|
+
attention=attention,
|
|
153
|
+
block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
|
|
154
|
+
dim=dim,
|
|
155
|
+
),
|
|
156
|
+
out_branch=[f"X_0_{j + 1}" for j in range(len(channels) - 2)],
|
|
157
|
+
)
|
|
158
|
+
for j in range(len(channels) - 2):
|
|
159
|
+
self.add_module(
|
|
160
|
+
f"Head_{j}",
|
|
161
|
+
NestedUNet.NestedUNetHead(
|
|
162
|
+
in_channels=channels[1],
|
|
163
|
+
nb_class=nb_class,
|
|
164
|
+
activation=activation,
|
|
165
|
+
dim=dim,
|
|
166
|
+
),
|
|
167
|
+
in_branch=[f"X_0_{j + 1}"],
|
|
168
|
+
out_branch=[-1],
|
|
169
|
+
)
|
|
55
170
|
|
|
56
171
|
|
|
57
172
|
class UNetpp(network.Network):
|
|
58
173
|
|
|
59
174
|
class ResNetEncoderLayer(network.ModuleArgsDict):
|
|
60
175
|
|
|
61
|
-
def __init__(
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
in_channel: int,
|
|
179
|
+
out_channel: int,
|
|
180
|
+
nb_block: int,
|
|
181
|
+
dim: int,
|
|
182
|
+
downsample_mode: blocks.DownsampleMode,
|
|
183
|
+
):
|
|
62
184
|
super().__init__()
|
|
63
185
|
for i in range(nb_block):
|
|
64
|
-
if
|
|
65
|
-
self.add_module(
|
|
66
|
-
|
|
186
|
+
if downsample_mode == blocks.DownsampleMode.MAXPOOL and i == 0:
|
|
187
|
+
self.add_module(
|
|
188
|
+
"DownSample",
|
|
189
|
+
blocks.get_torch_module("MaxPool", dim)(
|
|
190
|
+
kernel_size=3,
|
|
191
|
+
stride=2,
|
|
192
|
+
padding=1,
|
|
193
|
+
dilation=1,
|
|
194
|
+
ceil_mode=False,
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
self.add_module(
|
|
198
|
+
f"ResBlock_{i}",
|
|
199
|
+
blocks.ResBlock(
|
|
200
|
+
in_channel,
|
|
201
|
+
out_channel,
|
|
202
|
+
[
|
|
203
|
+
blocks.BlockConfig(
|
|
204
|
+
3,
|
|
205
|
+
(2 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i == 0 else 1),
|
|
206
|
+
1,
|
|
207
|
+
False,
|
|
208
|
+
"ReLU;True",
|
|
209
|
+
blocks.NormMode.BATCH,
|
|
210
|
+
),
|
|
211
|
+
blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH),
|
|
212
|
+
],
|
|
213
|
+
dim=dim,
|
|
214
|
+
),
|
|
215
|
+
)
|
|
67
216
|
in_channel = out_channel
|
|
68
|
-
|
|
69
|
-
|
|
217
|
+
|
|
218
|
+
@staticmethod
|
|
219
|
+
def resnet_encoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
|
|
70
220
|
modules = []
|
|
71
|
-
modules.append(
|
|
221
|
+
modules.append(
|
|
222
|
+
blocks.ConvBlock(
|
|
223
|
+
channels[0],
|
|
224
|
+
channels[1],
|
|
225
|
+
[blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)],
|
|
226
|
+
dim=dim,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
72
229
|
for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
|
|
73
|
-
modules.append(
|
|
230
|
+
modules.append(
|
|
231
|
+
UNetpp.ResNetEncoderLayer(
|
|
232
|
+
in_channel,
|
|
233
|
+
out_channel,
|
|
234
|
+
layer,
|
|
235
|
+
dim,
|
|
236
|
+
(blocks.DownsampleMode.MAXPOOL if i == 0 else blocks.DownsampleMode.CONV_STRIDE),
|
|
237
|
+
)
|
|
238
|
+
)
|
|
74
239
|
return modules
|
|
75
240
|
|
|
76
241
|
class UNetPPBlock(network.ModuleArgsDict):
|
|
77
242
|
|
|
78
|
-
def __init__(
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
encoder_channels: list[int],
|
|
246
|
+
decoder_channels: list[int],
|
|
247
|
+
encoders: list[torch.nn.Module],
|
|
248
|
+
upsample_mode: blocks.UpsampleMode,
|
|
249
|
+
dim: int,
|
|
250
|
+
i: int = 0,
|
|
251
|
+
) -> None:
|
|
79
252
|
super().__init__()
|
|
80
|
-
self.add_module("X_{}_{}"
|
|
253
|
+
self.add_module(f"X_{i}_{0}", encoders[0], out_branch=[f"X_{i}_{0}"])
|
|
81
254
|
if len(encoder_channels) > 2:
|
|
82
|
-
self.add_module(
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
255
|
+
self.add_module(
|
|
256
|
+
f"UNetBlock_{i + 1}",
|
|
257
|
+
UNetpp.UNetPPBlock(
|
|
258
|
+
encoder_channels[1:],
|
|
259
|
+
decoder_channels[1:],
|
|
260
|
+
encoders[1:],
|
|
261
|
+
upsample_mode,
|
|
262
|
+
dim,
|
|
263
|
+
i + 1,
|
|
264
|
+
),
|
|
265
|
+
in_branch=[f"X_{i}_{0}"],
|
|
266
|
+
out_branch=[f"X_{i + 1}_{j}" for j in range(len(encoder_channels) - 2)],
|
|
267
|
+
)
|
|
268
|
+
for j in range(len(encoder_channels) - 2):
|
|
269
|
+
in_channels = (
|
|
270
|
+
decoder_channels[3]
|
|
271
|
+
if j == len(encoder_channels) - 3 and len(encoder_channels) > 3
|
|
272
|
+
else encoder_channels[2]
|
|
273
|
+
)
|
|
274
|
+
out_channel = decoder_channels[2] if j == len(encoder_channels) - 3 else encoder_channels[1]
|
|
275
|
+
self.add_module(
|
|
276
|
+
f"X_{i}_{j + 1}_{upsample_mode.name}",
|
|
277
|
+
blocks.upsample(
|
|
278
|
+
in_channels=in_channels,
|
|
279
|
+
out_channels=out_channel,
|
|
280
|
+
upsample_mode=upsample_mode,
|
|
281
|
+
dim=dim,
|
|
282
|
+
),
|
|
283
|
+
in_branch=[f"X_{i + 1}_{j}"],
|
|
284
|
+
out_branch=[f"X_{i + 1}_{j}"],
|
|
285
|
+
)
|
|
286
|
+
self.add_module(
|
|
287
|
+
f"SkipConnection_{i}_{j + 1}",
|
|
288
|
+
blocks.Concat(),
|
|
289
|
+
in_branch=[f"X_{i + 1}_{j}"] + [f"X_{i}_{r}" for r in range(j + 1)],
|
|
290
|
+
out_branch=[f"X_{i}_{j + 1}"],
|
|
291
|
+
)
|
|
292
|
+
self.add_module(
|
|
293
|
+
f"X_{i}_{j + 1}",
|
|
294
|
+
blocks.ConvBlock(
|
|
295
|
+
in_channels=encoder_channels[1] * (j + 1)
|
|
296
|
+
+ (in_channels if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE else out_channel),
|
|
297
|
+
out_channels=out_channel,
|
|
298
|
+
block_configs=[blocks.BlockConfig(3, 1, 1, False, "ReLU;True", blocks.NormMode.BATCH)] * 2,
|
|
299
|
+
dim=dim,
|
|
300
|
+
),
|
|
301
|
+
in_branch=[f"X_{i}_{j + 1}"],
|
|
302
|
+
out_branch=[f"X_{i}_{j + 1}"],
|
|
303
|
+
)
|
|
304
|
+
|
|
90
305
|
class UNetPPHead(network.ModuleArgsDict):
|
|
91
306
|
|
|
92
307
|
def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
|
|
93
308
|
super().__init__()
|
|
94
|
-
self.add_module(
|
|
95
|
-
|
|
96
|
-
|
|
309
|
+
self.add_module(
|
|
310
|
+
"Upsample",
|
|
311
|
+
blocks.upsample(
|
|
312
|
+
in_channels=in_channels,
|
|
313
|
+
out_channels=out_channels,
|
|
314
|
+
upsample_mode=blocks.UpsampleMode.UPSAMPLE,
|
|
315
|
+
dim=dim,
|
|
316
|
+
),
|
|
317
|
+
)
|
|
318
|
+
self.add_module(
|
|
319
|
+
"ConvBlock",
|
|
320
|
+
blocks.ConvBlock(
|
|
321
|
+
in_channels=in_channels,
|
|
322
|
+
out_channels=out_channels,
|
|
323
|
+
block_configs=[blocks.BlockConfig(3, 1, 1, False, "ReLU;True", blocks.NormMode.BATCH)] * 2,
|
|
324
|
+
dim=dim,
|
|
325
|
+
),
|
|
326
|
+
)
|
|
327
|
+
self.add_module(
|
|
328
|
+
"Conv",
|
|
329
|
+
blocks.get_torch_module("Conv", dim)(
|
|
330
|
+
in_channels=out_channels,
|
|
331
|
+
out_channels=nb_class,
|
|
332
|
+
kernel_size=3,
|
|
333
|
+
stride=1,
|
|
334
|
+
padding=1,
|
|
335
|
+
),
|
|
336
|
+
)
|
|
97
337
|
if nb_class > 1:
|
|
98
338
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
99
339
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
100
340
|
else:
|
|
101
341
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
102
|
-
|
|
103
|
-
def __init__(
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
342
|
+
|
|
343
|
+
def __init__(
|
|
344
|
+
self,
|
|
345
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
346
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
347
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
348
|
+
},
|
|
349
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
350
|
+
patch: ModelPatch | None = None,
|
|
351
|
+
encoder_channels: list[int] = [1, 64, 64, 128, 256, 512],
|
|
352
|
+
decoder_channels: list[int] = [256, 128, 64, 32, 16, 1],
|
|
353
|
+
layers: list[int] = [3, 4, 6, 3],
|
|
354
|
+
dim: int = 2,
|
|
355
|
+
) -> None:
|
|
356
|
+
super().__init__(
|
|
357
|
+
in_channels=encoder_channels[0],
|
|
358
|
+
optimizer=optimizer,
|
|
359
|
+
schedulers=schedulers,
|
|
360
|
+
outputs_criterions=outputs_criterions,
|
|
361
|
+
patch=patch,
|
|
362
|
+
dim=dim,
|
|
363
|
+
)
|
|
364
|
+
self.add_module(
|
|
365
|
+
"Block_0",
|
|
366
|
+
UNetpp.UNetPPBlock(
|
|
367
|
+
encoder_channels,
|
|
368
|
+
decoder_channels[::-1],
|
|
369
|
+
UNetpp.resnet_encoder(encoder_channels, layers, dim),
|
|
370
|
+
blocks.UpsampleMode.UPSAMPLE,
|
|
371
|
+
dim=dim,
|
|
372
|
+
),
|
|
373
|
+
out_branch=[f"X_0_{j + 1}" for j in range(len(encoder_channels) - 2)],
|
|
374
|
+
)
|
|
375
|
+
self.add_module(
|
|
376
|
+
"Head",
|
|
377
|
+
UNetpp.UNetPPHead(
|
|
378
|
+
in_channels=decoder_channels[-3],
|
|
379
|
+
out_channels=decoder_channels[-2],
|
|
380
|
+
nb_class=decoder_channels[-1],
|
|
381
|
+
dim=dim,
|
|
382
|
+
),
|
|
383
|
+
in_branch=[f"X_0_{len(encoder_channels) - 2}"],
|
|
384
|
+
out_branch=[-1],
|
|
385
|
+
)
|
|
@@ -1,55 +1,165 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from typing import Union
|
|
3
2
|
|
|
4
|
-
from konfai.network import network, blocks
|
|
5
|
-
from konfai.utils.config import config
|
|
6
3
|
from konfai.data.patching import ModelPatch
|
|
4
|
+
from konfai.network import blocks, network
|
|
5
|
+
|
|
7
6
|
|
|
8
7
|
class UNetHead(network.ModuleArgsDict):
|
|
9
8
|
|
|
10
9
|
def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
|
|
11
10
|
super().__init__()
|
|
12
|
-
self.add_module(
|
|
11
|
+
self.add_module(
|
|
12
|
+
"Conv",
|
|
13
|
+
blocks.get_torch_module("Conv", dim)(
|
|
14
|
+
in_channels=in_channels,
|
|
15
|
+
out_channels=nb_class,
|
|
16
|
+
kernel_size=1,
|
|
17
|
+
stride=1,
|
|
18
|
+
padding=0,
|
|
19
|
+
),
|
|
20
|
+
)
|
|
13
21
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
22
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
23
|
|
|
24
|
+
|
|
16
25
|
class UNetBlock(network.ModuleArgsDict):
|
|
17
26
|
|
|
18
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
channels: list[int],
|
|
30
|
+
nb_conv_per_stage: int,
|
|
31
|
+
block_config: blocks.BlockConfig,
|
|
32
|
+
downsample_mode: blocks.DownsampleMode,
|
|
33
|
+
upsample_mode: blocks.UpsampleMode,
|
|
34
|
+
attention: bool,
|
|
35
|
+
block: type,
|
|
36
|
+
nb_class: int,
|
|
37
|
+
dim: int,
|
|
38
|
+
i: int = 0,
|
|
39
|
+
) -> None:
|
|
19
40
|
super().__init__()
|
|
20
|
-
|
|
41
|
+
block_config_stride = block_config
|
|
21
42
|
if i > 0:
|
|
22
|
-
if
|
|
23
|
-
self.add_module(
|
|
43
|
+
if downsample_mode != blocks.DownsampleMode.CONV_STRIDE:
|
|
44
|
+
self.add_module(
|
|
45
|
+
downsample_mode.name,
|
|
46
|
+
blocks.downsample(
|
|
47
|
+
in_channels=channels[0],
|
|
48
|
+
out_channels=channels[1],
|
|
49
|
+
downsample_mode=downsample_mode,
|
|
50
|
+
dim=dim,
|
|
51
|
+
),
|
|
52
|
+
)
|
|
24
53
|
else:
|
|
25
|
-
|
|
26
|
-
|
|
54
|
+
block_config_stride = blocks.BlockConfig(
|
|
55
|
+
block_config.kernel_size,
|
|
56
|
+
2,
|
|
57
|
+
block_config.padding,
|
|
58
|
+
block_config.bias,
|
|
59
|
+
block_config.activation,
|
|
60
|
+
block_config.norm_mode,
|
|
61
|
+
)
|
|
62
|
+
self.add_module(
|
|
63
|
+
"DownConvBlock",
|
|
64
|
+
block(
|
|
65
|
+
in_channels=channels[0],
|
|
66
|
+
out_channels=channels[1],
|
|
67
|
+
block_configs=[block_config_stride] + [block_config] * (nb_conv_per_stage - 1),
|
|
68
|
+
dim=dim,
|
|
69
|
+
),
|
|
70
|
+
)
|
|
27
71
|
if len(channels) > 2:
|
|
28
|
-
self.add_module(
|
|
29
|
-
|
|
72
|
+
self.add_module(
|
|
73
|
+
f"UNetBlock_{i + 1}",
|
|
74
|
+
UNetBlock(
|
|
75
|
+
channels[1:],
|
|
76
|
+
nb_conv_per_stage,
|
|
77
|
+
block_config,
|
|
78
|
+
downsample_mode,
|
|
79
|
+
upsample_mode,
|
|
80
|
+
attention,
|
|
81
|
+
block,
|
|
82
|
+
nb_class,
|
|
83
|
+
dim,
|
|
84
|
+
i + 1,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
self.add_module(
|
|
88
|
+
"UpConvBlock",
|
|
89
|
+
block(
|
|
90
|
+
in_channels=(
|
|
91
|
+
(channels[1] + channels[2])
|
|
92
|
+
if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
|
|
93
|
+
else channels[1] * 2
|
|
94
|
+
),
|
|
95
|
+
out_channels=channels[1],
|
|
96
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
97
|
+
dim=dim,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
30
100
|
if nb_class > 0:
|
|
31
101
|
self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
|
|
32
102
|
if i > 0:
|
|
33
103
|
if attention:
|
|
34
|
-
self.add_module(
|
|
35
|
-
|
|
104
|
+
self.add_module(
|
|
105
|
+
"Attention",
|
|
106
|
+
blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
|
|
107
|
+
in_branch=[1, 0],
|
|
108
|
+
out_branch=[1],
|
|
109
|
+
)
|
|
110
|
+
self.add_module(
|
|
111
|
+
upsample_mode.name,
|
|
112
|
+
blocks.upsample(
|
|
113
|
+
in_channels=channels[1],
|
|
114
|
+
out_channels=channels[0],
|
|
115
|
+
upsample_mode=upsample_mode,
|
|
116
|
+
dim=dim,
|
|
117
|
+
kernel_size=2,
|
|
118
|
+
stride=2,
|
|
119
|
+
),
|
|
120
|
+
)
|
|
36
121
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
|
37
122
|
|
|
123
|
+
|
|
38
124
|
class UNet(network.Network):
|
|
39
125
|
|
|
40
|
-
def __init__(
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
129
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
130
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
131
|
+
},
|
|
132
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
133
|
+
patch: ModelPatch | None = None,
|
|
134
|
+
dim: int = 3,
|
|
135
|
+
channels: list[int] = [1, 64, 128, 256, 512, 1024],
|
|
136
|
+
nb_class: int = 2,
|
|
137
|
+
block_config: blocks.BlockConfig = blocks.BlockConfig(),
|
|
138
|
+
nb_conv_per_stage: int = 2,
|
|
139
|
+
downsample_mode: str = "MAXPOOL",
|
|
140
|
+
upsample_mode: str = "CONV_TRANSPOSE",
|
|
141
|
+
attention: bool = False,
|
|
142
|
+
block_type: str = "Conv",
|
|
143
|
+
) -> None:
|
|
144
|
+
super().__init__(
|
|
145
|
+
in_channels=channels[0],
|
|
146
|
+
optimizer=optimizer,
|
|
147
|
+
schedulers=schedulers,
|
|
148
|
+
outputs_criterions=outputs_criterions,
|
|
149
|
+
patch=patch,
|
|
150
|
+
dim=dim,
|
|
151
|
+
)
|
|
152
|
+
self.add_module(
|
|
153
|
+
"UNetBlock_0",
|
|
154
|
+
UNetBlock(
|
|
155
|
+
channels,
|
|
156
|
+
nb_conv_per_stage,
|
|
157
|
+
block_config,
|
|
158
|
+
downsample_mode=blocks.DownsampleMode[downsample_mode],
|
|
159
|
+
upsample_mode=blocks.UpsampleMode[upsample_mode],
|
|
160
|
+
attention=attention,
|
|
161
|
+
block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
|
|
162
|
+
nb_class=nb_class,
|
|
163
|
+
dim=dim,
|
|
164
|
+
),
|
|
165
|
+
)
|