konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/network/blocks.py CHANGED
@@ -1,17 +1,19 @@
1
- from enum import Enum
1
+ import ast
2
2
  import importlib
3
- from typing import Callable
3
+ from collections.abc import Callable
4
+ from enum import Enum
5
+
6
+ import numpy as np
7
+ import SimpleITK as sitk # noqa: N813
4
8
  import torch
5
9
  from scipy.interpolate import interp1d
6
- import numpy as np
7
- import ast
8
- from typing import Union
9
10
 
10
- from konfai.utils.config import config
11
11
  from konfai.network import network
12
+ from konfai.utils.config import config
13
+
12
14
 
13
15
  class NormMode(Enum):
14
- NONE = 0,
16
+ NONE = (0,)
15
17
  BATCH = 1
16
18
  INSTANCE = 2
17
19
  GROUP = 3
@@ -19,227 +21,327 @@ class NormMode(Enum):
19
21
  SYNCBATCH = 5
20
22
  INSTANCE_AFFINE = 6
21
23
 
22
- def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
23
- if normMode == NormMode.BATCH:
24
- return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
25
- if normMode == NormMode.INSTANCE:
26
- return getTorchModule("InstanceNorm", dim = dim)(channels, affine=False, track_running_stats=False)
27
- if normMode == NormMode.INSTANCE_AFFINE:
28
- return getTorchModule("InstanceNorm", dim = dim)(channels, affine=True, track_running_stats=False)
29
- if normMode == NormMode.SYNCBATCH:
24
+
25
+ def get_norm(norm_mode: Enum, channels: int, dim: int) -> torch.nn.Module | None:
26
+ if norm_mode == NormMode.BATCH:
27
+ return get_torch_module("BatchNorm", dim=dim)(channels, affine=True, track_running_stats=True)
28
+ if norm_mode == NormMode.INSTANCE:
29
+ return get_torch_module("InstanceNorm", dim=dim)(channels, affine=False, track_running_stats=False)
30
+ if norm_mode == NormMode.INSTANCE_AFFINE:
31
+ return get_torch_module("InstanceNorm", dim=dim)(channels, affine=True, track_running_stats=False)
32
+ if norm_mode == NormMode.SYNCBATCH:
30
33
  return torch.nn.SyncBatchNorm(channels, affine=True, track_running_stats=True)
31
- if normMode == NormMode.GROUP:
34
+ if norm_mode == NormMode.GROUP:
32
35
  return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
33
- if normMode == NormMode.LAYER:
36
+ if norm_mode == NormMode.LAYER:
34
37
  return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
35
38
  return None
36
39
 
37
- class UpSampleMode(Enum):
38
- CONV_TRANSPOSE = 0,
39
- UPSAMPLE = 1,
40
40
 
41
- class DownSampleMode(Enum):
42
- MAXPOOL = 0,
43
- AVGPOOL = 1,
41
+ class UpsampleMode(Enum):
42
+ CONV_TRANSPOSE = (0,)
43
+ UPSAMPLE = (1,)
44
+
45
+
46
+ class DownsampleMode(Enum):
47
+ MAXPOOL = (0,)
48
+ AVGPOOL = (1,)
44
49
  CONV_STRIDE = 2
45
50
 
46
- def getTorchModule(name_fonction : str, dim : Union[int, None] = None) -> torch.nn.Module:
47
- return getattr(importlib.import_module("torch.nn"), "{}".format(name_fonction) + ("{}d".format(dim) if dim is not None else ""))
48
51
 
49
- class BlockConfig():
52
+ def get_torch_module(name_fonction: str, dim: int | None = None) -> torch.nn.Module:
53
+ return getattr(
54
+ importlib.import_module("torch.nn"),
55
+ f"{name_fonction}" + (f"{dim}d" if dim is not None else ""),
56
+ )
57
+
58
+
59
+ class BlockConfig:
50
60
 
51
61
  @config("BlockConfig")
52
- def __init__(self, kernel_size : int = 3, stride : int = 1, padding : int = 1, bias = True, activation : Union[str, Callable[[], torch.nn.Module]] = "ReLU", normMode : Union[str, NormMode, Callable[[int], torch.nn.Module]] = "NONE") -> None:
62
+ def __init__(
63
+ self,
64
+ kernel_size: int = 3,
65
+ stride: int = 1,
66
+ padding: int = 1,
67
+ bias=True,
68
+ activation: str | Callable[[], torch.nn.Module] = "ReLU",
69
+ norm_mode: str | NormMode | Callable[[int], torch.nn.Module] = "NONE",
70
+ ) -> None:
53
71
  self.kernel_size = kernel_size
54
72
  self.bias = bias
55
73
  self.stride = stride
56
74
  self.padding = padding
57
75
  self.activation = activation
58
- if normMode is None:
59
- self.norm = None
60
- elif isinstance(normMode, str):
61
- self.norm = NormMode._member_map_[normMode]
62
- elif isinstance(normMode, NormMode):
63
- self.norm = normMode
64
-
65
- def getConv(self, in_channels : int, out_channels : int, dim : int) -> torch.nn.Conv3d:
66
- return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
67
-
68
- def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
76
+ self.norm_mode = norm_mode
77
+ self.norm: NormMode | Callable[[int], torch.nn.Module] | None = None
78
+ if isinstance(norm_mode, str):
79
+ self.norm = NormMode[norm_mode]
80
+ else:
81
+ self.norm = norm_mode
82
+
83
+ def get_conv(self, in_channels: int, out_channels: int, dim: int) -> torch.nn.Conv3d:
84
+ return get_torch_module("Conv", dim=dim)(
85
+ in_channels=in_channels,
86
+ out_channels=out_channels,
87
+ kernel_size=self.kernel_size,
88
+ stride=self.stride,
89
+ padding=self.padding,
90
+ bias=self.bias,
91
+ )
92
+
93
+ def get_norm(self, channels: int, dim: int) -> torch.nn.Module:
69
94
  if self.norm is None:
70
95
  return None
71
- return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
96
+ return get_norm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
72
97
 
73
- def getActivation(self) -> torch.nn.Module:
98
+ def get_activation(self) -> torch.nn.Module:
74
99
  if self.activation is None:
75
100
  return None
76
101
  if isinstance(self.activation, str):
77
- return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
102
+ return (
103
+ get_torch_module(self.activation.split(";")[0])(
104
+ *[ast.literal_eval(value) for value in self.activation.split(";")[1:]]
105
+ )
106
+ if self.activation != "None"
107
+ else torch.nn.Identity()
108
+ )
78
109
  return self.activation()
79
-
110
+
111
+
80
112
  class ConvBlock(network.ModuleArgsDict):
81
-
82
- def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], []]) -> None:
113
+
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ out_channels: int,
118
+ block_configs: list[BlockConfig],
119
+ dim: int,
120
+ alias: list[list[str]] = [[], [], []],
121
+ ) -> None:
83
122
  super().__init__()
84
- for i, blockConfig in enumerate(blockConfigs):
85
- self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
86
- norm = blockConfig.getNorm(out_channels, dim)
123
+ for i, block_config in enumerate(block_configs):
124
+ self.add_module(
125
+ f"Conv_{i}",
126
+ block_config.get_conv(in_channels, out_channels, dim),
127
+ alias=alias[0],
128
+ )
129
+ norm = block_config.get_norm(out_channels, dim)
87
130
  if norm is not None:
88
- self.add_module("Norm_{}".format(i), norm, alias=alias[1])
89
- activation = blockConfig.getActivation()
131
+ self.add_module(f"Norm_{i}", norm, alias=alias[1])
132
+ activation = block_config.get_activation()
90
133
  if activation is not None:
91
- self.add_module("Activation_{}".format(i), activation, alias=alias[2])
134
+ self.add_module(f"Activation_{i}", activation, alias=alias[2])
92
135
  in_channels = out_channels
93
136
 
137
+
94
138
  class ResBlock(network.ModuleArgsDict):
95
-
96
- def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
139
+
140
+ def __init__(
141
+ self,
142
+ in_channels: int,
143
+ out_channels: int,
144
+ block_configs: list[BlockConfig],
145
+ dim: int,
146
+ alias: list[list[str]] = [[], [], [], [], []],
147
+ ) -> None:
97
148
  super().__init__()
98
- for i, blockConfig in enumerate(blockConfigs):
99
- self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
100
- norm = blockConfig.getNorm(out_channels, dim)
149
+ for i, block_config in enumerate(block_configs):
150
+ self.add_module(
151
+ f"Conv_{i}",
152
+ block_config.get_conv(in_channels, out_channels, dim),
153
+ alias=alias[0],
154
+ )
155
+ norm = block_config.get_norm(out_channels, dim)
101
156
  if norm is not None:
102
- self.add_module("Norm_{}".format(i), norm, alias=alias[1])
103
- activation = blockConfig.getActivation()
157
+ self.add_module(f"Norm_{i}", norm, alias=alias[1])
158
+ activation = block_config.get_activation()
104
159
  if activation is not None:
105
- self.add_module("Activation_{}".format(i), activation, alias=alias[2])
160
+ self.add_module(f"Activation_{i}", activation, alias=alias[2])
106
161
 
107
162
  if in_channels != out_channels:
108
- self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
109
- self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
163
+ self.add_module(
164
+ "Conv_skip",
165
+ get_torch_module("Conv", dim)(
166
+ in_channels,
167
+ out_channels,
168
+ 1,
169
+ block_config.stride,
170
+ bias=block_config.bias,
171
+ ),
172
+ alias=alias[3],
173
+ in_branch=[1],
174
+ out_branch=[1],
175
+ )
176
+ self.add_module(
177
+ "Norm_skip",
178
+ block_config.get_norm(out_channels, dim),
179
+ alias=alias[4],
180
+ in_branch=[1],
181
+ out_branch=[1],
182
+ )
110
183
  in_channels = out_channels
111
-
112
- self.add_module("Add", Add(), in_branch=[0,1])
113
- self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
114
-
115
- def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
116
- if downSampleMode == DownSampleMode.MAXPOOL:
117
- return getTorchModule("MaxPool", dim = dim)(2)
118
- if downSampleMode == DownSampleMode.AVGPOOL:
119
- return getTorchModule("AvgPool", dim = dim)(2)
120
- if downSampleMode == DownSampleMode.CONV_STRIDE:
121
- return getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
122
-
123
- def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, dim: int, kernel_size: Union[int, list[int]] = 2, stride: Union[int, list[int]] = 2):
124
- if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
125
- return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
184
+
185
+ self.add_module("Add", Add(), in_branch=[0, 1])
186
+ self.add_module(f"Norm_{i + 1}", torch.nn.ReLU(inplace=True))
187
+
188
+
189
+ def downsample(in_channels: int, out_channels: int, downsample_mode: DownsampleMode, dim: int) -> torch.nn.Module:
190
+ if downsample_mode == DownsampleMode.MAXPOOL:
191
+ return get_torch_module("MaxPool", dim=dim)(2)
192
+ if downsample_mode == DownsampleMode.AVGPOOL:
193
+ return get_torch_module("AvgPool", dim=dim)(2)
194
+ if downsample_mode == DownsampleMode.CONV_STRIDE:
195
+ return get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
196
+
197
+
198
+ def upsample(
199
+ in_channels: int,
200
+ out_channels: int,
201
+ upsample_mode: UpsampleMode,
202
+ dim: int,
203
+ kernel_size: int | list[int] = 2,
204
+ stride: int | list[int] = 2,
205
+ ):
206
+ if upsample_mode == UpsampleMode.CONV_TRANSPOSE:
207
+ return get_torch_module("ConvTranspose", dim=dim)(
208
+ in_channels=in_channels,
209
+ out_channels=out_channels,
210
+ kernel_size=kernel_size,
211
+ stride=stride,
212
+ padding=0,
213
+ )
126
214
  else:
127
215
  if dim == 3:
128
- upsampleMethod = "trilinear"
216
+ upsample_method = "trilinear"
129
217
  if dim == 2:
130
- upsampleMethod = "bilinear"
218
+ upsample_method = "bilinear"
131
219
  if dim == 1:
132
- upsampleMethod = "linear"
133
- return torch.nn.Upsample(scale_factor=2, mode=upsampleMethod.lower(), align_corners=False)
220
+ upsample_method = "linear"
221
+ return torch.nn.Upsample(scale_factor=2, mode=upsample_method.lower(), align_corners=False)
222
+
134
223
 
135
224
  class Unsqueeze(torch.nn.Module):
136
225
 
137
226
  def __init__(self, dim: int = 0):
138
227
  super().__init__()
139
228
  self.dim = dim
140
-
141
- def forward(self, *input : torch.Tensor) -> torch.Tensor:
142
- return torch.unsqueeze(input, self.dim)
143
-
229
+
230
+ def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
231
+ return torch.unsqueeze(tensor, self.dim)
232
+
144
233
  def extra_repr(self):
145
- return "dim={}".format(self.dim)
234
+ return f"dim={self.dim}"
235
+
146
236
 
147
237
  class Permute(torch.nn.Module):
148
238
 
149
- def __init__(self, dims : list[int]):
239
+ def __init__(self, dims: list[int]):
150
240
  super().__init__()
151
241
  self.dims = dims
152
242
 
153
- def forward(self, input : torch.Tensor) -> torch.Tensor:
154
- return torch.permute(input, self.dims)
155
-
243
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
244
+ return torch.permute(tensor, self.dims)
245
+
156
246
  def extra_repr(self):
157
- return "dims={}".format(self.dims)
247
+ return f"dims={self.dims}"
248
+
158
249
 
159
250
  class ToChannels(Permute):
160
251
 
161
- def __init__(self, dim):
162
- super().__init__([0, dim+1, *[i+1 for i in range(dim)]])
163
-
252
+ def __init__(self, dim: int):
253
+ super().__init__([0, dim + 1, *[i + 1 for i in range(dim)]])
254
+
255
+
164
256
  class ToFeatures(Permute):
165
257
 
166
- def __init__(self, dim):
167
- super().__init__([0, *[i+2 for i in range(dim)], 1])
258
+ def __init__(self, dim: int):
259
+ super().__init__([0, *[i + 2 for i in range(dim)], 1])
260
+
168
261
 
169
262
  class Add(torch.nn.Module):
170
263
 
171
264
  def __init__(self) -> None:
172
265
  super().__init__()
173
-
174
- def forward(self, *input : torch.Tensor) -> torch.Tensor:
175
- return torch.sum(torch.stack(input), dim=0)
176
-
266
+
267
+ def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
268
+ return torch.sum(torch.stack(tensor), dim=0)
269
+
270
+
177
271
  class Multiply(torch.nn.Module):
178
272
 
179
273
  def __init__(self) -> None:
180
274
  super().__init__()
181
275
 
182
- def forward(self, *input : torch.Tensor) -> torch.Tensor:
183
- return torch.mul(*input)
276
+ def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
277
+ return torch.mul(*tensor)
278
+
184
279
 
185
280
  class Concat(torch.nn.Module):
186
281
 
187
282
  def __init__(self) -> None:
188
283
  super().__init__()
189
284
 
190
- def forward(self, *input : torch.Tensor) -> torch.Tensor:
191
- return torch.cat(input, dim=1)
285
+ def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
286
+ return torch.cat(tensor, dim=1)
287
+
192
288
 
193
289
  class Print(torch.nn.Module):
194
290
 
195
291
  def __init__(self) -> None:
196
292
  super().__init__()
197
-
198
- def forward(self, input: torch.Tensor) -> torch.Tensor:
199
- print(input.shape)
200
- return input
293
+
294
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
295
+ print(tensor.shape)
296
+ return tensor
297
+
201
298
 
202
299
  class Write(torch.nn.Module):
203
300
 
204
301
  def __init__(self) -> None:
205
302
  super().__init__()
206
-
207
- def forward(self, input: torch.Tensor) -> torch.Tensor:
208
- import SimpleITK as sitk
209
- sitk.WriteImage(sitk.GetImageFromArray(input.clone()[0][0].cpu().numpy()), "./Data.mha")
210
- return input
211
-
303
+
304
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
305
+
306
+ sitk.WriteImage(sitk.GetImageFromArray(tensor.clone()[0][0].cpu().numpy()), "./Data.mha")
307
+ return tensor
308
+
309
+
212
310
  class Exit(torch.nn.Module):
213
311
 
214
312
  def __init__(self) -> None:
215
313
  super().__init__()
216
-
217
- def forward(self, input: torch.Tensor) -> torch.Tensor:
314
+
315
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
218
316
  exit(0)
219
-
317
+
318
+
220
319
  class Detach(torch.nn.Module):
221
320
 
222
321
  def __init__(self) -> None:
223
322
  super().__init__()
224
-
225
- def forward(self, input: torch.Tensor) -> torch.Tensor:
226
- return input.detach()
227
-
323
+
324
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
325
+ return tensor.detach()
326
+
327
+
228
328
  class Negative(torch.nn.Module):
229
329
 
230
330
  def __init__(self) -> None:
231
331
  super().__init__()
232
-
233
- def forward(self, input: torch.Tensor) -> torch.Tensor:
234
- return -input
332
+
333
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
334
+ return -tensor
335
+
235
336
 
236
337
  class GetShape(torch.nn.Module):
237
338
 
238
339
  def __init__(self) -> None:
239
340
  super().__init__()
240
341
 
241
- def forward(self, input: torch.Tensor) -> torch.Tensor:
242
- return torch.tensor(input.shape)
342
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
343
+ return torch.tensor(tensor.shape)
344
+
243
345
 
244
346
  class ArgMax(torch.nn.Module):
245
347
 
@@ -247,42 +349,46 @@ class ArgMax(torch.nn.Module):
247
349
  super().__init__()
248
350
  self.dim = dim
249
351
 
250
- def forward(self, input: torch.Tensor) -> torch.Tensor:
251
- return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
252
-
352
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
353
+ return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
354
+
355
+
253
356
  class Select(torch.nn.Module):
254
357
 
255
358
  def __init__(self, slices: list[slice]) -> None:
256
359
  super().__init__()
257
360
  self.slices = tuple(slices)
258
361
 
259
- def forward(self, input: torch.Tensor) -> torch.Tensor:
260
- result = input[self.slices]
362
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
363
+ result = tensor[self.slices]
261
364
  for i, s in enumerate(range(len(result.shape))):
262
365
  if s == 1:
263
- result = result.squeeze(dim=i)
366
+ result = result.squeeze(dim=i)
264
367
  return result
265
368
 
369
+
266
370
  class NormalNoise(torch.nn.Module):
267
371
 
268
- def __init__(self, dim: Union[int, None] = None) -> None:
372
+ def __init__(self, dim: int | None = None) -> None:
269
373
  super().__init__()
270
374
  self.dim = dim
271
-
272
- def forward(self, input: torch.Tensor) -> torch.Tensor:
375
+
376
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
273
377
  if self.dim is not None:
274
- return torch.randn(self.dim).to(input.device)
378
+ return torch.randn(self.dim).to(tensor.device)
275
379
  else:
276
- return torch.randn_like(input).to(input.device)
277
-
380
+ return torch.randn_like(tensor).to(tensor.device)
381
+
382
+
278
383
  class Const(torch.nn.Module):
279
384
 
280
385
  def __init__(self, shape: list[int], std: float) -> None:
281
386
  super().__init__()
282
- self.noise = torch.nn.parameter.Parameter(torch.randn(shape)*std)
283
-
284
- def forward(self, input: torch.Tensor) -> torch.Tensor:
285
- return self.noise.to(input.device)
387
+ self.noise = torch.nn.parameter.Parameter(torch.randn(shape) * std)
388
+
389
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
390
+ return self.noise.to(tensor.device)
391
+
286
392
 
287
393
  class HistogramNoise(torch.nn.Module):
288
394
 
@@ -290,80 +396,118 @@ class HistogramNoise(torch.nn.Module):
290
396
  super().__init__()
291
397
  self.x = np.linspace(0, 1, num=n, endpoint=True)
292
398
  self.sigma = sigma
293
-
294
- def forward(self, input: torch.Tensor) -> torch.Tensor:
295
- self.function = interp1d(self.x, self.x+np.random.normal(0, self.sigma, self.x.shape[0]), kind='cubic')
296
- result = torch.empty_like(input)
297
399
 
298
- for value in torch.unique(input):
400
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
401
+ self.function = interp1d(
402
+ self.x,
403
+ self.x + np.random.normal(0, self.sigma, self.x.shape[0]),
404
+ kind="cubic",
405
+ )
406
+ result = torch.empty_like(tensor)
407
+
408
+ for value in torch.unique(tensor):
299
409
  x = self.function(value.cpu())
300
- result[torch.where(input == value)] = torch.tensor(x, device=input.device).float()
410
+ result[torch.where(tensor == value)] = torch.tensor(x, device=tensor.device).float()
301
411
  return result
302
412
 
413
+
303
414
  class Subset(torch.nn.Module):
304
- def __init__(self, slices: list[slice]):
305
- super().__init__()
306
- self.slices = [slice(None, None), slice(None, None)] + slices
415
+ def __init__(self, slices: list[slice]):
416
+ super().__init__()
417
+ self.slices = [slice(None, None), slice(None, None)] + slices
418
+
419
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
420
+ return tensor[self.slices]
307
421
 
308
- def forward(self, tensor: torch.Tensor) -> torch.Tensor:
309
- return tensor[self.slices]
310
422
 
311
423
  class View(torch.nn.Module):
312
- def __init__(self, size: list[int]):
313
- super().__init__()
314
- self.size = size
424
+ def __init__(self, size: list[int]):
425
+ super().__init__()
426
+ self.size = size
427
+
428
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
429
+ return tensor.view(self.size)
430
+
315
431
 
316
- def forward(self, tensor: torch.Tensor) -> torch.Tensor:
317
- return tensor.view(self.size)
318
-
319
432
  class LatentDistribution(network.ModuleArgsDict):
320
433
 
321
- class LatentDistribution_Linear(torch.nn.Module):
434
+ class LatentDistributionLinear(torch.nn.Module):
322
435
 
323
- def __init__(self, shape: list[int], latentDim: int) -> None:
436
+ def __init__(self, shape: list[int], latent_dim: int) -> None:
324
437
  super().__init__()
325
- self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latentDim)
326
-
327
- def forward(self, input: torch.Tensor) -> torch.Tensor:
328
- return torch.unsqueeze(self.linear(input), 1)
329
-
330
- class LatentDistribution_Decoder(torch.nn.Module):
331
-
332
- def __init__(self, shape: list[int], latentDim: int) -> None:
438
+ self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latent_dim)
439
+
440
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
441
+ return torch.unsqueeze(self.linear(tensor), 1)
442
+
443
+ class LatentDistributionDecoder(torch.nn.Module):
444
+
445
+ def __init__(self, shape: list[int], latent_dim: int) -> None:
333
446
  super().__init__()
334
- self.linear = torch.nn.Linear(latentDim, torch.prod(torch.tensor(shape)))
447
+ self.linear = torch.nn.Linear(latent_dim, torch.prod(torch.tensor(shape)))
335
448
  self.shape = shape
336
449
 
337
- def forward(self, input: torch.Tensor) -> torch.Tensor:
338
- return self.linear(input).view(-1, *[int(i) for i in self.shape])
339
-
340
- class LatentDistribution_Z(torch.nn.Module):
450
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
451
+ return self.linear(tensor).view(-1, *[int(i) for i in self.shape])
452
+
453
+ class LatentDistributionZ(torch.nn.Module):
341
454
 
342
455
  def __init__(self) -> None:
343
456
  super().__init__()
344
457
 
345
458
  def forward(self, mu: torch.Tensor, log_std: torch.Tensor) -> torch.Tensor:
346
- return torch.exp(log_std/2)*torch.rand_like(mu)+mu
347
-
348
- def __init__(self, shape: list[int], latentDim: int) -> None:
349
- super().__init__()
459
+ return torch.exp(log_std / 2) * torch.rand_like(mu) + mu
460
+
461
+ def __init__(self, shape: list[int], latent_dim: int) -> None:
462
+ super().__init__()
350
463
  self.add_module("Flatten", torch.nn.Flatten(1))
351
- self.add_module("mu", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [1])
352
- self.add_module("log_std", LatentDistribution.LatentDistribution_Linear(shape, latentDim), out_branch = [2])
464
+ self.add_module(
465
+ "mu",
466
+ LatentDistribution.LatentDistributionLinear(shape, latent_dim),
467
+ out_branch=[1],
468
+ )
469
+ self.add_module(
470
+ "log_std",
471
+ LatentDistribution.LatentDistributionLinear(shape, latent_dim),
472
+ out_branch=[2],
473
+ )
474
+
475
+ self.add_module(
476
+ "z",
477
+ LatentDistribution.LatentDistributionZ(),
478
+ in_branch=[1, 2],
479
+ out_branch=[3],
480
+ )
481
+ self.add_module("Concat", Concat(), in_branch=[1, 2, 3])
482
+ self.add_module(
483
+ "DecoderInput",
484
+ LatentDistribution.LatentDistributionDecoder(shape, latent_dim),
485
+ in_branch=[3],
486
+ )
353
487
 
354
- self.add_module("z", LatentDistribution.LatentDistribution_Z(), in_branch=[1,2], out_branch=[3])
355
- self.add_module("Concat", Concat(), in_branch=[1,2,3])
356
- self.add_module("DecoderInput", LatentDistribution.LatentDistribution_Decoder(shape, latentDim), in_branch=[3])
357
488
 
358
489
  class Attention(network.ModuleArgsDict):
359
490
 
360
- def __init__(self, F_g : int, F_l : int, F_int : int, dim : int):
491
+ def __init__(self, f_g: int, f_l: int, f_int: int, dim: int):
361
492
  super().__init__()
362
- self.add_module("W_x", getTorchModule("Conv", dim = dim)(in_channels = F_l, out_channels = F_int, kernel_size=1, stride=2, padding=0), in_branch=[0], out_branch=[0])
363
- self.add_module("W_g", getTorchModule("Conv", dim = dim)(in_channels = F_g, out_channels = F_int, kernel_size=1, stride=1, padding=0), in_branch=[1], out_branch=[1])
364
- self.add_module("Add", Add(), in_branch=[0,1])
493
+ self.add_module(
494
+ "W_x",
495
+ get_torch_module("Conv", dim=dim)(in_channels=f_l, out_channels=f_int, kernel_size=1, stride=2, padding=0),
496
+ in_branch=[0],
497
+ out_branch=[0],
498
+ )
499
+ self.add_module(
500
+ "W_g",
501
+ get_torch_module("Conv", dim=dim)(in_channels=f_g, out_channels=f_int, kernel_size=1, stride=1, padding=0),
502
+ in_branch=[1],
503
+ out_branch=[1],
504
+ )
505
+ self.add_module("Add", Add(), in_branch=[0, 1])
365
506
  self.add_module("ReLU", torch.nn.ReLU(inplace=True))
366
- self.add_module("Conv", getTorchModule("Conv", dim = dim)(in_channels = F_int, out_channels = 1, kernel_size=1,stride=1, padding=0))
507
+ self.add_module(
508
+ "Conv",
509
+ get_torch_module("Conv", dim=dim)(in_channels=f_int, out_channels=1, kernel_size=1, stride=1, padding=0),
510
+ )
367
511
  self.add_module("Sigmoid", torch.nn.Sigmoid())
368
512
  self.add_module("Upsample", torch.nn.Upsample(scale_factor=2))
369
- self.add_module("Multiply", Multiply(), in_branch=[2,0])
513
+ self.add_module("Multiply", Multiply(), in_branch=[2, 0])