plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/config.py
ADDED
@@ -0,0 +1,339 @@
|
|
1
|
+
"""
|
2
|
+
Reading runtime parameters from a standard configuration file (which is easier
|
3
|
+
to work on than JSON).
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import json
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
from collections import OrderedDict, namedtuple
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import IO, Any
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
import yaml
|
16
|
+
|
17
|
+
|
18
|
+
class Loader(yaml.SafeLoader):
|
19
|
+
"""YAML Loader with `!include` constructor."""
|
20
|
+
|
21
|
+
def __init__(self, stream: IO) -> None:
|
22
|
+
"""Initialise Loader."""
|
23
|
+
|
24
|
+
try:
|
25
|
+
self.root_path = os.path.split(stream.name)[0]
|
26
|
+
except AttributeError:
|
27
|
+
self.root_path = os.path.curdir
|
28
|
+
|
29
|
+
super().__init__(stream)
|
30
|
+
|
31
|
+
|
32
|
+
class Config:
|
33
|
+
"""
|
34
|
+
Retrieving configuration parameters by parsing a configuration file
|
35
|
+
using the YAML configuration file parser.
|
36
|
+
"""
|
37
|
+
|
38
|
+
_instance = None
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def construct_include(loader: Loader, node: yaml.Node) -> Any:
|
42
|
+
"""Include file referenced at node."""
|
43
|
+
with open(
|
44
|
+
Path(loader.name)
|
45
|
+
.parent.joinpath(loader.construct_yaml_str(node))
|
46
|
+
.resolve(),
|
47
|
+
"r",
|
48
|
+
) as f:
|
49
|
+
return yaml.load(f, type(loader))
|
50
|
+
|
51
|
+
def __new__(cls):
|
52
|
+
if cls._instance is None:
|
53
|
+
parser = argparse.ArgumentParser()
|
54
|
+
parser.add_argument("-i", "--id", type=str, help="Unique client ID.")
|
55
|
+
parser.add_argument(
|
56
|
+
"-p", "--port", type=str, help="The port number for running a server."
|
57
|
+
)
|
58
|
+
parser.add_argument(
|
59
|
+
"-c",
|
60
|
+
"--config",
|
61
|
+
type=str,
|
62
|
+
default="./config.yml",
|
63
|
+
help="Federated learning configuration file.",
|
64
|
+
)
|
65
|
+
parser.add_argument(
|
66
|
+
"-b",
|
67
|
+
"--base",
|
68
|
+
type=str,
|
69
|
+
default="./",
|
70
|
+
help="The base path for datasets and models.",
|
71
|
+
)
|
72
|
+
parser.add_argument(
|
73
|
+
"-s",
|
74
|
+
"--server",
|
75
|
+
type=str,
|
76
|
+
default=None,
|
77
|
+
help="The server hostname and port number.",
|
78
|
+
)
|
79
|
+
parser.add_argument(
|
80
|
+
"-u", "--cpu", action="store_true", help="Use CPU as the device."
|
81
|
+
)
|
82
|
+
parser.add_argument(
|
83
|
+
"-m", "--mps", action="store_true", help="Use MPS as the device."
|
84
|
+
)
|
85
|
+
parser.add_argument(
|
86
|
+
"-d",
|
87
|
+
"--download",
|
88
|
+
action="store_true",
|
89
|
+
help="Download the dataset to prepare for a training session.",
|
90
|
+
)
|
91
|
+
parser.add_argument(
|
92
|
+
"-r",
|
93
|
+
"--resume",
|
94
|
+
action="store_true",
|
95
|
+
help="Resume a previously interrupted training session.",
|
96
|
+
)
|
97
|
+
parser.add_argument(
|
98
|
+
"-l", "--log", type=str, default="info", help="Log messages level."
|
99
|
+
)
|
100
|
+
|
101
|
+
args = parser.parse_args()
|
102
|
+
Config.args = args
|
103
|
+
|
104
|
+
if Config.args.id is not None:
|
105
|
+
Config.args.id = int(args.id)
|
106
|
+
if Config.args.port is not None:
|
107
|
+
Config.args.port = int(args.port)
|
108
|
+
|
109
|
+
numeric_level = getattr(logging, args.log.upper(), None)
|
110
|
+
|
111
|
+
if not isinstance(numeric_level, int):
|
112
|
+
raise ValueError(f"Invalid log level: {args.log}")
|
113
|
+
|
114
|
+
logging.basicConfig(
|
115
|
+
format="[%(levelname)s][%(asctime)s]: %(message)s", datefmt="%H:%M:%S"
|
116
|
+
)
|
117
|
+
|
118
|
+
root_logger = logging.getLogger()
|
119
|
+
root_logger.setLevel(numeric_level)
|
120
|
+
|
121
|
+
cls._instance = super(Config, cls).__new__(cls)
|
122
|
+
|
123
|
+
if "config_file" in os.environ:
|
124
|
+
filename = os.environ["config_file"]
|
125
|
+
else:
|
126
|
+
filename = args.config
|
127
|
+
|
128
|
+
yaml.add_constructor("!include", Config.construct_include, Loader)
|
129
|
+
|
130
|
+
if os.path.isfile(filename):
|
131
|
+
with open(filename, "r", encoding="utf-8") as config_file:
|
132
|
+
config = yaml.load(config_file, Loader)
|
133
|
+
else:
|
134
|
+
# if the configuration file does not exist, raise an error
|
135
|
+
raise ValueError("A configuration file must be supplied.")
|
136
|
+
|
137
|
+
Config.clients = Config.namedtuple_from_dict(config["clients"])
|
138
|
+
Config.server = Config.namedtuple_from_dict(config["server"])
|
139
|
+
Config.data = Config.namedtuple_from_dict(config["data"])
|
140
|
+
Config.trainer = Config.namedtuple_from_dict(config["trainer"])
|
141
|
+
Config.algorithm = Config.namedtuple_from_dict(config["algorithm"])
|
142
|
+
|
143
|
+
if Config.args.server is not None:
|
144
|
+
Config.server = Config.server._replace(
|
145
|
+
address=args.server.split(":")[0]
|
146
|
+
)
|
147
|
+
Config.server = Config.server._replace(port=args.server.split(":")[1])
|
148
|
+
|
149
|
+
if Config.args.download:
|
150
|
+
Config.clients = Config.clients._replace(total_clients=1)
|
151
|
+
Config.clients = Config.clients._replace(per_round=1)
|
152
|
+
|
153
|
+
if (
|
154
|
+
hasattr(Config.clients, "speed_simulation")
|
155
|
+
and Config.clients.speed_simulation
|
156
|
+
):
|
157
|
+
Config.simulate_client_speed()
|
158
|
+
|
159
|
+
# Customizable dictionary of global parameters
|
160
|
+
Config.params: dict = {}
|
161
|
+
|
162
|
+
# A run ID is unique to each client in an experiment
|
163
|
+
Config.params["run_id"] = os.getpid()
|
164
|
+
|
165
|
+
# The base path used for all datasets, models, checkpoints, and results
|
166
|
+
Config.params["base_path"] = Config.args.base
|
167
|
+
|
168
|
+
if "general" in config:
|
169
|
+
Config.general = Config.namedtuple_from_dict(config["general"])
|
170
|
+
|
171
|
+
if hasattr(Config.general, "base_path"):
|
172
|
+
Config.params["base_path"] = Config().general.base_path
|
173
|
+
|
174
|
+
# Directory of dataset
|
175
|
+
if hasattr(Config().data, "data_path"):
|
176
|
+
Config.params["data_path"] = os.path.join(
|
177
|
+
Config.params["base_path"], Config().data.data_path
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
Config.params["data_path"] = os.path.join(
|
181
|
+
Config.params["base_path"], "data"
|
182
|
+
)
|
183
|
+
|
184
|
+
# Pretrained models
|
185
|
+
if hasattr(Config().server, "model_path"):
|
186
|
+
Config.params["model_path"] = os.path.join(
|
187
|
+
Config.params["base_path"], Config().server.model_path
|
188
|
+
)
|
189
|
+
else:
|
190
|
+
Config.params["model_path"] = os.path.join(
|
191
|
+
Config.params["base_path"], "models/pretrained"
|
192
|
+
)
|
193
|
+
os.makedirs(Config.params["model_path"], exist_ok=True)
|
194
|
+
|
195
|
+
# Resume checkpoint
|
196
|
+
if hasattr(Config().server, "checkpoint_path"):
|
197
|
+
Config.params["checkpoint_path"] = os.path.join(
|
198
|
+
Config.params["base_path"], Config().server.checkpoint_path
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
Config.params["checkpoint_path"] = os.path.join(
|
202
|
+
Config.params["base_path"], "checkpoints"
|
203
|
+
)
|
204
|
+
os.makedirs(Config.params["checkpoint_path"], exist_ok=True)
|
205
|
+
|
206
|
+
if "results" in config:
|
207
|
+
Config.results = Config.namedtuple_from_dict(config["results"])
|
208
|
+
|
209
|
+
# Directory of the .csv file containing results
|
210
|
+
if hasattr(Config, "results") and hasattr(Config.results, "result_path"):
|
211
|
+
Config.params["result_path"] = os.path.join(
|
212
|
+
Config.params["base_path"], Config.results.result_path
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
Config.params["result_path"] = os.path.join(
|
216
|
+
Config.params["base_path"], "results"
|
217
|
+
)
|
218
|
+
os.makedirs(Config.params["result_path"], exist_ok=True)
|
219
|
+
|
220
|
+
# The set of columns in the .csv file
|
221
|
+
if hasattr(Config, "results") and hasattr(Config.results, "types"):
|
222
|
+
Config.params["result_types"] = Config.results.types
|
223
|
+
else:
|
224
|
+
Config.params["result_types"] = "round, accuracy, elapsed_time"
|
225
|
+
|
226
|
+
# The set of pairs to be plotted
|
227
|
+
if hasattr(Config, "results") and hasattr(Config.results, "plot"):
|
228
|
+
Config.params["plot_pairs"] = Config().results.plot
|
229
|
+
else:
|
230
|
+
Config.params["plot_pairs"] = "round-accuracy, elapsed_time-accuracy"
|
231
|
+
|
232
|
+
if "parameters" in config:
|
233
|
+
Config.parameters = Config.namedtuple_from_dict(config["parameters"])
|
234
|
+
|
235
|
+
return cls._instance
|
236
|
+
|
237
|
+
@staticmethod
|
238
|
+
def namedtuple_from_dict(obj):
|
239
|
+
"""Creates a named tuple from a dictionary."""
|
240
|
+
if isinstance(obj, dict):
|
241
|
+
fields = sorted(obj.keys())
|
242
|
+
namedtuple_type = namedtuple(
|
243
|
+
typename="Config", field_names=fields, rename=True
|
244
|
+
)
|
245
|
+
field_value_pairs = OrderedDict(
|
246
|
+
(str(field), Config.namedtuple_from_dict(obj[field]))
|
247
|
+
for field in fields
|
248
|
+
)
|
249
|
+
try:
|
250
|
+
return namedtuple_type(**field_value_pairs)
|
251
|
+
except TypeError:
|
252
|
+
# Cannot create namedtuple instance so fallback to dict (invalid attribute names)
|
253
|
+
return dict(**field_value_pairs)
|
254
|
+
elif isinstance(obj, (list, set, tuple, frozenset)):
|
255
|
+
return [Config.namedtuple_from_dict(item) for item in obj]
|
256
|
+
else:
|
257
|
+
return obj
|
258
|
+
|
259
|
+
@staticmethod
|
260
|
+
def simulate_client_speed() -> float:
|
261
|
+
"""Randomly generate a sleep time (in seconds per epoch) for each of the clients."""
|
262
|
+
# a random seed must be supplied to make sure that all the clients generate
|
263
|
+
# the same set of sleep times per epoch across the board
|
264
|
+
if hasattr(Config.clients, "random_seed"):
|
265
|
+
np.random.seed(Config.clients.random_seed)
|
266
|
+
else:
|
267
|
+
np.random.seed(1)
|
268
|
+
|
269
|
+
# Limit the simulated sleep time by the threshold 'max_sleep_time'
|
270
|
+
max_sleep_time = 60
|
271
|
+
if hasattr(Config.clients, "max_sleep_time"):
|
272
|
+
max_sleep_time = Config.clients.max_sleep_time
|
273
|
+
|
274
|
+
dist = Config.clients.simulation_distribution
|
275
|
+
total_clients = Config.clients.total_clients
|
276
|
+
sleep_times = []
|
277
|
+
|
278
|
+
if hasattr(Config.clients, "simulation_distribution"):
|
279
|
+
if dist.distribution.lower() == "normal":
|
280
|
+
sleep_times = np.random.normal(dist.mean, dist.sd, size=total_clients)
|
281
|
+
if dist.distribution.lower() == "pareto":
|
282
|
+
sleep_times = np.random.pareto(dist.alpha, size=total_clients)
|
283
|
+
if dist.distribution.lower() == "zipf":
|
284
|
+
sleep_times = np.random.zipf(dist.s, size=total_clients)
|
285
|
+
if dist.distribution.lower() == "uniform":
|
286
|
+
sleep_times = np.random.uniform(dist.low, dist.high, size=total_clients)
|
287
|
+
else:
|
288
|
+
# By default, use Pareto distribution with a parameter of 1.0
|
289
|
+
sleep_times = np.random.pareto(1.0, size=total_clients)
|
290
|
+
|
291
|
+
Config.client_sleep_times = np.minimum(
|
292
|
+
sleep_times, np.repeat(max_sleep_time, total_clients)
|
293
|
+
)
|
294
|
+
|
295
|
+
@staticmethod
|
296
|
+
def is_edge_server() -> bool:
|
297
|
+
"""Returns whether the current instance is an edge server in cross-silo FL."""
|
298
|
+
return Config().args.port is not None
|
299
|
+
|
300
|
+
@staticmethod
|
301
|
+
def is_central_server() -> bool:
|
302
|
+
"""Returns whether the current instance is a central server in cross-silo FL."""
|
303
|
+
return hasattr(Config().algorithm, "cross_silo") and Config().args.port is None
|
304
|
+
|
305
|
+
@staticmethod
|
306
|
+
def gpu_count() -> int:
|
307
|
+
"""Returns the number of GPUs available for training."""
|
308
|
+
|
309
|
+
import torch
|
310
|
+
|
311
|
+
if torch.cuda.is_available():
|
312
|
+
return torch.cuda.device_count()
|
313
|
+
elif Config.args.mps and torch.backends.mps.is_built():
|
314
|
+
return 1
|
315
|
+
else:
|
316
|
+
return 0
|
317
|
+
|
318
|
+
@staticmethod
|
319
|
+
def device() -> str:
|
320
|
+
"""Returns the device to be used for training."""
|
321
|
+
device = "cpu"
|
322
|
+
|
323
|
+
if Config.args.cpu:
|
324
|
+
return device
|
325
|
+
|
326
|
+
import torch
|
327
|
+
|
328
|
+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
329
|
+
if Config.gpu_count() > 1 and isinstance(Config.args.id, int):
|
330
|
+
# A client will always run on the same GPU
|
331
|
+
gpu_id = Config.args.id % torch.cuda.device_count()
|
332
|
+
device = f"cuda:{gpu_id}"
|
333
|
+
else:
|
334
|
+
device = "cuda:0"
|
335
|
+
|
336
|
+
if Config.args.mps and torch.backends.mps.is_built():
|
337
|
+
device = "mps"
|
338
|
+
|
339
|
+
return device
|
File without changes
|
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Base class for data sources, encapsulating training and testing datasets with
|
3
|
+
custom augmentations and transforms already accommodated.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import gzip
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
import sys
|
10
|
+
import tarfile
|
11
|
+
import zipfile
|
12
|
+
from urllib.parse import urlparse
|
13
|
+
|
14
|
+
import requests
|
15
|
+
from plato.config import Config
|
16
|
+
|
17
|
+
|
18
|
+
class DataSource:
|
19
|
+
"""
|
20
|
+
Training and testing datasets with custom augmentations and transforms
|
21
|
+
already accommodated.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
self.trainset = None
|
26
|
+
self.testset = None
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def download(url, data_path):
|
30
|
+
"""downloads a dataset from a URL."""
|
31
|
+
if not os.path.exists(data_path):
|
32
|
+
if Config().clients.total_clients > 1:
|
33
|
+
if (
|
34
|
+
not hasattr(Config().data, "concurrent_download")
|
35
|
+
or not Config().data.concurrent_download
|
36
|
+
):
|
37
|
+
raise ValueError(
|
38
|
+
"The dataset has not yet been downloaded from the Internet. "
|
39
|
+
"Please re-run with '-d' or '--download' first. "
|
40
|
+
)
|
41
|
+
|
42
|
+
os.makedirs(data_path, exist_ok=True)
|
43
|
+
|
44
|
+
url_parse = urlparse(url)
|
45
|
+
file_name = os.path.join(data_path, url_parse.path.split("/")[-1])
|
46
|
+
|
47
|
+
if not os.path.exists(file_name.replace(".gz", "")):
|
48
|
+
logging.info("Downloading %s.", url)
|
49
|
+
|
50
|
+
res = requests.get(url, verify=False, stream=True)
|
51
|
+
total_size = int(res.headers["Content-Length"])
|
52
|
+
downloaded_size = 0
|
53
|
+
|
54
|
+
with open(file_name, "wb+") as file:
|
55
|
+
for chunk in res.iter_content(chunk_size=1024):
|
56
|
+
downloaded_size += len(chunk)
|
57
|
+
file.write(chunk)
|
58
|
+
file.flush()
|
59
|
+
sys.stdout.write(
|
60
|
+
"\r{:.1f}%".format(100 * downloaded_size / total_size)
|
61
|
+
)
|
62
|
+
sys.stdout.flush()
|
63
|
+
sys.stdout.write("\n")
|
64
|
+
|
65
|
+
# Unzip the compressed file just downloaded
|
66
|
+
logging.info("Decompressing the dataset downloaded.")
|
67
|
+
name, suffix = os.path.splitext(file_name)
|
68
|
+
|
69
|
+
if file_name.endswith("tar.gz"):
|
70
|
+
tar = tarfile.open(file_name, "r:gz")
|
71
|
+
tar.extractall(data_path)
|
72
|
+
tar.close()
|
73
|
+
os.remove(file_name)
|
74
|
+
elif suffix == ".zip":
|
75
|
+
logging.info("Extracting %s to %s.", file_name, data_path)
|
76
|
+
with zipfile.ZipFile(file_name, "r") as zip_ref:
|
77
|
+
zip_ref.extractall(data_path)
|
78
|
+
elif suffix == ".gz":
|
79
|
+
unzipped_file = open(name, "wb")
|
80
|
+
zipped_file = gzip.GzipFile(file_name)
|
81
|
+
unzipped_file.write(zipped_file.read())
|
82
|
+
zipped_file.close()
|
83
|
+
os.remove(file_name)
|
84
|
+
else:
|
85
|
+
logging.info("Unknown compressed file type.")
|
86
|
+
sys.exit()
|
87
|
+
|
88
|
+
if Config().args.download:
|
89
|
+
logging.info(
|
90
|
+
"The dataset has been successfully downloaded. "
|
91
|
+
"Re-run the experiment without '-d' or '--download'."
|
92
|
+
)
|
93
|
+
sys.exit()
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def input_shape():
|
97
|
+
"""Obtains the input shape of this data source."""
|
98
|
+
raise NotImplementedError("Input shape not specified for this data source.")
|
99
|
+
|
100
|
+
def num_train_examples(self) -> int:
|
101
|
+
"""Obtains the number of training examples."""
|
102
|
+
return len(self.trainset)
|
103
|
+
|
104
|
+
def num_test_examples(self) -> int:
|
105
|
+
"""Obtains the number of testing examples."""
|
106
|
+
return len(self.testset)
|
107
|
+
|
108
|
+
def classes(self):
|
109
|
+
"""Obtains a list of class names in the dataset."""
|
110
|
+
return list(self.trainset.classes)
|
111
|
+
|
112
|
+
def targets(self):
|
113
|
+
"""Obtains a list of targets (labels) for all the examples
|
114
|
+
in the dataset."""
|
115
|
+
return self.trainset.targets
|
116
|
+
|
117
|
+
def get_train_set(self):
|
118
|
+
"""Obtains the training dataset."""
|
119
|
+
return self.trainset
|
120
|
+
|
121
|
+
def get_test_set(self):
|
122
|
+
"""Obtains the validation dataset."""
|
123
|
+
return self.testset
|
@@ -0,0 +1,150 @@
|
|
1
|
+
"""
|
2
|
+
The CelebA dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import Callable, List, Optional, Union
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torchvision import datasets, transforms
|
11
|
+
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.datasources import base
|
14
|
+
|
15
|
+
|
16
|
+
class CelebA(datasets.CelebA):
|
17
|
+
"""
|
18
|
+
A wrapper class of torchvision's CelebA dataset class
|
19
|
+
to add <targets> and <classes> attributes as celebrity
|
20
|
+
identity, which is used for non-IID samplers.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
root: str,
|
26
|
+
split: str = "train",
|
27
|
+
target_type: Union[List[str], str] = "attr",
|
28
|
+
transform: Optional[Callable] = None,
|
29
|
+
target_transform: Optional[Callable] = None,
|
30
|
+
download: bool = False,
|
31
|
+
) -> None:
|
32
|
+
super().__init__(
|
33
|
+
root, split, target_type, transform, target_transform, download
|
34
|
+
)
|
35
|
+
self.targets = self.identity.flatten().tolist()
|
36
|
+
self.classes = [f"Celebrity #{i}" for i in range(10177 + 1)]
|
37
|
+
|
38
|
+
|
39
|
+
class DataSource(base.DataSource):
|
40
|
+
"""The CelebA dataset."""
|
41
|
+
|
42
|
+
def __init__(self, **kwargs):
|
43
|
+
super().__init__()
|
44
|
+
_path = Config().params["data_path"]
|
45
|
+
|
46
|
+
if not os.path.exists(os.path.join(_path, "celeba")):
|
47
|
+
celeba_url = "http://iqua.ece.toronto.edu/baochun/celeba.tar.gz"
|
48
|
+
DataSource.download(celeba_url, _path)
|
49
|
+
else:
|
50
|
+
logging.info(
|
51
|
+
"CelebA data already decompressed under %s",
|
52
|
+
os.path.join(_path, "celeba"),
|
53
|
+
)
|
54
|
+
|
55
|
+
target_types = []
|
56
|
+
if hasattr(Config().data, "celeba_targets"):
|
57
|
+
targets = Config().data.celeba_targets
|
58
|
+
if hasattr(targets, "attr") and targets.attr:
|
59
|
+
target_types.append("attr")
|
60
|
+
if hasattr(targets, "identity") and targets.identity:
|
61
|
+
target_types.append("identity")
|
62
|
+
else:
|
63
|
+
target_types = ["attr", "identity"]
|
64
|
+
|
65
|
+
image_size = 64
|
66
|
+
if hasattr(Config().data, "celeba_img_size"):
|
67
|
+
image_size = Config().data.celeba_img_size
|
68
|
+
|
69
|
+
train_transform = (
|
70
|
+
kwargs["train_transform"]
|
71
|
+
if "train_transform" in kwargs
|
72
|
+
else (
|
73
|
+
transforms.Compose(
|
74
|
+
[
|
75
|
+
transforms.Resize(image_size),
|
76
|
+
transforms.CenterCrop(image_size),
|
77
|
+
transforms.ToTensor(),
|
78
|
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
79
|
+
]
|
80
|
+
)
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
test_transform = train_transform
|
85
|
+
|
86
|
+
target_transform = (
|
87
|
+
kwargs["target_transform"]
|
88
|
+
if "target_transform" in kwargs
|
89
|
+
else (DataSource._target_transform if target_types else None)
|
90
|
+
)
|
91
|
+
|
92
|
+
self.trainset = CelebA(
|
93
|
+
root=_path,
|
94
|
+
split="train",
|
95
|
+
target_type=target_types,
|
96
|
+
download=False,
|
97
|
+
transform=train_transform,
|
98
|
+
target_transform=target_transform,
|
99
|
+
)
|
100
|
+
self.testset = CelebA(
|
101
|
+
root=_path,
|
102
|
+
split="test",
|
103
|
+
target_type=target_types,
|
104
|
+
download=False,
|
105
|
+
transform=test_transform,
|
106
|
+
target_transform=target_transform,
|
107
|
+
)
|
108
|
+
|
109
|
+
@staticmethod
|
110
|
+
def _target_transform(label):
|
111
|
+
"""
|
112
|
+
Output labels are in a tuple of tensors if specified more
|
113
|
+
than one target types, so we need to convert the tuple to
|
114
|
+
tensors. Here, we just merge two tensors by adding identity
|
115
|
+
as the 41st attribute
|
116
|
+
"""
|
117
|
+
if isinstance(label, tuple):
|
118
|
+
if len(label) == 1:
|
119
|
+
return label[0]
|
120
|
+
elif len(label) == 2:
|
121
|
+
attr, identity = label
|
122
|
+
return torch.cat(
|
123
|
+
(
|
124
|
+
attr.reshape(
|
125
|
+
[
|
126
|
+
-1,
|
127
|
+
]
|
128
|
+
),
|
129
|
+
identity.reshape(
|
130
|
+
[
|
131
|
+
-1,
|
132
|
+
]
|
133
|
+
),
|
134
|
+
)
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
return label
|
138
|
+
|
139
|
+
@staticmethod
|
140
|
+
def input_shape():
|
141
|
+
image_size = 64
|
142
|
+
if hasattr(Config().data, "celeba_img_size"):
|
143
|
+
image_size = Config().data.celeba_img_size
|
144
|
+
return [162770, 3, image_size, image_size]
|
145
|
+
|
146
|
+
def num_train_examples(self):
|
147
|
+
return 162770
|
148
|
+
|
149
|
+
def num_test_examples(self):
|
150
|
+
return 19962
|
@@ -0,0 +1,87 @@
|
|
1
|
+
"""
|
2
|
+
The CIFAR-10 dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import sys
|
8
|
+
|
9
|
+
from torchvision import datasets, transforms
|
10
|
+
|
11
|
+
from plato.config import Config
|
12
|
+
from plato.datasources import base
|
13
|
+
|
14
|
+
|
15
|
+
class DataSource(base.DataSource):
|
16
|
+
"""The CIFAR-10 dataset."""
|
17
|
+
|
18
|
+
def __init__(self, **kwargs):
|
19
|
+
super().__init__()
|
20
|
+
|
21
|
+
train_transform = (
|
22
|
+
kwargs["train_transform"]
|
23
|
+
if "train_transform" in kwargs
|
24
|
+
else (
|
25
|
+
transforms.Compose(
|
26
|
+
[
|
27
|
+
transforms.RandomHorizontalFlip(),
|
28
|
+
transforms.RandomCrop(32, 4),
|
29
|
+
transforms.ToTensor(),
|
30
|
+
transforms.Normalize(
|
31
|
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
32
|
+
),
|
33
|
+
]
|
34
|
+
)
|
35
|
+
)
|
36
|
+
)
|
37
|
+
|
38
|
+
test_transform = (
|
39
|
+
kwargs["test_transform"]
|
40
|
+
if "test_transform" in kwargs
|
41
|
+
else (
|
42
|
+
transforms.Compose(
|
43
|
+
[
|
44
|
+
transforms.ToTensor(),
|
45
|
+
transforms.Normalize(
|
46
|
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
47
|
+
),
|
48
|
+
]
|
49
|
+
)
|
50
|
+
)
|
51
|
+
)
|
52
|
+
|
53
|
+
_path = Config().params["data_path"]
|
54
|
+
|
55
|
+
if not os.path.exists(_path):
|
56
|
+
if hasattr(Config().server, "do_test") and not Config().server.do_test:
|
57
|
+
# If the server is not performing local tests for accuracy, concurrent
|
58
|
+
# downloading on the clients may lead to PyTorch errors
|
59
|
+
if Config().clients.total_clients > 1:
|
60
|
+
if (
|
61
|
+
not hasattr(Config().data, "concurrent_download")
|
62
|
+
or not Config().data.concurrent_download
|
63
|
+
):
|
64
|
+
raise ValueError(
|
65
|
+
"The dataset has not yet been downloaded from the Internet. "
|
66
|
+
"Please re-run with '-d' or '--download' first. "
|
67
|
+
)
|
68
|
+
|
69
|
+
self.trainset = datasets.CIFAR10(
|
70
|
+
root=_path, train=True, download=True, transform=train_transform
|
71
|
+
)
|
72
|
+
self.testset = datasets.CIFAR10(
|
73
|
+
root=_path, train=False, download=True, transform=test_transform
|
74
|
+
)
|
75
|
+
|
76
|
+
if Config().args.download:
|
77
|
+
logging.info(
|
78
|
+
"The dataset has been successfully downloaded. "
|
79
|
+
"Re-run the experiment without '-d' or '--download'."
|
80
|
+
)
|
81
|
+
sys.exit()
|
82
|
+
|
83
|
+
def num_train_examples(self):
|
84
|
+
return 50000
|
85
|
+
|
86
|
+
def num_test_examples(self):
|
87
|
+
return 10000
|