konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/network/blocks.py
CHANGED
|
@@ -1,17 +1,19 @@
|
|
|
1
|
-
|
|
1
|
+
import ast
|
|
2
2
|
import importlib
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import SimpleITK as sitk # noqa: N813
|
|
4
8
|
import torch
|
|
5
9
|
from scipy.interpolate import interp1d
|
|
6
|
-
import numpy as np
|
|
7
|
-
import ast
|
|
8
|
-
from typing import Union
|
|
9
10
|
|
|
10
|
-
from konfai.utils.config import config
|
|
11
11
|
from konfai.network import network
|
|
12
|
+
from konfai.utils.config import config
|
|
13
|
+
|
|
12
14
|
|
|
13
15
|
class NormMode(Enum):
|
|
14
|
-
NONE = 0,
|
|
16
|
+
NONE = (0,)
|
|
15
17
|
BATCH = 1
|
|
16
18
|
INSTANCE = 2
|
|
17
19
|
GROUP = 3
|
|
@@ -19,227 +21,327 @@ class NormMode(Enum):
|
|
|
19
21
|
SYNCBATCH = 5
|
|
20
22
|
INSTANCE_AFFINE = 6
|
|
21
23
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
|
|
25
|
+
def get_norm(norm_mode: Enum, channels: int, dim: int) -> torch.nn.Module | None:
|
|
26
|
+
if norm_mode == NormMode.BATCH:
|
|
27
|
+
return get_torch_module("BatchNorm", dim=dim)(channels, affine=True, track_running_stats=True)
|
|
28
|
+
if norm_mode == NormMode.INSTANCE:
|
|
29
|
+
return get_torch_module("InstanceNorm", dim=dim)(channels, affine=False, track_running_stats=False)
|
|
30
|
+
if norm_mode == NormMode.INSTANCE_AFFINE:
|
|
31
|
+
return get_torch_module("InstanceNorm", dim=dim)(channels, affine=True, track_running_stats=False)
|
|
32
|
+
if norm_mode == NormMode.SYNCBATCH:
|
|
30
33
|
return torch.nn.SyncBatchNorm(channels, affine=True, track_running_stats=True)
|
|
31
|
-
if
|
|
34
|
+
if norm_mode == NormMode.GROUP:
|
|
32
35
|
return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
|
|
33
|
-
if
|
|
36
|
+
if norm_mode == NormMode.LAYER:
|
|
34
37
|
return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
|
|
35
38
|
return None
|
|
36
39
|
|
|
37
|
-
class UpSampleMode(Enum):
|
|
38
|
-
CONV_TRANSPOSE = 0,
|
|
39
|
-
UPSAMPLE = 1,
|
|
40
40
|
|
|
41
|
-
class
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
class UpsampleMode(Enum):
|
|
42
|
+
CONV_TRANSPOSE = (0,)
|
|
43
|
+
UPSAMPLE = (1,)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DownsampleMode(Enum):
|
|
47
|
+
MAXPOOL = (0,)
|
|
48
|
+
AVGPOOL = (1,)
|
|
44
49
|
CONV_STRIDE = 2
|
|
45
50
|
|
|
46
|
-
def getTorchModule(name_fonction : str, dim : Union[int, None] = None) -> torch.nn.Module:
|
|
47
|
-
return getattr(importlib.import_module("torch.nn"), "{}".format(name_fonction) + ("{}d".format(dim) if dim is not None else ""))
|
|
48
51
|
|
|
49
|
-
|
|
52
|
+
def get_torch_module(name_fonction: str, dim: int | None = None) -> torch.nn.Module:
|
|
53
|
+
return getattr(
|
|
54
|
+
importlib.import_module("torch.nn"),
|
|
55
|
+
f"{name_fonction}" + (f"{dim}d" if dim is not None else ""),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BlockConfig:
|
|
50
60
|
|
|
51
61
|
@config("BlockConfig")
|
|
52
|
-
def __init__(
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
kernel_size: int = 3,
|
|
65
|
+
stride: int = 1,
|
|
66
|
+
padding: int = 1,
|
|
67
|
+
bias=True,
|
|
68
|
+
activation: str | Callable[[], torch.nn.Module] = "ReLU",
|
|
69
|
+
norm_mode: str | NormMode | Callable[[int], torch.nn.Module] = "NONE",
|
|
70
|
+
) -> None:
|
|
53
71
|
self.kernel_size = kernel_size
|
|
54
72
|
self.bias = bias
|
|
55
73
|
self.stride = stride
|
|
56
74
|
self.padding = padding
|
|
57
75
|
self.activation = activation
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
self.norm = NormMode
|
|
62
|
-
|
|
63
|
-
self.norm =
|
|
64
|
-
|
|
65
|
-
def
|
|
66
|
-
return
|
|
67
|
-
|
|
68
|
-
|
|
76
|
+
self.norm_mode = norm_mode
|
|
77
|
+
self.norm: NormMode | Callable[[int], torch.nn.Module] | None = None
|
|
78
|
+
if isinstance(norm_mode, str):
|
|
79
|
+
self.norm = NormMode[norm_mode]
|
|
80
|
+
else:
|
|
81
|
+
self.norm = norm_mode
|
|
82
|
+
|
|
83
|
+
def get_conv(self, in_channels: int, out_channels: int, dim: int) -> torch.nn.Conv3d:
|
|
84
|
+
return get_torch_module("Conv", dim=dim)(
|
|
85
|
+
in_channels=in_channels,
|
|
86
|
+
out_channels=out_channels,
|
|
87
|
+
kernel_size=self.kernel_size,
|
|
88
|
+
stride=self.stride,
|
|
89
|
+
padding=self.padding,
|
|
90
|
+
bias=self.bias,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def get_norm(self, channels: int, dim: int) -> torch.nn.Module:
|
|
69
94
|
if self.norm is None:
|
|
70
95
|
return None
|
|
71
|
-
return
|
|
96
|
+
return get_norm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
|
|
72
97
|
|
|
73
|
-
def
|
|
98
|
+
def get_activation(self) -> torch.nn.Module:
|
|
74
99
|
if self.activation is None:
|
|
75
100
|
return None
|
|
76
101
|
if isinstance(self.activation, str):
|
|
77
|
-
return
|
|
102
|
+
return (
|
|
103
|
+
get_torch_module(self.activation.split(";")[0])(
|
|
104
|
+
*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]
|
|
105
|
+
)
|
|
106
|
+
if self.activation != "None"
|
|
107
|
+
else torch.nn.Identity()
|
|
108
|
+
)
|
|
78
109
|
return self.activation()
|
|
79
|
-
|
|
110
|
+
|
|
111
|
+
|
|
80
112
|
class ConvBlock(network.ModuleArgsDict):
|
|
81
|
-
|
|
82
|
-
def __init__(
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
in_channels: int,
|
|
117
|
+
out_channels: int,
|
|
118
|
+
block_configs: list[BlockConfig],
|
|
119
|
+
dim: int,
|
|
120
|
+
alias: list[list[str]] = [[], [], []],
|
|
121
|
+
) -> None:
|
|
83
122
|
super().__init__()
|
|
84
|
-
for i,
|
|
85
|
-
self.add_module(
|
|
86
|
-
|
|
123
|
+
for i, block_config in enumerate(block_configs):
|
|
124
|
+
self.add_module(
|
|
125
|
+
f"Conv_{i}",
|
|
126
|
+
block_config.get_conv(in_channels, out_channels, dim),
|
|
127
|
+
alias=alias[0],
|
|
128
|
+
)
|
|
129
|
+
norm = block_config.get_norm(out_channels, dim)
|
|
87
130
|
if norm is not None:
|
|
88
|
-
self.add_module("Norm_{}"
|
|
89
|
-
activation =
|
|
131
|
+
self.add_module(f"Norm_{i}", norm, alias=alias[1])
|
|
132
|
+
activation = block_config.get_activation()
|
|
90
133
|
if activation is not None:
|
|
91
|
-
self.add_module("Activation_{}"
|
|
134
|
+
self.add_module(f"Activation_{i}", activation, alias=alias[2])
|
|
92
135
|
in_channels = out_channels
|
|
93
136
|
|
|
137
|
+
|
|
94
138
|
class ResBlock(network.ModuleArgsDict):
|
|
95
|
-
|
|
96
|
-
def __init__(
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
in_channels: int,
|
|
143
|
+
out_channels: int,
|
|
144
|
+
block_configs: list[BlockConfig],
|
|
145
|
+
dim: int,
|
|
146
|
+
alias: list[list[str]] = [[], [], [], [], []],
|
|
147
|
+
) -> None:
|
|
97
148
|
super().__init__()
|
|
98
|
-
for i,
|
|
99
|
-
self.add_module(
|
|
100
|
-
|
|
149
|
+
for i, block_config in enumerate(block_configs):
|
|
150
|
+
self.add_module(
|
|
151
|
+
f"Conv_{i}",
|
|
152
|
+
block_config.get_conv(in_channels, out_channels, dim),
|
|
153
|
+
alias=alias[0],
|
|
154
|
+
)
|
|
155
|
+
norm = block_config.get_norm(out_channels, dim)
|
|
101
156
|
if norm is not None:
|
|
102
|
-
self.add_module("Norm_{}"
|
|
103
|
-
activation =
|
|
157
|
+
self.add_module(f"Norm_{i}", norm, alias=alias[1])
|
|
158
|
+
activation = block_config.get_activation()
|
|
104
159
|
if activation is not None:
|
|
105
|
-
self.add_module("Activation_{}"
|
|
160
|
+
self.add_module(f"Activation_{i}", activation, alias=alias[2])
|
|
106
161
|
|
|
107
162
|
if in_channels != out_channels:
|
|
108
|
-
self.add_module(
|
|
109
|
-
|
|
163
|
+
self.add_module(
|
|
164
|
+
"Conv_skip",
|
|
165
|
+
get_torch_module("Conv", dim)(
|
|
166
|
+
in_channels,
|
|
167
|
+
out_channels,
|
|
168
|
+
1,
|
|
169
|
+
block_config.stride,
|
|
170
|
+
bias=block_config.bias,
|
|
171
|
+
),
|
|
172
|
+
alias=alias[3],
|
|
173
|
+
in_branch=[1],
|
|
174
|
+
out_branch=[1],
|
|
175
|
+
)
|
|
176
|
+
self.add_module(
|
|
177
|
+
"Norm_skip",
|
|
178
|
+
block_config.get_norm(out_channels, dim),
|
|
179
|
+
alias=alias[4],
|
|
180
|
+
in_branch=[1],
|
|
181
|
+
out_branch=[1],
|
|
182
|
+
)
|
|
110
183
|
in_channels = out_channels
|
|
111
|
-
|
|
112
|
-
self.add_module("Add", Add(), in_branch=[0,1])
|
|
113
|
-
self.add_module("Norm_{
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
184
|
+
|
|
185
|
+
self.add_module("Add", Add(), in_branch=[0, 1])
|
|
186
|
+
self.add_module(f"Norm_{i + 1}", torch.nn.ReLU(inplace=True))
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def downsample(in_channels: int, out_channels: int, downsample_mode: DownsampleMode, dim: int) -> torch.nn.Module:
|
|
190
|
+
if downsample_mode == DownsampleMode.MAXPOOL:
|
|
191
|
+
return get_torch_module("MaxPool", dim=dim)(2)
|
|
192
|
+
if downsample_mode == DownsampleMode.AVGPOOL:
|
|
193
|
+
return get_torch_module("AvgPool", dim=dim)(2)
|
|
194
|
+
if downsample_mode == DownsampleMode.CONV_STRIDE:
|
|
195
|
+
return get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def upsample(
|
|
199
|
+
in_channels: int,
|
|
200
|
+
out_channels: int,
|
|
201
|
+
upsample_mode: UpsampleMode,
|
|
202
|
+
dim: int,
|
|
203
|
+
kernel_size: int | list[int] = 2,
|
|
204
|
+
stride: int | list[int] = 2,
|
|
205
|
+
):
|
|
206
|
+
if upsample_mode == UpsampleMode.CONV_TRANSPOSE:
|
|
207
|
+
return get_torch_module("ConvTranspose", dim=dim)(
|
|
208
|
+
in_channels=in_channels,
|
|
209
|
+
out_channels=out_channels,
|
|
210
|
+
kernel_size=kernel_size,
|
|
211
|
+
stride=stride,
|
|
212
|
+
padding=0,
|
|
213
|
+
)
|
|
126
214
|
else:
|
|
127
215
|
if dim == 3:
|
|
128
|
-
|
|
216
|
+
upsample_method = "trilinear"
|
|
129
217
|
if dim == 2:
|
|
130
|
-
|
|
218
|
+
upsample_method = "bilinear"
|
|
131
219
|
if dim == 1:
|
|
132
|
-
|
|
133
|
-
return torch.nn.Upsample(scale_factor=2, mode=
|
|
220
|
+
upsample_method = "linear"
|
|
221
|
+
return torch.nn.Upsample(scale_factor=2, mode=upsample_method.lower(), align_corners=False)
|
|
222
|
+
|
|
134
223
|
|
|
135
224
|
class Unsqueeze(torch.nn.Module):
|
|
136
225
|
|
|
137
226
|
def __init__(self, dim: int = 0):
|
|
138
227
|
super().__init__()
|
|
139
228
|
self.dim = dim
|
|
140
|
-
|
|
141
|
-
def forward(self, *
|
|
142
|
-
return torch.unsqueeze(
|
|
143
|
-
|
|
229
|
+
|
|
230
|
+
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
|
|
231
|
+
return torch.unsqueeze(tensor, self.dim)
|
|
232
|
+
|
|
144
233
|
def extra_repr(self):
|
|
145
|
-
return "dim={
|
|
234
|
+
return f"dim={self.dim}"
|
|
235
|
+
|
|
146
236
|
|
|
147
237
|
class Permute(torch.nn.Module):
|
|
148
238
|
|
|
149
|
-
def __init__(self, dims
|
|
239
|
+
def __init__(self, dims: list[int]):
|
|
150
240
|
super().__init__()
|
|
151
241
|
self.dims = dims
|
|
152
242
|
|
|
153
|
-
def forward(self,
|
|
154
|
-
return torch.permute(
|
|
155
|
-
|
|
243
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
244
|
+
return torch.permute(tensor, self.dims)
|
|
245
|
+
|
|
156
246
|
def extra_repr(self):
|
|
157
|
-
return "dims={
|
|
247
|
+
return f"dims={self.dims}"
|
|
248
|
+
|
|
158
249
|
|
|
159
250
|
class ToChannels(Permute):
|
|
160
251
|
|
|
161
|
-
def __init__(self, dim):
|
|
162
|
-
super().__init__([0, dim+1, *[i+1 for i in range(dim)]])
|
|
163
|
-
|
|
252
|
+
def __init__(self, dim: int):
|
|
253
|
+
super().__init__([0, dim + 1, *[i + 1 for i in range(dim)]])
|
|
254
|
+
|
|
255
|
+
|
|
164
256
|
class ToFeatures(Permute):
|
|
165
257
|
|
|
166
|
-
def __init__(self, dim):
|
|
167
|
-
super().__init__([0, *[i+2 for i in range(dim)], 1])
|
|
258
|
+
def __init__(self, dim: int):
|
|
259
|
+
super().__init__([0, *[i + 2 for i in range(dim)], 1])
|
|
260
|
+
|
|
168
261
|
|
|
169
262
|
class Add(torch.nn.Module):
|
|
170
263
|
|
|
171
264
|
def __init__(self) -> None:
|
|
172
265
|
super().__init__()
|
|
173
|
-
|
|
174
|
-
def forward(self, *
|
|
175
|
-
return torch.sum(torch.stack(
|
|
176
|
-
|
|
266
|
+
|
|
267
|
+
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
|
|
268
|
+
return torch.sum(torch.stack(tensor), dim=0)
|
|
269
|
+
|
|
270
|
+
|
|
177
271
|
class Multiply(torch.nn.Module):
|
|
178
272
|
|
|
179
273
|
def __init__(self) -> None:
|
|
180
274
|
super().__init__()
|
|
181
275
|
|
|
182
|
-
def forward(self, *
|
|
183
|
-
return torch.mul(*
|
|
276
|
+
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
|
|
277
|
+
return torch.mul(*tensor)
|
|
278
|
+
|
|
184
279
|
|
|
185
280
|
class Concat(torch.nn.Module):
|
|
186
281
|
|
|
187
282
|
def __init__(self) -> None:
|
|
188
283
|
super().__init__()
|
|
189
284
|
|
|
190
|
-
def forward(self, *
|
|
191
|
-
return torch.cat(
|
|
285
|
+
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
|
|
286
|
+
return torch.cat(tensor, dim=1)
|
|
287
|
+
|
|
192
288
|
|
|
193
289
|
class Print(torch.nn.Module):
|
|
194
290
|
|
|
195
291
|
def __init__(self) -> None:
|
|
196
292
|
super().__init__()
|
|
197
|
-
|
|
198
|
-
def forward(self,
|
|
199
|
-
print(
|
|
200
|
-
return
|
|
293
|
+
|
|
294
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
295
|
+
print(tensor.shape)
|
|
296
|
+
return tensor
|
|
297
|
+
|
|
201
298
|
|
|
202
299
|
class Write(torch.nn.Module):
|
|
203
300
|
|
|
204
301
|
def __init__(self) -> None:
|
|
205
302
|
super().__init__()
|
|
206
|
-
|
|
207
|
-
def forward(self,
|
|
208
|
-
|
|
209
|
-
sitk.WriteImage(sitk.GetImageFromArray(
|
|
210
|
-
return
|
|
211
|
-
|
|
303
|
+
|
|
304
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
305
|
+
|
|
306
|
+
sitk.WriteImage(sitk.GetImageFromArray(tensor.clone()[0][0].cpu().numpy()), "./Data.mha")
|
|
307
|
+
return tensor
|
|
308
|
+
|
|
309
|
+
|
|
212
310
|
class Exit(torch.nn.Module):
|
|
213
311
|
|
|
214
312
|
def __init__(self) -> None:
|
|
215
313
|
super().__init__()
|
|
216
|
-
|
|
217
|
-
def forward(self,
|
|
314
|
+
|
|
315
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
218
316
|
exit(0)
|
|
219
|
-
|
|
317
|
+
|
|
318
|
+
|
|
220
319
|
class Detach(torch.nn.Module):
|
|
221
320
|
|
|
222
321
|
def __init__(self) -> None:
|
|
223
322
|
super().__init__()
|
|
224
|
-
|
|
225
|
-
def forward(self,
|
|
226
|
-
return
|
|
227
|
-
|
|
323
|
+
|
|
324
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
return tensor.detach()
|
|
326
|
+
|
|
327
|
+
|
|
228
328
|
class Negative(torch.nn.Module):
|
|
229
329
|
|
|
230
330
|
def __init__(self) -> None:
|
|
231
331
|
super().__init__()
|
|
232
|
-
|
|
233
|
-
def forward(self,
|
|
234
|
-
return -
|
|
332
|
+
|
|
333
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
334
|
+
return -tensor
|
|
335
|
+
|
|
235
336
|
|
|
236
337
|
class GetShape(torch.nn.Module):
|
|
237
338
|
|
|
238
339
|
def __init__(self) -> None:
|
|
239
340
|
super().__init__()
|
|
240
341
|
|
|
241
|
-
def forward(self,
|
|
242
|
-
return torch.tensor(
|
|
342
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
343
|
+
return torch.tensor(tensor.shape)
|
|
344
|
+
|
|
243
345
|
|
|
244
346
|
class ArgMax(torch.nn.Module):
|
|
245
347
|
|
|
@@ -247,42 +349,46 @@ class ArgMax(torch.nn.Module):
|
|
|
247
349
|
super().__init__()
|
|
248
350
|
self.dim = dim
|
|
249
351
|
|
|
250
|
-
def forward(self,
|
|
251
|
-
return torch.argmax(
|
|
252
|
-
|
|
352
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
353
|
+
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
354
|
+
|
|
355
|
+
|
|
253
356
|
class Select(torch.nn.Module):
|
|
254
357
|
|
|
255
358
|
def __init__(self, slices: list[slice]) -> None:
|
|
256
359
|
super().__init__()
|
|
257
360
|
self.slices = tuple(slices)
|
|
258
361
|
|
|
259
|
-
def forward(self,
|
|
260
|
-
result =
|
|
362
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
363
|
+
result = tensor[self.slices]
|
|
261
364
|
for i, s in enumerate(range(len(result.shape))):
|
|
262
365
|
if s == 1:
|
|
263
|
-
|
|
366
|
+
result = result.squeeze(dim=i)
|
|
264
367
|
return result
|
|
265
368
|
|
|
369
|
+
|
|
266
370
|
class NormalNoise(torch.nn.Module):
|
|
267
371
|
|
|
268
|
-
def __init__(self, dim:
|
|
372
|
+
def __init__(self, dim: int | None = None) -> None:
|
|
269
373
|
super().__init__()
|
|
270
374
|
self.dim = dim
|
|
271
|
-
|
|
272
|
-
def forward(self,
|
|
375
|
+
|
|
376
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
273
377
|
if self.dim is not None:
|
|
274
|
-
return torch.randn(self.dim).to(
|
|
378
|
+
return torch.randn(self.dim).to(tensor.device)
|
|
275
379
|
else:
|
|
276
|
-
return torch.randn_like(
|
|
277
|
-
|
|
380
|
+
return torch.randn_like(tensor).to(tensor.device)
|
|
381
|
+
|
|
382
|
+
|
|
278
383
|
class Const(torch.nn.Module):
|
|
279
384
|
|
|
280
385
|
def __init__(self, shape: list[int], std: float) -> None:
|
|
281
386
|
super().__init__()
|
|
282
|
-
self.noise = torch.nn.parameter.Parameter(torch.randn(shape)*std)
|
|
283
|
-
|
|
284
|
-
def forward(self,
|
|
285
|
-
return self.noise.to(
|
|
387
|
+
self.noise = torch.nn.parameter.Parameter(torch.randn(shape) * std)
|
|
388
|
+
|
|
389
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
390
|
+
return self.noise.to(tensor.device)
|
|
391
|
+
|
|
286
392
|
|
|
287
393
|
class HistogramNoise(torch.nn.Module):
|
|
288
394
|
|
|
@@ -290,80 +396,118 @@ class HistogramNoise(torch.nn.Module):
|
|
|
290
396
|
super().__init__()
|
|
291
397
|
self.x = np.linspace(0, 1, num=n, endpoint=True)
|
|
292
398
|
self.sigma = sigma
|
|
293
|
-
|
|
294
|
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
295
|
-
self.function = interp1d(self.x, self.x+np.random.normal(0, self.sigma, self.x.shape[0]), kind='cubic')
|
|
296
|
-
result = torch.empty_like(input)
|
|
297
399
|
|
|
298
|
-
|
|
400
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
401
|
+
self.function = interp1d(
|
|
402
|
+
self.x,
|
|
403
|
+
self.x + np.random.normal(0, self.sigma, self.x.shape[0]),
|
|
404
|
+
kind="cubic",
|
|
405
|
+
)
|
|
406
|
+
result = torch.empty_like(tensor)
|
|
407
|
+
|
|
408
|
+
for value in torch.unique(tensor):
|
|
299
409
|
x = self.function(value.cpu())
|
|
300
|
-
result[torch.where(
|
|
410
|
+
result[torch.where(tensor == value)] = torch.tensor(x, device=tensor.device).float()
|
|
301
411
|
return result
|
|
302
412
|
|
|
413
|
+
|
|
303
414
|
class Subset(torch.nn.Module):
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
415
|
+
def __init__(self, slices: list[slice]):
|
|
416
|
+
super().__init__()
|
|
417
|
+
self.slices = [slice(None, None), slice(None, None)] + slices
|
|
418
|
+
|
|
419
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
420
|
+
return tensor[self.slices]
|
|
307
421
|
|
|
308
|
-
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
309
|
-
return tensor[self.slices]
|
|
310
422
|
|
|
311
423
|
class View(torch.nn.Module):
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
424
|
+
def __init__(self, size: list[int]):
|
|
425
|
+
super().__init__()
|
|
426
|
+
self.size = size
|
|
427
|
+
|
|
428
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
429
|
+
return tensor.view(self.size)
|
|
430
|
+
|
|
315
431
|
|
|
316
|
-
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
317
|
-
return tensor.view(self.size)
|
|
318
|
-
|
|
319
432
|
class LatentDistribution(network.ModuleArgsDict):
|
|
320
433
|
|
|
321
|
-
class
|
|
434
|
+
class LatentDistributionLinear(torch.nn.Module):
|
|
322
435
|
|
|
323
|
-
def __init__(self, shape: list[int],
|
|
436
|
+
def __init__(self, shape: list[int], latent_dim: int) -> None:
|
|
324
437
|
super().__init__()
|
|
325
|
-
self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)),
|
|
326
|
-
|
|
327
|
-
def forward(self,
|
|
328
|
-
return torch.unsqueeze(self.linear(
|
|
329
|
-
|
|
330
|
-
class
|
|
331
|
-
|
|
332
|
-
def __init__(self, shape: list[int],
|
|
438
|
+
self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latent_dim)
|
|
439
|
+
|
|
440
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
441
|
+
return torch.unsqueeze(self.linear(tensor), 1)
|
|
442
|
+
|
|
443
|
+
class LatentDistributionDecoder(torch.nn.Module):
|
|
444
|
+
|
|
445
|
+
def __init__(self, shape: list[int], latent_dim: int) -> None:
|
|
333
446
|
super().__init__()
|
|
334
|
-
self.linear = torch.nn.Linear(
|
|
447
|
+
self.linear = torch.nn.Linear(latent_dim, torch.prod(torch.tensor(shape)))
|
|
335
448
|
self.shape = shape
|
|
336
449
|
|
|
337
|
-
def forward(self,
|
|
338
|
-
return self.linear(
|
|
339
|
-
|
|
340
|
-
class
|
|
450
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
451
|
+
return self.linear(tensor).view(-1, *[int(i) for i in self.shape])
|
|
452
|
+
|
|
453
|
+
class LatentDistributionZ(torch.nn.Module):
|
|
341
454
|
|
|
342
455
|
def __init__(self) -> None:
|
|
343
456
|
super().__init__()
|
|
344
457
|
|
|
345
458
|
def forward(self, mu: torch.Tensor, log_std: torch.Tensor) -> torch.Tensor:
|
|
346
|
-
return torch.exp(log_std/2)*torch.rand_like(mu)+mu
|
|
347
|
-
|
|
348
|
-
def __init__(self, shape: list[int],
|
|
349
|
-
super().__init__()
|
|
459
|
+
return torch.exp(log_std / 2) * torch.rand_like(mu) + mu
|
|
460
|
+
|
|
461
|
+
def __init__(self, shape: list[int], latent_dim: int) -> None:
|
|
462
|
+
super().__init__()
|
|
350
463
|
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
351
|
-
self.add_module(
|
|
352
|
-
|
|
464
|
+
self.add_module(
|
|
465
|
+
"mu",
|
|
466
|
+
LatentDistribution.LatentDistributionLinear(shape, latent_dim),
|
|
467
|
+
out_branch=[1],
|
|
468
|
+
)
|
|
469
|
+
self.add_module(
|
|
470
|
+
"log_std",
|
|
471
|
+
LatentDistribution.LatentDistributionLinear(shape, latent_dim),
|
|
472
|
+
out_branch=[2],
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
self.add_module(
|
|
476
|
+
"z",
|
|
477
|
+
LatentDistribution.LatentDistributionZ(),
|
|
478
|
+
in_branch=[1, 2],
|
|
479
|
+
out_branch=[3],
|
|
480
|
+
)
|
|
481
|
+
self.add_module("Concat", Concat(), in_branch=[1, 2, 3])
|
|
482
|
+
self.add_module(
|
|
483
|
+
"DecoderInput",
|
|
484
|
+
LatentDistribution.LatentDistributionDecoder(shape, latent_dim),
|
|
485
|
+
in_branch=[3],
|
|
486
|
+
)
|
|
353
487
|
|
|
354
|
-
self.add_module("z", LatentDistribution.LatentDistribution_Z(), in_branch=[1,2], out_branch=[3])
|
|
355
|
-
self.add_module("Concat", Concat(), in_branch=[1,2,3])
|
|
356
|
-
self.add_module("DecoderInput", LatentDistribution.LatentDistribution_Decoder(shape, latentDim), in_branch=[3])
|
|
357
488
|
|
|
358
489
|
class Attention(network.ModuleArgsDict):
|
|
359
490
|
|
|
360
|
-
def __init__(self,
|
|
491
|
+
def __init__(self, f_g: int, f_l: int, f_int: int, dim: int):
|
|
361
492
|
super().__init__()
|
|
362
|
-
self.add_module(
|
|
363
|
-
|
|
364
|
-
|
|
493
|
+
self.add_module(
|
|
494
|
+
"W_x",
|
|
495
|
+
get_torch_module("Conv", dim=dim)(in_channels=f_l, out_channels=f_int, kernel_size=1, stride=2, padding=0),
|
|
496
|
+
in_branch=[0],
|
|
497
|
+
out_branch=[0],
|
|
498
|
+
)
|
|
499
|
+
self.add_module(
|
|
500
|
+
"W_g",
|
|
501
|
+
get_torch_module("Conv", dim=dim)(in_channels=f_g, out_channels=f_int, kernel_size=1, stride=1, padding=0),
|
|
502
|
+
in_branch=[1],
|
|
503
|
+
out_branch=[1],
|
|
504
|
+
)
|
|
505
|
+
self.add_module("Add", Add(), in_branch=[0, 1])
|
|
365
506
|
self.add_module("ReLU", torch.nn.ReLU(inplace=True))
|
|
366
|
-
self.add_module(
|
|
507
|
+
self.add_module(
|
|
508
|
+
"Conv",
|
|
509
|
+
get_torch_module("Conv", dim=dim)(in_channels=f_int, out_channels=1, kernel_size=1, stride=1, padding=0),
|
|
510
|
+
)
|
|
367
511
|
self.add_module("Sigmoid", torch.nn.Sigmoid())
|
|
368
512
|
self.add_module("Upsample", torch.nn.Upsample(scale_factor=2))
|
|
369
|
-
self.add_module("Multiply", Multiply(), in_branch=[2,0])
|
|
513
|
+
self.add_module("Multiply", Multiply(), in_branch=[2, 0])
|