konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,175 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from KonfAI.konfai.network import network, blocks
4
+ from KonfAI.konfai.utils.config import config
5
+ from KonfAI.konfai.data.HDF5 import ModelPatch
6
+
7
+ """
8
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]
9
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]
10
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]
11
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",[3, 3, 27, 3], dims=[192, 384, 768, 1536]
12
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
13
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
14
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
15
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
16
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]
17
+ """
18
+
19
+ class LayerNorm(torch.nn.Module):
20
+
21
+ def __init__(self, normalized_shape : int, eps : float = 1e-6, data_format : str="channels_last"):
22
+ super().__init__()
23
+ self.weight = torch.nn.parameter.Parameter(torch.ones(normalized_shape))
24
+ self.bias = torch.nn.parameter.Parameter(torch.zeros(normalized_shape))
25
+ self.eps = eps
26
+ self.data_format = data_format
27
+ if self.data_format not in ["channels_last", "channels_first"]:
28
+ raise NotImplementedError
29
+ self.normalized_shape = (normalized_shape, )
30
+
31
+ def forward(self, x):
32
+ if self.data_format == "channels_last":
33
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
34
+ elif self.data_format == "channels_first":
35
+ u = x.mean(1, keepdim=True)
36
+ s = (x - u).pow(2).mean(1, keepdim=True)
37
+ x = (x - u) / torch.sqrt(s + self.eps)
38
+ if len(x.shape) == 3:
39
+ x = self.weight[:, None] * x + self.bias[:, None]
40
+ elif len(x.shape) == 4:
41
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
42
+ else:
43
+ x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
44
+ return x
45
+
46
+ def extra_repr(self):
47
+ return "normalized_shape={}, eps={}, data_format={})".format(self.normalized_shape, self.eps, self.data_format)
48
+
49
+ class DropPath(torch.nn.Module):
50
+
51
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
52
+ super(DropPath, self).__init__()
53
+ self.drop_prob = drop_prob
54
+ self.scale_by_keep = scale_by_keep
55
+
56
+ def forward(self, x):
57
+ if self.drop_prob == 0. or not self.training:
58
+ return x
59
+ keep_prob = 1 - self.drop_prob
60
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
61
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
62
+ if keep_prob > 0.0 and self.scale_by_keep:
63
+ random_tensor.div_(keep_prob)
64
+ return x * random_tensor
65
+
66
+ def extra_repr(self):
67
+ return "drop_prob={}".format(round(self.drop_prob,3))
68
+
69
+ class LayerScaler(torch.nn.Module):
70
+
71
+ def __init__(self, init_value : float, dimensions : int):
72
+ super().__init__()
73
+ self.init_value = init_value
74
+ self.gamma = torch.nn.Parameter(torch.ones(dimensions, 1, 1) * init_value)
75
+
76
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
77
+ return self.gamma * input
78
+
79
+ def extra_repr(self):
80
+ return "init_value={}".format(self.init_value)
81
+
82
+ class BottleNeckBlock(network.ModuleArgsDict):
83
+ def __init__( self,
84
+ features: int,
85
+ drop_p: float,
86
+ layer_scaler_init_value: float,
87
+ dim: int):
88
+ super().__init__()
89
+ self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(features, features, kernel_size=7, padding=3, groups=features), alias=["dwconv"])
90
+ self.add_module("ToFeatures", blocks.ToFeatures(dim))
91
+ self.add_module("LayerNorm", LayerNorm(features, eps=1e-6), alias=["norm"])
92
+ self.add_module("Linear_1", torch.nn.Linear(features, features * 4), alias=["pwconv1"])
93
+ self.add_module("GELU", torch.nn.GELU())
94
+ self.add_module("Linear_2", torch.nn.Linear(features * 4, features), alias=["pwconv2"])
95
+ self.add_module("ToChannels", blocks.ToChannels(dim))
96
+ self.add_module("LayerScaler", LayerScaler(init_value=layer_scaler_init_value, dimensions=features), alias=[""])
97
+ self.add_module("StochasticDepth", DropPath(drop_p))
98
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
99
+
100
+ class DownSample(network.ModuleArgsDict):
101
+
102
+ def __init__( self,
103
+ in_features: int,
104
+ out_features: int,
105
+ dim : int):
106
+ super().__init__()
107
+ self.add_module("LayerNorm", LayerNorm(in_features, eps=1e-6, data_format="channels_first"), alias=["0"])
108
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_features, out_features, kernel_size=2, stride=2), alias=["1"])
109
+
110
+ class ConvNexStage(network.ModuleArgsDict):
111
+
112
+ def __init__( self,
113
+ features: int,
114
+ depth: int,
115
+ drop_p: list[float],
116
+ dim : int):
117
+ super().__init__()
118
+ for i in range(depth):
119
+ self.add_module("BottleNeckBlock_{}".format(i), BottleNeckBlock(features=features, drop_p=drop_p[i], layer_scaler_init_value=1e-6, dim=dim), alias=["{}".format(i)])
120
+
121
+ class ConvNextStem(network.ModuleArgsDict):
122
+
123
+ def __init__(self, in_features: int, out_features: int, dim: int):
124
+ super().__init__()
125
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_features, out_features, kernel_size=4, stride=4), alias=["0"])
126
+ self.add_module("LayerNorm", LayerNorm(out_features, eps=1e-6, data_format="channels_first"), alias=["1"])
127
+
128
+ class ConvNextEncoder(network.ModuleArgsDict):
129
+
130
+ def __init__( self,
131
+ in_channels: int,
132
+ depths: list[int],
133
+ widths: list[int],
134
+ drop_p: float,
135
+ dim : int):
136
+ super().__init__()
137
+ self.add_module("ConvNextStem", ConvNextStem(in_channels, widths[0], dim=dim), alias=["downsample_layers.0"])
138
+
139
+ drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))]
140
+ self.add_module("ConvNexStage_0", ConvNexStage(features=widths[0], depth=depths[0], drop_p=drop_probs[:depths[0]], dim=dim), alias=["stages.0"])
141
+
142
+ for i, (in_features, out_features) in enumerate(list(zip(widths[:], widths[1:]))):
143
+ self.add_module("DownSample_{}".format(i+1), DownSample(in_features=in_features, out_features=out_features, dim=dim), alias=["downsample_layers.{}".format(i+1)])
144
+ self.add_module("ConvNexStage_{}".format(i+1), ConvNexStage(features=out_features, depth=depths[i+1], drop_p=drop_probs[sum(depths[:i+1]):sum(depths[:i+2])], dim=dim), alias=["stages.{}".format(i+1)])
145
+
146
+ class Head(network.ModuleArgsDict):
147
+
148
+ def __init__(self, in_features : int, num_classes : list[int], dim : int) -> None:
149
+ super().__init__()
150
+ self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
151
+ self.add_module("Flatten", torch.nn.Flatten(1))
152
+ self.add_module("LayerNorm", torch.nn.LayerNorm(in_features, eps=1e-6), alias=["norm"])
153
+
154
+ for i, nb_classe in enumerate(num_classes):
155
+ self.add_module("Linear_{}".format(i), torch.nn.Linear(in_features, nb_classe), pretrained=False, alias=["head"], out_branch=[i+1])
156
+ self.add_module("Unsqueeze_{}".format(i), blocks.Unsqueeze(2), in_branch=[i+1], out_branch=[-1])
157
+
158
+ class ConvNeXt(network.Network):
159
+
160
+ @config("ConvNeXt")
161
+ def __init__( self,
162
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
163
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
164
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
165
+ patch : ModelPatch = ModelPatch(),
166
+ dim : int = 3,
167
+ in_channels: int = 1,
168
+ depths: list[int] = [3,3,27,3],
169
+ widths: list[int] = [128, 256, 512, 1024],
170
+ drop_p: float = 0.1,
171
+ num_classes: list[int] = [4, 7]):
172
+
173
+ super().__init__(in_channels = in_channels, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, patch=patch, init_type = "trunc_normal", init_gain=0.02)
174
+ self.add_module("ConvNextEncoder", ConvNextEncoder(in_channels=in_channels, depths=depths, widths=widths, drop_p=drop_p, dim=dim))
175
+ self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
@@ -0,0 +1,116 @@
1
+ from abc import ABC
2
+ from typing import Type
3
+ import torch
4
+ from KonfAI.konfai.network import network, blocks
5
+ from KonfAI.konfai.utils.config import config
6
+ from KonfAI.konfai.data.HDF5 import ModelPatch
7
+
8
+ """
9
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
10
+
11
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
12
+
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', dim = 2, in_channels = 3, depths=[3, 4, 6, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
14
+
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', dim = 2, in_channels = 3, depths=[3, 4, 23, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
16
+
17
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', dim = 2, in_channels = 3, depths=[3, 8, 36, 3], widths = [64, 256, 512, 1024, 2048], num_classes=1000, useBottleneck=True
18
+
19
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23
+ """
24
+ class AbstractResBlock(network.ModuleArgsDict, ABC):
25
+
26
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
27
+ super().__init__()
28
+
29
+ class ResBlock(AbstractResBlock):
30
+
31
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
32
+ super().__init__(in_channels, out_channels, downsample, dim)
33
+ if downsample:
34
+ self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
35
+
36
+ self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
37
+ self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels, out_channels, 1, blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
38
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
39
+ self.add_module("ReLU", torch.nn.ReLU())
40
+
41
+ class ResBottleneckBlock(AbstractResBlock):
42
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool, dim: int):
43
+ super().__init__(in_channels, out_channels, downsample, dim)
44
+ self.add_module("ConvBlock_0", blocks.ConvBlock(in_channels, out_channels//4, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
45
+ self.add_module("ConvBlock_1", blocks.ConvBlock(out_channels//4, out_channels//4, 1, blocks.BlockConfig(kernel_size=3, stride=2 if downsample else 1, padding=1, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv2"], ["bn2"], []]))
46
+ self.add_module("ConvBlock_2", blocks.ConvBlock(out_channels//4, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=1, padding=0, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv3"], ["bn3"], []]))
47
+
48
+ if downsample or in_channels != out_channels:
49
+ self.add_module("Shortcut", blocks.ConvBlock(in_channels, out_channels, 1, blocks.BlockConfig(kernel_size=1, stride=2 if downsample else 1, padding=0, bias=False, activation="None", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["0"], ["1"], []]), in_branch=[1], out_branch=[1], alias=["downsample"])
50
+
51
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
52
+ self.add_module("ReLU", torch.nn.ReLU())
53
+
54
+ class ResNetStage(network.ModuleArgsDict):
55
+
56
+ def __init__(self,
57
+ in_channels: int,
58
+ out_channels: int,
59
+ depth: int,
60
+ block : Type[AbstractResBlock],
61
+ downsample : bool,
62
+ dim : int):
63
+ super().__init__()
64
+ self.add_module("BottleNeckBlock_0", block(in_channels=in_channels, out_channels=out_channels, downsample=downsample, dim=dim), alias=["0"])
65
+ for i in range(1, depth):
66
+ self.add_module("BottleNeckBlock_{}".format(i), block(in_channels=out_channels, out_channels=out_channels, downsample=False, dim=dim), alias=["{}".format(i)])
67
+
68
+ class ResNetStem(network.ModuleArgsDict):
69
+
70
+ def __init__(self, in_channels: int, out_features: int, dim: int):
71
+ super().__init__()
72
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_features, 1, blocks.BlockConfig(kernel_size=7, stride=2, padding=3, bias=False, activation="ReLU", normMode=blocks.NormMode.BATCH.name), dim=dim, alias=[["conv1"], ["bn1"], []]))
73
+ self.add_module("MaxPool", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1))
74
+
75
+ class ResNetEncoder(network.ModuleArgsDict):
76
+
77
+ def __init__(self,
78
+ in_channels: int,
79
+ depths: list[int],
80
+ widths: list[int],
81
+ useBottleneck : bool,
82
+ dim : int):
83
+ super().__init__()
84
+ self.add_module("ResNetStem", ResNetStem(in_channels, widths[0], dim=dim))
85
+
86
+ for i, (in_channels, out_channels, depth) in enumerate(list(zip(widths[:], widths[1:], depths))):
87
+ self.add_module("ResNetStage_{}".format(i), ResNetStage(in_channels=in_channels, out_channels=out_channels, depth=depth, block=ResBottleneckBlock if useBottleneck else ResBlock, downsample=i!=0, dim=dim), alias=["layer{}".format(i+1)])
88
+
89
+ class Head(network.ModuleArgsDict):
90
+
91
+ def __init__(self, in_features : int, num_classes : int, dim : int):
92
+ super().__init__()
93
+ self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(1))
94
+ self.add_module("Flatten", torch.nn.Flatten(1))
95
+ self.add_module("Linear", torch.nn.Linear(in_features, num_classes), pretrained=False, alias=["fc"])
96
+ self.add_module("Unsqueeze", blocks.Unsqueeze(2))
97
+
98
+
99
+ class ResNet(network.Network):
100
+
101
+ @config("ResNet")
102
+ def __init__( self,
103
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
104
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
105
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
106
+ patch : ModelPatch = ModelPatch(),
107
+ dim : int = 3,
108
+ in_channels: int = 1,
109
+ depths: list[int] = [2, 2, 2, 2],
110
+ widths: list[int] = [64, 64, 128, 256, 512],
111
+ num_classes: int = 10,
112
+ useBottleneck=False):
113
+ super().__init__(in_channels = in_channels, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, patch=patch, init_type = "trunc_normal", init_gain=0.02)
114
+ self.add_module("ResNetEncoder", ResNetEncoder(in_channels=in_channels, depths=depths, widths=widths, useBottleneck=useBottleneck, dim=dim))
115
+ self.add_module("Head", Head(in_features=widths[-1], num_classes=num_classes, dim=dim))
116
+
@@ -0,0 +1,137 @@
1
+ import importlib
2
+ import torch
3
+ from KonfAI.konfai.network import network, blocks
4
+ from KonfAI.konfai.utils.config import config
5
+ from KonfAI.konfai.data.HDF5 import ModelPatch
6
+ class MappingNetwork(network.ModuleArgsDict):
7
+ def __init__(self, z_dim: int, c_dim: int, w_dim: int, num_layers: int, embed_features: int, layer_features: int):
8
+ super().__init__()
9
+
10
+ self.add_module("Concat_1", blocks.Concat(), in_branch=[0,1])
11
+
12
+ features = [z_dim + embed_features if c_dim > 0 else 0] + [layer_features] * (num_layers - 1) + [w_dim]
13
+ if c_dim > 0:
14
+ self.add_module("Linear", torch.nn.Linear(c_dim, embed_features), out_branch=["Embed"])
15
+
16
+ self.add_module("Noise", blocks.NormalNoise(z_dim), in_branch=["Embed"])
17
+ if c_dim > 0:
18
+ self.add_module("Concat", blocks.Concat(), in_branch=[0,"Embed"])
19
+
20
+ for i, (in_features, out_features) in enumerate(zip(features, features[1:])):
21
+ self.add_module("Linear_{}".format(i), torch.nn.Linear(in_features, out_features))
22
+
23
+ class ModulatedConv(torch.nn.Module):
24
+
25
+ class _ModulatedConv(torch.nn.Module):
26
+
27
+ def __init__(self, w_dim: int, conv: torch.nn.modules.conv._ConvNd, dim: int) -> None:
28
+ super().__init__()
29
+ self.affine = torch.nn.Linear(w_dim, conv.in_channels)
30
+ self.isConv = True
31
+ self.in_channels = conv.in_channels
32
+ self.out_channels = conv.out_channels
33
+ self.padding = conv.padding
34
+ self.stride = conv.stride
35
+ if isinstance(conv, torch.nn.modules.conv._ConvTransposeNd):
36
+ self.weight = torch.nn.parameter.Parameter(torch.randn((conv.in_channels, conv.out_channels, *conv.kernel_size)))
37
+ self.isConv = False
38
+ else:
39
+ self.weight = torch.nn.parameter.Parameter(torch.randn((conv.out_channels, conv.in_channels, *conv.kernel_size)))
40
+ conv.forward = self.forward
41
+ self.styles = None
42
+ self.dim = dim
43
+
44
+ def setStyle(self, styles: torch.Tensor) -> None:
45
+ self.styles = styles
46
+
47
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
48
+ b = input.shape[0]
49
+ self.affine.to(input.device)
50
+ styles = self.affine(self.styles)
51
+ w1 = styles.reshape(b, -1, 1, *[1 for _ in range(self.dim)]) if not self.isConv else styles.reshape(b, 1, -1, *[1 for _ in range(self.dim)])
52
+ w2 = self.weight.unsqueeze(0).to(input.device)
53
+ weights = w2 * (w1 + 1)
54
+
55
+ d = torch.rsqrt((weights ** 2).sum(dim=tuple([i+2 for i in range(len(weights.shape)-2)]), keepdim=True) + 1e-8)
56
+ weights = weights * d
57
+
58
+ input = input.reshape(1, -1, *input.shape[2:])
59
+
60
+ _, _, *ws = weights.shape
61
+ if not self.isConv:
62
+ out = getattr(importlib.import_module("torch.nn.functional"), "conv_transpose{}d".format(self.dim))(input, weights.reshape(b * self.in_channels, *ws), stride=self.stride, padding=self.padding, groups=b)
63
+ else:
64
+ out = getattr(importlib.import_module("torch.nn.functional"), "conv{}d".format(self.dim))(input, weights.reshape(b * self.out_channels, *ws), padding=self.padding, groups=b, stride=self.stride)
65
+
66
+ out = out.reshape(-1, self.out_channels, *out.shape[2:])
67
+ return out
68
+
69
+ def __init__(self, w_dim: int, module: torch.nn.Module) -> None:
70
+ super().__init__()
71
+ self.w_dim = w_dim
72
+ self.module = module
73
+ self.convs = torch.nn.ModuleList()
74
+ self.module.apply(self.apply)
75
+
76
+ def forward(self, input: torch.Tensor, styles: torch.Tensor) -> torch.Tensor:
77
+ for conv in self.convs:
78
+ conv.setStyle(styles.clone())
79
+ return self.module(input)
80
+
81
+ def apply(self, module: torch.nn.Module):
82
+ if isinstance(module, torch.nn.modules.conv._ConvNd):
83
+ delattr(module, "weight")
84
+ module.bias = None
85
+
86
+ str_dim = module.__class__.__name__[-2:]
87
+ dim = 1
88
+ if str_dim == "2d":
89
+ dim = 2
90
+ elif str_dim == "3d":
91
+ dim = 3
92
+ self.convs.append(ModulatedConv._ModulatedConv(self.w_dim, module, dim=dim))
93
+
94
+ class UNetBlock(network.ModuleArgsDict):
95
+
96
+ def __init__(self, w_dim: int, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, dim: int, i : int = 0) -> None:
97
+ super().__init__()
98
+ if i > 0:
99
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
100
+ self.add_module("DownConvBlock", blocks.ConvBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
101
+ if len(channels) > 2:
102
+ self.add_module("UNetBlock_{}".format(i+1), UNetBlock(w_dim, channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, dim, i+1), in_branch=[0,1])
103
+ self.add_module("UpConvBlock", ModulatedConv(w_dim, blocks.ConvBlock((channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim)), in_branch=[0,1])
104
+ if i > 0:
105
+ if attention:
106
+ self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=["Skip", 0], out_branch=["Skip"])
107
+ self.add_module(upSampleMode.name, ModulatedConv(w_dim, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim)), in_branch=[0, 1])
108
+ self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, "Skip"])
109
+
110
+ class Generator(network.Network):
111
+
112
+ class GeneratorHead(network.ModuleArgsDict):
113
+
114
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
115
+ super().__init__()
116
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
117
+ self.add_module("Tanh", torch.nn.Tanh())
118
+
119
+ @config("Generator")
120
+ def __init__(self,
121
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
122
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
123
+ patch : ModelPatch = ModelPatch(),
124
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
125
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
126
+ nb_batch_per_step: int = 64,
127
+ z_dim: int = 512,
128
+ c_dim: int = 1,
129
+ w_dim: int = 512,
130
+ dim : int = 3) -> None:
131
+ super().__init__(optimizer=optimizer, in_channels=channels[0], schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
132
+
133
+ self.add_module("MappingNetwork", MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_layers=8, embed_features=w_dim, layer_features=w_dim), in_branch=[1,2], out_branch=["Style"])
134
+ nb_conv_per_stage = 2
135
+ blockConfig = blocks.BlockConfig(kernel_size=3, stride=1, padding=1, bias=True, activation="ReLU", normMode="INSTANCE")
136
+ self.add_module("UNetBlock_0", UNetBlock(w_dim, channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode.MAXPOOL, upSampleMode=blocks.UpSampleMode.CONV_TRANSPOSE, attention=False, dim=dim), in_branch=[0, "Style"])
137
+ self.add_module("Head", Generator.GeneratorHead(in_channels=channels[1], out_channels=1, dim=dim))
@@ -0,0 +1,218 @@
1
+
2
+
3
+
4
+ from typing import Union, Callable
5
+ import torch
6
+ import tqdm
7
+ import numpy as np
8
+
9
+ from KonfAI.konfai.network import network, blocks
10
+ from KonfAI.konfai.utils.config import config
11
+ from KonfAI.konfai.data.HDF5 import ModelPatch
12
+ from KonfAI.konfai.utils.utils import gpuInfo
13
+ from KonfAI.konfai.metric.measure import Criterion
14
+
15
+ def cosine_beta_schedule(timesteps, s=0.008):
16
+ steps = timesteps + 1
17
+ x = torch.linspace(0, timesteps, steps)
18
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
19
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
20
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
21
+ return torch.clip(betas, 0.0001, 0.9999)
22
+
23
+ class DDPM(network.Network):
24
+
25
+ class DDPM_TE(torch.nn.Module):
26
+
27
+ def __init__(self, in_channels: int, out_channels: int) -> None:
28
+ super().__init__()
29
+ self.linear_0 = torch.nn.Linear(in_channels, out_channels)
30
+ self.siLU = torch.nn.SiLU()
31
+ self.linear_1 = torch.nn.Linear(out_channels, out_channels)
32
+
33
+ def forward(self, input: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
34
+ return input + self.linear_1(self.siLU(self.linear_0(t))).reshape(input.shape[0], -1, *[1 for _ in range(len(input.shape)-2)])
35
+
36
+ class DDPM_UNetBlock(network.ModuleArgsDict):
37
+
38
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, time_embedding_dim: int, dim: int, i : int = 0) -> None:
39
+ super().__init__()
40
+ if i > 0:
41
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
42
+ self.add_module("Te_down", DDPM.DDPM_TE(time_embedding_dim, channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0]), in_branch=[0, 1])
43
+ self.add_module("DownConvBlock", blocks.ResBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
44
+ if len(channels) > 2:
45
+ self.add_module("UNetBlock_{}".format(i+1), DDPM.DDPM_UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, time_embedding_dim, dim, i+1), in_branch=[0,1])
46
+ self.add_module("Te_up", DDPM.DDPM_TE(time_embedding_dim, (channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2), in_branch=[0, 1])
47
+ self.add_module("UpConvBlock", blocks.ResBlock(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
48
+ if i > 0:
49
+ if attention:
50
+ self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[2, 0], out_branch=[2])
51
+ self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
52
+ self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 2])
53
+
54
+ class DDPM_UNetHead(network.ModuleArgsDict):
55
+
56
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
57
+ super().__init__()
58
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
59
+
60
+ class DDPM_ForwardProcess(torch.nn.Module):
61
+
62
+ def __init__(self, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
63
+ super().__init__()
64
+ self.betas = torch.linspace(beta_start, beta_end, noise_step)
65
+ self.betas = DDPM.DDPM_ForwardProcess.enforce_zero_terminal_snr(self.betas)
66
+ self.alphas = 1 - self.betas
67
+ self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
68
+
69
+ def enforce_zero_terminal_snr(betas: torch.Tensor):
70
+ alphas = 1 - betas
71
+ alphas_bar = alphas.cumprod(0)
72
+ alphas_bar_sqrt = alphas_bar.sqrt()
73
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
74
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
75
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
76
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
77
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
78
+ alphas_bar = alphas_bar_sqrt ** 2
79
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
80
+ alphas = torch.cat([alphas_bar[0:1], alphas])
81
+ betas = 1 - alphas
82
+ return betas
83
+
84
+ def forward(self, input: torch.Tensor, t: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
85
+ alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
86
+ result = alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * eta
87
+ return result
88
+
89
+ class DDPM_SampleT(torch.nn.Module):
90
+
91
+ def __init__(self, noise_step: int) -> None:
92
+ super().__init__()
93
+ self.noise_step = noise_step
94
+
95
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
96
+ return torch.randint(0, self.noise_step, (input.shape[0],)).to(input.device)
97
+
98
+ class DDPM_TimeEmbedding(torch.nn.Module):
99
+
100
+ def sinusoidal_embedding(noise_step: int, time_embedding_dim: int):
101
+ noise_step += 1
102
+ embedding = torch.zeros(noise_step, time_embedding_dim)
103
+ wk = torch.tensor([1 / 10_000 ** (2 * j / time_embedding_dim) for j in range(time_embedding_dim)])
104
+ wk = wk.reshape((1, time_embedding_dim))
105
+ t = torch.arange(noise_step).reshape((noise_step, 1))
106
+ embedding[:,::2] = torch.sin(t * wk[:,::2])
107
+ embedding[:,1::2] = torch.cos(t * wk[:,::2])
108
+ return embedding
109
+
110
+ def __init__(self, noise_step: int=1000, time_embedding_dim: int=100) -> None:
111
+ super().__init__()
112
+ self.time_embed = torch.nn.Embedding(noise_step, time_embedding_dim)
113
+ self.time_embed.weight.data = DDPM.DDPM_TimeEmbedding.sinusoidal_embedding(noise_step, time_embedding_dim)
114
+ self.time_embed.requires_grad_(False)
115
+ self.noise_step = noise_step
116
+
117
+ def forward(self, input: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
118
+ return self.time_embed((p*self.noise_step).long().repeat((input.shape[0])))
119
+
120
+ class DDPM_UNet(network.ModuleArgsDict):
121
+
122
+ def __init__(self, noise_step: int, channels: list[int], blockConfig: blocks.BlockConfig, nb_conv_per_stage: int, downSampleMode: str, upSampleMode: str, attention : bool, time_embedding_dim: int, dim : int) -> None:
123
+ super().__init__()
124
+ self.add_module("t", DDPM.DDPM_TimeEmbedding(noise_step, time_embedding_dim), in_branch=[1], out_branch=["te"])
125
+ self.add_module("UNetBlock_0", DDPM.DDPM_UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, time_embedding_dim=time_embedding_dim, dim=dim), in_branch=[0,"te"])
126
+ self.add_module("Head", DDPM.DDPM_UNetHead(in_channels=channels[1], out_channels=1, dim=dim))
127
+
128
+ class DDPM_Inference(torch.nn.Module):
129
+
130
+ def __init__(self, model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], train_noise_step: int, inference_noise_step: int, beta_start: float, beta_end: float) -> None:
131
+ super().__init__()
132
+ self.model = model
133
+ self.train_noise_step = train_noise_step
134
+ self.inference_noise_step = inference_noise_step
135
+ self.forwardProcess = DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end)
136
+
137
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
138
+ x = torch.randn_like(input).to(input.device)
139
+ result = []
140
+ result.append(x.unsqueeze(1))
141
+ description = lambda : "Inference : "+gpuInfo(input.device)
142
+ offset = self.train_noise_step // self.inference_noise_step
143
+ t_list = np.round(np.arange(self.train_noise_step-1, 0, -offset)).astype(int).tolist()
144
+ with tqdm.tqdm(iterable = enumerate(t_list), desc = description(), total=len(t_list), leave=False, disable=True) as batch_iter:
145
+ for i, t in batch_iter:
146
+ alpha_t_hat = self.forwardProcess.alpha_hat[t]
147
+ alpha_t_hat_prev = self.forwardProcess.alpha_hat[t - offset] if t - offset >= 0 else self.forwardProcess.alpha_hat[0]
148
+ eta_theta = self.model(torch.concat((x, input), dim=1), (torch.ones(input.shape[0], 1) * t).to(input.device).long())
149
+
150
+ predicted = (x - (1 - alpha_t_hat).sqrt() * eta_theta)/alpha_t_hat.sqrt()*alpha_t_hat_prev.sqrt()
151
+
152
+ variance = ((1 - alpha_t_hat_prev) / (1 - alpha_t_hat)) * (1 - alpha_t_hat / alpha_t_hat_prev)
153
+
154
+ direction = (1-alpha_t_hat_prev-variance).sqrt()*eta_theta
155
+
156
+ x = predicted+direction
157
+ x += variance.sqrt()*torch.randn_like(input).to(input.device)
158
+
159
+ result.append(x.unsqueeze(1))
160
+
161
+ batch_iter.set_description(description())
162
+ result[-1] = torch.clip(result[-1], -1, 1)
163
+ return torch.concat(result, dim=1)
164
+
165
+ class DDPM_VLoss(torch.nn.Module):
166
+
167
+ def __init__(self, alpha_hat: torch.Tensor) -> None:
168
+ super().__init__()
169
+ self.alpha_hat = alpha_hat
170
+
171
+ def v(self, input: torch.Tensor, noise: torch.Tensor, t: torch.Tensor):
172
+ alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
173
+ return alpha_hat_t.sqrt()*input-(1-alpha_hat_t).sqrt()*noise
174
+
175
+ def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
176
+ return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
177
+
178
+ @config("DDPM")
179
+ def __init__( self,
180
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
181
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
182
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
183
+ patch : Union[ModelPatch, None] = None,
184
+ train_noise_step: int = 1000,
185
+ inference_noise_step: int = 100,
186
+ beta_start: float = 1e-4,
187
+ beta_end: float = 0.02,
188
+ time_embedding_dim: int = 100,
189
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
190
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
191
+ nb_conv_per_stage: int = 2,
192
+ downSampleMode: str = "MAXPOOL",
193
+ upSampleMode: str = "CONV_TRANSPOSE",
194
+ attention : bool = False,
195
+ dim : int = 3) -> None:
196
+ super().__init__(in_channels=1, optimizer=optimizer, schedulers=schedulers, outputsCriterions=outputsCriterions, patch=patch, dim=dim)
197
+ self.add_module("Noise", blocks.NormalNoise(), out_branch=["eta"], training=True)
198
+ self.add_module("Sample", DDPM.DDPM_SampleT(train_noise_step), out_branch=["t"], training=True)
199
+ self.add_module("Forward", DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end), in_branch=[0, "t", "eta"], out_branch=["x_t"], training=True)
200
+ self.add_module("Concat", blocks.Concat(), in_branch=["x_t",1], out_branch=["xy_t"], training=True)
201
+ self.add_module("UNet", DDPM.DDPM_UNet(train_noise_step, channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, time_embedding_dim, dim), in_branch=["xy_t", "t"], out_branch=["eta_hat"], training=True)
202
+
203
+ self.add_module("Noise_optim", DDPM.DDPM_VLoss(self["Forward"].alpha_hat), in_branch=[0, "eta", "eta_hat", "t"], out_branch=["noise"], training=True)
204
+
205
+ self.add_module("Inference", DDPM.DDPM_Inference(self.inference, train_noise_step, inference_noise_step, beta_start, beta_end), in_branch=[1], training=False)
206
+ self.add_module("LastImage", blocks.Select([slice(None, None), slice(-1, None)]), training=False)
207
+
208
+ def inference(self, input: torch.Tensor, t: torch.Tensor):
209
+ return self["UNet"](input, t)
210
+
211
+ class MSE(Criterion):
212
+
213
+ def __init__(self):
214
+ super().__init__()
215
+ self.loss = torch.nn.MSELoss()
216
+
217
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
218
+ return self.loss(input[:, 0, ...], input[:, 1, ...])