konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/network/blocks.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Callable
|
|
4
|
+
import torch
|
|
5
|
+
from scipy.interpolate import interp1d
|
|
6
|
+
import numpy as np
|
|
7
|
+
import ast
|
|
8
|
+
from typing import Union
|
|
9
|
+
from KonfAI.konfai.utils.config import config
|
|
10
|
+
from KonfAI.konfai.network import network
|
|
11
|
+
|
|
12
|
+
class NormMode(Enum):
|
|
13
|
+
NONE = 0,
|
|
14
|
+
BATCH = 1
|
|
15
|
+
INSTANCE = 2
|
|
16
|
+
GROUP = 3
|
|
17
|
+
LAYER = 4
|
|
18
|
+
SYNCBATCH = 5
|
|
19
|
+
INSTANCE_AFFINE = 6
|
|
20
|
+
|
|
21
|
+
def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
|
|
22
|
+
if normMode == NormMode.BATCH:
|
|
23
|
+
return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
|
|
24
|
+
if normMode == NormMode.INSTANCE:
|
|
25
|
+
return getTorchModule("InstanceNorm", dim = dim)(channels, affine=False, track_running_stats=False)
|
|
26
|
+
if normMode == NormMode.INSTANCE_AFFINE:
|
|
27
|
+
return getTorchModule("InstanceNorm", dim = dim)(channels, affine=True, track_running_stats=False)
|
|
28
|
+
if normMode == NormMode.SYNCBATCH:
|
|
29
|
+
return torch.nn.SyncBatchNorm(channels, affine=True, track_running_stats=True)
|
|
30
|
+
if normMode == NormMode.GROUP:
|
|
31
|
+
return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
|
|
32
|
+
if normMode == NormMode.LAYER:
|
|
33
|
+
return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
|
|
34
|
+
return torch.nn.Identity()
|
|
35
|
+
|
|
36
|
+
class UpSampleMode(Enum):
|
|
37
|
+
CONV_TRANSPOSE = 0,
|
|
38
|
+
UPSAMPLE_NEAREST = 1,
|
|
39
|
+
UPSAMPLE_LINEAR = 2,
|
|
40
|
+
UPSAMPLE_BILINEAR = 3,
|
|
41
|
+
UPSAMPLE_BICUBIC = 4,
|
|
42
|
+
UPSAMPLE_TRILINEAR = 5
|
|
43
|
+
|
|
44
|
+
class DownSampleMode(Enum):
|
|
45
|
+
MAXPOOL = 0,
|
|
46
|
+
AVGPOOL = 1,
|
|
47
|
+
CONV_STRIDE = 2
|
|
48
|
+
|
|
49
|
+
def getTorchModule(name_fonction : str, dim : Union[int, None] = None) -> torch.nn.Module:
|
|
50
|
+
return getattr(importlib.import_module("torch.nn"), "{}".format(name_fonction) + ("{}d".format(dim) if dim is not None else ""))
|
|
51
|
+
|
|
52
|
+
class BlockConfig():
|
|
53
|
+
|
|
54
|
+
@config("BlockConfig")
|
|
55
|
+
def __init__(self, kernel_size : int = 3, stride : int = 1, padding : int = 1, bias = True, activation : Union[str, Callable[[], torch.nn.Module]] = "ReLU", normMode : Union[str, NormMode, Callable[[int], torch.nn.Module]] = "NONE") -> None:
|
|
56
|
+
self.kernel_size = kernel_size
|
|
57
|
+
self.bias = bias
|
|
58
|
+
self.stride = stride
|
|
59
|
+
self.padding = padding
|
|
60
|
+
self.activation = activation
|
|
61
|
+
self.normMode = normMode
|
|
62
|
+
|
|
63
|
+
if isinstance(normMode, str):
|
|
64
|
+
self.norm = NormMode._member_map_[normMode]
|
|
65
|
+
elif isinstance(normMode, NormMode):
|
|
66
|
+
self.norm = normMode
|
|
67
|
+
|
|
68
|
+
def getConv(self, in_channels : int, out_channels : int, dim : int) -> torch.nn.Conv3d:
|
|
69
|
+
return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
|
|
70
|
+
|
|
71
|
+
def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
|
|
72
|
+
return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
|
|
73
|
+
|
|
74
|
+
def getActivation(self) -> torch.nn.Module:
|
|
75
|
+
if isinstance(self.activation, str):
|
|
76
|
+
return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
|
|
77
|
+
return self.activation()
|
|
78
|
+
|
|
79
|
+
class ConvBlock(network.ModuleArgsDict):
|
|
80
|
+
|
|
81
|
+
def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], []]) -> None:
|
|
82
|
+
super().__init__()
|
|
83
|
+
for i, blockConfig in enumerate(blockConfigs):
|
|
84
|
+
self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
|
|
85
|
+
self.add_module("Norm_{}".format(i), blockConfig.getNorm(out_channels, dim), alias=alias[1])
|
|
86
|
+
self.add_module("Activation_{}".format(i), blockConfig.getActivation(), alias=alias[2])
|
|
87
|
+
in_channels = out_channels
|
|
88
|
+
|
|
89
|
+
class ResBlock(network.ModuleArgsDict):
|
|
90
|
+
|
|
91
|
+
def __init__(self, in_channels : int, out_channels : int, nb_conv: int, blockConfig : BlockConfig, dim : int, alias : list[list[str]]=[[], [], [], []]) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
for i in range(nb_conv):
|
|
94
|
+
self.add_module("ConvBlock_{}".format(i), ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blockConfig, dim=dim, alias=alias[:3]))
|
|
95
|
+
if in_channels != out_channels:
|
|
96
|
+
self.add_module("Conv_skip_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[3], in_branch=[1], out_branch=[1])
|
|
97
|
+
in_channels = out_channels
|
|
98
|
+
self.add_module("Add_{}".format(i), Add(), in_branch=[0,1], out_branch=[0,1])
|
|
99
|
+
|
|
100
|
+
def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
|
|
101
|
+
if downSampleMode == DownSampleMode.MAXPOOL:
|
|
102
|
+
return getTorchModule("MaxPool", dim = dim)(2)
|
|
103
|
+
if downSampleMode == DownSampleMode.AVGPOOL:
|
|
104
|
+
return getTorchModule("AvgPool", dim = dim)(2)
|
|
105
|
+
if downSampleMode == DownSampleMode.CONV_STRIDE:
|
|
106
|
+
return getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
|
|
107
|
+
|
|
108
|
+
def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, dim: int, kernel_size: Union[int, list[int]] = 2, stride: Union[int, list[int]] = 2):
|
|
109
|
+
if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
|
|
110
|
+
return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
|
|
111
|
+
else:
|
|
112
|
+
return torch.nn.Upsample(scale_factor=2, mode=upSampleMode.name.replace("UPSAMPLE_", "").lower(), align_corners=False)
|
|
113
|
+
|
|
114
|
+
class Unsqueeze(torch.nn.Module):
|
|
115
|
+
|
|
116
|
+
def __init__(self, dim: int = 0):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.dim = dim
|
|
119
|
+
|
|
120
|
+
def forward(self, *input : torch.Tensor) -> torch.Tensor:
|
|
121
|
+
return torch.unsqueeze(input, self.dim)
|
|
122
|
+
|
|
123
|
+
def extra_repr(self):
|
|
124
|
+
return "dim={}".format(self.dim)
|
|
125
|
+
|
|
126
|
+
class Permute(torch.nn.Module):
|
|
127
|
+
|
|
128
|
+
def __init__(self, dims : list[int]):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.dims = dims
|
|
131
|
+
|
|
132
|
+
def forward(self, input : torch.Tensor) -> torch.Tensor:
|
|
133
|
+
return torch.permute(input, self.dims)
|
|
134
|
+
|
|
135
|
+
def extra_repr(self):
|
|
136
|
+
return "dims={}".format(self.dims)
|
|
137
|
+
|
|
138
|
+
class ToChannels(Permute):
|
|
139
|
+
|
|
140
|
+
def __init__(self, dim):
|
|
141
|
+
super().__init__([0, dim+1, *[i+1 for i in range(dim)]])
|
|
142
|
+
|
|
143
|
+
class ToFeatures(Permute):
|
|
144
|
+
|
|
145
|
+
def __init__(self, dim):
|
|
146
|
+
super().__init__([0, *[i+2 for i in range(dim)], 1])
|
|
147
|
+
|
|
148
|
+
class Add(torch.nn.Module):
|
|
149
|
+
|
|
150
|
+
def __init__(self) -> None:
|
|
151
|
+
super().__init__()
|
|
152
|
+
|
|
153
|
+
def forward(self, *input : torch.Tensor) -> torch.Tensor:
|
|
154
|
+
return torch.sum(torch.stack(input), dim=0)
|
|
155
|
+
|
|
156
|
+
class Multiply(torch.nn.Module):
|
|
157
|
+
|
|
158
|
+
def __init__(self) -> None:
|
|
159
|
+
super().__init__()
|
|
160
|
+
|
|
161
|
+
def forward(self, *input : torch.Tensor) -> torch.Tensor:
|
|
162
|
+
return torch.mul(*input)
|
|
163
|
+
|
|
164
|
+
class Concat(torch.nn.Module):
|
|
165
|
+
|
|
166
|
+
def __init__(self) -> None:
|
|
167
|
+
super().__init__()
|
|
168
|
+
|
|
169
|
+
def forward(self, *input : torch.Tensor) -> torch.Tensor:
|
|
170
|
+
return torch.cat(input, dim=1)
|
|
171
|
+
|
|
172
|
+
class Print(torch.nn.Module):
|
|
173
|
+
|
|
174
|
+
def __init__(self) -> None:
|
|
175
|
+
super().__init__()
|
|
176
|
+
|
|
177
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
178
|
+
print(input.shape)
|
|
179
|
+
return input
|
|
180
|
+
|
|
181
|
+
class Write(torch.nn.Module):
|
|
182
|
+
|
|
183
|
+
def __init__(self) -> None:
|
|
184
|
+
super().__init__()
|
|
185
|
+
|
|
186
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
187
|
+
import SimpleITK as sitk
|
|
188
|
+
sitk.WriteImage(sitk.GetImageFromArray(input.clone()[0][0].cpu().numpy()), "./Data.mha")
|
|
189
|
+
return input
|
|
190
|
+
|
|
191
|
+
class Exit(torch.nn.Module):
|
|
192
|
+
|
|
193
|
+
def __init__(self) -> None:
|
|
194
|
+
super().__init__()
|
|
195
|
+
|
|
196
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
197
|
+
exit(0)
|
|
198
|
+
|
|
199
|
+
class Detach(torch.nn.Module):
|
|
200
|
+
|
|
201
|
+
def __init__(self) -> None:
|
|
202
|
+
super().__init__()
|
|
203
|
+
|
|
204
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
return input.detach()
|
|
206
|
+
|
|
207
|
+
class Negative(torch.nn.Module):
|
|
208
|
+
|
|
209
|
+
def __init__(self) -> None:
|
|
210
|
+
super().__init__()
|
|
211
|
+
|
|
212
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
213
|
+
return -input
|
|
214
|
+
|
|
215
|
+
class GetShape(torch.nn.Module):
|
|
216
|
+
|
|
217
|
+
def __init__(self) -> None:
|
|
218
|
+
super().__init__()
|
|
219
|
+
|
|
220
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
221
|
+
return torch.tensor(input.shape)
|
|
222
|
+
|
|
223
|
+
class ArgMax(torch.nn.Module):
|
|
224
|
+
|
|
225
|
+
def __init__(self, dim: int) -> None:
|
|
226
|
+
super().__init__()
|
|
227
|
+
self.dim = dim
|
|
228
|
+
|
|
229
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
230
|
+
return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
|
|
231
|
+
|
|
232
|
+
class Select(torch.nn.Module):
|
|
233
|
+
|
|
234
|
+
def __init__(self, slices: list[slice]) -> None:
|
|
235
|
+
super().__init__()
|
|
236
|
+
self.slices = tuple(slices)
|
|
237
|
+
|
|
238
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
239
|
+
result = input[self.slices]
|
|
240
|
+
for i, s in enumerate(range(len(result.shape))):
|
|
241
|
+
if s == 1:
|
|
242
|
+
result = result.squeeze(dim=i)
|
|
243
|
+
return result
|
|
244
|
+
|
|
245
|
+
class NormalNoise(torch.nn.Module):
|
|
246
|
+
|
|
247
|
+
def __init__(self, dim: Union[int, None] = None) -> None:
|
|
248
|
+
super().__init__()
|
|
249
|
+
self.dim = dim
|
|
250
|
+
|
|
251
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
252
|
+
if self.dim is not None:
|
|
253
|
+
return torch.randn(self.dim).to(input.device)
|
|
254
|
+
else:
|
|
255
|
+
return torch.randn_like(input).to(input.device)
|
|
256
|
+
|
|
257
|
+
class Const(torch.nn.Module):
|
|
258
|
+
|
|
259
|
+
def __init__(self, shape: list[int], std: float) -> None:
|
|
260
|
+
super().__init__()
|
|
261
|
+
self.noise = torch.nn.parameter.Parameter(torch.randn(shape)*std)
|
|
262
|
+
|
|
263
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
264
|
+
return self.noise.to(input.device)
|
|
265
|
+
|
|
266
|
+
class HistogramNoise(torch.nn.Module):
|
|
267
|
+
|
|
268
|
+
def __init__(self, n: int, sigma: float) -> None:
|
|
269
|
+
super().__init__()
|
|
270
|
+
self.x = np.linspace(0, 1, num=n, endpoint=True)
|
|
271
|
+
self.sigma = sigma
|
|
272
|
+
|
|
273
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
274
|
+
self.function = interp1d(self.x, self.x+np.random.normal(0, self.sigma, self.x.shape[0]), kind='cubic')
|
|
275
|
+
result = torch.empty_like(input)
|
|
276
|
+
|
|
277
|
+
for value in torch.unique(input):
|
|
278
|
+
x = self.function(value.cpu())
|
|
279
|
+
result[torch.where(input == value)] = torch.tensor(x, device=input.device).float()
|
|
280
|
+
return result
|
|
281
|
+
|
|
282
|
+
class Subset(torch.nn.Module):
|
|
283
|
+
def __init__(self, slices: list[slice]):
|
|
284
|
+
super().__init__()
|
|
285
|
+
self.slices = [slice(None, None), slice(None, None)] + slices
|
|
286
|
+
|
|
287
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
288
|
+
return tensor[self.slices]
|
|
289
|
+
|
|
290
|
+
class View(torch.nn.Module):
|
|
291
|
+
def __init__(self, size: list[int]):
|
|
292
|
+
super().__init__()
|
|
293
|
+
self.size = size
|
|
294
|
+
|
|
295
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
296
|
+
return tensor.view(self.size)
|
|
297
|
+
|
|
298
|
+
class LatentDistribution(network.ModuleArgsDict):
|
|
299
|
+
|
|
300
|
+
class LatentDistribution_Linear(torch.nn.Module):
|
|
301
|
+
|
|
302
|
+
def __init__(self, shape: list[int], latentDim: int) -> None:
|
|
303
|
+
super().__init__()
|
|
304
|
+
self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latentDim)
|
|
305
|
+
|
|
306
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
307
|
+
return torch.unsqueeze(self.linear(input), 1)
|
|
308
|
+
|
|
309
|
+
class LatentDistribution_Decoder(torch.nn.Module):
|
|
310
|
+
|
|
311
|
+
def __init__(self, shape: list[int], latentDim: int) -> None:
|
|
312
|
+
super().__init__()
|
|
313
|
+
self.linear = torch.nn.Linear(latentDim, torch.prod(torch.tensor(shape)))
|
|
314
|
+
self.shape = shape
|
|
315
|
+
|
|
316
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
317
|
+
return self.linear(input).view(-1, *[int(i) for i in self.shape])
|
|
318
|
+
|
|
319
|
+
class LatentDistribution_Z(torch.nn.Module):
|
|
320
|
+
|
|
321
|
+
def __init__(self) -> None:
|
|
322
|
+
super().__init__()
|
|
323
|
+
|
|
324
|
+
def forward(self, mu: torch.Tensor, log_std: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
return torch.exp(log_std/2)*torch.rand_like(mu)+mu
|
|
326
|
+
|
|
327
|
+
def __init__(self, shape: list[int], latentDim: int) -> None:
|
|
328
|
+
super().__init__()
|
|
329
|
+
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
330
|
+
self.add_module("mu", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [1])
|
|
331
|
+
self.add_module("log_std", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [2])
|
|
332
|
+
|
|
333
|
+
self.add_module("z", LatentDistribution.LatentDistribution_Z(), in_branch=[1,2], out_branch=[3])
|
|
334
|
+
self.add_module("Concat", Concat(), in_branch=[1,2,3])
|
|
335
|
+
self.add_module("DecoderInput", LatentDistribution.LatentDistribution_Decoder(shape, latentDim), in_branch=[3])
|
|
336
|
+
|
|
337
|
+
class Attention(network.ModuleArgsDict):
|
|
338
|
+
|
|
339
|
+
def __init__(self, F_g : int, F_l : int, F_int : int, dim : int):
|
|
340
|
+
super().__init__()
|
|
341
|
+
self.add_module("W_x", getTorchModule("Conv", dim = dim)(in_channels = F_l, out_channels = F_int, kernel_size=1, stride=2, padding=0), in_branch=[0], out_branch=[0])
|
|
342
|
+
self.add_module("W_g", getTorchModule("Conv", dim = dim)(in_channels = F_g, out_channels = F_int, kernel_size=1, stride=1, padding=0), in_branch=[1], out_branch=[1])
|
|
343
|
+
self.add_module("Add", Add(), in_branch=[0,1])
|
|
344
|
+
self.add_module("ReLU", torch.nn.ReLU(inplace=True))
|
|
345
|
+
self.add_module("Conv", getTorchModule("Conv", dim = dim)(in_channels = F_int, out_channels = 1, kernel_size=1,stride=1, padding=0))
|
|
346
|
+
self.add_module("Sigmoid", torch.nn.Sigmoid())
|
|
347
|
+
self.add_module("Upsample", torch.nn.Upsample(scale_factor=2))
|
|
348
|
+
self.add_module("Multiply", Multiply(), in_branch=[2,0])
|