konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
|
@@ -1,27 +1,37 @@
|
|
|
1
1
|
import importlib
|
|
2
|
+
|
|
2
3
|
import torch
|
|
3
4
|
|
|
4
|
-
from konfai.
|
|
5
|
-
from konfai.
|
|
6
|
-
|
|
5
|
+
from konfai.data.patching import ModelPatch
|
|
6
|
+
from konfai.network import blocks, network
|
|
7
|
+
|
|
7
8
|
|
|
8
9
|
class MappingNetwork(network.ModuleArgsDict):
|
|
9
|
-
def __init__(
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
z_dim: int,
|
|
13
|
+
c_dim: int,
|
|
14
|
+
w_dim: int,
|
|
15
|
+
num_layers: int,
|
|
16
|
+
embed_features: int,
|
|
17
|
+
layer_features: int,
|
|
18
|
+
):
|
|
10
19
|
super().__init__()
|
|
11
|
-
|
|
12
|
-
self.add_module("Concat_1", blocks.Concat(), in_branch=[0,1])
|
|
13
|
-
|
|
14
|
-
features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
|
|
20
|
+
|
|
21
|
+
self.add_module("Concat_1", blocks.Concat(), in_branch=[0, 1])
|
|
22
|
+
|
|
23
|
+
features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
|
|
15
24
|
if c_dim > 0:
|
|
16
25
|
self.add_module("Linear", torch.nn.Linear(c_dim, embed_features), out_branch=["Embed"])
|
|
17
26
|
|
|
18
27
|
self.add_module("Noise", blocks.NormalNoise(z_dim), in_branch=["Embed"])
|
|
19
28
|
if c_dim > 0:
|
|
20
|
-
self.add_module("Concat", blocks.Concat(), in_branch=[0,"Embed"])
|
|
21
|
-
|
|
29
|
+
self.add_module("Concat", blocks.Concat(), in_branch=[0, "Embed"])
|
|
30
|
+
|
|
22
31
|
for i, (in_features, out_features) in enumerate(zip(features, features[1:])):
|
|
23
|
-
self.add_module("Linear_{}"
|
|
24
|
-
|
|
32
|
+
self.add_module(f"Linear_{i}", torch.nn.Linear(in_features, out_features))
|
|
33
|
+
|
|
34
|
+
|
|
25
35
|
class ModulatedConv(torch.nn.Module):
|
|
26
36
|
|
|
27
37
|
class _ModulatedConv(torch.nn.Module):
|
|
@@ -35,56 +45,88 @@ class ModulatedConv(torch.nn.Module):
|
|
|
35
45
|
self.padding = conv.padding
|
|
36
46
|
self.stride = conv.stride
|
|
37
47
|
if isinstance(conv, torch.nn.modules.conv._ConvTransposeNd):
|
|
38
|
-
self.weight = torch.nn.parameter.Parameter(
|
|
48
|
+
self.weight = torch.nn.parameter.Parameter(
|
|
49
|
+
torch.randn((conv.in_channels, conv.out_channels, *conv.kernel_size))
|
|
50
|
+
)
|
|
39
51
|
self.isConv = False
|
|
40
52
|
else:
|
|
41
|
-
self.weight = torch.nn.parameter.Parameter(
|
|
53
|
+
self.weight = torch.nn.parameter.Parameter(
|
|
54
|
+
torch.randn((conv.out_channels, conv.in_channels, *conv.kernel_size))
|
|
55
|
+
)
|
|
42
56
|
conv.forward = self.forward
|
|
43
57
|
self.styles = None
|
|
44
58
|
self.dim = dim
|
|
45
59
|
|
|
46
|
-
def
|
|
60
|
+
def set_style(self, styles: torch.Tensor) -> None:
|
|
47
61
|
self.styles = styles
|
|
48
62
|
|
|
49
|
-
def forward(self,
|
|
50
|
-
b =
|
|
51
|
-
self.affine.to(
|
|
63
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
b = tensor.shape[0]
|
|
65
|
+
self.affine.to(tensor.device)
|
|
52
66
|
styles = self.affine(self.styles)
|
|
53
|
-
w1 =
|
|
54
|
-
|
|
67
|
+
w1 = (
|
|
68
|
+
styles.reshape(b, -1, 1, *[1 for _ in range(self.dim)])
|
|
69
|
+
if not self.isConv
|
|
70
|
+
else styles.reshape(b, 1, -1, *[1 for _ in range(self.dim)])
|
|
71
|
+
)
|
|
72
|
+
w2 = self.weight.unsqueeze(0).to(tensor.device)
|
|
55
73
|
weights = w2 * (w1 + 1)
|
|
56
74
|
|
|
57
|
-
d = torch.rsqrt(
|
|
75
|
+
d = torch.rsqrt(
|
|
76
|
+
(weights**2).sum(
|
|
77
|
+
dim=tuple([i + 2 for i in range(len(weights.shape) - 2)]),
|
|
78
|
+
keepdim=True,
|
|
79
|
+
)
|
|
80
|
+
+ 1e-8
|
|
81
|
+
)
|
|
58
82
|
weights = weights * d
|
|
59
|
-
|
|
60
|
-
|
|
83
|
+
|
|
84
|
+
tensor = tensor.reshape(1, -1, *tensor.shape[2:])
|
|
61
85
|
|
|
62
86
|
_, _, *ws = weights.shape
|
|
63
87
|
if not self.isConv:
|
|
64
|
-
out = getattr(
|
|
88
|
+
out = getattr(
|
|
89
|
+
importlib.import_module("torch.nn.functional"),
|
|
90
|
+
f"conv_transpose{self.dim}d",
|
|
91
|
+
)(
|
|
92
|
+
tensor,
|
|
93
|
+
weights.reshape(b * self.in_channels, *ws),
|
|
94
|
+
stride=self.stride,
|
|
95
|
+
padding=self.padding,
|
|
96
|
+
groups=b,
|
|
97
|
+
)
|
|
65
98
|
else:
|
|
66
|
-
out = getattr(
|
|
67
|
-
|
|
99
|
+
out = getattr(
|
|
100
|
+
importlib.import_module("torch.nn.functional"),
|
|
101
|
+
f"conv{self.dim}d",
|
|
102
|
+
)(
|
|
103
|
+
tensor,
|
|
104
|
+
weights.reshape(b * self.out_channels, *ws),
|
|
105
|
+
padding=self.padding,
|
|
106
|
+
groups=b,
|
|
107
|
+
stride=self.stride,
|
|
108
|
+
)
|
|
109
|
+
|
|
68
110
|
out = out.reshape(-1, self.out_channels, *out.shape[2:])
|
|
69
111
|
return out
|
|
70
|
-
|
|
112
|
+
|
|
71
113
|
def __init__(self, w_dim: int, module: torch.nn.Module) -> None:
|
|
72
114
|
super().__init__()
|
|
73
115
|
self.w_dim = w_dim
|
|
74
116
|
self.module = module
|
|
75
|
-
self.convs = torch.nn.ModuleList()
|
|
117
|
+
self.convs = torch.nn.ModuleList()
|
|
76
118
|
self.module.apply(self.apply)
|
|
77
|
-
|
|
78
|
-
def forward(self,
|
|
119
|
+
|
|
120
|
+
def forward(self, tensor: torch.Tensor, styles: torch.Tensor) -> torch.Tensor:
|
|
79
121
|
for conv in self.convs:
|
|
80
|
-
conv.
|
|
81
|
-
return self.module(
|
|
122
|
+
conv.set_style(styles.clone())
|
|
123
|
+
return self.module(tensor)
|
|
82
124
|
|
|
83
125
|
def apply(self, module: torch.nn.Module):
|
|
84
126
|
if isinstance(module, torch.nn.modules.conv._ConvNd):
|
|
85
127
|
delattr(module, "weight")
|
|
86
128
|
module.bias = None
|
|
87
|
-
|
|
129
|
+
|
|
88
130
|
str_dim = module.__class__.__name__[-2:]
|
|
89
131
|
dim = 1
|
|
90
132
|
if str_dim == "2d":
|
|
@@ -92,47 +134,179 @@ class ModulatedConv(torch.nn.Module):
|
|
|
92
134
|
elif str_dim == "3d":
|
|
93
135
|
dim = 3
|
|
94
136
|
self.convs.append(ModulatedConv._ModulatedConv(self.w_dim, module, dim=dim))
|
|
95
|
-
|
|
137
|
+
|
|
138
|
+
|
|
96
139
|
class UNetBlock(network.ModuleArgsDict):
|
|
97
140
|
|
|
98
|
-
def __init__(
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
w_dim: int,
|
|
144
|
+
channels: list[int],
|
|
145
|
+
nb_conv_per_stage: int,
|
|
146
|
+
block_config: blocks.BlockConfig,
|
|
147
|
+
downsample_mode: blocks.DownsampleMode,
|
|
148
|
+
upsample_mode: blocks.UpsampleMode,
|
|
149
|
+
attention: bool,
|
|
150
|
+
dim: int,
|
|
151
|
+
i: int = 0,
|
|
152
|
+
) -> None:
|
|
99
153
|
super().__init__()
|
|
100
154
|
if i > 0:
|
|
101
|
-
self.add_module(
|
|
102
|
-
|
|
155
|
+
self.add_module(
|
|
156
|
+
downsample_mode.name,
|
|
157
|
+
blocks.downsample(
|
|
158
|
+
in_channels=channels[0],
|
|
159
|
+
out_channels=channels[1],
|
|
160
|
+
downsample_mode=downsample_mode,
|
|
161
|
+
dim=dim,
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
self.add_module(
|
|
165
|
+
"DownConvBlock",
|
|
166
|
+
blocks.ConvBlock(
|
|
167
|
+
in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
|
|
168
|
+
out_channels=channels[1],
|
|
169
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
170
|
+
dim=dim,
|
|
171
|
+
),
|
|
172
|
+
)
|
|
103
173
|
if len(channels) > 2:
|
|
104
|
-
self.add_module(
|
|
105
|
-
|
|
174
|
+
self.add_module(
|
|
175
|
+
f"UNetBlock_{i + 1}",
|
|
176
|
+
UNetBlock(
|
|
177
|
+
w_dim,
|
|
178
|
+
channels[1:],
|
|
179
|
+
nb_conv_per_stage,
|
|
180
|
+
block_config,
|
|
181
|
+
downsample_mode,
|
|
182
|
+
upsample_mode,
|
|
183
|
+
attention,
|
|
184
|
+
dim,
|
|
185
|
+
i + 1,
|
|
186
|
+
),
|
|
187
|
+
in_branch=[0, 1],
|
|
188
|
+
)
|
|
189
|
+
self.add_module(
|
|
190
|
+
"UpConvBlock",
|
|
191
|
+
ModulatedConv(
|
|
192
|
+
w_dim,
|
|
193
|
+
blocks.ConvBlock(
|
|
194
|
+
(
|
|
195
|
+
(channels[1] + channels[2])
|
|
196
|
+
if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
|
|
197
|
+
else channels[1] * 2
|
|
198
|
+
),
|
|
199
|
+
out_channels=channels[1],
|
|
200
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
201
|
+
dim=dim,
|
|
202
|
+
),
|
|
203
|
+
),
|
|
204
|
+
in_branch=[0, 1],
|
|
205
|
+
)
|
|
106
206
|
if i > 0:
|
|
107
207
|
if attention:
|
|
108
|
-
self.add_module(
|
|
109
|
-
|
|
208
|
+
self.add_module(
|
|
209
|
+
"Attention",
|
|
210
|
+
blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
|
|
211
|
+
in_branch=["Skip", 0],
|
|
212
|
+
out_branch=["Skip"],
|
|
213
|
+
)
|
|
214
|
+
self.add_module(
|
|
215
|
+
upsample_mode.name,
|
|
216
|
+
ModulatedConv(
|
|
217
|
+
w_dim,
|
|
218
|
+
blocks.upsample(
|
|
219
|
+
in_channels=channels[1],
|
|
220
|
+
out_channels=channels[0],
|
|
221
|
+
upsample_mode=upsample_mode,
|
|
222
|
+
dim=dim,
|
|
223
|
+
),
|
|
224
|
+
),
|
|
225
|
+
in_branch=[0, 1],
|
|
226
|
+
)
|
|
110
227
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, "Skip"])
|
|
111
228
|
|
|
229
|
+
|
|
112
230
|
class Generator(network.Network):
|
|
113
231
|
|
|
114
232
|
class GeneratorHead(network.ModuleArgsDict):
|
|
115
233
|
|
|
116
234
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
117
235
|
super().__init__()
|
|
118
|
-
self.add_module(
|
|
236
|
+
self.add_module(
|
|
237
|
+
"Conv",
|
|
238
|
+
blocks.get_torch_module("Conv", dim)(
|
|
239
|
+
in_channels=in_channels,
|
|
240
|
+
out_channels=out_channels,
|
|
241
|
+
kernel_size=3,
|
|
242
|
+
stride=1,
|
|
243
|
+
padding=1,
|
|
244
|
+
),
|
|
245
|
+
)
|
|
119
246
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
120
247
|
|
|
121
|
-
def __init__(
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
248
|
+
def __init__(
|
|
249
|
+
self,
|
|
250
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
251
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
252
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
253
|
+
},
|
|
254
|
+
patch: ModelPatch = ModelPatch(),
|
|
255
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
256
|
+
channels: list[int] = [1, 64, 128, 256, 512, 1024],
|
|
257
|
+
nb_batch_per_step: int = 64,
|
|
258
|
+
z_dim: int = 512,
|
|
259
|
+
c_dim: int = 1,
|
|
260
|
+
w_dim: int = 512,
|
|
261
|
+
dim: int = 3,
|
|
262
|
+
) -> None:
|
|
263
|
+
super().__init__(
|
|
264
|
+
optimizer=optimizer,
|
|
265
|
+
in_channels=channels[0],
|
|
266
|
+
schedulers=schedulers,
|
|
267
|
+
patch=patch,
|
|
268
|
+
outputs_criterions=outputs_criterions,
|
|
269
|
+
dim=dim,
|
|
270
|
+
nb_batch_per_step=nb_batch_per_step,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
self.add_module(
|
|
274
|
+
"MappingNetwork",
|
|
275
|
+
MappingNetwork(
|
|
276
|
+
z_dim=z_dim,
|
|
277
|
+
c_dim=c_dim,
|
|
278
|
+
w_dim=w_dim,
|
|
279
|
+
num_layers=8,
|
|
280
|
+
embed_features=w_dim,
|
|
281
|
+
layer_features=w_dim,
|
|
282
|
+
),
|
|
283
|
+
in_branch=[1, 2],
|
|
284
|
+
out_branch=["Style"],
|
|
285
|
+
)
|
|
135
286
|
nb_conv_per_stage = 2
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
287
|
+
block_config = blocks.BlockConfig(
|
|
288
|
+
kernel_size=3,
|
|
289
|
+
stride=1,
|
|
290
|
+
padding=1,
|
|
291
|
+
bias=True,
|
|
292
|
+
activation="ReLU",
|
|
293
|
+
norm_mode="INSTANCE",
|
|
294
|
+
)
|
|
295
|
+
self.add_module(
|
|
296
|
+
"UNetBlock_0",
|
|
297
|
+
UNetBlock(
|
|
298
|
+
w_dim,
|
|
299
|
+
channels,
|
|
300
|
+
nb_conv_per_stage,
|
|
301
|
+
block_config,
|
|
302
|
+
downsample_mode=blocks.DownsampleMode.MAXPOOL,
|
|
303
|
+
upsample_mode=blocks.UpsampleMode.CONV_TRANSPOSE,
|
|
304
|
+
attention=False,
|
|
305
|
+
dim=dim,
|
|
306
|
+
),
|
|
307
|
+
in_branch=[0, "Style"],
|
|
308
|
+
)
|
|
309
|
+
self.add_module(
|
|
310
|
+
"Head",
|
|
311
|
+
Generator.GeneratorHead(in_channels=channels[1], out_channels=1, dim=dim),
|
|
312
|
+
)
|