plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/clients/edge.py
ADDED
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning client at the edge server in a cross-silo training workload.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import time
|
6
|
+
from types import SimpleNamespace
|
7
|
+
|
8
|
+
from plato.clients import simple
|
9
|
+
from plato.config import Config
|
10
|
+
from plato.processors import registry as processor_registry
|
11
|
+
|
12
|
+
|
13
|
+
class Client(simple.Client):
|
14
|
+
"""A federated learning client at the edge server in a cross-silo training workload."""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
server,
|
19
|
+
model=None,
|
20
|
+
datasource=None,
|
21
|
+
algorithm=None,
|
22
|
+
trainer=None,
|
23
|
+
callbacks=None,
|
24
|
+
):
|
25
|
+
super().__init__(
|
26
|
+
model=model,
|
27
|
+
datasource=datasource,
|
28
|
+
algorithm=algorithm,
|
29
|
+
trainer=trainer,
|
30
|
+
callbacks=callbacks,
|
31
|
+
)
|
32
|
+
self.server = server
|
33
|
+
|
34
|
+
def configure(self) -> None:
|
35
|
+
"""Prepare this edge client for training."""
|
36
|
+
super().configure()
|
37
|
+
|
38
|
+
# Pass inbound and outbound data payloads through processors for
|
39
|
+
# additional data processing
|
40
|
+
self.outbound_processor, self.inbound_processor = processor_registry.get(
|
41
|
+
"Client", client_id=self.client_id, trainer=self.server.trainer
|
42
|
+
)
|
43
|
+
|
44
|
+
def load_data(self) -> None:
|
45
|
+
"""The edge client does not need to train models using local data."""
|
46
|
+
|
47
|
+
def _load_payload(self, server_payload) -> None:
|
48
|
+
"""The edge client loads the model from the central server."""
|
49
|
+
self.server.algorithm.load_weights(server_payload)
|
50
|
+
|
51
|
+
def process_server_response(self, server_response):
|
52
|
+
"""Additional client-specific processing on the server response."""
|
53
|
+
if "current_global_round" in server_response:
|
54
|
+
self.server.current_global_round = server_response["current_global_round"]
|
55
|
+
|
56
|
+
async def _train(self):
|
57
|
+
"""The aggregation workload on an edge client."""
|
58
|
+
training_start_time = time.perf_counter()
|
59
|
+
# Signal edge server to select clients to start a new round of local aggregation
|
60
|
+
self.server.new_global_round_begins.set()
|
61
|
+
|
62
|
+
# Wait for the edge server to finish model aggregation
|
63
|
+
await self.server.model_aggregated.wait()
|
64
|
+
self.server.model_aggregated.clear()
|
65
|
+
|
66
|
+
# Extract model weights and biases
|
67
|
+
weights = self.server.algorithm.extract_weights()
|
68
|
+
|
69
|
+
average_accuracy = self.server.average_accuracy
|
70
|
+
accuracy = self.server.accuracy
|
71
|
+
|
72
|
+
if (
|
73
|
+
hasattr(Config().clients, "sleep_simulation")
|
74
|
+
and Config().clients.sleep_simulation
|
75
|
+
):
|
76
|
+
training_time = self.server.edge_training_time
|
77
|
+
self.server.edge_training_time = 0
|
78
|
+
else:
|
79
|
+
training_time = time.perf_counter() - training_start_time
|
80
|
+
|
81
|
+
comm_time = time.time()
|
82
|
+
|
83
|
+
edge_server_comm_time = self.server.edge_comm_time
|
84
|
+
self.server.edge_comm_time = 0
|
85
|
+
|
86
|
+
# Generate a report for the central server
|
87
|
+
report = SimpleNamespace(
|
88
|
+
client_id=self.client_id,
|
89
|
+
num_samples=self.server.total_samples,
|
90
|
+
accuracy=accuracy,
|
91
|
+
training_time=training_time,
|
92
|
+
comm_time=comm_time,
|
93
|
+
update_response=False,
|
94
|
+
average_accuracy=average_accuracy,
|
95
|
+
edge_server_comm_overhead=self.server.comm_overhead,
|
96
|
+
edge_server_comm_time=edge_server_comm_time,
|
97
|
+
)
|
98
|
+
|
99
|
+
self._report = self.customize_report(report)
|
100
|
+
|
101
|
+
self.server.comm_overhead = 0
|
102
|
+
|
103
|
+
return self._report, weights
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""
|
2
|
+
A personalized federated learning client that saves its local layers before
|
3
|
+
sending the shared global model to the server after local training.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from collections import OrderedDict
|
7
|
+
|
8
|
+
from plato.clients import simple
|
9
|
+
from plato.config import Config
|
10
|
+
|
11
|
+
|
12
|
+
class Client(simple.Client):
|
13
|
+
"""
|
14
|
+
A personalized federated learning client that saves its local layers before sending the
|
15
|
+
shared global model to the server after local training.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def outbound_ready(self, report, outbound_processor):
|
19
|
+
super().outbound_ready(report, outbound_processor)
|
20
|
+
weights = self.algorithm.extract_weights()
|
21
|
+
|
22
|
+
# Save local layers before giving them to the outbound processor
|
23
|
+
if hasattr(Config().algorithm, "local_layer_names"):
|
24
|
+
# Extract weights of desired local layers
|
25
|
+
local_layers = OrderedDict(
|
26
|
+
[
|
27
|
+
(name, param)
|
28
|
+
for name, param in weights.items()
|
29
|
+
if any(
|
30
|
+
param_name in name.strip().split(".")
|
31
|
+
for param_name in Config().algorithm.local_layer_names
|
32
|
+
)
|
33
|
+
]
|
34
|
+
)
|
35
|
+
|
36
|
+
model_path = Config().params["model_path"]
|
37
|
+
model_name = Config().trainer.model_name
|
38
|
+
filename = f"{model_path}/{model_name}_{self.client_id}_local_layers.pth"
|
39
|
+
|
40
|
+
self.algorithm.save_local_layers(local_layers, filename)
|
plato/clients/mistnet.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning client for MistNet.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
|
7
|
+
Differential Privacy," found in docs/papers.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import logging
|
11
|
+
import time
|
12
|
+
from types import SimpleNamespace
|
13
|
+
|
14
|
+
from plato.config import Config
|
15
|
+
from plato.clients import simple
|
16
|
+
|
17
|
+
|
18
|
+
class Client(simple.Client):
|
19
|
+
"""A federated learning client for MistNet."""
|
20
|
+
|
21
|
+
async def _train(self):
|
22
|
+
"""A MistNet client only uses the first several layers in a forward pass."""
|
23
|
+
logging.info("Training on MistNet client #%d", self.client_id)
|
24
|
+
|
25
|
+
# Since training is performed on the server, the client should not be doing
|
26
|
+
# its own testing for the model accuracy
|
27
|
+
assert not Config().clients.do_test
|
28
|
+
|
29
|
+
tic = time.perf_counter()
|
30
|
+
|
31
|
+
# Perform a forward pass till the cut layer in the model
|
32
|
+
features = self.algorithm.extract_features(self.trainset, self.sampler)
|
33
|
+
|
34
|
+
training_time = time.perf_counter() - tic
|
35
|
+
|
36
|
+
# Generate a report for the server, performing model testing if applicable
|
37
|
+
comm_time = time.time()
|
38
|
+
return (
|
39
|
+
SimpleNamespace(
|
40
|
+
client_id=self.client_id,
|
41
|
+
num_samples=self.sampler.num_samples(),
|
42
|
+
accuracy=0,
|
43
|
+
training_time=training_time,
|
44
|
+
comm_time=comm_time,
|
45
|
+
update_response=False,
|
46
|
+
payload_length=len(features),
|
47
|
+
),
|
48
|
+
features,
|
49
|
+
)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
The registry that contains all available federated learning clients.
|
3
|
+
|
4
|
+
Having a registry of all available classes is convenient for retrieving an instance based
|
5
|
+
on a configuration at run-time.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
|
10
|
+
from plato.config import Config
|
11
|
+
from plato.clients import (
|
12
|
+
self_supervised_learning,
|
13
|
+
simple,
|
14
|
+
mistnet,
|
15
|
+
fedavg_personalized,
|
16
|
+
split_learning,
|
17
|
+
)
|
18
|
+
|
19
|
+
registered_clients = {
|
20
|
+
"simple": simple.Client,
|
21
|
+
"mistnet": mistnet.Client,
|
22
|
+
"fedavg_personalized": fedavg_personalized.Client,
|
23
|
+
"self_supervised_learning": self_supervised_learning.Client,
|
24
|
+
"split_learning": split_learning.Client,
|
25
|
+
}
|
26
|
+
|
27
|
+
|
28
|
+
def get(model=None, datasource=None, algorithm=None, trainer=None):
|
29
|
+
"""Get an instance of the server."""
|
30
|
+
if hasattr(Config().clients, "type"):
|
31
|
+
client_type = Config().clients.type
|
32
|
+
else:
|
33
|
+
client_type = Config().algorithm.type
|
34
|
+
|
35
|
+
if client_type in registered_clients:
|
36
|
+
logging.info("Client: %s", client_type)
|
37
|
+
registered_client = registered_clients[client_type](
|
38
|
+
model=model, datasource=datasource, algorithm=algorithm, trainer=trainer
|
39
|
+
)
|
40
|
+
else:
|
41
|
+
raise ValueError(f"No such client: {client_type}")
|
42
|
+
|
43
|
+
return registered_client
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""
|
2
|
+
A self-supervised learning (SSL) client prepares a personalized datasource for
|
3
|
+
the personalization process, which will be performed after finishing the FL
|
4
|
+
training process with SSL.
|
5
|
+
|
6
|
+
Specifically, the conventional FL training process with SSL will train the model
|
7
|
+
with the datasource and objective function of SSL. Yet, the datasource used in
|
8
|
+
personalization should be one of supervised learning. Therefore, a client needs
|
9
|
+
to prepare the personalized datasource.
|
10
|
+
"""
|
11
|
+
|
12
|
+
from plato.datasources import registry as datasources_registry
|
13
|
+
from plato.clients import simple
|
14
|
+
|
15
|
+
|
16
|
+
class Client(simple.Client):
|
17
|
+
"""An SSL client to prepare the datasource for personalization."""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
model=None,
|
22
|
+
datasource=None,
|
23
|
+
algorithm=None,
|
24
|
+
trainer=None,
|
25
|
+
callbacks=None,
|
26
|
+
trainer_callbacks=None,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
model=model,
|
30
|
+
datasource=datasource,
|
31
|
+
algorithm=algorithm,
|
32
|
+
trainer=trainer,
|
33
|
+
callbacks=callbacks,
|
34
|
+
trainer_callbacks=trainer_callbacks,
|
35
|
+
)
|
36
|
+
# The datasource used in personalization
|
37
|
+
self.personalized_datasource = None
|
38
|
+
|
39
|
+
def configure(self) -> None:
|
40
|
+
"""Prepare this client for training."""
|
41
|
+
super().configure()
|
42
|
+
|
43
|
+
# Get the personalized datasource
|
44
|
+
if self.personalized_datasource is None:
|
45
|
+
personalized_datasource = datasources_registry.get()
|
46
|
+
|
47
|
+
# Set the train and the test set for the trainer
|
48
|
+
self.trainer.set_personalized_datasets(
|
49
|
+
personalized_datasource.get_train_set(),
|
50
|
+
personalized_datasource.get_test_set(),
|
51
|
+
)
|
plato/clients/simple.py
ADDED
@@ -0,0 +1,218 @@
|
|
1
|
+
"""
|
2
|
+
A basic federated learning client who sends weight updates to the server.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import time
|
7
|
+
from types import SimpleNamespace
|
8
|
+
|
9
|
+
from plato.algorithms import registry as algorithms_registry
|
10
|
+
from plato.clients import base
|
11
|
+
from plato.config import Config
|
12
|
+
from plato.datasources import registry as datasources_registry
|
13
|
+
from plato.processors import registry as processor_registry
|
14
|
+
from plato.samplers import registry as samplers_registry
|
15
|
+
from plato.trainers import registry as trainers_registry
|
16
|
+
from plato.utils import fonts
|
17
|
+
|
18
|
+
|
19
|
+
class Client(base.Client):
|
20
|
+
"""A basic federated learning client who sends simple weight updates."""
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
model=None,
|
25
|
+
datasource=None,
|
26
|
+
algorithm=None,
|
27
|
+
trainer=None,
|
28
|
+
callbacks=None,
|
29
|
+
trainer_callbacks=None,
|
30
|
+
):
|
31
|
+
super().__init__(callbacks=callbacks)
|
32
|
+
# Save the callbacks that will be passed to trainer later
|
33
|
+
self.trainer_callbacks = trainer_callbacks
|
34
|
+
|
35
|
+
self.custom_model = model
|
36
|
+
self.model = None
|
37
|
+
|
38
|
+
self.custom_datasource = datasource
|
39
|
+
self.datasource = None
|
40
|
+
|
41
|
+
self.custom_algorithm = algorithm
|
42
|
+
self.algorithm = None
|
43
|
+
|
44
|
+
self.custom_trainer = trainer
|
45
|
+
self.trainer = None
|
46
|
+
|
47
|
+
self.trainset = None # Training dataset
|
48
|
+
self.testset = None # Testing dataset
|
49
|
+
self.sampler = None
|
50
|
+
self.testset_sampler = None # Sampler for the test set
|
51
|
+
|
52
|
+
self._report = None
|
53
|
+
|
54
|
+
def configure(self) -> None:
|
55
|
+
"""Prepares this client for training."""
|
56
|
+
super().configure()
|
57
|
+
|
58
|
+
if self.model is None and self.custom_model is not None:
|
59
|
+
self.model = self.custom_model
|
60
|
+
|
61
|
+
if self.trainer is None and self.custom_trainer is None:
|
62
|
+
self.trainer = trainers_registry.get(
|
63
|
+
model=self.model, callbacks=self.trainer_callbacks
|
64
|
+
)
|
65
|
+
elif self.trainer is None and self.custom_trainer is not None:
|
66
|
+
self.trainer = self.custom_trainer(
|
67
|
+
model=self.model, callbacks=self.trainer_callbacks
|
68
|
+
)
|
69
|
+
|
70
|
+
self.trainer.set_client_id(self.client_id)
|
71
|
+
|
72
|
+
if self.algorithm is None and self.custom_algorithm is None:
|
73
|
+
self.algorithm = algorithms_registry.get(trainer=self.trainer)
|
74
|
+
elif self.algorithm is None and self.custom_algorithm is not None:
|
75
|
+
self.algorithm = self.custom_algorithm(trainer=self.trainer)
|
76
|
+
|
77
|
+
self.algorithm.set_client_id(self.client_id)
|
78
|
+
|
79
|
+
# Pass inbound and outbound data payloads through processors for
|
80
|
+
# additional data processing
|
81
|
+
self.outbound_processor, self.inbound_processor = processor_registry.get(
|
82
|
+
"Client", client_id=self.client_id, trainer=self.trainer
|
83
|
+
)
|
84
|
+
|
85
|
+
# Setting up the data sampler
|
86
|
+
if self.datasource:
|
87
|
+
self.sampler = samplers_registry.get(self.datasource, self.client_id)
|
88
|
+
|
89
|
+
if (
|
90
|
+
hasattr(Config().clients, "do_test")
|
91
|
+
and Config().clients.do_test
|
92
|
+
and hasattr(Config().data, "testset_sampler")
|
93
|
+
):
|
94
|
+
# Set the sampler for test set
|
95
|
+
self.testset_sampler = samplers_registry.get(
|
96
|
+
self.datasource, self.client_id, testing=True
|
97
|
+
)
|
98
|
+
|
99
|
+
def _load_data(self) -> None:
|
100
|
+
"""Generates data and loads them onto this client."""
|
101
|
+
# The only case where Config().data.reload_data is set to true is
|
102
|
+
# when clients with different client IDs need to load from different datasets,
|
103
|
+
# such as in the pre-partitioned Federated EMNIST dataset. We do not support
|
104
|
+
# reloading data from a custom datasource at this time.
|
105
|
+
if (
|
106
|
+
self.datasource is None
|
107
|
+
or hasattr(Config().data, "reload_data")
|
108
|
+
and Config().data.reload_data
|
109
|
+
):
|
110
|
+
logging.info("[%s] Loading its data source...", self)
|
111
|
+
|
112
|
+
if self.custom_datasource is None:
|
113
|
+
self.datasource = datasources_registry.get(client_id=self.client_id)
|
114
|
+
elif self.custom_datasource is not None:
|
115
|
+
self.datasource = self.custom_datasource()
|
116
|
+
|
117
|
+
logging.info(
|
118
|
+
"[%s] Dataset size: %s", self, self.datasource.num_train_examples()
|
119
|
+
)
|
120
|
+
|
121
|
+
def _allocate_data(self) -> None:
|
122
|
+
"""Allocate training or testing dataset of this client."""
|
123
|
+
# PyTorch uses samplers when loading data with a data loader
|
124
|
+
self.trainset = self.datasource.get_train_set()
|
125
|
+
|
126
|
+
if hasattr(Config().clients, "do_test") and Config().clients.do_test:
|
127
|
+
# Set the testset if local testing is needed
|
128
|
+
self.testset = self.datasource.get_test_set()
|
129
|
+
|
130
|
+
def _load_payload(self, server_payload) -> None:
|
131
|
+
"""Loads the server model onto this client."""
|
132
|
+
self.algorithm.load_weights(server_payload)
|
133
|
+
|
134
|
+
async def _train(self):
|
135
|
+
"""The machine learning training workload on a client."""
|
136
|
+
logging.info(
|
137
|
+
fonts.colourize(
|
138
|
+
f"[{self}] Started training in communication round #{self.current_round}."
|
139
|
+
)
|
140
|
+
)
|
141
|
+
|
142
|
+
# Perform model training
|
143
|
+
try:
|
144
|
+
if hasattr(self.trainer, "current_round"):
|
145
|
+
self.trainer.current_round = self.current_round
|
146
|
+
training_time = self.trainer.train(self.trainset, self.sampler)
|
147
|
+
|
148
|
+
except ValueError as exc:
|
149
|
+
logging.info(
|
150
|
+
fonts.colourize(f"[{self}] Error occurred during training: {exc}")
|
151
|
+
)
|
152
|
+
await self.sio.disconnect()
|
153
|
+
|
154
|
+
# Extract model weights and biases
|
155
|
+
weights = self.algorithm.extract_weights()
|
156
|
+
|
157
|
+
# Generate a report for the server, performing model testing if applicable
|
158
|
+
if (hasattr(Config().clients, "do_test") and Config().clients.do_test) and (
|
159
|
+
not hasattr(Config().clients, "test_interval")
|
160
|
+
or self.current_round % Config().clients.test_interval == 0
|
161
|
+
):
|
162
|
+
accuracy = self.trainer.test(self.testset, self.testset_sampler)
|
163
|
+
|
164
|
+
if accuracy == -1:
|
165
|
+
# The testing process failed, disconnect from the server
|
166
|
+
logging.info(
|
167
|
+
fonts.colourize(
|
168
|
+
f"[{self}] Accuracy is -1 when testing. Disconnecting from the server."
|
169
|
+
)
|
170
|
+
)
|
171
|
+
await self.sio.disconnect()
|
172
|
+
|
173
|
+
if hasattr(Config().trainer, "target_perplexity"):
|
174
|
+
logging.info("[%s] Test perplexity: %.2f", self, accuracy)
|
175
|
+
else:
|
176
|
+
logging.info("[%s] Test accuracy: %.2f%%", self, 100 * accuracy)
|
177
|
+
else:
|
178
|
+
accuracy = 0
|
179
|
+
|
180
|
+
comm_time = time.time()
|
181
|
+
|
182
|
+
if (
|
183
|
+
hasattr(Config().clients, "sleep_simulation")
|
184
|
+
and Config().clients.sleep_simulation
|
185
|
+
):
|
186
|
+
sleep_seconds = Config().client_sleep_times[self.client_id - 1]
|
187
|
+
avg_training_time = Config().clients.avg_training_time
|
188
|
+
|
189
|
+
training_time = (
|
190
|
+
avg_training_time + sleep_seconds
|
191
|
+
) * Config().trainer.epochs
|
192
|
+
|
193
|
+
report = SimpleNamespace(
|
194
|
+
client_id=self.client_id,
|
195
|
+
num_samples=self.sampler.num_samples(),
|
196
|
+
accuracy=accuracy,
|
197
|
+
training_time=training_time,
|
198
|
+
comm_time=comm_time,
|
199
|
+
update_response=False,
|
200
|
+
)
|
201
|
+
|
202
|
+
self._report = self.customize_report(report)
|
203
|
+
|
204
|
+
return self._report, weights
|
205
|
+
|
206
|
+
async def _obtain_model_update(self, client_id, requested_time):
|
207
|
+
"""Retrieves a model update corresponding to a particular wall clock time."""
|
208
|
+
model = self.trainer.obtain_model_update(client_id, requested_time)
|
209
|
+
weights = self.algorithm.extract_weights(model)
|
210
|
+
self._report.comm_time = time.time()
|
211
|
+
self._report.client_id = client_id
|
212
|
+
self._report.update_response = True
|
213
|
+
|
214
|
+
return self._report, weights
|
215
|
+
|
216
|
+
def customize_report(self, report: SimpleNamespace) -> SimpleNamespace:
|
217
|
+
"""Customizes the report with any additional information."""
|
218
|
+
return report
|
@@ -0,0 +1,150 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning client using split learning.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
|
7
|
+
Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
|
8
|
+
|
9
|
+
https://arxiv.org/pdf/1812.00564.pdf
|
10
|
+
|
11
|
+
Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
|
12
|
+
Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
|
13
|
+
|
14
|
+
https://arxiv.org/pdf/2112.01637.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
import time
|
19
|
+
from types import SimpleNamespace
|
20
|
+
|
21
|
+
from plato.clients import simple
|
22
|
+
from plato.config import Config
|
23
|
+
from plato.utils import fonts
|
24
|
+
|
25
|
+
|
26
|
+
class Client(simple.Client):
|
27
|
+
"""A split learning client."""
|
28
|
+
|
29
|
+
# pylint:disable=too-many-arguments
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
model=None,
|
33
|
+
datasource=None,
|
34
|
+
algorithm=None,
|
35
|
+
trainer=None,
|
36
|
+
callbacks=None,
|
37
|
+
):
|
38
|
+
super().__init__(
|
39
|
+
model=model,
|
40
|
+
datasource=datasource,
|
41
|
+
algorithm=algorithm,
|
42
|
+
trainer=trainer,
|
43
|
+
callbacks=callbacks,
|
44
|
+
)
|
45
|
+
assert not Config().clients.do_test
|
46
|
+
|
47
|
+
self.model_received = False
|
48
|
+
self.gradient_received = False
|
49
|
+
self.contexts = {}
|
50
|
+
self.original_weights = None
|
51
|
+
|
52
|
+
# Iteration control
|
53
|
+
self.iterations = Config().clients.iteration
|
54
|
+
self.iter_left = Config().clients.iteration
|
55
|
+
|
56
|
+
# Sampler cannot be reconfigured otherwise same training samples
|
57
|
+
# will be selected every round
|
58
|
+
self.static_sampler = None
|
59
|
+
|
60
|
+
async def inbound_processed(self, processed_inbound_payload):
|
61
|
+
"""Extract features or complete the training using split learning."""
|
62
|
+
server_payload, info = processed_inbound_payload
|
63
|
+
|
64
|
+
# Preparing the client response
|
65
|
+
report, payload = None, None
|
66
|
+
|
67
|
+
if info == "prompt":
|
68
|
+
# Server prompts a new client to conduct split learning
|
69
|
+
self._load_context(self.client_id)
|
70
|
+
report, payload = self._extract_features()
|
71
|
+
elif info == "gradients":
|
72
|
+
# server sends the gradients of the features, i.e., complete training
|
73
|
+
logging.warning("[%s] Gradients received, complete training.", self)
|
74
|
+
training_time, weights = self._complete_training(server_payload)
|
75
|
+
self.iter_left -= 1
|
76
|
+
|
77
|
+
if self.iter_left == 0:
|
78
|
+
logging.warning(
|
79
|
+
"[%s] Finished training, sending weights to the server.", self
|
80
|
+
)
|
81
|
+
# Send weights to server for evaluation
|
82
|
+
report = SimpleNamespace(
|
83
|
+
client_id=self.client_id,
|
84
|
+
num_samples=self.sampler.num_samples(),
|
85
|
+
accuracy=0,
|
86
|
+
training_time=training_time,
|
87
|
+
comm_time=time.time(),
|
88
|
+
update_response=False,
|
89
|
+
type="weights",
|
90
|
+
)
|
91
|
+
payload = weights
|
92
|
+
self.iter_left = self.iterations
|
93
|
+
else:
|
94
|
+
# Continue feature extraction
|
95
|
+
report, payload = self._extract_features()
|
96
|
+
report.training_time += training_time
|
97
|
+
|
98
|
+
# Save the state of current client
|
99
|
+
self._save_context(self.client_id)
|
100
|
+
return report, payload
|
101
|
+
|
102
|
+
def _save_context(self, client_id):
|
103
|
+
"""Saving the extracted weights and the data sampler for a given client."""
|
104
|
+
# Sampler needs to be saved otherwise same data samples will be selected every round
|
105
|
+
self.contexts[client_id] = (
|
106
|
+
self.algorithm.extract_weights(),
|
107
|
+
self.static_sampler,
|
108
|
+
)
|
109
|
+
|
110
|
+
def _load_context(self, client_id):
|
111
|
+
"""Load client's model weights and the sampler from last selected round."""
|
112
|
+
if not client_id in self.contexts:
|
113
|
+
if self.original_weights is None:
|
114
|
+
self.original_weights = self.algorithm.extract_weights()
|
115
|
+
self.algorithm.load_weights(self.original_weights)
|
116
|
+
self.static_sampler = self.sampler.get()
|
117
|
+
else:
|
118
|
+
weights, sampler = self.contexts.pop(client_id)
|
119
|
+
self.algorithm.load_weights(weights)
|
120
|
+
self.static_sampler = sampler
|
121
|
+
|
122
|
+
def _extract_features(self):
|
123
|
+
"""Extract the feature till the cut layer."""
|
124
|
+
round_number = self.iterations - self.iter_left + 1
|
125
|
+
logging.warning(
|
126
|
+
fonts.colourize(
|
127
|
+
f"[{self}] Started split learning in round #{round_number}/{self.iterations}"
|
128
|
+
+ f" (Global round {self.current_round})."
|
129
|
+
)
|
130
|
+
)
|
131
|
+
|
132
|
+
features, training_time = self.algorithm.extract_features(
|
133
|
+
self.trainset, self.static_sampler
|
134
|
+
)
|
135
|
+
report = SimpleNamespace(
|
136
|
+
client_id=self.client_id,
|
137
|
+
num_samples=self.sampler.num_samples(),
|
138
|
+
accuracy=0,
|
139
|
+
training_time=training_time,
|
140
|
+
comm_time=time.time(),
|
141
|
+
update_response=False,
|
142
|
+
type="features",
|
143
|
+
)
|
144
|
+
return report, features
|
145
|
+
|
146
|
+
def _complete_training(self, payload):
|
147
|
+
"""Complete the training based on the gradients from server."""
|
148
|
+
training_time = self.algorithm.complete_train(payload)
|
149
|
+
weights = self.algorithm.extract_weights()
|
150
|
+
return training_time, weights
|