Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- myosotis_researches/CcGAN/utils/__init__.py +2 -1
- myosotis_researches/CcGAN/utils/train.py +94 -3
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/METADATA +1 -1
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/RECORD +7 -7
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/WHEEL +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from .concat_image import concat_image
|
|
3
3
|
from .make_h5 import make_h5
|
4
4
|
from .SimpleProgressBar import SimpleProgressBar
|
5
5
|
from .IMGs_dataset import IMGs_dataset
|
6
|
-
from .train import PlotLoss, compute_entropy, predict_class_labels
|
6
|
+
from .train import PlotLoss, compute_entropy, predict_class_labels, DiffAugment
|
7
7
|
from .opts import parse_opts
|
8
8
|
|
9
9
|
__all__ = [
|
@@ -15,5 +15,6 @@ __all__ = [
|
|
15
15
|
"PlotLoss",
|
16
16
|
"compute_entropy",
|
17
17
|
"predict_class_labels",
|
18
|
+
"DiffAugment",
|
18
19
|
"parse_opts"
|
19
20
|
]
|
@@ -1,7 +1,8 @@
|
|
1
|
+
import matplotlib as mpl
|
2
|
+
import matplotlib.pyplot as plt
|
1
3
|
import numpy as np
|
2
4
|
import torch
|
3
|
-
import
|
4
|
-
import matplotlib as mpl
|
5
|
+
import torch.nn.functional as F
|
5
6
|
|
6
7
|
|
7
8
|
def PlotLoss(loss, filename):
|
@@ -62,4 +63,94 @@ def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers
|
|
62
63
|
return class_labels_pred
|
63
64
|
|
64
65
|
|
65
|
-
|
66
|
+
def DiffAugment(x, policy="", channels_first=True):
|
67
|
+
if policy:
|
68
|
+
if not channels_first:
|
69
|
+
x = x.permute(0, 3, 1, 2)
|
70
|
+
for p in policy.split(","):
|
71
|
+
for f in AUGMENT_FNS[p]:
|
72
|
+
x = f(x)
|
73
|
+
if not channels_first:
|
74
|
+
x = x.permute(0, 2, 3, 1)
|
75
|
+
x = x.contiguous()
|
76
|
+
return x
|
77
|
+
|
78
|
+
|
79
|
+
def rand_brightness(x):
|
80
|
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
81
|
+
return x
|
82
|
+
|
83
|
+
|
84
|
+
def rand_saturation(x):
|
85
|
+
x_mean = x.mean(dim=1, keepdim=True)
|
86
|
+
x = (x - x_mean) * (
|
87
|
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
|
88
|
+
) + x_mean
|
89
|
+
return x
|
90
|
+
|
91
|
+
|
92
|
+
def rand_contrast(x):
|
93
|
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
94
|
+
x = (x - x_mean) * (
|
95
|
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
|
96
|
+
) + x_mean
|
97
|
+
return x
|
98
|
+
|
99
|
+
|
100
|
+
def rand_translation(x, ratio=0.125):
|
101
|
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
102
|
+
translation_x = torch.randint(
|
103
|
+
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
|
104
|
+
)
|
105
|
+
translation_y = torch.randint(
|
106
|
+
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
|
107
|
+
)
|
108
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
109
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
110
|
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
111
|
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
112
|
+
)
|
113
|
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
114
|
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
115
|
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
116
|
+
x = (
|
117
|
+
x_pad.permute(0, 2, 3, 1)
|
118
|
+
.contiguous()[grid_batch, grid_x, grid_y]
|
119
|
+
.permute(0, 3, 1, 2)
|
120
|
+
)
|
121
|
+
return x
|
122
|
+
|
123
|
+
|
124
|
+
def rand_cutout(x, ratio=0.5):
|
125
|
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
126
|
+
offset_x = torch.randint(
|
127
|
+
0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
|
128
|
+
)
|
129
|
+
offset_y = torch.randint(
|
130
|
+
0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
|
131
|
+
)
|
132
|
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
133
|
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
134
|
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
135
|
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
136
|
+
)
|
137
|
+
grid_x = torch.clamp(
|
138
|
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
|
139
|
+
)
|
140
|
+
grid_y = torch.clamp(
|
141
|
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
|
142
|
+
)
|
143
|
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
144
|
+
mask[grid_batch, grid_x, grid_y] = 0
|
145
|
+
x = x * mask.unsqueeze(1)
|
146
|
+
return x
|
147
|
+
|
148
|
+
|
149
|
+
AUGMENT_FNS = {
|
150
|
+
"color": [rand_brightness, rand_saturation, rand_contrast],
|
151
|
+
"translation": [rand_translation],
|
152
|
+
"cutout": [rand_cutout],
|
153
|
+
}
|
154
|
+
|
155
|
+
|
156
|
+
__all__ = ["PlotLoss", "compute_entropy", "predict_class_labels", "DiffAugment"]
|
@@ -46,14 +46,14 @@ myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py,sha25
|
|
46
46
|
myosotis_researches/CcGAN/train_128_output_10/utils.py,sha256=B-V6ct4WDisVVCOLO0W7VIBL8StPVNJJTZZ2b2NkMFU,3766
|
47
47
|
myosotis_researches/CcGAN/utils/IMGs_dataset.py,sha256=i45PBNSCeAEB5uUG0SluYRTuHWZwH_5ldz2wm6afkYs,927
|
48
48
|
myosotis_researches/CcGAN/utils/SimpleProgressBar.py,sha256=S4eD_m6ysHRMHAmRtkTXVRNfXTR8kuHv-d3lUN0BVn4,546
|
49
|
-
myosotis_researches/CcGAN/utils/__init__.py,sha256=
|
49
|
+
myosotis_researches/CcGAN/utils/__init__.py,sha256=em3aB0C-V230NQtT64hyuHGo4CjV6p2DwIdtNM0dk4k,516
|
50
50
|
myosotis_researches/CcGAN/utils/concat_image.py,sha256=BIGKz52Inn9S7M5fBFKye2V9bLJ0DqEQILoOVWAXUiE,2165
|
51
51
|
myosotis_researches/CcGAN/utils/make_h5.py,sha256=VtFYjr_i-JktsEW_BvofpilcDmChRmyLykv0VvlMuY0,963
|
52
52
|
myosotis_researches/CcGAN/utils/opts.py,sha256=pd7-wknNPBO5hWRpO3YAPmmAsPKgZUUpKc4gWMs6Wto,5397
|
53
53
|
myosotis_researches/CcGAN/utils/print_hdf5.py,sha256=VvmNAWtMDmg6D9V6ZbSUXrQTKRh9WIJeC4BR_ORJkco,300
|
54
|
-
myosotis_researches/CcGAN/utils/train.py,sha256=
|
55
|
-
myosotis_researches-0.1.
|
56
|
-
myosotis_researches-0.1.
|
57
|
-
myosotis_researches-0.1.
|
58
|
-
myosotis_researches-0.1.
|
59
|
-
myosotis_researches-0.1.
|
54
|
+
myosotis_researches/CcGAN/utils/train.py,sha256=5ZXgkGesuInqUooJRpLej_KHqYQtlSDq90_5wig5elQ,5152
|
55
|
+
myosotis_researches-0.1.8.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
56
|
+
myosotis_researches-0.1.8.dist-info/METADATA,sha256=BWkcdFq2IMeEH_rmMoeEWxwA4H79nPc1DypDZ0DW_cM,2663
|
57
|
+
myosotis_researches-0.1.8.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
|
58
|
+
myosotis_researches-0.1.8.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
|
59
|
+
myosotis_researches-0.1.8.dist-info/RECORD,,
|
File without changes
|
{myosotis_researches-0.1.7.dist-info → myosotis_researches-0.1.8.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|