plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,281 @@
1
+ """
2
+ A simple federated learning server using federated averaging.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+
9
+ from plato.algorithms import registry as algorithms_registry
10
+ from plato.config import Config
11
+ from plato.datasources import registry as datasources_registry
12
+ from plato.processors import registry as processor_registry
13
+ from plato.samplers import all_inclusive
14
+ from plato.servers import base
15
+ from plato.trainers import registry as trainers_registry
16
+ from plato.utils import csv_processor, fonts
17
+
18
+
19
+ class Server(base.Server):
20
+ """Federated learning server using federated averaging."""
21
+
22
+ def __init__(
23
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
24
+ ):
25
+ super().__init__(callbacks=callbacks)
26
+
27
+ self.custom_model = model
28
+ self.model = None
29
+
30
+ self.custom_algorithm = algorithm
31
+ self.algorithm = None
32
+
33
+ self.custom_trainer = trainer
34
+ self.trainer = None
35
+
36
+ self.custom_datasource = datasource
37
+ self.datasource = None
38
+
39
+ self.testset = None
40
+ self.testset_sampler = None
41
+ self.total_samples = 0
42
+
43
+ self.total_clients = Config().clients.total_clients
44
+ self.clients_per_round = Config().clients.per_round
45
+
46
+ logging.info(
47
+ "[Server #%d] Started training on %d clients with %d per round.",
48
+ os.getpid(),
49
+ self.total_clients,
50
+ self.clients_per_round,
51
+ )
52
+
53
+ def configure(self) -> None:
54
+ """
55
+ Booting the federated learning server by setting up the data, model, and
56
+ creating the clients.
57
+ """
58
+ super().configure()
59
+
60
+ total_rounds = Config().trainer.rounds
61
+ target_accuracy = None
62
+ target_perplexity = None
63
+
64
+ if hasattr(Config().trainer, "target_accuracy"):
65
+ target_accuracy = Config().trainer.target_accuracy
66
+ elif hasattr(Config().trainer, "target_perplexity"):
67
+ target_perplexity = Config().trainer.target_perplexity
68
+
69
+ if target_accuracy:
70
+ logging.info(
71
+ "Training: %s rounds or accuracy above %.1f%%\n",
72
+ total_rounds,
73
+ 100 * target_accuracy,
74
+ )
75
+ elif target_perplexity:
76
+ logging.info(
77
+ "Training: %s rounds or perplexity below %.1f\n",
78
+ total_rounds,
79
+ target_perplexity,
80
+ )
81
+ else:
82
+ logging.info("Training: %s rounds\n", total_rounds)
83
+
84
+ self.init_trainer()
85
+
86
+ # Prepares this server for processors that processes outbound and inbound
87
+ # data payloads
88
+ self.outbound_processor, self.inbound_processor = processor_registry.get(
89
+ "Server", server_id=os.getpid(), trainer=self.trainer
90
+ )
91
+
92
+ if not (hasattr(Config().server, "do_test") and not Config().server.do_test):
93
+ if self.datasource is None and self.custom_datasource is None:
94
+ self.datasource = datasources_registry.get(client_id=0)
95
+ elif self.datasource is None and self.custom_datasource is not None:
96
+ self.datasource = self.custom_datasource()
97
+
98
+ self.testset = self.datasource.get_test_set()
99
+ if hasattr(Config().data, "testset_size"):
100
+ self.testset_sampler = all_inclusive.Sampler(
101
+ self.datasource, testing=True
102
+ )
103
+
104
+ # Initialize the test accuracy csv file if clients compute locally
105
+ if (
106
+ hasattr(Config().clients, "do_test")
107
+ and Config().clients.do_test
108
+ and (
109
+ hasattr(Config(), "results")
110
+ and hasattr(Config().results, "record_clients_accuracy")
111
+ and Config().results.record_clients_accuracy
112
+ )
113
+ ):
114
+ accuracy_csv_file = (
115
+ f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
116
+ )
117
+ accuracy_headers = ["round", "client_id", "accuracy"]
118
+ csv_processor.initialize_csv(
119
+ accuracy_csv_file, accuracy_headers, Config().params["result_path"]
120
+ )
121
+
122
+ def init_trainer(self) -> None:
123
+ """Setting up the global model, trainer, and algorithm."""
124
+ if self.model is None and self.custom_model is not None:
125
+ self.model = self.custom_model
126
+
127
+ if self.trainer is None and self.custom_trainer is None:
128
+ self.trainer = trainers_registry.get(model=self.model)
129
+ elif self.trainer is None and self.custom_trainer is not None:
130
+ self.trainer = self.custom_trainer(model=self.model)
131
+
132
+ if self.algorithm is None and self.custom_algorithm is None:
133
+ self.algorithm = algorithms_registry.get(trainer=self.trainer)
134
+ elif self.algorithm is None and self.custom_algorithm is not None:
135
+ self.algorithm = self.custom_algorithm(trainer=self.trainer)
136
+
137
+ async def aggregate_deltas(self, updates, deltas_received):
138
+ """Aggregate weight updates from the clients using federated averaging."""
139
+ # Extract the total number of samples
140
+ self.total_samples = sum(update.report.num_samples for update in updates)
141
+
142
+ # Perform weighted averaging
143
+ avg_update = {
144
+ name: self.trainer.zeros(delta.shape)
145
+ for name, delta in deltas_received[0].items()
146
+ }
147
+
148
+ for i, update in enumerate(deltas_received):
149
+ report = updates[i].report
150
+ num_samples = report.num_samples
151
+
152
+ for name, delta in update.items():
153
+ # Use weighted average by the number of samples
154
+ avg_update[name] += delta * (num_samples / self.total_samples)
155
+
156
+ # Yield to other tasks in the server
157
+ await asyncio.sleep(0)
158
+
159
+ return avg_update
160
+
161
+ async def _process_reports(self):
162
+ """Process the client reports by aggregating their weights."""
163
+ weights_received = [update.payload for update in self.updates]
164
+
165
+ weights_received = self.weights_received(weights_received)
166
+ self.callback_handler.call_event("on_weights_received", self, weights_received)
167
+
168
+ # Extract the current model weights as the baseline
169
+ baseline_weights = self.algorithm.extract_weights()
170
+
171
+ if hasattr(self, "aggregate_weights"):
172
+ # Runs a server aggregation algorithm using weights rather than deltas
173
+ logging.info(
174
+ "[Server #%d] Aggregating model weights directly rather than weight deltas.",
175
+ os.getpid(),
176
+ )
177
+ updated_weights = await self.aggregate_weights(
178
+ self.updates, baseline_weights, weights_received
179
+ )
180
+
181
+ # Loads the new model weights
182
+ self.algorithm.load_weights(updated_weights)
183
+ else:
184
+ # Computes the weight deltas by comparing the weights received with
185
+ # the current global model weights
186
+ deltas_received = self.algorithm.compute_weight_deltas(
187
+ baseline_weights, weights_received
188
+ )
189
+ # Runs a framework-agnostic server aggregation algorithm, such as
190
+ # the federated averaging algorithm
191
+ logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid())
192
+ deltas = await self.aggregate_deltas(self.updates, deltas_received)
193
+ # Updates the existing model weights from the provided deltas
194
+ updated_weights = self.algorithm.update_weights(deltas)
195
+ # Loads the new model weights
196
+ self.algorithm.load_weights(updated_weights)
197
+
198
+ # The model weights have already been aggregated, now calls the
199
+ # corresponding hook and callback
200
+ self.weights_aggregated(self.updates)
201
+ self.callback_handler.call_event("on_weights_aggregated", self, self.updates)
202
+
203
+ # Testing the global model accuracy
204
+ if hasattr(Config().server, "do_test") and not Config().server.do_test:
205
+ # Compute the average accuracy from client reports
206
+ self.accuracy, self.accuracy_std = self.get_accuracy_mean_std(self.updates)
207
+ logging.info(
208
+ "[%s] Average client accuracy: %.2f%%.", self, 100 * self.accuracy
209
+ )
210
+ else:
211
+ # Testing the updated model directly at the server
212
+ logging.info("[%s] Started model testing.", self)
213
+ self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
214
+
215
+ if hasattr(Config().trainer, "target_perplexity"):
216
+ logging.info(
217
+ fonts.colourize(
218
+ f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
219
+ )
220
+ )
221
+ else:
222
+ logging.info(
223
+ fonts.colourize(
224
+ f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
225
+ )
226
+ )
227
+
228
+ self.clients_processed()
229
+ self.callback_handler.call_event("on_clients_processed", self)
230
+
231
+ def clients_processed(self) -> None:
232
+ """Additional work to be performed after client reports have been processed."""
233
+
234
+ def get_logged_items(self) -> dict:
235
+ """Get items to be logged by the LogProgressCallback class in a .csv file."""
236
+ return {
237
+ "round": self.current_round,
238
+ "accuracy": self.accuracy,
239
+ "accuracy_std": self.accuracy_std,
240
+ "elapsed_time": self.wall_time - self.initial_wall_time,
241
+ "processing_time": max(
242
+ update.report.processing_time for update in self.updates
243
+ ),
244
+ "comm_time": max(update.report.comm_time for update in self.updates),
245
+ "round_time": max(
246
+ update.report.training_time
247
+ + update.report.processing_time
248
+ + update.report.comm_time
249
+ for update in self.updates
250
+ ),
251
+ "comm_overhead": self.comm_overhead,
252
+ }
253
+
254
+ @staticmethod
255
+ def get_accuracy_mean_std(updates):
256
+ """Compute the accuracy mean and standard deviation across clients."""
257
+ # Get total number of samples
258
+ total_samples = sum(update.report.num_samples for update in updates)
259
+
260
+ # Perform weighted averaging
261
+ updates_accuracy = [update.report.accuracy for update in updates]
262
+ weights = [update.report.num_samples / total_samples for update in updates]
263
+
264
+ mean = sum(acc * weights[idx] for idx, acc in enumerate(updates_accuracy))
265
+ variance = sum(
266
+ (acc - mean) ** 2 * weights[idx] for idx, acc in enumerate(updates_accuracy)
267
+ )
268
+ std = variance**0.5
269
+
270
+ return mean, std
271
+
272
+ def weights_received(self, weights_received):
273
+ """
274
+ Method called after the updated weights have been received.
275
+ """
276
+ return weights_received
277
+
278
+ def weights_aggregated(self, updates):
279
+ """
280
+ Method called after the updated weights have been aggregated.
281
+ """
@@ -0,0 +1,335 @@
1
+ """
2
+ A cross-silo federated learning server using federated averaging, as either edge or central servers.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ import numpy as np
9
+
10
+ from plato.config import Config
11
+ from plato.datasources import registry as datasources_registry
12
+ from plato.processors import registry as processor_registry
13
+ from plato.samplers import registry as samplers_registry
14
+ from plato.samplers import all_inclusive
15
+ from plato.servers import fedavg
16
+ from plato.utils import fonts
17
+
18
+
19
+ class Server(fedavg.Server):
20
+ """Cross-silo federated learning server using federated averaging."""
21
+
22
+ def __init__(
23
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
24
+ ):
25
+ super().__init__(
26
+ model=model,
27
+ datasource=datasource,
28
+ algorithm=algorithm,
29
+ trainer=trainer,
30
+ callbacks=callbacks,
31
+ )
32
+
33
+ self.current_global_round = 0
34
+ self.average_accuracy = 0
35
+ self.std_accuracy = 0
36
+
37
+ if Config().is_edge_server():
38
+ # An edge client waits for the event that a certain number of
39
+ # aggregations are completed
40
+ self.model_aggregated = asyncio.Event()
41
+
42
+ # An edge client waits for the event that a new global round begins
43
+ # before starting the first round of local aggregation
44
+ self.new_global_round_begins = asyncio.Event()
45
+
46
+ edge_server_id = Config().args.id - Config().clients.total_clients
47
+
48
+ # Compute the total number of clients in each silo for edge servers
49
+ edges_total_clients = [
50
+ len(i)
51
+ for i in np.array_split(
52
+ np.arange(Config().clients.total_clients),
53
+ Config().algorithm.total_silos,
54
+ )
55
+ ]
56
+ self.total_clients = edges_total_clients[edge_server_id - 1]
57
+
58
+ self.clients_per_round = [
59
+ len(i)
60
+ for i in np.array_split(
61
+ np.arange(Config().clients.per_round),
62
+ Config().algorithm.total_silos,
63
+ )
64
+ ][edge_server_id - 1]
65
+
66
+ starting_client_id = sum(edges_total_clients[: edge_server_id - 1])
67
+ self.clients_pool = list(
68
+ range(
69
+ starting_client_id + 1, starting_client_id + 1 + self.total_clients
70
+ )
71
+ )
72
+
73
+ logging.info(
74
+ "[Edge server #%d (#%d)] Started training on %d clients with %d per round.",
75
+ Config().args.id,
76
+ os.getpid(),
77
+ self.total_clients,
78
+ self.clients_per_round,
79
+ )
80
+
81
+ # The training time of a edge server in one global round
82
+ self.edge_training_time = 0
83
+
84
+ # The training time of a edge server with its clients in one global round
85
+ self.edge_comm_time = 0
86
+
87
+ # Compute the number of clients for the central server
88
+ if Config().is_central_server():
89
+ self.clients_per_round = Config().algorithm.total_silos
90
+ self.total_clients = self.clients_per_round
91
+
92
+ logging.info(
93
+ "The central server starts training with %s edge servers.",
94
+ self.total_clients,
95
+ )
96
+
97
+ def configure(self) -> None:
98
+ """
99
+ Booting the federated learning server by setting up the data, model, and
100
+ creating the clients.
101
+ """
102
+ super().configure()
103
+
104
+ if Config().is_edge_server():
105
+ logging.info(
106
+ "Configuring edge server #%d as a %s server.",
107
+ Config().args.id,
108
+ Config().algorithm.type,
109
+ )
110
+ logging.info(
111
+ "[Edge server #%d (#%d)] Training with %s local aggregation rounds.",
112
+ Config().args.id,
113
+ os.getpid(),
114
+ Config().algorithm.local_rounds,
115
+ )
116
+
117
+ self.init_trainer()
118
+ self.trainer.set_client_id(Config().args.id)
119
+
120
+ # Prepares this server for processors that processes outbound and inbound
121
+ # data payloads
122
+ self.outbound_processor, self.inbound_processor = processor_registry.get(
123
+ "Server", server_id=os.getpid(), trainer=self.trainer
124
+ )
125
+
126
+ if (
127
+ hasattr(Config().server, "edge_do_test")
128
+ and Config().server.edge_do_test
129
+ ):
130
+ self.datasource = datasources_registry.get(client_id=0)
131
+ self.testset = self.datasource.get_test_set()
132
+
133
+ if hasattr(Config().data, "testset_sampler"):
134
+ # Set the sampler for test set
135
+ self.testset_sampler = samplers_registry.get(
136
+ self.datasource, Config().args.id, testing=True
137
+ )
138
+ else:
139
+ if hasattr(Config().data, "testset_size"):
140
+ self.testset_sampler = all_inclusive.Sampler(
141
+ self.datasource, testing=True
142
+ )
143
+
144
+ async def _select_clients(self, for_next_batch=False):
145
+ if Config().is_edge_server() and not for_next_batch:
146
+ if self.current_round == 0:
147
+ # Wait until this edge server is selected by the central server
148
+ # to avoid the edge server selects clients and clients begin training
149
+ # before the edge server is selected
150
+ await self.new_global_round_begins.wait()
151
+ self.new_global_round_begins.clear()
152
+
153
+ await super()._select_clients(for_next_batch=for_next_batch)
154
+
155
+ def customize_server_response(self, server_response: dict, client_id) -> dict:
156
+ """Wrap up generating the server response with any additional information."""
157
+ if Config().is_central_server():
158
+ server_response["current_global_round"] = self.current_round
159
+ return server_response
160
+
161
+ async def _process_reports(self):
162
+ """Process the client reports by aggregating their weights."""
163
+ # To pass the client_id == 0 assertion during aggregation
164
+ self.trainer.set_client_id(0)
165
+
166
+ weights_received = [update.payload for update in self.updates]
167
+
168
+ weights_received = self.weights_received(weights_received)
169
+ self.callback_handler.call_event("on_weights_received", self, weights_received)
170
+
171
+ # Extract the current model weights as the baseline
172
+ baseline_weights = self.algorithm.extract_weights()
173
+
174
+ if hasattr(self, "aggregate_weights"):
175
+ # Runs a server aggregation algorithm using weights rather than deltas
176
+ logging.info(
177
+ "[Server #%d] Aggregating model weights directly rather than weight deltas.",
178
+ os.getpid(),
179
+ )
180
+ updated_weights = await self.aggregate_weights(
181
+ self.updates, baseline_weights, weights_received
182
+ )
183
+
184
+ # Loads the new model weights
185
+ self.algorithm.load_weights(updated_weights)
186
+ else:
187
+ # Computes the weight deltas by comparing the weights received with
188
+ # the current global model weights
189
+ deltas_received = self.algorithm.compute_weight_deltas(
190
+ baseline_weights, weights_received
191
+ )
192
+ # Runs a framework-agnostic server aggregation algorithm, such as
193
+ # the federated averaging algorithm
194
+ logging.info("[Server #%d] Aggregating model weight deltas.", os.getpid())
195
+ deltas = await self.aggregate_deltas(self.updates, deltas_received)
196
+ # Updates the existing model weights from the provided deltas
197
+ updated_weights = self.algorithm.update_weights(deltas)
198
+ # Loads the new model weights
199
+ self.algorithm.load_weights(updated_weights)
200
+
201
+ # The model weights have already been aggregated, now calls the
202
+ # corresponding hook and callback
203
+ self.weights_aggregated(self.updates)
204
+ self.callback_handler.call_event("on_weights_aggregated", self, self.updates)
205
+
206
+ if Config().is_edge_server():
207
+ self.trainer.set_client_id(Config().args.id)
208
+
209
+ # Testing the model accuracy
210
+ if (Config().is_edge_server() and Config().clients.do_test) or (
211
+ Config().is_central_server()
212
+ and hasattr(Config().server, "edge_do_test")
213
+ and Config().server.edge_do_test
214
+ ):
215
+ # Compute the average accuracy from client reports
216
+ (
217
+ self.average_accuracy,
218
+ self.std_accuracy,
219
+ ) = self.get_accuracy_mean_std(self.updates)
220
+ logging.info(
221
+ "[%s] Average client accuracy: %.2f%%.",
222
+ self,
223
+ 100 * self.average_accuracy,
224
+ )
225
+ elif Config().is_central_server() and Config().clients.do_test:
226
+ # Compute the average accuracy from client reports
227
+ total_samples = sum(update.report.num_samples for update in self.updates)
228
+ self.average_accuracy = (
229
+ sum(
230
+ update.report.average_accuracy * update.report.num_samples
231
+ for update in self.updates
232
+ )
233
+ / total_samples
234
+ )
235
+
236
+ logging.info(
237
+ "[%s] Average client accuracy: %.2f%%.",
238
+ self,
239
+ 100 * self.average_accuracy,
240
+ )
241
+
242
+ if (
243
+ Config().is_central_server()
244
+ and hasattr(Config().server, "do_test")
245
+ and Config().server.do_test
246
+ ):
247
+ # Testing the updated model directly at the server
248
+ logging.info("[%s] Started model testing.", self)
249
+ self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
250
+
251
+ if hasattr(Config().trainer, "target_perplexity"):
252
+ logging.info(
253
+ fonts.colourize(
254
+ f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
255
+ )
256
+ )
257
+ else:
258
+ logging.info(
259
+ fonts.colourize(
260
+ f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
261
+ )
262
+ )
263
+ elif (
264
+ Config().is_edge_server()
265
+ and hasattr(Config().server, "edge_do_test")
266
+ and Config().server.edge_do_test
267
+ ):
268
+ # Test the aggregated model directly at the edge server
269
+ logging.info("[%s] Started model testing.", self)
270
+ self.accuracy = self.trainer.test(self.testset, self.testset_sampler)
271
+
272
+ if hasattr(Config().trainer, "target_perplexity"):
273
+ logging.info(
274
+ fonts.colourize(
275
+ f"[{self}] Global model perplexity: {self.accuracy:.2f}\n"
276
+ )
277
+ )
278
+ else:
279
+ logging.info(
280
+ fonts.colourize(
281
+ f"[{self}] Global model accuracy: {100 * self.accuracy:.2f}%\n"
282
+ )
283
+ )
284
+ else:
285
+ self.accuracy = self.average_accuracy
286
+ self.accuracy_std = self.std_accuracy
287
+
288
+ self.clients_processed()
289
+ self.callback_handler.call_event("on_clients_processed", self)
290
+
291
+ def clients_processed(self):
292
+ """Additional work to be performed after client reports have been processed."""
293
+ # Record results into a .csv file
294
+ if Config().is_central_server():
295
+ super().clients_processed()
296
+
297
+ if Config().is_edge_server():
298
+ logged_items = self.get_logged_items()
299
+ self.edge_training_time += logged_items["round_time"]
300
+ self.edge_comm_time += logged_items["comm_time"]
301
+
302
+ # When a certain number of aggregations are completed, an edge client
303
+ # needs to be signaled to send a report to the central server
304
+ if self.current_round == Config().algorithm.local_rounds:
305
+ logging.info(
306
+ "[Server #%d] Completed %s rounds of local aggregation.",
307
+ os.getpid(),
308
+ Config().algorithm.local_rounds,
309
+ )
310
+ self.model_aggregated.set()
311
+
312
+ self.current_round = 0
313
+ self.current_global_round += 1
314
+
315
+ def get_logged_items(self) -> dict:
316
+ """Get items to be logged by the LogProgressCallback class in a .csv file."""
317
+ logged_items = super().get_logged_items()
318
+
319
+ logged_items["global_round"] = self.current_global_round
320
+ logged_items["average_accuracy"] = self.average_accuracy
321
+ logged_items["edge_agg_num"] = Config().algorithm.local_rounds
322
+ logged_items["local_epoch_num"] = Config().trainer.epochs
323
+
324
+ if Config().is_central_server():
325
+ logged_items["comm_time"] = max(
326
+ update.report.comm_time + update.report.edge_server_comm_time
327
+ for update in self.updates
328
+ )
329
+
330
+ return logged_items
331
+
332
+ async def wrap_up(self) -> None:
333
+ """Wrapping up when each round of training is done."""
334
+ if Config().is_central_server():
335
+ await super().wrap_up()
@@ -0,0 +1,74 @@
1
+ """
2
+ A federated learning server using federated averaging to train GAN models.
3
+ """
4
+
5
+ import asyncio
6
+
7
+ from plato.servers import fedavg
8
+ from plato.config import Config
9
+
10
+
11
+ class Server(fedavg.Server):
12
+ """Federated learning server using federated averaging to train GAN models."""
13
+
14
+ async def aggregate_deltas(self, updates, deltas_received):
15
+ """Aggregate weight updates from the clients using federated averaging."""
16
+ # Total sample is the same for both Generator and Discriminator
17
+ self.total_samples = sum(update.report.num_samples for update in updates)
18
+
19
+ # Perform weighted averaging for both Generator and Discriminator
20
+ gen_avg_update = {
21
+ name: self.trainer.zeros(weights.shape)
22
+ for name, weights in deltas_received[0][0].items()
23
+ }
24
+ disc_avg_update = {
25
+ name: self.trainer.zeros(weights.shape)
26
+ for name, weights in deltas_received[0][1].items()
27
+ }
28
+
29
+ for i, update in enumerate(deltas_received):
30
+ num_samples = updates[i].report.num_samples
31
+
32
+ update_from_gen, update_from_disc = update
33
+
34
+ for name, delta in update_from_gen.items():
35
+ gen_avg_update[name] += delta * (num_samples / self.total_samples)
36
+
37
+ for name, delta in update_from_disc.items():
38
+ disc_avg_update[name] += delta * (num_samples / self.total_samples)
39
+
40
+ # Yield to other tasks in the server
41
+ await asyncio.sleep(0)
42
+
43
+ return gen_avg_update, disc_avg_update
44
+
45
+ def customize_server_payload(self, payload):
46
+ """
47
+ Customize the server payload before sending to the client.
48
+
49
+ At the end of each round, the server can choose to only send the global Generator
50
+ or Discriminator (or both or neither) model to the clients next round.
51
+
52
+ Reference this paper for more detail:
53
+ https://deepai.org/publication/federated-generative-adversarial-learning
54
+
55
+ By default, both model will be sent to the clients.
56
+ """
57
+ if hasattr(Config().server, "network_to_sync"):
58
+ network = Config().server.network_to_sync.lower()
59
+ else:
60
+ network = "both"
61
+
62
+ weights_gen, weights_disc = payload
63
+ if network == "none":
64
+ server_payload = None, None
65
+ elif network == "generator":
66
+ server_payload = weights_gen, None
67
+ elif network == "discriminator":
68
+ server_payload = None, weights_disc
69
+ elif network == "both":
70
+ server_payload = payload
71
+ else:
72
+ raise ValueError(f"Unknown value to attribute network_to_sync: {network}")
73
+
74
+ return server_payload