konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,37 @@
1
1
  import importlib
2
+
2
3
  import torch
3
4
 
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
- from konfai.data.patching import ModelPatch
5
+ from konfai.data.patching import ModelPatch
6
+ from konfai.network import blocks, network
7
+
7
8
 
8
9
  class MappingNetwork(network.ModuleArgsDict):
9
- def __init__(self, z_dim: int, c_dim: int, w_dim: int, num_layers: int, embed_features: int, layer_features: int):
10
+ def __init__(
11
+ self,
12
+ z_dim: int,
13
+ c_dim: int,
14
+ w_dim: int,
15
+ num_layers: int,
16
+ embed_features: int,
17
+ layer_features: int,
18
+ ):
10
19
  super().__init__()
11
-
12
- self.add_module("Concat_1", blocks.Concat(), in_branch=[0,1])
13
-
14
- features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
20
+
21
+ self.add_module("Concat_1", blocks.Concat(), in_branch=[0, 1])
22
+
23
+ features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
15
24
  if c_dim > 0:
16
25
  self.add_module("Linear", torch.nn.Linear(c_dim, embed_features), out_branch=["Embed"])
17
26
 
18
27
  self.add_module("Noise", blocks.NormalNoise(z_dim), in_branch=["Embed"])
19
28
  if c_dim > 0:
20
- self.add_module("Concat", blocks.Concat(), in_branch=[0,"Embed"])
21
-
29
+ self.add_module("Concat", blocks.Concat(), in_branch=[0, "Embed"])
30
+
22
31
  for i, (in_features, out_features) in enumerate(zip(features, features[1:])):
23
- self.add_module("Linear_{}".format(i), torch.nn.Linear(in_features, out_features))
24
-
32
+ self.add_module(f"Linear_{i}", torch.nn.Linear(in_features, out_features))
33
+
34
+
25
35
  class ModulatedConv(torch.nn.Module):
26
36
 
27
37
  class _ModulatedConv(torch.nn.Module):
@@ -35,56 +45,88 @@ class ModulatedConv(torch.nn.Module):
35
45
  self.padding = conv.padding
36
46
  self.stride = conv.stride
37
47
  if isinstance(conv, torch.nn.modules.conv._ConvTransposeNd):
38
- self.weight = torch.nn.parameter.Parameter(torch.randn((conv.in_channels, conv.out_channels, *conv.kernel_size)))
48
+ self.weight = torch.nn.parameter.Parameter(
49
+ torch.randn((conv.in_channels, conv.out_channels, *conv.kernel_size))
50
+ )
39
51
  self.isConv = False
40
52
  else:
41
- self.weight = torch.nn.parameter.Parameter(torch.randn((conv.out_channels, conv.in_channels, *conv.kernel_size)))
53
+ self.weight = torch.nn.parameter.Parameter(
54
+ torch.randn((conv.out_channels, conv.in_channels, *conv.kernel_size))
55
+ )
42
56
  conv.forward = self.forward
43
57
  self.styles = None
44
58
  self.dim = dim
45
59
 
46
- def setStyle(self, styles: torch.Tensor) -> None:
60
+ def set_style(self, styles: torch.Tensor) -> None:
47
61
  self.styles = styles
48
62
 
49
- def forward(self, input: torch.Tensor) -> torch.Tensor:
50
- b = input.shape[0]
51
- self.affine.to(input.device)
63
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
64
+ b = tensor.shape[0]
65
+ self.affine.to(tensor.device)
52
66
  styles = self.affine(self.styles)
53
- w1 = styles.reshape(b, -1, 1, *[1 for _ in range(self.dim)]) if not self.isConv else styles.reshape(b, 1, -1, *[1 for _ in range(self.dim)])
54
- w2 = self.weight.unsqueeze(0).to(input.device)
67
+ w1 = (
68
+ styles.reshape(b, -1, 1, *[1 for _ in range(self.dim)])
69
+ if not self.isConv
70
+ else styles.reshape(b, 1, -1, *[1 for _ in range(self.dim)])
71
+ )
72
+ w2 = self.weight.unsqueeze(0).to(tensor.device)
55
73
  weights = w2 * (w1 + 1)
56
74
 
57
- d = torch.rsqrt((weights ** 2).sum(dim=tuple([i+2 for i in range(len(weights.shape)-2)]), keepdim=True) + 1e-8)
75
+ d = torch.rsqrt(
76
+ (weights**2).sum(
77
+ dim=tuple([i + 2 for i in range(len(weights.shape) - 2)]),
78
+ keepdim=True,
79
+ )
80
+ + 1e-8
81
+ )
58
82
  weights = weights * d
59
-
60
- input = input.reshape(1, -1, *input.shape[2:])
83
+
84
+ tensor = tensor.reshape(1, -1, *tensor.shape[2:])
61
85
 
62
86
  _, _, *ws = weights.shape
63
87
  if not self.isConv:
64
- out = getattr(importlib.import_module("torch.nn.functional"), "conv_transpose{}d".format(self.dim))(input, weights.reshape(b * self.in_channels, *ws), stride=self.stride, padding=self.padding, groups=b)
88
+ out = getattr(
89
+ importlib.import_module("torch.nn.functional"),
90
+ f"conv_transpose{self.dim}d",
91
+ )(
92
+ tensor,
93
+ weights.reshape(b * self.in_channels, *ws),
94
+ stride=self.stride,
95
+ padding=self.padding,
96
+ groups=b,
97
+ )
65
98
  else:
66
- out = getattr(importlib.import_module("torch.nn.functional"), "conv{}d".format(self.dim))(input, weights.reshape(b * self.out_channels, *ws), padding=self.padding, groups=b, stride=self.stride)
67
-
99
+ out = getattr(
100
+ importlib.import_module("torch.nn.functional"),
101
+ f"conv{self.dim}d",
102
+ )(
103
+ tensor,
104
+ weights.reshape(b * self.out_channels, *ws),
105
+ padding=self.padding,
106
+ groups=b,
107
+ stride=self.stride,
108
+ )
109
+
68
110
  out = out.reshape(-1, self.out_channels, *out.shape[2:])
69
111
  return out
70
-
112
+
71
113
  def __init__(self, w_dim: int, module: torch.nn.Module) -> None:
72
114
  super().__init__()
73
115
  self.w_dim = w_dim
74
116
  self.module = module
75
- self.convs = torch.nn.ModuleList()
117
+ self.convs = torch.nn.ModuleList()
76
118
  self.module.apply(self.apply)
77
-
78
- def forward(self, input: torch.Tensor, styles: torch.Tensor) -> torch.Tensor:
119
+
120
+ def forward(self, tensor: torch.Tensor, styles: torch.Tensor) -> torch.Tensor:
79
121
  for conv in self.convs:
80
- conv.setStyle(styles.clone())
81
- return self.module(input)
122
+ conv.set_style(styles.clone())
123
+ return self.module(tensor)
82
124
 
83
125
  def apply(self, module: torch.nn.Module):
84
126
  if isinstance(module, torch.nn.modules.conv._ConvNd):
85
127
  delattr(module, "weight")
86
128
  module.bias = None
87
-
129
+
88
130
  str_dim = module.__class__.__name__[-2:]
89
131
  dim = 1
90
132
  if str_dim == "2d":
@@ -92,47 +134,179 @@ class ModulatedConv(torch.nn.Module):
92
134
  elif str_dim == "3d":
93
135
  dim = 3
94
136
  self.convs.append(ModulatedConv._ModulatedConv(self.w_dim, module, dim=dim))
95
-
137
+
138
+
96
139
  class UNetBlock(network.ModuleArgsDict):
97
140
 
98
- def __init__(self, w_dim: int, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, dim: int, i : int = 0) -> None:
141
+ def __init__(
142
+ self,
143
+ w_dim: int,
144
+ channels: list[int],
145
+ nb_conv_per_stage: int,
146
+ block_config: blocks.BlockConfig,
147
+ downsample_mode: blocks.DownsampleMode,
148
+ upsample_mode: blocks.UpsampleMode,
149
+ attention: bool,
150
+ dim: int,
151
+ i: int = 0,
152
+ ) -> None:
99
153
  super().__init__()
100
154
  if i > 0:
101
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
102
- self.add_module("DownConvBlock", blocks.ConvBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
155
+ self.add_module(
156
+ downsample_mode.name,
157
+ blocks.downsample(
158
+ in_channels=channels[0],
159
+ out_channels=channels[1],
160
+ downsample_mode=downsample_mode,
161
+ dim=dim,
162
+ ),
163
+ )
164
+ self.add_module(
165
+ "DownConvBlock",
166
+ blocks.ConvBlock(
167
+ in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
168
+ out_channels=channels[1],
169
+ block_configs=[block_config] * nb_conv_per_stage,
170
+ dim=dim,
171
+ ),
172
+ )
103
173
  if len(channels) > 2:
104
- self.add_module("UNetBlock_{}".format(i+1), UNetBlock(w_dim, channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, dim, i+1), in_branch=[0,1])
105
- self.add_module("UpConvBlock", ModulatedConv(w_dim, blocks.ConvBlock((channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim)), in_branch=[0,1])
174
+ self.add_module(
175
+ f"UNetBlock_{i + 1}",
176
+ UNetBlock(
177
+ w_dim,
178
+ channels[1:],
179
+ nb_conv_per_stage,
180
+ block_config,
181
+ downsample_mode,
182
+ upsample_mode,
183
+ attention,
184
+ dim,
185
+ i + 1,
186
+ ),
187
+ in_branch=[0, 1],
188
+ )
189
+ self.add_module(
190
+ "UpConvBlock",
191
+ ModulatedConv(
192
+ w_dim,
193
+ blocks.ConvBlock(
194
+ (
195
+ (channels[1] + channels[2])
196
+ if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
197
+ else channels[1] * 2
198
+ ),
199
+ out_channels=channels[1],
200
+ block_configs=[block_config] * nb_conv_per_stage,
201
+ dim=dim,
202
+ ),
203
+ ),
204
+ in_branch=[0, 1],
205
+ )
106
206
  if i > 0:
107
207
  if attention:
108
- self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=["Skip", 0], out_branch=["Skip"])
109
- self.add_module(upSampleMode.name, ModulatedConv(w_dim, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim)), in_branch=[0, 1])
208
+ self.add_module(
209
+ "Attention",
210
+ blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
211
+ in_branch=["Skip", 0],
212
+ out_branch=["Skip"],
213
+ )
214
+ self.add_module(
215
+ upsample_mode.name,
216
+ ModulatedConv(
217
+ w_dim,
218
+ blocks.upsample(
219
+ in_channels=channels[1],
220
+ out_channels=channels[0],
221
+ upsample_mode=upsample_mode,
222
+ dim=dim,
223
+ ),
224
+ ),
225
+ in_branch=[0, 1],
226
+ )
110
227
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, "Skip"])
111
228
 
229
+
112
230
  class Generator(network.Network):
113
231
 
114
232
  class GeneratorHead(network.ModuleArgsDict):
115
233
 
116
234
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
117
235
  super().__init__()
118
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
236
+ self.add_module(
237
+ "Conv",
238
+ blocks.get_torch_module("Conv", dim)(
239
+ in_channels=in_channels,
240
+ out_channels=out_channels,
241
+ kernel_size=3,
242
+ stride=1,
243
+ padding=1,
244
+ ),
245
+ )
119
246
  self.add_module("Tanh", torch.nn.Tanh())
120
247
 
121
- def __init__(self,
122
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
123
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
124
- patch : ModelPatch = ModelPatch(),
125
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
126
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
127
- nb_batch_per_step: int = 64,
128
- z_dim: int = 512,
129
- c_dim: int = 1,
130
- w_dim: int = 512,
131
- dim : int = 3) -> None:
132
- super().__init__(optimizer=optimizer, in_channels=channels[0], schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
133
-
134
- self.add_module("MappingNetwork", MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_layers=8, embed_features=w_dim, layer_features=w_dim), in_branch=[1,2], out_branch=["Style"])
248
+ def __init__(
249
+ self,
250
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
251
+ schedulers: dict[str, network.LRSchedulersLoader] = {
252
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
253
+ },
254
+ patch: ModelPatch = ModelPatch(),
255
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
256
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
257
+ nb_batch_per_step: int = 64,
258
+ z_dim: int = 512,
259
+ c_dim: int = 1,
260
+ w_dim: int = 512,
261
+ dim: int = 3,
262
+ ) -> None:
263
+ super().__init__(
264
+ optimizer=optimizer,
265
+ in_channels=channels[0],
266
+ schedulers=schedulers,
267
+ patch=patch,
268
+ outputs_criterions=outputs_criterions,
269
+ dim=dim,
270
+ nb_batch_per_step=nb_batch_per_step,
271
+ )
272
+
273
+ self.add_module(
274
+ "MappingNetwork",
275
+ MappingNetwork(
276
+ z_dim=z_dim,
277
+ c_dim=c_dim,
278
+ w_dim=w_dim,
279
+ num_layers=8,
280
+ embed_features=w_dim,
281
+ layer_features=w_dim,
282
+ ),
283
+ in_branch=[1, 2],
284
+ out_branch=["Style"],
285
+ )
135
286
  nb_conv_per_stage = 2
136
- blockConfig = blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=True, activation="ReLU", normMode="INSTANCE")
137
- self.add_module("UNetBlock_0", UNetBlock(w_dim, channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode.MAXPOOL, upSampleMode=blocks.UpSampleMode.CONV_TRANSPOSE, attention=False, dim=dim), in_branch=[0, "Style"])
138
- self.add_module("Head", Generator.GeneratorHead(in_channels=channels[1], out_channels=1, dim=dim))
287
+ block_config = blocks.BlockConfig(
288
+ kernel_size=3,
289
+ stride=1,
290
+ padding=1,
291
+ bias=True,
292
+ activation="ReLU",
293
+ norm_mode="INSTANCE",
294
+ )
295
+ self.add_module(
296
+ "UNetBlock_0",
297
+ UNetBlock(
298
+ w_dim,
299
+ channels,
300
+ nb_conv_per_stage,
301
+ block_config,
302
+ downsample_mode=blocks.DownsampleMode.MAXPOOL,
303
+ upsample_mode=blocks.UpsampleMode.CONV_TRANSPOSE,
304
+ attention=False,
305
+ dim=dim,
306
+ ),
307
+ in_branch=[0, "Style"],
308
+ )
309
+ self.add_module(
310
+ "Head",
311
+ Generator.GeneratorHead(in_channels=channels[1], out_channels=1, dim=dim),
312
+ )