plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/config.py ADDED
@@ -0,0 +1,339 @@
1
+ """
2
+ Reading runtime parameters from a standard configuration file (which is easier
3
+ to work on than JSON).
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import logging
9
+ import os
10
+ from collections import OrderedDict, namedtuple
11
+ from pathlib import Path
12
+ from typing import IO, Any
13
+
14
+ import numpy as np
15
+ import yaml
16
+
17
+
18
+ class Loader(yaml.SafeLoader):
19
+ """YAML Loader with `!include` constructor."""
20
+
21
+ def __init__(self, stream: IO) -> None:
22
+ """Initialise Loader."""
23
+
24
+ try:
25
+ self.root_path = os.path.split(stream.name)[0]
26
+ except AttributeError:
27
+ self.root_path = os.path.curdir
28
+
29
+ super().__init__(stream)
30
+
31
+
32
+ class Config:
33
+ """
34
+ Retrieving configuration parameters by parsing a configuration file
35
+ using the YAML configuration file parser.
36
+ """
37
+
38
+ _instance = None
39
+
40
+ @staticmethod
41
+ def construct_include(loader: Loader, node: yaml.Node) -> Any:
42
+ """Include file referenced at node."""
43
+ with open(
44
+ Path(loader.name)
45
+ .parent.joinpath(loader.construct_yaml_str(node))
46
+ .resolve(),
47
+ "r",
48
+ ) as f:
49
+ return yaml.load(f, type(loader))
50
+
51
+ def __new__(cls):
52
+ if cls._instance is None:
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("-i", "--id", type=str, help="Unique client ID.")
55
+ parser.add_argument(
56
+ "-p", "--port", type=str, help="The port number for running a server."
57
+ )
58
+ parser.add_argument(
59
+ "-c",
60
+ "--config",
61
+ type=str,
62
+ default="./config.yml",
63
+ help="Federated learning configuration file.",
64
+ )
65
+ parser.add_argument(
66
+ "-b",
67
+ "--base",
68
+ type=str,
69
+ default="./",
70
+ help="The base path for datasets and models.",
71
+ )
72
+ parser.add_argument(
73
+ "-s",
74
+ "--server",
75
+ type=str,
76
+ default=None,
77
+ help="The server hostname and port number.",
78
+ )
79
+ parser.add_argument(
80
+ "-u", "--cpu", action="store_true", help="Use CPU as the device."
81
+ )
82
+ parser.add_argument(
83
+ "-m", "--mps", action="store_true", help="Use MPS as the device."
84
+ )
85
+ parser.add_argument(
86
+ "-d",
87
+ "--download",
88
+ action="store_true",
89
+ help="Download the dataset to prepare for a training session.",
90
+ )
91
+ parser.add_argument(
92
+ "-r",
93
+ "--resume",
94
+ action="store_true",
95
+ help="Resume a previously interrupted training session.",
96
+ )
97
+ parser.add_argument(
98
+ "-l", "--log", type=str, default="info", help="Log messages level."
99
+ )
100
+
101
+ args = parser.parse_args()
102
+ Config.args = args
103
+
104
+ if Config.args.id is not None:
105
+ Config.args.id = int(args.id)
106
+ if Config.args.port is not None:
107
+ Config.args.port = int(args.port)
108
+
109
+ numeric_level = getattr(logging, args.log.upper(), None)
110
+
111
+ if not isinstance(numeric_level, int):
112
+ raise ValueError(f"Invalid log level: {args.log}")
113
+
114
+ logging.basicConfig(
115
+ format="[%(levelname)s][%(asctime)s]: %(message)s", datefmt="%H:%M:%S"
116
+ )
117
+
118
+ root_logger = logging.getLogger()
119
+ root_logger.setLevel(numeric_level)
120
+
121
+ cls._instance = super(Config, cls).__new__(cls)
122
+
123
+ if "config_file" in os.environ:
124
+ filename = os.environ["config_file"]
125
+ else:
126
+ filename = args.config
127
+
128
+ yaml.add_constructor("!include", Config.construct_include, Loader)
129
+
130
+ if os.path.isfile(filename):
131
+ with open(filename, "r", encoding="utf-8") as config_file:
132
+ config = yaml.load(config_file, Loader)
133
+ else:
134
+ # if the configuration file does not exist, raise an error
135
+ raise ValueError("A configuration file must be supplied.")
136
+
137
+ Config.clients = Config.namedtuple_from_dict(config["clients"])
138
+ Config.server = Config.namedtuple_from_dict(config["server"])
139
+ Config.data = Config.namedtuple_from_dict(config["data"])
140
+ Config.trainer = Config.namedtuple_from_dict(config["trainer"])
141
+ Config.algorithm = Config.namedtuple_from_dict(config["algorithm"])
142
+
143
+ if Config.args.server is not None:
144
+ Config.server = Config.server._replace(
145
+ address=args.server.split(":")[0]
146
+ )
147
+ Config.server = Config.server._replace(port=args.server.split(":")[1])
148
+
149
+ if Config.args.download:
150
+ Config.clients = Config.clients._replace(total_clients=1)
151
+ Config.clients = Config.clients._replace(per_round=1)
152
+
153
+ if (
154
+ hasattr(Config.clients, "speed_simulation")
155
+ and Config.clients.speed_simulation
156
+ ):
157
+ Config.simulate_client_speed()
158
+
159
+ # Customizable dictionary of global parameters
160
+ Config.params: dict = {}
161
+
162
+ # A run ID is unique to each client in an experiment
163
+ Config.params["run_id"] = os.getpid()
164
+
165
+ # The base path used for all datasets, models, checkpoints, and results
166
+ Config.params["base_path"] = Config.args.base
167
+
168
+ if "general" in config:
169
+ Config.general = Config.namedtuple_from_dict(config["general"])
170
+
171
+ if hasattr(Config.general, "base_path"):
172
+ Config.params["base_path"] = Config().general.base_path
173
+
174
+ # Directory of dataset
175
+ if hasattr(Config().data, "data_path"):
176
+ Config.params["data_path"] = os.path.join(
177
+ Config.params["base_path"], Config().data.data_path
178
+ )
179
+ else:
180
+ Config.params["data_path"] = os.path.join(
181
+ Config.params["base_path"], "data"
182
+ )
183
+
184
+ # Pretrained models
185
+ if hasattr(Config().server, "model_path"):
186
+ Config.params["model_path"] = os.path.join(
187
+ Config.params["base_path"], Config().server.model_path
188
+ )
189
+ else:
190
+ Config.params["model_path"] = os.path.join(
191
+ Config.params["base_path"], "models/pretrained"
192
+ )
193
+ os.makedirs(Config.params["model_path"], exist_ok=True)
194
+
195
+ # Resume checkpoint
196
+ if hasattr(Config().server, "checkpoint_path"):
197
+ Config.params["checkpoint_path"] = os.path.join(
198
+ Config.params["base_path"], Config().server.checkpoint_path
199
+ )
200
+ else:
201
+ Config.params["checkpoint_path"] = os.path.join(
202
+ Config.params["base_path"], "checkpoints"
203
+ )
204
+ os.makedirs(Config.params["checkpoint_path"], exist_ok=True)
205
+
206
+ if "results" in config:
207
+ Config.results = Config.namedtuple_from_dict(config["results"])
208
+
209
+ # Directory of the .csv file containing results
210
+ if hasattr(Config, "results") and hasattr(Config.results, "result_path"):
211
+ Config.params["result_path"] = os.path.join(
212
+ Config.params["base_path"], Config.results.result_path
213
+ )
214
+ else:
215
+ Config.params["result_path"] = os.path.join(
216
+ Config.params["base_path"], "results"
217
+ )
218
+ os.makedirs(Config.params["result_path"], exist_ok=True)
219
+
220
+ # The set of columns in the .csv file
221
+ if hasattr(Config, "results") and hasattr(Config.results, "types"):
222
+ Config.params["result_types"] = Config.results.types
223
+ else:
224
+ Config.params["result_types"] = "round, accuracy, elapsed_time"
225
+
226
+ # The set of pairs to be plotted
227
+ if hasattr(Config, "results") and hasattr(Config.results, "plot"):
228
+ Config.params["plot_pairs"] = Config().results.plot
229
+ else:
230
+ Config.params["plot_pairs"] = "round-accuracy, elapsed_time-accuracy"
231
+
232
+ if "parameters" in config:
233
+ Config.parameters = Config.namedtuple_from_dict(config["parameters"])
234
+
235
+ return cls._instance
236
+
237
+ @staticmethod
238
+ def namedtuple_from_dict(obj):
239
+ """Creates a named tuple from a dictionary."""
240
+ if isinstance(obj, dict):
241
+ fields = sorted(obj.keys())
242
+ namedtuple_type = namedtuple(
243
+ typename="Config", field_names=fields, rename=True
244
+ )
245
+ field_value_pairs = OrderedDict(
246
+ (str(field), Config.namedtuple_from_dict(obj[field]))
247
+ for field in fields
248
+ )
249
+ try:
250
+ return namedtuple_type(**field_value_pairs)
251
+ except TypeError:
252
+ # Cannot create namedtuple instance so fallback to dict (invalid attribute names)
253
+ return dict(**field_value_pairs)
254
+ elif isinstance(obj, (list, set, tuple, frozenset)):
255
+ return [Config.namedtuple_from_dict(item) for item in obj]
256
+ else:
257
+ return obj
258
+
259
+ @staticmethod
260
+ def simulate_client_speed() -> float:
261
+ """Randomly generate a sleep time (in seconds per epoch) for each of the clients."""
262
+ # a random seed must be supplied to make sure that all the clients generate
263
+ # the same set of sleep times per epoch across the board
264
+ if hasattr(Config.clients, "random_seed"):
265
+ np.random.seed(Config.clients.random_seed)
266
+ else:
267
+ np.random.seed(1)
268
+
269
+ # Limit the simulated sleep time by the threshold 'max_sleep_time'
270
+ max_sleep_time = 60
271
+ if hasattr(Config.clients, "max_sleep_time"):
272
+ max_sleep_time = Config.clients.max_sleep_time
273
+
274
+ dist = Config.clients.simulation_distribution
275
+ total_clients = Config.clients.total_clients
276
+ sleep_times = []
277
+
278
+ if hasattr(Config.clients, "simulation_distribution"):
279
+ if dist.distribution.lower() == "normal":
280
+ sleep_times = np.random.normal(dist.mean, dist.sd, size=total_clients)
281
+ if dist.distribution.lower() == "pareto":
282
+ sleep_times = np.random.pareto(dist.alpha, size=total_clients)
283
+ if dist.distribution.lower() == "zipf":
284
+ sleep_times = np.random.zipf(dist.s, size=total_clients)
285
+ if dist.distribution.lower() == "uniform":
286
+ sleep_times = np.random.uniform(dist.low, dist.high, size=total_clients)
287
+ else:
288
+ # By default, use Pareto distribution with a parameter of 1.0
289
+ sleep_times = np.random.pareto(1.0, size=total_clients)
290
+
291
+ Config.client_sleep_times = np.minimum(
292
+ sleep_times, np.repeat(max_sleep_time, total_clients)
293
+ )
294
+
295
+ @staticmethod
296
+ def is_edge_server() -> bool:
297
+ """Returns whether the current instance is an edge server in cross-silo FL."""
298
+ return Config().args.port is not None
299
+
300
+ @staticmethod
301
+ def is_central_server() -> bool:
302
+ """Returns whether the current instance is a central server in cross-silo FL."""
303
+ return hasattr(Config().algorithm, "cross_silo") and Config().args.port is None
304
+
305
+ @staticmethod
306
+ def gpu_count() -> int:
307
+ """Returns the number of GPUs available for training."""
308
+
309
+ import torch
310
+
311
+ if torch.cuda.is_available():
312
+ return torch.cuda.device_count()
313
+ elif Config.args.mps and torch.backends.mps.is_built():
314
+ return 1
315
+ else:
316
+ return 0
317
+
318
+ @staticmethod
319
+ def device() -> str:
320
+ """Returns the device to be used for training."""
321
+ device = "cpu"
322
+
323
+ if Config.args.cpu:
324
+ return device
325
+
326
+ import torch
327
+
328
+ if torch.cuda.is_available() and torch.cuda.device_count() > 0:
329
+ if Config.gpu_count() > 1 and isinstance(Config.args.id, int):
330
+ # A client will always run on the same GPU
331
+ gpu_id = Config.args.id % torch.cuda.device_count()
332
+ device = f"cuda:{gpu_id}"
333
+ else:
334
+ device = "cuda:0"
335
+
336
+ if Config.args.mps and torch.backends.mps.is_built():
337
+ device = "mps"
338
+
339
+ return device
File without changes
@@ -0,0 +1,123 @@
1
+ """
2
+ Base class for data sources, encapsulating training and testing datasets with
3
+ custom augmentations and transforms already accommodated.
4
+ """
5
+
6
+ import gzip
7
+ import logging
8
+ import os
9
+ import sys
10
+ import tarfile
11
+ import zipfile
12
+ from urllib.parse import urlparse
13
+
14
+ import requests
15
+ from plato.config import Config
16
+
17
+
18
+ class DataSource:
19
+ """
20
+ Training and testing datasets with custom augmentations and transforms
21
+ already accommodated.
22
+ """
23
+
24
+ def __init__(self):
25
+ self.trainset = None
26
+ self.testset = None
27
+
28
+ @staticmethod
29
+ def download(url, data_path):
30
+ """downloads a dataset from a URL."""
31
+ if not os.path.exists(data_path):
32
+ if Config().clients.total_clients > 1:
33
+ if (
34
+ not hasattr(Config().data, "concurrent_download")
35
+ or not Config().data.concurrent_download
36
+ ):
37
+ raise ValueError(
38
+ "The dataset has not yet been downloaded from the Internet. "
39
+ "Please re-run with '-d' or '--download' first. "
40
+ )
41
+
42
+ os.makedirs(data_path, exist_ok=True)
43
+
44
+ url_parse = urlparse(url)
45
+ file_name = os.path.join(data_path, url_parse.path.split("/")[-1])
46
+
47
+ if not os.path.exists(file_name.replace(".gz", "")):
48
+ logging.info("Downloading %s.", url)
49
+
50
+ res = requests.get(url, verify=False, stream=True)
51
+ total_size = int(res.headers["Content-Length"])
52
+ downloaded_size = 0
53
+
54
+ with open(file_name, "wb+") as file:
55
+ for chunk in res.iter_content(chunk_size=1024):
56
+ downloaded_size += len(chunk)
57
+ file.write(chunk)
58
+ file.flush()
59
+ sys.stdout.write(
60
+ "\r{:.1f}%".format(100 * downloaded_size / total_size)
61
+ )
62
+ sys.stdout.flush()
63
+ sys.stdout.write("\n")
64
+
65
+ # Unzip the compressed file just downloaded
66
+ logging.info("Decompressing the dataset downloaded.")
67
+ name, suffix = os.path.splitext(file_name)
68
+
69
+ if file_name.endswith("tar.gz"):
70
+ tar = tarfile.open(file_name, "r:gz")
71
+ tar.extractall(data_path)
72
+ tar.close()
73
+ os.remove(file_name)
74
+ elif suffix == ".zip":
75
+ logging.info("Extracting %s to %s.", file_name, data_path)
76
+ with zipfile.ZipFile(file_name, "r") as zip_ref:
77
+ zip_ref.extractall(data_path)
78
+ elif suffix == ".gz":
79
+ unzipped_file = open(name, "wb")
80
+ zipped_file = gzip.GzipFile(file_name)
81
+ unzipped_file.write(zipped_file.read())
82
+ zipped_file.close()
83
+ os.remove(file_name)
84
+ else:
85
+ logging.info("Unknown compressed file type.")
86
+ sys.exit()
87
+
88
+ if Config().args.download:
89
+ logging.info(
90
+ "The dataset has been successfully downloaded. "
91
+ "Re-run the experiment without '-d' or '--download'."
92
+ )
93
+ sys.exit()
94
+
95
+ @staticmethod
96
+ def input_shape():
97
+ """Obtains the input shape of this data source."""
98
+ raise NotImplementedError("Input shape not specified for this data source.")
99
+
100
+ def num_train_examples(self) -> int:
101
+ """Obtains the number of training examples."""
102
+ return len(self.trainset)
103
+
104
+ def num_test_examples(self) -> int:
105
+ """Obtains the number of testing examples."""
106
+ return len(self.testset)
107
+
108
+ def classes(self):
109
+ """Obtains a list of class names in the dataset."""
110
+ return list(self.trainset.classes)
111
+
112
+ def targets(self):
113
+ """Obtains a list of targets (labels) for all the examples
114
+ in the dataset."""
115
+ return self.trainset.targets
116
+
117
+ def get_train_set(self):
118
+ """Obtains the training dataset."""
119
+ return self.trainset
120
+
121
+ def get_test_set(self):
122
+ """Obtains the validation dataset."""
123
+ return self.testset
@@ -0,0 +1,150 @@
1
+ """
2
+ The CelebA dataset from the torchvision package.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from typing import Callable, List, Optional, Union
8
+
9
+ import torch
10
+ from torchvision import datasets, transforms
11
+
12
+ from plato.config import Config
13
+ from plato.datasources import base
14
+
15
+
16
+ class CelebA(datasets.CelebA):
17
+ """
18
+ A wrapper class of torchvision's CelebA dataset class
19
+ to add <targets> and <classes> attributes as celebrity
20
+ identity, which is used for non-IID samplers.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ root: str,
26
+ split: str = "train",
27
+ target_type: Union[List[str], str] = "attr",
28
+ transform: Optional[Callable] = None,
29
+ target_transform: Optional[Callable] = None,
30
+ download: bool = False,
31
+ ) -> None:
32
+ super().__init__(
33
+ root, split, target_type, transform, target_transform, download
34
+ )
35
+ self.targets = self.identity.flatten().tolist()
36
+ self.classes = [f"Celebrity #{i}" for i in range(10177 + 1)]
37
+
38
+
39
+ class DataSource(base.DataSource):
40
+ """The CelebA dataset."""
41
+
42
+ def __init__(self, **kwargs):
43
+ super().__init__()
44
+ _path = Config().params["data_path"]
45
+
46
+ if not os.path.exists(os.path.join(_path, "celeba")):
47
+ celeba_url = "http://iqua.ece.toronto.edu/baochun/celeba.tar.gz"
48
+ DataSource.download(celeba_url, _path)
49
+ else:
50
+ logging.info(
51
+ "CelebA data already decompressed under %s",
52
+ os.path.join(_path, "celeba"),
53
+ )
54
+
55
+ target_types = []
56
+ if hasattr(Config().data, "celeba_targets"):
57
+ targets = Config().data.celeba_targets
58
+ if hasattr(targets, "attr") and targets.attr:
59
+ target_types.append("attr")
60
+ if hasattr(targets, "identity") and targets.identity:
61
+ target_types.append("identity")
62
+ else:
63
+ target_types = ["attr", "identity"]
64
+
65
+ image_size = 64
66
+ if hasattr(Config().data, "celeba_img_size"):
67
+ image_size = Config().data.celeba_img_size
68
+
69
+ train_transform = (
70
+ kwargs["train_transform"]
71
+ if "train_transform" in kwargs
72
+ else (
73
+ transforms.Compose(
74
+ [
75
+ transforms.Resize(image_size),
76
+ transforms.CenterCrop(image_size),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
79
+ ]
80
+ )
81
+ )
82
+ )
83
+
84
+ test_transform = train_transform
85
+
86
+ target_transform = (
87
+ kwargs["target_transform"]
88
+ if "target_transform" in kwargs
89
+ else (DataSource._target_transform if target_types else None)
90
+ )
91
+
92
+ self.trainset = CelebA(
93
+ root=_path,
94
+ split="train",
95
+ target_type=target_types,
96
+ download=False,
97
+ transform=train_transform,
98
+ target_transform=target_transform,
99
+ )
100
+ self.testset = CelebA(
101
+ root=_path,
102
+ split="test",
103
+ target_type=target_types,
104
+ download=False,
105
+ transform=test_transform,
106
+ target_transform=target_transform,
107
+ )
108
+
109
+ @staticmethod
110
+ def _target_transform(label):
111
+ """
112
+ Output labels are in a tuple of tensors if specified more
113
+ than one target types, so we need to convert the tuple to
114
+ tensors. Here, we just merge two tensors by adding identity
115
+ as the 41st attribute
116
+ """
117
+ if isinstance(label, tuple):
118
+ if len(label) == 1:
119
+ return label[0]
120
+ elif len(label) == 2:
121
+ attr, identity = label
122
+ return torch.cat(
123
+ (
124
+ attr.reshape(
125
+ [
126
+ -1,
127
+ ]
128
+ ),
129
+ identity.reshape(
130
+ [
131
+ -1,
132
+ ]
133
+ ),
134
+ )
135
+ )
136
+ else:
137
+ return label
138
+
139
+ @staticmethod
140
+ def input_shape():
141
+ image_size = 64
142
+ if hasattr(Config().data, "celeba_img_size"):
143
+ image_size = Config().data.celeba_img_size
144
+ return [162770, 3, image_size, image_size]
145
+
146
+ def num_train_examples(self):
147
+ return 162770
148
+
149
+ def num_test_examples(self):
150
+ return 19962
@@ -0,0 +1,87 @@
1
+ """
2
+ The CIFAR-10 dataset from the torchvision package.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import sys
8
+
9
+ from torchvision import datasets, transforms
10
+
11
+ from plato.config import Config
12
+ from plato.datasources import base
13
+
14
+
15
+ class DataSource(base.DataSource):
16
+ """The CIFAR-10 dataset."""
17
+
18
+ def __init__(self, **kwargs):
19
+ super().__init__()
20
+
21
+ train_transform = (
22
+ kwargs["train_transform"]
23
+ if "train_transform" in kwargs
24
+ else (
25
+ transforms.Compose(
26
+ [
27
+ transforms.RandomHorizontalFlip(),
28
+ transforms.RandomCrop(32, 4),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(
31
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
32
+ ),
33
+ ]
34
+ )
35
+ )
36
+ )
37
+
38
+ test_transform = (
39
+ kwargs["test_transform"]
40
+ if "test_transform" in kwargs
41
+ else (
42
+ transforms.Compose(
43
+ [
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(
46
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
47
+ ),
48
+ ]
49
+ )
50
+ )
51
+ )
52
+
53
+ _path = Config().params["data_path"]
54
+
55
+ if not os.path.exists(_path):
56
+ if hasattr(Config().server, "do_test") and not Config().server.do_test:
57
+ # If the server is not performing local tests for accuracy, concurrent
58
+ # downloading on the clients may lead to PyTorch errors
59
+ if Config().clients.total_clients > 1:
60
+ if (
61
+ not hasattr(Config().data, "concurrent_download")
62
+ or not Config().data.concurrent_download
63
+ ):
64
+ raise ValueError(
65
+ "The dataset has not yet been downloaded from the Internet. "
66
+ "Please re-run with '-d' or '--download' first. "
67
+ )
68
+
69
+ self.trainset = datasets.CIFAR10(
70
+ root=_path, train=True, download=True, transform=train_transform
71
+ )
72
+ self.testset = datasets.CIFAR10(
73
+ root=_path, train=False, download=True, transform=test_transform
74
+ )
75
+
76
+ if Config().args.download:
77
+ logging.info(
78
+ "The dataset has been successfully downloaded. "
79
+ "Re-run the experiment without '-d' or '--download'."
80
+ )
81
+ sys.exit()
82
+
83
+ def num_train_examples(self):
84
+ return 50000
85
+
86
+ def num_test_examples(self):
87
+ return 10000