Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/train/__init__.py +4 -0
- myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan.py +1 -3
- myosotis_researches/CcGAN/{train_128 → train}/train_cgan_concat.py +1 -3
- myosotis_researches/CcGAN/utils/__init__.py +2 -1
- myosotis_researches/CcGAN/utils/train.py +94 -3
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/METADATA +1 -1
- myosotis_researches-0.1.9.dist-info/RECORD +24 -0
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
- myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
- myosotis_researches/CcGAN/models_128/__init__.py +0 -7
- myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
- myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
- myosotis_researches/CcGAN/models_256/__init__.py +0 -7
- myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128/opts.py +0 -87
- myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
- myosotis_researches/CcGAN/train_128/utils.py +0 -120
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
- myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
- myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
- myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
- myosotis_researches-0.1.7.dist-info/RECORD +0 -59
- /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.9.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,9 @@ import timeit
|
|
5
5
|
from PIL import Image
|
6
6
|
from torchvision.utils import save_image
|
7
7
|
import torch.cuda as cutorch
|
8
|
+
import sys
|
8
9
|
|
9
|
-
from .utils import
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
10
|
+
from myosotis_researches.CcGAN.utils import *
|
12
11
|
|
13
12
|
''' Settings '''
|
14
13
|
args = parse_opts()
|
@@ -79,7 +78,8 @@ def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net
|
|
79
78
|
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
|
80
79
|
n_row=10; n_col = 1
|
81
80
|
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
|
82
|
-
|
81
|
+
|
82
|
+
selected_labels = np.array([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
83
83
|
|
84
84
|
y_fixed = np.zeros(n_row*n_col)
|
85
85
|
for i in range(n_row):
|
@@ -6,9 +6,7 @@ import numpy as np
|
|
6
6
|
import os
|
7
7
|
import timeit
|
8
8
|
|
9
|
-
from .utils import
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
9
|
+
from myosotis_researches.CcGAN.utils import *
|
12
10
|
|
13
11
|
''' Settings '''
|
14
12
|
args = parse_opts()
|
@@ -6,9 +6,7 @@ import numpy as np
|
|
6
6
|
import os
|
7
7
|
import timeit
|
8
8
|
|
9
|
-
from .utils import
|
10
|
-
from .opts import parse_opts
|
11
|
-
from .DiffAugment_pytorch import DiffAugment
|
9
|
+
from myosotis_researches.CcGAN.utils import *
|
12
10
|
|
13
11
|
''' Settings '''
|
14
12
|
args = parse_opts()
|
@@ -3,7 +3,7 @@ from .concat_image import concat_image
|
|
3
3
|
from .make_h5 import make_h5
|
4
4
|
from .SimpleProgressBar import SimpleProgressBar
|
5
5
|
from .IMGs_dataset import IMGs_dataset
|
6
|
-
from .train import PlotLoss, compute_entropy, predict_class_labels
|
6
|
+
from .train import PlotLoss, compute_entropy, predict_class_labels, DiffAugment
|
7
7
|
from .opts import parse_opts
|
8
8
|
|
9
9
|
__all__ = [
|
@@ -15,5 +15,6 @@ __all__ = [
|
|
15
15
|
"PlotLoss",
|
16
16
|
"compute_entropy",
|
17
17
|
"predict_class_labels",
|
18
|
+
"DiffAugment",
|
18
19
|
"parse_opts"
|
19
20
|
]
|
@@ -1,7 +1,8 @@
|
|
1
|
+
import matplotlib as mpl
|
2
|
+
import matplotlib.pyplot as plt
|
1
3
|
import numpy as np
|
2
4
|
import torch
|
3
|
-
import
|
4
|
-
import matplotlib as mpl
|
5
|
+
import torch.nn.functional as F
|
5
6
|
|
6
7
|
|
7
8
|
def PlotLoss(loss, filename):
|
@@ -62,4 +63,94 @@ def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers
|
|
62
63
|
return class_labels_pred
|
63
64
|
|
64
65
|
|
65
|
-
|
66
|
+
def DiffAugment(x, policy="", channels_first=True):
|
67
|
+
if policy:
|
68
|
+
if not channels_first:
|
69
|
+
x = x.permute(0, 3, 1, 2)
|
70
|
+
for p in policy.split(","):
|
71
|
+
for f in AUGMENT_FNS[p]:
|
72
|
+
x = f(x)
|
73
|
+
if not channels_first:
|
74
|
+
x = x.permute(0, 2, 3, 1)
|
75
|
+
x = x.contiguous()
|
76
|
+
return x
|
77
|
+
|
78
|
+
|
79
|
+
def rand_brightness(x):
|
80
|
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
81
|
+
return x
|
82
|
+
|
83
|
+
|
84
|
+
def rand_saturation(x):
|
85
|
+
x_mean = x.mean(dim=1, keepdim=True)
|
86
|
+
x = (x - x_mean) * (
|
87
|
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
|
88
|
+
) + x_mean
|
89
|
+
return x
|
90
|
+
|
91
|
+
|
92
|
+
def rand_contrast(x):
|
93
|
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
94
|
+
x = (x - x_mean) * (
|
95
|
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
|
96
|
+
) + x_mean
|
97
|
+
return x
|
98
|
+
|
99
|
+
|
100
|
+
def rand_translation(x, ratio=0.125):
|
101
|
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
102
|
+
translation_x = torch.randint(
|
103
|
+
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
|
104
|
+
)
|
105
|
+
translation_y = torch.randint(
|
106
|
+
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
|
107
|
+
)
|
108
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
109
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
110
|
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
111
|
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
112
|
+
)
|
113
|
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
114
|
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
115
|
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
116
|
+
x = (
|
117
|
+
x_pad.permute(0, 2, 3, 1)
|
118
|
+
.contiguous()[grid_batch, grid_x, grid_y]
|
119
|
+
.permute(0, 3, 1, 2)
|
120
|
+
)
|
121
|
+
return x
|
122
|
+
|
123
|
+
|
124
|
+
def rand_cutout(x, ratio=0.5):
|
125
|
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
126
|
+
offset_x = torch.randint(
|
127
|
+
0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
|
128
|
+
)
|
129
|
+
offset_y = torch.randint(
|
130
|
+
0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
|
131
|
+
)
|
132
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
133
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
134
|
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
135
|
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
136
|
+
)
|
137
|
+
grid_x = torch.clamp(
|
138
|
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
|
139
|
+
)
|
140
|
+
grid_y = torch.clamp(
|
141
|
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
|
142
|
+
)
|
143
|
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
144
|
+
mask[grid_batch, grid_x, grid_y] = 0
|
145
|
+
x = x * mask.unsqueeze(1)
|
146
|
+
return x
|
147
|
+
|
148
|
+
|
149
|
+
AUGMENT_FNS = {
|
150
|
+
"color": [rand_brightness, rand_saturation, rand_contrast],
|
151
|
+
"translation": [rand_translation],
|
152
|
+
"cutout": [rand_cutout],
|
153
|
+
}
|
154
|
+
|
155
|
+
|
156
|
+
__all__ = ["PlotLoss", "compute_entropy", "predict_class_labels", "DiffAugment"]
|
@@ -0,0 +1,24 @@
|
|
1
|
+
myosotis_researches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
myosotis_researches/CcGAN/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
myosotis_researches/CcGAN/internal/__init__.py,sha256=b-63yANNRQXgLF9k9yGdrm7mlULqGic1HTQTzg9wIME,209
|
4
|
+
myosotis_researches/CcGAN/internal/install_datasets.py,sha256=jJwLOZrDnHMrJSUhXxSIFobdeWK5N6eitPmjeBW9FyA,1144
|
5
|
+
myosotis_researches/CcGAN/internal/show_datasets.py,sha256=BWtQ6vdiEUOTrOs8aMBv6utuUN0IiaLKcK5iXq9y2qI,363
|
6
|
+
myosotis_researches/CcGAN/internal/uninstall_datasets.py,sha256=7pxPZcSe9RHncF0I_4rf8ZdI7eQwv-sFVfxzSVZfYHQ,297
|
7
|
+
myosotis_researches/CcGAN/train/__init__.py,sha256=-55Ccov89II6Yuaiszi8ziw9EoVQr7OJR0bQfPAE_10,127
|
8
|
+
myosotis_researches/CcGAN/train/train_ccgan.py,sha256=0Qxibgd2-WaYgbyYeeOyiMkdcwkd_M1m1gSqoHTjN0w,13268
|
9
|
+
myosotis_researches/CcGAN/train/train_cgan.py,sha256=sxMzvlmdjmqufwJFxBwatcoJecYqn2Uidedu15CL9ws,9619
|
10
|
+
myosotis_researches/CcGAN/train/train_cgan_concat.py,sha256=OrQbwdU_ujUeKFGixUUpnini6rURtbuHv9NDrP6g0X0,8861
|
11
|
+
myosotis_researches/CcGAN/train/train_net_for_label_embed.py,sha256=4j6r4_o4rXgAN4MdUQL-TXqZJpbhH7d9gWQR8YzBlXw,6976
|
12
|
+
myosotis_researches/CcGAN/utils/IMGs_dataset.py,sha256=i45PBNSCeAEB5uUG0SluYRTuHWZwH_5ldz2wm6afkYs,927
|
13
|
+
myosotis_researches/CcGAN/utils/SimpleProgressBar.py,sha256=S4eD_m6ysHRMHAmRtkTXVRNfXTR8kuHv-d3lUN0BVn4,546
|
14
|
+
myosotis_researches/CcGAN/utils/__init__.py,sha256=em3aB0C-V230NQtT64hyuHGo4CjV6p2DwIdtNM0dk4k,516
|
15
|
+
myosotis_researches/CcGAN/utils/concat_image.py,sha256=BIGKz52Inn9S7M5fBFKye2V9bLJ0DqEQILoOVWAXUiE,2165
|
16
|
+
myosotis_researches/CcGAN/utils/make_h5.py,sha256=VtFYjr_i-JktsEW_BvofpilcDmChRmyLykv0VvlMuY0,963
|
17
|
+
myosotis_researches/CcGAN/utils/opts.py,sha256=pd7-wknNPBO5hWRpO3YAPmmAsPKgZUUpKc4gWMs6Wto,5397
|
18
|
+
myosotis_researches/CcGAN/utils/print_hdf5.py,sha256=VvmNAWtMDmg6D9V6ZbSUXrQTKRh9WIJeC4BR_ORJkco,300
|
19
|
+
myosotis_researches/CcGAN/utils/train.py,sha256=5ZXgkGesuInqUooJRpLej_KHqYQtlSDq90_5wig5elQ,5152
|
20
|
+
myosotis_researches-0.1.9.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
21
|
+
myosotis_researches-0.1.9.dist-info/METADATA,sha256=F0XMimBS26-MprX3UHMvW1KtXOuMF4FZQlTw9L3L0mc,2663
|
22
|
+
myosotis_researches-0.1.9.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
|
23
|
+
myosotis_researches-0.1.9.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
|
24
|
+
myosotis_researches-0.1.9.dist-info/RECORD,,
|
@@ -1,301 +0,0 @@
|
|
1
|
-
'''
|
2
|
-
|
3
|
-
Adapted from https://github.com/voletiv/self-attention-GAN-pytorch/blob/master/sagan_models.py
|
4
|
-
|
5
|
-
|
6
|
-
'''
|
7
|
-
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
import torch
|
11
|
-
import torch.nn as nn
|
12
|
-
import torch.nn.functional as F
|
13
|
-
|
14
|
-
from torch.nn.utils import spectral_norm
|
15
|
-
from torch.nn.init import xavier_uniform_
|
16
|
-
|
17
|
-
|
18
|
-
def init_weights(m):
|
19
|
-
if type(m) == nn.Linear or type(m) == nn.Conv2d:
|
20
|
-
xavier_uniform_(m.weight)
|
21
|
-
if m.bias is not None:
|
22
|
-
m.bias.data.fill_(0.)
|
23
|
-
|
24
|
-
|
25
|
-
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
26
|
-
return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
27
|
-
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
|
28
|
-
|
29
|
-
def snlinear(in_features, out_features, bias=True):
|
30
|
-
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias))
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
class Self_Attn(nn.Module):
|
35
|
-
""" Self attention Layer"""
|
36
|
-
|
37
|
-
def __init__(self, in_channels):
|
38
|
-
super(Self_Attn, self).__init__()
|
39
|
-
self.in_channels = in_channels
|
40
|
-
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
41
|
-
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
|
42
|
-
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
|
43
|
-
self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
|
44
|
-
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
45
|
-
self.softmax = nn.Softmax(dim=-1)
|
46
|
-
self.sigma = nn.Parameter(torch.zeros(1))
|
47
|
-
|
48
|
-
def forward(self, x):
|
49
|
-
"""
|
50
|
-
inputs :
|
51
|
-
x : input feature maps(B X C X W X H)
|
52
|
-
returns :
|
53
|
-
out : self attention value + input feature
|
54
|
-
attention: B X N X N (N is Width*Height)
|
55
|
-
"""
|
56
|
-
_, ch, h, w = x.size()
|
57
|
-
# Theta path
|
58
|
-
theta = self.snconv1x1_theta(x)
|
59
|
-
theta = theta.view(-1, ch//8, h*w)
|
60
|
-
# Phi path
|
61
|
-
phi = self.snconv1x1_phi(x)
|
62
|
-
phi = self.maxpool(phi)
|
63
|
-
phi = phi.view(-1, ch//8, h*w//4)
|
64
|
-
# Attn map
|
65
|
-
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
66
|
-
attn = self.softmax(attn)
|
67
|
-
# g path
|
68
|
-
g = self.snconv1x1_g(x)
|
69
|
-
g = self.maxpool(g)
|
70
|
-
g = g.view(-1, ch//2, h*w//4)
|
71
|
-
# Attn_g
|
72
|
-
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
73
|
-
attn_g = attn_g.view(-1, ch//2, h, w)
|
74
|
-
attn_g = self.snconv1x1_attn(attn_g)
|
75
|
-
# Out
|
76
|
-
out = x + self.sigma*attn_g
|
77
|
-
return out
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
'''
|
83
|
-
|
84
|
-
Generator
|
85
|
-
|
86
|
-
'''
|
87
|
-
|
88
|
-
|
89
|
-
class ConditionalBatchNorm2d(nn.Module):
|
90
|
-
def __init__(self, num_features, dim_embed):
|
91
|
-
super().__init__()
|
92
|
-
self.num_features = num_features
|
93
|
-
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
|
94
|
-
self.embed_gamma = nn.Linear(dim_embed, num_features, bias=False)
|
95
|
-
self.embed_beta = nn.Linear(dim_embed, num_features, bias=False)
|
96
|
-
|
97
|
-
def forward(self, x, y):
|
98
|
-
out = self.bn(x)
|
99
|
-
gamma = self.embed_gamma(y).view(-1, self.num_features, 1, 1)
|
100
|
-
beta = self.embed_beta(y).view(-1, self.num_features, 1, 1)
|
101
|
-
out = out + gamma*out + beta
|
102
|
-
return out
|
103
|
-
|
104
|
-
|
105
|
-
class GenBlock(nn.Module):
|
106
|
-
def __init__(self, in_channels, out_channels, dim_embed):
|
107
|
-
super(GenBlock, self).__init__()
|
108
|
-
self.cond_bn1 = ConditionalBatchNorm2d(in_channels, dim_embed)
|
109
|
-
self.relu = nn.ReLU(inplace=True)
|
110
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
111
|
-
self.cond_bn2 = ConditionalBatchNorm2d(out_channels, dim_embed)
|
112
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
113
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
114
|
-
|
115
|
-
def forward(self, x, labels):
|
116
|
-
x0 = x
|
117
|
-
|
118
|
-
x = self.cond_bn1(x, labels)
|
119
|
-
x = self.relu(x)
|
120
|
-
x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
|
121
|
-
x = self.snconv2d1(x)
|
122
|
-
x = self.cond_bn2(x, labels)
|
123
|
-
x = self.relu(x)
|
124
|
-
x = self.snconv2d2(x)
|
125
|
-
|
126
|
-
x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
|
127
|
-
x0 = self.snconv2d0(x0)
|
128
|
-
|
129
|
-
out = x + x0
|
130
|
-
return out
|
131
|
-
|
132
|
-
|
133
|
-
class CcGAN_SAGAN_Generator(nn.Module):
|
134
|
-
"""Generator."""
|
135
|
-
|
136
|
-
def __init__(self, dim_z, dim_embed=128, nc=3, gene_ch=64):
|
137
|
-
super(CcGAN_SAGAN_Generator, self).__init__()
|
138
|
-
|
139
|
-
self.dim_z = dim_z
|
140
|
-
self.gene_ch = gene_ch
|
141
|
-
|
142
|
-
self.snlinear0 = snlinear(in_features=dim_z, out_features=gene_ch*16*4*4)
|
143
|
-
self.block1 = GenBlock(gene_ch*16, gene_ch*16, dim_embed)
|
144
|
-
self.block2 = GenBlock(gene_ch*16, gene_ch*8, dim_embed)
|
145
|
-
self.block3 = GenBlock(gene_ch*8, gene_ch*4, dim_embed)
|
146
|
-
self.self_attn = Self_Attn(gene_ch*4)
|
147
|
-
self.block4 = GenBlock(gene_ch*4, gene_ch*2, dim_embed)
|
148
|
-
self.block5 = GenBlock(gene_ch*2, gene_ch, dim_embed)
|
149
|
-
self.bn = nn.BatchNorm2d(gene_ch, eps=1e-5, momentum=0.0001, affine=True)
|
150
|
-
self.relu = nn.ReLU(inplace=True)
|
151
|
-
self.snconv2d1 = snconv2d(in_channels=gene_ch, out_channels=nc, kernel_size=3, stride=1, padding=1)
|
152
|
-
self.tanh = nn.Tanh()
|
153
|
-
|
154
|
-
# Weight init
|
155
|
-
self.apply(init_weights)
|
156
|
-
|
157
|
-
def forward(self, z, labels):
|
158
|
-
# n x dim_z
|
159
|
-
out = self.snlinear0(z) # 4*4
|
160
|
-
out = out.view(-1, self.gene_ch*16, 4, 4) # 4 x 4
|
161
|
-
out = self.block1(out, labels) # 8 x 8
|
162
|
-
out = self.block2(out, labels) # 16 x 16
|
163
|
-
out = self.block3(out, labels) # 32 x 32
|
164
|
-
out = self.self_attn(out) # 32 x 32
|
165
|
-
out = self.block4(out, labels) # 64 x 64
|
166
|
-
out = self.block5(out, labels) # 128 x 128
|
167
|
-
out = self.bn(out)
|
168
|
-
out = self.relu(out)
|
169
|
-
out = self.snconv2d1(out)
|
170
|
-
out = self.tanh(out)
|
171
|
-
return out
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
'''
|
176
|
-
|
177
|
-
Discriminator
|
178
|
-
|
179
|
-
'''
|
180
|
-
|
181
|
-
class DiscOptBlock(nn.Module):
|
182
|
-
def __init__(self, in_channels, out_channels):
|
183
|
-
super(DiscOptBlock, self).__init__()
|
184
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
185
|
-
self.relu = nn.ReLU(inplace=True)
|
186
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
187
|
-
self.downsample = nn.AvgPool2d(2)
|
188
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
189
|
-
|
190
|
-
def forward(self, x):
|
191
|
-
x0 = x
|
192
|
-
|
193
|
-
x = self.snconv2d1(x)
|
194
|
-
x = self.relu(x)
|
195
|
-
x = self.snconv2d2(x)
|
196
|
-
x = self.downsample(x)
|
197
|
-
|
198
|
-
x0 = self.downsample(x0)
|
199
|
-
x0 = self.snconv2d0(x0)
|
200
|
-
|
201
|
-
out = x + x0
|
202
|
-
return out
|
203
|
-
|
204
|
-
|
205
|
-
class DiscBlock(nn.Module):
|
206
|
-
def __init__(self, in_channels, out_channels):
|
207
|
-
super(DiscBlock, self).__init__()
|
208
|
-
self.relu = nn.ReLU(inplace=True)
|
209
|
-
self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
210
|
-
self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
|
211
|
-
self.downsample = nn.AvgPool2d(2)
|
212
|
-
self.ch_mismatch = False
|
213
|
-
if in_channels != out_channels:
|
214
|
-
self.ch_mismatch = True
|
215
|
-
self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
|
216
|
-
|
217
|
-
def forward(self, x, downsample=True):
|
218
|
-
x0 = x
|
219
|
-
|
220
|
-
x = self.relu(x)
|
221
|
-
x = self.snconv2d1(x)
|
222
|
-
x = self.relu(x)
|
223
|
-
x = self.snconv2d2(x)
|
224
|
-
if downsample:
|
225
|
-
x = self.downsample(x)
|
226
|
-
|
227
|
-
if downsample or self.ch_mismatch:
|
228
|
-
x0 = self.snconv2d0(x0)
|
229
|
-
if downsample:
|
230
|
-
x0 = self.downsample(x0)
|
231
|
-
|
232
|
-
out = x + x0
|
233
|
-
return out
|
234
|
-
|
235
|
-
|
236
|
-
class CcGAN_SAGAN_Discriminator(nn.Module):
|
237
|
-
"""Discriminator."""
|
238
|
-
|
239
|
-
def __init__(self, dim_embed=128, nc=3, disc_ch=64):
|
240
|
-
super(CcGAN_SAGAN_Discriminator, self).__init__()
|
241
|
-
self.disc_ch = disc_ch
|
242
|
-
self.opt_block1 = DiscOptBlock(nc, disc_ch)
|
243
|
-
self.block1 = DiscBlock(disc_ch, disc_ch*2)
|
244
|
-
self.self_attn = Self_Attn(disc_ch*2)
|
245
|
-
self.block2 = DiscBlock(disc_ch*2, disc_ch*4)
|
246
|
-
self.block3 = DiscBlock(disc_ch*4, disc_ch*8)
|
247
|
-
self.block4 = DiscBlock(disc_ch*8, disc_ch*16)
|
248
|
-
self.block5 = DiscBlock(disc_ch*16, disc_ch*16)
|
249
|
-
self.relu = nn.ReLU(inplace=True)
|
250
|
-
self.snlinear1 = snlinear(in_features=disc_ch*16*4*4, out_features=1)
|
251
|
-
self.sn_embedding1 = snlinear(dim_embed, disc_ch*16*4*4, bias=False)
|
252
|
-
|
253
|
-
# Weight init
|
254
|
-
self.apply(init_weights)
|
255
|
-
xavier_uniform_(self.sn_embedding1.weight)
|
256
|
-
|
257
|
-
def forward(self, x, labels):
|
258
|
-
# 128x128
|
259
|
-
out = self.opt_block1(x) # 128x128
|
260
|
-
out = self.block1(out) # 64 x 64
|
261
|
-
out = self.self_attn(out) # 64 x 64
|
262
|
-
out = self.block2(out) # 32 x 32
|
263
|
-
out = self.block3(out) # 16 x 16
|
264
|
-
out = self.block4(out) # 8 x 8
|
265
|
-
out = self.block5(out, downsample=False) # 4 x 4
|
266
|
-
out = self.relu(out) # n x disc_ch*16 x 4 x 4
|
267
|
-
out = out.view(-1, self.disc_ch*16*4*4)
|
268
|
-
output1 = torch.squeeze(self.snlinear1(out)) # n
|
269
|
-
# Projection
|
270
|
-
h_labels = self.sn_embedding1(labels) # n x disc_ch*16
|
271
|
-
proj = torch.mul(out, h_labels) # n x disc_ch*16
|
272
|
-
output2 = torch.sum(proj, dim=[1]) # n
|
273
|
-
# Out
|
274
|
-
output = output1 + output2 # n
|
275
|
-
return output
|
276
|
-
|
277
|
-
|
278
|
-
if __name__ == "__main__":
|
279
|
-
|
280
|
-
def get_parameter_number(net):
|
281
|
-
total_num = sum(p.numel() for p in net.parameters())
|
282
|
-
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
283
|
-
return {'Total': total_num, 'Trainable': trainable_num}
|
284
|
-
|
285
|
-
|
286
|
-
netG = CcGAN_SAGAN_Generator(dim_z=256, dim_embed=128, gene_ch=128).cuda()
|
287
|
-
netD = CcGAN_SAGAN_Discriminator(dim_embed=128, disc_ch=128).cuda()
|
288
|
-
|
289
|
-
# netG = nn.DataParallel(netG)
|
290
|
-
# netD = nn.DataParallel(netD)
|
291
|
-
|
292
|
-
N=4
|
293
|
-
z = torch.randn(N, 256).cuda()
|
294
|
-
y = torch.randn(N, 128).cuda()
|
295
|
-
x = netG(z,y)
|
296
|
-
o = netD(x,y)
|
297
|
-
print(x.size())
|
298
|
-
print(o.size())
|
299
|
-
|
300
|
-
print('G:', get_parameter_number(netG))
|
301
|
-
print('D:', get_parameter_number(netD))
|
@@ -1,141 +0,0 @@
|
|
1
|
-
'''
|
2
|
-
Regular ResNet
|
3
|
-
|
4
|
-
codes are based on
|
5
|
-
@article{
|
6
|
-
zhang2018mixup,
|
7
|
-
title={mixup: Beyond Empirical Risk Minimization},
|
8
|
-
author={Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz},
|
9
|
-
journal={International Conference on Learning Representations},
|
10
|
-
year={2018},
|
11
|
-
url={https://openreview.net/forum?id=r1Ddp1-Rb},
|
12
|
-
}
|
13
|
-
'''
|
14
|
-
|
15
|
-
|
16
|
-
import torch
|
17
|
-
import torch.nn as nn
|
18
|
-
import torch.nn.functional as F
|
19
|
-
|
20
|
-
from torch.autograd import Variable
|
21
|
-
|
22
|
-
IMG_SIZE=128
|
23
|
-
NC=3
|
24
|
-
|
25
|
-
|
26
|
-
class BasicBlock(nn.Module):
|
27
|
-
expansion = 1
|
28
|
-
|
29
|
-
def __init__(self, in_planes, planes, stride=1):
|
30
|
-
super(BasicBlock, self).__init__()
|
31
|
-
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
32
|
-
self.bn1 = nn.BatchNorm2d(planes)
|
33
|
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
34
|
-
self.bn2 = nn.BatchNorm2d(planes)
|
35
|
-
|
36
|
-
self.shortcut = nn.Sequential()
|
37
|
-
if stride != 1 or in_planes != self.expansion*planes:
|
38
|
-
self.shortcut = nn.Sequential(
|
39
|
-
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
40
|
-
nn.BatchNorm2d(self.expansion*planes)
|
41
|
-
)
|
42
|
-
|
43
|
-
def forward(self, x):
|
44
|
-
out = F.relu(self.bn1(self.conv1(x)))
|
45
|
-
out = self.bn2(self.conv2(out))
|
46
|
-
out += self.shortcut(x)
|
47
|
-
out = F.relu(out)
|
48
|
-
return out
|
49
|
-
|
50
|
-
|
51
|
-
class Bottleneck(nn.Module):
|
52
|
-
expansion = 4
|
53
|
-
|
54
|
-
def __init__(self, in_planes, planes, stride=1):
|
55
|
-
super(Bottleneck, self).__init__()
|
56
|
-
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
57
|
-
self.bn1 = nn.BatchNorm2d(planes)
|
58
|
-
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
59
|
-
self.bn2 = nn.BatchNorm2d(planes)
|
60
|
-
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
61
|
-
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
62
|
-
|
63
|
-
self.shortcut = nn.Sequential()
|
64
|
-
if stride != 1 or in_planes != self.expansion*planes:
|
65
|
-
self.shortcut = nn.Sequential(
|
66
|
-
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
67
|
-
nn.BatchNorm2d(self.expansion*planes)
|
68
|
-
)
|
69
|
-
|
70
|
-
def forward(self, x):
|
71
|
-
out = F.relu(self.bn1(self.conv1(x)))
|
72
|
-
out = F.relu(self.bn2(self.conv2(out)))
|
73
|
-
out = self.bn3(self.conv3(out))
|
74
|
-
out += self.shortcut(x)
|
75
|
-
out = F.relu(out)
|
76
|
-
return out
|
77
|
-
|
78
|
-
|
79
|
-
class ResNet_class_eval(nn.Module):
|
80
|
-
def __init__(self, block, num_blocks, num_classes=49, nc=NC, ngpu = 1):
|
81
|
-
super(ResNet_class_eval, self).__init__()
|
82
|
-
self.in_planes = 64
|
83
|
-
self.ngpu = ngpu
|
84
|
-
|
85
|
-
self.main = nn.Sequential(
|
86
|
-
nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False), # h=h
|
87
|
-
nn.BatchNorm2d(64),
|
88
|
-
nn.ReLU(),
|
89
|
-
nn.MaxPool2d(2,2), #h=h/2 64
|
90
|
-
# self._make_layer(block, 64, num_blocks[0], stride=1), # h=h
|
91
|
-
self._make_layer(block, 64, num_blocks[0], stride=2), # h=h/2 32
|
92
|
-
self._make_layer(block, 128, num_blocks[1], stride=2),
|
93
|
-
self._make_layer(block, 256, num_blocks[2], stride=2),
|
94
|
-
self._make_layer(block, 512, num_blocks[3], stride=2),
|
95
|
-
nn.AvgPool2d(kernel_size=4)
|
96
|
-
)
|
97
|
-
self.classifier = nn.Linear(512*block.expansion, num_classes)
|
98
|
-
|
99
|
-
def _make_layer(self, block, planes, num_blocks, stride):
|
100
|
-
strides = [stride] + [1]*(num_blocks-1)
|
101
|
-
layers = []
|
102
|
-
for stride in strides:
|
103
|
-
layers.append(block(self.in_planes, planes, stride))
|
104
|
-
self.in_planes = planes * block.expansion
|
105
|
-
return nn.Sequential(*layers)
|
106
|
-
|
107
|
-
def forward(self, x):
|
108
|
-
|
109
|
-
if x.is_cuda and self.ngpu > 1:
|
110
|
-
features = nn.parallel.data_parallel(self.main, x, range(self.ngpu))
|
111
|
-
features = features.view(features.size(0), -1)
|
112
|
-
out = nn.parallel.data_parallel(self.classifier, features, range(self.ngpu))
|
113
|
-
else:
|
114
|
-
features = self.main(x)
|
115
|
-
features = features.view(features.size(0), -1)
|
116
|
-
out = self.classifier(features)
|
117
|
-
return out, features
|
118
|
-
|
119
|
-
|
120
|
-
def ResNet18_class_eval(num_classes=49, ngpu = 1):
|
121
|
-
return ResNet_class_eval(BasicBlock, [2,2,2,2], num_classes=num_classes, ngpu = ngpu)
|
122
|
-
|
123
|
-
def ResNet34_class_eval(num_classes=49, ngpu = 1):
|
124
|
-
return ResNet_class_eval(BasicBlock, [3,4,6,3], num_classes=num_classes, ngpu = ngpu)
|
125
|
-
|
126
|
-
def ResNet50_class_eval(num_classes=49, ngpu = 1):
|
127
|
-
return ResNet_class_eval(Bottleneck, [3,4,6,3], num_classes=num_classes, ngpu = ngpu)
|
128
|
-
|
129
|
-
def ResNet101_class_eval(num_classes=49, ngpu = 1):
|
130
|
-
return ResNet_class_eval(Bottleneck, [3,4,23,3], num_classes=num_classes, ngpu = ngpu)
|
131
|
-
|
132
|
-
def ResNet152_class_eval(num_classes=49, ngpu = 1):
|
133
|
-
return ResNet_class_eval(Bottleneck, [3,8,36,3], num_classes=num_classes, ngpu = ngpu)
|
134
|
-
|
135
|
-
|
136
|
-
if __name__ == "__main__":
|
137
|
-
net = ResNet50_class_eval(num_classes=5, ngpu = 1).cuda()
|
138
|
-
x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
|
139
|
-
out, features = net(x)
|
140
|
-
print(out.size())
|
141
|
-
print(features.size())
|