konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,115 +1,329 @@
1
1
  from abc import ABC
2
- from typing import Type
2
+
3
3
  import torch
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
4
+
6
5
  from konfai.data.patching import ModelPatch
6
+ from konfai.network import blocks, network
7
7
 
8
8
  """
9
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
9
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10
+ dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512],
11
+ num_classes=1000, use_bottleneck=False
10
12
 
11
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 64, 128, 256, 512],
15
+ num_classes=1000, use_bottleneck=False
12
16
 
13
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
17
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18
+ dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 256, 512, 1024, 2048],
19
+ num_classes=1000, use_bottleneck=True
14
20
 
15
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', dim = 2, in_channels = 3, depths=[3, 4, 23, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
21
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22
+ dim = 2, in_channels = 3, depths=[3, 4, 23, 3], widths = [64, 256, 512, 1024, 2048],
23
+ num_classes=1000, use_bottleneck=True
16
24
 
17
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', dim = 2, in_channels = 3, depths=[3, 8, 36, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
25
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
26
+ dim = 2, in_channels = 3, depths=[3, 8, 36, 3], widths = [64, 256, 512, 1024, 2048],
27
+ num_classes=1000, use_bottleneck=True
18
28
 
19
29
  'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20
30
  'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21
31
  'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22
32
  'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23
33
  """
34
+
35
+
24
36
  class AbstractResBlock(network.ModuleArgsDict, ABC):
25
37
 
26
38
  def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
27
39
  super().__init__()
28
40
 
41
+
29
42
  class ResBlock(AbstractResBlock):
30
43
 
31
44
  def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
32
45
  super().__init__(in_channels, out_channels, downsample, dim)
33
46
  if downsample:
34
- self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
35
-
36
- self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
37
- self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
38
- self.add_module("Residual", blocks.Add(), in_branch=[0,1])
47
+ self.add_module(
48
+ "Shortcut",
49
+ blocks.ConvBlock(
50
+ in_channels,
51
+ out_channels,
52
+ blocks.BlockConfig(
53
+ kernel_size=1,
54
+ stride=2,
55
+ padding=0,
56
+ bias=False,
57
+ activation="None",
58
+ norm_mode=blocks.NormMode.BATCH.name,
59
+ ),
60
+ dim=dim,
61
+ alias=[["0"], ["1"], []],
62
+ ),
63
+ in_branch=[1],
64
+ out_branch=[1],
65
+ alias=["downsample"],
66
+ )
67
+
68
+ self.add_module(
69
+ "ConvBlock_0",
70
+ blocks.ConvBlock(
71
+ in_channels,
72
+ out_channels,
73
+ blocks.BlockConfig(
74
+ kernel_size=3,
75
+ stride=2 if downsample else 1,
76
+ padding=1,
77
+ bias=False,
78
+ activation="ReLU",
79
+ norm_mode=blocks.NormMode.BATCH.name,
80
+ ),
81
+ dim=dim,
82
+ alias=[["conv1"], ["bn1"], []],
83
+ ),
84
+ )
85
+ self.add_module(
86
+ "ConvBlock_1",
87
+ blocks.ConvBlock(
88
+ out_channels,
89
+ out_channels,
90
+ blocks.BlockConfig(
91
+ kernel_size=3,
92
+ stride=1,
93
+ padding=1,
94
+ bias=False,
95
+ activation="None",
96
+ norm_mode=blocks.NormMode.BATCH.name,
97
+ ),
98
+ dim=dim,
99
+ alias=[["conv2"], ["bn2"], []],
100
+ ),
101
+ )
102
+ self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
39
103
  self.add_module("ReLU", torch.nn.ReLU())
40
104
 
105
+
41
106
  class ResBottleneckBlock(AbstractResBlock):
42
107
  def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
43
108
  super().__init__(in_channels, out_channels, downsample, dim)
44
- self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels//4, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
45
- self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels//4, out_channels//4, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
46
- self.add_module("ConvBlock_2", blocks.ConvBlock(out_channels//4, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv3"], ["bn3"], []]))
47
-
109
+ self.add_module(
110
+ "ConvBlock_0",
111
+ blocks.ConvBlock(
112
+ in_channels,
113
+ out_channels // 4,
114
+ blocks.BlockConfig(
115
+ kernel_size=1,
116
+ stride=1,
117
+ padding=0,
118
+ bias=False,
119
+ activation="ReLU",
120
+ norm_mode=blocks.NormMode.BATCH.name,
121
+ ),
122
+ dim=dim,
123
+ alias=[["conv1"], ["bn1"], []],
124
+ ),
125
+ )
126
+ self.add_module(
127
+ "ConvBlock_1",
128
+ blocks.ConvBlock(
129
+ out_channels // 4,
130
+ out_channels // 4,
131
+ blocks.BlockConfig(
132
+ kernel_size=3,
133
+ stride=2 if downsample else 1,
134
+ padding=1,
135
+ bias=False,
136
+ activation="ReLU",
137
+ norm_mode=blocks.NormMode.BATCH.name,
138
+ ),
139
+ dim=dim,
140
+ alias=[["conv2"], ["bn2"], []],
141
+ ),
142
+ )
143
+ self.add_module(
144
+ "ConvBlock_2",
145
+ blocks.ConvBlock(
146
+ out_channels // 4,
147
+ out_channels,
148
+ blocks.BlockConfig(
149
+ kernel_size=1,
150
+ stride=1,
151
+ padding=0,
152
+ bias=False,
153
+ activation="ReLU",
154
+ norm_mode=blocks.NormMode.BATCH.name,
155
+ ),
156
+ dim=dim,
157
+ alias=[["conv3"], ["bn3"], []],
158
+ ),
159
+ )
160
+
48
161
  if downsample or in_channels != out_channels:
49
- self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2 if downsample else 1, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
50
-
51
- self.add_module("Residual", blocks.Add(), in_branch=[0,1])
162
+ self.add_module(
163
+ "Shortcut",
164
+ blocks.ConvBlock(
165
+ in_channels,
166
+ out_channels,
167
+ blocks.BlockConfig(
168
+ kernel_size=1,
169
+ stride=2 if downsample else 1,
170
+ padding=0,
171
+ bias=False,
172
+ activation="None",
173
+ norm_mode=blocks.NormMode.BATCH.name,
174
+ ),
175
+ dim=dim,
176
+ alias=[["0"], ["1"], []],
177
+ ),
178
+ in_branch=[1],
179
+ out_branch=[1],
180
+ alias=["downsample"],
181
+ )
182
+
183
+ self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
52
184
  self.add_module("ReLU", torch.nn.ReLU())
53
185
 
186
+
54
187
  class ResNetStage(network.ModuleArgsDict):
55
188
 
56
- def __init__(self,
57
- in_channels: int,
58
- out_channels: int,
59
- depth: int,
60
- block : Type[AbstractResBlock],
61
- downsample : bool,
62
- dim : int):
189
+ def __init__(
190
+ self,
191
+ in_channels: int,
192
+ out_channels: int,
193
+ depth: int,
194
+ block: type[AbstractResBlock],
195
+ downsample: bool,
196
+ dim: int,
197
+ ):
63
198
  super().__init__()
64
- self.add_module("BottleNeckBlock_0", block(in_channels=in_channels, out_channels=out_channels, downsample=downsample, dim=dim), alias=["0"])
199
+ self.add_module(
200
+ "BottleNeckBlock_0",
201
+ block(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ downsample=downsample,
205
+ dim=dim,
206
+ ),
207
+ alias=["0"],
208
+ )
65
209
  for i in range(1, depth):
66
- self.add_module("BottleNeckBlock_{}".format(i), block(in_channels=out_channels, out_channels=out_channels, downsample=False, dim=dim), alias=["{}".format(i)])
210
+ self.add_module(
211
+ f"BottleNeckBlock_{i}",
212
+ block(
213
+ in_channels=out_channels,
214
+ out_channels=out_channels,
215
+ downsample=False,
216
+ dim=dim,
217
+ ),
218
+ alias=[f"{i}"],
219
+ )
220
+
67
221
 
68
222
  class ResNetStem(network.ModuleArgsDict):
69
223
 
70
224
  def __init__(self, in_channels: int, out_features: int, dim: int):
71
225
  super().__init__()
72
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_features, 1, blocks.BlockConfig(kernel_size=7, stride=2, padding=3, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
73
- self.add_module("MaxPool", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1))
226
+ self.add_module(
227
+ "ConvBlock",
228
+ blocks.ConvBlock(
229
+ in_channels,
230
+ out_features,
231
+ blocks.BlockConfig(
232
+ kernel_size=7,
233
+ stride=2,
234
+ padding=3,
235
+ bias=False,
236
+ activation="ReLU",
237
+ norm_mode=blocks.NormMode.BATCH.name,
238
+ ),
239
+ dim=dim,
240
+ alias=[["conv1"], ["bn1"], []],
241
+ ),
242
+ )
243
+ self.add_module(
244
+ "MaxPool",
245
+ blocks.get_torch_module("MaxPool", dim)(kernel_size=3, stride=2, padding=1),
246
+ )
247
+
74
248
 
75
249
  class ResNetEncoder(network.ModuleArgsDict):
76
-
77
- def __init__(self,
78
- in_channels: int,
79
- depths: list[int],
80
- widths: list[int],
81
- useBottleneck : bool,
82
- dim : int):
250
+
251
+ def __init__(
252
+ self,
253
+ in_channels: int,
254
+ depths: list[int],
255
+ widths: list[int],
256
+ use_bottleneck: bool,
257
+ dim: int,
258
+ ):
83
259
  super().__init__()
84
260
  self.add_module("ResNetStem", ResNetStem(in_channels, widths[0], dim=dim))
85
261
 
86
262
  for i, (in_channels, out_channels, depth) in enumerate(list(zip(widths[:], widths[1:], depths))):
87
- self.add_module("ResNetStage_{}".format(i), ResNetStage(in_channels=in_channels, out_channels=out_channels, depth=depth, block=ResBottleneckBlock if useBottleneck else ResBlock, downsample=i!=0, dim=dim), alias=["layer{}".format(i+1)])
263
+ self.add_module(
264
+ f"ResNetStage_{i}",
265
+ ResNetStage(
266
+ in_channels=in_channels,
267
+ out_channels=out_channels,
268
+ depth=depth,
269
+ block=ResBottleneckBlock if use_bottleneck else ResBlock,
270
+ downsample=i != 0,
271
+ dim=dim,
272
+ ),
273
+ alias=[f"layer{i + 1}"],
274
+ )
275
+
88
276
 
89
277
  class Head(network.ModuleArgsDict):
90
-
91
- def __init__(self, in_features : int, num_classes : int, dim : int):
278
+
279
+ def __init__(self, in_features: int, num_classes: int, dim: int):
92
280
  super().__init__()
93
- self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(1))
281
+ self.add_module("AdaptiveAvgPool", blocks.get_torch_module("AdaptiveAvgPool", dim)(1))
94
282
  self.add_module("Flatten", torch.nn.Flatten(1))
95
- self.add_module("Linear", torch.nn.Linear(in_features, num_classes), pretrained=False, alias=["fc"])
283
+ self.add_module(
284
+ "Linear",
285
+ torch.nn.Linear(in_features, num_classes),
286
+ pretrained=False,
287
+ alias=["fc"],
288
+ )
96
289
  self.add_module("Unsqueeze", blocks.Unsqueeze(2))
97
290
 
98
-
291
+
99
292
  class ResNet(network.Network):
100
293
 
101
- def __init__( self,
102
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
103
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
104
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
105
- patch : ModelPatch = ModelPatch(),
106
- dim : int = 3,
107
- in_channels: int = 1,
108
- depths: list[int] = [2, 2, 2, 2],
109
- widths: list[int] = [64, 64, 128, 256, 512],
110
- num_classes: int = 10,
111
- useBottleneck=False):
112
- super().__init__(in_channels = in_channels, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, patch=patch, init_type = "trunc_normal", init_gain=0.02)
113
- self.add_module("ResNetEncoder", ResNetEncoder(in_channels=in_channels, depths=depths, widths=widths, useBottleneck=useBottleneck, dim=dim))
294
+ def __init__(
295
+ self,
296
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
297
+ schedulers: dict[str, network.LRSchedulersLoader] = {
298
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
299
+ },
300
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
301
+ patch: ModelPatch = ModelPatch(),
302
+ dim: int = 3,
303
+ in_channels: int = 1,
304
+ depths: list[int] = [2, 2, 2, 2],
305
+ widths: list[int] = [64, 64, 128, 256, 512],
306
+ num_classes: int = 10,
307
+ use_bottleneck=False,
308
+ ):
309
+ super().__init__(
310
+ in_channels=in_channels,
311
+ optimizer=optimizer,
312
+ schedulers=schedulers,
313
+ outputs_criterions=outputs_criterions,
314
+ dim=dim,
315
+ patch=patch,
316
+ init_type="trunc_normal",
317
+ init_gain=0.02,
318
+ )
319
+ self.add_module(
320
+ "ResNetEncoder",
321
+ ResNetEncoder(
322
+ in_channels=in_channels,
323
+ depths=depths,
324
+ widths=widths,
325
+ use_bottleneck=use_bottleneck,
326
+ dim=dim,
327
+ ),
328
+ )
114
329
  self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
115
-