plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/servers/base.py
ADDED
@@ -0,0 +1,1395 @@
|
|
1
|
+
"""
|
2
|
+
The base class for federated learning servers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import heapq
|
7
|
+
import logging
|
8
|
+
import multiprocessing as mp
|
9
|
+
import os
|
10
|
+
import pickle
|
11
|
+
import random
|
12
|
+
import sys
|
13
|
+
import time
|
14
|
+
from abc import abstractmethod
|
15
|
+
from types import SimpleNamespace
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import socketio
|
19
|
+
from aiohttp import web
|
20
|
+
|
21
|
+
from plato.callbacks.handler import CallbackHandler
|
22
|
+
from plato.callbacks.server import LogProgressCallback
|
23
|
+
from plato.client import run
|
24
|
+
from plato.config import Config
|
25
|
+
from plato.utils import fonts, s3
|
26
|
+
|
27
|
+
|
28
|
+
# pylint: disable=unused-argument, protected-access
|
29
|
+
class ServerEvents(socketio.AsyncNamespace):
|
30
|
+
"""A custom namespace for socketio.AsyncServer."""
|
31
|
+
|
32
|
+
def __init__(self, namespace, plato_server):
|
33
|
+
super().__init__(namespace)
|
34
|
+
self.plato_server = plato_server
|
35
|
+
|
36
|
+
async def on_connect(self, sid, environ):
|
37
|
+
"""Upon a new connection from a client."""
|
38
|
+
logging.info("[Server #%d] A new client just connected.", os.getpid())
|
39
|
+
|
40
|
+
async def on_disconnect(self, sid, reason=None):
|
41
|
+
"""Upon a disconnection event."""
|
42
|
+
logging.info("[Server #%d] An existing client just disconnected.", os.getpid())
|
43
|
+
await self.plato_server._client_disconnected(sid)
|
44
|
+
|
45
|
+
async def on_client_alive(self, sid, data):
|
46
|
+
"""A new client arrived."""
|
47
|
+
await self.plato_server.register_client(sid, data["pid"], data["id"])
|
48
|
+
|
49
|
+
async def on_client_report(self, sid, data):
|
50
|
+
"""An existing client sends a new report from local training."""
|
51
|
+
await self.plato_server._client_report_arrived(sid, data["id"], data["report"])
|
52
|
+
|
53
|
+
async def on_chunk(self, sid, data):
|
54
|
+
"""A chunk of data from the server arrived."""
|
55
|
+
await self.plato_server._client_chunk_arrived(sid, data["data"])
|
56
|
+
|
57
|
+
async def on_client_payload(self, sid, data):
|
58
|
+
"""An existing client sends a new payload from local training."""
|
59
|
+
await self.plato_server._client_payload_arrived(sid, data["id"])
|
60
|
+
|
61
|
+
async def on_client_payload_done(self, sid, data):
|
62
|
+
"""An existing client finished sending its payloads from local training."""
|
63
|
+
if "s3_key" in data:
|
64
|
+
await self.plato_server._client_payload_done(
|
65
|
+
sid, data["id"], s3_key=data["s3_key"]
|
66
|
+
)
|
67
|
+
else:
|
68
|
+
await self.plato_server._client_payload_done(sid, data["id"])
|
69
|
+
|
70
|
+
|
71
|
+
class Server:
|
72
|
+
"""The base class for federated learning servers."""
|
73
|
+
|
74
|
+
def __init__(self, callbacks=None):
|
75
|
+
self.sio = None
|
76
|
+
self.client = None
|
77
|
+
self.clients = {}
|
78
|
+
self.total_clients = 0
|
79
|
+
# The client ids are stored for client selection
|
80
|
+
self.clients_pool = []
|
81
|
+
self.clients_per_round = 0
|
82
|
+
self.selected_clients = None
|
83
|
+
self.selected_client_id = 0
|
84
|
+
self.selected_sids = []
|
85
|
+
self.current_round = 0
|
86
|
+
self.resumed_session = False
|
87
|
+
self.algorithm = None
|
88
|
+
self.trainer = None
|
89
|
+
self.accuracy = 0
|
90
|
+
self.accuracy_std = 0
|
91
|
+
self.reports = {}
|
92
|
+
self.updates = []
|
93
|
+
self.client_payload = {}
|
94
|
+
self.client_chunks = {}
|
95
|
+
self.s3_client = None
|
96
|
+
self.outbound_processor = None
|
97
|
+
self.inbound_processor = None
|
98
|
+
self.comm_simulation = (
|
99
|
+
Config().clients.comm_simulation
|
100
|
+
if hasattr(Config().clients, "comm_simulation")
|
101
|
+
else True
|
102
|
+
)
|
103
|
+
|
104
|
+
# Starting from the default server callback class, add all supplied server callbacks
|
105
|
+
self.callbacks = [LogProgressCallback]
|
106
|
+
if callbacks is not None:
|
107
|
+
self.callbacks.extend(callbacks)
|
108
|
+
self.callback_handler = CallbackHandler(self.callbacks)
|
109
|
+
|
110
|
+
# Accumulated communication overhead (MB) throughout the FL training session
|
111
|
+
self.comm_overhead = 0
|
112
|
+
|
113
|
+
# Downlink and uplink bandwidth (Mbps)
|
114
|
+
# for computing communication time in communication simulation mode
|
115
|
+
self.downlink_bandwidth = (
|
116
|
+
Config().server.downlink_bandwidth
|
117
|
+
if hasattr(Config().server, "downlink_bandwidth")
|
118
|
+
else 100
|
119
|
+
)
|
120
|
+
self.uplink_bandwidth = (
|
121
|
+
Config().server.uplink_bandwidth
|
122
|
+
if hasattr(Config().server, "uplink_bandwidth")
|
123
|
+
else 100
|
124
|
+
)
|
125
|
+
if Config().is_edge_server():
|
126
|
+
if hasattr(Config().server, "edge_downlink_bandwidth"):
|
127
|
+
self.downlink_bandwidth = Config().server.edge_downlink_bandwidth
|
128
|
+
if hasattr(Config().server, "edge_uplink_bandwidth"):
|
129
|
+
self.uplink_bandwidth = Config().server.edge_uplink_bandwidth
|
130
|
+
|
131
|
+
# Use dictionaries to record downlink/uplink communication time of each client
|
132
|
+
self.downlink_comm_time = {}
|
133
|
+
self.uplink_comm_time = {}
|
134
|
+
|
135
|
+
# States that need to be maintained for asynchronous FL
|
136
|
+
|
137
|
+
# sids that are currently in use
|
138
|
+
self.training_sids = []
|
139
|
+
|
140
|
+
# Clients whose new reports were received but not yet processed
|
141
|
+
self.reported_clients = []
|
142
|
+
|
143
|
+
# Clients who are still training since the last round of aggregation
|
144
|
+
self.training_clients = {}
|
145
|
+
|
146
|
+
# The wall clock time that is simulated to accommodate the fact that
|
147
|
+
# clients can only run a batch at a time, controlled by `max_concurrency`
|
148
|
+
self.initial_wall_time = time.time()
|
149
|
+
self.wall_time = time.time()
|
150
|
+
|
151
|
+
# The wall clock time when a communication round starts
|
152
|
+
self.round_start_wall_time = self.wall_time
|
153
|
+
|
154
|
+
# When simulating the wall clock time, the server needs to remember the
|
155
|
+
# set of reporting clients received since the previous round of aggregation
|
156
|
+
self.current_reported_clients = {}
|
157
|
+
self.current_processed_clients = {}
|
158
|
+
self.prng_state = random.getstate()
|
159
|
+
|
160
|
+
self.ping_interval = 3600
|
161
|
+
self.ping_timeout = 3600
|
162
|
+
self.asynchronous_mode = False
|
163
|
+
self.periodic_interval = 5
|
164
|
+
self.staleness_bound = 1000
|
165
|
+
self.minimum_clients = 1
|
166
|
+
self.simulate_wall_time = False
|
167
|
+
self.request_update = False
|
168
|
+
self.disable_clients = False
|
169
|
+
|
170
|
+
# With specifying max_concurrency, selected clients run batch by batach
|
171
|
+
# The number of clients in a batch on an available device is the same as the max_concurrency
|
172
|
+
# This list contains ids of selected clients that has run in the current round
|
173
|
+
if hasattr(Config().trainer, "max_concurrency"):
|
174
|
+
self.trained_clients = []
|
175
|
+
|
176
|
+
def __repr__(self):
|
177
|
+
return f"Server #{os.getpid()}"
|
178
|
+
|
179
|
+
def __str__(self):
|
180
|
+
return f"Server #{os.getpid()}"
|
181
|
+
|
182
|
+
def configure(self) -> None:
|
183
|
+
"""Initializes configuration settings based on the configuration file."""
|
184
|
+
logging.info("[%s] Configuring the server...", self)
|
185
|
+
|
186
|
+
# Ping interval and timeout setup for the server
|
187
|
+
self.ping_interval = (
|
188
|
+
Config().server.ping_interval
|
189
|
+
if hasattr(Config().server, "ping_interval")
|
190
|
+
else 3600
|
191
|
+
)
|
192
|
+
self.ping_timeout = (
|
193
|
+
Config().server.ping_timeout
|
194
|
+
if hasattr(Config().server, "ping_timeout")
|
195
|
+
else 3600
|
196
|
+
)
|
197
|
+
|
198
|
+
# Are we operating in asynchronous mode?
|
199
|
+
self.asynchronous_mode = (
|
200
|
+
hasattr(Config().server, "synchronous") and not Config().server.synchronous
|
201
|
+
)
|
202
|
+
|
203
|
+
# What is the periodic interval for running our periodic task in asynchronous mode?
|
204
|
+
self.periodic_interval = (
|
205
|
+
Config().server.periodic_interval
|
206
|
+
if hasattr(Config().server, "periodic_interval")
|
207
|
+
else 5
|
208
|
+
)
|
209
|
+
|
210
|
+
# The staleness threshold is used to determine if a training clients should be
|
211
|
+
# considered 'stale', if their starting round is too much behind the current round
|
212
|
+
# on the server
|
213
|
+
self.staleness_bound = (
|
214
|
+
Config().server.staleness_bound
|
215
|
+
if hasattr(Config().server, "staleness_bound")
|
216
|
+
else 0
|
217
|
+
)
|
218
|
+
|
219
|
+
if not Config().is_central_server():
|
220
|
+
# What is the minimum number of clients that must have reported before aggregation
|
221
|
+
# takes place?
|
222
|
+
self.minimum_clients = (
|
223
|
+
Config().server.minimum_clients_aggregated
|
224
|
+
if hasattr(Config().server, "minimum_clients_aggregated")
|
225
|
+
else 1
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
# In cross-silo FL, what is the minimum number of edge servers that must have reported
|
229
|
+
# before the central server conduct aggregation?
|
230
|
+
self.minimum_clients = (
|
231
|
+
Config().server.minimum_edges_aggregated
|
232
|
+
if hasattr(Config().server, "minimum_edges_aggregated")
|
233
|
+
else Config().algorithm.total_silos
|
234
|
+
)
|
235
|
+
|
236
|
+
# Are we simulating the wall clock time on the server? This is useful when the clients
|
237
|
+
# are training in batches due to a lack of memory on the GPUs
|
238
|
+
self.simulate_wall_time = (
|
239
|
+
hasattr(Config().server, "simulate_wall_time")
|
240
|
+
and Config().server.simulate_wall_time
|
241
|
+
)
|
242
|
+
|
243
|
+
# Do we wish to send urgent requests for model updates to the slow clients?
|
244
|
+
self.request_update = (
|
245
|
+
hasattr(Config().server, "request_update")
|
246
|
+
and Config().server.request_update
|
247
|
+
)
|
248
|
+
|
249
|
+
# Are we disabling all clients and prevent them from running?
|
250
|
+
self.disable_clients = (
|
251
|
+
hasattr(Config().server, "disable_clients")
|
252
|
+
and Config().server.disable_clients
|
253
|
+
)
|
254
|
+
|
255
|
+
# Compute the per-client uplink bandwidth
|
256
|
+
if self.asynchronous_mode:
|
257
|
+
self.uplink_bandwidth = self.uplink_bandwidth / self.minimum_clients
|
258
|
+
else:
|
259
|
+
self.uplink_bandwidth = self.uplink_bandwidth / self.clients_per_round
|
260
|
+
|
261
|
+
def run(self, client=None, edge_server=None, edge_client=None, trainer=None):
|
262
|
+
"""Starts a run loop for the server."""
|
263
|
+
self.client = client
|
264
|
+
self.configure()
|
265
|
+
|
266
|
+
if Config().args.resume:
|
267
|
+
self._resume_from_checkpoint()
|
268
|
+
|
269
|
+
if Config().is_central_server():
|
270
|
+
# Start the edge servers as clients of the central server first
|
271
|
+
# Once all edge servers are live, clients will be initialized in the
|
272
|
+
# training_will_start() event call of the central server
|
273
|
+
Server._start_clients(
|
274
|
+
as_server=True,
|
275
|
+
client=self.client,
|
276
|
+
edge_server=edge_server,
|
277
|
+
edge_client=edge_client,
|
278
|
+
trainer=trainer,
|
279
|
+
)
|
280
|
+
|
281
|
+
asyncio.get_event_loop().create_task(self._periodic(self.periodic_interval))
|
282
|
+
if hasattr(Config().server, "random_seed"):
|
283
|
+
seed = Config().server.random_seed
|
284
|
+
logging.info("Setting the random seed for selecting clients: %s", seed)
|
285
|
+
random.seed(seed)
|
286
|
+
self.prng_state = random.getstate()
|
287
|
+
self.start()
|
288
|
+
|
289
|
+
else:
|
290
|
+
if self.disable_clients:
|
291
|
+
logging.info("No clients are launched (server:disable_clients = true)")
|
292
|
+
else:
|
293
|
+
Server._start_clients(client=self.client)
|
294
|
+
|
295
|
+
asyncio.get_event_loop().create_task(self._periodic(self.periodic_interval))
|
296
|
+
|
297
|
+
if hasattr(Config().server, "random_seed"):
|
298
|
+
seed = Config().server.random_seed
|
299
|
+
logging.info("Setting the random seed for selecting clients: %s", seed)
|
300
|
+
random.seed(seed)
|
301
|
+
self.prng_state = random.getstate()
|
302
|
+
|
303
|
+
self.start()
|
304
|
+
|
305
|
+
def start(self, port=Config().server.port):
|
306
|
+
"""Starts running the socket.io server."""
|
307
|
+
logging.info(
|
308
|
+
"Starting a server at address %s and port %s.",
|
309
|
+
Config().server.address,
|
310
|
+
port,
|
311
|
+
)
|
312
|
+
|
313
|
+
self.sio = socketio.AsyncServer(
|
314
|
+
ping_interval=self.ping_interval,
|
315
|
+
max_http_buffer_size=2**31,
|
316
|
+
ping_timeout=self.ping_timeout,
|
317
|
+
)
|
318
|
+
self.sio.register_namespace(ServerEvents(namespace="/", plato_server=self))
|
319
|
+
|
320
|
+
if hasattr(Config().server, "s3_endpoint_url"):
|
321
|
+
self.s3_client = s3.S3()
|
322
|
+
|
323
|
+
app = web.Application()
|
324
|
+
self.sio.attach(app)
|
325
|
+
web.run_app(
|
326
|
+
app,
|
327
|
+
host=Config().server.address,
|
328
|
+
port=port,
|
329
|
+
loop=asyncio.get_event_loop(),
|
330
|
+
)
|
331
|
+
|
332
|
+
async def register_client(self, sid, client_process_id, client_id):
|
333
|
+
"""Adds a newly arrived client to the list of clients."""
|
334
|
+
self.clients[client_process_id] = {
|
335
|
+
"sid": sid,
|
336
|
+
"client_id": client_id,
|
337
|
+
}
|
338
|
+
logging.info("[%s] New client with id #%d arrived.", self, client_id)
|
339
|
+
logging.info("[%s] Client process #%d registered.", self, client_process_id)
|
340
|
+
|
341
|
+
if (
|
342
|
+
hasattr(Config().trainer, "max_concurrency")
|
343
|
+
and not Config().is_central_server()
|
344
|
+
):
|
345
|
+
required_launched_clients = min(
|
346
|
+
Config().trainer.max_concurrency * max(1, Config().gpu_count()),
|
347
|
+
self.clients_per_round,
|
348
|
+
)
|
349
|
+
else:
|
350
|
+
required_launched_clients = self.clients_per_round
|
351
|
+
|
352
|
+
if (self.current_round == 0 or self.resumed_session) and len(
|
353
|
+
self.clients
|
354
|
+
) >= required_launched_clients:
|
355
|
+
self.resumed_session = False
|
356
|
+
|
357
|
+
self.training_will_start()
|
358
|
+
self.callback_handler.call_event("on_training_will_start", self)
|
359
|
+
|
360
|
+
await self._select_clients()
|
361
|
+
|
362
|
+
@staticmethod
|
363
|
+
def _start_clients(
|
364
|
+
client=None,
|
365
|
+
as_server=False,
|
366
|
+
edge_server=None,
|
367
|
+
edge_client=None,
|
368
|
+
trainer=None,
|
369
|
+
):
|
370
|
+
"""Starts all the clients as separate processes."""
|
371
|
+
starting_id = 1
|
372
|
+
|
373
|
+
# We only need to launch the number of clients necessary for concurrent training
|
374
|
+
# If `max_concurrency` in `trainer` is specified, the limit number is
|
375
|
+
# `max_concurrency` multiply the number of available devices
|
376
|
+
# (multiply number of edge servers in cross-silo training)
|
377
|
+
if hasattr(Config().trainer, "max_concurrency"):
|
378
|
+
if Config().is_central_server():
|
379
|
+
client_processes = min(
|
380
|
+
Config().trainer.max_concurrency
|
381
|
+
* max(1, Config().gpu_count())
|
382
|
+
* Config().algorithm.total_silos,
|
383
|
+
Config().clients.per_round,
|
384
|
+
)
|
385
|
+
else:
|
386
|
+
client_processes = min(
|
387
|
+
Config().trainer.max_concurrency * max(1, Config().gpu_count()),
|
388
|
+
Config().clients.per_round,
|
389
|
+
)
|
390
|
+
# Otherwise, the limited number is the same as the number of clients per round
|
391
|
+
else:
|
392
|
+
client_processes = Config().clients.per_round
|
393
|
+
|
394
|
+
if as_server:
|
395
|
+
total_processes = Config().algorithm.total_silos
|
396
|
+
starting_id += Config().clients.total_clients
|
397
|
+
else:
|
398
|
+
total_processes = client_processes
|
399
|
+
|
400
|
+
if mp.get_start_method(allow_none=True) != "spawn":
|
401
|
+
mp.set_start_method("spawn", force=True)
|
402
|
+
|
403
|
+
for client_id in range(starting_id, total_processes + starting_id):
|
404
|
+
if as_server:
|
405
|
+
port = int(Config().server.port) + client_id
|
406
|
+
logging.info(
|
407
|
+
"Starting client #%d as an edge server on port %s.",
|
408
|
+
client_id,
|
409
|
+
port,
|
410
|
+
)
|
411
|
+
proc = mp.Process(
|
412
|
+
target=run,
|
413
|
+
args=(
|
414
|
+
client_id,
|
415
|
+
port,
|
416
|
+
client,
|
417
|
+
edge_server,
|
418
|
+
edge_client,
|
419
|
+
trainer,
|
420
|
+
),
|
421
|
+
)
|
422
|
+
proc.start()
|
423
|
+
else:
|
424
|
+
logging.info("Starting client #%d's process.", client_id)
|
425
|
+
proc = mp.Process(
|
426
|
+
target=run, args=(client_id, None, client, None, None, None)
|
427
|
+
)
|
428
|
+
proc.start()
|
429
|
+
|
430
|
+
async def _close_connections(self):
|
431
|
+
"""Closes all socket.io connections after training completes."""
|
432
|
+
for client_id, client in dict(self.clients).items():
|
433
|
+
logging.info("Closing the connection to client #%d.", client_id)
|
434
|
+
await self.sio.emit("disconnect", room=client["sid"])
|
435
|
+
|
436
|
+
async def _select_clients(self, for_next_batch=False):
|
437
|
+
"""Selects a subset of the clients and send messages to them to start training."""
|
438
|
+
if not for_next_batch:
|
439
|
+
self.updates = []
|
440
|
+
self.current_round += 1
|
441
|
+
self.round_start_wall_time = self.wall_time
|
442
|
+
|
443
|
+
if hasattr(Config().trainer, "max_concurrency"):
|
444
|
+
self.trained_clients = []
|
445
|
+
|
446
|
+
logging.info(
|
447
|
+
fonts.colourize(
|
448
|
+
f"\n[{self}] Starting round {self.current_round}/{Config().trainer.rounds}."
|
449
|
+
)
|
450
|
+
)
|
451
|
+
|
452
|
+
if Config().is_central_server():
|
453
|
+
# In cross-silo FL, the central server selects from the pool of edge servers
|
454
|
+
self.clients_pool = list(self.clients)
|
455
|
+
|
456
|
+
elif not Config().is_edge_server():
|
457
|
+
self.clients_pool = list(range(1, 1 + self.total_clients))
|
458
|
+
|
459
|
+
# In asychronous FL, avoid selecting new clients to replace those that are still
|
460
|
+
# training at this time
|
461
|
+
|
462
|
+
# When simulating the wall clock time, if len(self.reported_clients) is 0, the
|
463
|
+
# server has aggregated all reporting clients already
|
464
|
+
if (
|
465
|
+
self.asynchronous_mode
|
466
|
+
and self.selected_clients is not None
|
467
|
+
and len(self.reported_clients) > 0
|
468
|
+
and len(self.reported_clients) < self.clients_per_round
|
469
|
+
):
|
470
|
+
# If self.selected_clients is None, it implies that it is the first iteration;
|
471
|
+
# If len(self.reported_clients) == self.clients_per_round, it implies that
|
472
|
+
# all selected clients have already reported.
|
473
|
+
|
474
|
+
# Except for these two cases, we need to exclude the clients who are still
|
475
|
+
# training.
|
476
|
+
training_client_ids = [
|
477
|
+
self.training_clients[client_id]["id"]
|
478
|
+
for client_id in self.training_clients
|
479
|
+
]
|
480
|
+
|
481
|
+
# If the server is simulating the wall clock time, some of the clients who
|
482
|
+
# reported may not have been aggregated; they should be excluded from the next
|
483
|
+
# round of client selection
|
484
|
+
reporting_client_ids = [
|
485
|
+
client[2]["client_id"] for client in self.reported_clients
|
486
|
+
]
|
487
|
+
|
488
|
+
selectable_clients = [
|
489
|
+
client
|
490
|
+
for client in self.clients_pool
|
491
|
+
if client not in training_client_ids
|
492
|
+
and client not in reporting_client_ids
|
493
|
+
]
|
494
|
+
|
495
|
+
if self.simulate_wall_time:
|
496
|
+
self.selected_clients = self.choose_clients(
|
497
|
+
selectable_clients, len(self.current_processed_clients)
|
498
|
+
)
|
499
|
+
else:
|
500
|
+
self.selected_clients = self.choose_clients(
|
501
|
+
selectable_clients, len(self.reported_clients)
|
502
|
+
)
|
503
|
+
else:
|
504
|
+
self.selected_clients = self.choose_clients(
|
505
|
+
self.clients_pool, self.clients_per_round
|
506
|
+
)
|
507
|
+
|
508
|
+
self.current_reported_clients = {}
|
509
|
+
self.current_processed_clients = {}
|
510
|
+
|
511
|
+
# There is no need to clear the list of reporting clients if we are
|
512
|
+
# simulating the wall clock time on the server. This is because
|
513
|
+
# when wall clock time is simulated, the server needs to wait for
|
514
|
+
# all the clients to report before selecting a subset of clients for
|
515
|
+
# replacement, and all remaining reporting clients will be processed
|
516
|
+
# in the next round
|
517
|
+
if not self.simulate_wall_time:
|
518
|
+
self.reported_clients = []
|
519
|
+
|
520
|
+
if len(self.selected_clients) > 0:
|
521
|
+
self.selected_sids = []
|
522
|
+
|
523
|
+
# If max_concurrency is specified, run selected clients batch by batch,
|
524
|
+
# and the number of clients in each batch (on each GPU, if multiple GPUs are available)
|
525
|
+
# is equal to # (or maybe smaller than for the last batch) max_concurrency
|
526
|
+
if (
|
527
|
+
hasattr(Config().trainer, "max_concurrency")
|
528
|
+
and not Config().is_central_server()
|
529
|
+
):
|
530
|
+
selected_clients = []
|
531
|
+
if Config().gpu_count() > 1:
|
532
|
+
untrained_clients = list(
|
533
|
+
set(self.selected_clients).difference(self.trained_clients)
|
534
|
+
)
|
535
|
+
available_gpus = Config().gpu_count()
|
536
|
+
for cuda_id in range(available_gpus):
|
537
|
+
for client_id in untrained_clients:
|
538
|
+
if client_id % available_gpus == cuda_id:
|
539
|
+
selected_clients.append(client_id)
|
540
|
+
if len(selected_clients) >= min(
|
541
|
+
len(self.clients),
|
542
|
+
(cuda_id + 1) * Config().trainer.max_concurrency,
|
543
|
+
self.clients_per_round,
|
544
|
+
):
|
545
|
+
break
|
546
|
+
# There is no enough alive clients, break the selection
|
547
|
+
if len(selected_clients) >= len(self.clients):
|
548
|
+
break
|
549
|
+
else:
|
550
|
+
selected_clients = self.selected_clients[
|
551
|
+
len(self.trained_clients) : min(
|
552
|
+
len(self.trained_clients) + len(self.clients),
|
553
|
+
len(self.selected_clients),
|
554
|
+
)
|
555
|
+
]
|
556
|
+
|
557
|
+
self.trained_clients += selected_clients
|
558
|
+
|
559
|
+
else:
|
560
|
+
selected_clients = self.selected_clients
|
561
|
+
|
562
|
+
for selected_client_id in selected_clients:
|
563
|
+
self.selected_client_id = selected_client_id
|
564
|
+
|
565
|
+
if Config().is_central_server():
|
566
|
+
client_process_id = selected_client_id
|
567
|
+
else:
|
568
|
+
client_processes = [client for client in self.clients]
|
569
|
+
|
570
|
+
# Find a client process that is currently not training
|
571
|
+
# or selected in this round
|
572
|
+
for process_id in client_processes:
|
573
|
+
current_sid = self.clients[process_id]["sid"]
|
574
|
+
if not (
|
575
|
+
current_sid in self.training_sids
|
576
|
+
or current_sid in self.selected_sids
|
577
|
+
):
|
578
|
+
client_process_id = process_id
|
579
|
+
break
|
580
|
+
|
581
|
+
sid = self.clients[client_process_id]["sid"]
|
582
|
+
|
583
|
+
# Track the selected client process
|
584
|
+
self.training_sids.append(sid)
|
585
|
+
self.selected_sids.append(sid)
|
586
|
+
|
587
|
+
# Assign the client id to the client process
|
588
|
+
self.clients[client_process_id]["client_id"] = self.selected_client_id
|
589
|
+
|
590
|
+
self.training_clients[self.selected_client_id] = {
|
591
|
+
"id": self.selected_client_id,
|
592
|
+
"starting_round": self.current_round,
|
593
|
+
"start_time": self.round_start_wall_time,
|
594
|
+
"update_requested": False,
|
595
|
+
}
|
596
|
+
|
597
|
+
logging.info(
|
598
|
+
"[%s] Selecting client #%d for training.",
|
599
|
+
self,
|
600
|
+
self.selected_client_id,
|
601
|
+
)
|
602
|
+
|
603
|
+
server_response = {
|
604
|
+
"id": self.selected_client_id,
|
605
|
+
"current_round": self.current_round,
|
606
|
+
}
|
607
|
+
server_response = self.customize_server_response(
|
608
|
+
server_response, client_id=self.selected_client_id
|
609
|
+
)
|
610
|
+
|
611
|
+
payload = self.algorithm.extract_weights()
|
612
|
+
payload = self.customize_server_payload(payload)
|
613
|
+
|
614
|
+
if self.comm_simulation:
|
615
|
+
logging.info(
|
616
|
+
"[%s] Sending the current model to client #%d (simulated).",
|
617
|
+
self,
|
618
|
+
self.selected_client_id,
|
619
|
+
)
|
620
|
+
|
621
|
+
# First apply outbound processors, if any
|
622
|
+
payload = self.outbound_processor.process(payload)
|
623
|
+
|
624
|
+
model_name = (
|
625
|
+
Config().trainer.model_name
|
626
|
+
if hasattr(Config().trainer, "model_name")
|
627
|
+
else "custom"
|
628
|
+
)
|
629
|
+
if "/" in model_name:
|
630
|
+
model_name = model_name.replace("/", "_")
|
631
|
+
|
632
|
+
checkpoint_path = Config().params["checkpoint_path"]
|
633
|
+
|
634
|
+
payload_filename = (
|
635
|
+
f"{checkpoint_path}/{model_name}_{self.selected_client_id}.pth"
|
636
|
+
)
|
637
|
+
|
638
|
+
with open(payload_filename, "wb") as payload_file:
|
639
|
+
pickle.dump(payload, payload_file)
|
640
|
+
|
641
|
+
server_response["payload_filename"] = payload_filename
|
642
|
+
|
643
|
+
payload_size = sys.getsizeof(pickle.dumps(payload)) / 1024**2
|
644
|
+
|
645
|
+
logging.info(
|
646
|
+
"[%s] Sending %.2f MB of payload data to client #%d (simulated).",
|
647
|
+
self,
|
648
|
+
payload_size,
|
649
|
+
self.selected_client_id,
|
650
|
+
)
|
651
|
+
|
652
|
+
self.comm_overhead += payload_size
|
653
|
+
|
654
|
+
# Compute the communication time to transfer the current global model to client
|
655
|
+
self.downlink_comm_time[self.selected_client_id] = payload_size / (
|
656
|
+
(self.downlink_bandwidth / 8) / len(self.selected_clients)
|
657
|
+
)
|
658
|
+
|
659
|
+
# Send the server response as metadata to the clients (payload to follow)
|
660
|
+
await self.sio.emit(
|
661
|
+
"payload_to_arrive", {"response": server_response}, room=sid
|
662
|
+
)
|
663
|
+
|
664
|
+
if not self.comm_simulation:
|
665
|
+
# Send the server payload to the client
|
666
|
+
logging.info(
|
667
|
+
"[%s] Sending the current model to client #%d.",
|
668
|
+
self,
|
669
|
+
selected_client_id,
|
670
|
+
)
|
671
|
+
|
672
|
+
await self._send(sid, payload, selected_client_id)
|
673
|
+
|
674
|
+
self.clients_selected(self.selected_clients)
|
675
|
+
self.callback_handler.call_event(
|
676
|
+
"on_clients_selected", self, self.selected_clients
|
677
|
+
)
|
678
|
+
|
679
|
+
def choose_clients(self, clients_pool, clients_count):
|
680
|
+
"""Chooses a subset of the clients to participate in each round."""
|
681
|
+
assert clients_count <= len(clients_pool)
|
682
|
+
random.setstate(self.prng_state)
|
683
|
+
|
684
|
+
# Select clients randomly
|
685
|
+
selected_clients = random.sample(clients_pool, clients_count)
|
686
|
+
|
687
|
+
self.prng_state = random.getstate()
|
688
|
+
logging.info("[%s] Selected clients: %s", self, selected_clients)
|
689
|
+
return selected_clients
|
690
|
+
|
691
|
+
async def _periodic(self, periodic_interval):
|
692
|
+
"""Runs _periodic_task() periodically on the server. The time interval between
|
693
|
+
its execution is defined in 'server:periodic_interval'.
|
694
|
+
"""
|
695
|
+
while True:
|
696
|
+
await self._periodic_task()
|
697
|
+
await asyncio.sleep(periodic_interval)
|
698
|
+
|
699
|
+
async def _periodic_task(self):
|
700
|
+
"""A periodic task that is executed from time to time, determined by
|
701
|
+
'server:periodic_interval' with a default value of 5 seconds, in the configuration.
|
702
|
+
"""
|
703
|
+
# Call the async function that defines a customized periodic task, if any
|
704
|
+
await self.periodic_task()
|
705
|
+
|
706
|
+
# If we are operating in asynchronous mode, aggregate the model updates received so far.
|
707
|
+
if self.asynchronous_mode and not self.simulate_wall_time:
|
708
|
+
# Is there any training clients who are currently training on models that are too
|
709
|
+
# `stale,` as defined by the staleness threshold?
|
710
|
+
for __, client_data in self.training_clients.items():
|
711
|
+
# The client is still working at an early round, early enough to stop the
|
712
|
+
# aggregation process as determined by 'staleness'
|
713
|
+
client_staleness = self.current_round - client_data["starting_round"]
|
714
|
+
if client_staleness > self.staleness_bound:
|
715
|
+
logging.info(
|
716
|
+
"[%s] Client %s is still working at round %s, which is "
|
717
|
+
"beyond the staleness bound %s compared to the current round %s. "
|
718
|
+
"Nothing to process.",
|
719
|
+
self,
|
720
|
+
client_data["id"],
|
721
|
+
client_data["starting_round"],
|
722
|
+
self.staleness_bound,
|
723
|
+
self.current_round,
|
724
|
+
)
|
725
|
+
|
726
|
+
return
|
727
|
+
|
728
|
+
if len(self.updates) >= self.minimum_clients:
|
729
|
+
logging.info(
|
730
|
+
"[%s] %d client report(s) received in asynchronous mode. Processing.",
|
731
|
+
self,
|
732
|
+
len(self.updates),
|
733
|
+
)
|
734
|
+
await self._process_reports()
|
735
|
+
await self.wrap_up()
|
736
|
+
await self._select_clients()
|
737
|
+
else:
|
738
|
+
logging.info(
|
739
|
+
"[%s] No sufficient number of client reports have been received. "
|
740
|
+
"Nothing to process.",
|
741
|
+
self,
|
742
|
+
)
|
743
|
+
|
744
|
+
async def _send_in_chunks(self, data, sid, client_id) -> None:
|
745
|
+
"""Sends a bytes object in fixed-sized chunks to the client."""
|
746
|
+
step = 1024**2
|
747
|
+
chunks = [data[i : i + step] for i in range(0, len(data), step)]
|
748
|
+
|
749
|
+
for chunk in chunks:
|
750
|
+
await self.sio.emit("chunk", {"data": chunk}, room=sid)
|
751
|
+
|
752
|
+
await self.sio.emit("payload", {"id": client_id}, room=sid)
|
753
|
+
|
754
|
+
async def _send(self, sid, payload, client_id) -> None:
|
755
|
+
"""Sends a new data payload to the client using either S3 or socket.io."""
|
756
|
+
# First apply outbound processors, if any
|
757
|
+
payload = self.outbound_processor.process(payload)
|
758
|
+
|
759
|
+
metadata = {"id": client_id}
|
760
|
+
|
761
|
+
if self.s3_client is not None:
|
762
|
+
s3_key = f"server_payload_{os.getpid()}_{self.current_round}"
|
763
|
+
self.s3_client.send_to_s3(s3_key, payload)
|
764
|
+
data_size = sys.getsizeof(pickle.dumps(payload))
|
765
|
+
metadata["s3_key"] = s3_key
|
766
|
+
else:
|
767
|
+
data_size = 0
|
768
|
+
|
769
|
+
if isinstance(payload, list):
|
770
|
+
for data in payload:
|
771
|
+
_data = pickle.dumps(data)
|
772
|
+
await self._send_in_chunks(_data, sid, client_id)
|
773
|
+
data_size += sys.getsizeof(_data)
|
774
|
+
|
775
|
+
else:
|
776
|
+
_data = pickle.dumps(payload)
|
777
|
+
await self._send_in_chunks(_data, sid, client_id)
|
778
|
+
data_size = sys.getsizeof(_data)
|
779
|
+
|
780
|
+
await self.sio.emit("payload_done", metadata, room=sid)
|
781
|
+
|
782
|
+
logging.info(
|
783
|
+
"[%s] Sent %.2f MB of payload data to client #%d.",
|
784
|
+
self,
|
785
|
+
data_size / 1024**2,
|
786
|
+
client_id,
|
787
|
+
)
|
788
|
+
|
789
|
+
self.comm_overhead += data_size / 1024**2
|
790
|
+
|
791
|
+
async def _client_report_arrived(self, sid, client_id, report):
|
792
|
+
"""Upon receiving a report from a client."""
|
793
|
+
self.reports[sid] = pickle.loads(report)
|
794
|
+
self.client_payload[sid] = None
|
795
|
+
self.client_chunks[sid] = []
|
796
|
+
|
797
|
+
if self.comm_simulation:
|
798
|
+
model_name = (
|
799
|
+
Config().trainer.model_name
|
800
|
+
if hasattr(Config().trainer, "model_name")
|
801
|
+
else "custom"
|
802
|
+
)
|
803
|
+
if "/" in model_name:
|
804
|
+
model_name = model_name.replace("/", "_")
|
805
|
+
checkpoint_path = Config().params["checkpoint_path"]
|
806
|
+
payload_filename = f"{checkpoint_path}/{model_name}_client_{client_id}.pth"
|
807
|
+
with open(payload_filename, "rb") as payload_file:
|
808
|
+
self.client_payload[sid] = pickle.load(payload_file)
|
809
|
+
|
810
|
+
payload_size = (
|
811
|
+
sys.getsizeof(pickle.dumps(self.client_payload[sid])) / 1024**2
|
812
|
+
)
|
813
|
+
|
814
|
+
logging.info(
|
815
|
+
"[%s] Received %.2f MB of payload data from client #%d (simulated).",
|
816
|
+
self,
|
817
|
+
payload_size,
|
818
|
+
client_id,
|
819
|
+
)
|
820
|
+
|
821
|
+
self.comm_overhead += payload_size
|
822
|
+
|
823
|
+
self.uplink_comm_time[client_id] = payload_size / (
|
824
|
+
self.uplink_bandwidth / 8
|
825
|
+
)
|
826
|
+
|
827
|
+
await self.process_client_info(client_id, sid)
|
828
|
+
|
829
|
+
async def _client_chunk_arrived(self, sid, data) -> None:
|
830
|
+
"""Upon receiving a chunk of data from a client."""
|
831
|
+
self.client_chunks[sid].append(data)
|
832
|
+
|
833
|
+
async def _client_payload_arrived(self, sid, client_id):
|
834
|
+
"""Upon receiving a portion of the payload from a client."""
|
835
|
+
assert len(self.client_chunks[sid]) > 0 and client_id in self.training_clients
|
836
|
+
|
837
|
+
payload = b"".join(self.client_chunks[sid])
|
838
|
+
_data = pickle.loads(payload)
|
839
|
+
self.client_chunks[sid] = []
|
840
|
+
|
841
|
+
if self.client_payload[sid] is None:
|
842
|
+
self.client_payload[sid] = _data
|
843
|
+
elif isinstance(self.client_payload[sid], list):
|
844
|
+
self.client_payload[sid].append(_data)
|
845
|
+
else:
|
846
|
+
self.client_payload[sid] = [self.client_payload[sid]]
|
847
|
+
self.client_payload[sid].append(_data)
|
848
|
+
|
849
|
+
async def _client_payload_done(self, sid, client_id, s3_key=None):
|
850
|
+
"""Upon receiving all the payload from a client, either via S3 or socket.io."""
|
851
|
+
if s3_key is None:
|
852
|
+
assert self.client_payload[sid] is not None
|
853
|
+
|
854
|
+
payload_size = 0
|
855
|
+
if isinstance(self.client_payload[sid], list):
|
856
|
+
for _data in self.client_payload[sid]:
|
857
|
+
payload_size += sys.getsizeof(pickle.dumps(_data))
|
858
|
+
else:
|
859
|
+
payload_size = sys.getsizeof(pickle.dumps(self.client_payload[sid]))
|
860
|
+
else:
|
861
|
+
self.client_payload[sid] = self.s3_client.receive_from_s3(s3_key)
|
862
|
+
payload_size = sys.getsizeof(pickle.dumps(self.client_payload[sid]))
|
863
|
+
|
864
|
+
logging.info(
|
865
|
+
"[%s] Received %.2f MB of payload data from client #%d.",
|
866
|
+
self,
|
867
|
+
payload_size / 1024**2,
|
868
|
+
client_id,
|
869
|
+
)
|
870
|
+
|
871
|
+
self.comm_overhead += payload_size / 1024**2
|
872
|
+
|
873
|
+
await self.process_client_info(client_id, sid)
|
874
|
+
|
875
|
+
async def process_client_info(self, client_id, sid):
|
876
|
+
"""Processes the received metadata information from a reporting client."""
|
877
|
+
# First pass through the inbound_processor(s), if any
|
878
|
+
self.client_payload[sid] = self.inbound_processor.process(
|
879
|
+
self.client_payload[sid]
|
880
|
+
)
|
881
|
+
|
882
|
+
if self.comm_simulation:
|
883
|
+
if (
|
884
|
+
hasattr(Config().clients, "compute_comm_time")
|
885
|
+
and Config().clients.compute_comm_time
|
886
|
+
):
|
887
|
+
self.reports[sid].comm_time = (
|
888
|
+
self.downlink_comm_time[client_id]
|
889
|
+
+ self.uplink_comm_time[client_id]
|
890
|
+
)
|
891
|
+
else:
|
892
|
+
self.reports[sid].comm_time = 0
|
893
|
+
else:
|
894
|
+
self.reports[sid].comm_time = time.time() - self.reports[sid].comm_time
|
895
|
+
|
896
|
+
# When the client is responding to an urgent request for an update, it will
|
897
|
+
# store its (possibly different) client ID in its report
|
898
|
+
client_id = self.reports[sid].client_id
|
899
|
+
|
900
|
+
start_time = self.training_clients[client_id]["start_time"]
|
901
|
+
finish_time = (
|
902
|
+
self.reports[sid].training_time
|
903
|
+
+ self.reports[sid].processing_time
|
904
|
+
+ self.reports[sid].comm_time
|
905
|
+
+ start_time
|
906
|
+
)
|
907
|
+
starting_round = self.training_clients[client_id]["starting_round"]
|
908
|
+
|
909
|
+
if Config().is_central_server():
|
910
|
+
self.comm_overhead += self.reports[sid].edge_server_comm_overhead
|
911
|
+
|
912
|
+
client_info = (
|
913
|
+
finish_time, # sorted by the client's finish time
|
914
|
+
client_id, # in case two or more clients have the same finish time
|
915
|
+
{
|
916
|
+
"client_id": client_id,
|
917
|
+
"sid": sid,
|
918
|
+
"starting_round": starting_round,
|
919
|
+
"start_time": start_time,
|
920
|
+
"report": self.reports[sid],
|
921
|
+
"payload": self.client_payload[sid],
|
922
|
+
},
|
923
|
+
)
|
924
|
+
|
925
|
+
if self.asynchronous_mode and self.simulate_wall_time:
|
926
|
+
heapq.heappush(self.reported_clients, client_info)
|
927
|
+
self.current_reported_clients[client_info[2]["client_id"]] = True
|
928
|
+
del self.training_clients[client_id]
|
929
|
+
|
930
|
+
self.training_sids.remove(client_info[2]["sid"])
|
931
|
+
|
932
|
+
await self._process_clients(client_info)
|
933
|
+
|
934
|
+
# pylint: disable=unused-argument
|
935
|
+
def should_request_update(
|
936
|
+
self, client_id, start_time, finish_time, client_staleness, report
|
937
|
+
):
|
938
|
+
"""Determines if an explicit request for model update should be sent to the client."""
|
939
|
+
return client_staleness > self.staleness_bound and finish_time > self.wall_time
|
940
|
+
|
941
|
+
async def _process_clients(self, client_info):
|
942
|
+
"""Determines whether it is time to process the client reports and
|
943
|
+
proceed with the aggregation process.
|
944
|
+
|
945
|
+
When in asynchronous mode, additional processing is needed to simulate
|
946
|
+
the wall clock time.
|
947
|
+
"""
|
948
|
+
# In asynchronous mode with simulated wall clock time, we need to extract
|
949
|
+
# the minimum number of clients from the list of all reporting clients, and then
|
950
|
+
# proceed with report processing and replace these clients with a new set of
|
951
|
+
# selected clients
|
952
|
+
if (
|
953
|
+
self.asynchronous_mode
|
954
|
+
and self.simulate_wall_time
|
955
|
+
and len(self.current_reported_clients) >= len(self.selected_clients)
|
956
|
+
):
|
957
|
+
# Step 1: Sanity checks to see if there are any stale clients; if so, send them
|
958
|
+
# an urgent request for model updates at the current simulated wall clock time
|
959
|
+
if self.request_update:
|
960
|
+
# We should not proceed with further processing if there are outstanding requests
|
961
|
+
# for urgent client updates
|
962
|
+
for __, client_data in self.training_clients.items():
|
963
|
+
if client_data["update_requested"]:
|
964
|
+
return
|
965
|
+
|
966
|
+
request_sent = False
|
967
|
+
for i, client_info in enumerate(self.reported_clients):
|
968
|
+
client = client_info[2]
|
969
|
+
client_staleness = self.current_round - client["starting_round"]
|
970
|
+
|
971
|
+
if (
|
972
|
+
self.should_request_update(
|
973
|
+
client_id=client["client_id"],
|
974
|
+
start_time=client["start_time"],
|
975
|
+
finish_time=client_info[0],
|
976
|
+
client_staleness=client_staleness,
|
977
|
+
report=client["report"],
|
978
|
+
)
|
979
|
+
and not client["report"].update_response
|
980
|
+
):
|
981
|
+
# Sending an urgent request to the client for a model update at the
|
982
|
+
# currently simulated wall clock time
|
983
|
+
client_id = client["client_id"]
|
984
|
+
|
985
|
+
logging.info(
|
986
|
+
"[Server #%s] Requesting urgent model update from client #%s.",
|
987
|
+
os.getpid(),
|
988
|
+
client_id,
|
989
|
+
)
|
990
|
+
|
991
|
+
# Remove the client information from the list of reporting clients since
|
992
|
+
# this client will report again soon with another model update upon
|
993
|
+
# receiving the request from the server
|
994
|
+
del self.reported_clients[i]
|
995
|
+
|
996
|
+
self.training_clients[client_id] = {
|
997
|
+
"id": client_id,
|
998
|
+
"starting_round": client["starting_round"],
|
999
|
+
"start_time": client["start_time"],
|
1000
|
+
"update_requested": True,
|
1001
|
+
}
|
1002
|
+
|
1003
|
+
sid = client["sid"]
|
1004
|
+
|
1005
|
+
self.training_sids.append(sid)
|
1006
|
+
|
1007
|
+
await self.sio.emit(
|
1008
|
+
"request_update",
|
1009
|
+
{
|
1010
|
+
"client_id": client_id,
|
1011
|
+
"time": self.wall_time - client["start_time"],
|
1012
|
+
},
|
1013
|
+
room=sid,
|
1014
|
+
)
|
1015
|
+
request_sent = True
|
1016
|
+
|
1017
|
+
# If an urgent request was sent, we will wait until the client gets back to proceed
|
1018
|
+
# with aggregation.
|
1019
|
+
if request_sent:
|
1020
|
+
return
|
1021
|
+
|
1022
|
+
# Step 2: Processing clients in chronological order of finish times in wall clock time
|
1023
|
+
for __ in range(
|
1024
|
+
0, min(len(self.current_reported_clients), self.minimum_clients)
|
1025
|
+
):
|
1026
|
+
# Extract a client with the earliest finish time in wall clock time
|
1027
|
+
client_info = heapq.heappop(self.reported_clients)
|
1028
|
+
client = client_info[2]
|
1029
|
+
|
1030
|
+
# Removing from the list of current reporting clients as well, if needed
|
1031
|
+
self.current_processed_clients[client["client_id"]] = True
|
1032
|
+
|
1033
|
+
# Update the simulated wall clock time to be the finish time of this client
|
1034
|
+
self.wall_time = client_info[0]
|
1035
|
+
|
1036
|
+
# Add the report and payload of the extracted reporting client into updates
|
1037
|
+
logging.info(
|
1038
|
+
"[Server #%s] Adding client #%s to the list of clients for aggregation.",
|
1039
|
+
os.getpid(),
|
1040
|
+
client["client_id"],
|
1041
|
+
)
|
1042
|
+
|
1043
|
+
client_staleness = self.current_round - client["starting_round"]
|
1044
|
+
self.updates.append(
|
1045
|
+
SimpleNamespace(
|
1046
|
+
client_id=client["client_id"],
|
1047
|
+
report=client["report"],
|
1048
|
+
payload=client["payload"],
|
1049
|
+
staleness=client_staleness,
|
1050
|
+
)
|
1051
|
+
)
|
1052
|
+
|
1053
|
+
# Step 3: Processing stale clients that exceed a staleness threshold
|
1054
|
+
|
1055
|
+
# If there are more clients in the list of reporting clients that violate the
|
1056
|
+
# staleness bound, the server needs to wait for these clients even when the minimum
|
1057
|
+
# number of clients has been reached, by simply advancing its simulated wall clock
|
1058
|
+
# time ahead to include the remaining clients, until no stale clients exist
|
1059
|
+
possibly_stale_clients = []
|
1060
|
+
|
1061
|
+
# Is there any reporting clients who are currently training on models that are too
|
1062
|
+
# `stale,` as defined by the staleness threshold? If so, we need to advance the wall
|
1063
|
+
# clock time until no stale clients exist in the future
|
1064
|
+
for __ in range(0, len(self.reported_clients)):
|
1065
|
+
# Extract a client with the earliest finish time in wall clock time
|
1066
|
+
client_info = heapq.heappop(self.reported_clients)
|
1067
|
+
heapq.heappush(possibly_stale_clients, client_info)
|
1068
|
+
|
1069
|
+
if (
|
1070
|
+
client_info[2]["starting_round"]
|
1071
|
+
< self.current_round - self.staleness_bound
|
1072
|
+
):
|
1073
|
+
for __ in range(0, len(possibly_stale_clients)):
|
1074
|
+
stale_client_info = heapq.heappop(possibly_stale_clients)
|
1075
|
+
# Update the simulated wall clock time to be the finish time of this client
|
1076
|
+
self.wall_time = stale_client_info[0]
|
1077
|
+
client = stale_client_info[2]
|
1078
|
+
|
1079
|
+
# Add the report and payload of the extracted reporting client into updates
|
1080
|
+
logging.info(
|
1081
|
+
"[Server #%s] Adding client #%s to the list of clients for "
|
1082
|
+
"aggregation.",
|
1083
|
+
os.getpid(),
|
1084
|
+
client["client_id"],
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
client_staleness = self.current_round - client["starting_round"]
|
1088
|
+
self.updates.append(
|
1089
|
+
SimpleNamespace(
|
1090
|
+
client_id=client["client_id"],
|
1091
|
+
report=client["report"],
|
1092
|
+
payload=client["payload"],
|
1093
|
+
staleness=client_staleness,
|
1094
|
+
)
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
self.reported_clients = possibly_stale_clients
|
1098
|
+
logging.info(
|
1099
|
+
"[Server #%s] Aggregating %s clients in total.",
|
1100
|
+
os.getpid(),
|
1101
|
+
len(self.updates),
|
1102
|
+
)
|
1103
|
+
|
1104
|
+
await self._process_reports()
|
1105
|
+
await self.wrap_up()
|
1106
|
+
await self._select_clients()
|
1107
|
+
return
|
1108
|
+
|
1109
|
+
if not self.simulate_wall_time or not self.asynchronous_mode:
|
1110
|
+
# In both synchronous and asynchronous modes, if we are not simulating the wall clock
|
1111
|
+
# time, we need to add the client report to the list of updates so far;
|
1112
|
+
# the same applies when we are running in synchronous mode.
|
1113
|
+
client = client_info[2]
|
1114
|
+
client_staleness = self.current_round - client["starting_round"]
|
1115
|
+
|
1116
|
+
self.updates.append(
|
1117
|
+
SimpleNamespace(
|
1118
|
+
client_id=client["client_id"],
|
1119
|
+
report=client["report"],
|
1120
|
+
payload=client["payload"],
|
1121
|
+
staleness=client_staleness,
|
1122
|
+
)
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
if not self.simulate_wall_time:
|
1126
|
+
# In both synchronous and asynchronous modes, if we are not simulating the wall clock
|
1127
|
+
# time, it will need to be updated to the real wall clock time
|
1128
|
+
self.wall_time = time.time()
|
1129
|
+
|
1130
|
+
if not self.asynchronous_mode and self.simulate_wall_time:
|
1131
|
+
# In synchronous mode with the wall clock time simulated, in addition to adding
|
1132
|
+
# the client report to the list of updates, we will also need to advance the wall
|
1133
|
+
# clock time to the finish time of the reporting client
|
1134
|
+
client_finish_time = client_info[0]
|
1135
|
+
self.wall_time = max(client_finish_time, self.wall_time)
|
1136
|
+
|
1137
|
+
logging.info(
|
1138
|
+
"[%s] Advancing the wall clock time to %.2f.",
|
1139
|
+
self,
|
1140
|
+
self.wall_time,
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
# If all updates have been received from selected clients, the aggregation process
|
1144
|
+
# proceeds regardless of synchronous or asynchronous modes. This guarantees that
|
1145
|
+
# if asynchronous mode uses an excessively long aggregation interval, it will not
|
1146
|
+
# unnecessarily delay the aggregation process.
|
1147
|
+
if len(self.updates) >= self.clients_per_round:
|
1148
|
+
logging.info(
|
1149
|
+
"[%s] All %d client report(s) received. Processing.",
|
1150
|
+
self,
|
1151
|
+
len(self.updates),
|
1152
|
+
)
|
1153
|
+
await self._process_reports()
|
1154
|
+
await self.wrap_up()
|
1155
|
+
await self._select_clients()
|
1156
|
+
|
1157
|
+
elif (
|
1158
|
+
hasattr(Config().trainer, "max_concurrency")
|
1159
|
+
and not Config().is_central_server()
|
1160
|
+
):
|
1161
|
+
# Clients in the current batch finish training
|
1162
|
+
# The server will select the next batch of clients to train
|
1163
|
+
if len(self.updates) >= len(self.trained_clients) or len(
|
1164
|
+
self.current_reported_clients
|
1165
|
+
) >= len(self.trained_clients):
|
1166
|
+
await self._select_clients(for_next_batch=True)
|
1167
|
+
|
1168
|
+
async def _client_disconnected(self, sid):
|
1169
|
+
"""When a client process disconnected it should be removed from its internal states."""
|
1170
|
+
for client_process_id, client in dict(self.clients).items():
|
1171
|
+
if client["sid"] == sid:
|
1172
|
+
# Obtain the client id before deleting
|
1173
|
+
client_id = self.clients[client_process_id]["client_id"]
|
1174
|
+
|
1175
|
+
# Remove the physical client from server list
|
1176
|
+
del self.clients[client_process_id]
|
1177
|
+
logging.warning(
|
1178
|
+
"[%s] Client process #%d disconnected and removed from this server, %d client processes are remaining.",
|
1179
|
+
self,
|
1180
|
+
client_process_id,
|
1181
|
+
len(self.clients),
|
1182
|
+
)
|
1183
|
+
|
1184
|
+
if len(self.clients) == 0:
|
1185
|
+
logging.warning(
|
1186
|
+
fonts.colourize(
|
1187
|
+
f"[{self}] All clients disconnected, closing the server."
|
1188
|
+
)
|
1189
|
+
)
|
1190
|
+
await self._close()
|
1191
|
+
|
1192
|
+
# Handle the logical client under different situations
|
1193
|
+
if client_id in self.training_clients:
|
1194
|
+
del self.training_clients[client_id]
|
1195
|
+
|
1196
|
+
if client_id in self.current_reported_clients:
|
1197
|
+
del self.current_reported_clients[client_id]
|
1198
|
+
|
1199
|
+
# Decide continue or exit training
|
1200
|
+
if (
|
1201
|
+
hasattr(Config(), "general")
|
1202
|
+
and hasattr(Config().general, "debug")
|
1203
|
+
and not Config().general.debug
|
1204
|
+
):
|
1205
|
+
# Recover from the failed client and proceed with training
|
1206
|
+
if (
|
1207
|
+
client_id in self.selected_clients
|
1208
|
+
and client_id in self.trained_clients
|
1209
|
+
):
|
1210
|
+
self.trained_clients.remove(client_id)
|
1211
|
+
fail_client_index = self.selected_clients.index(client_id)
|
1212
|
+
untrained_client_index = len(self.trained_clients)
|
1213
|
+
|
1214
|
+
# Swap current client to the begining of untrained clients
|
1215
|
+
self.selected_clients[fail_client_index] = (
|
1216
|
+
self.selected_clients[untrained_client_index]
|
1217
|
+
)
|
1218
|
+
self.selected_clients[untrained_client_index] = client_id
|
1219
|
+
|
1220
|
+
# Start next batch of client selection if current batch is done
|
1221
|
+
if len(self.updates) >= len(self.trained_clients) or len(
|
1222
|
+
self.current_reported_clients
|
1223
|
+
) >= len(self.trained_clients):
|
1224
|
+
await self._select_clients(for_next_batch=True)
|
1225
|
+
else:
|
1226
|
+
# Debug is either turned on or not specified, stop the training to avoid blocking.
|
1227
|
+
logging.warning(
|
1228
|
+
fonts.colourize(
|
1229
|
+
f"[{self}] Closing the server due to a failed client."
|
1230
|
+
)
|
1231
|
+
)
|
1232
|
+
await self._close()
|
1233
|
+
|
1234
|
+
def save_to_checkpoint(self) -> None:
|
1235
|
+
"""Saves a checkpoint for resuming the training session."""
|
1236
|
+
checkpoint_path = Config.params["checkpoint_path"]
|
1237
|
+
|
1238
|
+
model_name = (
|
1239
|
+
Config().trainer.model_name
|
1240
|
+
if hasattr(Config().trainer, "model_name")
|
1241
|
+
else "custom"
|
1242
|
+
)
|
1243
|
+
if "/" in model_name:
|
1244
|
+
model_name = model_name.replace("/", "_")
|
1245
|
+
filename = f"checkpoint_{model_name}_{self.current_round}.pth"
|
1246
|
+
logging.info(
|
1247
|
+
"[%s] Saving the checkpoint to %s/%s.",
|
1248
|
+
self,
|
1249
|
+
checkpoint_path,
|
1250
|
+
filename,
|
1251
|
+
)
|
1252
|
+
self.trainer.save_model(filename, checkpoint_path)
|
1253
|
+
self._save_random_states(self.current_round, checkpoint_path)
|
1254
|
+
|
1255
|
+
# Saving the current round in the server for resuming its session later on
|
1256
|
+
with open(f"{checkpoint_path}/current_round.pkl", "wb") as checkpoint_file:
|
1257
|
+
pickle.dump(self.current_round, checkpoint_file)
|
1258
|
+
|
1259
|
+
def _resume_from_checkpoint(self):
|
1260
|
+
"""Resumes a training session from a previously saved checkpoint."""
|
1261
|
+
logging.info(
|
1262
|
+
"[%s] Resume a training session from a previously saved checkpoint.",
|
1263
|
+
self,
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
# Loading important data in the server for resuming its session
|
1267
|
+
checkpoint_path = Config.params["checkpoint_path"]
|
1268
|
+
|
1269
|
+
with open(f"{checkpoint_path}/current_round.pkl", "rb") as checkpoint_file:
|
1270
|
+
self.current_round = pickle.load(checkpoint_file)
|
1271
|
+
|
1272
|
+
self._restore_random_states(self.current_round, checkpoint_path)
|
1273
|
+
self.resumed_session = True
|
1274
|
+
|
1275
|
+
model_name = (
|
1276
|
+
Config().trainer.model_name
|
1277
|
+
if hasattr(Config().trainer, "model_name")
|
1278
|
+
else "custom"
|
1279
|
+
)
|
1280
|
+
filename = f"checkpoint_{model_name}_{self.current_round}.pth"
|
1281
|
+
self.trainer.load_model(filename, checkpoint_path)
|
1282
|
+
|
1283
|
+
def _save_random_states(self, round_to_save, checkpoint_path):
|
1284
|
+
"""Saves the random states in the server for resuming its session later on."""
|
1285
|
+
states_to_save = [
|
1286
|
+
f"numpy_prng_state_{round_to_save}",
|
1287
|
+
f"prng_state_{round_to_save}",
|
1288
|
+
]
|
1289
|
+
|
1290
|
+
variables_to_save = [
|
1291
|
+
np.random.get_state(),
|
1292
|
+
random.getstate(),
|
1293
|
+
]
|
1294
|
+
|
1295
|
+
for i, state in enumerate(states_to_save):
|
1296
|
+
with open(f"{checkpoint_path}/{state}.pkl", "wb") as checkpoint_file:
|
1297
|
+
pickle.dump(variables_to_save[i], checkpoint_file)
|
1298
|
+
|
1299
|
+
def _restore_random_states(self, round_to_restore, checkpoint_path):
|
1300
|
+
"""Restors the numpy.random and random states from previously saved checkpoints
|
1301
|
+
for a particular round.
|
1302
|
+
"""
|
1303
|
+
states_to_load = ["numpy_prng_state", "prng_state"]
|
1304
|
+
variables_to_load = {}
|
1305
|
+
|
1306
|
+
for i, state in enumerate(states_to_load):
|
1307
|
+
with open(
|
1308
|
+
f"{checkpoint_path}/{state}_{round_to_restore}.pkl", "rb"
|
1309
|
+
) as checkpoint_file:
|
1310
|
+
variables_to_load[i] = pickle.load(checkpoint_file)
|
1311
|
+
|
1312
|
+
numpy_prng_state = variables_to_load[0]
|
1313
|
+
self.prng_state = variables_to_load[1]
|
1314
|
+
|
1315
|
+
np.random.set_state(numpy_prng_state)
|
1316
|
+
random.setstate(self.prng_state)
|
1317
|
+
|
1318
|
+
async def wrap_up(self) -> None:
|
1319
|
+
"""Wraps up when each round of training is done."""
|
1320
|
+
self.save_to_checkpoint()
|
1321
|
+
|
1322
|
+
# Break the loop when the target accuracy is achieved
|
1323
|
+
target_accuracy = None
|
1324
|
+
target_perplexity = None
|
1325
|
+
|
1326
|
+
if hasattr(Config().trainer, "target_accuracy"):
|
1327
|
+
target_accuracy = Config().trainer.target_accuracy
|
1328
|
+
elif hasattr(Config().trainer, "target_perplexity"):
|
1329
|
+
target_perplexity = Config().trainer.target_perplexity
|
1330
|
+
|
1331
|
+
if target_accuracy and self.accuracy >= target_accuracy:
|
1332
|
+
logging.info("[%s] Target accuracy reached.", self)
|
1333
|
+
await self._close()
|
1334
|
+
|
1335
|
+
if target_perplexity and self.accuracy <= target_perplexity:
|
1336
|
+
logging.info("[%s] Target perplexity reached.", self)
|
1337
|
+
await self._close()
|
1338
|
+
|
1339
|
+
if self.current_round >= Config().trainer.rounds:
|
1340
|
+
logging.info("Target number of training rounds reached.")
|
1341
|
+
await self._close()
|
1342
|
+
|
1343
|
+
async def _close(self):
|
1344
|
+
"""Closes the server."""
|
1345
|
+
logging.info("[%s] Training concluded.", self)
|
1346
|
+
self.trainer.save_model()
|
1347
|
+
|
1348
|
+
self.server_will_close()
|
1349
|
+
self.callback_handler.call_event("on_server_will_close", self)
|
1350
|
+
|
1351
|
+
await self._close_connections()
|
1352
|
+
os._exit(0)
|
1353
|
+
|
1354
|
+
def add_callbacks(self, callbacks):
|
1355
|
+
"""Adds a list of callbacks to the server callback handler."""
|
1356
|
+
self.callback_handler.add_callbacks(callbacks)
|
1357
|
+
|
1358
|
+
def customize_server_response(self, server_response: dict, client_id) -> dict:
|
1359
|
+
"""Customizes the server response with any additional information."""
|
1360
|
+
return server_response
|
1361
|
+
|
1362
|
+
def customize_server_payload(self, payload):
|
1363
|
+
"""Customizes the server payload before sending to the client."""
|
1364
|
+
return payload
|
1365
|
+
|
1366
|
+
@abstractmethod
|
1367
|
+
async def _process_reports(self) -> None:
|
1368
|
+
"""Processes a client report."""
|
1369
|
+
|
1370
|
+
async def periodic_task(self) -> None:
|
1371
|
+
"""
|
1372
|
+
Async method called periodically in asynchronous mode.
|
1373
|
+
"""
|
1374
|
+
|
1375
|
+
def clients_selected(self, selected_clients) -> None:
|
1376
|
+
"""
|
1377
|
+
Method called after clients have been selected in each round."""
|
1378
|
+
|
1379
|
+
def clients_processed(self) -> None:
|
1380
|
+
"""Additional work to be performed after client reports have been processed."""
|
1381
|
+
|
1382
|
+
def training_will_start(self) -> None:
|
1383
|
+
"""
|
1384
|
+
Method called before selecting clients for the first round of training.
|
1385
|
+
"""
|
1386
|
+
if Config().is_central_server():
|
1387
|
+
if self.disable_clients:
|
1388
|
+
logging.info("No clients are launched (server:disable_clients = true)")
|
1389
|
+
else:
|
1390
|
+
Server._start_clients(client=self.client)
|
1391
|
+
|
1392
|
+
def server_will_close(self) -> None:
|
1393
|
+
"""
|
1394
|
+
Method called before closing the server.
|
1395
|
+
"""
|