konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,348 @@
1
+ from enum import Enum
2
+ import importlib
3
+ from typing import Callable
4
+ import torch
5
+ from scipy.interpolate import interp1d
6
+ import numpy as np
7
+ import ast
8
+ from typing import Union
9
+ from KonfAI.konfai.utils.config import config
10
+ from KonfAI.konfai.network import network
11
+
12
+ class NormMode(Enum):
13
+ NONE = 0,
14
+ BATCH = 1
15
+ INSTANCE = 2
16
+ GROUP = 3
17
+ LAYER = 4
18
+ SYNCBATCH = 5
19
+ INSTANCE_AFFINE = 6
20
+
21
+ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
22
+ if normMode == NormMode.BATCH:
23
+ return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
24
+ if normMode == NormMode.INSTANCE:
25
+ return getTorchModule("InstanceNorm", dim = dim)(channels, affine=False, track_running_stats=False)
26
+ if normMode == NormMode.INSTANCE_AFFINE:
27
+ return getTorchModule("InstanceNorm", dim = dim)(channels, affine=True, track_running_stats=False)
28
+ if normMode == NormMode.SYNCBATCH:
29
+ return torch.nn.SyncBatchNorm(channels, affine=True, track_running_stats=True)
30
+ if normMode == NormMode.GROUP:
31
+ return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
32
+ if normMode == NormMode.LAYER:
33
+ return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
34
+ return torch.nn.Identity()
35
+
36
+ class UpSampleMode(Enum):
37
+ CONV_TRANSPOSE = 0,
38
+ UPSAMPLE_NEAREST = 1,
39
+ UPSAMPLE_LINEAR = 2,
40
+ UPSAMPLE_BILINEAR = 3,
41
+ UPSAMPLE_BICUBIC = 4,
42
+ UPSAMPLE_TRILINEAR = 5
43
+
44
+ class DownSampleMode(Enum):
45
+ MAXPOOL = 0,
46
+ AVGPOOL = 1,
47
+ CONV_STRIDE = 2
48
+
49
+ def getTorchModule(name_fonction : str, dim : Union[int, None] = None) -> torch.nn.Module:
50
+ return getattr(importlib.import_module("torch.nn"), "{}".format(name_fonction) + ("{}d".format(dim) if dim is not None else ""))
51
+
52
+ class BlockConfig():
53
+
54
+ @config("BlockConfig")
55
+ def __init__(self, kernel_size : int = 3, stride : int = 1, padding : int = 1, bias = True, activation : Union[str, Callable[[], torch.nn.Module]] = "ReLU", normMode : Union[str, NormMode, Callable[[int], torch.nn.Module]] = "NONE") -> None:
56
+ self.kernel_size = kernel_size
57
+ self.bias = bias
58
+ self.stride = stride
59
+ self.padding = padding
60
+ self.activation = activation
61
+ self.normMode = normMode
62
+
63
+ if isinstance(normMode, str):
64
+ self.norm = NormMode._member_map_[normMode]
65
+ elif isinstance(normMode, NormMode):
66
+ self.norm = normMode
67
+
68
+ def getConv(self, in_channels : int, out_channels : int, dim : int) -> torch.nn.Conv3d:
69
+ return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
70
+
71
+ def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
72
+ return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
73
+
74
+ def getActivation(self) -> torch.nn.Module:
75
+ if isinstance(self.activation, str):
76
+ return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
77
+ return self.activation()
78
+
79
+ class ConvBlock(network.ModuleArgsDict):
80
+
81
+ def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], []]) -> None:
82
+ super().__init__()
83
+ for i, blockConfig in enumerate(blockConfigs):
84
+ self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
85
+ self.add_module("Norm_{}".format(i), blockConfig.getNorm(out_channels, dim), alias=alias[1])
86
+ self.add_module("Activation_{}".format(i), blockConfig.getActivation(), alias=alias[2])
87
+ in_channels = out_channels
88
+
89
+ class ResBlock(network.ModuleArgsDict):
90
+
91
+ def __init__(self, in_channels : int, out_channels : int, nb_conv: int, blockConfig : BlockConfig, dim : int, alias : list[list[str]]=[[], [], [], []]) -> None:
92
+ super().__init__()
93
+ for i in range(nb_conv):
94
+ self.add_module("ConvBlock_{}".format(i), ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blockConfig, dim=dim, alias=alias[:3]))
95
+ if in_channels != out_channels:
96
+ self.add_module("Conv_skip_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[3], in_branch=[1], out_branch=[1])
97
+ in_channels = out_channels
98
+ self.add_module("Add_{}".format(i), Add(), in_branch=[0,1], out_branch=[0,1])
99
+
100
+ def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
101
+ if downSampleMode == DownSampleMode.MAXPOOL:
102
+ return getTorchModule("MaxPool", dim = dim)(2)
103
+ if downSampleMode == DownSampleMode.AVGPOOL:
104
+ return getTorchModule("AvgPool", dim = dim)(2)
105
+ if downSampleMode == DownSampleMode.CONV_STRIDE:
106
+ return getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
107
+
108
+ def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, dim: int, kernel_size: Union[int, list[int]] = 2, stride: Union[int, list[int]] = 2):
109
+ if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
110
+ return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
111
+ else:
112
+ return torch.nn.Upsample(scale_factor=2, mode=upSampleMode.name.replace("UPSAMPLE_", "").lower(), align_corners=False)
113
+
114
+ class Unsqueeze(torch.nn.Module):
115
+
116
+ def __init__(self, dim: int = 0):
117
+ super().__init__()
118
+ self.dim = dim
119
+
120
+ def forward(self, *input : torch.Tensor) -> torch.Tensor:
121
+ return torch.unsqueeze(input, self.dim)
122
+
123
+ def extra_repr(self):
124
+ return "dim={}".format(self.dim)
125
+
126
+ class Permute(torch.nn.Module):
127
+
128
+ def __init__(self, dims : list[int]):
129
+ super().__init__()
130
+ self.dims = dims
131
+
132
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
133
+ return torch.permute(input, self.dims)
134
+
135
+ def extra_repr(self):
136
+ return "dims={}".format(self.dims)
137
+
138
+ class ToChannels(Permute):
139
+
140
+ def __init__(self, dim):
141
+ super().__init__([0, dim+1, *[i+1 for i in range(dim)]])
142
+
143
+ class ToFeatures(Permute):
144
+
145
+ def __init__(self, dim):
146
+ super().__init__([0, *[i+2 for i in range(dim)], 1])
147
+
148
+ class Add(torch.nn.Module):
149
+
150
+ def __init__(self) -> None:
151
+ super().__init__()
152
+
153
+ def forward(self, *input : torch.Tensor) -> torch.Tensor:
154
+ return torch.sum(torch.stack(input), dim=0)
155
+
156
+ class Multiply(torch.nn.Module):
157
+
158
+ def __init__(self) -> None:
159
+ super().__init__()
160
+
161
+ def forward(self, *input : torch.Tensor) -> torch.Tensor:
162
+ return torch.mul(*input)
163
+
164
+ class Concat(torch.nn.Module):
165
+
166
+ def __init__(self) -> None:
167
+ super().__init__()
168
+
169
+ def forward(self, *input : torch.Tensor) -> torch.Tensor:
170
+ return torch.cat(input, dim=1)
171
+
172
+ class Print(torch.nn.Module):
173
+
174
+ def __init__(self) -> None:
175
+ super().__init__()
176
+
177
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
178
+ print(input.shape)
179
+ return input
180
+
181
+ class Write(torch.nn.Module):
182
+
183
+ def __init__(self) -> None:
184
+ super().__init__()
185
+
186
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
187
+ import SimpleITK as sitk
188
+ sitk.WriteImage(sitk.GetImageFromArray(input.clone()[0][0].cpu().numpy()), "./Data.mha")
189
+ return input
190
+
191
+ class Exit(torch.nn.Module):
192
+
193
+ def __init__(self) -> None:
194
+ super().__init__()
195
+
196
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
197
+ exit(0)
198
+
199
+ class Detach(torch.nn.Module):
200
+
201
+ def __init__(self) -> None:
202
+ super().__init__()
203
+
204
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
205
+ return input.detach()
206
+
207
+ class Negative(torch.nn.Module):
208
+
209
+ def __init__(self) -> None:
210
+ super().__init__()
211
+
212
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
213
+ return -input
214
+
215
+ class GetShape(torch.nn.Module):
216
+
217
+ def __init__(self) -> None:
218
+ super().__init__()
219
+
220
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
221
+ return torch.tensor(input.shape)
222
+
223
+ class ArgMax(torch.nn.Module):
224
+
225
+ def __init__(self, dim: int) -> None:
226
+ super().__init__()
227
+ self.dim = dim
228
+
229
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
230
+ return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
231
+
232
+ class Select(torch.nn.Module):
233
+
234
+ def __init__(self, slices: list[slice]) -> None:
235
+ super().__init__()
236
+ self.slices = tuple(slices)
237
+
238
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
239
+ result = input[self.slices]
240
+ for i, s in enumerate(range(len(result.shape))):
241
+ if s == 1:
242
+ result = result.squeeze(dim=i)
243
+ return result
244
+
245
+ class NormalNoise(torch.nn.Module):
246
+
247
+ def __init__(self, dim: Union[int, None] = None) -> None:
248
+ super().__init__()
249
+ self.dim = dim
250
+
251
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
252
+ if self.dim is not None:
253
+ return torch.randn(self.dim).to(input.device)
254
+ else:
255
+ return torch.randn_like(input).to(input.device)
256
+
257
+ class Const(torch.nn.Module):
258
+
259
+ def __init__(self, shape: list[int], std: float) -> None:
260
+ super().__init__()
261
+ self.noise = torch.nn.parameter.Parameter(torch.randn(shape)*std)
262
+
263
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
264
+ return self.noise.to(input.device)
265
+
266
+ class HistogramNoise(torch.nn.Module):
267
+
268
+ def __init__(self, n: int, sigma: float) -> None:
269
+ super().__init__()
270
+ self.x = np.linspace(0, 1, num=n, endpoint=True)
271
+ self.sigma = sigma
272
+
273
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
274
+ self.function = interp1d(self.x, self.x+np.random.normal(0, self.sigma, self.x.shape[0]), kind='cubic')
275
+ result = torch.empty_like(input)
276
+
277
+ for value in torch.unique(input):
278
+ x = self.function(value.cpu())
279
+ result[torch.where(input == value)] = torch.tensor(x, device=input.device).float()
280
+ return result
281
+
282
+ class Subset(torch.nn.Module):
283
+ def __init__(self, slices: list[slice]):
284
+ super().__init__()
285
+ self.slices = [slice(None, None), slice(None, None)] + slices
286
+
287
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
288
+ return tensor[self.slices]
289
+
290
+ class View(torch.nn.Module):
291
+ def __init__(self, size: list[int]):
292
+ super().__init__()
293
+ self.size = size
294
+
295
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
296
+ return tensor.view(self.size)
297
+
298
+ class LatentDistribution(network.ModuleArgsDict):
299
+
300
+ class LatentDistribution_Linear(torch.nn.Module):
301
+
302
+ def __init__(self, shape: list[int], latentDim: int) -> None:
303
+ super().__init__()
304
+ self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latentDim)
305
+
306
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
307
+ return torch.unsqueeze(self.linear(input), 1)
308
+
309
+ class LatentDistribution_Decoder(torch.nn.Module):
310
+
311
+ def __init__(self, shape: list[int], latentDim: int) -> None:
312
+ super().__init__()
313
+ self.linear = torch.nn.Linear(latentDim, torch.prod(torch.tensor(shape)))
314
+ self.shape = shape
315
+
316
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
317
+ return self.linear(input).view(-1, *[int(i) for i in self.shape])
318
+
319
+ class LatentDistribution_Z(torch.nn.Module):
320
+
321
+ def __init__(self) -> None:
322
+ super().__init__()
323
+
324
+ def forward(self, mu: torch.Tensor, log_std: torch.Tensor) -> torch.Tensor:
325
+ return torch.exp(log_std/2)*torch.rand_like(mu)+mu
326
+
327
+ def __init__(self, shape: list[int], latentDim: int) -> None:
328
+ super().__init__()
329
+ self.add_module("Flatten", torch.nn.Flatten(1))
330
+ self.add_module("mu", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [1])
331
+ self.add_module("log_std", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [2])
332
+
333
+ self.add_module("z", LatentDistribution.LatentDistribution_Z(), in_branch=[1,2], out_branch=[3])
334
+ self.add_module("Concat", Concat(), in_branch=[1,2,3])
335
+ self.add_module("DecoderInput", LatentDistribution.LatentDistribution_Decoder(shape, latentDim), in_branch=[3])
336
+
337
+ class Attention(network.ModuleArgsDict):
338
+
339
+ def __init__(self, F_g : int, F_l : int, F_int : int, dim : int):
340
+ super().__init__()
341
+ self.add_module("W_x", getTorchModule("Conv", dim = dim)(in_channels = F_l, out_channels = F_int, kernel_size=1, stride=2, padding=0), in_branch=[0], out_branch=[0])
342
+ self.add_module("W_g", getTorchModule("Conv", dim = dim)(in_channels = F_g, out_channels = F_int, kernel_size=1, stride=1, padding=0), in_branch=[1], out_branch=[1])
343
+ self.add_module("Add", Add(), in_branch=[0,1])
344
+ self.add_module("ReLU", torch.nn.ReLU(inplace=True))
345
+ self.add_module("Conv", getTorchModule("Conv", dim = dim)(in_channels = F_int, out_channels = 1, kernel_size=1,stride=1, padding=0))
346
+ self.add_module("Sigmoid", torch.nn.Sigmoid())
347
+ self.add_module("Upsample", torch.nn.Upsample(scale_factor=2))
348
+ self.add_module("Multiply", Multiply(), in_branch=[2,0])