plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/servers/base.py ADDED
@@ -0,0 +1,1395 @@
1
+ """
2
+ The base class for federated learning servers.
3
+ """
4
+
5
+ import asyncio
6
+ import heapq
7
+ import logging
8
+ import multiprocessing as mp
9
+ import os
10
+ import pickle
11
+ import random
12
+ import sys
13
+ import time
14
+ from abc import abstractmethod
15
+ from types import SimpleNamespace
16
+
17
+ import numpy as np
18
+ import socketio
19
+ from aiohttp import web
20
+
21
+ from plato.callbacks.handler import CallbackHandler
22
+ from plato.callbacks.server import LogProgressCallback
23
+ from plato.client import run
24
+ from plato.config import Config
25
+ from plato.utils import fonts, s3
26
+
27
+
28
+ # pylint: disable=unused-argument, protected-access
29
+ class ServerEvents(socketio.AsyncNamespace):
30
+ """A custom namespace for socketio.AsyncServer."""
31
+
32
+ def __init__(self, namespace, plato_server):
33
+ super().__init__(namespace)
34
+ self.plato_server = plato_server
35
+
36
+ async def on_connect(self, sid, environ):
37
+ """Upon a new connection from a client."""
38
+ logging.info("[Server #%d] A new client just connected.", os.getpid())
39
+
40
+ async def on_disconnect(self, sid, reason=None):
41
+ """Upon a disconnection event."""
42
+ logging.info("[Server #%d] An existing client just disconnected.", os.getpid())
43
+ await self.plato_server._client_disconnected(sid)
44
+
45
+ async def on_client_alive(self, sid, data):
46
+ """A new client arrived."""
47
+ await self.plato_server.register_client(sid, data["pid"], data["id"])
48
+
49
+ async def on_client_report(self, sid, data):
50
+ """An existing client sends a new report from local training."""
51
+ await self.plato_server._client_report_arrived(sid, data["id"], data["report"])
52
+
53
+ async def on_chunk(self, sid, data):
54
+ """A chunk of data from the server arrived."""
55
+ await self.plato_server._client_chunk_arrived(sid, data["data"])
56
+
57
+ async def on_client_payload(self, sid, data):
58
+ """An existing client sends a new payload from local training."""
59
+ await self.plato_server._client_payload_arrived(sid, data["id"])
60
+
61
+ async def on_client_payload_done(self, sid, data):
62
+ """An existing client finished sending its payloads from local training."""
63
+ if "s3_key" in data:
64
+ await self.plato_server._client_payload_done(
65
+ sid, data["id"], s3_key=data["s3_key"]
66
+ )
67
+ else:
68
+ await self.plato_server._client_payload_done(sid, data["id"])
69
+
70
+
71
+ class Server:
72
+ """The base class for federated learning servers."""
73
+
74
+ def __init__(self, callbacks=None):
75
+ self.sio = None
76
+ self.client = None
77
+ self.clients = {}
78
+ self.total_clients = 0
79
+ # The client ids are stored for client selection
80
+ self.clients_pool = []
81
+ self.clients_per_round = 0
82
+ self.selected_clients = None
83
+ self.selected_client_id = 0
84
+ self.selected_sids = []
85
+ self.current_round = 0
86
+ self.resumed_session = False
87
+ self.algorithm = None
88
+ self.trainer = None
89
+ self.accuracy = 0
90
+ self.accuracy_std = 0
91
+ self.reports = {}
92
+ self.updates = []
93
+ self.client_payload = {}
94
+ self.client_chunks = {}
95
+ self.s3_client = None
96
+ self.outbound_processor = None
97
+ self.inbound_processor = None
98
+ self.comm_simulation = (
99
+ Config().clients.comm_simulation
100
+ if hasattr(Config().clients, "comm_simulation")
101
+ else True
102
+ )
103
+
104
+ # Starting from the default server callback class, add all supplied server callbacks
105
+ self.callbacks = [LogProgressCallback]
106
+ if callbacks is not None:
107
+ self.callbacks.extend(callbacks)
108
+ self.callback_handler = CallbackHandler(self.callbacks)
109
+
110
+ # Accumulated communication overhead (MB) throughout the FL training session
111
+ self.comm_overhead = 0
112
+
113
+ # Downlink and uplink bandwidth (Mbps)
114
+ # for computing communication time in communication simulation mode
115
+ self.downlink_bandwidth = (
116
+ Config().server.downlink_bandwidth
117
+ if hasattr(Config().server, "downlink_bandwidth")
118
+ else 100
119
+ )
120
+ self.uplink_bandwidth = (
121
+ Config().server.uplink_bandwidth
122
+ if hasattr(Config().server, "uplink_bandwidth")
123
+ else 100
124
+ )
125
+ if Config().is_edge_server():
126
+ if hasattr(Config().server, "edge_downlink_bandwidth"):
127
+ self.downlink_bandwidth = Config().server.edge_downlink_bandwidth
128
+ if hasattr(Config().server, "edge_uplink_bandwidth"):
129
+ self.uplink_bandwidth = Config().server.edge_uplink_bandwidth
130
+
131
+ # Use dictionaries to record downlink/uplink communication time of each client
132
+ self.downlink_comm_time = {}
133
+ self.uplink_comm_time = {}
134
+
135
+ # States that need to be maintained for asynchronous FL
136
+
137
+ # sids that are currently in use
138
+ self.training_sids = []
139
+
140
+ # Clients whose new reports were received but not yet processed
141
+ self.reported_clients = []
142
+
143
+ # Clients who are still training since the last round of aggregation
144
+ self.training_clients = {}
145
+
146
+ # The wall clock time that is simulated to accommodate the fact that
147
+ # clients can only run a batch at a time, controlled by `max_concurrency`
148
+ self.initial_wall_time = time.time()
149
+ self.wall_time = time.time()
150
+
151
+ # The wall clock time when a communication round starts
152
+ self.round_start_wall_time = self.wall_time
153
+
154
+ # When simulating the wall clock time, the server needs to remember the
155
+ # set of reporting clients received since the previous round of aggregation
156
+ self.current_reported_clients = {}
157
+ self.current_processed_clients = {}
158
+ self.prng_state = random.getstate()
159
+
160
+ self.ping_interval = 3600
161
+ self.ping_timeout = 3600
162
+ self.asynchronous_mode = False
163
+ self.periodic_interval = 5
164
+ self.staleness_bound = 1000
165
+ self.minimum_clients = 1
166
+ self.simulate_wall_time = False
167
+ self.request_update = False
168
+ self.disable_clients = False
169
+
170
+ # With specifying max_concurrency, selected clients run batch by batach
171
+ # The number of clients in a batch on an available device is the same as the max_concurrency
172
+ # This list contains ids of selected clients that has run in the current round
173
+ if hasattr(Config().trainer, "max_concurrency"):
174
+ self.trained_clients = []
175
+
176
+ def __repr__(self):
177
+ return f"Server #{os.getpid()}"
178
+
179
+ def __str__(self):
180
+ return f"Server #{os.getpid()}"
181
+
182
+ def configure(self) -> None:
183
+ """Initializes configuration settings based on the configuration file."""
184
+ logging.info("[%s] Configuring the server...", self)
185
+
186
+ # Ping interval and timeout setup for the server
187
+ self.ping_interval = (
188
+ Config().server.ping_interval
189
+ if hasattr(Config().server, "ping_interval")
190
+ else 3600
191
+ )
192
+ self.ping_timeout = (
193
+ Config().server.ping_timeout
194
+ if hasattr(Config().server, "ping_timeout")
195
+ else 3600
196
+ )
197
+
198
+ # Are we operating in asynchronous mode?
199
+ self.asynchronous_mode = (
200
+ hasattr(Config().server, "synchronous") and not Config().server.synchronous
201
+ )
202
+
203
+ # What is the periodic interval for running our periodic task in asynchronous mode?
204
+ self.periodic_interval = (
205
+ Config().server.periodic_interval
206
+ if hasattr(Config().server, "periodic_interval")
207
+ else 5
208
+ )
209
+
210
+ # The staleness threshold is used to determine if a training clients should be
211
+ # considered 'stale', if their starting round is too much behind the current round
212
+ # on the server
213
+ self.staleness_bound = (
214
+ Config().server.staleness_bound
215
+ if hasattr(Config().server, "staleness_bound")
216
+ else 0
217
+ )
218
+
219
+ if not Config().is_central_server():
220
+ # What is the minimum number of clients that must have reported before aggregation
221
+ # takes place?
222
+ self.minimum_clients = (
223
+ Config().server.minimum_clients_aggregated
224
+ if hasattr(Config().server, "minimum_clients_aggregated")
225
+ else 1
226
+ )
227
+ else:
228
+ # In cross-silo FL, what is the minimum number of edge servers that must have reported
229
+ # before the central server conduct aggregation?
230
+ self.minimum_clients = (
231
+ Config().server.minimum_edges_aggregated
232
+ if hasattr(Config().server, "minimum_edges_aggregated")
233
+ else Config().algorithm.total_silos
234
+ )
235
+
236
+ # Are we simulating the wall clock time on the server? This is useful when the clients
237
+ # are training in batches due to a lack of memory on the GPUs
238
+ self.simulate_wall_time = (
239
+ hasattr(Config().server, "simulate_wall_time")
240
+ and Config().server.simulate_wall_time
241
+ )
242
+
243
+ # Do we wish to send urgent requests for model updates to the slow clients?
244
+ self.request_update = (
245
+ hasattr(Config().server, "request_update")
246
+ and Config().server.request_update
247
+ )
248
+
249
+ # Are we disabling all clients and prevent them from running?
250
+ self.disable_clients = (
251
+ hasattr(Config().server, "disable_clients")
252
+ and Config().server.disable_clients
253
+ )
254
+
255
+ # Compute the per-client uplink bandwidth
256
+ if self.asynchronous_mode:
257
+ self.uplink_bandwidth = self.uplink_bandwidth / self.minimum_clients
258
+ else:
259
+ self.uplink_bandwidth = self.uplink_bandwidth / self.clients_per_round
260
+
261
+ def run(self, client=None, edge_server=None, edge_client=None, trainer=None):
262
+ """Starts a run loop for the server."""
263
+ self.client = client
264
+ self.configure()
265
+
266
+ if Config().args.resume:
267
+ self._resume_from_checkpoint()
268
+
269
+ if Config().is_central_server():
270
+ # Start the edge servers as clients of the central server first
271
+ # Once all edge servers are live, clients will be initialized in the
272
+ # training_will_start() event call of the central server
273
+ Server._start_clients(
274
+ as_server=True,
275
+ client=self.client,
276
+ edge_server=edge_server,
277
+ edge_client=edge_client,
278
+ trainer=trainer,
279
+ )
280
+
281
+ asyncio.get_event_loop().create_task(self._periodic(self.periodic_interval))
282
+ if hasattr(Config().server, "random_seed"):
283
+ seed = Config().server.random_seed
284
+ logging.info("Setting the random seed for selecting clients: %s", seed)
285
+ random.seed(seed)
286
+ self.prng_state = random.getstate()
287
+ self.start()
288
+
289
+ else:
290
+ if self.disable_clients:
291
+ logging.info("No clients are launched (server:disable_clients = true)")
292
+ else:
293
+ Server._start_clients(client=self.client)
294
+
295
+ asyncio.get_event_loop().create_task(self._periodic(self.periodic_interval))
296
+
297
+ if hasattr(Config().server, "random_seed"):
298
+ seed = Config().server.random_seed
299
+ logging.info("Setting the random seed for selecting clients: %s", seed)
300
+ random.seed(seed)
301
+ self.prng_state = random.getstate()
302
+
303
+ self.start()
304
+
305
+ def start(self, port=Config().server.port):
306
+ """Starts running the socket.io server."""
307
+ logging.info(
308
+ "Starting a server at address %s and port %s.",
309
+ Config().server.address,
310
+ port,
311
+ )
312
+
313
+ self.sio = socketio.AsyncServer(
314
+ ping_interval=self.ping_interval,
315
+ max_http_buffer_size=2**31,
316
+ ping_timeout=self.ping_timeout,
317
+ )
318
+ self.sio.register_namespace(ServerEvents(namespace="/", plato_server=self))
319
+
320
+ if hasattr(Config().server, "s3_endpoint_url"):
321
+ self.s3_client = s3.S3()
322
+
323
+ app = web.Application()
324
+ self.sio.attach(app)
325
+ web.run_app(
326
+ app,
327
+ host=Config().server.address,
328
+ port=port,
329
+ loop=asyncio.get_event_loop(),
330
+ )
331
+
332
+ async def register_client(self, sid, client_process_id, client_id):
333
+ """Adds a newly arrived client to the list of clients."""
334
+ self.clients[client_process_id] = {
335
+ "sid": sid,
336
+ "client_id": client_id,
337
+ }
338
+ logging.info("[%s] New client with id #%d arrived.", self, client_id)
339
+ logging.info("[%s] Client process #%d registered.", self, client_process_id)
340
+
341
+ if (
342
+ hasattr(Config().trainer, "max_concurrency")
343
+ and not Config().is_central_server()
344
+ ):
345
+ required_launched_clients = min(
346
+ Config().trainer.max_concurrency * max(1, Config().gpu_count()),
347
+ self.clients_per_round,
348
+ )
349
+ else:
350
+ required_launched_clients = self.clients_per_round
351
+
352
+ if (self.current_round == 0 or self.resumed_session) and len(
353
+ self.clients
354
+ ) >= required_launched_clients:
355
+ self.resumed_session = False
356
+
357
+ self.training_will_start()
358
+ self.callback_handler.call_event("on_training_will_start", self)
359
+
360
+ await self._select_clients()
361
+
362
+ @staticmethod
363
+ def _start_clients(
364
+ client=None,
365
+ as_server=False,
366
+ edge_server=None,
367
+ edge_client=None,
368
+ trainer=None,
369
+ ):
370
+ """Starts all the clients as separate processes."""
371
+ starting_id = 1
372
+
373
+ # We only need to launch the number of clients necessary for concurrent training
374
+ # If `max_concurrency` in `trainer` is specified, the limit number is
375
+ # `max_concurrency` multiply the number of available devices
376
+ # (multiply number of edge servers in cross-silo training)
377
+ if hasattr(Config().trainer, "max_concurrency"):
378
+ if Config().is_central_server():
379
+ client_processes = min(
380
+ Config().trainer.max_concurrency
381
+ * max(1, Config().gpu_count())
382
+ * Config().algorithm.total_silos,
383
+ Config().clients.per_round,
384
+ )
385
+ else:
386
+ client_processes = min(
387
+ Config().trainer.max_concurrency * max(1, Config().gpu_count()),
388
+ Config().clients.per_round,
389
+ )
390
+ # Otherwise, the limited number is the same as the number of clients per round
391
+ else:
392
+ client_processes = Config().clients.per_round
393
+
394
+ if as_server:
395
+ total_processes = Config().algorithm.total_silos
396
+ starting_id += Config().clients.total_clients
397
+ else:
398
+ total_processes = client_processes
399
+
400
+ if mp.get_start_method(allow_none=True) != "spawn":
401
+ mp.set_start_method("spawn", force=True)
402
+
403
+ for client_id in range(starting_id, total_processes + starting_id):
404
+ if as_server:
405
+ port = int(Config().server.port) + client_id
406
+ logging.info(
407
+ "Starting client #%d as an edge server on port %s.",
408
+ client_id,
409
+ port,
410
+ )
411
+ proc = mp.Process(
412
+ target=run,
413
+ args=(
414
+ client_id,
415
+ port,
416
+ client,
417
+ edge_server,
418
+ edge_client,
419
+ trainer,
420
+ ),
421
+ )
422
+ proc.start()
423
+ else:
424
+ logging.info("Starting client #%d's process.", client_id)
425
+ proc = mp.Process(
426
+ target=run, args=(client_id, None, client, None, None, None)
427
+ )
428
+ proc.start()
429
+
430
+ async def _close_connections(self):
431
+ """Closes all socket.io connections after training completes."""
432
+ for client_id, client in dict(self.clients).items():
433
+ logging.info("Closing the connection to client #%d.", client_id)
434
+ await self.sio.emit("disconnect", room=client["sid"])
435
+
436
+ async def _select_clients(self, for_next_batch=False):
437
+ """Selects a subset of the clients and send messages to them to start training."""
438
+ if not for_next_batch:
439
+ self.updates = []
440
+ self.current_round += 1
441
+ self.round_start_wall_time = self.wall_time
442
+
443
+ if hasattr(Config().trainer, "max_concurrency"):
444
+ self.trained_clients = []
445
+
446
+ logging.info(
447
+ fonts.colourize(
448
+ f"\n[{self}] Starting round {self.current_round}/{Config().trainer.rounds}."
449
+ )
450
+ )
451
+
452
+ if Config().is_central_server():
453
+ # In cross-silo FL, the central server selects from the pool of edge servers
454
+ self.clients_pool = list(self.clients)
455
+
456
+ elif not Config().is_edge_server():
457
+ self.clients_pool = list(range(1, 1 + self.total_clients))
458
+
459
+ # In asychronous FL, avoid selecting new clients to replace those that are still
460
+ # training at this time
461
+
462
+ # When simulating the wall clock time, if len(self.reported_clients) is 0, the
463
+ # server has aggregated all reporting clients already
464
+ if (
465
+ self.asynchronous_mode
466
+ and self.selected_clients is not None
467
+ and len(self.reported_clients) > 0
468
+ and len(self.reported_clients) < self.clients_per_round
469
+ ):
470
+ # If self.selected_clients is None, it implies that it is the first iteration;
471
+ # If len(self.reported_clients) == self.clients_per_round, it implies that
472
+ # all selected clients have already reported.
473
+
474
+ # Except for these two cases, we need to exclude the clients who are still
475
+ # training.
476
+ training_client_ids = [
477
+ self.training_clients[client_id]["id"]
478
+ for client_id in self.training_clients
479
+ ]
480
+
481
+ # If the server is simulating the wall clock time, some of the clients who
482
+ # reported may not have been aggregated; they should be excluded from the next
483
+ # round of client selection
484
+ reporting_client_ids = [
485
+ client[2]["client_id"] for client in self.reported_clients
486
+ ]
487
+
488
+ selectable_clients = [
489
+ client
490
+ for client in self.clients_pool
491
+ if client not in training_client_ids
492
+ and client not in reporting_client_ids
493
+ ]
494
+
495
+ if self.simulate_wall_time:
496
+ self.selected_clients = self.choose_clients(
497
+ selectable_clients, len(self.current_processed_clients)
498
+ )
499
+ else:
500
+ self.selected_clients = self.choose_clients(
501
+ selectable_clients, len(self.reported_clients)
502
+ )
503
+ else:
504
+ self.selected_clients = self.choose_clients(
505
+ self.clients_pool, self.clients_per_round
506
+ )
507
+
508
+ self.current_reported_clients = {}
509
+ self.current_processed_clients = {}
510
+
511
+ # There is no need to clear the list of reporting clients if we are
512
+ # simulating the wall clock time on the server. This is because
513
+ # when wall clock time is simulated, the server needs to wait for
514
+ # all the clients to report before selecting a subset of clients for
515
+ # replacement, and all remaining reporting clients will be processed
516
+ # in the next round
517
+ if not self.simulate_wall_time:
518
+ self.reported_clients = []
519
+
520
+ if len(self.selected_clients) > 0:
521
+ self.selected_sids = []
522
+
523
+ # If max_concurrency is specified, run selected clients batch by batch,
524
+ # and the number of clients in each batch (on each GPU, if multiple GPUs are available)
525
+ # is equal to # (or maybe smaller than for the last batch) max_concurrency
526
+ if (
527
+ hasattr(Config().trainer, "max_concurrency")
528
+ and not Config().is_central_server()
529
+ ):
530
+ selected_clients = []
531
+ if Config().gpu_count() > 1:
532
+ untrained_clients = list(
533
+ set(self.selected_clients).difference(self.trained_clients)
534
+ )
535
+ available_gpus = Config().gpu_count()
536
+ for cuda_id in range(available_gpus):
537
+ for client_id in untrained_clients:
538
+ if client_id % available_gpus == cuda_id:
539
+ selected_clients.append(client_id)
540
+ if len(selected_clients) >= min(
541
+ len(self.clients),
542
+ (cuda_id + 1) * Config().trainer.max_concurrency,
543
+ self.clients_per_round,
544
+ ):
545
+ break
546
+ # There is no enough alive clients, break the selection
547
+ if len(selected_clients) >= len(self.clients):
548
+ break
549
+ else:
550
+ selected_clients = self.selected_clients[
551
+ len(self.trained_clients) : min(
552
+ len(self.trained_clients) + len(self.clients),
553
+ len(self.selected_clients),
554
+ )
555
+ ]
556
+
557
+ self.trained_clients += selected_clients
558
+
559
+ else:
560
+ selected_clients = self.selected_clients
561
+
562
+ for selected_client_id in selected_clients:
563
+ self.selected_client_id = selected_client_id
564
+
565
+ if Config().is_central_server():
566
+ client_process_id = selected_client_id
567
+ else:
568
+ client_processes = [client for client in self.clients]
569
+
570
+ # Find a client process that is currently not training
571
+ # or selected in this round
572
+ for process_id in client_processes:
573
+ current_sid = self.clients[process_id]["sid"]
574
+ if not (
575
+ current_sid in self.training_sids
576
+ or current_sid in self.selected_sids
577
+ ):
578
+ client_process_id = process_id
579
+ break
580
+
581
+ sid = self.clients[client_process_id]["sid"]
582
+
583
+ # Track the selected client process
584
+ self.training_sids.append(sid)
585
+ self.selected_sids.append(sid)
586
+
587
+ # Assign the client id to the client process
588
+ self.clients[client_process_id]["client_id"] = self.selected_client_id
589
+
590
+ self.training_clients[self.selected_client_id] = {
591
+ "id": self.selected_client_id,
592
+ "starting_round": self.current_round,
593
+ "start_time": self.round_start_wall_time,
594
+ "update_requested": False,
595
+ }
596
+
597
+ logging.info(
598
+ "[%s] Selecting client #%d for training.",
599
+ self,
600
+ self.selected_client_id,
601
+ )
602
+
603
+ server_response = {
604
+ "id": self.selected_client_id,
605
+ "current_round": self.current_round,
606
+ }
607
+ server_response = self.customize_server_response(
608
+ server_response, client_id=self.selected_client_id
609
+ )
610
+
611
+ payload = self.algorithm.extract_weights()
612
+ payload = self.customize_server_payload(payload)
613
+
614
+ if self.comm_simulation:
615
+ logging.info(
616
+ "[%s] Sending the current model to client #%d (simulated).",
617
+ self,
618
+ self.selected_client_id,
619
+ )
620
+
621
+ # First apply outbound processors, if any
622
+ payload = self.outbound_processor.process(payload)
623
+
624
+ model_name = (
625
+ Config().trainer.model_name
626
+ if hasattr(Config().trainer, "model_name")
627
+ else "custom"
628
+ )
629
+ if "/" in model_name:
630
+ model_name = model_name.replace("/", "_")
631
+
632
+ checkpoint_path = Config().params["checkpoint_path"]
633
+
634
+ payload_filename = (
635
+ f"{checkpoint_path}/{model_name}_{self.selected_client_id}.pth"
636
+ )
637
+
638
+ with open(payload_filename, "wb") as payload_file:
639
+ pickle.dump(payload, payload_file)
640
+
641
+ server_response["payload_filename"] = payload_filename
642
+
643
+ payload_size = sys.getsizeof(pickle.dumps(payload)) / 1024**2
644
+
645
+ logging.info(
646
+ "[%s] Sending %.2f MB of payload data to client #%d (simulated).",
647
+ self,
648
+ payload_size,
649
+ self.selected_client_id,
650
+ )
651
+
652
+ self.comm_overhead += payload_size
653
+
654
+ # Compute the communication time to transfer the current global model to client
655
+ self.downlink_comm_time[self.selected_client_id] = payload_size / (
656
+ (self.downlink_bandwidth / 8) / len(self.selected_clients)
657
+ )
658
+
659
+ # Send the server response as metadata to the clients (payload to follow)
660
+ await self.sio.emit(
661
+ "payload_to_arrive", {"response": server_response}, room=sid
662
+ )
663
+
664
+ if not self.comm_simulation:
665
+ # Send the server payload to the client
666
+ logging.info(
667
+ "[%s] Sending the current model to client #%d.",
668
+ self,
669
+ selected_client_id,
670
+ )
671
+
672
+ await self._send(sid, payload, selected_client_id)
673
+
674
+ self.clients_selected(self.selected_clients)
675
+ self.callback_handler.call_event(
676
+ "on_clients_selected", self, self.selected_clients
677
+ )
678
+
679
+ def choose_clients(self, clients_pool, clients_count):
680
+ """Chooses a subset of the clients to participate in each round."""
681
+ assert clients_count <= len(clients_pool)
682
+ random.setstate(self.prng_state)
683
+
684
+ # Select clients randomly
685
+ selected_clients = random.sample(clients_pool, clients_count)
686
+
687
+ self.prng_state = random.getstate()
688
+ logging.info("[%s] Selected clients: %s", self, selected_clients)
689
+ return selected_clients
690
+
691
+ async def _periodic(self, periodic_interval):
692
+ """Runs _periodic_task() periodically on the server. The time interval between
693
+ its execution is defined in 'server:periodic_interval'.
694
+ """
695
+ while True:
696
+ await self._periodic_task()
697
+ await asyncio.sleep(periodic_interval)
698
+
699
+ async def _periodic_task(self):
700
+ """A periodic task that is executed from time to time, determined by
701
+ 'server:periodic_interval' with a default value of 5 seconds, in the configuration.
702
+ """
703
+ # Call the async function that defines a customized periodic task, if any
704
+ await self.periodic_task()
705
+
706
+ # If we are operating in asynchronous mode, aggregate the model updates received so far.
707
+ if self.asynchronous_mode and not self.simulate_wall_time:
708
+ # Is there any training clients who are currently training on models that are too
709
+ # `stale,` as defined by the staleness threshold?
710
+ for __, client_data in self.training_clients.items():
711
+ # The client is still working at an early round, early enough to stop the
712
+ # aggregation process as determined by 'staleness'
713
+ client_staleness = self.current_round - client_data["starting_round"]
714
+ if client_staleness > self.staleness_bound:
715
+ logging.info(
716
+ "[%s] Client %s is still working at round %s, which is "
717
+ "beyond the staleness bound %s compared to the current round %s. "
718
+ "Nothing to process.",
719
+ self,
720
+ client_data["id"],
721
+ client_data["starting_round"],
722
+ self.staleness_bound,
723
+ self.current_round,
724
+ )
725
+
726
+ return
727
+
728
+ if len(self.updates) >= self.minimum_clients:
729
+ logging.info(
730
+ "[%s] %d client report(s) received in asynchronous mode. Processing.",
731
+ self,
732
+ len(self.updates),
733
+ )
734
+ await self._process_reports()
735
+ await self.wrap_up()
736
+ await self._select_clients()
737
+ else:
738
+ logging.info(
739
+ "[%s] No sufficient number of client reports have been received. "
740
+ "Nothing to process.",
741
+ self,
742
+ )
743
+
744
+ async def _send_in_chunks(self, data, sid, client_id) -> None:
745
+ """Sends a bytes object in fixed-sized chunks to the client."""
746
+ step = 1024**2
747
+ chunks = [data[i : i + step] for i in range(0, len(data), step)]
748
+
749
+ for chunk in chunks:
750
+ await self.sio.emit("chunk", {"data": chunk}, room=sid)
751
+
752
+ await self.sio.emit("payload", {"id": client_id}, room=sid)
753
+
754
+ async def _send(self, sid, payload, client_id) -> None:
755
+ """Sends a new data payload to the client using either S3 or socket.io."""
756
+ # First apply outbound processors, if any
757
+ payload = self.outbound_processor.process(payload)
758
+
759
+ metadata = {"id": client_id}
760
+
761
+ if self.s3_client is not None:
762
+ s3_key = f"server_payload_{os.getpid()}_{self.current_round}"
763
+ self.s3_client.send_to_s3(s3_key, payload)
764
+ data_size = sys.getsizeof(pickle.dumps(payload))
765
+ metadata["s3_key"] = s3_key
766
+ else:
767
+ data_size = 0
768
+
769
+ if isinstance(payload, list):
770
+ for data in payload:
771
+ _data = pickle.dumps(data)
772
+ await self._send_in_chunks(_data, sid, client_id)
773
+ data_size += sys.getsizeof(_data)
774
+
775
+ else:
776
+ _data = pickle.dumps(payload)
777
+ await self._send_in_chunks(_data, sid, client_id)
778
+ data_size = sys.getsizeof(_data)
779
+
780
+ await self.sio.emit("payload_done", metadata, room=sid)
781
+
782
+ logging.info(
783
+ "[%s] Sent %.2f MB of payload data to client #%d.",
784
+ self,
785
+ data_size / 1024**2,
786
+ client_id,
787
+ )
788
+
789
+ self.comm_overhead += data_size / 1024**2
790
+
791
+ async def _client_report_arrived(self, sid, client_id, report):
792
+ """Upon receiving a report from a client."""
793
+ self.reports[sid] = pickle.loads(report)
794
+ self.client_payload[sid] = None
795
+ self.client_chunks[sid] = []
796
+
797
+ if self.comm_simulation:
798
+ model_name = (
799
+ Config().trainer.model_name
800
+ if hasattr(Config().trainer, "model_name")
801
+ else "custom"
802
+ )
803
+ if "/" in model_name:
804
+ model_name = model_name.replace("/", "_")
805
+ checkpoint_path = Config().params["checkpoint_path"]
806
+ payload_filename = f"{checkpoint_path}/{model_name}_client_{client_id}.pth"
807
+ with open(payload_filename, "rb") as payload_file:
808
+ self.client_payload[sid] = pickle.load(payload_file)
809
+
810
+ payload_size = (
811
+ sys.getsizeof(pickle.dumps(self.client_payload[sid])) / 1024**2
812
+ )
813
+
814
+ logging.info(
815
+ "[%s] Received %.2f MB of payload data from client #%d (simulated).",
816
+ self,
817
+ payload_size,
818
+ client_id,
819
+ )
820
+
821
+ self.comm_overhead += payload_size
822
+
823
+ self.uplink_comm_time[client_id] = payload_size / (
824
+ self.uplink_bandwidth / 8
825
+ )
826
+
827
+ await self.process_client_info(client_id, sid)
828
+
829
+ async def _client_chunk_arrived(self, sid, data) -> None:
830
+ """Upon receiving a chunk of data from a client."""
831
+ self.client_chunks[sid].append(data)
832
+
833
+ async def _client_payload_arrived(self, sid, client_id):
834
+ """Upon receiving a portion of the payload from a client."""
835
+ assert len(self.client_chunks[sid]) > 0 and client_id in self.training_clients
836
+
837
+ payload = b"".join(self.client_chunks[sid])
838
+ _data = pickle.loads(payload)
839
+ self.client_chunks[sid] = []
840
+
841
+ if self.client_payload[sid] is None:
842
+ self.client_payload[sid] = _data
843
+ elif isinstance(self.client_payload[sid], list):
844
+ self.client_payload[sid].append(_data)
845
+ else:
846
+ self.client_payload[sid] = [self.client_payload[sid]]
847
+ self.client_payload[sid].append(_data)
848
+
849
+ async def _client_payload_done(self, sid, client_id, s3_key=None):
850
+ """Upon receiving all the payload from a client, either via S3 or socket.io."""
851
+ if s3_key is None:
852
+ assert self.client_payload[sid] is not None
853
+
854
+ payload_size = 0
855
+ if isinstance(self.client_payload[sid], list):
856
+ for _data in self.client_payload[sid]:
857
+ payload_size += sys.getsizeof(pickle.dumps(_data))
858
+ else:
859
+ payload_size = sys.getsizeof(pickle.dumps(self.client_payload[sid]))
860
+ else:
861
+ self.client_payload[sid] = self.s3_client.receive_from_s3(s3_key)
862
+ payload_size = sys.getsizeof(pickle.dumps(self.client_payload[sid]))
863
+
864
+ logging.info(
865
+ "[%s] Received %.2f MB of payload data from client #%d.",
866
+ self,
867
+ payload_size / 1024**2,
868
+ client_id,
869
+ )
870
+
871
+ self.comm_overhead += payload_size / 1024**2
872
+
873
+ await self.process_client_info(client_id, sid)
874
+
875
+ async def process_client_info(self, client_id, sid):
876
+ """Processes the received metadata information from a reporting client."""
877
+ # First pass through the inbound_processor(s), if any
878
+ self.client_payload[sid] = self.inbound_processor.process(
879
+ self.client_payload[sid]
880
+ )
881
+
882
+ if self.comm_simulation:
883
+ if (
884
+ hasattr(Config().clients, "compute_comm_time")
885
+ and Config().clients.compute_comm_time
886
+ ):
887
+ self.reports[sid].comm_time = (
888
+ self.downlink_comm_time[client_id]
889
+ + self.uplink_comm_time[client_id]
890
+ )
891
+ else:
892
+ self.reports[sid].comm_time = 0
893
+ else:
894
+ self.reports[sid].comm_time = time.time() - self.reports[sid].comm_time
895
+
896
+ # When the client is responding to an urgent request for an update, it will
897
+ # store its (possibly different) client ID in its report
898
+ client_id = self.reports[sid].client_id
899
+
900
+ start_time = self.training_clients[client_id]["start_time"]
901
+ finish_time = (
902
+ self.reports[sid].training_time
903
+ + self.reports[sid].processing_time
904
+ + self.reports[sid].comm_time
905
+ + start_time
906
+ )
907
+ starting_round = self.training_clients[client_id]["starting_round"]
908
+
909
+ if Config().is_central_server():
910
+ self.comm_overhead += self.reports[sid].edge_server_comm_overhead
911
+
912
+ client_info = (
913
+ finish_time, # sorted by the client's finish time
914
+ client_id, # in case two or more clients have the same finish time
915
+ {
916
+ "client_id": client_id,
917
+ "sid": sid,
918
+ "starting_round": starting_round,
919
+ "start_time": start_time,
920
+ "report": self.reports[sid],
921
+ "payload": self.client_payload[sid],
922
+ },
923
+ )
924
+
925
+ if self.asynchronous_mode and self.simulate_wall_time:
926
+ heapq.heappush(self.reported_clients, client_info)
927
+ self.current_reported_clients[client_info[2]["client_id"]] = True
928
+ del self.training_clients[client_id]
929
+
930
+ self.training_sids.remove(client_info[2]["sid"])
931
+
932
+ await self._process_clients(client_info)
933
+
934
+ # pylint: disable=unused-argument
935
+ def should_request_update(
936
+ self, client_id, start_time, finish_time, client_staleness, report
937
+ ):
938
+ """Determines if an explicit request for model update should be sent to the client."""
939
+ return client_staleness > self.staleness_bound and finish_time > self.wall_time
940
+
941
+ async def _process_clients(self, client_info):
942
+ """Determines whether it is time to process the client reports and
943
+ proceed with the aggregation process.
944
+
945
+ When in asynchronous mode, additional processing is needed to simulate
946
+ the wall clock time.
947
+ """
948
+ # In asynchronous mode with simulated wall clock time, we need to extract
949
+ # the minimum number of clients from the list of all reporting clients, and then
950
+ # proceed with report processing and replace these clients with a new set of
951
+ # selected clients
952
+ if (
953
+ self.asynchronous_mode
954
+ and self.simulate_wall_time
955
+ and len(self.current_reported_clients) >= len(self.selected_clients)
956
+ ):
957
+ # Step 1: Sanity checks to see if there are any stale clients; if so, send them
958
+ # an urgent request for model updates at the current simulated wall clock time
959
+ if self.request_update:
960
+ # We should not proceed with further processing if there are outstanding requests
961
+ # for urgent client updates
962
+ for __, client_data in self.training_clients.items():
963
+ if client_data["update_requested"]:
964
+ return
965
+
966
+ request_sent = False
967
+ for i, client_info in enumerate(self.reported_clients):
968
+ client = client_info[2]
969
+ client_staleness = self.current_round - client["starting_round"]
970
+
971
+ if (
972
+ self.should_request_update(
973
+ client_id=client["client_id"],
974
+ start_time=client["start_time"],
975
+ finish_time=client_info[0],
976
+ client_staleness=client_staleness,
977
+ report=client["report"],
978
+ )
979
+ and not client["report"].update_response
980
+ ):
981
+ # Sending an urgent request to the client for a model update at the
982
+ # currently simulated wall clock time
983
+ client_id = client["client_id"]
984
+
985
+ logging.info(
986
+ "[Server #%s] Requesting urgent model update from client #%s.",
987
+ os.getpid(),
988
+ client_id,
989
+ )
990
+
991
+ # Remove the client information from the list of reporting clients since
992
+ # this client will report again soon with another model update upon
993
+ # receiving the request from the server
994
+ del self.reported_clients[i]
995
+
996
+ self.training_clients[client_id] = {
997
+ "id": client_id,
998
+ "starting_round": client["starting_round"],
999
+ "start_time": client["start_time"],
1000
+ "update_requested": True,
1001
+ }
1002
+
1003
+ sid = client["sid"]
1004
+
1005
+ self.training_sids.append(sid)
1006
+
1007
+ await self.sio.emit(
1008
+ "request_update",
1009
+ {
1010
+ "client_id": client_id,
1011
+ "time": self.wall_time - client["start_time"],
1012
+ },
1013
+ room=sid,
1014
+ )
1015
+ request_sent = True
1016
+
1017
+ # If an urgent request was sent, we will wait until the client gets back to proceed
1018
+ # with aggregation.
1019
+ if request_sent:
1020
+ return
1021
+
1022
+ # Step 2: Processing clients in chronological order of finish times in wall clock time
1023
+ for __ in range(
1024
+ 0, min(len(self.current_reported_clients), self.minimum_clients)
1025
+ ):
1026
+ # Extract a client with the earliest finish time in wall clock time
1027
+ client_info = heapq.heappop(self.reported_clients)
1028
+ client = client_info[2]
1029
+
1030
+ # Removing from the list of current reporting clients as well, if needed
1031
+ self.current_processed_clients[client["client_id"]] = True
1032
+
1033
+ # Update the simulated wall clock time to be the finish time of this client
1034
+ self.wall_time = client_info[0]
1035
+
1036
+ # Add the report and payload of the extracted reporting client into updates
1037
+ logging.info(
1038
+ "[Server #%s] Adding client #%s to the list of clients for aggregation.",
1039
+ os.getpid(),
1040
+ client["client_id"],
1041
+ )
1042
+
1043
+ client_staleness = self.current_round - client["starting_round"]
1044
+ self.updates.append(
1045
+ SimpleNamespace(
1046
+ client_id=client["client_id"],
1047
+ report=client["report"],
1048
+ payload=client["payload"],
1049
+ staleness=client_staleness,
1050
+ )
1051
+ )
1052
+
1053
+ # Step 3: Processing stale clients that exceed a staleness threshold
1054
+
1055
+ # If there are more clients in the list of reporting clients that violate the
1056
+ # staleness bound, the server needs to wait for these clients even when the minimum
1057
+ # number of clients has been reached, by simply advancing its simulated wall clock
1058
+ # time ahead to include the remaining clients, until no stale clients exist
1059
+ possibly_stale_clients = []
1060
+
1061
+ # Is there any reporting clients who are currently training on models that are too
1062
+ # `stale,` as defined by the staleness threshold? If so, we need to advance the wall
1063
+ # clock time until no stale clients exist in the future
1064
+ for __ in range(0, len(self.reported_clients)):
1065
+ # Extract a client with the earliest finish time in wall clock time
1066
+ client_info = heapq.heappop(self.reported_clients)
1067
+ heapq.heappush(possibly_stale_clients, client_info)
1068
+
1069
+ if (
1070
+ client_info[2]["starting_round"]
1071
+ < self.current_round - self.staleness_bound
1072
+ ):
1073
+ for __ in range(0, len(possibly_stale_clients)):
1074
+ stale_client_info = heapq.heappop(possibly_stale_clients)
1075
+ # Update the simulated wall clock time to be the finish time of this client
1076
+ self.wall_time = stale_client_info[0]
1077
+ client = stale_client_info[2]
1078
+
1079
+ # Add the report and payload of the extracted reporting client into updates
1080
+ logging.info(
1081
+ "[Server #%s] Adding client #%s to the list of clients for "
1082
+ "aggregation.",
1083
+ os.getpid(),
1084
+ client["client_id"],
1085
+ )
1086
+
1087
+ client_staleness = self.current_round - client["starting_round"]
1088
+ self.updates.append(
1089
+ SimpleNamespace(
1090
+ client_id=client["client_id"],
1091
+ report=client["report"],
1092
+ payload=client["payload"],
1093
+ staleness=client_staleness,
1094
+ )
1095
+ )
1096
+
1097
+ self.reported_clients = possibly_stale_clients
1098
+ logging.info(
1099
+ "[Server #%s] Aggregating %s clients in total.",
1100
+ os.getpid(),
1101
+ len(self.updates),
1102
+ )
1103
+
1104
+ await self._process_reports()
1105
+ await self.wrap_up()
1106
+ await self._select_clients()
1107
+ return
1108
+
1109
+ if not self.simulate_wall_time or not self.asynchronous_mode:
1110
+ # In both synchronous and asynchronous modes, if we are not simulating the wall clock
1111
+ # time, we need to add the client report to the list of updates so far;
1112
+ # the same applies when we are running in synchronous mode.
1113
+ client = client_info[2]
1114
+ client_staleness = self.current_round - client["starting_round"]
1115
+
1116
+ self.updates.append(
1117
+ SimpleNamespace(
1118
+ client_id=client["client_id"],
1119
+ report=client["report"],
1120
+ payload=client["payload"],
1121
+ staleness=client_staleness,
1122
+ )
1123
+ )
1124
+
1125
+ if not self.simulate_wall_time:
1126
+ # In both synchronous and asynchronous modes, if we are not simulating the wall clock
1127
+ # time, it will need to be updated to the real wall clock time
1128
+ self.wall_time = time.time()
1129
+
1130
+ if not self.asynchronous_mode and self.simulate_wall_time:
1131
+ # In synchronous mode with the wall clock time simulated, in addition to adding
1132
+ # the client report to the list of updates, we will also need to advance the wall
1133
+ # clock time to the finish time of the reporting client
1134
+ client_finish_time = client_info[0]
1135
+ self.wall_time = max(client_finish_time, self.wall_time)
1136
+
1137
+ logging.info(
1138
+ "[%s] Advancing the wall clock time to %.2f.",
1139
+ self,
1140
+ self.wall_time,
1141
+ )
1142
+
1143
+ # If all updates have been received from selected clients, the aggregation process
1144
+ # proceeds regardless of synchronous or asynchronous modes. This guarantees that
1145
+ # if asynchronous mode uses an excessively long aggregation interval, it will not
1146
+ # unnecessarily delay the aggregation process.
1147
+ if len(self.updates) >= self.clients_per_round:
1148
+ logging.info(
1149
+ "[%s] All %d client report(s) received. Processing.",
1150
+ self,
1151
+ len(self.updates),
1152
+ )
1153
+ await self._process_reports()
1154
+ await self.wrap_up()
1155
+ await self._select_clients()
1156
+
1157
+ elif (
1158
+ hasattr(Config().trainer, "max_concurrency")
1159
+ and not Config().is_central_server()
1160
+ ):
1161
+ # Clients in the current batch finish training
1162
+ # The server will select the next batch of clients to train
1163
+ if len(self.updates) >= len(self.trained_clients) or len(
1164
+ self.current_reported_clients
1165
+ ) >= len(self.trained_clients):
1166
+ await self._select_clients(for_next_batch=True)
1167
+
1168
+ async def _client_disconnected(self, sid):
1169
+ """When a client process disconnected it should be removed from its internal states."""
1170
+ for client_process_id, client in dict(self.clients).items():
1171
+ if client["sid"] == sid:
1172
+ # Obtain the client id before deleting
1173
+ client_id = self.clients[client_process_id]["client_id"]
1174
+
1175
+ # Remove the physical client from server list
1176
+ del self.clients[client_process_id]
1177
+ logging.warning(
1178
+ "[%s] Client process #%d disconnected and removed from this server, %d client processes are remaining.",
1179
+ self,
1180
+ client_process_id,
1181
+ len(self.clients),
1182
+ )
1183
+
1184
+ if len(self.clients) == 0:
1185
+ logging.warning(
1186
+ fonts.colourize(
1187
+ f"[{self}] All clients disconnected, closing the server."
1188
+ )
1189
+ )
1190
+ await self._close()
1191
+
1192
+ # Handle the logical client under different situations
1193
+ if client_id in self.training_clients:
1194
+ del self.training_clients[client_id]
1195
+
1196
+ if client_id in self.current_reported_clients:
1197
+ del self.current_reported_clients[client_id]
1198
+
1199
+ # Decide continue or exit training
1200
+ if (
1201
+ hasattr(Config(), "general")
1202
+ and hasattr(Config().general, "debug")
1203
+ and not Config().general.debug
1204
+ ):
1205
+ # Recover from the failed client and proceed with training
1206
+ if (
1207
+ client_id in self.selected_clients
1208
+ and client_id in self.trained_clients
1209
+ ):
1210
+ self.trained_clients.remove(client_id)
1211
+ fail_client_index = self.selected_clients.index(client_id)
1212
+ untrained_client_index = len(self.trained_clients)
1213
+
1214
+ # Swap current client to the begining of untrained clients
1215
+ self.selected_clients[fail_client_index] = (
1216
+ self.selected_clients[untrained_client_index]
1217
+ )
1218
+ self.selected_clients[untrained_client_index] = client_id
1219
+
1220
+ # Start next batch of client selection if current batch is done
1221
+ if len(self.updates) >= len(self.trained_clients) or len(
1222
+ self.current_reported_clients
1223
+ ) >= len(self.trained_clients):
1224
+ await self._select_clients(for_next_batch=True)
1225
+ else:
1226
+ # Debug is either turned on or not specified, stop the training to avoid blocking.
1227
+ logging.warning(
1228
+ fonts.colourize(
1229
+ f"[{self}] Closing the server due to a failed client."
1230
+ )
1231
+ )
1232
+ await self._close()
1233
+
1234
+ def save_to_checkpoint(self) -> None:
1235
+ """Saves a checkpoint for resuming the training session."""
1236
+ checkpoint_path = Config.params["checkpoint_path"]
1237
+
1238
+ model_name = (
1239
+ Config().trainer.model_name
1240
+ if hasattr(Config().trainer, "model_name")
1241
+ else "custom"
1242
+ )
1243
+ if "/" in model_name:
1244
+ model_name = model_name.replace("/", "_")
1245
+ filename = f"checkpoint_{model_name}_{self.current_round}.pth"
1246
+ logging.info(
1247
+ "[%s] Saving the checkpoint to %s/%s.",
1248
+ self,
1249
+ checkpoint_path,
1250
+ filename,
1251
+ )
1252
+ self.trainer.save_model(filename, checkpoint_path)
1253
+ self._save_random_states(self.current_round, checkpoint_path)
1254
+
1255
+ # Saving the current round in the server for resuming its session later on
1256
+ with open(f"{checkpoint_path}/current_round.pkl", "wb") as checkpoint_file:
1257
+ pickle.dump(self.current_round, checkpoint_file)
1258
+
1259
+ def _resume_from_checkpoint(self):
1260
+ """Resumes a training session from a previously saved checkpoint."""
1261
+ logging.info(
1262
+ "[%s] Resume a training session from a previously saved checkpoint.",
1263
+ self,
1264
+ )
1265
+
1266
+ # Loading important data in the server for resuming its session
1267
+ checkpoint_path = Config.params["checkpoint_path"]
1268
+
1269
+ with open(f"{checkpoint_path}/current_round.pkl", "rb") as checkpoint_file:
1270
+ self.current_round = pickle.load(checkpoint_file)
1271
+
1272
+ self._restore_random_states(self.current_round, checkpoint_path)
1273
+ self.resumed_session = True
1274
+
1275
+ model_name = (
1276
+ Config().trainer.model_name
1277
+ if hasattr(Config().trainer, "model_name")
1278
+ else "custom"
1279
+ )
1280
+ filename = f"checkpoint_{model_name}_{self.current_round}.pth"
1281
+ self.trainer.load_model(filename, checkpoint_path)
1282
+
1283
+ def _save_random_states(self, round_to_save, checkpoint_path):
1284
+ """Saves the random states in the server for resuming its session later on."""
1285
+ states_to_save = [
1286
+ f"numpy_prng_state_{round_to_save}",
1287
+ f"prng_state_{round_to_save}",
1288
+ ]
1289
+
1290
+ variables_to_save = [
1291
+ np.random.get_state(),
1292
+ random.getstate(),
1293
+ ]
1294
+
1295
+ for i, state in enumerate(states_to_save):
1296
+ with open(f"{checkpoint_path}/{state}.pkl", "wb") as checkpoint_file:
1297
+ pickle.dump(variables_to_save[i], checkpoint_file)
1298
+
1299
+ def _restore_random_states(self, round_to_restore, checkpoint_path):
1300
+ """Restors the numpy.random and random states from previously saved checkpoints
1301
+ for a particular round.
1302
+ """
1303
+ states_to_load = ["numpy_prng_state", "prng_state"]
1304
+ variables_to_load = {}
1305
+
1306
+ for i, state in enumerate(states_to_load):
1307
+ with open(
1308
+ f"{checkpoint_path}/{state}_{round_to_restore}.pkl", "rb"
1309
+ ) as checkpoint_file:
1310
+ variables_to_load[i] = pickle.load(checkpoint_file)
1311
+
1312
+ numpy_prng_state = variables_to_load[0]
1313
+ self.prng_state = variables_to_load[1]
1314
+
1315
+ np.random.set_state(numpy_prng_state)
1316
+ random.setstate(self.prng_state)
1317
+
1318
+ async def wrap_up(self) -> None:
1319
+ """Wraps up when each round of training is done."""
1320
+ self.save_to_checkpoint()
1321
+
1322
+ # Break the loop when the target accuracy is achieved
1323
+ target_accuracy = None
1324
+ target_perplexity = None
1325
+
1326
+ if hasattr(Config().trainer, "target_accuracy"):
1327
+ target_accuracy = Config().trainer.target_accuracy
1328
+ elif hasattr(Config().trainer, "target_perplexity"):
1329
+ target_perplexity = Config().trainer.target_perplexity
1330
+
1331
+ if target_accuracy and self.accuracy >= target_accuracy:
1332
+ logging.info("[%s] Target accuracy reached.", self)
1333
+ await self._close()
1334
+
1335
+ if target_perplexity and self.accuracy <= target_perplexity:
1336
+ logging.info("[%s] Target perplexity reached.", self)
1337
+ await self._close()
1338
+
1339
+ if self.current_round >= Config().trainer.rounds:
1340
+ logging.info("Target number of training rounds reached.")
1341
+ await self._close()
1342
+
1343
+ async def _close(self):
1344
+ """Closes the server."""
1345
+ logging.info("[%s] Training concluded.", self)
1346
+ self.trainer.save_model()
1347
+
1348
+ self.server_will_close()
1349
+ self.callback_handler.call_event("on_server_will_close", self)
1350
+
1351
+ await self._close_connections()
1352
+ os._exit(0)
1353
+
1354
+ def add_callbacks(self, callbacks):
1355
+ """Adds a list of callbacks to the server callback handler."""
1356
+ self.callback_handler.add_callbacks(callbacks)
1357
+
1358
+ def customize_server_response(self, server_response: dict, client_id) -> dict:
1359
+ """Customizes the server response with any additional information."""
1360
+ return server_response
1361
+
1362
+ def customize_server_payload(self, payload):
1363
+ """Customizes the server payload before sending to the client."""
1364
+ return payload
1365
+
1366
+ @abstractmethod
1367
+ async def _process_reports(self) -> None:
1368
+ """Processes a client report."""
1369
+
1370
+ async def periodic_task(self) -> None:
1371
+ """
1372
+ Async method called periodically in asynchronous mode.
1373
+ """
1374
+
1375
+ def clients_selected(self, selected_clients) -> None:
1376
+ """
1377
+ Method called after clients have been selected in each round."""
1378
+
1379
+ def clients_processed(self) -> None:
1380
+ """Additional work to be performed after client reports have been processed."""
1381
+
1382
+ def training_will_start(self) -> None:
1383
+ """
1384
+ Method called before selecting clients for the first round of training.
1385
+ """
1386
+ if Config().is_central_server():
1387
+ if self.disable_clients:
1388
+ logging.info("No clients are launched (server:disable_clients = true)")
1389
+ else:
1390
+ Server._start_clients(client=self.client)
1391
+
1392
+ def server_will_close(self) -> None:
1393
+ """
1394
+ Method called before closing the server.
1395
+ """