plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/servers/fedavg.py
ADDED
@@ -0,0 +1,281 @@
|
|
1
|
+
"""
|
2
|
+
A simple federated learning server using federated averaging.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
|
9
|
+
from plato.algorithms import registry as algorithms_registry
|
10
|
+
from plato.config import Config
|
11
|
+
from plato.datasources import registry as datasources_registry
|
12
|
+
from plato.processors import registry as processor_registry
|
13
|
+
from plato.samplers import all_inclusive
|
14
|
+
from plato.servers import base
|
15
|
+
from plato.trainers import registry as trainers_registry
|
16
|
+
from plato.utils import csv_processor, fonts
|
17
|
+
|
18
|
+
|
19
|
+
class Server(base.Server):
|
20
|
+
"""Federated learning server using federated averaging."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
24
|
+
):
|
25
|
+
super().__init__(callbacks=callbacks)
|
26
|
+
|
27
|
+
self.custom_model = model
|
28
|
+
self.model = None
|
29
|
+
|
30
|
+
self.custom_algorithm = algorithm
|
31
|
+
self.algorithm = None
|
32
|
+
|
33
|
+
self.custom_trainer = trainer
|
34
|
+
self.trainer = None
|
35
|
+
|
36
|
+
self.custom_datasource = datasource
|
37
|
+
self.datasource = None
|
38
|
+
|
39
|
+
self.testset = None
|
40
|
+
self.testset_sampler = None
|
41
|
+
self.total_samples = 0
|
42
|
+
|
43
|
+
self.total_clients = Config().clients.total_clients
|
44
|
+
self.clients_per_round = Config().clients.per_round
|
45
|
+
|
46
|
+
logging.info(
|
47
|
+
"[Server #%d] Started training on %d clients with %d per round.",
|
48
|
+
os.getpid(),
|
49
|
+
self.total_clients,
|
50
|
+
self.clients_per_round,
|
51
|
+
)
|
52
|
+
|
53
|
+
def configure(self) -> None:
|
54
|
+
"""
|
55
|
+
Booting the federated learning server by setting up the data, model, and
|
56
|
+
creating the clients.
|
57
|
+
"""
|
58
|
+
super().configure()
|
59
|
+
|
60
|
+
total_rounds = Config().trainer.rounds
|
61
|
+
target_accuracy = None
|
62
|
+
target_perplexity = None
|
63
|
+
|
64
|
+
if hasattr(Config().trainer, "target_accuracy"):
|
65
|
+
target_accuracy = Config().trainer.target_accuracy
|
66
|
+
elif hasattr(Config().trainer, "target_perplexity"):
|
67
|
+
target_perplexity = Config().trainer.target_perplexity
|
68
|
+
|
69
|
+
if target_accuracy:
|
70
|
+
logging.info(
|
71
|
+
"Training: %s rounds or accuracy above %.1f%%\n",
|
72
|
+
total_rounds,
|
73
|
+
100 * target_accuracy,
|
74
|
+
)
|
75
|
+
elif target_perplexity:
|
76
|
+
logging.info(
|
77
|
+
"Training: %s rounds or perplexity below %.1f\n",
|
78
|
+
total_rounds,
|
79
|
+
target_perplexity,
|
80
|
+
)
|
81
|
+
else:
|
82
|
+
logging.info("Training: %s rounds\n", total_rounds)
|
83
|
+
|
84
|
+
self.init_trainer()
|
85
|
+
|
86
|
+
# Prepares this server for processors that processes outbound and inbound
|
87
|
+
# data payloads
|
88
|
+
self.outbound_processor, self.inbound_processor = processor_registry.get(
|
89
|
+
"Server", server_id=os.getpid(), trainer=self.trainer
|
90
|
+
)
|
91
|
+
|
92
|
+
if not (hasattr(Config().server, "do_test") and not Config().server.do_test):
|
93
|
+
if self.datasource is None and self.custom_datasource is None:
|
94
|
+
self.datasource = datasources_registry.get(client_id=0)
|
95
|
+
elif self.datasource is None and self.custom_datasource is not None:
|
96
|
+
self.datasource = self.custom_datasource()
|
97
|
+
|
98
|
+
self.testset = self.datasource.get_test_set()
|
99
|
+
if hasattr(Config().data, "testset_size"):
|
100
|
+
self.testset_sampler = all_inclusive.Sampler(
|
101
|
+
self.datasource, testing=True
|
102
|
+
)
|
103
|
+
|
104
|
+
# Initialize the test accuracy csv file if clients compute locally
|
105
|
+
if (
|
106
|
+
hasattr(Config().clients, "do_test")
|
107
|
+
and Config().clients.do_test
|
108
|
+
and (
|
109
|
+
hasattr(Config(), "results")
|
110
|
+
and hasattr(Config().results, "record_clients_accuracy")
|
111
|
+
and Config().results.record_clients_accuracy
|
112
|
+
)
|
113
|
+
):
|
114
|
+
accuracy_csv_file = (
|
115
|
+
f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
|
116
|
+
)
|
117
|
+
accuracy_headers = ["round", "client_id", "accuracy"]
|
118
|
+
csv_processor.initialize_csv(
|
119
|
+
accuracy_csv_file, accuracy_headers, Config().params["result_path"]
|
120
|
+
)
|
121
|
+
|
122
|
+
def init_trainer(self) -> None:
|
123
|
+
"""Setting up the global model, trainer, and algorithm."""
|
124
|
+
if self.model is None and self.custom_model is not None:
|
125
|
+
self.model = self.custom_model
|
126
|
+
|
127
|
+
if self.trainer is None and self.custom_trainer is None:
|
128
|
+
self.trainer = trainers_registry.get(model=self.model)
|
129
|
+
elif self.trainer is None and self.custom_trainer is not None:
|
130
|
+
self.trainer = self.custom_trainer(model=self.model)
|
131
|
+
|
132
|
+
if self.algorithm is None and self.custom_algorithm is None:
|
133
|
+
self.algorithm = algorithms_registry.get(trainer=self.trainer)
|
134
|
+
elif self.algorithm is None and self.custom_algorithm is not None:
|
135
|
+
self.algorithm = self.custom_algorithm(trainer=self.trainer)
|
136
|
+
|
137
|
+
async def aggregate_deltas(self, updates, deltas_received):
|
138
|
+
"""Aggregate weight updates from the clients using federated averaging."""
|
139
|
+
# Extract the total number of samples
|
140
|
+
self.total_samples = sum(update.report.num_samples for update in updates)
|
141
|
+
|
142
|
+
# Perform weighted averaging
|
143
|
+
avg_update = {
|
144
|
+
name: self.trainer.zeros(delta.shape)
|
145
|
+
for name, delta in deltas_received[0].items()
|
146
|
+
}
|
147
|
+
|
148
|
+
for i, update in enumerate(deltas_received):
|
149
|
+
report = updates[i].report
|
150
|
+
num_samples = report.num_samples
|
151
|
+
|
152
|
+
for name, delta in update.items():
|
153
|
+
# Use weighted average by the number of samples
|
154
|
+
avg_update[name] += delta * (num_samples / self.total_samples)
|
155
|
+
|
156
|
+
# Yield to other tasks in the server
|
157
|
+
await asyncio.sleep(0)
|
158
|
+
|
159
|
+
return avg_update
|
160
|
+
|
161
|
+
async def _process_reports(self):
|
162
|
+
"""Process the client reports by aggregating their weights."""
|
163
|
+
weights_received = [update.payload for update in self.updates]
|
164
|
+
|
165
|
+
weights_received = self.weights_received(weights_received)
|
166
|
+
self.callback_handler.call_event("on_weights_received", self, weights_received)
|
167
|
+
|
168
|
+
# Extract the current model weights as the baseline
|
169
|
+
baseline_weights = self.algorithm.extract_weights()
|
170
|
+
|
171
|
+
if hasattr(self, "aggregate_weights"):
|
172
|
+
# Runs a server aggregation algorithm using weights rather than deltas
|
173
|
+
logging.info(
|
174
|
+
"[Server #%d] Aggregating model weights directly rather than weight deltas.",
|
175
|
+
os.getpid(),
|
176
|
+
)
|
177
|
+
updated_weights = await self.aggregate_weights(
|
178
|
+
self.updates, baseline_weights, weights_received
|
179
|
+
)
|
180
|
+
|
181
|
+
# Loads the new model weights
|
182
|
+
self.algorithm.load_weights(updated_weights)
|
183
|
+
else:
|
184
|
+
# Computes the weight deltas by comparing the weights received with
|
185
|
+
# the current global model weights
|
186
|
+
deltas_received = self.algorithm.compute_weight_deltas(
|
187
|
+
baseline_weights, weights_received
|
188
|
+
)
|
189
|
+
# Runs a framework-agnostic server aggregation algorithm, such as
|
190
|
+
# the federated averaging algorithm
|
191
|
+
logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid())
|
192
|
+
deltas = await self.aggregate_deltas(self.updates, deltas_received)
|
193
|
+
# Updates the existing model weights from the provided deltas
|
194
|
+
updated_weights = self.algorithm.update_weights(deltas)
|
195
|
+
# Loads the new model weights
|
196
|
+
self.algorithm.load_weights(updated_weights)
|
197
|
+
|
198
|
+
# The model weights have already been aggregated, now calls the
|
199
|
+
# corresponding hook and callback
|
200
|
+
self.weights_aggregated(self.updates)
|
201
|
+
self.callback_handler.call_event("on_weights_aggregated", self, self.updates)
|
202
|
+
|
203
|
+
# Testing the global model accuracy
|
204
|
+
if hasattr(Config().server, "do_test") and not Config().server.do_test:
|
205
|
+
# Compute the average accuracy from client reports
|
206
|
+
self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates)
|
207
|
+
logging.info(
|
208
|
+
"[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
# Testing the updated model directly at the server
|
212
|
+
logging.info("[%s] Started model testing.", self)
|
213
|
+
self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
|
214
|
+
|
215
|
+
if hasattr(Config().trainer, "target_perplexity"):
|
216
|
+
logging.info(
|
217
|
+
fonts.colourize(
|
218
|
+
f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
|
219
|
+
)
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
logging.info(
|
223
|
+
fonts.colourize(
|
224
|
+
f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
|
225
|
+
)
|
226
|
+
)
|
227
|
+
|
228
|
+
self.clients_processed()
|
229
|
+
self.callback_handler.call_event("on_clients_processed", self)
|
230
|
+
|
231
|
+
def clients_processed(self) -> None:
|
232
|
+
"""Additional work to be performed after client reports have been processed."""
|
233
|
+
|
234
|
+
def get_logged_items(self) -> dict:
|
235
|
+
"""Get items to be logged by the LogProgressCallback class in a .csv file."""
|
236
|
+
return {
|
237
|
+
"round": self.current_round,
|
238
|
+
"accuracy": self.accuracy,
|
239
|
+
"accuracy_std": self.accuracy_std,
|
240
|
+
"elapsed_time": self.wall_time - self.initial_wall_time,
|
241
|
+
"processing_time": max(
|
242
|
+
update.report.processing_time for update in self.updates
|
243
|
+
),
|
244
|
+
"comm_time": max(update.report.comm_time for update in self.updates),
|
245
|
+
"round_time": max(
|
246
|
+
update.report.training_time
|
247
|
+
+ update.report.processing_time
|
248
|
+
+ update.report.comm_time
|
249
|
+
for update in self.updates
|
250
|
+
),
|
251
|
+
"comm_overhead": self.comm_overhead,
|
252
|
+
}
|
253
|
+
|
254
|
+
@staticmethod
|
255
|
+
def get_accuracy_mean_std(updates):
|
256
|
+
"""Compute the accuracy mean and standard deviation across clients."""
|
257
|
+
# Get total number of samples
|
258
|
+
total_samples = sum(update.report.num_samples for update in updates)
|
259
|
+
|
260
|
+
# Perform weighted averaging
|
261
|
+
updates_accuracy = [update.report.accuracy for update in updates]
|
262
|
+
weights = [update.report.num_samples / total_samples for update in updates]
|
263
|
+
|
264
|
+
mean = sum(acc * weights[idx] for idx, acc in enumerate(updates_accuracy))
|
265
|
+
variance = sum(
|
266
|
+
(acc - mean) ** 2 * weights[idx] for idx, acc in enumerate(updates_accuracy)
|
267
|
+
)
|
268
|
+
std = variance**0.5
|
269
|
+
|
270
|
+
return mean, std
|
271
|
+
|
272
|
+
def weights_received(self, weights_received):
|
273
|
+
"""
|
274
|
+
Method called after the updated weights have been received.
|
275
|
+
"""
|
276
|
+
return weights_received
|
277
|
+
|
278
|
+
def weights_aggregated(self, updates):
|
279
|
+
"""
|
280
|
+
Method called after the updated weights have been aggregated.
|
281
|
+
"""
|
@@ -0,0 +1,335 @@
|
|
1
|
+
"""
|
2
|
+
A cross-silo federated learning server using federated averaging, as either edge or central servers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from plato.config import Config
|
11
|
+
from plato.datasources import registry as datasources_registry
|
12
|
+
from plato.processors import registry as processor_registry
|
13
|
+
from plato.samplers import registry as samplers_registry
|
14
|
+
from plato.samplers import all_inclusive
|
15
|
+
from plato.servers import fedavg
|
16
|
+
from plato.utils import fonts
|
17
|
+
|
18
|
+
|
19
|
+
class Server(fedavg.Server):
|
20
|
+
"""Cross-silo federated learning server using federated averaging."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
24
|
+
):
|
25
|
+
super().__init__(
|
26
|
+
model=model,
|
27
|
+
datasource=datasource,
|
28
|
+
algorithm=algorithm,
|
29
|
+
trainer=trainer,
|
30
|
+
callbacks=callbacks,
|
31
|
+
)
|
32
|
+
|
33
|
+
self.current_global_round = 0
|
34
|
+
self.average_accuracy = 0
|
35
|
+
self.std_accuracy = 0
|
36
|
+
|
37
|
+
if Config().is_edge_server():
|
38
|
+
# An edge client waits for the event that a certain number of
|
39
|
+
# aggregations are completed
|
40
|
+
self.model_aggregated = asyncio.Event()
|
41
|
+
|
42
|
+
# An edge client waits for the event that a new global round begins
|
43
|
+
# before starting the first round of local aggregation
|
44
|
+
self.new_global_round_begins = asyncio.Event()
|
45
|
+
|
46
|
+
edge_server_id = Config().args.id - Config().clients.total_clients
|
47
|
+
|
48
|
+
# Compute the total number of clients in each silo for edge servers
|
49
|
+
edges_total_clients = [
|
50
|
+
len(i)
|
51
|
+
for i in np.array_split(
|
52
|
+
np.arange(Config().clients.total_clients),
|
53
|
+
Config().algorithm.total_silos,
|
54
|
+
)
|
55
|
+
]
|
56
|
+
self.total_clients = edges_total_clients[edge_server_id - 1]
|
57
|
+
|
58
|
+
self.clients_per_round = [
|
59
|
+
len(i)
|
60
|
+
for i in np.array_split(
|
61
|
+
np.arange(Config().clients.per_round),
|
62
|
+
Config().algorithm.total_silos,
|
63
|
+
)
|
64
|
+
][edge_server_id - 1]
|
65
|
+
|
66
|
+
starting_client_id = sum(edges_total_clients[: edge_server_id - 1])
|
67
|
+
self.clients_pool = list(
|
68
|
+
range(
|
69
|
+
starting_client_id + 1, starting_client_id + 1 + self.total_clients
|
70
|
+
)
|
71
|
+
)
|
72
|
+
|
73
|
+
logging.info(
|
74
|
+
"[Edge server #%d (#%d)] Started training on %d clients with %d per round.",
|
75
|
+
Config().args.id,
|
76
|
+
os.getpid(),
|
77
|
+
self.total_clients,
|
78
|
+
self.clients_per_round,
|
79
|
+
)
|
80
|
+
|
81
|
+
# The training time of a edge server in one global round
|
82
|
+
self.edge_training_time = 0
|
83
|
+
|
84
|
+
# The training time of a edge server with its clients in one global round
|
85
|
+
self.edge_comm_time = 0
|
86
|
+
|
87
|
+
# Compute the number of clients for the central server
|
88
|
+
if Config().is_central_server():
|
89
|
+
self.clients_per_round = Config().algorithm.total_silos
|
90
|
+
self.total_clients = self.clients_per_round
|
91
|
+
|
92
|
+
logging.info(
|
93
|
+
"The central server starts training with %s edge servers.",
|
94
|
+
self.total_clients,
|
95
|
+
)
|
96
|
+
|
97
|
+
def configure(self) -> None:
|
98
|
+
"""
|
99
|
+
Booting the federated learning server by setting up the data, model, and
|
100
|
+
creating the clients.
|
101
|
+
"""
|
102
|
+
super().configure()
|
103
|
+
|
104
|
+
if Config().is_edge_server():
|
105
|
+
logging.info(
|
106
|
+
"Configuring edge server #%d as a %s server.",
|
107
|
+
Config().args.id,
|
108
|
+
Config().algorithm.type,
|
109
|
+
)
|
110
|
+
logging.info(
|
111
|
+
"[Edge server #%d (#%d)] Training with %s local aggregation rounds.",
|
112
|
+
Config().args.id,
|
113
|
+
os.getpid(),
|
114
|
+
Config().algorithm.local_rounds,
|
115
|
+
)
|
116
|
+
|
117
|
+
self.init_trainer()
|
118
|
+
self.trainer.set_client_id(Config().args.id)
|
119
|
+
|
120
|
+
# Prepares this server for processors that processes outbound and inbound
|
121
|
+
# data payloads
|
122
|
+
self.outbound_processor, self.inbound_processor = processor_registry.get(
|
123
|
+
"Server", server_id=os.getpid(), trainer=self.trainer
|
124
|
+
)
|
125
|
+
|
126
|
+
if (
|
127
|
+
hasattr(Config().server, "edge_do_test")
|
128
|
+
and Config().server.edge_do_test
|
129
|
+
):
|
130
|
+
self.datasource = datasources_registry.get(client_id=0)
|
131
|
+
self.testset = self.datasource.get_test_set()
|
132
|
+
|
133
|
+
if hasattr(Config().data, "testset_sampler"):
|
134
|
+
# Set the sampler for test set
|
135
|
+
self.testset_sampler = samplers_registry.get(
|
136
|
+
self.datasource, Config().args.id, testing=True
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
if hasattr(Config().data, "testset_size"):
|
140
|
+
self.testset_sampler = all_inclusive.Sampler(
|
141
|
+
self.datasource, testing=True
|
142
|
+
)
|
143
|
+
|
144
|
+
async def _select_clients(self, for_next_batch=False):
|
145
|
+
if Config().is_edge_server() and not for_next_batch:
|
146
|
+
if self.current_round == 0:
|
147
|
+
# Wait until this edge server is selected by the central server
|
148
|
+
# to avoid the edge server selects clients and clients begin training
|
149
|
+
# before the edge server is selected
|
150
|
+
await self.new_global_round_begins.wait()
|
151
|
+
self.new_global_round_begins.clear()
|
152
|
+
|
153
|
+
await super()._select_clients(for_next_batch=for_next_batch)
|
154
|
+
|
155
|
+
def customize_server_response(self, server_response: dict, client_id) -> dict:
|
156
|
+
"""Wrap up generating the server response with any additional information."""
|
157
|
+
if Config().is_central_server():
|
158
|
+
server_response["current_global_round"] = self.current_round
|
159
|
+
return server_response
|
160
|
+
|
161
|
+
async def _process_reports(self):
|
162
|
+
"""Process the client reports by aggregating their weights."""
|
163
|
+
# To pass the client_id == 0 assertion during aggregation
|
164
|
+
self.trainer.set_client_id(0)
|
165
|
+
|
166
|
+
weights_received = [update.payload for update in self.updates]
|
167
|
+
|
168
|
+
weights_received = self.weights_received(weights_received)
|
169
|
+
self.callback_handler.call_event("on_weights_received", self, weights_received)
|
170
|
+
|
171
|
+
# Extract the current model weights as the baseline
|
172
|
+
baseline_weights = self.algorithm.extract_weights()
|
173
|
+
|
174
|
+
if hasattr(self, "aggregate_weights"):
|
175
|
+
# Runs a server aggregation algorithm using weights rather than deltas
|
176
|
+
logging.info(
|
177
|
+
"[Server #%d] Aggregating model weights directly rather than weight deltas.",
|
178
|
+
os.getpid(),
|
179
|
+
)
|
180
|
+
updated_weights = await self.aggregate_weights(
|
181
|
+
self.updates, baseline_weights, weights_received
|
182
|
+
)
|
183
|
+
|
184
|
+
# Loads the new model weights
|
185
|
+
self.algorithm.load_weights(updated_weights)
|
186
|
+
else:
|
187
|
+
# Computes the weight deltas by comparing the weights received with
|
188
|
+
# the current global model weights
|
189
|
+
deltas_received = self.algorithm.compute_weight_deltas(
|
190
|
+
baseline_weights, weights_received
|
191
|
+
)
|
192
|
+
# Runs a framework-agnostic server aggregation algorithm, such as
|
193
|
+
# the federated averaging algorithm
|
194
|
+
logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid())
|
195
|
+
deltas = await self.aggregate_deltas(self.updates, deltas_received)
|
196
|
+
# Updates the existing model weights from the provided deltas
|
197
|
+
updated_weights = self.algorithm.update_weights(deltas)
|
198
|
+
# Loads the new model weights
|
199
|
+
self.algorithm.load_weights(updated_weights)
|
200
|
+
|
201
|
+
# The model weights have already been aggregated, now calls the
|
202
|
+
# corresponding hook and callback
|
203
|
+
self.weights_aggregated(self.updates)
|
204
|
+
self.callback_handler.call_event("on_weights_aggregated", self, self.updates)
|
205
|
+
|
206
|
+
if Config().is_edge_server():
|
207
|
+
self.trainer.set_client_id(Config().args.id)
|
208
|
+
|
209
|
+
# Testing the model accuracy
|
210
|
+
if (Config().is_edge_server() and Config().clients.do_test) or (
|
211
|
+
Config().is_central_server()
|
212
|
+
and hasattr(Config().server, "edge_do_test")
|
213
|
+
and Config().server.edge_do_test
|
214
|
+
):
|
215
|
+
# Compute the average accuracy from client reports
|
216
|
+
(
|
217
|
+
self.average_accuracy,
|
218
|
+
self.std_accuracy,
|
219
|
+
) = self.get_accuracy_mean_std(self.updates)
|
220
|
+
logging.info(
|
221
|
+
"[%s] Average client accuracy: %.2f%%.",
|
222
|
+
self,
|
223
|
+
100 * self.average_accuracy,
|
224
|
+
)
|
225
|
+
elif Config().is_central_server() and Config().clients.do_test:
|
226
|
+
# Compute the average accuracy from client reports
|
227
|
+
total_samples = sum(update.report.num_samples for update in self.updates)
|
228
|
+
self.average_accuracy = (
|
229
|
+
sum(
|
230
|
+
update.report.average_accuracy * update.report.num_samples
|
231
|
+
for update in self.updates
|
232
|
+
)
|
233
|
+
/ total_samples
|
234
|
+
)
|
235
|
+
|
236
|
+
logging.info(
|
237
|
+
"[%s] Average client accuracy: %.2f%%.",
|
238
|
+
self,
|
239
|
+
100 * self.average_accuracy,
|
240
|
+
)
|
241
|
+
|
242
|
+
if (
|
243
|
+
Config().is_central_server()
|
244
|
+
and hasattr(Config().server, "do_test")
|
245
|
+
and Config().server.do_test
|
246
|
+
):
|
247
|
+
# Testing the updated model directly at the server
|
248
|
+
logging.info("[%s] Started model testing.", self)
|
249
|
+
self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
|
250
|
+
|
251
|
+
if hasattr(Config().trainer, "target_perplexity"):
|
252
|
+
logging.info(
|
253
|
+
fonts.colourize(
|
254
|
+
f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
|
255
|
+
)
|
256
|
+
)
|
257
|
+
else:
|
258
|
+
logging.info(
|
259
|
+
fonts.colourize(
|
260
|
+
f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
|
261
|
+
)
|
262
|
+
)
|
263
|
+
elif (
|
264
|
+
Config().is_edge_server()
|
265
|
+
and hasattr(Config().server, "edge_do_test")
|
266
|
+
and Config().server.edge_do_test
|
267
|
+
):
|
268
|
+
# Test the aggregated model directly at the edge server
|
269
|
+
logging.info("[%s] Started model testing.", self)
|
270
|
+
self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
|
271
|
+
|
272
|
+
if hasattr(Config().trainer, "target_perplexity"):
|
273
|
+
logging.info(
|
274
|
+
fonts.colourize(
|
275
|
+
f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
|
276
|
+
)
|
277
|
+
)
|
278
|
+
else:
|
279
|
+
logging.info(
|
280
|
+
fonts.colourize(
|
281
|
+
f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
|
282
|
+
)
|
283
|
+
)
|
284
|
+
else:
|
285
|
+
self.accuracy = self.average_accuracy
|
286
|
+
self.accuracy_std = self.std_accuracy
|
287
|
+
|
288
|
+
self.clients_processed()
|
289
|
+
self.callback_handler.call_event("on_clients_processed", self)
|
290
|
+
|
291
|
+
def clients_processed(self):
|
292
|
+
"""Additional work to be performed after client reports have been processed."""
|
293
|
+
# Record results into a .csv file
|
294
|
+
if Config().is_central_server():
|
295
|
+
super().clients_processed()
|
296
|
+
|
297
|
+
if Config().is_edge_server():
|
298
|
+
logged_items = self.get_logged_items()
|
299
|
+
self.edge_training_time += logged_items["round_time"]
|
300
|
+
self.edge_comm_time += logged_items["comm_time"]
|
301
|
+
|
302
|
+
# When a certain number of aggregations are completed, an edge client
|
303
|
+
# needs to be signaled to send a report to the central server
|
304
|
+
if self.current_round == Config().algorithm.local_rounds:
|
305
|
+
logging.info(
|
306
|
+
"[Server #%d] Completed %s rounds of local aggregation.",
|
307
|
+
os.getpid(),
|
308
|
+
Config().algorithm.local_rounds,
|
309
|
+
)
|
310
|
+
self.model_aggregated.set()
|
311
|
+
|
312
|
+
self.current_round = 0
|
313
|
+
self.current_global_round += 1
|
314
|
+
|
315
|
+
def get_logged_items(self) -> dict:
|
316
|
+
"""Get items to be logged by the LogProgressCallback class in a .csv file."""
|
317
|
+
logged_items = super().get_logged_items()
|
318
|
+
|
319
|
+
logged_items["global_round"] = self.current_global_round
|
320
|
+
logged_items["average_accuracy"] = self.average_accuracy
|
321
|
+
logged_items["edge_agg_num"] = Config().algorithm.local_rounds
|
322
|
+
logged_items["local_epoch_num"] = Config().trainer.epochs
|
323
|
+
|
324
|
+
if Config().is_central_server():
|
325
|
+
logged_items["comm_time"] = max(
|
326
|
+
update.report.comm_time + update.report.edge_server_comm_time
|
327
|
+
for update in self.updates
|
328
|
+
)
|
329
|
+
|
330
|
+
return logged_items
|
331
|
+
|
332
|
+
async def wrap_up(self) -> None:
|
333
|
+
"""Wrapping up when each round of training is done."""
|
334
|
+
if Config().is_central_server():
|
335
|
+
await super().wrap_up()
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning server using federated averaging to train GAN models.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
|
7
|
+
from plato.servers import fedavg
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
|
11
|
+
class Server(fedavg.Server):
|
12
|
+
"""Federated learning server using federated averaging to train GAN models."""
|
13
|
+
|
14
|
+
async def aggregate_deltas(self, updates, deltas_received):
|
15
|
+
"""Aggregate weight updates from the clients using federated averaging."""
|
16
|
+
# Total sample is the same for both Generator and Discriminator
|
17
|
+
self.total_samples = sum(update.report.num_samples for update in updates)
|
18
|
+
|
19
|
+
# Perform weighted averaging for both Generator and Discriminator
|
20
|
+
gen_avg_update = {
|
21
|
+
name: self.trainer.zeros(weights.shape)
|
22
|
+
for name, weights in deltas_received[0][0].items()
|
23
|
+
}
|
24
|
+
disc_avg_update = {
|
25
|
+
name: self.trainer.zeros(weights.shape)
|
26
|
+
for name, weights in deltas_received[0][1].items()
|
27
|
+
}
|
28
|
+
|
29
|
+
for i, update in enumerate(deltas_received):
|
30
|
+
num_samples = updates[i].report.num_samples
|
31
|
+
|
32
|
+
update_from_gen, update_from_disc = update
|
33
|
+
|
34
|
+
for name, delta in update_from_gen.items():
|
35
|
+
gen_avg_update[name] += delta * (num_samples / self.total_samples)
|
36
|
+
|
37
|
+
for name, delta in update_from_disc.items():
|
38
|
+
disc_avg_update[name] += delta * (num_samples / self.total_samples)
|
39
|
+
|
40
|
+
# Yield to other tasks in the server
|
41
|
+
await asyncio.sleep(0)
|
42
|
+
|
43
|
+
return gen_avg_update, disc_avg_update
|
44
|
+
|
45
|
+
def customize_server_payload(self, payload):
|
46
|
+
"""
|
47
|
+
Customize the server payload before sending to the client.
|
48
|
+
|
49
|
+
At the end of each round, the server can choose to only send the global Generator
|
50
|
+
or Discriminator (or both or neither) model to the clients next round.
|
51
|
+
|
52
|
+
Reference this paper for more detail:
|
53
|
+
https://deepai.org/publication/federated-generative-adversarial-learning
|
54
|
+
|
55
|
+
By default, both model will be sent to the clients.
|
56
|
+
"""
|
57
|
+
if hasattr(Config().server, "network_to_sync"):
|
58
|
+
network = Config().server.network_to_sync.lower()
|
59
|
+
else:
|
60
|
+
network = "both"
|
61
|
+
|
62
|
+
weights_gen, weights_disc = payload
|
63
|
+
if network == "none":
|
64
|
+
server_payload = None, None
|
65
|
+
elif network == "generator":
|
66
|
+
server_payload = weights_gen, None
|
67
|
+
elif network == "discriminator":
|
68
|
+
server_payload = None, weights_disc
|
69
|
+
elif network == "both":
|
70
|
+
server_payload = payload
|
71
|
+
else:
|
72
|
+
raise ValueError(f"Unknown value to attribute network_to_sync: {network}")
|
73
|
+
|
74
|
+
return server_payload
|