Myosotis-Researches 0.1.7__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9/Myosotis_Researches.egg-info}/PKG-INFO +1 -1
  2. myosotis_researches-0.1.9/Myosotis_Researches.egg-info/SOURCES.txt +27 -0
  3. {myosotis_researches-0.1.7/Myosotis_Researches.egg-info → myosotis_researches-0.1.9}/PKG-INFO +1 -1
  4. myosotis_researches-0.1.9/myosotis_researches/CcGAN/train/__init__.py +4 -0
  5. {myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10 → myosotis_researches-0.1.9/myosotis_researches/CcGAN/train}/train_ccgan.py +4 -4
  6. {myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128 → myosotis_researches-0.1.9/myosotis_researches/CcGAN/train}/train_cgan.py +1 -3
  7. {myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128 → myosotis_researches-0.1.9/myosotis_researches/CcGAN/train}/train_cgan_concat.py +1 -3
  8. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/__init__.py +2 -1
  9. myosotis_researches-0.1.9/myosotis_researches/CcGAN/utils/train.py +156 -0
  10. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/setup.py +1 -1
  11. myosotis_researches-0.1.7/Myosotis_Researches.egg-info/SOURCES.txt +0 -62
  12. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py +0 -301
  13. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/ResNet_class_eval.py +0 -141
  14. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/ResNet_embed.py +0 -188
  15. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py +0 -175
  16. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/__init__.py +0 -7
  17. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/autoencoder.py +0 -119
  18. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py +0 -276
  19. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py +0 -245
  20. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py +0 -303
  21. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/ResNet_class_eval.py +0 -142
  22. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/ResNet_embed.py +0 -188
  23. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py +0 -178
  24. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/__init__.py +0 -7
  25. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/autoencoder.py +0 -133
  26. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py +0 -280
  27. myosotis_researches-0.1.7/myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py +0 -249
  28. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py +0 -76
  29. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/eval_metrics.py +0 -205
  30. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/opts.py +0 -87
  31. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/pretrain_AE.py +0 -268
  32. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py +0 -251
  33. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py +0 -255
  34. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/train_ccgan.py +0 -303
  35. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128/utils.py +0 -120
  36. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py +0 -76
  37. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/__init__.py +0 -0
  38. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py +0 -205
  39. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/opts.py +0 -87
  40. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py +0 -268
  41. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py +0 -251
  42. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py +0 -255
  43. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/train_cgan.py +0 -254
  44. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py +0 -242
  45. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py +0 -181
  46. myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128_output_10/utils.py +0 -120
  47. myosotis_researches-0.1.7/myosotis_researches/CcGAN/utils/train.py +0 -65
  48. myosotis_researches-0.1.7/myosotis_researches/__init__.py +0 -0
  49. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/LICENSE +0 -0
  50. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/Myosotis_Researches.egg-info/dependency_links.txt +0 -0
  51. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/Myosotis_Researches.egg-info/top_level.txt +0 -0
  52. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/README.md +0 -0
  53. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/__init__.py +0 -0
  54. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/internal/__init__.py +0 -0
  55. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/internal/install_datasets.py +0 -0
  56. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/internal/show_datasets.py +0 -0
  57. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/internal/uninstall_datasets.py +0 -0
  58. {myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128 → myosotis_researches-0.1.9/myosotis_researches/CcGAN/train}/train_net_for_label_embed.py +0 -0
  59. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/IMGs_dataset.py +0 -0
  60. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/SimpleProgressBar.py +0 -0
  61. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/concat_image.py +0 -0
  62. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/make_h5.py +0 -0
  63. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/opts.py +0 -0
  64. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/myosotis_researches/CcGAN/utils/print_hdf5.py +0 -0
  65. {myosotis_researches-0.1.7/myosotis_researches/CcGAN/train_128 → myosotis_researches-0.1.9/myosotis_researches}/__init__.py +0 -0
  66. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/pyproject.toml +0 -0
  67. {myosotis_researches-0.1.7 → myosotis_researches-0.1.9}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -0,0 +1,27 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ setup.py
5
+ Myosotis_Researches.egg-info/PKG-INFO
6
+ Myosotis_Researches.egg-info/SOURCES.txt
7
+ Myosotis_Researches.egg-info/dependency_links.txt
8
+ Myosotis_Researches.egg-info/top_level.txt
9
+ myosotis_researches/__init__.py
10
+ myosotis_researches/CcGAN/__init__.py
11
+ myosotis_researches/CcGAN/internal/__init__.py
12
+ myosotis_researches/CcGAN/internal/install_datasets.py
13
+ myosotis_researches/CcGAN/internal/show_datasets.py
14
+ myosotis_researches/CcGAN/internal/uninstall_datasets.py
15
+ myosotis_researches/CcGAN/train/__init__.py
16
+ myosotis_researches/CcGAN/train/train_ccgan.py
17
+ myosotis_researches/CcGAN/train/train_cgan.py
18
+ myosotis_researches/CcGAN/train/train_cgan_concat.py
19
+ myosotis_researches/CcGAN/train/train_net_for_label_embed.py
20
+ myosotis_researches/CcGAN/utils/IMGs_dataset.py
21
+ myosotis_researches/CcGAN/utils/SimpleProgressBar.py
22
+ myosotis_researches/CcGAN/utils/__init__.py
23
+ myosotis_researches/CcGAN/utils/concat_image.py
24
+ myosotis_researches/CcGAN/utils/make_h5.py
25
+ myosotis_researches/CcGAN/utils/opts.py
26
+ myosotis_researches/CcGAN/utils/print_hdf5.py
27
+ myosotis_researches/CcGAN/utils/train.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -0,0 +1,4 @@
1
+ from .train_ccgan import *
2
+ from .train_cgan import *
3
+ from .train_cgan_concat import *
4
+ from .train_net_for_label_embed import *
@@ -5,10 +5,9 @@ import timeit
5
5
  from PIL import Image
6
6
  from torchvision.utils import save_image
7
7
  import torch.cuda as cutorch
8
+ import sys
8
9
 
9
- from .utils import SimpleProgressBar, IMGs_dataset
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
10
+ from myosotis_researches.CcGAN.utils import *
12
11
 
13
12
  ''' Settings '''
14
13
  args = parse_opts()
@@ -79,7 +78,8 @@ def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net
79
78
  # printed images with labels between the 5-th quantile and 95-th quantile of training labels
80
79
  n_row=10; n_col = 1
81
80
  z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
82
- selected_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
81
+
82
+ selected_labels = np.array([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
83
83
 
84
84
  y_fixed = np.zeros(n_row*n_col)
85
85
  for i in range(n_row):
@@ -6,9 +6,7 @@ import numpy as np
6
6
  import os
7
7
  import timeit
8
8
 
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
9
+ from myosotis_researches.CcGAN.utils import *
12
10
 
13
11
  ''' Settings '''
14
12
  args = parse_opts()
@@ -6,9 +6,7 @@ import numpy as np
6
6
  import os
7
7
  import timeit
8
8
 
9
- from .utils import IMGs_dataset, SimpleProgressBar
10
- from .opts import parse_opts
11
- from .DiffAugment_pytorch import DiffAugment
9
+ from myosotis_researches.CcGAN.utils import *
12
10
 
13
11
  ''' Settings '''
14
12
  args = parse_opts()
@@ -3,7 +3,7 @@ from .concat_image import concat_image
3
3
  from .make_h5 import make_h5
4
4
  from .SimpleProgressBar import SimpleProgressBar
5
5
  from .IMGs_dataset import IMGs_dataset
6
- from .train import PlotLoss, compute_entropy, predict_class_labels
6
+ from .train import PlotLoss, compute_entropy, predict_class_labels, DiffAugment
7
7
  from .opts import parse_opts
8
8
 
9
9
  __all__ = [
@@ -15,5 +15,6 @@ __all__ = [
15
15
  "PlotLoss",
16
16
  "compute_entropy",
17
17
  "predict_class_labels",
18
+ "DiffAugment",
18
19
  "parse_opts"
19
20
  ]
@@ -0,0 +1,156 @@
1
+ import matplotlib as mpl
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def PlotLoss(loss, filename):
9
+ x_axis = np.arange(start=1, stop=len(loss) + 1)
10
+ plt.switch_backend("agg")
11
+ mpl.style.use("seaborn")
12
+ fig = plt.figure()
13
+ ax = plt.subplot(111)
14
+ ax.plot(x_axis, np.array(loss))
15
+ plt.xlabel("epoch")
16
+ plt.ylabel("training loss")
17
+ plt.legend()
18
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
19
+ # plt.title('Training Loss')
20
+ plt.savefig(filename)
21
+
22
+
23
+ # compute entropy of class labels; labels is a numpy array
24
+ def compute_entropy(labels, base=None):
25
+ value, counts = np.unique(labels, return_counts=True)
26
+ norm_counts = counts / counts.sum()
27
+ base = np.e if base is None else base
28
+ return -(norm_counts * np.log(norm_counts) / np.log(base)).sum()
29
+
30
+
31
+ def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
32
+ net = net.cuda()
33
+ net.eval()
34
+
35
+ n = len(images)
36
+ if batch_size > n:
37
+ batch_size = n
38
+ dataset_pred = IMGs_dataset(images, normalize=False)
39
+ dataloader_pred = torch.utils.data.DataLoader(
40
+ dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers
41
+ )
42
+
43
+ class_labels_pred = np.zeros(n + batch_size)
44
+ with torch.no_grad():
45
+ nimgs_got = 0
46
+ if verbose:
47
+ pb = SimpleProgressBar()
48
+ for batch_idx, batch_images in enumerate(dataloader_pred):
49
+ batch_images = batch_images.type(torch.float).cuda()
50
+ batch_size_curr = len(batch_images)
51
+
52
+ outputs, _ = net(batch_images)
53
+ _, batch_class_labels_pred = torch.max(outputs.data, 1)
54
+ class_labels_pred[nimgs_got : (nimgs_got + batch_size_curr)] = (
55
+ batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
56
+ )
57
+
58
+ nimgs_got += batch_size_curr
59
+ if verbose:
60
+ pb.update((float(nimgs_got) / n) * 100)
61
+ # end for batch_idx
62
+ class_labels_pred = class_labels_pred[0:n]
63
+ return class_labels_pred
64
+
65
+
66
+ def DiffAugment(x, policy="", channels_first=True):
67
+ if policy:
68
+ if not channels_first:
69
+ x = x.permute(0, 3, 1, 2)
70
+ for p in policy.split(","):
71
+ for f in AUGMENT_FNS[p]:
72
+ x = f(x)
73
+ if not channels_first:
74
+ x = x.permute(0, 2, 3, 1)
75
+ x = x.contiguous()
76
+ return x
77
+
78
+
79
+ def rand_brightness(x):
80
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
81
+ return x
82
+
83
+
84
+ def rand_saturation(x):
85
+ x_mean = x.mean(dim=1, keepdim=True)
86
+ x = (x - x_mean) * (
87
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
88
+ ) + x_mean
89
+ return x
90
+
91
+
92
+ def rand_contrast(x):
93
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
94
+ x = (x - x_mean) * (
95
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
96
+ ) + x_mean
97
+ return x
98
+
99
+
100
+ def rand_translation(x, ratio=0.125):
101
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
102
+ translation_x = torch.randint(
103
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
104
+ )
105
+ translation_y = torch.randint(
106
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
107
+ )
108
+ grid_batch, grid_x, grid_y = torch.meshgrid(
109
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
110
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
111
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
112
+ )
113
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
114
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
115
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
116
+ x = (
117
+ x_pad.permute(0, 2, 3, 1)
118
+ .contiguous()[grid_batch, grid_x, grid_y]
119
+ .permute(0, 3, 1, 2)
120
+ )
121
+ return x
122
+
123
+
124
+ def rand_cutout(x, ratio=0.5):
125
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
126
+ offset_x = torch.randint(
127
+ 0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
128
+ )
129
+ offset_y = torch.randint(
130
+ 0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
131
+ )
132
+ grid_batch, grid_x, grid_y = torch.meshgrid(
133
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
134
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
135
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
136
+ )
137
+ grid_x = torch.clamp(
138
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
139
+ )
140
+ grid_y = torch.clamp(
141
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
142
+ )
143
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
144
+ mask[grid_batch, grid_x, grid_y] = 0
145
+ x = x * mask.unsqueeze(1)
146
+ return x
147
+
148
+
149
+ AUGMENT_FNS = {
150
+ "color": [rand_brightness, rand_saturation, rand_contrast],
151
+ "translation": [rand_translation],
152
+ "cutout": [rand_cutout],
153
+ }
154
+
155
+
156
+ __all__ = ["PlotLoss", "compute_entropy", "predict_class_labels", "DiffAugment"]
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="Myosotis-Researches",
5
- version="0.1.7",
5
+ version="0.1.9",
6
6
  description="A repository for storing my progress of researches.",
7
7
  long_description=open("README.md").read(),
8
8
  long_description_content_type="text/markdown",
@@ -1,62 +0,0 @@
1
- LICENSE
2
- README.md
3
- pyproject.toml
4
- setup.py
5
- Myosotis_Researches.egg-info/PKG-INFO
6
- Myosotis_Researches.egg-info/SOURCES.txt
7
- Myosotis_Researches.egg-info/dependency_links.txt
8
- Myosotis_Researches.egg-info/top_level.txt
9
- myosotis_researches/__init__.py
10
- myosotis_researches/CcGAN/__init__.py
11
- myosotis_researches/CcGAN/internal/__init__.py
12
- myosotis_researches/CcGAN/internal/install_datasets.py
13
- myosotis_researches/CcGAN/internal/show_datasets.py
14
- myosotis_researches/CcGAN/internal/uninstall_datasets.py
15
- myosotis_researches/CcGAN/models_128/CcGAN_SAGAN.py
16
- myosotis_researches/CcGAN/models_128/ResNet_class_eval.py
17
- myosotis_researches/CcGAN/models_128/ResNet_embed.py
18
- myosotis_researches/CcGAN/models_128/ResNet_regre_eval.py
19
- myosotis_researches/CcGAN/models_128/__init__.py
20
- myosotis_researches/CcGAN/models_128/autoencoder.py
21
- myosotis_researches/CcGAN/models_128/cGAN_SAGAN.py
22
- myosotis_researches/CcGAN/models_128/cGAN_concat_SAGAN.py
23
- myosotis_researches/CcGAN/models_256/CcGAN_SAGAN.py
24
- myosotis_researches/CcGAN/models_256/ResNet_class_eval.py
25
- myosotis_researches/CcGAN/models_256/ResNet_embed.py
26
- myosotis_researches/CcGAN/models_256/ResNet_regre_eval.py
27
- myosotis_researches/CcGAN/models_256/__init__.py
28
- myosotis_researches/CcGAN/models_256/autoencoder.py
29
- myosotis_researches/CcGAN/models_256/cGAN_SAGAN.py
30
- myosotis_researches/CcGAN/models_256/cGAN_concat_SAGAN.py
31
- myosotis_researches/CcGAN/train_128/DiffAugment_pytorch.py
32
- myosotis_researches/CcGAN/train_128/__init__.py
33
- myosotis_researches/CcGAN/train_128/eval_metrics.py
34
- myosotis_researches/CcGAN/train_128/opts.py
35
- myosotis_researches/CcGAN/train_128/pretrain_AE.py
36
- myosotis_researches/CcGAN/train_128/pretrain_CNN_class.py
37
- myosotis_researches/CcGAN/train_128/pretrain_CNN_regre.py
38
- myosotis_researches/CcGAN/train_128/train_ccgan.py
39
- myosotis_researches/CcGAN/train_128/train_cgan.py
40
- myosotis_researches/CcGAN/train_128/train_cgan_concat.py
41
- myosotis_researches/CcGAN/train_128/train_net_for_label_embed.py
42
- myosotis_researches/CcGAN/train_128/utils.py
43
- myosotis_researches/CcGAN/train_128_output_10/DiffAugment_pytorch.py
44
- myosotis_researches/CcGAN/train_128_output_10/__init__.py
45
- myosotis_researches/CcGAN/train_128_output_10/eval_metrics.py
46
- myosotis_researches/CcGAN/train_128_output_10/opts.py
47
- myosotis_researches/CcGAN/train_128_output_10/pretrain_AE.py
48
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_class.py
49
- myosotis_researches/CcGAN/train_128_output_10/pretrain_CNN_regre.py
50
- myosotis_researches/CcGAN/train_128_output_10/train_ccgan.py
51
- myosotis_researches/CcGAN/train_128_output_10/train_cgan.py
52
- myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py
53
- myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py
54
- myosotis_researches/CcGAN/train_128_output_10/utils.py
55
- myosotis_researches/CcGAN/utils/IMGs_dataset.py
56
- myosotis_researches/CcGAN/utils/SimpleProgressBar.py
57
- myosotis_researches/CcGAN/utils/__init__.py
58
- myosotis_researches/CcGAN/utils/concat_image.py
59
- myosotis_researches/CcGAN/utils/make_h5.py
60
- myosotis_researches/CcGAN/utils/opts.py
61
- myosotis_researches/CcGAN/utils/print_hdf5.py
62
- myosotis_researches/CcGAN/utils/train.py
@@ -1,301 +0,0 @@
1
- '''
2
-
3
- Adapted from https://github.com/voletiv/self-attention-GAN-pytorch/blob/master/sagan_models.py
4
-
5
-
6
- '''
7
-
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- from torch.nn.utils import spectral_norm
15
- from torch.nn.init import xavier_uniform_
16
-
17
-
18
- def init_weights(m):
19
- if type(m) == nn.Linear or type(m) == nn.Conv2d:
20
- xavier_uniform_(m.weight)
21
- if m.bias is not None:
22
- m.bias.data.fill_(0.)
23
-
24
-
25
- def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
26
- return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
27
- stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))
28
-
29
- def snlinear(in_features, out_features, bias=True):
30
- return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias))
31
-
32
-
33
-
34
- class Self_Attn(nn.Module):
35
- """ Self attention Layer"""
36
-
37
- def __init__(self, in_channels):
38
- super(Self_Attn, self).__init__()
39
- self.in_channels = in_channels
40
- self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
41
- self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)
42
- self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)
43
- self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
44
- self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
45
- self.softmax = nn.Softmax(dim=-1)
46
- self.sigma = nn.Parameter(torch.zeros(1))
47
-
48
- def forward(self, x):
49
- """
50
- inputs :
51
- x : input feature maps(B X C X W X H)
52
- returns :
53
- out : self attention value + input feature
54
- attention: B X N X N (N is Width*Height)
55
- """
56
- _, ch, h, w = x.size()
57
- # Theta path
58
- theta = self.snconv1x1_theta(x)
59
- theta = theta.view(-1, ch//8, h*w)
60
- # Phi path
61
- phi = self.snconv1x1_phi(x)
62
- phi = self.maxpool(phi)
63
- phi = phi.view(-1, ch//8, h*w//4)
64
- # Attn map
65
- attn = torch.bmm(theta.permute(0, 2, 1), phi)
66
- attn = self.softmax(attn)
67
- # g path
68
- g = self.snconv1x1_g(x)
69
- g = self.maxpool(g)
70
- g = g.view(-1, ch//2, h*w//4)
71
- # Attn_g
72
- attn_g = torch.bmm(g, attn.permute(0, 2, 1))
73
- attn_g = attn_g.view(-1, ch//2, h, w)
74
- attn_g = self.snconv1x1_attn(attn_g)
75
- # Out
76
- out = x + self.sigma*attn_g
77
- return out
78
-
79
-
80
-
81
-
82
- '''
83
-
84
- Generator
85
-
86
- '''
87
-
88
-
89
- class ConditionalBatchNorm2d(nn.Module):
90
- def __init__(self, num_features, dim_embed):
91
- super().__init__()
92
- self.num_features = num_features
93
- self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
94
- self.embed_gamma = nn.Linear(dim_embed, num_features, bias=False)
95
- self.embed_beta = nn.Linear(dim_embed, num_features, bias=False)
96
-
97
- def forward(self, x, y):
98
- out = self.bn(x)
99
- gamma = self.embed_gamma(y).view(-1, self.num_features, 1, 1)
100
- beta = self.embed_beta(y).view(-1, self.num_features, 1, 1)
101
- out = out + gamma*out + beta
102
- return out
103
-
104
-
105
- class GenBlock(nn.Module):
106
- def __init__(self, in_channels, out_channels, dim_embed):
107
- super(GenBlock, self).__init__()
108
- self.cond_bn1 = ConditionalBatchNorm2d(in_channels, dim_embed)
109
- self.relu = nn.ReLU(inplace=True)
110
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
111
- self.cond_bn2 = ConditionalBatchNorm2d(out_channels, dim_embed)
112
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
113
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
114
-
115
- def forward(self, x, labels):
116
- x0 = x
117
-
118
- x = self.cond_bn1(x, labels)
119
- x = self.relu(x)
120
- x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample
121
- x = self.snconv2d1(x)
122
- x = self.cond_bn2(x, labels)
123
- x = self.relu(x)
124
- x = self.snconv2d2(x)
125
-
126
- x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample
127
- x0 = self.snconv2d0(x0)
128
-
129
- out = x + x0
130
- return out
131
-
132
-
133
- class CcGAN_SAGAN_Generator(nn.Module):
134
- """Generator."""
135
-
136
- def __init__(self, dim_z, dim_embed=128, nc=3, gene_ch=64):
137
- super(CcGAN_SAGAN_Generator, self).__init__()
138
-
139
- self.dim_z = dim_z
140
- self.gene_ch = gene_ch
141
-
142
- self.snlinear0 = snlinear(in_features=dim_z, out_features=gene_ch*16*4*4)
143
- self.block1 = GenBlock(gene_ch*16, gene_ch*16, dim_embed)
144
- self.block2 = GenBlock(gene_ch*16, gene_ch*8, dim_embed)
145
- self.block3 = GenBlock(gene_ch*8, gene_ch*4, dim_embed)
146
- self.self_attn = Self_Attn(gene_ch*4)
147
- self.block4 = GenBlock(gene_ch*4, gene_ch*2, dim_embed)
148
- self.block5 = GenBlock(gene_ch*2, gene_ch, dim_embed)
149
- self.bn = nn.BatchNorm2d(gene_ch, eps=1e-5, momentum=0.0001, affine=True)
150
- self.relu = nn.ReLU(inplace=True)
151
- self.snconv2d1 = snconv2d(in_channels=gene_ch, out_channels=nc, kernel_size=3, stride=1, padding=1)
152
- self.tanh = nn.Tanh()
153
-
154
- # Weight init
155
- self.apply(init_weights)
156
-
157
- def forward(self, z, labels):
158
- # n x dim_z
159
- out = self.snlinear0(z) # 4*4
160
- out = out.view(-1, self.gene_ch*16, 4, 4) # 4 x 4
161
- out = self.block1(out, labels) # 8 x 8
162
- out = self.block2(out, labels) # 16 x 16
163
- out = self.block3(out, labels) # 32 x 32
164
- out = self.self_attn(out) # 32 x 32
165
- out = self.block4(out, labels) # 64 x 64
166
- out = self.block5(out, labels) # 128 x 128
167
- out = self.bn(out)
168
- out = self.relu(out)
169
- out = self.snconv2d1(out)
170
- out = self.tanh(out)
171
- return out
172
-
173
-
174
-
175
- '''
176
-
177
- Discriminator
178
-
179
- '''
180
-
181
- class DiscOptBlock(nn.Module):
182
- def __init__(self, in_channels, out_channels):
183
- super(DiscOptBlock, self).__init__()
184
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
185
- self.relu = nn.ReLU(inplace=True)
186
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
187
- self.downsample = nn.AvgPool2d(2)
188
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
189
-
190
- def forward(self, x):
191
- x0 = x
192
-
193
- x = self.snconv2d1(x)
194
- x = self.relu(x)
195
- x = self.snconv2d2(x)
196
- x = self.downsample(x)
197
-
198
- x0 = self.downsample(x0)
199
- x0 = self.snconv2d0(x0)
200
-
201
- out = x + x0
202
- return out
203
-
204
-
205
- class DiscBlock(nn.Module):
206
- def __init__(self, in_channels, out_channels):
207
- super(DiscBlock, self).__init__()
208
- self.relu = nn.ReLU(inplace=True)
209
- self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
210
- self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
211
- self.downsample = nn.AvgPool2d(2)
212
- self.ch_mismatch = False
213
- if in_channels != out_channels:
214
- self.ch_mismatch = True
215
- self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
216
-
217
- def forward(self, x, downsample=True):
218
- x0 = x
219
-
220
- x = self.relu(x)
221
- x = self.snconv2d1(x)
222
- x = self.relu(x)
223
- x = self.snconv2d2(x)
224
- if downsample:
225
- x = self.downsample(x)
226
-
227
- if downsample or self.ch_mismatch:
228
- x0 = self.snconv2d0(x0)
229
- if downsample:
230
- x0 = self.downsample(x0)
231
-
232
- out = x + x0
233
- return out
234
-
235
-
236
- class CcGAN_SAGAN_Discriminator(nn.Module):
237
- """Discriminator."""
238
-
239
- def __init__(self, dim_embed=128, nc=3, disc_ch=64):
240
- super(CcGAN_SAGAN_Discriminator, self).__init__()
241
- self.disc_ch = disc_ch
242
- self.opt_block1 = DiscOptBlock(nc, disc_ch)
243
- self.block1 = DiscBlock(disc_ch, disc_ch*2)
244
- self.self_attn = Self_Attn(disc_ch*2)
245
- self.block2 = DiscBlock(disc_ch*2, disc_ch*4)
246
- self.block3 = DiscBlock(disc_ch*4, disc_ch*8)
247
- self.block4 = DiscBlock(disc_ch*8, disc_ch*16)
248
- self.block5 = DiscBlock(disc_ch*16, disc_ch*16)
249
- self.relu = nn.ReLU(inplace=True)
250
- self.snlinear1 = snlinear(in_features=disc_ch*16*4*4, out_features=1)
251
- self.sn_embedding1 = snlinear(dim_embed, disc_ch*16*4*4, bias=False)
252
-
253
- # Weight init
254
- self.apply(init_weights)
255
- xavier_uniform_(self.sn_embedding1.weight)
256
-
257
- def forward(self, x, labels):
258
- # 128x128
259
- out = self.opt_block1(x) # 128x128
260
- out = self.block1(out) # 64 x 64
261
- out = self.self_attn(out) # 64 x 64
262
- out = self.block2(out) # 32 x 32
263
- out = self.block3(out) # 16 x 16
264
- out = self.block4(out) # 8 x 8
265
- out = self.block5(out, downsample=False) # 4 x 4
266
- out = self.relu(out) # n x disc_ch*16 x 4 x 4
267
- out = out.view(-1, self.disc_ch*16*4*4)
268
- output1 = torch.squeeze(self.snlinear1(out)) # n
269
- # Projection
270
- h_labels = self.sn_embedding1(labels) # n x disc_ch*16
271
- proj = torch.mul(out, h_labels) # n x disc_ch*16
272
- output2 = torch.sum(proj, dim=[1]) # n
273
- # Out
274
- output = output1 + output2 # n
275
- return output
276
-
277
-
278
- if __name__ == "__main__":
279
-
280
- def get_parameter_number(net):
281
- total_num = sum(p.numel() for p in net.parameters())
282
- trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
283
- return {'Total': total_num, 'Trainable': trainable_num}
284
-
285
-
286
- netG = CcGAN_SAGAN_Generator(dim_z=256, dim_embed=128, gene_ch=128).cuda()
287
- netD = CcGAN_SAGAN_Discriminator(dim_embed=128, disc_ch=128).cuda()
288
-
289
- # netG = nn.DataParallel(netG)
290
- # netD = nn.DataParallel(netD)
291
-
292
- N=4
293
- z = torch.randn(N, 256).cuda()
294
- y = torch.randn(N, 128).cuda()
295
- x = netG(z,y)
296
- o = netD(x,y)
297
- print(x.size())
298
- print(o.size())
299
-
300
- print('G:', get_parameter_number(netG))
301
- print('D:', get_parameter_number(netD))