Myosotis-Researches 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from .concat_image import concat_image
3
3
  from .make_h5 import make_h5
4
4
  from .SimpleProgressBar import SimpleProgressBar
5
5
  from .IMGs_dataset import IMGs_dataset
6
- from .train import PlotLoss, compute_entropy, predict_class_labels
6
+ from .train import PlotLoss, compute_entropy, predict_class_labels, DiffAugment
7
7
  from .opts import parse_opts
8
8
 
9
9
  __all__ = [
@@ -15,5 +15,6 @@ __all__ = [
15
15
  "PlotLoss",
16
16
  "compute_entropy",
17
17
  "predict_class_labels",
18
+ "DiffAugment",
18
19
  "parse_opts"
19
20
  ]
@@ -1,7 +1,8 @@
1
+ import matplotlib as mpl
2
+ import matplotlib.pyplot as plt
1
3
  import numpy as np
2
4
  import torch
3
- import matplotlib.pyplot as plt
4
- import matplotlib as mpl
5
+ import torch.nn.functional as F
5
6
 
6
7
 
7
8
  def PlotLoss(loss, filename):
@@ -62,4 +63,94 @@ def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers
62
63
  return class_labels_pred
63
64
 
64
65
 
65
- __all__ = ["PlotLoss", "compute_entropy", "predict_class_labels"]
66
+ def DiffAugment(x, policy="", channels_first=True):
67
+ if policy:
68
+ if not channels_first:
69
+ x = x.permute(0, 3, 1, 2)
70
+ for p in policy.split(","):
71
+ for f in AUGMENT_FNS[p]:
72
+ x = f(x)
73
+ if not channels_first:
74
+ x = x.permute(0, 2, 3, 1)
75
+ x = x.contiguous()
76
+ return x
77
+
78
+
79
+ def rand_brightness(x):
80
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
81
+ return x
82
+
83
+
84
+ def rand_saturation(x):
85
+ x_mean = x.mean(dim=1, keepdim=True)
86
+ x = (x - x_mean) * (
87
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
88
+ ) + x_mean
89
+ return x
90
+
91
+
92
+ def rand_contrast(x):
93
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
94
+ x = (x - x_mean) * (
95
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
96
+ ) + x_mean
97
+ return x
98
+
99
+
100
+ def rand_translation(x, ratio=0.125):
101
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
102
+ translation_x = torch.randint(
103
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
104
+ )
105
+ translation_y = torch.randint(
106
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
107
+ )
108
+ grid_batch, grid_x, grid_y = torch.meshgrid(
109
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
110
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
111
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
112
+ )
113
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
114
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
115
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
116
+ x = (
117
+ x_pad.permute(0, 2, 3, 1)
118
+ .contiguous()[grid_batch, grid_x, grid_y]
119
+ .permute(0, 3, 1, 2)
120
+ )
121
+ return x
122
+
123
+
124
+ def rand_cutout(x, ratio=0.5):
125
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
126
+ offset_x = torch.randint(
127
+ 0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
128
+ )
129
+ offset_y = torch.randint(
130
+ 0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
131
+ )
132
+ grid_batch, grid_x, grid_y = torch.meshgrid(
133
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
134
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
135
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
136
+ )
137
+ grid_x = torch.clamp(
138
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
139
+ )
140
+ grid_y = torch.clamp(
141
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
142
+ )
143
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
144
+ mask[grid_batch, grid_x, grid_y] = 0
145
+ x = x * mask.unsqueeze(1)
146
+ return x
147
+
148
+
149
+ AUGMENT_FNS = {
150
+ "color": [rand_brightness, rand_saturation, rand_contrast],
151
+ "translation": [rand_translation],
152
+ "cutout": [rand_cutout],
153
+ }
154
+
155
+
156
+ __all__ = ["PlotLoss", "compute_entropy", "predict_class_labels", "DiffAugment"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -46,14 +46,14 @@ myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py,sha25
46
46
  myosotis_researches/CcGAN/train_128_output_10/utils.py,sha256=B-V6ct4WDisVVCOLO0W7VIBL8StPVNJJTZZ2b2NkMFU,3766
47
47
  myosotis_researches/CcGAN/utils/IMGs_dataset.py,sha256=i45PBNSCeAEB5uUG0SluYRTuHWZwH_5ldz2wm6afkYs,927
48
48
  myosotis_researches/CcGAN/utils/SimpleProgressBar.py,sha256=S4eD_m6ysHRMHAmRtkTXVRNfXTR8kuHv-d3lUN0BVn4,546
49
- myosotis_researches/CcGAN/utils/__init__.py,sha256=6eJdO4qgHefW606C_ATXg8xhjixeTQHkOdNxBOKACwQ,484
49
+ myosotis_researches/CcGAN/utils/__init__.py,sha256=em3aB0C-V230NQtT64hyuHGo4CjV6p2DwIdtNM0dk4k,516
50
50
  myosotis_researches/CcGAN/utils/concat_image.py,sha256=BIGKz52Inn9S7M5fBFKye2V9bLJ0DqEQILoOVWAXUiE,2165
51
51
  myosotis_researches/CcGAN/utils/make_h5.py,sha256=VtFYjr_i-JktsEW_BvofpilcDmChRmyLykv0VvlMuY0,963
52
52
  myosotis_researches/CcGAN/utils/opts.py,sha256=pd7-wknNPBO5hWRpO3YAPmmAsPKgZUUpKc4gWMs6Wto,5397
53
53
  myosotis_researches/CcGAN/utils/print_hdf5.py,sha256=VvmNAWtMDmg6D9V6ZbSUXrQTKRh9WIJeC4BR_ORJkco,300
54
- myosotis_researches/CcGAN/utils/train.py,sha256=NhUee86SkFT7Cq5RG8Fhy0f6WbZNJ5jmomDlhq9FY5I,2140
55
- myosotis_researches-0.1.7.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
56
- myosotis_researches-0.1.7.dist-info/METADATA,sha256=Gde6bmI1QC4CsNsEWxgMZ1Eip-dETkF20Z4y1BZTqTw,2663
57
- myosotis_researches-0.1.7.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
58
- myosotis_researches-0.1.7.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
59
- myosotis_researches-0.1.7.dist-info/RECORD,,
54
+ myosotis_researches/CcGAN/utils/train.py,sha256=5ZXgkGesuInqUooJRpLej_KHqYQtlSDq90_5wig5elQ,5152
55
+ myosotis_researches-0.1.8.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
56
+ myosotis_researches-0.1.8.dist-info/METADATA,sha256=BWkcdFq2IMeEH_rmMoeEWxwA4H79nPc1DypDZ0DW_cM,2663
57
+ myosotis_researches-0.1.8.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
58
+ myosotis_researches-0.1.8.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
59
+ myosotis_researches-0.1.8.dist-info/RECORD,,