plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,124 @@
|
|
1
|
+
"""
|
2
|
+
Defines the TrainerCallback class, which is the abstract base class to be subclassed
|
3
|
+
when creating new trainer callbacks.
|
4
|
+
|
5
|
+
Defines a default callback to print training progress.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
from abc import ABC
|
11
|
+
|
12
|
+
from plato.utils import fonts
|
13
|
+
|
14
|
+
|
15
|
+
class TrainerCallback(ABC):
|
16
|
+
"""
|
17
|
+
The abstract base class to be subclassed when creating new trainer callbacks.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def on_train_run_start(self, trainer, config, **kwargs):
|
21
|
+
"""
|
22
|
+
Event called at the start of training run.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def on_train_run_end(self, trainer, config, **kwargs):
|
26
|
+
"""
|
27
|
+
Event called at the end of training run.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def on_train_epoch_start(self, trainer, config, **kwargs):
|
31
|
+
"""
|
32
|
+
Event called at the beginning of a training epoch.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def on_train_step_start(self, trainer, config, batch, **kwargs):
|
36
|
+
"""
|
37
|
+
Event called at the beginning of a training step.
|
38
|
+
|
39
|
+
:param batch: the current batch of training data.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def on_train_step_end(self, trainer, config, batch, loss, **kwargs):
|
43
|
+
"""
|
44
|
+
Event called at the end of a training step.
|
45
|
+
|
46
|
+
:param batch: the current batch of training data.
|
47
|
+
:param loss: the loss computed in the current batch.
|
48
|
+
"""
|
49
|
+
|
50
|
+
def on_train_epoch_end(self, trainer, config, **kwargs):
|
51
|
+
"""
|
52
|
+
Event called at the end of a training epoch.
|
53
|
+
"""
|
54
|
+
|
55
|
+
|
56
|
+
class LogProgressCallback(TrainerCallback):
|
57
|
+
"""
|
58
|
+
A callback which prints a message at the start of each epoch, and at the end of each step.
|
59
|
+
"""
|
60
|
+
|
61
|
+
def on_train_run_start(self, trainer, config, **kwargs):
|
62
|
+
"""
|
63
|
+
Event called at the start of training run.
|
64
|
+
"""
|
65
|
+
if trainer.client_id == 0:
|
66
|
+
logging.info(
|
67
|
+
"[Server #%s] Loading the dataset with size %d.",
|
68
|
+
os.getpid(),
|
69
|
+
len(list(trainer.sampler)),
|
70
|
+
)
|
71
|
+
else:
|
72
|
+
logging.info(
|
73
|
+
"[Client #%d] Loading the dataset with size %d.",
|
74
|
+
trainer.client_id,
|
75
|
+
len(list(trainer.sampler)),
|
76
|
+
)
|
77
|
+
|
78
|
+
def on_train_epoch_start(self, trainer, config, **kwargs):
|
79
|
+
"""
|
80
|
+
Event called at the beginning of a training epoch.
|
81
|
+
"""
|
82
|
+
if trainer.client_id == 0:
|
83
|
+
logging.info(
|
84
|
+
fonts.colourize(
|
85
|
+
f"[Server #{os.getpid()}] Started training epoch {trainer.current_epoch}."
|
86
|
+
)
|
87
|
+
)
|
88
|
+
else:
|
89
|
+
logging.info(
|
90
|
+
fonts.colourize(
|
91
|
+
f"[Client #{trainer.client_id}] Started training epoch {trainer.current_epoch}."
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
def on_train_step_end(self, trainer, config, batch=None, loss=None, **kwargs):
|
96
|
+
"""
|
97
|
+
Event called at the end of a training step.
|
98
|
+
|
99
|
+
:param batch: the current batch of training data.
|
100
|
+
:param loss: the loss computed in the current batch.
|
101
|
+
"""
|
102
|
+
log_interval = 10
|
103
|
+
|
104
|
+
if batch % log_interval == 0:
|
105
|
+
if trainer.client_id == 0:
|
106
|
+
logging.info(
|
107
|
+
"[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
|
108
|
+
os.getpid(),
|
109
|
+
trainer.current_epoch,
|
110
|
+
config["epochs"],
|
111
|
+
batch,
|
112
|
+
len(trainer.train_loader),
|
113
|
+
loss.data.item(),
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
logging.info(
|
117
|
+
"[Client #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
|
118
|
+
trainer.client_id,
|
119
|
+
trainer.current_epoch,
|
120
|
+
config["epochs"],
|
121
|
+
batch,
|
122
|
+
len(trainer.train_loader),
|
123
|
+
loss.data.item(),
|
124
|
+
)
|
plato/client.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
"""
|
2
|
+
Starting point for a Plato federated learning client.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
|
9
|
+
from plato.clients import registry as client_registry
|
10
|
+
from plato.config import Config
|
11
|
+
|
12
|
+
|
13
|
+
def run(client_id, port, client=None, edge_server=None, edge_client=None, trainer=None):
|
14
|
+
"""Starting a client to connect to the server."""
|
15
|
+
Config().args.id = client_id
|
16
|
+
if port is not None:
|
17
|
+
Config().args.port = port
|
18
|
+
|
19
|
+
# If a server needs to be running concurrently
|
20
|
+
if Config().is_edge_server():
|
21
|
+
Config().trainer = Config().trainer._replace(
|
22
|
+
rounds=Config().algorithm.local_rounds
|
23
|
+
)
|
24
|
+
|
25
|
+
if edge_server is None:
|
26
|
+
from plato.clients import edge
|
27
|
+
from plato.servers import fedavg_cs
|
28
|
+
|
29
|
+
server = fedavg_cs.Server()
|
30
|
+
client = edge.Client(server)
|
31
|
+
else:
|
32
|
+
# A customized edge server
|
33
|
+
if trainer is not None:
|
34
|
+
server = edge_server(trainer=trainer())
|
35
|
+
else:
|
36
|
+
server = edge_server()
|
37
|
+
client = edge_client(server)
|
38
|
+
|
39
|
+
server.configure()
|
40
|
+
client.configure()
|
41
|
+
|
42
|
+
logging.info("Starting an edge server as client #%d", Config().args.id)
|
43
|
+
asyncio.ensure_future(client.start_client())
|
44
|
+
|
45
|
+
logging.info(
|
46
|
+
"Starting an edge server as server #%d on port %d",
|
47
|
+
os.getpid(),
|
48
|
+
Config().args.port,
|
49
|
+
)
|
50
|
+
server.start(port=Config().args.port)
|
51
|
+
|
52
|
+
else:
|
53
|
+
if client is None:
|
54
|
+
client = client_registry.get()
|
55
|
+
logging.info("Starting a %s client #%d.", Config().clients.type, client_id)
|
56
|
+
else:
|
57
|
+
client.client_id = client_id
|
58
|
+
logging.info("Starting a custom client #%d", client_id)
|
59
|
+
|
60
|
+
client.configure()
|
61
|
+
|
62
|
+
loop = asyncio.get_event_loop()
|
63
|
+
loop.run_until_complete(client.start_client())
|
64
|
+
|
65
|
+
|
66
|
+
if __name__ == "__main__":
|
67
|
+
run(Config().args.id, Config().args.port)
|
File without changes
|
plato/clients/base.py
ADDED
@@ -0,0 +1,467 @@
|
|
1
|
+
"""
|
2
|
+
The base class for all federated learning clients on edge devices or edge servers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import pickle
|
9
|
+
import re
|
10
|
+
import sys
|
11
|
+
import time
|
12
|
+
import uuid
|
13
|
+
from abc import abstractmethod
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
import socketio
|
17
|
+
|
18
|
+
from plato.callbacks.client import LogProgressCallback
|
19
|
+
from plato.callbacks.handler import CallbackHandler
|
20
|
+
from plato.config import Config
|
21
|
+
from plato.utils import s3
|
22
|
+
|
23
|
+
|
24
|
+
# pylint: disable=unused-argument, protected-access
|
25
|
+
class ClientEvents(socketio.AsyncClientNamespace):
|
26
|
+
"""A custom namespace for socketio.AsyncServer."""
|
27
|
+
|
28
|
+
def __init__(self, namespace, plato_client):
|
29
|
+
super().__init__(namespace)
|
30
|
+
self.plato_client = plato_client
|
31
|
+
self.client_id = plato_client.client_id
|
32
|
+
|
33
|
+
async def on_connect(self):
|
34
|
+
"""Upon a new connection to the server."""
|
35
|
+
logging.info("[Client #%d] Connected to the server.", self.client_id)
|
36
|
+
|
37
|
+
async def on_disconnect(self):
|
38
|
+
"""Upon a disconnection event."""
|
39
|
+
logging.info(
|
40
|
+
"[Client #%d] The server disconnected the connection.", self.client_id
|
41
|
+
)
|
42
|
+
self.plato_client._clear_checkpoint_files()
|
43
|
+
os._exit(0)
|
44
|
+
|
45
|
+
async def on_connect_error(self, data):
|
46
|
+
"""Upon a failed connection attempt to the server."""
|
47
|
+
logging.info(
|
48
|
+
"[Client #%d] A connection attempt to the server failed.", self.client_id
|
49
|
+
)
|
50
|
+
|
51
|
+
async def on_payload_to_arrive(self, data):
|
52
|
+
"""New payload is about to arrive from the server."""
|
53
|
+
await self.plato_client._payload_to_arrive(data["response"])
|
54
|
+
|
55
|
+
async def on_request_update(self, data):
|
56
|
+
"""The server is requesting an urgent model update."""
|
57
|
+
await self.plato_client._request_update(data)
|
58
|
+
|
59
|
+
async def on_chunk(self, data):
|
60
|
+
"""A chunk of data from the server arrived."""
|
61
|
+
await self.plato_client._chunk_arrived(data["data"])
|
62
|
+
|
63
|
+
async def on_payload(self, data):
|
64
|
+
"""A portion of the new payload from the server arrived."""
|
65
|
+
await self.plato_client._payload_arrived(data["id"])
|
66
|
+
|
67
|
+
async def on_payload_done(self, data):
|
68
|
+
"""All of the new payload sent from the server arrived."""
|
69
|
+
if "s3_key" in data:
|
70
|
+
await self.plato_client._payload_done(data["id"], s3_key=data["s3_key"])
|
71
|
+
else:
|
72
|
+
await self.plato_client._payload_done(data["id"])
|
73
|
+
|
74
|
+
|
75
|
+
class Client:
|
76
|
+
"""A basic federated learning client."""
|
77
|
+
|
78
|
+
def __init__(self, callbacks=None) -> None:
|
79
|
+
self.client_id = Config().args.id
|
80
|
+
self.current_round = 0
|
81
|
+
self.sio = None
|
82
|
+
self.chunks = []
|
83
|
+
self.server_payload = None
|
84
|
+
self.s3_client = None
|
85
|
+
self.outbound_processor = None
|
86
|
+
self.inbound_processor = None
|
87
|
+
self.payload = None
|
88
|
+
self.report = None
|
89
|
+
|
90
|
+
self.processing_time = 0
|
91
|
+
|
92
|
+
self.comm_simulation = (
|
93
|
+
Config().clients.comm_simulation
|
94
|
+
if hasattr(Config().clients, "comm_simulation")
|
95
|
+
else True
|
96
|
+
)
|
97
|
+
|
98
|
+
if hasattr(Config().algorithm, "cross_silo") and not Config().is_edge_server():
|
99
|
+
self.edge_server_id = None
|
100
|
+
|
101
|
+
assert hasattr(Config().algorithm, "total_silos")
|
102
|
+
|
103
|
+
# Starting from the default client callback class, add all supplied server callbacks
|
104
|
+
self.callbacks = [LogProgressCallback]
|
105
|
+
if callbacks is not None:
|
106
|
+
self.callbacks.extend(callbacks)
|
107
|
+
self.callback_handler = CallbackHandler(self.callbacks)
|
108
|
+
|
109
|
+
def __repr__(self):
|
110
|
+
return f"Client #{self.client_id}"
|
111
|
+
|
112
|
+
async def start_client(self) -> None:
|
113
|
+
"""Startup function for a client."""
|
114
|
+
if hasattr(Config().algorithm, "cross_silo") and not Config().is_edge_server():
|
115
|
+
# Contact one of the edge servers
|
116
|
+
self.edge_server_id = self.get_edge_server_id()
|
117
|
+
|
118
|
+
logging.info(
|
119
|
+
"[Client #%d] Contacting Edge Server #%d.",
|
120
|
+
self.client_id,
|
121
|
+
self.edge_server_id,
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
await asyncio.sleep(5)
|
125
|
+
logging.info("[Client #%d] Contacting the server.", self.client_id)
|
126
|
+
|
127
|
+
self.sio = socketio.AsyncClient(reconnection=True)
|
128
|
+
self.sio.register_namespace(ClientEvents(namespace="/", plato_client=self))
|
129
|
+
|
130
|
+
if hasattr(Config().server, "s3_endpoint_url"):
|
131
|
+
self.s3_client = s3.S3()
|
132
|
+
|
133
|
+
if hasattr(Config().server, "use_https"):
|
134
|
+
uri = f"https://{Config().server.address}"
|
135
|
+
else:
|
136
|
+
uri = f"http://{Config().server.address}"
|
137
|
+
|
138
|
+
if hasattr(Config().server, "port"):
|
139
|
+
# If we are not using a production server deployed in the cloud
|
140
|
+
if (
|
141
|
+
hasattr(Config().algorithm, "cross_silo")
|
142
|
+
and not Config().is_edge_server()
|
143
|
+
):
|
144
|
+
uri = f"{uri}:{int(Config().server.port) + int(self.edge_server_id)}"
|
145
|
+
else:
|
146
|
+
uri = f"{uri}:{Config().server.port}"
|
147
|
+
|
148
|
+
logging.info("[%s] Connecting to the server at %s.", self, uri)
|
149
|
+
await self.sio.connect(uri, wait_timeout=600)
|
150
|
+
await self.sio.emit("client_alive", {"pid": os.getpid(), "id": self.client_id})
|
151
|
+
|
152
|
+
logging.info("[Client #%d] Waiting to be selected.", self.client_id)
|
153
|
+
await self.sio.wait()
|
154
|
+
|
155
|
+
def get_edge_server_id(self):
|
156
|
+
"""Returns the edge server id of the client in cross-silo FL."""
|
157
|
+
launched_client_num = (
|
158
|
+
min(
|
159
|
+
Config().trainer.max_concurrency
|
160
|
+
* max(1, Config().gpu_count())
|
161
|
+
* Config().algorithm.total_silos,
|
162
|
+
Config().clients.per_round,
|
163
|
+
)
|
164
|
+
if hasattr(Config().trainer, "max_concurrency")
|
165
|
+
else Config().clients.per_round
|
166
|
+
)
|
167
|
+
|
168
|
+
edges_launched_clients = [
|
169
|
+
len(i)
|
170
|
+
for i in np.array_split(
|
171
|
+
np.arange(launched_client_num), Config().algorithm.total_silos
|
172
|
+
)
|
173
|
+
]
|
174
|
+
|
175
|
+
total = 0
|
176
|
+
for i, count in enumerate(edges_launched_clients):
|
177
|
+
total += count
|
178
|
+
if self.client_id <= total:
|
179
|
+
return i + 1 + Config().clients.total_clients
|
180
|
+
|
181
|
+
async def _payload_to_arrive(self, response) -> None:
|
182
|
+
"""Upon receiving a response from the server."""
|
183
|
+
self.current_round = response["current_round"]
|
184
|
+
|
185
|
+
# Update (virtual) client id for client, trainer and algorithm
|
186
|
+
self.client_id = response["id"]
|
187
|
+
|
188
|
+
logging.info("[Client #%d] Selected by the server.", self.client_id)
|
189
|
+
|
190
|
+
self.process_server_response(response)
|
191
|
+
|
192
|
+
self._load_data()
|
193
|
+
self.configure()
|
194
|
+
self._allocate_data()
|
195
|
+
|
196
|
+
self.server_payload = None
|
197
|
+
|
198
|
+
if self.comm_simulation:
|
199
|
+
payload_filename = response["payload_filename"]
|
200
|
+
with open(payload_filename, "rb") as payload_file:
|
201
|
+
self.server_payload = pickle.load(payload_file)
|
202
|
+
|
203
|
+
payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
|
204
|
+
|
205
|
+
logging.info(
|
206
|
+
"[%s] Received %.2f MB of payload data from the server (simulated).",
|
207
|
+
self,
|
208
|
+
payload_size / 1024**2,
|
209
|
+
)
|
210
|
+
|
211
|
+
await self._handle_payload(self.server_payload)
|
212
|
+
|
213
|
+
async def _handle_payload(self, inbound_payload):
|
214
|
+
"""Handles the inbound payload upon receiving it from the server."""
|
215
|
+
self.inbound_received(self.inbound_processor)
|
216
|
+
self.callback_handler.call_event(
|
217
|
+
"on_inbound_received", self, self.inbound_processor
|
218
|
+
)
|
219
|
+
|
220
|
+
tic = time.perf_counter()
|
221
|
+
processed_inbound_payload = self.inbound_processor.process(inbound_payload)
|
222
|
+
self.processing_time = time.perf_counter() - tic
|
223
|
+
|
224
|
+
# Inbound data is processed, computing outbound response
|
225
|
+
report, outbound_payload = await self.inbound_processed(
|
226
|
+
processed_inbound_payload
|
227
|
+
)
|
228
|
+
self.callback_handler.call_event(
|
229
|
+
"on_inbound_processed", self, processed_inbound_payload
|
230
|
+
)
|
231
|
+
|
232
|
+
# Outbound data is ready to be processed
|
233
|
+
tic = time.perf_counter()
|
234
|
+
self.outbound_ready(report, self.outbound_processor)
|
235
|
+
self.callback_handler.call_event(
|
236
|
+
"on_outbound_ready", self, report, self.outbound_processor
|
237
|
+
)
|
238
|
+
processed_outbound_payload = self.outbound_processor.process(outbound_payload)
|
239
|
+
self.processing_time += time.perf_counter() - tic
|
240
|
+
report.processing_time = self.processing_time
|
241
|
+
|
242
|
+
# Sending the client report as metadata to the server (payload to follow)
|
243
|
+
await self.sio.emit(
|
244
|
+
"client_report", {"id": self.client_id, "report": pickle.dumps(report)}
|
245
|
+
)
|
246
|
+
|
247
|
+
# Sending the client training payload to the server
|
248
|
+
await self._send(processed_outbound_payload)
|
249
|
+
|
250
|
+
def inbound_received(self, inbound_processor):
|
251
|
+
"""
|
252
|
+
Override this method to complete additional tasks before the inbound processors start to
|
253
|
+
process the data received from the server.
|
254
|
+
"""
|
255
|
+
|
256
|
+
async def inbound_processed(self, processed_inbound_payload):
|
257
|
+
"""
|
258
|
+
Override this method to conduct customized operations to generate a client's response to
|
259
|
+
the server when inbound payload from the server has been processed.
|
260
|
+
"""
|
261
|
+
report, outbound_payload = await self._start_training(processed_inbound_payload)
|
262
|
+
return report, outbound_payload
|
263
|
+
|
264
|
+
def outbound_ready(self, report, outbound_processor):
|
265
|
+
"""
|
266
|
+
Override this method to complete additional tasks before the outbound processors start
|
267
|
+
to process the data to be sent to the server.
|
268
|
+
"""
|
269
|
+
|
270
|
+
async def _chunk_arrived(self, data) -> None:
|
271
|
+
"""Upon receiving a chunk of data from the server."""
|
272
|
+
self.chunks.append(data)
|
273
|
+
|
274
|
+
async def _request_update(self, data) -> None:
|
275
|
+
"""Upon receiving a request for an urgent model update."""
|
276
|
+
logging.info(
|
277
|
+
"[Client #%s] Urgent request received for model update at time %s.",
|
278
|
+
data["client_id"],
|
279
|
+
data["time"],
|
280
|
+
)
|
281
|
+
|
282
|
+
report, payload = await self._obtain_model_update(
|
283
|
+
client_id=data["client_id"],
|
284
|
+
requested_time=data["time"],
|
285
|
+
)
|
286
|
+
|
287
|
+
# Process outbound data when necessary
|
288
|
+
self.callback_handler.call_event(
|
289
|
+
"on_outbound_ready", self, report, self.outbound_processor
|
290
|
+
)
|
291
|
+
self.outbound_ready(report, self.outbound_processor)
|
292
|
+
payload = self.outbound_processor.process(payload)
|
293
|
+
|
294
|
+
# Sending the client report as metadata to the server (payload to follow)
|
295
|
+
await self.sio.emit(
|
296
|
+
"client_report", {"id": self.client_id, "report": pickle.dumps(report)}
|
297
|
+
)
|
298
|
+
|
299
|
+
# Sending the client training payload to the server
|
300
|
+
await self._send(payload)
|
301
|
+
|
302
|
+
async def _payload_arrived(self, client_id) -> None:
|
303
|
+
"""Upon receiving a portion of the new payload from the server."""
|
304
|
+
assert client_id == self.client_id
|
305
|
+
|
306
|
+
payload = b"".join(self.chunks)
|
307
|
+
_data = pickle.loads(payload)
|
308
|
+
self.chunks = []
|
309
|
+
|
310
|
+
if self.server_payload is None:
|
311
|
+
self.server_payload = _data
|
312
|
+
elif isinstance(self.server_payload, list):
|
313
|
+
self.server_payload.append(_data)
|
314
|
+
else:
|
315
|
+
self.server_payload = [self.server_payload]
|
316
|
+
self.server_payload.append(_data)
|
317
|
+
|
318
|
+
async def _payload_done(self, client_id, s3_key=None) -> None:
|
319
|
+
"""Upon receiving all the new payload from the server."""
|
320
|
+
payload_size = 0
|
321
|
+
|
322
|
+
if s3_key is None:
|
323
|
+
if isinstance(self.server_payload, list):
|
324
|
+
for _data in self.server_payload:
|
325
|
+
payload_size += sys.getsizeof(pickle.dumps(_data))
|
326
|
+
elif isinstance(self.server_payload, dict):
|
327
|
+
for key, value in self.server_payload.items():
|
328
|
+
payload_size += sys.getsizeof(pickle.dumps({key: value}))
|
329
|
+
else:
|
330
|
+
payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
|
331
|
+
else:
|
332
|
+
self.server_payload = self.s3_client.receive_from_s3(s3_key)
|
333
|
+
payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
|
334
|
+
|
335
|
+
assert client_id == self.client_id
|
336
|
+
|
337
|
+
logging.info(
|
338
|
+
"[Client #%d] Received %.2f MB of payload data from the server.",
|
339
|
+
client_id,
|
340
|
+
payload_size / 1024**2,
|
341
|
+
)
|
342
|
+
|
343
|
+
await self._handle_payload(self.server_payload)
|
344
|
+
|
345
|
+
async def _start_training(self, inbound_payload):
|
346
|
+
"""Complete one round of training on this client."""
|
347
|
+
self._load_payload(inbound_payload)
|
348
|
+
|
349
|
+
report, outbound_payload = await self._train()
|
350
|
+
|
351
|
+
if Config().is_edge_server():
|
352
|
+
logging.info(
|
353
|
+
"[Server #%d] Model aggregated on edge server (%s).", os.getpid(), self
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
logging.info("[%s] Model trained.", self)
|
357
|
+
|
358
|
+
return report, outbound_payload
|
359
|
+
|
360
|
+
async def _send_in_chunks(self, data) -> None:
|
361
|
+
"""Sending a bytes object in fixed-sized chunks to the client."""
|
362
|
+
step = 1024**2
|
363
|
+
chunks = [data[i : i + step] for i in range(0, len(data), step)]
|
364
|
+
|
365
|
+
for chunk in chunks:
|
366
|
+
await self.sio.emit("chunk", {"data": chunk})
|
367
|
+
|
368
|
+
await self.sio.emit("client_payload", {"id": self.client_id})
|
369
|
+
|
370
|
+
async def _send(self, payload) -> None:
|
371
|
+
"""Sending the client payload to the server using simulation, S3 or socket.io."""
|
372
|
+
if self.comm_simulation:
|
373
|
+
# If we are using the filesystem to simulate communication over a network
|
374
|
+
model_name = (
|
375
|
+
Config().trainer.model_name
|
376
|
+
if hasattr(Config().trainer, "model_name")
|
377
|
+
else "custom"
|
378
|
+
)
|
379
|
+
if "/" in model_name:
|
380
|
+
model_name = model_name.replace("/", "_")
|
381
|
+
checkpoint_path = Config().params["checkpoint_path"]
|
382
|
+
payload_filename = (
|
383
|
+
f"{checkpoint_path}/{model_name}_client_{self.client_id}.pth"
|
384
|
+
)
|
385
|
+
with open(payload_filename, "wb") as payload_file:
|
386
|
+
pickle.dump(payload, payload_file)
|
387
|
+
|
388
|
+
data_size = sys.getsizeof(pickle.dumps(payload))
|
389
|
+
|
390
|
+
logging.info(
|
391
|
+
"[%s] Sent %.2f MB of payload data to the server (simulated).",
|
392
|
+
self,
|
393
|
+
data_size / 1024**2,
|
394
|
+
)
|
395
|
+
|
396
|
+
else:
|
397
|
+
metadata = {"id": self.client_id}
|
398
|
+
|
399
|
+
if self.s3_client is not None:
|
400
|
+
unique_key = uuid.uuid4().hex[:6].upper()
|
401
|
+
s3_key = f"client_payload_{self.client_id}_{unique_key}"
|
402
|
+
self.s3_client.send_to_s3(s3_key, payload)
|
403
|
+
data_size = sys.getsizeof(pickle.dumps(payload))
|
404
|
+
metadata["s3_key"] = s3_key
|
405
|
+
else:
|
406
|
+
if isinstance(payload, list):
|
407
|
+
data_size: int = 0
|
408
|
+
|
409
|
+
for data in payload:
|
410
|
+
_data = pickle.dumps(data)
|
411
|
+
await self._send_in_chunks(_data)
|
412
|
+
data_size += sys.getsizeof(_data)
|
413
|
+
else:
|
414
|
+
_data = pickle.dumps(payload)
|
415
|
+
await self._send_in_chunks(_data)
|
416
|
+
data_size = sys.getsizeof(_data)
|
417
|
+
|
418
|
+
await self.sio.emit("client_payload_done", metadata)
|
419
|
+
|
420
|
+
logging.info(
|
421
|
+
"[%s] Sent %.2f MB of payload data to the server.",
|
422
|
+
self,
|
423
|
+
data_size / 1024**2,
|
424
|
+
)
|
425
|
+
|
426
|
+
def _clear_checkpoint_files(self):
|
427
|
+
"""Delete all the temporary checkpoint files created by the client."""
|
428
|
+
model_path = Config().params["model_path"]
|
429
|
+
for filename in os.listdir(model_path):
|
430
|
+
split = re.match(
|
431
|
+
r"(?P<client_id>\d+)_(?P<epoch>\d+)_(?P<training_time>\d+.\d+).pth",
|
432
|
+
filename,
|
433
|
+
)
|
434
|
+
if split is not None:
|
435
|
+
file_path = f"{model_path}/{filename}"
|
436
|
+
os.remove(file_path)
|
437
|
+
|
438
|
+
def add_callbacks(self, callbacks):
|
439
|
+
"""Adds a list of callbacks to the client callback handler."""
|
440
|
+
self.callback_handler.add_callbacks(callbacks)
|
441
|
+
|
442
|
+
@abstractmethod
|
443
|
+
async def _train(self):
|
444
|
+
"""The machine learning training workload on a client."""
|
445
|
+
|
446
|
+
@abstractmethod
|
447
|
+
def configure(self) -> None:
|
448
|
+
"""Prepare this client for training."""
|
449
|
+
|
450
|
+
@abstractmethod
|
451
|
+
def _load_data(self) -> None:
|
452
|
+
"""Generating data and loading them onto this client."""
|
453
|
+
|
454
|
+
@abstractmethod
|
455
|
+
def _allocate_data(self) -> None:
|
456
|
+
"""Allocate training or testing dataset of this client."""
|
457
|
+
|
458
|
+
@abstractmethod
|
459
|
+
def _load_payload(self, server_payload) -> None:
|
460
|
+
"""Loading the payload onto this client."""
|
461
|
+
|
462
|
+
def process_server_response(self, server_response) -> None:
|
463
|
+
"""Additional client-specific processing on the server response."""
|
464
|
+
|
465
|
+
@abstractmethod
|
466
|
+
async def _obtain_model_update(self, client_id, requested_time):
|
467
|
+
"""Retrieving a model update corrsponding to a particular wall clock time."""
|