Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
  5. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
  6. myosotis_researches-0.1.9.dist-info/RECORD +24 -0
  7. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  8. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  9. myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  10. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  11. myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  12. myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  13. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  14. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  15. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  16. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  17. myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  18. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  19. myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  20. myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  21. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  22. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  23. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  24. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  25. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  26. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  27. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  28. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  29. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  30. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  31. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  32. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  33. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  34. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  35. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  36. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  37. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  38. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  39. myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  40. myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  41. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  42. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  43. myosotis_researches-0.1.8.dist-info/RECORD +0 -59
  44. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  45. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
  46. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
  47. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,276 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from torch.nn.utils import spectral_norm
7
- from torch.nn.init import xavier_uniform_
8
-
9
-
10
- def init_weights(m):
11
- if type(m) == nn.Linear or type(m) == nn.Conv2d:
12
- xavier_uniform_(m.weight)
13
- m.bias.data.fill_(0.)
14
-
15
-
16
- def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
17
- return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
18
- stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
19
-
20
-
21
- def snlinear(in_features, out_features):
22
- return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
23
-
24
-
25
- def sn_embedding(num_embeddings, embedding_dim):
26
- return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
27
-
28
-
29
- class Self_Attn(nn.Module):
30
- """ Self attention Layer"""
31
-
32
- def __init__(self, in_channels):
33
- super(Self_Attn, self).__init__()
34
- self.in_channels = in_channels
35
- self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
36
- self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
37
- self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
38
- self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
39
- self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
40
- self.softmax = nn.Softmax(dim=-1)
41
- self.sigma = nn.Parameter(torch.zeros(1))
42
-
43
- def forward(self, x):
44
- """
45
- inputs :
46
- x : input feature maps(B X C X W X H)
47
- returns :
48
- out : self attention value + input feature
49
- attention: B X N X N (N is Width*Height)
50
- """
51
- _, ch, h, w = x.size()
52
- # Theta path
53
- theta = self.snconv1x1_theta(x)
54
- theta = theta.view(-1, ch//8, h*w)
55
- # Phi path
56
- phi = self.snconv1x1_phi(x)
57
- phi = self.maxpool(phi)
58
- phi = phi.view(-1, ch//8, h*w//4)
59
- # Attn map
60
- attn = torch.bmm(theta.permute(0, 2, 1), phi)
61
- attn = self.softmax(attn)
62
- # g path
63
- g = self.snconv1x1_g(x)
64
- g = self.maxpool(g)
65
- g = g.view(-1, ch//2, h*w//4)
66
- # Attn_g
67
- attn_g = torch.bmm(g, attn.permute(0, 2, 1))
68
- attn_g = attn_g.view(-1, ch//2, h, w)
69
- attn_g = self.snconv1x1_attn(attn_g)
70
- # Out
71
- out = x + self.sigma*attn_g
72
- return out
73
-
74
-
75
- class ConditionalBatchNorm2d(nn.Module):
76
- # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775
77
- def __init__(self, num_features, num_classes):
78
- super().__init__()
79
- self.num_features = num_features
80
- self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
81
- self.embed = nn.Embedding(num_classes, num_features * 2)
82
- # self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
83
- self.embed.weight.data[:, :num_features].fill_(1.) # Initialize scale to 1
84
- self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0
85
-
86
- def forward(self, x, y):
87
- out = self.bn(x)
88
- gamma, beta = self.embed(y).chunk(2, 1)
89
- out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
90
- return out
91
-
92
-
93
- class GenBlock(nn.Module):
94
- def __init__(self, in_channels, out_channels, num_classes):
95
- super(GenBlock, self).__init__()
96
- self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
97
- self.relu = nn.ReLU(inplace=True)
98
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
99
- self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
100
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
101
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
102
-
103
- def forward(self, x, labels):
104
- x0 = x
105
-
106
- x = self.cond_bn1(x, labels)
107
- x = self.relu(x)
108
- x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
109
- x = self.snconv2d1(x)
110
- x = self.cond_bn2(x, labels)
111
- x = self.relu(x)
112
- x = self.snconv2d2(x)
113
-
114
- x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
115
- x0 = self.snconv2d0(x0)
116
-
117
- out = x + x0
118
- return out
119
-
120
-
121
- class cGAN_SAGAN_Generator(nn.Module):
122
- """Generator."""
123
-
124
- def __init__(self, z_dim, num_classes, g_conv_dim=64):
125
- super(cGAN_SAGAN_Generator, self).__init__()
126
-
127
- self.z_dim = z_dim
128
- self.g_conv_dim = g_conv_dim
129
- self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)
130
- self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)
131
- self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)
132
- self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)
133
- self.self_attn = Self_Attn(g_conv_dim*4)
134
- self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)
135
- self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)
136
- self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
137
- self.relu = nn.ReLU(inplace=True)
138
- self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
139
- self.tanh = nn.Tanh()
140
-
141
- # Weight init
142
- self.apply(init_weights)
143
-
144
- def forward(self, z, labels):
145
- # n x z_dim
146
- act0 = self.snlinear0(z) # n x g_conv_dim*16*4*4
147
- act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
148
- act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8
149
- act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16
150
- act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32
151
- act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
152
- act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64
153
- act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128
154
- act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
155
- act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
156
- act6 = self.snconv2d1(act5) # n x 3 x 128 x 128
157
- act6 = self.tanh(act6) # n x 3 x 128 x 128
158
- return act6
159
-
160
-
161
- class DiscOptBlock(nn.Module):
162
- def __init__(self, in_channels, out_channels):
163
- super(DiscOptBlock, self).__init__()
164
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
165
- self.relu = nn.ReLU(inplace=True)
166
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
167
- self.downsample = nn.AvgPool2d(2)
168
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
169
-
170
- def forward(self, x):
171
- x0 = x
172
-
173
- x = self.snconv2d1(x)
174
- x = self.relu(x)
175
- x = self.snconv2d2(x)
176
- x = self.downsample(x)
177
-
178
- x0 = self.downsample(x0)
179
- x0 = self.snconv2d0(x0)
180
-
181
- out = x + x0
182
- return out
183
-
184
-
185
- class DiscBlock(nn.Module):
186
- def __init__(self, in_channels, out_channels):
187
- super(DiscBlock, self).__init__()
188
- self.relu = nn.ReLU(inplace=True)
189
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
190
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
191
- self.downsample = nn.AvgPool2d(2)
192
- self.ch_mismatch = False
193
- if in_channels != out_channels:
194
- self.ch_mismatch = True
195
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
196
-
197
- def forward(self, x, downsample=True):
198
- x0 = x
199
-
200
- x = self.relu(x)
201
- x = self.snconv2d1(x)
202
- x = self.relu(x)
203
- x = self.snconv2d2(x)
204
- if downsample:
205
- x = self.downsample(x)
206
-
207
- if downsample or self.ch_mismatch:
208
- x0 = self.snconv2d0(x0)
209
- if downsample:
210
- x0 = self.downsample(x0)
211
-
212
- out = x + x0
213
- return out
214
-
215
-
216
- class cGAN_SAGAN_Discriminator(nn.Module):
217
- """Discriminator."""
218
-
219
- def __init__(self, num_classes, d_conv_dim=64):
220
- super(cGAN_SAGAN_Discriminator, self).__init__()
221
- self.d_conv_dim = d_conv_dim
222
- self.opt_block1 = DiscOptBlock(3, d_conv_dim)
223
- self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
224
- self.self_attn = Self_Attn(d_conv_dim*2)
225
- self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
226
- self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
227
- self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
228
- self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
229
- self.relu = nn.ReLU(inplace=True)
230
- self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)
231
- self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)
232
-
233
- # Weight init
234
- self.apply(init_weights)
235
- xavier_uniform_(self.sn_embedding1.weight)
236
-
237
- def forward(self, x, labels):
238
- # n x 3 x 128 x 128
239
- h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
240
- h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
241
- h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
242
- h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
243
- h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
244
- h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
245
- h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
246
- h5 = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
247
- h6 = torch.sum(h5, dim=[2,3]) # n x d_conv_dim*16
248
- output1 = torch.squeeze(self.snlinear1(h6)) # n
249
- # Projection
250
- h_labels = self.sn_embedding1(labels) # n x d_conv_dim*16
251
- proj = torch.mul(h6, h_labels) # n x d_conv_dim*16
252
- output2 = torch.sum(proj, dim=[1]) # n
253
- # Out
254
- output = output1 + output2 # n
255
- return output
256
-
257
-
258
-
259
- if __name__ == "__main__":
260
-
261
- num_classes = 10
262
-
263
- netG = cGAN_SAGAN_Generator(z_dim=128, num_classes=num_classes, g_conv_dim=128).cuda()
264
- netD = cGAN_SAGAN_Discriminator(num_classes=num_classes, d_conv_dim=128).cuda()
265
-
266
- n = 4
267
- # target = torch.randint(high=num_classes, size=(1,n)) # set size (2,10) for MHE
268
- # y = torch.zeros(n, num_classes)
269
- # y[range(y.shape[0]), target]=1
270
- # y = y.type(torch.long).cuda()
271
- y = torch.randint(high=num_classes, size=(n,)).type(torch.long).cuda()
272
- z = torch.randn(n, 128).cuda()
273
- x = netG(z,y)
274
- o = netD(x,y)
275
- print(x.size())
276
- print(o.size())
@@ -1,245 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from torch.nn.utils import spectral_norm
7
- from torch.nn.init import xavier_uniform_
8
-
9
-
10
- def init_weights(m):
11
- if type(m) == nn.Linear or type(m) == nn.Conv2d:
12
- xavier_uniform_(m.weight)
13
- m.bias.data.fill_(0.)
14
-
15
-
16
- def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
17
- return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
18
- stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
19
-
20
-
21
- def snlinear(in_features, out_features):
22
- return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
23
-
24
-
25
- def sn_embedding(num_embeddings, embedding_dim):
26
- return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
27
-
28
-
29
- class Self_Attn(nn.Module):
30
- """ Self attention Layer"""
31
-
32
- def __init__(self, in_channels):
33
- super(Self_Attn, self).__init__()
34
- self.in_channels = in_channels
35
- self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
36
- self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
37
- self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
38
- self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
39
- self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
40
- self.softmax = nn.Softmax(dim=-1)
41
- self.sigma = nn.Parameter(torch.zeros(1))
42
-
43
- def forward(self, x):
44
- """
45
- inputs :
46
- x : input feature maps(B X C X W X H)
47
- returns :
48
- out : self attention value + input feature
49
- attention: B X N X N (N is Width*Height)
50
- """
51
- _, ch, h, w = x.size()
52
- # Theta path
53
- theta = self.snconv1x1_theta(x)
54
- theta = theta.view(-1, ch//8, h*w)
55
- # Phi path
56
- phi = self.snconv1x1_phi(x)
57
- phi = self.maxpool(phi)
58
- phi = phi.view(-1, ch//8, h*w//4)
59
- # Attn map
60
- attn = torch.bmm(theta.permute(0, 2, 1), phi)
61
- attn = self.softmax(attn)
62
- # g path
63
- g = self.snconv1x1_g(x)
64
- g = self.maxpool(g)
65
- g = g.view(-1, ch//2, h*w//4)
66
- # Attn_g
67
- attn_g = torch.bmm(g, attn.permute(0, 2, 1))
68
- attn_g = attn_g.view(-1, ch//2, h, w)
69
- attn_g = self.snconv1x1_attn(attn_g)
70
- # Out
71
- out = x + self.sigma*attn_g
72
- return out
73
-
74
-
75
- class GenBlock(nn.Module):
76
- def __init__(self, in_channels, out_channels):
77
- super(GenBlock, self).__init__()
78
- self.cond_bn1 = nn.BatchNorm2d(in_channels)
79
- self.relu = nn.ReLU(inplace=True)
80
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
81
- self.cond_bn2 = nn.BatchNorm2d(out_channels)
82
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
83
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
84
-
85
- def forward(self, x):
86
- x0 = x
87
-
88
- x = self.cond_bn1(x)
89
- x = self.relu(x)
90
- x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
91
- x = self.snconv2d1(x)
92
- x = self.cond_bn2(x)
93
- x = self.relu(x)
94
- x = self.snconv2d2(x)
95
-
96
- x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
97
- x0 = self.snconv2d0(x0)
98
-
99
- out = x + x0
100
- return out
101
-
102
-
103
- class cGAN_concat_SAGAN_Generator(nn.Module):
104
- """Generator."""
105
-
106
- def __init__(self, z_dim, dim_c=1, g_conv_dim=64):
107
- super(cGAN_concat_SAGAN_Generator, self).__init__()
108
-
109
- self.z_dim = z_dim
110
- self.dim_c = dim_c
111
- self.g_conv_dim = g_conv_dim
112
- self.snlinear0 = snlinear(in_features=z_dim+dim_c, out_features=g_conv_dim*16*4*4)
113
- self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16)
114
- self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8)
115
- self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4)
116
- self.self_attn = Self_Attn(g_conv_dim*4)
117
- self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2)
118
- self.block5 = GenBlock(g_conv_dim*2, g_conv_dim)
119
- self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
120
- self.relu = nn.ReLU(inplace=True)
121
- self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
122
- self.tanh = nn.Tanh()
123
-
124
- # Weight init
125
- self.apply(init_weights)
126
-
127
- def forward(self, z, labels):
128
- # n x z_dim
129
- act0 = self.snlinear0(torch.cat((z, labels.view(-1,1)),dim=1)) # n x g_conv_dim*16*4*4
130
- act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
131
- act1 = self.block1(act0) # n x g_conv_dim*16 x 8 x 8
132
- act2 = self.block2(act1) # n x g_conv_dim*8 x 16 x 16
133
- act3 = self.block3(act2) # n x g_conv_dim*4 x 32 x 32
134
- act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
135
- act4 = self.block4(act3) # n x g_conv_dim*2 x 64 x 64
136
- act5 = self.block5(act4) # n x g_conv_dim x 128 x 128
137
- act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
138
- act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
139
- act6 = self.snconv2d1(act5) # n x 3 x 128 x 128
140
- act6 = self.tanh(act6) # n x 3 x 128 x 128
141
- return act6
142
-
143
-
144
- class DiscOptBlock(nn.Module):
145
- def __init__(self, in_channels, out_channels):
146
- super(DiscOptBlock, self).__init__()
147
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
148
- self.relu = nn.ReLU(inplace=True)
149
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
150
- self.downsample = nn.AvgPool2d(2)
151
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
152
-
153
- def forward(self, x):
154
- x0 = x
155
-
156
- x = self.snconv2d1(x)
157
- x = self.relu(x)
158
- x = self.snconv2d2(x)
159
- x = self.downsample(x)
160
-
161
- x0 = self.downsample(x0)
162
- x0 = self.snconv2d0(x0)
163
-
164
- out = x + x0
165
- return out
166
-
167
-
168
- class DiscBlock(nn.Module):
169
- def __init__(self, in_channels, out_channels):
170
- super(DiscBlock, self).__init__()
171
- self.relu = nn.ReLU(inplace=True)
172
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
173
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
174
- self.downsample = nn.AvgPool2d(2)
175
- self.ch_mismatch = False
176
- if in_channels != out_channels:
177
- self.ch_mismatch = True
178
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
179
-
180
- def forward(self, x, downsample=True):
181
- x0 = x
182
-
183
- x = self.relu(x)
184
- x = self.snconv2d1(x)
185
- x = self.relu(x)
186
- x = self.snconv2d2(x)
187
- if downsample:
188
- x = self.downsample(x)
189
-
190
- if downsample or self.ch_mismatch:
191
- x0 = self.snconv2d0(x0)
192
- if downsample:
193
- x0 = self.downsample(x0)
194
-
195
- out = x + x0
196
- return out
197
-
198
-
199
- class cGAN_concat_SAGAN_Discriminator(nn.Module):
200
- """Discriminator."""
201
-
202
- def __init__(self, dim_c=1, d_conv_dim=64):
203
- super(cGAN_concat_SAGAN_Discriminator, self).__init__()
204
- self.d_conv_dim = d_conv_dim
205
- self.opt_block1 = DiscOptBlock(3, d_conv_dim)
206
- self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
207
- self.self_attn = Self_Attn(d_conv_dim*2)
208
- self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
209
- self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
210
- self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
211
- self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
212
- self.relu = nn.ReLU(inplace=True)
213
- self.snlinear1 = snlinear(in_features=d_conv_dim*16*4*4+dim_c, out_features=1)
214
-
215
- def forward(self, x, labels):
216
- # n x 3 x 128 x 128
217
- h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
218
- h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
219
- h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
220
- h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
221
- h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
222
- h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
223
- h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
224
- out = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
225
- # out = torch.sum(out, dim=[2,3]) # n x d_conv_dim*16
226
- out = out.view(-1,self.d_conv_dim*16*4*4)
227
- output = self.snlinear1(torch.cat((out, labels.view(-1,1)), dim=1))
228
-
229
- return output
230
-
231
-
232
-
233
- if __name__ == "__main__":
234
-
235
-
236
- netG = cGAN_concat_SAGAN_Generator(z_dim=128, dim_c=1, g_conv_dim=128).cuda()
237
- netD = cGAN_concat_SAGAN_Discriminator(dim_c=1, d_conv_dim=128).cuda()
238
-
239
- n = 4
240
- y = torch.randn(n, 1).cuda()
241
- z = torch.randn(n, 128).cuda()
242
- x = netG(z,y)
243
- o = netD(x,y)
244
- print(x.size())
245
- print(o.size())