Myosotis-Researches 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ import torch
2
+
3
+ class IMGs_dataset(torch.utils.data.Dataset):
4
+ def __init__(self, images, labels=None, normalize=False):
5
+ super(IMGs_dataset, self).__init__()
6
+
7
+ self.images = images
8
+ self.n_images = len(self.images)
9
+ self.labels = labels
10
+ if labels is not None:
11
+ if len(self.images) != len(self.labels):
12
+ raise Exception('images (' + str(len(self.images)) +') and labels ('+str(len(self.labels))+') do not have the same length!!!')
13
+ self.normalize = normalize
14
+
15
+
16
+ def __getitem__(self, index):
17
+
18
+ image = self.images[index]
19
+
20
+ if self.normalize:
21
+ image = image/255.0
22
+ image = (image-0.5)/0.5
23
+
24
+ if self.labels is not None:
25
+ label = self.labels[index]
26
+ return (image, label)
27
+ else:
28
+ return image
29
+
30
+ def __len__(self):
31
+ return self.n_images
32
+
33
+ __all__ = ["IMGs_dataset"]
@@ -0,0 +1,18 @@
1
+ import sys
2
+
3
+ class SimpleProgressBar():
4
+ def __init__(self, width=50):
5
+ self.last_x = -1
6
+ self.width = width
7
+
8
+ def update(self, x):
9
+ assert 0 <= x <= 100 # `x`: progress in percent ( between 0 and 100)
10
+ if self.last_x == int(x): return
11
+ self.last_x = int(x)
12
+ pointer = int(self.width * (x / 100.0))
13
+ sys.stdout.write( '\r%d%% [%s]' % (int(x), '#' * pointer + '.' * (self.width - pointer)))
14
+ sys.stdout.flush()
15
+ if x == 100:
16
+ print('')
17
+
18
+ __all__ = ["SimpleProgressBar"]
@@ -1,5 +1,17 @@
1
1
  from .print_hdf5 import print_hdf5
2
2
  from .concat_image import concat_image
3
3
  from .make_h5 import make_h5
4
+ from .SimpleProgressBar import SimpleProgressBar
5
+ from .IMGs_dataset import IMGs_dataset
6
+ from .train import PlotLoss, compute_entropy, predict_class_labels
4
7
 
5
- __all__ = ["print_hdf5", "concat_image", "make_h5"]
8
+ __all__ = [
9
+ "print_hdf5",
10
+ "concat_image",
11
+ "make_h5",
12
+ "SimpleProgressBar",
13
+ "IMGs_dataset",
14
+ "PlotLoss",
15
+ "compute_entropy",
16
+ "predict_class_labels",
17
+ ]
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib as mpl
5
+
6
+
7
+ def PlotLoss(loss, filename):
8
+ x_axis = np.arange(start=1, stop=len(loss) + 1)
9
+ plt.switch_backend("agg")
10
+ mpl.style.use("seaborn")
11
+ fig = plt.figure()
12
+ ax = plt.subplot(111)
13
+ ax.plot(x_axis, np.array(loss))
14
+ plt.xlabel("epoch")
15
+ plt.ylabel("training loss")
16
+ plt.legend()
17
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), shadow=True, ncol=3)
18
+ # plt.title('Training Loss')
19
+ plt.savefig(filename)
20
+
21
+
22
+ # compute entropy of class labels; labels is a numpy array
23
+ def compute_entropy(labels, base=None):
24
+ value, counts = np.unique(labels, return_counts=True)
25
+ norm_counts = counts / counts.sum()
26
+ base = np.e if base is None else base
27
+ return -(norm_counts * np.log(norm_counts) / np.log(base)).sum()
28
+
29
+
30
+ def predict_class_labels(net, images, batch_size=500, verbose=False, num_workers=0):
31
+ net = net.cuda()
32
+ net.eval()
33
+
34
+ n = len(images)
35
+ if batch_size > n:
36
+ batch_size = n
37
+ dataset_pred = IMGs_dataset(images, normalize=False)
38
+ dataloader_pred = torch.utils.data.DataLoader(
39
+ dataset_pred, batch_size=batch_size, shuffle=False, num_workers=num_workers
40
+ )
41
+
42
+ class_labels_pred = np.zeros(n + batch_size)
43
+ with torch.no_grad():
44
+ nimgs_got = 0
45
+ if verbose:
46
+ pb = SimpleProgressBar()
47
+ for batch_idx, batch_images in enumerate(dataloader_pred):
48
+ batch_images = batch_images.type(torch.float).cuda()
49
+ batch_size_curr = len(batch_images)
50
+
51
+ outputs, _ = net(batch_images)
52
+ _, batch_class_labels_pred = torch.max(outputs.data, 1)
53
+ class_labels_pred[nimgs_got : (nimgs_got + batch_size_curr)] = (
54
+ batch_class_labels_pred.detach().cpu().numpy().reshape(-1)
55
+ )
56
+
57
+ nimgs_got += batch_size_curr
58
+ if verbose:
59
+ pb.update((float(nimgs_got) / n) * 100)
60
+ # end for batch_idx
61
+ class_labels_pred = class_labels_pred[0:n]
62
+ return class_labels_pred
63
+
64
+
65
+ __all__ = ["PlotLoss", "compute_entropy", "predict_class_labels"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: Myosotis-Researches
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A repository for storing my progress of researches.
5
5
  Home-page: https://github.com/Zeyu-Xie/Myosotis-Researches
6
6
  Author: Zeyu Xie
@@ -44,12 +44,15 @@ myosotis_researches/CcGAN/train_128_output_10/train_cgan.py,sha256=bYJbBskTpESfC
44
44
  myosotis_researches/CcGAN/train_128_output_10/train_cgan_concat.py,sha256=PYctY3IZiHGh4TshXx3mUZBf9su_8NuV_D8InkxKQZ4,8940
45
45
  myosotis_researches/CcGAN/train_128_output_10/train_net_for_label_embed.py,sha256=4j6r4_o4rXgAN4MdUQL-TXqZJpbhH7d9gWQR8YzBlXw,6976
46
46
  myosotis_researches/CcGAN/train_128_output_10/utils.py,sha256=B-V6ct4WDisVVCOLO0W7VIBL8StPVNJJTZZ2b2NkMFU,3766
47
- myosotis_researches/CcGAN/utils/__init__.py,sha256=azZ2ZSSmWREoptI_5oQ180HojMoCqv2oleveRswq40w,155
47
+ myosotis_researches/CcGAN/utils/IMGs_dataset.py,sha256=i45PBNSCeAEB5uUG0SluYRTuHWZwH_5ldz2wm6afkYs,927
48
+ myosotis_researches/CcGAN/utils/SimpleProgressBar.py,sha256=S4eD_m6ysHRMHAmRtkTXVRNfXTR8kuHv-d3lUN0BVn4,546
49
+ myosotis_researches/CcGAN/utils/__init__.py,sha256=shSmo-zunolt8zSZ-Cjgv__N2kyflBfrR8UfxvKJqGg,438
48
50
  myosotis_researches/CcGAN/utils/concat_image.py,sha256=BIGKz52Inn9S7M5fBFKye2V9bLJ0DqEQILoOVWAXUiE,2165
49
51
  myosotis_researches/CcGAN/utils/make_h5.py,sha256=VtFYjr_i-JktsEW_BvofpilcDmChRmyLykv0VvlMuY0,963
50
52
  myosotis_researches/CcGAN/utils/print_hdf5.py,sha256=VvmNAWtMDmg6D9V6ZbSUXrQTKRh9WIJeC4BR_ORJkco,300
51
- myosotis_researches-0.1.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
52
- myosotis_researches-0.1.4.dist-info/METADATA,sha256=O8xsjrUOntGmWZQLl5f7_glRnoCn01smEfF9-B7GA8g,2663
53
- myosotis_researches-0.1.4.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
54
- myosotis_researches-0.1.4.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
55
- myosotis_researches-0.1.4.dist-info/RECORD,,
53
+ myosotis_researches/CcGAN/utils/train.py,sha256=NhUee86SkFT7Cq5RG8Fhy0f6WbZNJ5jmomDlhq9FY5I,2140
54
+ myosotis_researches-0.1.5.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
55
+ myosotis_researches-0.1.5.dist-info/METADATA,sha256=G8_S3LwrCNALwaZEvMoCWFIIwJRRz1mOoJcmEkN9JUc,2663
56
+ myosotis_researches-0.1.5.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
57
+ myosotis_researches-0.1.5.dist-info/top_level.txt,sha256=zxAiMn5eyZNJM28MewTAkgi_RZJMbfWbzVR-KF0LdZE,20
58
+ myosotis_researches-0.1.5.dist-info/RECORD,,