Myosotis-Researches 0.0.14__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16/Myosotis_Researches.egg-info}/PKG-INFO +1 -1
  2. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/SOURCES.txt +12 -0
  3. {myosotis_researches-0.0.14/Myosotis_Researches.egg-info → myosotis_researches-0.0.16}/PKG-INFO +1 -1
  4. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +76 -0
  5. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/eval_metrics.py +205 -0
  6. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/opts.py +87 -0
  7. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_AE.py +268 -0
  8. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +251 -0
  9. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +255 -0
  10. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_ccgan.py +303 -0
  11. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_cgan.py +254 -0
  12. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_cgan_concat.py +242 -0
  13. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py +181 -0
  14. myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128/utils.py +120 -0
  15. myosotis_researches-0.0.16/myosotis_researches/__init__.py +0 -0
  16. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/setup.py +1 -1
  17. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/LICENSE +0 -0
  18. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/dependency_links.txt +0 -0
  19. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/Myosotis_Researches.egg-info/top_level.txt +0 -0
  20. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/README.md +0 -0
  21. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/__init__.py +0 -0
  22. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -0
  23. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -0
  24. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -0
  25. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -0
  26. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/__init__.py +0 -0
  27. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/autoencoder.py +0 -0
  28. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -0
  29. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -0
  30. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -0
  31. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -0
  32. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -0
  33. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -0
  34. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/__init__.py +0 -0
  35. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/autoencoder.py +0 -0
  36. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -0
  37. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -0
  38. {myosotis_researches-0.0.14/myosotis_researches → myosotis_researches-0.0.16/myosotis_researches/CcGAN/train_128}/__init__.py +0 -0
  39. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/__init__.py +0 -0
  40. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/concat_image_horizontal.py +0 -0
  41. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/concat_image_vertical.py +0 -0
  42. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/make_h5.py +0 -0
  43. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/myosotis_researches/CcGAN/utils/print_hdf5_structure.py +0 -0
  44. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/pyproject.toml +0 -0
  45. {myosotis_researches-0.0.14 → myosotis_researches-0.0.16}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -24,6 +24,18 @@ myosotis_researches/CcGAN/models_256/__init__.py
24
24
  myosotis_researches/CcGAN/models_256/autoencoder.py
25
25
  myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py
26
26
  myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py
27
+ myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py
28
+ myosotis_researches/CcGAN/train_128/__init__.py
29
+ myosotis_researches/CcGAN/train_128/eval_metrics.py
30
+ myosotis_researches/CcGAN/train_128/opts.py
31
+ myosotis_researches/CcGAN/train_128/pretrain_AE.py
32
+ myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py
33
+ myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py
34
+ myosotis_researches/CcGAN/train_128/train_ccgan.py
35
+ myosotis_researches/CcGAN/train_128/train_cgan.py
36
+ myosotis_researches/CcGAN/train_128/train_cgan_concat.py
37
+ myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py
38
+ myosotis_researches/CcGAN/train_128/utils.py
27
39
  myosotis_researches/CcGAN/utils/__init__.py
28
40
  myosotis_researches/CcGAN/utils/concat_image_horizontal.py
29
41
  myosotis_researches/CcGAN/utils/concat_image_vertical.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -0,0 +1,76 @@
1
+ # Differentiable Augmentation for Data-Efficient GAN Training
2
+ # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
+ # https://arxiv.org/pdf/2006.10738
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def DiffAugment(x, policy='', channels_first=True):
10
+ if policy:
11
+ if not channels_first:
12
+ x = x.permute(0, 3, 1, 2)
13
+ for p in policy.split(','):
14
+ for f in AUGMENT_FNS[p]:
15
+ x = f(x)
16
+ if not channels_first:
17
+ x = x.permute(0, 2, 3, 1)
18
+ x = x.contiguous()
19
+ return x
20
+
21
+
22
+ def rand_brightness(x):
23
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
+ return x
25
+
26
+
27
+ def rand_saturation(x):
28
+ x_mean = x.mean(dim=1, keepdim=True)
29
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
+ return x
31
+
32
+
33
+ def rand_contrast(x):
34
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
+ return x
37
+
38
+
39
+ def rand_translation(x, ratio=0.125):
40
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
+ grid_batch, grid_x, grid_y = torch.meshgrid(
44
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
+ )
48
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
+ return x
53
+
54
+
55
+ def rand_cutout(x, ratio=0.5):
56
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
+ grid_batch, grid_x, grid_y = torch.meshgrid(
60
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
+ )
64
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
+ mask[grid_batch, grid_x, grid_y] = 0
68
+ x = x * mask.unsqueeze(1)
69
+ return x
70
+
71
+
72
+ AUGMENT_FNS = {
73
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
74
+ 'translation': [rand_translation],
75
+ 'cutout': [rand_cutout],
76
+ }
@@ -0,0 +1,205 @@
1
+ """
2
+ Compute
3
+ Inception Score (IS),
4
+ Frechet Inception Discrepency (FID), ref "https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py"
5
+ Maximum Mean Discrepancy (MMD)
6
+ for a set of fake images
7
+
8
+ use numpy array
9
+ Xr: high-level features for real images; nr by d array
10
+ Yr: labels for real images
11
+ Xg: high-level features for fake images; ng by d array
12
+ Yg: labels for fake images
13
+ IMGSr: real images
14
+ IMGSg: fake images
15
+
16
+ """
17
+
18
+ import os
19
+ import gc
20
+ import numpy as np
21
+ # from numpy import linalg as LA
22
+ from scipy import linalg
23
+ import torch
24
+ import torch.nn as nn
25
+ from scipy.stats import entropy
26
+ from torch.nn import functional as F
27
+ from torchvision.utils import save_image
28
+
29
+ from .utils import SimpleProgressBar, IMGs_dataset
30
+
31
+
32
+ def normalize_images(batch_images):
33
+ batch_images = batch_images/255.0
34
+ batch_images = (batch_images - 0.5)/0.5
35
+ return batch_images
36
+
37
+ ##############################################################################
38
+ # FID scores
39
+ ##############################################################################
40
+ # compute FID based on extracted features
41
+ def FID(Xr, Xg, eps=1e-10):
42
+ '''
43
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
44
+ and X_2 ~ N(mu_2, C_2) is
45
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
46
+ '''
47
+ #sample mean
48
+ MUr = np.mean(Xr, axis = 0)
49
+ MUg = np.mean(Xg, axis = 0)
50
+ mean_diff = MUr - MUg
51
+ #sample covariance
52
+ SIGMAr = np.cov(Xr.transpose())
53
+ SIGMAg = np.cov(Xg.transpose())
54
+
55
+ # Product might be almost singular
56
+ covmean, _ = linalg.sqrtm(SIGMAr.dot(SIGMAg), disp=False)#square root of a matrix
57
+ covmean = covmean.real
58
+ if not np.isfinite(covmean).all():
59
+ msg = ('fid calculation produces singular product; '
60
+ 'adding %s to diagonal of cov estimates') % eps
61
+ print(msg)
62
+ offset = np.eye(SIGMAr.shape[0]) * eps
63
+ covmean = linalg.sqrtm((SIGMAr + offset).dot(SIGMAg + offset))
64
+
65
+ #fid score
66
+ fid_score = mean_diff.dot(mean_diff) + np.trace(SIGMAr + SIGMAg - 2*covmean)
67
+
68
+ return fid_score
69
+
70
+ ##test
71
+ #Xr = np.random.rand(10000,1000)
72
+ #Xg = np.random.rand(10000,1000)
73
+ #print(FID(Xr, Xg))
74
+
75
+ # compute FID from raw images
76
+ def cal_FID(PreNetFID, IMGSr, IMGSg, batch_size = 500, resize = None, norm_img = False):
77
+ #resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
78
+
79
+ PreNetFID.eval()
80
+
81
+ nr = IMGSr.shape[0]
82
+ ng = IMGSg.shape[0]
83
+
84
+ nc = IMGSr.shape[1] #IMGSr is nrxNCxIMG_SIExIMG_SIZE
85
+ img_size = IMGSr.shape[2]
86
+
87
+ if batch_size > min(nr, ng):
88
+ batch_size = min(nr, ng)
89
+ # print("FID: recude batch size to {}".format(batch_size))
90
+
91
+ #compute the length of extracted features
92
+ with torch.no_grad():
93
+ test_img = torch.from_numpy(IMGSr[0].reshape((1,nc,img_size,img_size))).type(torch.float).cuda()
94
+ if resize is not None:
95
+ test_img = nn.functional.interpolate(test_img, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
96
+ if norm_img:
97
+ test_img = normalize_images(test_img)
98
+ # _, test_features = PreNetFID(test_img)
99
+ test_features = PreNetFID(test_img)
100
+ d = test_features.shape[1] #length of extracted features
101
+
102
+ Xr = np.zeros((nr, d))
103
+ Xg = np.zeros((ng, d))
104
+
105
+ #batch_size = 500
106
+ with torch.no_grad():
107
+ tmp = 0
108
+ pb1 = SimpleProgressBar()
109
+ for i in range(nr//batch_size):
110
+ imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp+batch_size)]).type(torch.float).cuda()
111
+ if resize is not None:
112
+ imgr_tensor = nn.functional.interpolate(imgr_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
113
+ if norm_img:
114
+ imgr_tensor = normalize_images(imgr_tensor)
115
+ # _, Xr_tmp = PreNetFID(imgr_tensor)
116
+ Xr_tmp = PreNetFID(imgr_tensor)
117
+ Xr[tmp:(tmp+batch_size)] = Xr_tmp.detach().cpu().numpy()
118
+ tmp+=batch_size
119
+ # pb1.update(min(float(i)*100/(nr//batch_size), 100))
120
+ pb1.update(min(max(tmp/nr*100,100), 100))
121
+ del Xr_tmp,imgr_tensor; gc.collect()
122
+ torch.cuda.empty_cache()
123
+
124
+ tmp = 0
125
+ pb2 = SimpleProgressBar()
126
+ for j in range(ng//batch_size):
127
+ imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp+batch_size)]).type(torch.float).cuda()
128
+ if resize is not None:
129
+ imgg_tensor = nn.functional.interpolate(imgg_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
130
+ if norm_img:
131
+ imgg_tensor = normalize_images(imgg_tensor)
132
+ # _, Xg_tmp = PreNetFID(imgg_tensor)
133
+ Xg_tmp = PreNetFID(imgg_tensor)
134
+ Xg[tmp:(tmp+batch_size)] = Xg_tmp.detach().cpu().numpy()
135
+ tmp+=batch_size
136
+ # pb2.update(min(float(j)*100/(ng//batch_size), 100))
137
+ pb2.update(min(max(tmp/ng*100, 100), 100))
138
+ del Xg_tmp,imgg_tensor; gc.collect()
139
+ torch.cuda.empty_cache()
140
+
141
+
142
+ fid_score = FID(Xr, Xg, eps=1e-6)
143
+
144
+ return fid_score
145
+
146
+
147
+
148
+
149
+
150
+
151
+ ##############################################################################
152
+ # label_score
153
+ # difference between assigned label and predicted label
154
+ ##############################################################################
155
+ def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None, norm_img = False, num_workers=0):
156
+ '''
157
+ PreNet: pre-trained CNN
158
+ images: fake images
159
+ labels_assi: assigned labels
160
+ resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
161
+ '''
162
+
163
+ PreNet.eval()
164
+
165
+ # assume images are nxncximg_sizeximg_size
166
+ n = images.shape[0]
167
+ nc = images.shape[1] #number of channels
168
+ img_size = images.shape[2]
169
+ labels_assi = labels_assi.reshape(-1)
170
+
171
+ eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
172
+ eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
173
+
174
+ labels_pred = np.zeros(n+batch_size)
175
+
176
+ nimgs_got = 0
177
+ pb = SimpleProgressBar()
178
+ for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
179
+ batch_images = batch_images.type(torch.float).cuda()
180
+ batch_labels = batch_labels.type(torch.float).cuda()
181
+ batch_size_curr = len(batch_labels)
182
+
183
+ if norm_img:
184
+ batch_images = normalize_images(batch_images)
185
+
186
+ batch_labels_pred, _ = PreNet(batch_images)
187
+ labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)
188
+
189
+ nimgs_got += batch_size_curr
190
+ pb.update((float(nimgs_got)/n)*100)
191
+
192
+ del batch_images; gc.collect()
193
+ torch.cuda.empty_cache()
194
+ #end for batch_idx
195
+
196
+ labels_pred = labels_pred[0:n]
197
+
198
+
199
+ labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
200
+ labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)
201
+
202
+ ls_mean = np.mean(np.abs(labels_pred-labels_assi))
203
+ ls_std = np.std(np.abs(labels_pred-labels_assi))
204
+
205
+ return ls_mean, ls_std
@@ -0,0 +1,87 @@
1
+ import argparse
2
+
3
+ def parse_opts():
4
+ parser = argparse.ArgumentParser()
5
+
6
+ ''' Overall Settings '''
7
+ parser.add_argument('--root_path', type=str, default='')
8
+ parser.add_argument('--data_path', type=str, default='')
9
+ parser.add_argument('--eval_ckpt_path', type=str, default='')
10
+ parser.add_argument('--seed', type=int, default=2021, metavar='S', help='random seed (default: 2020)')
11
+ parser.add_argument('--num_workers', type=int, default=0)
12
+
13
+
14
+ ''' Dataset '''
15
+ ## Data split: RC-49 is split into a train set (the last decimal of the degree is odd) and a test set (the last decimal of the degree is even); the unique labels in two sets do not overlap.
16
+ parser.add_argument('--data_split', type=str, default='train',
17
+ choices=['all', 'train'])
18
+ parser.add_argument('--min_label', type=float, default=0.0)
19
+ parser.add_argument('--max_label', type=float, default=90.0)
20
+ parser.add_argument('--num_channels', type=int, default=3, metavar='N')
21
+ parser.add_argument('--img_size', type=int, default=128, metavar='N')
22
+ parser.add_argument('--max_num_img_per_label', type=int, default=50, metavar='N')
23
+ parser.add_argument('--max_num_img_per_label_after_replica', type=int, default=0, metavar='N')
24
+ parser.add_argument('--show_real_imgs', action='store_true', default=False)
25
+ parser.add_argument('--visualize_fake_images', action='store_true', default=False)
26
+
27
+
28
+ ''' GAN settings '''
29
+ parser.add_argument('--GAN', type=str, default='CcGAN', choices=['cGAN', 'cGAN-concat', 'CcGAN'])
30
+ parser.add_argument('--GAN_arch', type=str, default='SAGAN', choices=['SAGAN'])
31
+
32
+ # label embedding setting
33
+ parser.add_argument('--net_embed', type=str, default='ResNet34_embed') #ResNetXX_emebed
34
+ parser.add_argument('--epoch_cnn_embed', type=int, default=200) #epoch of cnn training for label embedding
35
+ parser.add_argument('--resumeepoch_cnn_embed', type=int, default=0) #epoch of cnn training for label embedding
36
+ parser.add_argument('--epoch_net_y2h', type=int, default=500)
37
+ parser.add_argument('--dim_embed', type=int, default=128) #dimension of the embedding space
38
+ parser.add_argument('--batch_size_embed', type=int, default=256, metavar='N')
39
+
40
+ parser.add_argument('--loss_type_gan', type=str, default='hinge')
41
+ parser.add_argument('--niters_gan', type=int, default=10000, help='number of iterations')
42
+ parser.add_argument('--resume_niters_gan', type=int, default=0)
43
+ parser.add_argument('--save_niters_freq', type=int, default=2000, help='frequency of saving checkpoints')
44
+ parser.add_argument('--lr_g_gan', type=float, default=1e-4, help='learning rate for generator')
45
+ parser.add_argument('--lr_d_gan', type=float, default=1e-4, help='learning rate for discriminator')
46
+ parser.add_argument('--dim_gan', type=int, default=128, help='Latent dimension of GAN')
47
+ parser.add_argument('--batch_size_disc', type=int, default=64)
48
+ parser.add_argument('--batch_size_gene', type=int, default=64)
49
+ parser.add_argument('--num_D_steps', type=int, default=4, help='number of Ds updates in one iteration')
50
+ parser.add_argument('--cGAN_num_classes', type=int, default=20, metavar='N') #bin label into cGAN_num_classes
51
+ parser.add_argument('--visualize_freq', type=int, default=2000, help='frequency of visualization')
52
+
53
+ parser.add_argument('--kernel_sigma', type=float, default=-1.0,
54
+ help='If kernel_sigma<0, then use rule-of-thumb formula to compute the sigma.')
55
+ parser.add_argument('--threshold_type', type=str, default='hard', choices=['soft', 'hard'])
56
+ parser.add_argument('--kappa', type=float, default=-1)
57
+ parser.add_argument('--nonzero_soft_weight_threshold', type=float, default=1e-3,
58
+ help='threshold for determining nonzero weights for SVDL; we neglect images with too small weights')
59
+
60
+ # DiffAugment setting
61
+ parser.add_argument('--gan_DiffAugment', action='store_true', default=False)
62
+ parser.add_argument('--gan_DiffAugment_policy', type=str, default='color,translation,cutout')
63
+
64
+
65
+ # evaluation setting
66
+ '''
67
+ Four evaluation modes:
68
+ Mode 1: eval on unique labels used for GAN training;
69
+ Mode 2. eval on all unique labels in the dataset and when computing FID use all real images in the dataset;
70
+ Mode 3. eval on all unique labels in the dataset and when computing FID only use real images for GAN training in the dataset (to test SFID's effectiveness on unseen labels);
71
+ Mode 4. eval on a interval [min_label, max_label] with num_eval_labels labels.
72
+ '''
73
+ parser.add_argument('--eval_mode', type=int, default=2)
74
+ parser.add_argument('--num_eval_labels', type=int, default=-1)
75
+ parser.add_argument('--samp_batch_size', type=int, default=200)
76
+ parser.add_argument('--nfake_per_label', type=int, default=200)
77
+ parser.add_argument('--nreal_per_label', type=int, default=-1)
78
+ parser.add_argument('--comp_FID', action='store_true', default=False)
79
+ parser.add_argument('--epoch_FID_CNN', type=int, default=200)
80
+ parser.add_argument('--FID_radius', type=float, default=0)
81
+ parser.add_argument('--FID_num_centers', type=int, default=-1)
82
+ parser.add_argument('--dump_fake_for_NIQE', action='store_true', default=False,
83
+ help='Dump fake images for computing NIQE')
84
+
85
+ args = parser.parse_args()
86
+
87
+ return args
@@ -0,0 +1,268 @@
1
+
2
+ import os
3
+ import argparse
4
+ import shutil
5
+ import timeit
6
+ import torch
7
+ import torchvision
8
+ import torchvision.transforms as transforms
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import torch.backends.cudnn as cudnn
12
+ import random
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib as mpl
15
+ from torch import autograd
16
+ from torchvision.utils import save_image
17
+ import csv
18
+ from tqdm import tqdm
19
+ import gc
20
+ import h5py
21
+
22
+
23
+ #############################
24
+ # Settings
25
+ #############################
26
+
27
+ parser = argparse.ArgumentParser(description='Pre-train AE for computing FID')
28
+ parser.add_argument('--root_path', type=str, default='')
29
+ parser.add_argument('--data_path', type=str, default='')
30
+ parser.add_argument('--num_workers', type=int, default=0)
31
+ parser.add_argument('--dim_bottleneck', type=int, default=512)
32
+ parser.add_argument('--epochs', type=int, default=200, metavar='N',
33
+ help='number of epochs to train CNNs (default: 200)')
34
+ parser.add_argument('--resume_epoch', type=int, default=0)
35
+ parser.add_argument('--batch_size_train', type=int, default=128, metavar='N',
36
+ help='input batch size for training')
37
+ parser.add_argument('--batch_size_valid', type=int, default=10, metavar='N',
38
+ help='input batch size for testing')
39
+ parser.add_argument('--base_lr', type=float, default=1e-3,
40
+ help='learning rate, default=1e-3')
41
+ parser.add_argument('--lr_decay_epochs', type=int, default=50) #decay lr rate every dre_lr_decay_epochs epochs
42
+ parser.add_argument('--lr_decay_factor', type=float, default=0.1)
43
+ parser.add_argument('--lambda_sparsity', type=float, default=1e-4, help='penalty for sparsity')
44
+ parser.add_argument('--weight_dacay', type=float, default=1e-4,
45
+ help='Weigth decay, default=1e-4')
46
+ parser.add_argument('--seed', type=int, default=2020, metavar='S',
47
+ help='random seed (default: 1)')
48
+ parser.add_argument('--CVMode', action='store_true', default=False,
49
+ help='CV mode?')
50
+ parser.add_argument('--img_size', type=int, default=128, metavar='N')
51
+ parser.add_argument('--min_label', type=float, default=0.0)
52
+ parser.add_argument('--max_label', type=float, default=90.0)
53
+ args = parser.parse_args()
54
+
55
+ wd = args.root_path
56
+ os.chdir(wd)
57
+ from ..models_128 import *
58
+ from .utils import IMGs_dataset, SimpleProgressBar
59
+
60
+ # some parameters in the opts
61
+ dim_bottleneck = args.dim_bottleneck
62
+ epochs = args.epochs
63
+ base_lr = args.base_lr
64
+ lr_decay_epochs = args.lr_decay_epochs
65
+ lr_decay_factor = args.lr_decay_factor
66
+ resume_epoch = args.resume_epoch
67
+ lambda_sparsity = args.lambda_sparsity
68
+
69
+
70
+ # random seed
71
+ random.seed(args.seed)
72
+ torch.manual_seed(args.seed)
73
+ torch.backends.cudnn.deterministic = True
74
+ cudnn.benchmark = False
75
+ np.random.seed(args.seed)
76
+
77
+ # directories for checkpoint, images and log files
78
+ save_models_folder = wd + '/output/eval_models'
79
+ os.makedirs(save_models_folder, exist_ok=True)
80
+ save_AE_images_in_train_folder = save_models_folder + '/AE_lambda_{}_images_in_train'.format(lambda_sparsity)
81
+ os.makedirs(save_AE_images_in_train_folder, exist_ok=True)
82
+ save_AE_images_in_valid_folder = save_models_folder + '/AE_lambda_{}_images_in_valid'.format(lambda_sparsity)
83
+ os.makedirs(save_AE_images_in_valid_folder, exist_ok=True)
84
+
85
+
86
+ ###########################################################################################################
87
+ # Data
88
+ ###########################################################################################################
89
+ # data loader
90
+ data_filename = args.data_path + '/Ra_' + str(args.img_size) + 'x' + str(args.img_size) + '.h5'
91
+ hf = h5py.File(data_filename, 'r')
92
+ labels = hf['labels'][:]
93
+ labels = labels.astype(float)
94
+ images = hf['images'][:]
95
+ hf.close()
96
+ N_all = len(images)
97
+ assert len(images) == len(labels)
98
+
99
+ q1 = args.min_label
100
+ q2 = args.max_label
101
+ indx = np.where((labels>q1)*(labels<q2)==True)[0]
102
+ labels = labels[indx]
103
+ images = images[indx]
104
+ assert len(labels)==len(images)
105
+
106
+ # define training and validation sets
107
+ if args.CVMode:
108
+ #90% Training; 10% valdation
109
+ valid_prop = 0.1 #proportion of the validation samples
110
+ indx_all = np.arange(len(images))
111
+ np.random.shuffle(indx_all)
112
+ indx_valid = indx_all[0:int(valid_prop*len(images))]
113
+ indx_train = indx_all[int(valid_prop*len(images)):]
114
+
115
+ trainset = IMGs_dataset(images[indx_train], labels=None, normalize=True)
116
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
117
+ validset = IMGs_dataset(images[indx_valid], labels=None, normalize=True)
118
+ validloader = torch.utils.data.DataLoader(validset, batch_size=args.batch_size_valid, shuffle=False, num_workers=args.num_workers)
119
+
120
+ else:
121
+ trainset = IMGs_dataset(images, labels=None, normalize=True)
122
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_train, shuffle=True, num_workers=args.num_workers)
123
+
124
+
125
+ ###########################################################################################################
126
+ # Necessary functions
127
+ ###########################################################################################################
128
+
129
+ def adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor):
130
+ lr = base_lr #1e-4
131
+
132
+ for i in range(epochs//lr_decay_epochs):
133
+ if epoch >= (i+1)*lr_decay_epochs:
134
+ lr *= lr_decay_factor
135
+
136
+ for param_group in optimizer.param_groups:
137
+ param_group['lr'] = lr
138
+
139
+ def train_AE():
140
+
141
+ # define optimizer
142
+ params = list(net_encoder.parameters()) + list(net_decoder.parameters())
143
+ optimizer = torch.optim.Adam(params, lr = base_lr, betas=(0.5, 0.999), weight_decay=1e-4)
144
+
145
+ # criterion
146
+ criterion = nn.MSELoss()
147
+
148
+ if resume_epoch>0:
149
+ print("Loading ckpt to resume training AE >>>")
150
+ ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(resume_epoch, lambda_sparsity)
151
+ checkpoint = torch.load(ckpt_fullpath)
152
+ net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
153
+ net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
154
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
155
+ torch.set_rng_state(checkpoint['rng_state'])
156
+ gen_iterations = checkpoint['gen_iterations']
157
+ else:
158
+ gen_iterations = 0
159
+
160
+ start_time = timeit.default_timer()
161
+ for epoch in range(resume_epoch, epochs):
162
+
163
+ adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor)
164
+
165
+ train_loss = 0
166
+
167
+ for batch_idx, batch_real_images in enumerate(trainloader):
168
+
169
+ net_encoder.train()
170
+ net_decoder.train()
171
+
172
+ batch_size_curr = batch_real_images.shape[0]
173
+
174
+ batch_real_images = batch_real_images.type(torch.float).cuda()
175
+
176
+
177
+ batch_features = net_encoder(batch_real_images)
178
+ batch_recons_images = net_decoder(batch_features)
179
+
180
+ '''
181
+ based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
182
+ '''
183
+ loss = criterion(batch_recons_images, batch_real_images) + lambda_sparsity * batch_features.mean()
184
+
185
+ #backward pass
186
+ optimizer.zero_grad()
187
+ loss.backward()
188
+ optimizer.step()
189
+
190
+ train_loss += loss.cpu().item()
191
+
192
+ gen_iterations += 1
193
+
194
+ if gen_iterations % 100 == 0:
195
+ n_row=min(10, int(np.sqrt(batch_size_curr)))
196
+ with torch.no_grad():
197
+ batch_recons_images = net_decoder(net_encoder(batch_real_images[0:n_row**2]))
198
+ batch_recons_images = batch_recons_images.detach().cpu()
199
+ save_image(batch_recons_images.data, save_AE_images_in_train_folder + '/{}.png'.format(gen_iterations), nrow=n_row, normalize=True)
200
+
201
+ if gen_iterations % 20 == 0:
202
+ print("AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]".format(lambda_sparsity, gen_iterations, epoch+1, epochs, train_loss/(batch_idx+1), timeit.default_timer()-start_time) )
203
+ # end for batch_idx
204
+
205
+ if (epoch+1) % 50 == 0:
206
+ save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(epoch+1, lambda_sparsity)
207
+ os.makedirs(os.path.dirname(save_file), exist_ok=True)
208
+ torch.save({
209
+ 'gen_iterations': gen_iterations,
210
+ 'net_encoder_state_dict': net_encoder.state_dict(),
211
+ 'net_decoder_state_dict': net_decoder.state_dict(),
212
+ 'optimizer_state_dict': optimizer.state_dict(),
213
+ 'rng_state': torch.get_rng_state()
214
+ }, save_file)
215
+ #end for epoch
216
+
217
+ return net_encoder, net_decoder
218
+
219
+
220
+ if args.CVMode:
221
+ def valid_AE():
222
+ net_encoder.eval()
223
+ net_decoder.eval()
224
+ with torch.no_grad():
225
+ for batch_idx, images in enumerate(validloader):
226
+ images = images.type(torch.float).cuda()
227
+ features = net_encoder(images)
228
+ recons_images = net_decoder(features)
229
+ save_image(recons_images.data, save_AE_images_in_valid_folder + '/{}_recons.png'.format(batch_idx), nrow=10, normalize=True)
230
+ save_image(images.data, save_AE_images_in_valid_folder + '/{}_real.png'.format(batch_idx), nrow=10, normalize=True)
231
+ return None
232
+
233
+
234
+
235
+ ###########################################################################################################
236
+ # Training and validation
237
+ ###########################################################################################################
238
+
239
+ # model initialization
240
+ net_encoder = encoder(dim_bottleneck=args.dim_bottleneck).cuda()
241
+ net_decoder = decoder(dim_bottleneck=args.dim_bottleneck).cuda()
242
+ net_encoder = nn.DataParallel(net_encoder)
243
+ net_decoder = nn.DataParallel(net_decoder)
244
+
245
+ filename_ckpt = save_models_folder + '/ckpt_AE_epoch_{}_seed_{}_CVMode_{}.pth'.format(args.epochs, args.seed, args.CVMode)
246
+
247
+ # training
248
+ if not os.path.isfile(filename_ckpt):
249
+ print("\n Begin training AE: ")
250
+ start = timeit.default_timer()
251
+ net_encoder, net_decoder = train_AE()
252
+ stop = timeit.default_timer()
253
+ print("Time elapses: {}s".format(stop - start))
254
+ # save model
255
+ torch.save({
256
+ 'net_encoder_state_dict': net_encoder.state_dict(),
257
+ 'net_decoder_state_dict': net_decoder.state_dict(),
258
+ }, filename_ckpt)
259
+ else:
260
+ print("\n Ckpt already exists")
261
+ print("\n Loading...")
262
+ checkpoint = torch.load(filename_ckpt)
263
+ net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
264
+ net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
265
+
266
+ if args.CVMode:
267
+ #validation
268
+ _ = valid_AE()