plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
"""
|
2
|
+
The Texas100 dataset.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import logging
|
7
|
+
import urllib
|
8
|
+
import tarfile
|
9
|
+
import torch
|
10
|
+
import numpy as np
|
11
|
+
from torch.utils import data
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.datasources import base
|
14
|
+
|
15
|
+
|
16
|
+
class DataSource(base.DataSource):
|
17
|
+
"""The Texas100 dataset."""
|
18
|
+
|
19
|
+
def __init__(self):
|
20
|
+
super().__init__()
|
21
|
+
root_path = Config().params["data_path"]
|
22
|
+
feat_path = os.path.join(root_path, "texas/100/feats")
|
23
|
+
label_path = os.path.join(root_path, "texas/100/labels")
|
24
|
+
if not os.path.isdir(root_path):
|
25
|
+
os.mkdir(root_path)
|
26
|
+
if not os.path.isfile(feat_path):
|
27
|
+
self.download_dataset(root_path, feat_path, label_path)
|
28
|
+
|
29
|
+
self.trainset, self.testset = self.extract_data(root_path)
|
30
|
+
|
31
|
+
def download_dataset(self, root_path, feat_path, label_path):
|
32
|
+
"""Download the Texas100 dataset."""
|
33
|
+
logging.info("Downloading the Texas100 dataset...")
|
34
|
+
filename = "https://www.comp.nus.edu.sg/~reza/files/dataset_texas.tgz"
|
35
|
+
urllib.request.urlretrieve(filename, os.path.join(root_path, "tmp_texas.tgz"))
|
36
|
+
logging.info("Dataset downloaded.")
|
37
|
+
tar = tarfile.open(os.path.join(root_path, "tmp_texas.tgz"))
|
38
|
+
tar.extractall(path=root_path)
|
39
|
+
|
40
|
+
logging.info("Processing the dataset...")
|
41
|
+
data_set_feats = np.genfromtxt(feat_path, delimiter=",")
|
42
|
+
data_set_labels = np.genfromtxt(label_path, delimiter=",")
|
43
|
+
logging.info("Finish processing the dataset.")
|
44
|
+
|
45
|
+
X = data_set_feats.astype(np.float64)
|
46
|
+
Y = data_set_labels.astype(np.int32) - 1
|
47
|
+
np.savez(os.path.join(root_path, "texas_numpy.npz"), X=X, Y=Y)
|
48
|
+
|
49
|
+
def extract_data(self, root_path):
|
50
|
+
"""Extract data."""
|
51
|
+
data = np.load(os.path.join(root_path, "texas_numpy.npz"))
|
52
|
+
|
53
|
+
## randomly shuffle the data
|
54
|
+
X, Y = data["X"], data["Y"]
|
55
|
+
np.random.seed(0)
|
56
|
+
indices = np.arange(len(X))
|
57
|
+
np.random.shuffle(indices)
|
58
|
+
X, Y = X[indices], Y[indices]
|
59
|
+
|
60
|
+
## extract 20000 data samplers for training and testing respectively
|
61
|
+
num_train = 20000
|
62
|
+
train_data = X[:num_train]
|
63
|
+
test_data = X[num_train : num_train * 2]
|
64
|
+
train_label = Y[:num_train]
|
65
|
+
test_label = Y[num_train : num_train * 2]
|
66
|
+
|
67
|
+
## create datasets
|
68
|
+
train_dataset = VectorDataset(train_data, train_label)
|
69
|
+
test_dataset = VectorDataset(test_data, test_label)
|
70
|
+
|
71
|
+
return train_dataset, test_dataset
|
72
|
+
|
73
|
+
def num_train_examples(self):
|
74
|
+
return 20000
|
75
|
+
|
76
|
+
def num_test_examples(self):
|
77
|
+
return 20000
|
78
|
+
|
79
|
+
|
80
|
+
class VectorDataset(data.Dataset):
|
81
|
+
"""
|
82
|
+
Create a Texas100 dataset based on features and labels
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(self, features, labels):
|
86
|
+
self.data = torch.stack([torch.FloatTensor(i) for i in features])
|
87
|
+
self.targets = torch.stack([torch.LongTensor([i]) for i in labels])[:, 0]
|
88
|
+
self.classes = [f"Procedure #{i}" for i in range(100)]
|
89
|
+
|
90
|
+
def __getitem__(self, index):
|
91
|
+
return self.data[index], self.targets[index]
|
92
|
+
|
93
|
+
def __len__(self):
|
94
|
+
return self.data.size(0)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
"""
|
2
|
+
The Tiny ImageNet 200 Classification dataset.
|
3
|
+
|
4
|
+
Tiny ImageNet contains 100000 images of 200 classes (500 for each class)
|
5
|
+
downsized to 64×64 colored images.
|
6
|
+
Each class has 500 training images, 50 validation images and 50 test images.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
|
12
|
+
from torchvision import datasets, transforms
|
13
|
+
|
14
|
+
from plato.config import Config
|
15
|
+
from plato.datasources import base
|
16
|
+
|
17
|
+
|
18
|
+
class DataSource(base.DataSource):
|
19
|
+
"""The Tiny ImageNet 200 dataset."""
|
20
|
+
|
21
|
+
def __init__(self, **kwargs):
|
22
|
+
super().__init__()
|
23
|
+
_path = Config().params["data_path"]
|
24
|
+
|
25
|
+
if not os.path.exists(_path):
|
26
|
+
logging.info(
|
27
|
+
"Downloading the Tiny ImageNet 200 dataset. This may take a while."
|
28
|
+
)
|
29
|
+
url = (
|
30
|
+
Config().data.download_url
|
31
|
+
if hasattr(Config().data, "download_url")
|
32
|
+
else "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
|
33
|
+
)
|
34
|
+
DataSource.download(url, _path)
|
35
|
+
|
36
|
+
train_transform = (
|
37
|
+
kwargs["train_transform"]
|
38
|
+
if "train_transform" in kwargs
|
39
|
+
else (
|
40
|
+
transforms.Compose(
|
41
|
+
[
|
42
|
+
transforms.RandomResizedCrop(299),
|
43
|
+
transforms.CenterCrop(299),
|
44
|
+
transforms.ToTensor(),
|
45
|
+
transforms.Normalize(
|
46
|
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
47
|
+
),
|
48
|
+
]
|
49
|
+
)
|
50
|
+
)
|
51
|
+
)
|
52
|
+
test_transform = train_transform
|
53
|
+
self.trainset = datasets.ImageFolder(
|
54
|
+
root=os.path.join(_path, "train"), transform=train_transform
|
55
|
+
)
|
56
|
+
self.testset = datasets.ImageFolder(
|
57
|
+
root=os.path.join(_path, "test"), transform=test_transform
|
58
|
+
)
|
59
|
+
|
60
|
+
def num_train_examples(self):
|
61
|
+
return 100000
|
62
|
+
|
63
|
+
def num_test_examples(self):
|
64
|
+
return 10000
|
@@ -0,0 +1,85 @@
|
|
1
|
+
"""
|
2
|
+
The COCO dataset or other datasets for the YOLOv8 model.
|
3
|
+
|
4
|
+
For more information about COCO 128, which contains the first 128 images of the
|
5
|
+
COCO 2017 dataset, refer to https://www.kaggle.com/ultralytics/coco128.
|
6
|
+
|
7
|
+
For more information about the COCO 2017 dataset, refer to http://cocodataset.org.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from ultralytics.data.dataset import YOLODataset
|
11
|
+
from ultralytics.cfg import DEFAULT_CFG
|
12
|
+
from ultralytics.data.utils import check_det_dataset
|
13
|
+
from plato.config import Config
|
14
|
+
from plato.datasources import base
|
15
|
+
|
16
|
+
|
17
|
+
class DataSource(base.DataSource):
|
18
|
+
"""The YOLO dataset."""
|
19
|
+
|
20
|
+
# pylint: disable=unused-argument
|
21
|
+
def __init__(self, **kwargs):
|
22
|
+
super().__init__()
|
23
|
+
|
24
|
+
self.grid_size = Config().parameters.grid_size
|
25
|
+
self.data = check_det_dataset(Config().data.data_params)
|
26
|
+
self.train_set = None
|
27
|
+
self.test_set = None
|
28
|
+
|
29
|
+
def get_train_set(self):
|
30
|
+
single_class = Config().parameters.model.num_classes == 1
|
31
|
+
|
32
|
+
if self.train_set is None:
|
33
|
+
self.train_set = YOLODataset(
|
34
|
+
img_path=self.data["train"],
|
35
|
+
imgsz=Config().data.image_size,
|
36
|
+
batch_size=Config().trainer.batch_size,
|
37
|
+
augment=False,
|
38
|
+
hyp=DEFAULT_CFG,
|
39
|
+
rect=False,
|
40
|
+
cache=False,
|
41
|
+
single_cls=single_class,
|
42
|
+
stride=int(self.grid_size),
|
43
|
+
pad=0.0,
|
44
|
+
prefix="",
|
45
|
+
use_segments=False,
|
46
|
+
use_keypoints=False,
|
47
|
+
classes=Config().data.classes,
|
48
|
+
data=self.data,
|
49
|
+
)
|
50
|
+
|
51
|
+
return self.train_set
|
52
|
+
|
53
|
+
def get_test_set(self):
|
54
|
+
single_class = Config().parameters.model.num_classes == 1
|
55
|
+
|
56
|
+
if self.test_set is None:
|
57
|
+
self.test_set = YOLODataset(
|
58
|
+
img_path=self.data["val"],
|
59
|
+
imgsz=Config().data.image_size,
|
60
|
+
batch_size=Config().trainer.batch_size,
|
61
|
+
augment=True,
|
62
|
+
hyp=DEFAULT_CFG,
|
63
|
+
rect=False,
|
64
|
+
cache=False,
|
65
|
+
single_cls=single_class,
|
66
|
+
stride=int(self.grid_size),
|
67
|
+
pad=0.0,
|
68
|
+
prefix="",
|
69
|
+
use_segments=False,
|
70
|
+
use_keypoints=False,
|
71
|
+
classes=Config().data.classes,
|
72
|
+
data=self.data,
|
73
|
+
)
|
74
|
+
|
75
|
+
return self.test_set
|
76
|
+
|
77
|
+
def num_train_examples(self):
|
78
|
+
return Config().data.num_train_examples
|
79
|
+
|
80
|
+
def num_test_examples(self):
|
81
|
+
return Config().data.num_test_examples
|
82
|
+
|
83
|
+
def classes(self):
|
84
|
+
"""Obtains a list of class names in the dataset."""
|
85
|
+
return Config().data.classes
|
plato/models/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
A factory that generates fully convolutional neural network encoder.
|
3
|
+
|
4
|
+
The fully convolutional neural network used as the encoder in this implementation
|
5
|
+
is the backbone part of a model.
|
6
|
+
|
7
|
+
This encoder is solely capable of extracting features from the input sample.
|
8
|
+
Generally, in the context of classification tasks, this encoder has to cooperate
|
9
|
+
with the classification head to make the prediction.
|
10
|
+
|
11
|
+
Besides, the 'AdaptiveAvgPool2d' layer is included to support extracting
|
12
|
+
features with fixed dimensions.
|
13
|
+
|
14
|
+
"""
|
15
|
+
|
16
|
+
from typing import Optional, Dict
|
17
|
+
|
18
|
+
from torch import nn
|
19
|
+
import torchvision
|
20
|
+
|
21
|
+
from plato.models.lenet5 import Model as lenet5_model
|
22
|
+
from plato.models.vgg import Model as vgg_model
|
23
|
+
|
24
|
+
from plato.config import Config
|
25
|
+
|
26
|
+
|
27
|
+
class TruncatedLeNetModel(nn.Module):
|
28
|
+
"""The truncated LeNet-5 model."""
|
29
|
+
|
30
|
+
def __init__(self, defined_lenet5_model):
|
31
|
+
super().__init__()
|
32
|
+
self.model = defined_lenet5_model
|
33
|
+
self.model.fc4 = nn.Identity()
|
34
|
+
self.model.relu4 = nn.Identity()
|
35
|
+
self.model.fc5 = nn.Identity()
|
36
|
+
|
37
|
+
def forward(self, samples):
|
38
|
+
"""Forward to specific layer (cut)_layer) of LeNet5."""
|
39
|
+
self.model.cut_layer = "flatten"
|
40
|
+
return self.model.forward_to(samples)
|
41
|
+
|
42
|
+
|
43
|
+
class Model:
|
44
|
+
"""The encoder obtained by removing the final
|
45
|
+
fully-connected blocks of the required model.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# pylint:disable=too-few-public-methods
|
49
|
+
@staticmethod
|
50
|
+
def get(model_name: Optional[str] = None, **kwargs: Dict[str, str]): # pylint: disable=unused-argument
|
51
|
+
"""Returns an encoder that is a fully CNN block."""
|
52
|
+
|
53
|
+
# as the final fully-connected layer will be removed
|
54
|
+
# the number of classes can be the randomly value
|
55
|
+
# thus, set it to be constant value 10.
|
56
|
+
num_classes = 10
|
57
|
+
|
58
|
+
if model_name == "lenet5":
|
59
|
+
model = lenet5_model(num_classes=num_classes)
|
60
|
+
# get encoding dimensions
|
61
|
+
# i.e., the output dim of the encoder
|
62
|
+
encode_output_dim = model.fc4.in_features
|
63
|
+
encoder = TruncatedLeNetModel(model)
|
64
|
+
|
65
|
+
if "vgg" in model_name:
|
66
|
+
encoder = vgg_model.get(model_name=model_name, num_classes=num_classes)
|
67
|
+
# get encoding dimensions
|
68
|
+
# i.e., the output dim of the encoder
|
69
|
+
encode_output_dim = encoder.fc.in_features
|
70
|
+
encoder.fc = nn.Identity()
|
71
|
+
|
72
|
+
if "resnet" in model_name:
|
73
|
+
resnets = {
|
74
|
+
"resnet_18": torchvision.models.resnet18,
|
75
|
+
"resnet_50": torchvision.models.resnet50,
|
76
|
+
}
|
77
|
+
|
78
|
+
encoder = resnets[model_name](num_classes=num_classes)
|
79
|
+
|
80
|
+
datasource = (
|
81
|
+
kwargs["datasource"]
|
82
|
+
if "datasource" in kwargs
|
83
|
+
else Config().data.datasource
|
84
|
+
)
|
85
|
+
if "CIFAR" in datasource:
|
86
|
+
# The structure specifically for CIFAR-based dataset.
|
87
|
+
# Replace conv 7x7 with conv 3x3,
|
88
|
+
# and remove first max pooling.
|
89
|
+
# For example,
|
90
|
+
# see Section B.9 of SimCLR paper.
|
91
|
+
encoder.conv1 = nn.Conv2d(
|
92
|
+
3, 64, kernel_size=3, stride=1, padding=2, bias=False
|
93
|
+
)
|
94
|
+
encoder.maxpool = nn.Identity()
|
95
|
+
|
96
|
+
# get encoding dimensions
|
97
|
+
# i.e., the output dim of the encoder
|
98
|
+
encode_output_dim = encoder.fc.in_features
|
99
|
+
encoder.fc = nn.Identity()
|
100
|
+
|
101
|
+
encoder.encoding_dim = encode_output_dim
|
102
|
+
|
103
|
+
return encoder
|
plato/models/dcgan.py
ADDED
@@ -0,0 +1,116 @@
|
|
1
|
+
"""
|
2
|
+
The DCGAN model.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
|
6
|
+
"""
|
7
|
+
|
8
|
+
from torch import nn
|
9
|
+
|
10
|
+
nz = 100
|
11
|
+
nc = 3
|
12
|
+
ngf = 64
|
13
|
+
ndf = 64
|
14
|
+
|
15
|
+
|
16
|
+
class Generator(nn.Module):
|
17
|
+
"""Generator network of DCGAN"""
|
18
|
+
|
19
|
+
def __init__(self):
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
self.main = nn.Sequential(
|
23
|
+
# input is Z, going into a convolution
|
24
|
+
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
25
|
+
nn.BatchNorm2d(ngf * 8),
|
26
|
+
nn.ReLU(True),
|
27
|
+
# state size. (ngf*8) x 4 x 4
|
28
|
+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
29
|
+
nn.BatchNorm2d(ngf * 4),
|
30
|
+
nn.ReLU(True),
|
31
|
+
# state size. (ngf*4) x 8 x 8
|
32
|
+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
33
|
+
nn.BatchNorm2d(ngf * 2),
|
34
|
+
nn.ReLU(True),
|
35
|
+
# state size. (ngf*2) x 16 x 16
|
36
|
+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
37
|
+
nn.BatchNorm2d(ngf),
|
38
|
+
nn.ReLU(True),
|
39
|
+
# state size. (ngf) x 32 x 32
|
40
|
+
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
41
|
+
nn.Tanh(),
|
42
|
+
# state size. (nc) x 64 x 64
|
43
|
+
)
|
44
|
+
|
45
|
+
def forward(self, input_data):
|
46
|
+
"""Forward pass."""
|
47
|
+
return self.main(input_data)
|
48
|
+
|
49
|
+
|
50
|
+
class Discriminator(nn.Module):
|
51
|
+
"""Discriminator network of DCGAN."""
|
52
|
+
|
53
|
+
def __init__(self):
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
self.main = nn.Sequential(
|
57
|
+
# input is (nc) x 64 x 64
|
58
|
+
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
59
|
+
nn.LeakyReLU(0.2, inplace=True),
|
60
|
+
# state size. (ndf) x 32 x 32
|
61
|
+
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
62
|
+
nn.BatchNorm2d(ndf * 2),
|
63
|
+
nn.LeakyReLU(0.2, inplace=True),
|
64
|
+
# state size. (ndf*2) x 16 x 16
|
65
|
+
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
66
|
+
nn.BatchNorm2d(ndf * 4),
|
67
|
+
nn.LeakyReLU(0.2, inplace=True),
|
68
|
+
# state size. (ndf*4) x 8 x 8
|
69
|
+
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
70
|
+
nn.BatchNorm2d(ndf * 8),
|
71
|
+
nn.LeakyReLU(0.2, inplace=True),
|
72
|
+
# state size. (ndf*8) x 4 x 4
|
73
|
+
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
74
|
+
nn.Sigmoid(),
|
75
|
+
)
|
76
|
+
|
77
|
+
def forward(self, input_data):
|
78
|
+
return self.main(input_data)
|
79
|
+
|
80
|
+
|
81
|
+
class Model:
|
82
|
+
"""A wrapper class to hold the Generator and Discriminator models of DCGAN."""
|
83
|
+
|
84
|
+
def __init__(self) -> None:
|
85
|
+
self.generator = Generator()
|
86
|
+
self.discriminator = Discriminator()
|
87
|
+
self.loss_criterion = nn.BCELoss()
|
88
|
+
|
89
|
+
self.nz = nz
|
90
|
+
self.nc = nc
|
91
|
+
self.ngf = ngf
|
92
|
+
self.ndf = ndf
|
93
|
+
|
94
|
+
def weights_init(self, model):
|
95
|
+
classname = model.__class__.__name__
|
96
|
+
if classname.find("Conv") != -1:
|
97
|
+
nn.init.normal_(model.weight.data, 0.0, 0.02)
|
98
|
+
elif classname.find("BatchNorm") != -1:
|
99
|
+
nn.init.normal_(model.weight.data, 1.0, 0.02)
|
100
|
+
nn.init.constant_(model.bias.data, 0)
|
101
|
+
|
102
|
+
def cpu(self):
|
103
|
+
self.generator.cpu()
|
104
|
+
self.discriminator.cpu()
|
105
|
+
|
106
|
+
def to(self, device):
|
107
|
+
self.generator.to(device)
|
108
|
+
self.discriminator.to(device)
|
109
|
+
|
110
|
+
def train(self):
|
111
|
+
self.generator.train()
|
112
|
+
self.discriminator.train()
|
113
|
+
|
114
|
+
def eval(self):
|
115
|
+
self.generator.eval()
|
116
|
+
self.discriminator.eval()
|