konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from KonfAI.konfai.network import network, blocks
|
|
4
|
+
from KonfAI.konfai.utils.config import config
|
|
5
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]
|
|
9
|
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]
|
|
10
|
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]
|
|
11
|
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",[3, 3, 27, 3], dims=[192, 384, 768, 1536]
|
|
12
|
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
|
13
|
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
|
14
|
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
|
15
|
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
|
16
|
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
class LayerNorm(torch.nn.Module):
|
|
20
|
+
|
|
21
|
+
def __init__(self, normalized_shape : int, eps : float = 1e-6, data_format : str="channels_last"):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.weight = torch.nn.parameter.Parameter(torch.ones(normalized_shape))
|
|
24
|
+
self.bias = torch.nn.parameter.Parameter(torch.zeros(normalized_shape))
|
|
25
|
+
self.eps = eps
|
|
26
|
+
self.data_format = data_format
|
|
27
|
+
if self.data_format not in ["channels_last", "channels_first"]:
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
self.normalized_shape = (normalized_shape, )
|
|
30
|
+
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
if self.data_format == "channels_last":
|
|
33
|
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
34
|
+
elif self.data_format == "channels_first":
|
|
35
|
+
u = x.mean(1, keepdim=True)
|
|
36
|
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
37
|
+
x = (x - u) / torch.sqrt(s + self.eps)
|
|
38
|
+
if len(x.shape) == 3:
|
|
39
|
+
x = self.weight[:, None] * x + self.bias[:, None]
|
|
40
|
+
elif len(x.shape) == 4:
|
|
41
|
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
42
|
+
else:
|
|
43
|
+
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
|
|
44
|
+
return x
|
|
45
|
+
|
|
46
|
+
def extra_repr(self):
|
|
47
|
+
return "normalized_shape={}, eps={}, data_format={})".format(self.normalized_shape, self.eps, self.data_format)
|
|
48
|
+
|
|
49
|
+
class DropPath(torch.nn.Module):
|
|
50
|
+
|
|
51
|
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
|
52
|
+
super(DropPath, self).__init__()
|
|
53
|
+
self.drop_prob = drop_prob
|
|
54
|
+
self.scale_by_keep = scale_by_keep
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
if self.drop_prob == 0. or not self.training:
|
|
58
|
+
return x
|
|
59
|
+
keep_prob = 1 - self.drop_prob
|
|
60
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
61
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
62
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
63
|
+
random_tensor.div_(keep_prob)
|
|
64
|
+
return x * random_tensor
|
|
65
|
+
|
|
66
|
+
def extra_repr(self):
|
|
67
|
+
return "drop_prob={}".format(round(self.drop_prob,3))
|
|
68
|
+
|
|
69
|
+
class LayerScaler(torch.nn.Module):
|
|
70
|
+
|
|
71
|
+
def __init__(self, init_value : float, dimensions : int):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.init_value = init_value
|
|
74
|
+
self.gamma = torch.nn.Parameter(torch.ones(dimensions, 1, 1) * init_value)
|
|
75
|
+
|
|
76
|
+
def forward(self, input : torch.Tensor) -> torch.Tensor:
|
|
77
|
+
return self.gamma * input
|
|
78
|
+
|
|
79
|
+
def extra_repr(self):
|
|
80
|
+
return "init_value={}".format(self.init_value)
|
|
81
|
+
|
|
82
|
+
class BottleNeckBlock(network.ModuleArgsDict):
|
|
83
|
+
def __init__( self,
|
|
84
|
+
features: int,
|
|
85
|
+
drop_p: float,
|
|
86
|
+
layer_scaler_init_value: float,
|
|
87
|
+
dim: int):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(features, features, kernel_size=7, padding=3, groups=features), alias=["dwconv"])
|
|
90
|
+
self.add_module("ToFeatures", blocks.ToFeatures(dim))
|
|
91
|
+
self.add_module("LayerNorm", LayerNorm(features, eps=1e-6), alias=["norm"])
|
|
92
|
+
self.add_module("Linear_1", torch.nn.Linear(features, features * 4), alias=["pwconv1"])
|
|
93
|
+
self.add_module("GELU", torch.nn.GELU())
|
|
94
|
+
self.add_module("Linear_2", torch.nn.Linear(features * 4, features), alias=["pwconv2"])
|
|
95
|
+
self.add_module("ToChannels", blocks.ToChannels(dim))
|
|
96
|
+
self.add_module("LayerScaler", LayerScaler(init_value=layer_scaler_init_value, dimensions=features), alias=[""])
|
|
97
|
+
self.add_module("StochasticDepth", DropPath(drop_p))
|
|
98
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0,1])
|
|
99
|
+
|
|
100
|
+
class DownSample(network.ModuleArgsDict):
|
|
101
|
+
|
|
102
|
+
def __init__( self,
|
|
103
|
+
in_features: int,
|
|
104
|
+
out_features: int,
|
|
105
|
+
dim : int):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.add_module("LayerNorm", LayerNorm(in_features, eps=1e-6, data_format="channels_first"), alias=["0"])
|
|
108
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_features, out_features, kernel_size=2, stride=2), alias=["1"])
|
|
109
|
+
|
|
110
|
+
class ConvNexStage(network.ModuleArgsDict):
|
|
111
|
+
|
|
112
|
+
def __init__( self,
|
|
113
|
+
features: int,
|
|
114
|
+
depth: int,
|
|
115
|
+
drop_p: list[float],
|
|
116
|
+
dim : int):
|
|
117
|
+
super().__init__()
|
|
118
|
+
for i in range(depth):
|
|
119
|
+
self.add_module("BottleNeckBlock_{}".format(i), BottleNeckBlock(features=features, drop_p=drop_p[i], layer_scaler_init_value=1e-6, dim=dim), alias=["{}".format(i)])
|
|
120
|
+
|
|
121
|
+
class ConvNextStem(network.ModuleArgsDict):
|
|
122
|
+
|
|
123
|
+
def __init__(self, in_features: int, out_features: int, dim: int):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_features, out_features, kernel_size=4, stride=4), alias=["0"])
|
|
126
|
+
self.add_module("LayerNorm", LayerNorm(out_features, eps=1e-6, data_format="channels_first"), alias=["1"])
|
|
127
|
+
|
|
128
|
+
class ConvNextEncoder(network.ModuleArgsDict):
|
|
129
|
+
|
|
130
|
+
def __init__( self,
|
|
131
|
+
in_channels: int,
|
|
132
|
+
depths: list[int],
|
|
133
|
+
widths: list[int],
|
|
134
|
+
drop_p: float,
|
|
135
|
+
dim : int):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.add_module("ConvNextStem", ConvNextStem(in_channels, widths[0], dim=dim), alias=["downsample_layers.0"])
|
|
138
|
+
|
|
139
|
+
drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))]
|
|
140
|
+
self.add_module("ConvNexStage_0", ConvNexStage(features=widths[0], depth=depths[0], drop_p=drop_probs[:depths[0]], dim=dim), alias=["stages.0"])
|
|
141
|
+
|
|
142
|
+
for i, (in_features, out_features) in enumerate(list(zip(widths[:], widths[1:]))):
|
|
143
|
+
self.add_module("DownSample_{}".format(i+1), DownSample(in_features=in_features, out_features=out_features, dim=dim), alias=["downsample_layers.{}".format(i+1)])
|
|
144
|
+
self.add_module("ConvNexStage_{}".format(i+1), ConvNexStage(features=out_features, depth=depths[i+1], drop_p=drop_probs[sum(depths[:i+1]):sum(depths[:i+2])], dim=dim), alias=["stages.{}".format(i+1)])
|
|
145
|
+
|
|
146
|
+
class Head(network.ModuleArgsDict):
|
|
147
|
+
|
|
148
|
+
def __init__(self, in_features : int, num_classes : list[int], dim : int) -> None:
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
|
|
151
|
+
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
152
|
+
self.add_module("LayerNorm", torch.nn.LayerNorm(in_features, eps=1e-6), alias=["norm"])
|
|
153
|
+
|
|
154
|
+
for i, nb_classe in enumerate(num_classes):
|
|
155
|
+
self.add_module("Linear_{}".format(i), torch.nn.Linear(in_features, nb_classe), pretrained=False, alias=["head"], out_branch=[i+1])
|
|
156
|
+
self.add_module("Unsqueeze_{}".format(i), blocks.Unsqueeze(2), in_branch=[i+1], out_branch=[-1])
|
|
157
|
+
|
|
158
|
+
class ConvNeXt(network.Network):
|
|
159
|
+
|
|
160
|
+
@config("ConvNeXt")
|
|
161
|
+
def __init__( self,
|
|
162
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
163
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
164
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
165
|
+
patch : ModelPatch = ModelPatch(),
|
|
166
|
+
dim : int = 3,
|
|
167
|
+
in_channels: int = 1,
|
|
168
|
+
depths: list[int] = [3,3,27,3],
|
|
169
|
+
widths: list[int] = [128, 256, 512, 1024],
|
|
170
|
+
drop_p: float = 0.1,
|
|
171
|
+
num_classes: list[int] = [4, 7]):
|
|
172
|
+
|
|
173
|
+
super().__init__(in_channels = in_channels, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, patch=patch, init_type = "trunc_normal", init_gain=0.02)
|
|
174
|
+
self.add_module("ConvNextEncoder", ConvNextEncoder(in_channels=in_channels, depths=depths, widths=widths, drop_p=drop_p, dim=dim))
|
|
175
|
+
self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Type
|
|
3
|
+
import torch
|
|
4
|
+
from KonfAI.konfai.network import network, blocks
|
|
5
|
+
from KonfAI.konfai.utils.config import config
|
|
6
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
|
|
10
|
+
|
|
11
|
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
|
|
12
|
+
|
|
13
|
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
|
|
14
|
+
|
|
15
|
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', dim = 2, in_channels = 3, depths=[3, 4, 23, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
|
|
16
|
+
|
|
17
|
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', dim = 2, in_channels = 3, depths=[3, 8, 36, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
|
|
18
|
+
|
|
19
|
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
|
20
|
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
|
21
|
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
|
22
|
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
|
23
|
+
"""
|
|
24
|
+
class AbstractResBlock(network.ModuleArgsDict, ABC):
|
|
25
|
+
|
|
26
|
+
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
class ResBlock(AbstractResBlock):
|
|
30
|
+
|
|
31
|
+
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
32
|
+
super().__init__(in_channels, out_channels, downsample, dim)
|
|
33
|
+
if downsample:
|
|
34
|
+
self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
|
|
35
|
+
|
|
36
|
+
self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
|
|
37
|
+
self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
|
|
38
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0,1])
|
|
39
|
+
self.add_module("ReLU", torch.nn.ReLU())
|
|
40
|
+
|
|
41
|
+
class ResBottleneckBlock(AbstractResBlock):
|
|
42
|
+
def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
|
|
43
|
+
super().__init__(in_channels, out_channels, downsample, dim)
|
|
44
|
+
self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels//4, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
|
|
45
|
+
self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels//4, out_channels//4, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
|
|
46
|
+
self.add_module("ConvBlock_2", blocks.ConvBlock(out_channels//4, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv3"], ["bn3"], []]))
|
|
47
|
+
|
|
48
|
+
if downsample or in_channels != out_channels:
|
|
49
|
+
self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2 if downsample else 1, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
|
|
50
|
+
|
|
51
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0,1])
|
|
52
|
+
self.add_module("ReLU", torch.nn.ReLU())
|
|
53
|
+
|
|
54
|
+
class ResNetStage(network.ModuleArgsDict):
|
|
55
|
+
|
|
56
|
+
def __init__(self,
|
|
57
|
+
in_channels: int,
|
|
58
|
+
out_channels: int,
|
|
59
|
+
depth: int,
|
|
60
|
+
block : Type[AbstractResBlock],
|
|
61
|
+
downsample : bool,
|
|
62
|
+
dim : int):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.add_module("BottleNeckBlock_0", block(in_channels=in_channels, out_channels=out_channels, downsample=downsample, dim=dim), alias=["0"])
|
|
65
|
+
for i in range(1, depth):
|
|
66
|
+
self.add_module("BottleNeckBlock_{}".format(i), block(in_channels=out_channels, out_channels=out_channels, downsample=False, dim=dim), alias=["{}".format(i)])
|
|
67
|
+
|
|
68
|
+
class ResNetStem(network.ModuleArgsDict):
|
|
69
|
+
|
|
70
|
+
def __init__(self, in_channels: int, out_features: int, dim: int):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_features, 1, blocks.BlockConfig(kernel_size=7, stride=2, padding=3, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
|
|
73
|
+
self.add_module("MaxPool", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1))
|
|
74
|
+
|
|
75
|
+
class ResNetEncoder(network.ModuleArgsDict):
|
|
76
|
+
|
|
77
|
+
def __init__(self,
|
|
78
|
+
in_channels: int,
|
|
79
|
+
depths: list[int],
|
|
80
|
+
widths: list[int],
|
|
81
|
+
useBottleneck : bool,
|
|
82
|
+
dim : int):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.add_module("ResNetStem", ResNetStem(in_channels, widths[0], dim=dim))
|
|
85
|
+
|
|
86
|
+
for i, (in_channels, out_channels, depth) in enumerate(list(zip(widths[:], widths[1:], depths))):
|
|
87
|
+
self.add_module("ResNetStage_{}".format(i), ResNetStage(in_channels=in_channels, out_channels=out_channels, depth=depth, block=ResBottleneckBlock if useBottleneck else ResBlock, downsample=i!=0, dim=dim), alias=["layer{}".format(i+1)])
|
|
88
|
+
|
|
89
|
+
class Head(network.ModuleArgsDict):
|
|
90
|
+
|
|
91
|
+
def __init__(self, in_features : int, num_classes : int, dim : int):
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(1))
|
|
94
|
+
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
95
|
+
self.add_module("Linear", torch.nn.Linear(in_features, num_classes), pretrained=False, alias=["fc"])
|
|
96
|
+
self.add_module("Unsqueeze", blocks.Unsqueeze(2))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class ResNet(network.Network):
|
|
100
|
+
|
|
101
|
+
@config("ResNet")
|
|
102
|
+
def __init__( self,
|
|
103
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
104
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
105
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
106
|
+
patch : ModelPatch = ModelPatch(),
|
|
107
|
+
dim : int = 3,
|
|
108
|
+
in_channels: int = 1,
|
|
109
|
+
depths: list[int] = [2, 2, 2, 2],
|
|
110
|
+
widths: list[int] = [64, 64, 128, 256, 512],
|
|
111
|
+
num_classes: int = 10,
|
|
112
|
+
useBottleneck=False):
|
|
113
|
+
super().__init__(in_channels = in_channels, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, patch=patch, init_type = "trunc_normal", init_gain=0.02)
|
|
114
|
+
self.add_module("ResNetEncoder", ResNetEncoder(in_channels=in_channels, depths=depths, widths=widths, useBottleneck=useBottleneck, dim=dim))
|
|
115
|
+
self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
|
|
116
|
+
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import torch
|
|
3
|
+
from KonfAI.konfai.network import network, blocks
|
|
4
|
+
from KonfAI.konfai.utils.config import config
|
|
5
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
6
|
+
class MappingNetwork(network.ModuleArgsDict):
|
|
7
|
+
def __init__(self, z_dim: int, c_dim: int, w_dim: int, num_layers: int, embed_features: int, layer_features: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.add_module("Concat_1", blocks.Concat(), in_branch=[0,1])
|
|
11
|
+
|
|
12
|
+
features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
|
|
13
|
+
if c_dim > 0:
|
|
14
|
+
self.add_module("Linear", torch.nn.Linear(c_dim, embed_features), out_branch=["Embed"])
|
|
15
|
+
|
|
16
|
+
self.add_module("Noise", blocks.NormalNoise(z_dim), in_branch=["Embed"])
|
|
17
|
+
if c_dim > 0:
|
|
18
|
+
self.add_module("Concat", blocks.Concat(), in_branch=[0,"Embed"])
|
|
19
|
+
|
|
20
|
+
for i, (in_features, out_features) in enumerate(zip(features, features[1:])):
|
|
21
|
+
self.add_module("Linear_{}".format(i), torch.nn.Linear(in_features, out_features))
|
|
22
|
+
|
|
23
|
+
class ModulatedConv(torch.nn.Module):
|
|
24
|
+
|
|
25
|
+
class _ModulatedConv(torch.nn.Module):
|
|
26
|
+
|
|
27
|
+
def __init__(self, w_dim: int, conv: torch.nn.modules.conv._ConvNd, dim: int) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.affine = torch.nn.Linear(w_dim, conv.in_channels)
|
|
30
|
+
self.isConv = True
|
|
31
|
+
self.in_channels = conv.in_channels
|
|
32
|
+
self.out_channels = conv.out_channels
|
|
33
|
+
self.padding = conv.padding
|
|
34
|
+
self.stride = conv.stride
|
|
35
|
+
if isinstance(conv, torch.nn.modules.conv._ConvTransposeNd):
|
|
36
|
+
self.weight = torch.nn.parameter.Parameter(torch.randn((conv.in_channels, conv.out_channels, *conv.kernel_size)))
|
|
37
|
+
self.isConv = False
|
|
38
|
+
else:
|
|
39
|
+
self.weight = torch.nn.parameter.Parameter(torch.randn((conv.out_channels, conv.in_channels, *conv.kernel_size)))
|
|
40
|
+
conv.forward = self.forward
|
|
41
|
+
self.styles = None
|
|
42
|
+
self.dim = dim
|
|
43
|
+
|
|
44
|
+
def setStyle(self, styles: torch.Tensor) -> None:
|
|
45
|
+
self.styles = styles
|
|
46
|
+
|
|
47
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
b = input.shape[0]
|
|
49
|
+
self.affine.to(input.device)
|
|
50
|
+
styles = self.affine(self.styles)
|
|
51
|
+
w1 = styles.reshape(b, -1, 1, *[1 for _ in range(self.dim)]) if not self.isConv else styles.reshape(b, 1, -1, *[1 for _ in range(self.dim)])
|
|
52
|
+
w2 = self.weight.unsqueeze(0).to(input.device)
|
|
53
|
+
weights = w2 * (w1 + 1)
|
|
54
|
+
|
|
55
|
+
d = torch.rsqrt((weights ** 2).sum(dim=tuple([i+2 for i in range(len(weights.shape)-2)]), keepdim=True) + 1e-8)
|
|
56
|
+
weights = weights * d
|
|
57
|
+
|
|
58
|
+
input = input.reshape(1, -1, *input.shape[2:])
|
|
59
|
+
|
|
60
|
+
_, _, *ws = weights.shape
|
|
61
|
+
if not self.isConv:
|
|
62
|
+
out = getattr(importlib.import_module("torch.nn.functional"), "conv_transpose{}d".format(self.dim))(input, weights.reshape(b * self.in_channels, *ws), stride=self.stride, padding=self.padding, groups=b)
|
|
63
|
+
else:
|
|
64
|
+
out = getattr(importlib.import_module("torch.nn.functional"), "conv{}d".format(self.dim))(input, weights.reshape(b * self.out_channels, *ws), padding=self.padding, groups=b, stride=self.stride)
|
|
65
|
+
|
|
66
|
+
out = out.reshape(-1, self.out_channels, *out.shape[2:])
|
|
67
|
+
return out
|
|
68
|
+
|
|
69
|
+
def __init__(self, w_dim: int, module: torch.nn.Module) -> None:
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.w_dim = w_dim
|
|
72
|
+
self.module = module
|
|
73
|
+
self.convs = torch.nn.ModuleList()
|
|
74
|
+
self.module.apply(self.apply)
|
|
75
|
+
|
|
76
|
+
def forward(self, input: torch.Tensor, styles: torch.Tensor) -> torch.Tensor:
|
|
77
|
+
for conv in self.convs:
|
|
78
|
+
conv.setStyle(styles.clone())
|
|
79
|
+
return self.module(input)
|
|
80
|
+
|
|
81
|
+
def apply(self, module: torch.nn.Module):
|
|
82
|
+
if isinstance(module, torch.nn.modules.conv._ConvNd):
|
|
83
|
+
delattr(module, "weight")
|
|
84
|
+
module.bias = None
|
|
85
|
+
|
|
86
|
+
str_dim = module.__class__.__name__[-2:]
|
|
87
|
+
dim = 1
|
|
88
|
+
if str_dim == "2d":
|
|
89
|
+
dim = 2
|
|
90
|
+
elif str_dim == "3d":
|
|
91
|
+
dim = 3
|
|
92
|
+
self.convs.append(ModulatedConv._ModulatedConv(self.w_dim, module, dim=dim))
|
|
93
|
+
|
|
94
|
+
class UNetBlock(network.ModuleArgsDict):
|
|
95
|
+
|
|
96
|
+
def __init__(self, w_dim: int, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, dim: int, i : int = 0) -> None:
|
|
97
|
+
super().__init__()
|
|
98
|
+
if i > 0:
|
|
99
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
100
|
+
self.add_module("DownConvBlock", blocks.ConvBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
|
|
101
|
+
if len(channels) > 2:
|
|
102
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(w_dim, channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, dim, i+1), in_branch=[0,1])
|
|
103
|
+
self.add_module("UpConvBlock", ModulatedConv(w_dim, blocks.ConvBlock((channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim)), in_branch=[0,1])
|
|
104
|
+
if i > 0:
|
|
105
|
+
if attention:
|
|
106
|
+
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=["Skip", 0], out_branch=["Skip"])
|
|
107
|
+
self.add_module(upSampleMode.name, ModulatedConv(w_dim, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim)), in_branch=[0, 1])
|
|
108
|
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, "Skip"])
|
|
109
|
+
|
|
110
|
+
class Generator(network.Network):
|
|
111
|
+
|
|
112
|
+
class GeneratorHead(network.ModuleArgsDict):
|
|
113
|
+
|
|
114
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
117
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
118
|
+
|
|
119
|
+
@config("Generator")
|
|
120
|
+
def __init__(self,
|
|
121
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
122
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
123
|
+
patch : ModelPatch = ModelPatch(),
|
|
124
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
125
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
126
|
+
nb_batch_per_step: int = 64,
|
|
127
|
+
z_dim: int = 512,
|
|
128
|
+
c_dim: int = 1,
|
|
129
|
+
w_dim: int = 512,
|
|
130
|
+
dim : int = 3) -> None:
|
|
131
|
+
super().__init__(optimizer=optimizer, in_channels=channels[0], schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
|
|
132
|
+
|
|
133
|
+
self.add_module("MappingNetwork", MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_layers=8, embed_features=w_dim, layer_features=w_dim), in_branch=[1,2], out_branch=["Style"])
|
|
134
|
+
nb_conv_per_stage = 2
|
|
135
|
+
blockConfig = blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=True, activation="ReLU", normMode="INSTANCE")
|
|
136
|
+
self.add_module("UNetBlock_0", UNetBlock(w_dim, channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode.MAXPOOL, upSampleMode=blocks.UpSampleMode.CONV_TRANSPOSE, attention=False, dim=dim), in_branch=[0, "Style"])
|
|
137
|
+
self.add_module("Head", Generator.GeneratorHead(in_channels=channels[1], out_channels=1, dim=dim))
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from typing import Union, Callable
|
|
5
|
+
import torch
|
|
6
|
+
import tqdm
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from KonfAI.konfai.network import network, blocks
|
|
10
|
+
from KonfAI.konfai.utils.config import config
|
|
11
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
12
|
+
from KonfAI.konfai.utils.utils import gpuInfo
|
|
13
|
+
from KonfAI.konfai.metric.measure import Criterion
|
|
14
|
+
|
|
15
|
+
def cosine_beta_schedule(timesteps, s=0.008):
|
|
16
|
+
steps = timesteps + 1
|
|
17
|
+
x = torch.linspace(0, timesteps, steps)
|
|
18
|
+
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
|
19
|
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
20
|
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
21
|
+
return torch.clip(betas, 0.0001, 0.9999)
|
|
22
|
+
|
|
23
|
+
class DDPM(network.Network):
|
|
24
|
+
|
|
25
|
+
class DDPM_TE(torch.nn.Module):
|
|
26
|
+
|
|
27
|
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.linear_0 = torch.nn.Linear(in_channels, out_channels)
|
|
30
|
+
self.siLU = torch.nn.SiLU()
|
|
31
|
+
self.linear_1 = torch.nn.Linear(out_channels, out_channels)
|
|
32
|
+
|
|
33
|
+
def forward(self, input: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
return input + self.linear_1(self.siLU(self.linear_0(t))).reshape(input.shape[0], -1, *[1 for _ in range(len(input.shape)-2)])
|
|
35
|
+
|
|
36
|
+
class DDPM_UNetBlock(network.ModuleArgsDict):
|
|
37
|
+
|
|
38
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, time_embedding_dim: int, dim: int, i : int = 0) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
if i > 0:
|
|
41
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
42
|
+
self.add_module("Te_down", DDPM.DDPM_TE(time_embedding_dim, channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0]), in_branch=[0, 1])
|
|
43
|
+
self.add_module("DownConvBlock", blocks.ResBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
|
|
44
|
+
if len(channels) > 2:
|
|
45
|
+
self.add_module("UNetBlock_{}".format(i+1), DDPM.DDPM_UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, time_embedding_dim, dim, i+1), in_branch=[0,1])
|
|
46
|
+
self.add_module("Te_up", DDPM.DDPM_TE(time_embedding_dim, (channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2), in_branch=[0, 1])
|
|
47
|
+
self.add_module("UpConvBlock", blocks.ResBlock(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
|
|
48
|
+
if i > 0:
|
|
49
|
+
if attention:
|
|
50
|
+
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[2, 0], out_branch=[2])
|
|
51
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
|
|
52
|
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 2])
|
|
53
|
+
|
|
54
|
+
class DDPM_UNetHead(network.ModuleArgsDict):
|
|
55
|
+
|
|
56
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
59
|
+
|
|
60
|
+
class DDPM_ForwardProcess(torch.nn.Module):
|
|
61
|
+
|
|
62
|
+
def __init__(self, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.betas = torch.linspace(beta_start, beta_end, noise_step)
|
|
65
|
+
self.betas = DDPM.DDPM_ForwardProcess.enforce_zero_terminal_snr(self.betas)
|
|
66
|
+
self.alphas = 1 - self.betas
|
|
67
|
+
self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
|
|
68
|
+
|
|
69
|
+
def enforce_zero_terminal_snr(betas: torch.Tensor):
|
|
70
|
+
alphas = 1 - betas
|
|
71
|
+
alphas_bar = alphas.cumprod(0)
|
|
72
|
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
|
73
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
74
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
75
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
|
76
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
|
77
|
+
alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
78
|
+
alphas_bar = alphas_bar_sqrt ** 2
|
|
79
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
80
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
81
|
+
betas = 1 - alphas
|
|
82
|
+
return betas
|
|
83
|
+
|
|
84
|
+
def forward(self, input: torch.Tensor, t: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
|
|
85
|
+
alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
|
|
86
|
+
result = alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * eta
|
|
87
|
+
return result
|
|
88
|
+
|
|
89
|
+
class DDPM_SampleT(torch.nn.Module):
|
|
90
|
+
|
|
91
|
+
def __init__(self, noise_step: int) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
self.noise_step = noise_step
|
|
94
|
+
|
|
95
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
return torch.randint(0, self.noise_step, (input.shape[0],)).to(input.device)
|
|
97
|
+
|
|
98
|
+
class DDPM_TimeEmbedding(torch.nn.Module):
|
|
99
|
+
|
|
100
|
+
def sinusoidal_embedding(noise_step: int, time_embedding_dim: int):
|
|
101
|
+
noise_step += 1
|
|
102
|
+
embedding = torch.zeros(noise_step, time_embedding_dim)
|
|
103
|
+
wk = torch.tensor([1 / 10_000 ** (2 * j / time_embedding_dim) for j in range(time_embedding_dim)])
|
|
104
|
+
wk = wk.reshape((1, time_embedding_dim))
|
|
105
|
+
t = torch.arange(noise_step).reshape((noise_step, 1))
|
|
106
|
+
embedding[:,::2] = torch.sin(t * wk[:,::2])
|
|
107
|
+
embedding[:,1::2] = torch.cos(t * wk[:,::2])
|
|
108
|
+
return embedding
|
|
109
|
+
|
|
110
|
+
def __init__(self, noise_step: int=1000, time_embedding_dim: int=100) -> None:
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.time_embed = torch.nn.Embedding(noise_step, time_embedding_dim)
|
|
113
|
+
self.time_embed.weight.data = DDPM.DDPM_TimeEmbedding.sinusoidal_embedding(noise_step, time_embedding_dim)
|
|
114
|
+
self.time_embed.requires_grad_(False)
|
|
115
|
+
self.noise_step = noise_step
|
|
116
|
+
|
|
117
|
+
def forward(self, input: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
return self.time_embed((p*self.noise_step).long().repeat((input.shape[0])))
|
|
119
|
+
|
|
120
|
+
class DDPM_UNet(network.ModuleArgsDict):
|
|
121
|
+
|
|
122
|
+
def __init__(self, noise_step: int, channels: list[int], blockConfig: blocks.BlockConfig, nb_conv_per_stage: int, downSampleMode: str, upSampleMode: str, attention : bool, time_embedding_dim: int, dim : int) -> None:
|
|
123
|
+
super().__init__()
|
|
124
|
+
self.add_module("t", DDPM.DDPM_TimeEmbedding(noise_step, time_embedding_dim), in_branch=[1], out_branch=["te"])
|
|
125
|
+
self.add_module("UNetBlock_0", DDPM.DDPM_UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, time_embedding_dim=time_embedding_dim, dim=dim), in_branch=[0,"te"])
|
|
126
|
+
self.add_module("Head", DDPM.DDPM_UNetHead(in_channels=channels[1], out_channels=1, dim=dim))
|
|
127
|
+
|
|
128
|
+
class DDPM_Inference(torch.nn.Module):
|
|
129
|
+
|
|
130
|
+
def __init__(self, model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], train_noise_step: int, inference_noise_step: int, beta_start: float, beta_end: float) -> None:
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.model = model
|
|
133
|
+
self.train_noise_step = train_noise_step
|
|
134
|
+
self.inference_noise_step = inference_noise_step
|
|
135
|
+
self.forwardProcess = DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end)
|
|
136
|
+
|
|
137
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
138
|
+
x = torch.randn_like(input).to(input.device)
|
|
139
|
+
result = []
|
|
140
|
+
result.append(x.unsqueeze(1))
|
|
141
|
+
description = lambda : "Inference : "+gpuInfo(input.device)
|
|
142
|
+
offset = self.train_noise_step // self.inference_noise_step
|
|
143
|
+
t_list = np.round(np.arange(self.train_noise_step-1, 0, -offset)).astype(int).tolist()
|
|
144
|
+
with tqdm.tqdm(iterable = enumerate(t_list), desc = description(), total=len(t_list), leave=False, disable=True) as batch_iter:
|
|
145
|
+
for i, t in batch_iter:
|
|
146
|
+
alpha_t_hat = self.forwardProcess.alpha_hat[t]
|
|
147
|
+
alpha_t_hat_prev = self.forwardProcess.alpha_hat[t - offset] if t - offset >= 0 else self.forwardProcess.alpha_hat[0]
|
|
148
|
+
eta_theta = self.model(torch.concat((x, input), dim=1), (torch.ones(input.shape[0], 1) * t).to(input.device).long())
|
|
149
|
+
|
|
150
|
+
predicted = (x - (1 - alpha_t_hat).sqrt() * eta_theta)/alpha_t_hat.sqrt()*alpha_t_hat_prev.sqrt()
|
|
151
|
+
|
|
152
|
+
variance = ((1 - alpha_t_hat_prev) / (1 - alpha_t_hat)) * (1 - alpha_t_hat / alpha_t_hat_prev)
|
|
153
|
+
|
|
154
|
+
direction = (1-alpha_t_hat_prev-variance).sqrt()*eta_theta
|
|
155
|
+
|
|
156
|
+
x = predicted+direction
|
|
157
|
+
x += variance.sqrt()*torch.randn_like(input).to(input.device)
|
|
158
|
+
|
|
159
|
+
result.append(x.unsqueeze(1))
|
|
160
|
+
|
|
161
|
+
batch_iter.set_description(description())
|
|
162
|
+
result[-1] = torch.clip(result[-1], -1, 1)
|
|
163
|
+
return torch.concat(result, dim=1)
|
|
164
|
+
|
|
165
|
+
class DDPM_VLoss(torch.nn.Module):
|
|
166
|
+
|
|
167
|
+
def __init__(self, alpha_hat: torch.Tensor) -> None:
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.alpha_hat = alpha_hat
|
|
170
|
+
|
|
171
|
+
def v(self, input: torch.Tensor, noise: torch.Tensor, t: torch.Tensor):
|
|
172
|
+
alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
|
|
173
|
+
return alpha_hat_t.sqrt()*input-(1-alpha_hat_t).sqrt()*noise
|
|
174
|
+
|
|
175
|
+
def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
|
|
176
|
+
return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
|
|
177
|
+
|
|
178
|
+
@config("DDPM")
|
|
179
|
+
def __init__( self,
|
|
180
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
181
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
182
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
183
|
+
patch : Union[ModelPatch, None] = None,
|
|
184
|
+
train_noise_step: int = 1000,
|
|
185
|
+
inference_noise_step: int = 100,
|
|
186
|
+
beta_start: float = 1e-4,
|
|
187
|
+
beta_end: float = 0.02,
|
|
188
|
+
time_embedding_dim: int = 100,
|
|
189
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
190
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
191
|
+
nb_conv_per_stage: int = 2,
|
|
192
|
+
downSampleMode: str = "MAXPOOL",
|
|
193
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
194
|
+
attention : bool = False,
|
|
195
|
+
dim : int = 3) -> None:
|
|
196
|
+
super().__init__(in_channels=1, optimizer=optimizer, schedulers=schedulers, outputsCriterions=outputsCriterions, patch=patch, dim=dim)
|
|
197
|
+
self.add_module("Noise", blocks.NormalNoise(), out_branch=["eta"], training=True)
|
|
198
|
+
self.add_module("Sample", DDPM.DDPM_SampleT(train_noise_step), out_branch=["t"], training=True)
|
|
199
|
+
self.add_module("Forward", DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end), in_branch=[0, "t", "eta"], out_branch=["x_t"], training=True)
|
|
200
|
+
self.add_module("Concat", blocks.Concat(), in_branch=["x_t",1], out_branch=["xy_t"], training=True)
|
|
201
|
+
self.add_module("UNet", DDPM.DDPM_UNet(train_noise_step, channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, time_embedding_dim, dim), in_branch=["xy_t", "t"], out_branch=["eta_hat"], training=True)
|
|
202
|
+
|
|
203
|
+
self.add_module("Noise_optim", DDPM.DDPM_VLoss(self["Forward"].alpha_hat), in_branch=[0, "eta", "eta_hat", "t"], out_branch=["noise"], training=True)
|
|
204
|
+
|
|
205
|
+
self.add_module("Inference", DDPM.DDPM_Inference(self.inference, train_noise_step, inference_noise_step, beta_start, beta_end), in_branch=[1], training=False)
|
|
206
|
+
self.add_module("LastImage", blocks.Select([slice(None, None), slice(-1, None)]), training=False)
|
|
207
|
+
|
|
208
|
+
def inference(self, input: torch.Tensor, t: torch.Tensor):
|
|
209
|
+
return self["UNet"](input, t)
|
|
210
|
+
|
|
211
|
+
class MSE(Criterion):
|
|
212
|
+
|
|
213
|
+
def __init__(self):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.loss = torch.nn.MSELoss()
|
|
216
|
+
|
|
217
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
218
|
+
return self.loss(input[:, 0, ...], input[:, 1, ...])
|