konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,115 +1,329 @@
|
|
|
1
1
|
from abc import ABC
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import torch
|
|
4
|
-
|
|
5
|
-
from konfai.utils.config import config
|
|
4
|
+
|
|
6
5
|
from konfai.data.patching import ModelPatch
|
|
6
|
+
from konfai.network import blocks, network
|
|
7
7
|
|
|
8
8
|
"""
|
|
9
|
-
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
|
9
|
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
|
10
|
+
dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512],
|
|
11
|
+
num_classes=1000, use_bottleneck=False
|
|
10
12
|
|
|
11
|
-
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
|
13
|
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
|
14
|
+
dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 64, 128, 256, 512],
|
|
15
|
+
num_classes=1000, use_bottleneck=False
|
|
12
16
|
|
|
13
|
-
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
17
|
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
|
18
|
+
dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 256, 512, 1024, 2048],
|
|
19
|
+
num_classes=1000, use_bottleneck=True
|
|
14
20
|
|
|
15
|
-
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
|
21
|
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
|
22
|
+
dim = 2, in_channels = 3, depths=[3, 4, 23, 3], widths = [64, 256, 512, 1024, 2048],
|
|
23
|
+
num_classes=1000, use_bottleneck=True
|
|
16
24
|
|
|
17
|
-
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
|
25
|
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
|
26
|
+
dim = 2, in_channels = 3, depths=[3, 8, 36, 3], widths = [64, 256, 512, 1024, 2048],
|
|
27
|
+
num_classes=1000, use_bottleneck=True
|
|
18
28
|
|
|
19
29
|
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
|
20
30
|
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
|
21
31
|
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
|
22
32
|
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
|
23
33
|
"""
|
|
34
|
+
|
|
35
|
+
|
|
24
36
|
class AbstractResBlock(network.ModuleArgsDict, ABC):
|
|
25
37
|
|
|
26
38
|
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
27
39
|
super().__init__()
|
|
28
40
|
|
|
41
|
+
|
|
29
42
|
class ResBlock(AbstractResBlock):
|
|
30
43
|
|
|
31
44
|
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
32
45
|
super().__init__(in_channels, out_channels, downsample, dim)
|
|
33
46
|
if downsample:
|
|
34
|
-
self.add_module(
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
47
|
+
self.add_module(
|
|
48
|
+
"Shortcut",
|
|
49
|
+
blocks.ConvBlock(
|
|
50
|
+
in_channels,
|
|
51
|
+
out_channels,
|
|
52
|
+
blocks.BlockConfig(
|
|
53
|
+
kernel_size=1,
|
|
54
|
+
stride=2,
|
|
55
|
+
padding=0,
|
|
56
|
+
bias=False,
|
|
57
|
+
activation="None",
|
|
58
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
59
|
+
),
|
|
60
|
+
dim=dim,
|
|
61
|
+
alias=[["0"], ["1"], []],
|
|
62
|
+
),
|
|
63
|
+
in_branch=[1],
|
|
64
|
+
out_branch=[1],
|
|
65
|
+
alias=["downsample"],
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.add_module(
|
|
69
|
+
"ConvBlock_0",
|
|
70
|
+
blocks.ConvBlock(
|
|
71
|
+
in_channels,
|
|
72
|
+
out_channels,
|
|
73
|
+
blocks.BlockConfig(
|
|
74
|
+
kernel_size=3,
|
|
75
|
+
stride=2 if downsample else 1,
|
|
76
|
+
padding=1,
|
|
77
|
+
bias=False,
|
|
78
|
+
activation="ReLU",
|
|
79
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
80
|
+
),
|
|
81
|
+
dim=dim,
|
|
82
|
+
alias=[["conv1"], ["bn1"], []],
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
self.add_module(
|
|
86
|
+
"ConvBlock_1",
|
|
87
|
+
blocks.ConvBlock(
|
|
88
|
+
out_channels,
|
|
89
|
+
out_channels,
|
|
90
|
+
blocks.BlockConfig(
|
|
91
|
+
kernel_size=3,
|
|
92
|
+
stride=1,
|
|
93
|
+
padding=1,
|
|
94
|
+
bias=False,
|
|
95
|
+
activation="None",
|
|
96
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
97
|
+
),
|
|
98
|
+
dim=dim,
|
|
99
|
+
alias=[["conv2"], ["bn2"], []],
|
|
100
|
+
),
|
|
101
|
+
)
|
|
102
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
|
|
39
103
|
self.add_module("ReLU", torch.nn.ReLU())
|
|
40
104
|
|
|
105
|
+
|
|
41
106
|
class ResBottleneckBlock(AbstractResBlock):
|
|
42
107
|
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
43
108
|
super().__init__(in_channels, out_channels, downsample, dim)
|
|
44
|
-
self.add_module(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
109
|
+
self.add_module(
|
|
110
|
+
"ConvBlock_0",
|
|
111
|
+
blocks.ConvBlock(
|
|
112
|
+
in_channels,
|
|
113
|
+
out_channels // 4,
|
|
114
|
+
blocks.BlockConfig(
|
|
115
|
+
kernel_size=1,
|
|
116
|
+
stride=1,
|
|
117
|
+
padding=0,
|
|
118
|
+
bias=False,
|
|
119
|
+
activation="ReLU",
|
|
120
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
121
|
+
),
|
|
122
|
+
dim=dim,
|
|
123
|
+
alias=[["conv1"], ["bn1"], []],
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
self.add_module(
|
|
127
|
+
"ConvBlock_1",
|
|
128
|
+
blocks.ConvBlock(
|
|
129
|
+
out_channels // 4,
|
|
130
|
+
out_channels // 4,
|
|
131
|
+
blocks.BlockConfig(
|
|
132
|
+
kernel_size=3,
|
|
133
|
+
stride=2 if downsample else 1,
|
|
134
|
+
padding=1,
|
|
135
|
+
bias=False,
|
|
136
|
+
activation="ReLU",
|
|
137
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
138
|
+
),
|
|
139
|
+
dim=dim,
|
|
140
|
+
alias=[["conv2"], ["bn2"], []],
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
self.add_module(
|
|
144
|
+
"ConvBlock_2",
|
|
145
|
+
blocks.ConvBlock(
|
|
146
|
+
out_channels // 4,
|
|
147
|
+
out_channels,
|
|
148
|
+
blocks.BlockConfig(
|
|
149
|
+
kernel_size=1,
|
|
150
|
+
stride=1,
|
|
151
|
+
padding=0,
|
|
152
|
+
bias=False,
|
|
153
|
+
activation="ReLU",
|
|
154
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
155
|
+
),
|
|
156
|
+
dim=dim,
|
|
157
|
+
alias=[["conv3"], ["bn3"], []],
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
|
|
48
161
|
if downsample or in_channels != out_channels:
|
|
49
|
-
self.add_module(
|
|
50
|
-
|
|
51
|
-
|
|
162
|
+
self.add_module(
|
|
163
|
+
"Shortcut",
|
|
164
|
+
blocks.ConvBlock(
|
|
165
|
+
in_channels,
|
|
166
|
+
out_channels,
|
|
167
|
+
blocks.BlockConfig(
|
|
168
|
+
kernel_size=1,
|
|
169
|
+
stride=2 if downsample else 1,
|
|
170
|
+
padding=0,
|
|
171
|
+
bias=False,
|
|
172
|
+
activation="None",
|
|
173
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
174
|
+
),
|
|
175
|
+
dim=dim,
|
|
176
|
+
alias=[["0"], ["1"], []],
|
|
177
|
+
),
|
|
178
|
+
in_branch=[1],
|
|
179
|
+
out_branch=[1],
|
|
180
|
+
alias=["downsample"],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
|
|
52
184
|
self.add_module("ReLU", torch.nn.ReLU())
|
|
53
185
|
|
|
186
|
+
|
|
54
187
|
class ResNetStage(network.ModuleArgsDict):
|
|
55
188
|
|
|
56
|
-
def __init__(
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
in_channels: int,
|
|
192
|
+
out_channels: int,
|
|
193
|
+
depth: int,
|
|
194
|
+
block: type[AbstractResBlock],
|
|
195
|
+
downsample: bool,
|
|
196
|
+
dim: int,
|
|
197
|
+
):
|
|
63
198
|
super().__init__()
|
|
64
|
-
self.add_module(
|
|
199
|
+
self.add_module(
|
|
200
|
+
"BottleNeckBlock_0",
|
|
201
|
+
block(
|
|
202
|
+
in_channels=in_channels,
|
|
203
|
+
out_channels=out_channels,
|
|
204
|
+
downsample=downsample,
|
|
205
|
+
dim=dim,
|
|
206
|
+
),
|
|
207
|
+
alias=["0"],
|
|
208
|
+
)
|
|
65
209
|
for i in range(1, depth):
|
|
66
|
-
self.add_module(
|
|
210
|
+
self.add_module(
|
|
211
|
+
f"BottleNeckBlock_{i}",
|
|
212
|
+
block(
|
|
213
|
+
in_channels=out_channels,
|
|
214
|
+
out_channels=out_channels,
|
|
215
|
+
downsample=False,
|
|
216
|
+
dim=dim,
|
|
217
|
+
),
|
|
218
|
+
alias=[f"{i}"],
|
|
219
|
+
)
|
|
220
|
+
|
|
67
221
|
|
|
68
222
|
class ResNetStem(network.ModuleArgsDict):
|
|
69
223
|
|
|
70
224
|
def __init__(self, in_channels: int, out_features: int, dim: int):
|
|
71
225
|
super().__init__()
|
|
72
|
-
self.add_module(
|
|
73
|
-
|
|
226
|
+
self.add_module(
|
|
227
|
+
"ConvBlock",
|
|
228
|
+
blocks.ConvBlock(
|
|
229
|
+
in_channels,
|
|
230
|
+
out_features,
|
|
231
|
+
blocks.BlockConfig(
|
|
232
|
+
kernel_size=7,
|
|
233
|
+
stride=2,
|
|
234
|
+
padding=3,
|
|
235
|
+
bias=False,
|
|
236
|
+
activation="ReLU",
|
|
237
|
+
norm_mode=blocks.NormMode.BATCH.name,
|
|
238
|
+
),
|
|
239
|
+
dim=dim,
|
|
240
|
+
alias=[["conv1"], ["bn1"], []],
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
self.add_module(
|
|
244
|
+
"MaxPool",
|
|
245
|
+
blocks.get_torch_module("MaxPool", dim)(kernel_size=3, stride=2, padding=1),
|
|
246
|
+
)
|
|
247
|
+
|
|
74
248
|
|
|
75
249
|
class ResNetEncoder(network.ModuleArgsDict):
|
|
76
|
-
|
|
77
|
-
def __init__(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
in_channels: int,
|
|
254
|
+
depths: list[int],
|
|
255
|
+
widths: list[int],
|
|
256
|
+
use_bottleneck: bool,
|
|
257
|
+
dim: int,
|
|
258
|
+
):
|
|
83
259
|
super().__init__()
|
|
84
260
|
self.add_module("ResNetStem", ResNetStem(in_channels, widths[0], dim=dim))
|
|
85
261
|
|
|
86
262
|
for i, (in_channels, out_channels, depth) in enumerate(list(zip(widths[:], widths[1:], depths))):
|
|
87
|
-
self.add_module(
|
|
263
|
+
self.add_module(
|
|
264
|
+
f"ResNetStage_{i}",
|
|
265
|
+
ResNetStage(
|
|
266
|
+
in_channels=in_channels,
|
|
267
|
+
out_channels=out_channels,
|
|
268
|
+
depth=depth,
|
|
269
|
+
block=ResBottleneckBlock if use_bottleneck else ResBlock,
|
|
270
|
+
downsample=i != 0,
|
|
271
|
+
dim=dim,
|
|
272
|
+
),
|
|
273
|
+
alias=[f"layer{i + 1}"],
|
|
274
|
+
)
|
|
275
|
+
|
|
88
276
|
|
|
89
277
|
class Head(network.ModuleArgsDict):
|
|
90
|
-
|
|
91
|
-
def __init__(self, in_features
|
|
278
|
+
|
|
279
|
+
def __init__(self, in_features: int, num_classes: int, dim: int):
|
|
92
280
|
super().__init__()
|
|
93
|
-
self.add_module("AdaptiveAvgPool", blocks.
|
|
281
|
+
self.add_module("AdaptiveAvgPool", blocks.get_torch_module("AdaptiveAvgPool", dim)(1))
|
|
94
282
|
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
95
|
-
self.add_module(
|
|
283
|
+
self.add_module(
|
|
284
|
+
"Linear",
|
|
285
|
+
torch.nn.Linear(in_features, num_classes),
|
|
286
|
+
pretrained=False,
|
|
287
|
+
alias=["fc"],
|
|
288
|
+
)
|
|
96
289
|
self.add_module("Unsqueeze", blocks.Unsqueeze(2))
|
|
97
290
|
|
|
98
|
-
|
|
291
|
+
|
|
99
292
|
class ResNet(network.Network):
|
|
100
293
|
|
|
101
|
-
def __init__(
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
297
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
298
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
299
|
+
},
|
|
300
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
301
|
+
patch: ModelPatch = ModelPatch(),
|
|
302
|
+
dim: int = 3,
|
|
303
|
+
in_channels: int = 1,
|
|
304
|
+
depths: list[int] = [2, 2, 2, 2],
|
|
305
|
+
widths: list[int] = [64, 64, 128, 256, 512],
|
|
306
|
+
num_classes: int = 10,
|
|
307
|
+
use_bottleneck=False,
|
|
308
|
+
):
|
|
309
|
+
super().__init__(
|
|
310
|
+
in_channels=in_channels,
|
|
311
|
+
optimizer=optimizer,
|
|
312
|
+
schedulers=schedulers,
|
|
313
|
+
outputs_criterions=outputs_criterions,
|
|
314
|
+
dim=dim,
|
|
315
|
+
patch=patch,
|
|
316
|
+
init_type="trunc_normal",
|
|
317
|
+
init_gain=0.02,
|
|
318
|
+
)
|
|
319
|
+
self.add_module(
|
|
320
|
+
"ResNetEncoder",
|
|
321
|
+
ResNetEncoder(
|
|
322
|
+
in_channels=in_channels,
|
|
323
|
+
depths=depths,
|
|
324
|
+
widths=widths,
|
|
325
|
+
use_bottleneck=use_bottleneck,
|
|
326
|
+
dim=dim,
|
|
327
|
+
),
|
|
328
|
+
)
|
|
114
329
|
self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
|
|
115
|
-
|