plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,94 @@
1
+ """
2
+ The Texas100 dataset.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import urllib
8
+ import tarfile
9
+ import torch
10
+ import numpy as np
11
+ from torch.utils import data
12
+ from plato.config import Config
13
+ from plato.datasources import base
14
+
15
+
16
+ class DataSource(base.DataSource):
17
+ """The Texas100 dataset."""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ root_path = Config().params["data_path"]
22
+ feat_path = os.path.join(root_path, "texas/100/feats")
23
+ label_path = os.path.join(root_path, "texas/100/labels")
24
+ if not os.path.isdir(root_path):
25
+ os.mkdir(root_path)
26
+ if not os.path.isfile(feat_path):
27
+ self.download_dataset(root_path, feat_path, label_path)
28
+
29
+ self.trainset, self.testset = self.extract_data(root_path)
30
+
31
+ def download_dataset(self, root_path, feat_path, label_path):
32
+ """Download the Texas100 dataset."""
33
+ logging.info("Downloading the Texas100 dataset...")
34
+ filename = "https://www.comp.nus.edu.sg/~reza/files/dataset_texas.tgz"
35
+ urllib.request.urlretrieve(filename, os.path.join(root_path, "tmp_texas.tgz"))
36
+ logging.info("Dataset downloaded.")
37
+ tar = tarfile.open(os.path.join(root_path, "tmp_texas.tgz"))
38
+ tar.extractall(path=root_path)
39
+
40
+ logging.info("Processing the dataset...")
41
+ data_set_feats = np.genfromtxt(feat_path, delimiter=",")
42
+ data_set_labels = np.genfromtxt(label_path, delimiter=",")
43
+ logging.info("Finish processing the dataset.")
44
+
45
+ X = data_set_feats.astype(np.float64)
46
+ Y = data_set_labels.astype(np.int32) - 1
47
+ np.savez(os.path.join(root_path, "texas_numpy.npz"), X=X, Y=Y)
48
+
49
+ def extract_data(self, root_path):
50
+ """Extract data."""
51
+ data = np.load(os.path.join(root_path, "texas_numpy.npz"))
52
+
53
+ ## randomly shuffle the data
54
+ X, Y = data["X"], data["Y"]
55
+ np.random.seed(0)
56
+ indices = np.arange(len(X))
57
+ np.random.shuffle(indices)
58
+ X, Y = X[indices], Y[indices]
59
+
60
+ ## extract 20000 data samplers for training and testing respectively
61
+ num_train = 20000
62
+ train_data = X[:num_train]
63
+ test_data = X[num_train : num_train * 2]
64
+ train_label = Y[:num_train]
65
+ test_label = Y[num_train : num_train * 2]
66
+
67
+ ## create datasets
68
+ train_dataset = VectorDataset(train_data, train_label)
69
+ test_dataset = VectorDataset(test_data, test_label)
70
+
71
+ return train_dataset, test_dataset
72
+
73
+ def num_train_examples(self):
74
+ return 20000
75
+
76
+ def num_test_examples(self):
77
+ return 20000
78
+
79
+
80
+ class VectorDataset(data.Dataset):
81
+ """
82
+ Create a Texas100 dataset based on features and labels
83
+ """
84
+
85
+ def __init__(self, features, labels):
86
+ self.data = torch.stack([torch.FloatTensor(i) for i in features])
87
+ self.targets = torch.stack([torch.LongTensor([i]) for i in labels])[:, 0]
88
+ self.classes = [f"Procedure #{i}" for i in range(100)]
89
+
90
+ def __getitem__(self, index):
91
+ return self.data[index], self.targets[index]
92
+
93
+ def __len__(self):
94
+ return self.data.size(0)
@@ -0,0 +1,64 @@
1
+ """
2
+ The Tiny ImageNet 200 Classification dataset.
3
+
4
+ Tiny ImageNet contains 100000 images of 200 classes (500 for each class)
5
+ downsized to 64×64 colored images.
6
+ Each class has 500 training images, 50 validation images and 50 test images.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ from torchvision import datasets, transforms
13
+
14
+ from plato.config import Config
15
+ from plato.datasources import base
16
+
17
+
18
+ class DataSource(base.DataSource):
19
+ """The Tiny ImageNet 200 dataset."""
20
+
21
+ def __init__(self, **kwargs):
22
+ super().__init__()
23
+ _path = Config().params["data_path"]
24
+
25
+ if not os.path.exists(_path):
26
+ logging.info(
27
+ "Downloading the Tiny ImageNet 200 dataset. This may take a while."
28
+ )
29
+ url = (
30
+ Config().data.download_url
31
+ if hasattr(Config().data, "download_url")
32
+ else "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
33
+ )
34
+ DataSource.download(url, _path)
35
+
36
+ train_transform = (
37
+ kwargs["train_transform"]
38
+ if "train_transform" in kwargs
39
+ else (
40
+ transforms.Compose(
41
+ [
42
+ transforms.RandomResizedCrop(299),
43
+ transforms.CenterCrop(299),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(
46
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
47
+ ),
48
+ ]
49
+ )
50
+ )
51
+ )
52
+ test_transform = train_transform
53
+ self.trainset = datasets.ImageFolder(
54
+ root=os.path.join(_path, "train"), transform=train_transform
55
+ )
56
+ self.testset = datasets.ImageFolder(
57
+ root=os.path.join(_path, "test"), transform=test_transform
58
+ )
59
+
60
+ def num_train_examples(self):
61
+ return 100000
62
+
63
+ def num_test_examples(self):
64
+ return 10000
@@ -0,0 +1,85 @@
1
+ """
2
+ The COCO dataset or other datasets for the YOLOv8 model.
3
+
4
+ For more information about COCO 128, which contains the first 128 images of the
5
+ COCO 2017 dataset, refer to https://www.kaggle.com/ultralytics/coco128.
6
+
7
+ For more information about the COCO 2017 dataset, refer to http://cocodataset.org.
8
+ """
9
+
10
+ from ultralytics.data.dataset import YOLODataset
11
+ from ultralytics.cfg import DEFAULT_CFG
12
+ from ultralytics.data.utils import check_det_dataset
13
+ from plato.config import Config
14
+ from plato.datasources import base
15
+
16
+
17
+ class DataSource(base.DataSource):
18
+ """The YOLO dataset."""
19
+
20
+ # pylint: disable=unused-argument
21
+ def __init__(self, **kwargs):
22
+ super().__init__()
23
+
24
+ self.grid_size = Config().parameters.grid_size
25
+ self.data = check_det_dataset(Config().data.data_params)
26
+ self.train_set = None
27
+ self.test_set = None
28
+
29
+ def get_train_set(self):
30
+ single_class = Config().parameters.model.num_classes == 1
31
+
32
+ if self.train_set is None:
33
+ self.train_set = YOLODataset(
34
+ img_path=self.data["train"],
35
+ imgsz=Config().data.image_size,
36
+ batch_size=Config().trainer.batch_size,
37
+ augment=False,
38
+ hyp=DEFAULT_CFG,
39
+ rect=False,
40
+ cache=False,
41
+ single_cls=single_class,
42
+ stride=int(self.grid_size),
43
+ pad=0.0,
44
+ prefix="",
45
+ use_segments=False,
46
+ use_keypoints=False,
47
+ classes=Config().data.classes,
48
+ data=self.data,
49
+ )
50
+
51
+ return self.train_set
52
+
53
+ def get_test_set(self):
54
+ single_class = Config().parameters.model.num_classes == 1
55
+
56
+ if self.test_set is None:
57
+ self.test_set = YOLODataset(
58
+ img_path=self.data["val"],
59
+ imgsz=Config().data.image_size,
60
+ batch_size=Config().trainer.batch_size,
61
+ augment=True,
62
+ hyp=DEFAULT_CFG,
63
+ rect=False,
64
+ cache=False,
65
+ single_cls=single_class,
66
+ stride=int(self.grid_size),
67
+ pad=0.0,
68
+ prefix="",
69
+ use_segments=False,
70
+ use_keypoints=False,
71
+ classes=Config().data.classes,
72
+ data=self.data,
73
+ )
74
+
75
+ return self.test_set
76
+
77
+ def num_train_examples(self):
78
+ return Config().data.num_train_examples
79
+
80
+ def num_test_examples(self):
81
+ return Config().data.num_test_examples
82
+
83
+ def classes(self):
84
+ """Obtains a list of class names in the dataset."""
85
+ return Config().data.classes
File without changes
@@ -0,0 +1,103 @@
1
+ """
2
+ A factory that generates fully convolutional neural network encoder.
3
+
4
+ The fully convolutional neural network used as the encoder in this implementation
5
+ is the backbone part of a model.
6
+
7
+ This encoder is solely capable of extracting features from the input sample.
8
+ Generally, in the context of classification tasks, this encoder has to cooperate
9
+ with the classification head to make the prediction.
10
+
11
+ Besides, the 'AdaptiveAvgPool2d' layer is included to support extracting
12
+ features with fixed dimensions.
13
+
14
+ """
15
+
16
+ from typing import Optional, Dict
17
+
18
+ from torch import nn
19
+ import torchvision
20
+
21
+ from plato.models.lenet5 import Model as lenet5_model
22
+ from plato.models.vgg import Model as vgg_model
23
+
24
+ from plato.config import Config
25
+
26
+
27
+ class TruncatedLeNetModel(nn.Module):
28
+ """The truncated LeNet-5 model."""
29
+
30
+ def __init__(self, defined_lenet5_model):
31
+ super().__init__()
32
+ self.model = defined_lenet5_model
33
+ self.model.fc4 = nn.Identity()
34
+ self.model.relu4 = nn.Identity()
35
+ self.model.fc5 = nn.Identity()
36
+
37
+ def forward(self, samples):
38
+ """Forward to specific layer (cut)_layer) of LeNet5."""
39
+ self.model.cut_layer = "flatten"
40
+ return self.model.forward_to(samples)
41
+
42
+
43
+ class Model:
44
+ """The encoder obtained by removing the final
45
+ fully-connected blocks of the required model.
46
+ """
47
+
48
+ # pylint:disable=too-few-public-methods
49
+ @staticmethod
50
+ def get(model_name: Optional[str] = None, **kwargs: Dict[str, str]): # pylint: disable=unused-argument
51
+ """Returns an encoder that is a fully CNN block."""
52
+
53
+ # as the final fully-connected layer will be removed
54
+ # the number of classes can be the randomly value
55
+ # thus, set it to be constant value 10.
56
+ num_classes = 10
57
+
58
+ if model_name == "lenet5":
59
+ model = lenet5_model(num_classes=num_classes)
60
+ # get encoding dimensions
61
+ # i.e., the output dim of the encoder
62
+ encode_output_dim = model.fc4.in_features
63
+ encoder = TruncatedLeNetModel(model)
64
+
65
+ if "vgg" in model_name:
66
+ encoder = vgg_model.get(model_name=model_name, num_classes=num_classes)
67
+ # get encoding dimensions
68
+ # i.e., the output dim of the encoder
69
+ encode_output_dim = encoder.fc.in_features
70
+ encoder.fc = nn.Identity()
71
+
72
+ if "resnet" in model_name:
73
+ resnets = {
74
+ "resnet_18": torchvision.models.resnet18,
75
+ "resnet_50": torchvision.models.resnet50,
76
+ }
77
+
78
+ encoder = resnets[model_name](num_classes=num_classes)
79
+
80
+ datasource = (
81
+ kwargs["datasource"]
82
+ if "datasource" in kwargs
83
+ else Config().data.datasource
84
+ )
85
+ if "CIFAR" in datasource:
86
+ # The structure specifically for CIFAR-based dataset.
87
+ # Replace conv 7x7 with conv 3x3,
88
+ # and remove first max pooling.
89
+ # For example,
90
+ # see Section B.9 of SimCLR paper.
91
+ encoder.conv1 = nn.Conv2d(
92
+ 3, 64, kernel_size=3, stride=1, padding=2, bias=False
93
+ )
94
+ encoder.maxpool = nn.Identity()
95
+
96
+ # get encoding dimensions
97
+ # i.e., the output dim of the encoder
98
+ encode_output_dim = encoder.fc.in_features
99
+ encoder.fc = nn.Identity()
100
+
101
+ encoder.encoding_dim = encode_output_dim
102
+
103
+ return encoder
plato/models/dcgan.py ADDED
@@ -0,0 +1,116 @@
1
+ """
2
+ The DCGAN model.
3
+
4
+ Reference:
5
+ https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
6
+ """
7
+
8
+ from torch import nn
9
+
10
+ nz = 100
11
+ nc = 3
12
+ ngf = 64
13
+ ndf = 64
14
+
15
+
16
+ class Generator(nn.Module):
17
+ """Generator network of DCGAN"""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ self.main = nn.Sequential(
23
+ # input is Z, going into a convolution
24
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
25
+ nn.BatchNorm2d(ngf * 8),
26
+ nn.ReLU(True),
27
+ # state size. (ngf*8) x 4 x 4
28
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(ngf * 4),
30
+ nn.ReLU(True),
31
+ # state size. (ngf*4) x 8 x 8
32
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
33
+ nn.BatchNorm2d(ngf * 2),
34
+ nn.ReLU(True),
35
+ # state size. (ngf*2) x 16 x 16
36
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
37
+ nn.BatchNorm2d(ngf),
38
+ nn.ReLU(True),
39
+ # state size. (ngf) x 32 x 32
40
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
41
+ nn.Tanh(),
42
+ # state size. (nc) x 64 x 64
43
+ )
44
+
45
+ def forward(self, input_data):
46
+ """Forward pass."""
47
+ return self.main(input_data)
48
+
49
+
50
+ class Discriminator(nn.Module):
51
+ """Discriminator network of DCGAN."""
52
+
53
+ def __init__(self):
54
+ super().__init__()
55
+
56
+ self.main = nn.Sequential(
57
+ # input is (nc) x 64 x 64
58
+ nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
59
+ nn.LeakyReLU(0.2, inplace=True),
60
+ # state size. (ndf) x 32 x 32
61
+ nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
62
+ nn.BatchNorm2d(ndf * 2),
63
+ nn.LeakyReLU(0.2, inplace=True),
64
+ # state size. (ndf*2) x 16 x 16
65
+ nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
66
+ nn.BatchNorm2d(ndf * 4),
67
+ nn.LeakyReLU(0.2, inplace=True),
68
+ # state size. (ndf*4) x 8 x 8
69
+ nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
70
+ nn.BatchNorm2d(ndf * 8),
71
+ nn.LeakyReLU(0.2, inplace=True),
72
+ # state size. (ndf*8) x 4 x 4
73
+ nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
74
+ nn.Sigmoid(),
75
+ )
76
+
77
+ def forward(self, input_data):
78
+ return self.main(input_data)
79
+
80
+
81
+ class Model:
82
+ """A wrapper class to hold the Generator and Discriminator models of DCGAN."""
83
+
84
+ def __init__(self) -> None:
85
+ self.generator = Generator()
86
+ self.discriminator = Discriminator()
87
+ self.loss_criterion = nn.BCELoss()
88
+
89
+ self.nz = nz
90
+ self.nc = nc
91
+ self.ngf = ngf
92
+ self.ndf = ndf
93
+
94
+ def weights_init(self, model):
95
+ classname = model.__class__.__name__
96
+ if classname.find("Conv") != -1:
97
+ nn.init.normal_(model.weight.data, 0.0, 0.02)
98
+ elif classname.find("BatchNorm") != -1:
99
+ nn.init.normal_(model.weight.data, 1.0, 0.02)
100
+ nn.init.constant_(model.bias.data, 0)
101
+
102
+ def cpu(self):
103
+ self.generator.cpu()
104
+ self.discriminator.cpu()
105
+
106
+ def to(self, device):
107
+ self.generator.to(device)
108
+ self.discriminator.to(device)
109
+
110
+ def train(self):
111
+ self.generator.train()
112
+ self.discriminator.train()
113
+
114
+ def eval(self):
115
+ self.generator.eval()
116
+ self.discriminator.eval()