Myosotis-Researches 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. myosotis_researches/CcGAN/train/__init__.py +4 -0
  2. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_ccgan.py +4 -4
  3. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan.py +1 -3
  4. myosotis_researches/CcGAN/{train_128_output_10 → train}/train_cgan_concat.py +1 -3
  5. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/METADATA +1 -1
  6. myosotis_researches-0.1.10.dist-info/RECORD +40 -0
  7. myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  8. myosotis_researches/CcGAN/train_128/__init__.py +0 -0
  9. myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  10. myosotis_researches/CcGAN/train_128/opts.py +0 -87
  11. myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  12. myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  13. myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  14. myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  15. myosotis_researches/CcGAN/train_128/train_cgan.py +0 -254
  16. myosotis_researches/CcGAN/train_128/train_cgan_concat.py +0 -242
  17. myosotis_researches/CcGAN/train_128/utils.py +0 -120
  18. myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  19. myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  20. myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  21. myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  22. myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  23. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  24. myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  25. myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  26. myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  27. myosotis_researches-0.1.8.dist-info/RECORD +0 -59
  28. /myosotis_researches/CcGAN/{train_128 → train}/train_net_for_label_embed.py +0 -0
  29. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/WHEEL +0 -0
  30. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/licenses/LICENSE +0 -0
  31. {myosotis_researches-0.1.8.dist-info → myosotis_researches-0.1.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4 @@
1
+ from .train_ccgan import *
2
+ from .train_cgan import *
3
+ from .train_cgan_concat import *
4
+ from .train_net_for_label_embed import *
@@ -5,10 +5,9 @@ import timeit
5
5
  from PIL import Image
6
6
  from torchvision.utils import save_image
7
7
  import torch.cuda as cutorch
8
+ import sys
8
9
 
9
- from .utils import SimpleProgressBar, IMGs_dataset
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
10
+ from myosotis_researches.CcGAN.utils import *
12
11
 
13
12
  ''' Settings '''
14
13
  args = parse_opts()
@@ -79,7 +78,8 @@ def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net
79
78
  # printed images with labels between the 5-th quantile and 95-th quantile of training labels
80
79
  n_row=10; n_col = 1
81
80
  z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
82
- selected_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
81
+
82
+ selected_labels = np.array([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
83
83
 
84
84
  y_fixed = np.zeros(n_row*n_col)
85
85
  for i in range(n_row):
@@ -6,9 +6,7 @@ import numpy as np
6
6
  import os
7
7
  import timeit
8
8
 
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
9
+ from myosotis_researches.CcGAN.utils import *
12
10
 
13
11
  ''' Settings '''
14
12
  args = parse_opts()
@@ -6,9 +6,7 @@ import numpy as np
6
6
  import os
7
7
  import timeit
8
8
 
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
9
+ from myosotis_researches.CcGAN.utils import *
12
10
 
13
11
  ''' Settings '''
14
12
  args = parse_opts()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -0,0 +1,40 @@
1
+ myosotis_researches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ myosotis_researches/CcGAN/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ myosotis_researches/CcGAN/internal/__init__.py,sha256=b-63yANNRQXgLF9k9yGdrm7mlULqGic1HTQTzg9wIME,209
4
+ myosotis_researches/CcGAN/internal/install_datasets.py,sha256=jJwLOZrDnHMrJSUhXxSIFobdeWK5N6eitPmjeBW9FyA,1144
5
+ myosotis_researches/CcGAN/internal/show_datasets.py,sha256=BWtQ6vdiEUOTrOs8aMBv6utuUN0IiaLKcK5iXq9y2qI,363
6
+ myosotis_researches/CcGAN/internal/uninstall_datasets.py,sha256=7pxPZcSe9RHncF0I_4rf8ZdI7eQwv-sFVfxzSVZfYHQ,297
7
+ myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py,sha256=uYDngtHoB7frPg2Vs7YCFXeUh7Y7MjaAXbRWHXO_xvw,10629
8
+ myosotis_researches/CcGAN/models_128/ResNet_class_eval.py,sha256=wa5CPkYzrS0X6kZ6pGHM-GxcGNkSpBdTTqgy5dKVKkU,5131
9
+ myosotis_researches/CcGAN/models_128/ResNet_embed.py,sha256=HKSY-5WWa9jGniOgRoR1WOTfWhR1Dcj6cq2sgznZEbE,6344
10
+ myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py,sha256=VJYJiiwrjf9DvfZrlwOMJJAPu3PlwgFgIddDaRlGsac,6190
11
+ myosotis_researches/CcGAN/models_128/__init__.py,sha256=PJQP7ozE9vY23k01he5qvEuGndPZKqxiWWxvgbLDhqg,449
12
+ myosotis_researches/CcGAN/models_128/autoencoder.py,sha256=ugOwBNoSNP4-WiATVkhC4-igRjj6yEY91qU0egpX744,3827
13
+ myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py,sha256=JDr0Ss5osf9m-u34bVN_PvMsvMXkmi2jwPOAnls6EOA,11240
14
+ myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py,sha256=GHAmrNjORXKu-8UqAdP-A5WG-_3BdQUmWsrWD1NX5-w,9634
15
+ myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py,sha256=ju1dBYhqxl722_eeUGc2mKwf1AV_qsv1PlBL3tyOu48,10861
16
+ myosotis_researches/CcGAN/models_256/ResNet_class_eval.py,sha256=tS5YxIpiFS9tDCNe2IDv1hTZNn40_JBD_nn97MfQJNI,5178
17
+ myosotis_researches/CcGAN/models_256/ResNet_embed.py,sha256=9OcMQ-8nuWEbEbWc9tGaWQtfV1hdnkl0PrTphoGX77c,6295
18
+ myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py,sha256=tHAbRNM9XodyfPsu00ac5KMjcgRH8qdx8AtCN9QGXKc,6269
19
+ myosotis_researches/CcGAN/models_256/__init__.py,sha256=PJQP7ozE9vY23k01he5qvEuGndPZKqxiWWxvgbLDhqg,449
20
+ myosotis_researches/CcGAN/models_256/autoencoder.py,sha256=Nv3eSWJVrWaOufoVGe04sZ_KiXFLtu3Y0asZcAdyyj0,4382
21
+ myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py,sha256=wTHVkUcAp07n3lgweKFo6cqd91E_rEqgJrBDbBe6qrg,11510
22
+ myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py,sha256=ZmGEpprDDlFR3dG32LT3NH5yiA1WR8Hg26rcbz42aCQ,9807
23
+ myosotis_researches/CcGAN/train/__init__.py,sha256=-55Ccov89II6Yuaiszi8ziw9EoVQr7OJR0bQfPAE_10,127
24
+ myosotis_researches/CcGAN/train/train_ccgan.py,sha256=0Qxibgd2-WaYgbyYeeOyiMkdcwkd_M1m1gSqoHTjN0w,13268
25
+ myosotis_researches/CcGAN/train/train_cgan.py,sha256=sxMzvlmdjmqufwJFxBwatcoJecYqn2Uidedu15CL9ws,9619
26
+ myosotis_researches/CcGAN/train/train_cgan_concat.py,sha256=OrQbwdU_ujUeKFGixUUpnini6rURtbuHv9NDrP6g0X0,8861
27
+ myosotis_researches/CcGAN/train/train_net_for_label_embed.py,sha256=4j6r4_o4rXgAN4MdUQL-TXqZJpbhH7d9gWQR8YzBlXw,6976
28
+ myosotis_researches/CcGAN/utils/IMGs_dataset.py,sha256=i45PBNSCeAEB5uUG0SluYRTuHWZwH_5ldz2wm6afkYs,927
29
+ myosotis_researches/CcGAN/utils/SimpleProgressBar.py,sha256=S4eD_m6ysHRMHAmRtkTXVRNfXTR8kuHv-d3lUN0BVn4,546
30
+ myosotis_researches/CcGAN/utils/__init__.py,sha256=em3aB0C-V230NQtT64hyuHGo4CjV6p2DwIdtNM0dk4k,516
31
+ myosotis_researches/CcGAN/utils/concat_image.py,sha256=BIGKz52Inn9S7M5fBFKye2V9bLJ0DqEQILoOVWAXUiE,2165
32
+ myosotis_researches/CcGAN/utils/make_h5.py,sha256=VtFYjr_i-JktsEW_BvofpilcDmChRmyLykv0VvlMuY0,963
33
+ myosotis_researches/CcGAN/utils/opts.py,sha256=pd7-wknNPBO5hWRpO3YAPmmAsPKgZUUpKc4gWMs6Wto,5397
34
+ myosotis_researches/CcGAN/utils/print_hdf5.py,sha256=VvmNAWtMDmg6D9V6ZbSUXrQTKRh9WIJeC4BR_ORJkco,300
35
+ myosotis_researches/CcGAN/utils/train.py,sha256=5ZXgkGesuInqUooJRpLej_KHqYQtlSDq90_5wig5elQ,5152
36
+ myosotis_researches-0.1.10.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
37
+ myosotis_researches-0.1.10.dist-info/METADATA,sha256=tCHcXYDZ_af1keBiLUExIsWWoBIKxFdrY6wTWWm9L8c,2664
38
+ myosotis_researches-0.1.10.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
39
+ myosotis_researches-0.1.10.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
40
+ myosotis_researches-0.1.10.dist-info/RECORD,,
@@ -1,76 +0,0 @@
1
- # Differentiable Augmentation for Data-Efficient GAN Training
2
- # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
- # https://arxiv.org/pdf/2006.10738
4
-
5
- import torch
6
- import torch.nn.functional as F
7
-
8
-
9
- def DiffAugment(x, policy='', channels_first=True):
10
- if policy:
11
- if not channels_first:
12
- x = x.permute(0, 3, 1, 2)
13
- for p in policy.split(','):
14
- for f in AUGMENT_FNS[p]:
15
- x = f(x)
16
- if not channels_first:
17
- x = x.permute(0, 2, 3, 1)
18
- x = x.contiguous()
19
- return x
20
-
21
-
22
- def rand_brightness(x):
23
- x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
- return x
25
-
26
-
27
- def rand_saturation(x):
28
- x_mean = x.mean(dim=1, keepdim=True)
29
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
- return x
31
-
32
-
33
- def rand_contrast(x):
34
- x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
- x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
- return x
37
-
38
-
39
- def rand_translation(x, ratio=0.125):
40
- shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
- translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
- translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
- grid_batch, grid_x, grid_y = torch.meshgrid(
44
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
- torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
- torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
- )
48
- grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
- grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
- x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
- x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
- return x
53
-
54
-
55
- def rand_cutout(x, ratio=0.5):
56
- cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
- offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
- offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
- grid_batch, grid_x, grid_y = torch.meshgrid(
60
- torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
- torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
- torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
- )
64
- grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
- grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
- mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
- mask[grid_batch, grid_x, grid_y] = 0
68
- x = x * mask.unsqueeze(1)
69
- return x
70
-
71
-
72
- AUGMENT_FNS = {
73
- 'color': [rand_brightness, rand_saturation, rand_contrast],
74
- 'translation': [rand_translation],
75
- 'cutout': [rand_cutout],
76
- }
File without changes
@@ -1,205 +0,0 @@
1
- """
2
- Compute
3
- Inception Score (IS),
4
- Frechet Inception Discrepency (FID), ref "https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py"
5
- Maximum Mean Discrepancy (MMD)
6
- for a set of fake images
7
-
8
- use numpy array
9
- Xr: high-level features for real images; nr by d array
10
- Yr: labels for real images
11
- Xg: high-level features for fake images; ng by d array
12
- Yg: labels for fake images
13
- IMGSr: real images
14
- IMGSg: fake images
15
-
16
- """
17
-
18
- import os
19
- import gc
20
- import numpy as np
21
- # from numpy import linalg as LA
22
- from scipy import linalg
23
- import torch
24
- import torch.nn as nn
25
- from scipy.stats import entropy
26
- from torch.nn import functional as F
27
- from torchvision.utils import save_image
28
-
29
- from .utils import SimpleProgressBar, IMGs_dataset
30
-
31
-
32
- def normalize_images(batch_images):
33
- batch_images = batch_images/255.0
34
- batch_images = (batch_images - 0.5)/0.5
35
- return batch_images
36
-
37
- ##############################################################################
38
- # FID scores
39
- ##############################################################################
40
- # compute FID based on extracted features
41
- def FID(Xr, Xg, eps=1e-10):
42
- '''
43
- The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
44
- and X_2 ~ N(mu_2, C_2) is
45
- d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
46
- '''
47
- #sample mean
48
- MUr = np.mean(Xr, axis = 0)
49
- MUg = np.mean(Xg, axis = 0)
50
- mean_diff = MUr - MUg
51
- #sample covariance
52
- SIGMAr = np.cov(Xr.transpose())
53
- SIGMAg = np.cov(Xg.transpose())
54
-
55
- # Product might be almost singular
56
- covmean, _ = linalg.sqrtm(SIGMAr.dot(SIGMAg), disp=False)#square root of a matrix
57
- covmean = covmean.real
58
- if not np.isfinite(covmean).all():
59
- msg = ('fid calculation produces singular product; '
60
- 'adding %s to diagonal of cov estimates') % eps
61
- print(msg)
62
- offset = np.eye(SIGMAr.shape[0]) * eps
63
- covmean = linalg.sqrtm((SIGMAr + offset).dot(SIGMAg + offset))
64
-
65
- #fid score
66
- fid_score = mean_diff.dot(mean_diff) + np.trace(SIGMAr + SIGMAg - 2*covmean)
67
-
68
- return fid_score
69
-
70
- ##test
71
- #Xr = np.random.rand(10000,1000)
72
- #Xg = np.random.rand(10000,1000)
73
- #print(FID(Xr, Xg))
74
-
75
- # compute FID from raw images
76
- def cal_FID(PreNetFID, IMGSr, IMGSg, batch_size = 500, resize = None, norm_img = False):
77
- #resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
78
-
79
- PreNetFID.eval()
80
-
81
- nr = IMGSr.shape[0]
82
- ng = IMGSg.shape[0]
83
-
84
- nc = IMGSr.shape[1] #IMGSr is nrxNCxIMG_SIExIMG_SIZE
85
- img_size = IMGSr.shape[2]
86
-
87
- if batch_size > min(nr, ng):
88
- batch_size = min(nr, ng)
89
- # print("FID: recude batch size to {}".format(batch_size))
90
-
91
- #compute the length of extracted features
92
- with torch.no_grad():
93
- test_img = torch.from_numpy(IMGSr[0].reshape((1,nc,img_size,img_size))).type(torch.float).cuda()
94
- if resize is not None:
95
- test_img = nn.functional.interpolate(test_img, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
96
- if norm_img:
97
- test_img = normalize_images(test_img)
98
- # _, test_features = PreNetFID(test_img)
99
- test_features = PreNetFID(test_img)
100
- d = test_features.shape[1] #length of extracted features
101
-
102
- Xr = np.zeros((nr, d))
103
- Xg = np.zeros((ng, d))
104
-
105
- #batch_size = 500
106
- with torch.no_grad():
107
- tmp = 0
108
- pb1 = SimpleProgressBar()
109
- for i in range(nr//batch_size):
110
- imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp+batch_size)]).type(torch.float).cuda()
111
- if resize is not None:
112
- imgr_tensor = nn.functional.interpolate(imgr_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
113
- if norm_img:
114
- imgr_tensor = normalize_images(imgr_tensor)
115
- # _, Xr_tmp = PreNetFID(imgr_tensor)
116
- Xr_tmp = PreNetFID(imgr_tensor)
117
- Xr[tmp:(tmp+batch_size)] = Xr_tmp.detach().cpu().numpy()
118
- tmp+=batch_size
119
- # pb1.update(min(float(i)*100/(nr//batch_size), 100))
120
- pb1.update(min(max(tmp/nr*100,100), 100))
121
- del Xr_tmp,imgr_tensor; gc.collect()
122
- torch.cuda.empty_cache()
123
-
124
- tmp = 0
125
- pb2 = SimpleProgressBar()
126
- for j in range(ng//batch_size):
127
- imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp+batch_size)]).type(torch.float).cuda()
128
- if resize is not None:
129
- imgg_tensor = nn.functional.interpolate(imgg_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
130
- if norm_img:
131
- imgg_tensor = normalize_images(imgg_tensor)
132
- # _, Xg_tmp = PreNetFID(imgg_tensor)
133
- Xg_tmp = PreNetFID(imgg_tensor)
134
- Xg[tmp:(tmp+batch_size)] = Xg_tmp.detach().cpu().numpy()
135
- tmp+=batch_size
136
- # pb2.update(min(float(j)*100/(ng//batch_size), 100))
137
- pb2.update(min(max(tmp/ng*100, 100), 100))
138
- del Xg_tmp,imgg_tensor; gc.collect()
139
- torch.cuda.empty_cache()
140
-
141
-
142
- fid_score = FID(Xr, Xg, eps=1e-6)
143
-
144
- return fid_score
145
-
146
-
147
-
148
-
149
-
150
-
151
- ##############################################################################
152
- # label_score
153
- # difference between assigned label and predicted label
154
- ##############################################################################
155
- def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None, norm_img = False, num_workers=0):
156
- '''
157
- PreNet: pre-trained CNN
158
- images: fake images
159
- labels_assi: assigned labels
160
- resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
161
- '''
162
-
163
- PreNet.eval()
164
-
165
- # assume images are nxncximg_sizeximg_size
166
- n = images.shape[0]
167
- nc = images.shape[1] #number of channels
168
- img_size = images.shape[2]
169
- labels_assi = labels_assi.reshape(-1)
170
-
171
- eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
172
- eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
173
-
174
- labels_pred = np.zeros(n+batch_size)
175
-
176
- nimgs_got = 0
177
- pb = SimpleProgressBar()
178
- for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
179
- batch_images = batch_images.type(torch.float).cuda()
180
- batch_labels = batch_labels.type(torch.float).cuda()
181
- batch_size_curr = len(batch_labels)
182
-
183
- if norm_img:
184
- batch_images = normalize_images(batch_images)
185
-
186
- batch_labels_pred, _ = PreNet(batch_images)
187
- labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)
188
-
189
- nimgs_got += batch_size_curr
190
- pb.update((float(nimgs_got)/n)*100)
191
-
192
- del batch_images; gc.collect()
193
- torch.cuda.empty_cache()
194
- #end for batch_idx
195
-
196
- labels_pred = labels_pred[0:n]
197
-
198
-
199
- labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
200
- labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)
201
-
202
- ls_mean = np.mean(np.abs(labels_pred-labels_assi))
203
- ls_std = np.std(np.abs(labels_pred-labels_assi))
204
-
205
- return ls_mean, ls_std
@@ -1,87 +0,0 @@
1
- import argparse
2
-
3
- def parse_opts():
4
- parser = argparse.ArgumentParser()
5
-
6
- ''' Overall Settings '''
7
- parser.add_argument('--root_path', type=str, default='')
8
- parser.add_argument('--data_path', type=str, default='')
9
- parser.add_argument('--eval_ckpt_path', type=str, default='')
10
- parser.add_argument('--seed', type=int, default=2021, metavar='S', help='random seed (default: 2020)')
11
- parser.add_argument('--num_workers', type=int, default=0)
12
-
13
-
14
- ''' Dataset '''
15
- ## Data split: RC-49 is split into a train set (the last decimal of the degree is odd) and a test set (the last decimal of the degree is even); the unique labels in two sets do not overlap.
16
- parser.add_argument('--data_split', type=str, default='train',
17
- choices=['all', 'train'])
18
- parser.add_argument('--min_label', type=float, default=0.0)
19
- parser.add_argument('--max_label', type=float, default=90.0)
20
- parser.add_argument('--num_channels', type=int, default=3, metavar='N')
21
- parser.add_argument('--img_size', type=int, default=128, metavar='N')
22
- parser.add_argument('--max_num_img_per_label', type=int, default=50, metavar='N')
23
- parser.add_argument('--max_num_img_per_label_after_replica', type=int, default=0, metavar='N')
24
- parser.add_argument('--show_real_imgs', action='store_true', default=False)
25
- parser.add_argument('--visualize_fake_images', action='store_true', default=False)
26
-
27
-
28
- ''' GAN settings '''
29
- parser.add_argument('--GAN', type=str, default='CcGAN', choices=['cGAN', 'cGAN-concat', 'CcGAN'])
30
- parser.add_argument('--GAN_arch', type=str, default='SAGAN', choices=['SAGAN'])
31
-
32
- # label embedding setting
33
- parser.add_argument('--net_embed', type=str, default='ResNet34_embed') #ResNetXX_emebed
34
- parser.add_argument('--epoch_cnn_embed', type=int, default=200) #epoch of cnn training for label embedding
35
- parser.add_argument('--resumeepoch_cnn_embed', type=int, default=0) #epoch of cnn training for label embedding
36
- parser.add_argument('--epoch_net_y2h', type=int, default=500)
37
- parser.add_argument('--dim_embed', type=int, default=128) #dimension of the embedding space
38
- parser.add_argument('--batch_size_embed', type=int, default=256, metavar='N')
39
-
40
- parser.add_argument('--loss_type_gan', type=str, default='hinge')
41
- parser.add_argument('--niters_gan', type=int, default=10000, help='number of iterations')
42
- parser.add_argument('--resume_niters_gan', type=int, default=0)
43
- parser.add_argument('--save_niters_freq', type=int, default=2000, help='frequency of saving checkpoints')
44
- parser.add_argument('--lr_g_gan', type=float, default=1e-4, help='learning rate for generator')
45
- parser.add_argument('--lr_d_gan', type=float, default=1e-4, help='learning rate for discriminator')
46
- parser.add_argument('--dim_gan', type=int, default=128, help='Latent dimension of GAN')
47
- parser.add_argument('--batch_size_disc', type=int, default=64)
48
- parser.add_argument('--batch_size_gene', type=int, default=64)
49
- parser.add_argument('--num_D_steps', type=int, default=4, help='number of Ds updates in one iteration')
50
- parser.add_argument('--cGAN_num_classes', type=int, default=20, metavar='N') #bin label into cGAN_num_classes
51
- parser.add_argument('--visualize_freq', type=int, default=2000, help='frequency of visualization')
52
-
53
- parser.add_argument('--kernel_sigma', type=float, default=-1.0,
54
- help='If kernel_sigma<0, then use rule-of-thumb formula to compute the sigma.')
55
- parser.add_argument('--threshold_type', type=str, default='hard', choices=['soft', 'hard'])
56
- parser.add_argument('--kappa', type=float, default=-1)
57
- parser.add_argument('--nonzero_soft_weight_threshold', type=float, default=1e-3,
58
- help='threshold for determining nonzero weights for SVDL; we neglect images with too small weights')
59
-
60
- # DiffAugment setting
61
- parser.add_argument('--gan_DiffAugment', action='store_true', default=False)
62
- parser.add_argument('--gan_DiffAugment_policy', type=str, default='color,translation,cutout')
63
-
64
-
65
- # evaluation setting
66
- '''
67
- Four evaluation modes:
68
- Mode 1: eval on unique labels used for GAN training;
69
- Mode 2. eval on all unique labels in the dataset and when computing FID use all real images in the dataset;
70
- Mode 3. eval on all unique labels in the dataset and when computing FID only use real images for GAN training in the dataset (to test SFID's effectiveness on unseen labels);
71
- Mode 4. eval on a interval [min_label, max_label] with num_eval_labels labels.
72
- '''
73
- parser.add_argument('--eval_mode', type=int, default=2)
74
- parser.add_argument('--num_eval_labels', type=int, default=-1)
75
- parser.add_argument('--samp_batch_size', type=int, default=200)
76
- parser.add_argument('--nfake_per_label', type=int, default=200)
77
- parser.add_argument('--nreal_per_label', type=int, default=-1)
78
- parser.add_argument('--comp_FID', action='store_true', default=False)
79
- parser.add_argument('--epoch_FID_CNN', type=int, default=200)
80
- parser.add_argument('--FID_radius', type=float, default=0)
81
- parser.add_argument('--FID_num_centers', type=int, default=-1)
82
- parser.add_argument('--dump_fake_for_NIQE', action='store_true', default=False,
83
- help='Dump fake images for computing NIQE')
84
-
85
- args = parser.parse_args()
86
-
87
- return args