Myosotis-Researches 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +301 -0
  2. myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +141 -0
  3. myosotis_researches/CcGAN/models_128/ResNet_embed.py +188 -0
  4. myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +175 -0
  5. myosotis_researches/CcGAN/models_128/__init__.py +8 -0
  6. myosotis_researches/CcGAN/models_128/autoencoder.py +119 -0
  7. myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +276 -0
  8. myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +245 -0
  9. myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +303 -0
  10. myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +142 -0
  11. myosotis_researches/CcGAN/models_256/ResNet_embed.py +188 -0
  12. myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +178 -0
  13. myosotis_researches/CcGAN/models_256/__init__.py +8 -0
  14. myosotis_researches/CcGAN/models_256/autoencoder.py +133 -0
  15. myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +280 -0
  16. myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +249 -0
  17. myosotis_researches/CcGAN/utils/make_h5.py +13 -9
  18. {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/METADATA +1 -1
  19. myosotis_researches-0.0.14.dist-info/RECORD +28 -0
  20. myosotis_researches-0.0.12.dist-info/RECORD +0 -12
  21. {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/WHEEL +0 -0
  22. {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/licenses/LICENSE +0 -0
  23. {myosotis_researches-0.0.12.dist-info → myosotis_researches-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from torch.nn.utils import spectral_norm
7
+ from torch.nn.init import xavier_uniform_
8
+
9
+
10
+ def init_weights(m):
11
+ if type(m) == nn.Linear or type(m) == nn.Conv2d:
12
+ xavier_uniform_(m.weight)
13
+ m.bias.data.fill_(0.)
14
+
15
+
16
+ def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
17
+ return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
18
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
19
+
20
+
21
+ def snlinear(in_features, out_features):
22
+ return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
23
+
24
+
25
+ def sn_embedding(num_embeddings, embedding_dim):
26
+ return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))
27
+
28
+
29
+ class Self_Attn(nn.Module):
30
+ """ Self attention Layer"""
31
+
32
+ def __init__(self, in_channels):
33
+ super(Self_Attn, self).__init__()
34
+ self.in_channels = in_channels
35
+ self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
36
+ self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
37
+ self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
38
+ self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
39
+ self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
40
+ self.softmax = nn.Softmax(dim=-1)
41
+ self.sigma = nn.Parameter(torch.zeros(1))
42
+
43
+ def forward(self, x):
44
+ """
45
+ inputs :
46
+ x : input feature maps(B X C X W X H)
47
+ returns :
48
+ out : self attention value + input feature
49
+ attention: B X N X N (N is Width*Height)
50
+ """
51
+ _, ch, h, w = x.size()
52
+ # Theta path
53
+ theta = self.snconv1x1_theta(x)
54
+ theta = theta.view(-1, ch//8, h*w)
55
+ # Phi path
56
+ phi = self.snconv1x1_phi(x)
57
+ phi = self.maxpool(phi)
58
+ phi = phi.view(-1, ch//8, h*w//4)
59
+ # Attn map
60
+ attn = torch.bmm(theta.permute(0, 2, 1), phi)
61
+ attn = self.softmax(attn)
62
+ # g path
63
+ g = self.snconv1x1_g(x)
64
+ g = self.maxpool(g)
65
+ g = g.view(-1, ch//2, h*w//4)
66
+ # Attn_g
67
+ attn_g = torch.bmm(g, attn.permute(0, 2, 1))
68
+ attn_g = attn_g.view(-1, ch//2, h, w)
69
+ attn_g = self.snconv1x1_attn(attn_g)
70
+ # Out
71
+ out = x + self.sigma*attn_g
72
+ return out
73
+
74
+
75
+ class GenBlock(nn.Module):
76
+ def __init__(self, in_channels, out_channels):
77
+ super(GenBlock, self).__init__()
78
+ self.cond_bn1 = nn.BatchNorm2d(in_channels)
79
+ self.relu = nn.ReLU(inplace=True)
80
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
81
+ self.cond_bn2 = nn.BatchNorm2d(out_channels)
82
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
83
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
84
+
85
+ def forward(self, x):
86
+ x0 = x
87
+
88
+ x = self.cond_bn1(x)
89
+ x = self.relu(x)
90
+ x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
91
+ x = self.snconv2d1(x)
92
+ x = self.cond_bn2(x)
93
+ x = self.relu(x)
94
+ x = self.snconv2d2(x)
95
+
96
+ x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
97
+ x0 = self.snconv2d0(x0)
98
+
99
+ out = x + x0
100
+ return out
101
+
102
+
103
+ class cGAN_concat_SAGAN_Generator(nn.Module):
104
+ """Generator."""
105
+
106
+ def __init__(self, z_dim, dim_c=1, g_conv_dim=64):
107
+ super(cGAN_concat_SAGAN_Generator, self).__init__()
108
+
109
+ self.z_dim = z_dim
110
+ self.dim_c = dim_c
111
+ self.g_conv_dim = g_conv_dim
112
+ self.snlinear0 = snlinear(in_features=z_dim+dim_c, out_features=g_conv_dim*16*4*4)
113
+ self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16)
114
+ self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8)
115
+ self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4)
116
+ self.self_attn = Self_Attn(g_conv_dim*4)
117
+ self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2)
118
+ self.block5 = GenBlock(g_conv_dim*2, g_conv_dim)
119
+ self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
120
+ self.relu = nn.ReLU(inplace=True)
121
+ self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)
122
+ self.tanh = nn.Tanh()
123
+
124
+ # Weight init
125
+ self.apply(init_weights)
126
+
127
+ def forward(self, z, labels):
128
+ # n x z_dim
129
+ act0 = self.snlinear0(torch.cat((z, labels.view(-1,1)),dim=1)) # n x g_conv_dim*16*4*4
130
+ act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4
131
+ act1 = self.block1(act0) # n x g_conv_dim*16 x 8 x 8
132
+ act2 = self.block2(act1) # n x g_conv_dim*8 x 16 x 16
133
+ act3 = self.block3(act2) # n x g_conv_dim*4 x 32 x 32
134
+ act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
135
+ act4 = self.block4(act3) # n x g_conv_dim*2 x 64 x 64
136
+ act5 = self.block5(act4) # n x g_conv_dim x 128 x 128
137
+ act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
138
+ act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
139
+ act6 = self.snconv2d1(act5) # n x 3 x 128 x 128
140
+ act6 = self.tanh(act6) # n x 3 x 128 x 128
141
+ return act6
142
+
143
+
144
+ class DiscOptBlock(nn.Module):
145
+ def __init__(self, in_channels, out_channels):
146
+ super(DiscOptBlock, self).__init__()
147
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
148
+ self.relu = nn.ReLU(inplace=True)
149
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
150
+ self.downsample = nn.AvgPool2d(2)
151
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
152
+
153
+ def forward(self, x):
154
+ x0 = x
155
+
156
+ x = self.snconv2d1(x)
157
+ x = self.relu(x)
158
+ x = self.snconv2d2(x)
159
+ x = self.downsample(x)
160
+
161
+ x0 = self.downsample(x0)
162
+ x0 = self.snconv2d0(x0)
163
+
164
+ out = x + x0
165
+ return out
166
+
167
+
168
+ class DiscBlock(nn.Module):
169
+ def __init__(self, in_channels, out_channels):
170
+ super(DiscBlock, self).__init__()
171
+ self.relu = nn.ReLU(inplace=True)
172
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
173
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
174
+ self.downsample = nn.AvgPool2d(2)
175
+ self.ch_mismatch = False
176
+ if in_channels != out_channels:
177
+ self.ch_mismatch = True
178
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
179
+
180
+ def forward(self, x, downsample=True):
181
+ x0 = x
182
+
183
+ x = self.relu(x)
184
+ x = self.snconv2d1(x)
185
+ x = self.relu(x)
186
+ x = self.snconv2d2(x)
187
+ if downsample:
188
+ x = self.downsample(x)
189
+
190
+ if downsample or self.ch_mismatch:
191
+ x0 = self.snconv2d0(x0)
192
+ if downsample:
193
+ x0 = self.downsample(x0)
194
+
195
+ out = x + x0
196
+ return out
197
+
198
+
199
+ class cGAN_concat_SAGAN_Discriminator(nn.Module):
200
+ """Discriminator."""
201
+
202
+ def __init__(self, dim_c=1, d_conv_dim=64):
203
+ super(cGAN_concat_SAGAN_Discriminator, self).__init__()
204
+ self.d_conv_dim = d_conv_dim
205
+ self.opt_block1 = DiscOptBlock(3, d_conv_dim)
206
+ self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)
207
+ self.self_attn = Self_Attn(d_conv_dim*2)
208
+ self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)
209
+ self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)
210
+ self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)
211
+ self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.snlinear1 = snlinear(in_features=d_conv_dim*16*4*4+dim_c, out_features=1)
214
+
215
+ def forward(self, x, labels):
216
+ # n x 3 x 128 x 128
217
+ h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
218
+ h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
219
+ h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
220
+ h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
221
+ h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
222
+ h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
223
+ h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
224
+ out = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
225
+ # out = torch.sum(out, dim=[2,3]) # n x d_conv_dim*16
226
+ out = out.view(-1,self.d_conv_dim*16*4*4)
227
+ output = self.snlinear1(torch.cat((out, labels.view(-1,1)), dim=1))
228
+
229
+ return output
230
+
231
+
232
+
233
+ if __name__ == "__main__":
234
+
235
+
236
+ netG = cGAN_concat_SAGAN_Generator(z_dim=128, dim_c=1, g_conv_dim=128).cuda()
237
+ netD = cGAN_concat_SAGAN_Discriminator(dim_c=1, d_conv_dim=128).cuda()
238
+
239
+ n = 4
240
+ y = torch.randn(n, 1).cuda()
241
+ z = torch.randn(n, 128).cuda()
242
+ x = netG(z,y)
243
+ o = netD(x,y)
244
+ print(x.size())
245
+ print(o.size())
@@ -0,0 +1,303 @@
1
+ '''
2
+
3
+ Adapted from https://github.com/voletiv/self-attention-GAN-pytorch/blob/master/sagan_models.py
4
+
5
+
6
+ '''
7
+
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from torch.nn.utils import spectral_norm
15
+ from torch.nn.init import xavier_uniform_
16
+
17
+
18
+ def init_weights(m):
19
+ if type(m) == nn.Linear or type(m) == nn.Conv2d:
20
+ xavier_uniform_(m.weight)
21
+ if m.bias is not None:
22
+ m.bias.data.fill_(0.)
23
+
24
+
25
+ def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
26
+ return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
27
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
28
+
29
+ def snlinear(in_features, out_features, bias=True):
30
+ return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias))
31
+
32
+
33
+
34
+ class Self_Attn(nn.Module):
35
+ """ Self attention Layer"""
36
+
37
+ def __init__(self, in_channels):
38
+ super(Self_Attn, self).__init__()
39
+ self.in_channels = in_channels
40
+ self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
41
+ self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
42
+ self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
43
+ self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
44
+ self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
45
+ self.softmax = nn.Softmax(dim=-1)
46
+ self.sigma = nn.Parameter(torch.zeros(1))
47
+
48
+ def forward(self, x):
49
+ """
50
+ inputs :
51
+ x : input feature maps(B X C X W X H)
52
+ returns :
53
+ out : self attention value + input feature
54
+ attention: B X N X N (N is Width*Height)
55
+ """
56
+ _, ch, h, w = x.size()
57
+ # Theta path
58
+ theta = self.snconv1x1_theta(x)
59
+ theta = theta.view(-1, ch//8, h*w)
60
+ # Phi path
61
+ phi = self.snconv1x1_phi(x)
62
+ phi = self.maxpool(phi)
63
+ phi = phi.view(-1, ch//8, h*w//4)
64
+ # Attn map
65
+ attn = torch.bmm(theta.permute(0, 2, 1), phi)
66
+ attn = self.softmax(attn)
67
+ # g path
68
+ g = self.snconv1x1_g(x)
69
+ g = self.maxpool(g)
70
+ g = g.view(-1, ch//2, h*w//4)
71
+ # Attn_g
72
+ attn_g = torch.bmm(g, attn.permute(0, 2, 1))
73
+ attn_g = attn_g.view(-1, ch//2, h, w)
74
+ attn_g = self.snconv1x1_attn(attn_g)
75
+ # Out
76
+ out = x + self.sigma*attn_g
77
+ return out
78
+
79
+
80
+
81
+
82
+ '''
83
+
84
+ Generator
85
+
86
+ '''
87
+
88
+
89
+ class ConditionalBatchNorm2d(nn.Module):
90
+ def __init__(self, num_features, dim_embed):
91
+ super().__init__()
92
+ self.num_features = num_features
93
+ self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
94
+ self.embed_gamma = nn.Linear(dim_embed, num_features, bias=False)
95
+ self.embed_beta = nn.Linear(dim_embed, num_features, bias=False)
96
+
97
+ def forward(self, x, y):
98
+ out = self.bn(x)
99
+ gamma = self.embed_gamma(y).view(-1, self.num_features, 1, 1)
100
+ beta = self.embed_beta(y).view(-1, self.num_features, 1, 1)
101
+ out = out + gamma*out + beta
102
+ return out
103
+
104
+
105
+ class GenBlock(nn.Module):
106
+ def __init__(self, in_channels, out_channels, dim_embed):
107
+ super(GenBlock, self).__init__()
108
+ self.cond_bn1 = ConditionalBatchNorm2d(in_channels, dim_embed)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
111
+ self.cond_bn2 = ConditionalBatchNorm2d(out_channels, dim_embed)
112
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
113
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
114
+
115
+ def forward(self, x, labels):
116
+ x0 = x
117
+
118
+ x = self.cond_bn1(x, labels)
119
+ x = self.relu(x)
120
+ x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
121
+ x = self.snconv2d1(x)
122
+ x = self.cond_bn2(x, labels)
123
+ x = self.relu(x)
124
+ x = self.snconv2d2(x)
125
+
126
+ x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
127
+ x0 = self.snconv2d0(x0)
128
+
129
+ out = x + x0
130
+ return out
131
+
132
+
133
+ class CcGAN_SAGAN_Generator(nn.Module):
134
+ """Generator."""
135
+
136
+ def __init__(self, dim_z, dim_embed=128, nc=3, gene_ch=64):
137
+ super(CcGAN_SAGAN_Generator, self).__init__()
138
+
139
+ self.dim_z = dim_z
140
+ self.gene_ch = gene_ch
141
+
142
+ self.snlinear0 = snlinear(in_features=dim_z, out_features=gene_ch*16*4*4)
143
+ self.block1 = GenBlock(gene_ch*16, gene_ch*16, dim_embed)
144
+ self.block2 = GenBlock(gene_ch*16, gene_ch*8, dim_embed)
145
+ self.block3 = GenBlock(gene_ch*8, gene_ch*4, dim_embed)
146
+ self.block4 = GenBlock(gene_ch*4, gene_ch*2, dim_embed)
147
+ self.self_attn = Self_Attn(gene_ch*2)
148
+ self.block5 = GenBlock(gene_ch*2, gene_ch*2, dim_embed)
149
+ self.block6 = GenBlock(gene_ch*2, gene_ch, dim_embed)
150
+ self.bn = nn.BatchNorm2d(gene_ch, eps=1e-5, momentum=0.0001, affine=True)
151
+ self.relu = nn.ReLU(inplace=True)
152
+ self.snconv2d1 = snconv2d(in_channels=gene_ch, out_channels=nc, kernel_size=3, stride=1, padding=1)
153
+ self.tanh = nn.Tanh()
154
+
155
+ # Weight init
156
+ self.apply(init_weights)
157
+
158
+ def forward(self, z, labels):
159
+ # n x dim_z
160
+ out = self.snlinear0(z) # 4*4
161
+ out = out.view(-1, self.gene_ch*16, 4, 4) # 4 x 4
162
+ out = self.block1(out, labels) # 8 x 8
163
+ out = self.block2(out, labels) # 16 x 16
164
+ out = self.block3(out, labels) # 32 x 32
165
+ out = self.block4(out, labels) # 64 x 64
166
+ out = self.self_attn(out) # 64 x 64
167
+ out = self.block5(out, labels) # 128 x 128
168
+ out = self.block6(out, labels) # 256 x 256
169
+ out = self.bn(out)
170
+ out = self.relu(out)
171
+ out = self.snconv2d1(out)
172
+ out = self.tanh(out)
173
+ return out
174
+
175
+
176
+
177
+ '''
178
+
179
+ Discriminator
180
+
181
+ '''
182
+
183
+ class DiscOptBlock(nn.Module):
184
+ def __init__(self, in_channels, out_channels):
185
+ super(DiscOptBlock, self).__init__()
186
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
187
+ self.relu = nn.ReLU(inplace=True)
188
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
189
+ self.downsample = nn.AvgPool2d(2)
190
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
191
+
192
+ def forward(self, x):
193
+ x0 = x
194
+
195
+ x = self.snconv2d1(x)
196
+ x = self.relu(x)
197
+ x = self.snconv2d2(x)
198
+ x = self.downsample(x)
199
+
200
+ x0 = self.downsample(x0)
201
+ x0 = self.snconv2d0(x0)
202
+
203
+ out = x + x0
204
+ return out
205
+
206
+
207
+ class DiscBlock(nn.Module):
208
+ def __init__(self, in_channels, out_channels):
209
+ super(DiscBlock, self).__init__()
210
+ self.relu = nn.ReLU(inplace=True)
211
+ self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
212
+ self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
213
+ self.downsample = nn.AvgPool2d(2)
214
+ self.ch_mismatch = False
215
+ if in_channels != out_channels:
216
+ self.ch_mismatch = True
217
+ self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
218
+
219
+ def forward(self, x, downsample=True):
220
+ x0 = x
221
+
222
+ x = self.relu(x)
223
+ x = self.snconv2d1(x)
224
+ x = self.relu(x)
225
+ x = self.snconv2d2(x)
226
+ if downsample:
227
+ x = self.downsample(x)
228
+
229
+ if downsample or self.ch_mismatch:
230
+ x0 = self.snconv2d0(x0)
231
+ if downsample:
232
+ x0 = self.downsample(x0)
233
+
234
+ out = x + x0
235
+ return out
236
+
237
+
238
+ class CcGAN_SAGAN_Discriminator(nn.Module):
239
+ """Discriminator."""
240
+
241
+ def __init__(self, dim_embed=128, nc=3, disc_ch=64):
242
+ super(CcGAN_SAGAN_Discriminator, self).__init__()
243
+ self.disc_ch = disc_ch
244
+ self.opt_block1 = DiscOptBlock(nc, disc_ch)
245
+ self.block1 = DiscBlock(disc_ch, disc_ch*2)
246
+ self.self_attn = Self_Attn(disc_ch*2)
247
+ self.block2 = DiscBlock(disc_ch*2, disc_ch*4)
248
+ self.block3 = DiscBlock(disc_ch*4, disc_ch*6)
249
+ self.block4 = DiscBlock(disc_ch*6, disc_ch*12)
250
+ self.block5 = DiscBlock(disc_ch*12, disc_ch*12)
251
+ self.block6 = DiscBlock(disc_ch*12, disc_ch*16)
252
+ self.relu = nn.ReLU(inplace=True)
253
+ self.snlinear1 = snlinear(in_features=disc_ch*16*4*4, out_features=1)
254
+ self.sn_embedding1 = snlinear(dim_embed, disc_ch*16*4*4, bias=False)
255
+
256
+ # Weight init
257
+ self.apply(init_weights)
258
+ xavier_uniform_(self.sn_embedding1.weight)
259
+
260
+ def forward(self, x, labels):
261
+ # 256x256
262
+ out = self.opt_block1(x) # 128 x 128
263
+ out = self.block1(out) # 64 x 64
264
+ out = self.self_attn(out) # 64 x 64
265
+ out = self.block2(out) # 32 x 32
266
+ out = self.block3(out) # 16 x 16
267
+ out = self.block4(out) # 8 x 8
268
+ out = self.block5(out) # 4 x 4
269
+ out = self.block6(out, downsample=False) # 4 x 4
270
+ out = self.relu(out) # n x disc_ch*16 x 4 x 4
271
+ out = out.view(-1, self.disc_ch*16*4*4)
272
+ output1 = torch.squeeze(self.snlinear1(out)) # n
273
+ # Projection
274
+ h_labels = self.sn_embedding1(labels) # n x disc_ch*16 x 4 x 4
275
+ proj = torch.mul(out, h_labels) # n x disc_ch*16 x 4 x 4
276
+ output2 = torch.sum(proj, dim=[1]) # n
277
+ # Out
278
+ output = output1 + output2 # n
279
+ return output
280
+
281
+
282
+ if __name__ == "__main__":
283
+ def get_parameter_number(net):
284
+ total_num = sum(p.numel() for p in net.parameters())
285
+ trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
286
+ return {'Total': total_num, 'Trainable': trainable_num}
287
+
288
+ netG = CcGAN_SAGAN_Generator(dim_z=256, dim_embed=128, gene_ch=128).cuda()
289
+ netD = CcGAN_SAGAN_Discriminator(dim_embed=128, disc_ch=128).cuda()
290
+
291
+ # netG = nn.DataParallel(netG)
292
+ # netD = nn.DataParallel(netD)
293
+
294
+ N=4
295
+ z = torch.randn(N, 256).cuda()
296
+ y = torch.randn(N, 128).cuda()
297
+ x = netG(z,y)
298
+ o = netD(x,y)
299
+ print(x.size())
300
+ print(o.size())
301
+
302
+ print('G:', get_parameter_number(netG))
303
+ print('D:', get_parameter_number(netD))
@@ -0,0 +1,142 @@
1
+ '''
2
+ Regular ResNet
3
+
4
+ codes are based on
5
+ @article{
6
+ zhang2018mixup,
7
+ title={mixup: Beyond Empirical Risk Minimization},
8
+ author={Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz},
9
+ journal={International Conference on Learning Representations},
10
+ year={2018},
11
+ url={https://openreview.net/forum?id=r1Ddp1-Rb},
12
+ }
13
+ '''
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from torch.autograd import Variable
21
+
22
+ IMG_SIZE=256
23
+ NC=3
24
+
25
+
26
+ class BasicBlock(nn.Module):
27
+ expansion = 1
28
+
29
+ def __init__(self, in_planes, planes, stride=1):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
32
+ self.bn1 = nn.BatchNorm2d(planes)
33
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
34
+ self.bn2 = nn.BatchNorm2d(planes)
35
+
36
+ self.shortcut = nn.Sequential()
37
+ if stride != 1 or in_planes != self.expansion*planes:
38
+ self.shortcut = nn.Sequential(
39
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
40
+ nn.BatchNorm2d(self.expansion*planes)
41
+ )
42
+
43
+ def forward(self, x):
44
+ out = F.relu(self.bn1(self.conv1(x)))
45
+ out = self.bn2(self.conv2(out))
46
+ out += self.shortcut(x)
47
+ out = F.relu(out)
48
+ return out
49
+
50
+
51
+ class Bottleneck(nn.Module):
52
+ expansion = 4
53
+
54
+ def __init__(self, in_planes, planes, stride=1):
55
+ super(Bottleneck, self).__init__()
56
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
57
+ self.bn1 = nn.BatchNorm2d(planes)
58
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
59
+ self.bn2 = nn.BatchNorm2d(planes)
60
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
61
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
62
+
63
+ self.shortcut = nn.Sequential()
64
+ if stride != 1 or in_planes != self.expansion*planes:
65
+ self.shortcut = nn.Sequential(
66
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
67
+ nn.BatchNorm2d(self.expansion*planes)
68
+ )
69
+
70
+ def forward(self, x):
71
+ out = F.relu(self.bn1(self.conv1(x)))
72
+ out = F.relu(self.bn2(self.conv2(out)))
73
+ out = self.bn3(self.conv3(out))
74
+ out += self.shortcut(x)
75
+ out = F.relu(out)
76
+ return out
77
+
78
+
79
+ class ResNet_class_eval(nn.Module):
80
+ def __init__(self, block, num_blocks, num_classes=49, nc=NC, ngpu = 1):
81
+ super(ResNet_class_eval, self).__init__()
82
+ self.in_planes = 64
83
+ self.ngpu = ngpu
84
+
85
+ self.main = nn.Sequential(
86
+ nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False), # h=h #256
87
+ nn.BatchNorm2d(64),
88
+ nn.ReLU(),
89
+ nn.MaxPool2d(2,2), #h=h/2 128
90
+ # self._make_layer(block, 64, num_blocks[0], stride=1), # h=h
91
+ self._make_layer(block, 64, num_blocks[0], stride=2), # h=h/2 64
92
+ nn.MaxPool2d(2,2), #h=h/2 32
93
+ self._make_layer(block, 128, num_blocks[1], stride=2),
94
+ self._make_layer(block, 256, num_blocks[2], stride=2),
95
+ self._make_layer(block, 512, num_blocks[3], stride=2),
96
+ nn.AvgPool2d(kernel_size=4)
97
+ )
98
+ self.classifier = nn.Linear(512*block.expansion, num_classes)
99
+
100
+ def _make_layer(self, block, planes, num_blocks, stride):
101
+ strides = [stride] + [1]*(num_blocks-1)
102
+ layers = []
103
+ for stride in strides:
104
+ layers.append(block(self.in_planes, planes, stride))
105
+ self.in_planes = planes * block.expansion
106
+ return nn.Sequential(*layers)
107
+
108
+ def forward(self, x):
109
+
110
+ if x.is_cuda and self.ngpu > 1:
111
+ features = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
112
+ features = features.view(features.size(0), -1)
113
+ out = nn.parallel.data_parallel(self.classifier, features, range(self.ngpu))
114
+ else:
115
+ features = self.main(x)
116
+ features = features.view(features.size(0), -1)
117
+ out = self.classifier(features)
118
+ return out, features
119
+
120
+
121
+ def ResNet18_class_eval(num_classes=49, ngpu = 1):
122
+ return ResNet_class_eval(BasicBlock, [2,2,2,2], num_classes=num_classes, ngpu = ngpu)
123
+
124
+ def ResNet34_class_eval(num_classes=49, ngpu = 1):
125
+ return ResNet_class_eval(BasicBlock, [3,4,6,3], num_classes=num_classes, ngpu = ngpu)
126
+
127
+ def ResNet50_class_eval(num_classes=49, ngpu = 1):
128
+ return ResNet_class_eval(Bottleneck, [3,4,6,3], num_classes=num_classes, ngpu = ngpu)
129
+
130
+ def ResNet101_class_eval(num_classes=49, ngpu = 1):
131
+ return ResNet_class_eval(Bottleneck, [3,4,23,3], num_classes=num_classes, ngpu = ngpu)
132
+
133
+ def ResNet152_class_eval(num_classes=49, ngpu = 1):
134
+ return ResNet_class_eval(Bottleneck, [3,8,36,3], num_classes=num_classes, ngpu = ngpu)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ net = ResNet50_class_eval(num_classes=5, ngpu = 1).cuda()
139
+ x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
140
+ out, features = net(x)
141
+ print(out.size())
142
+ print(features.size())