plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning server using federated averaging to aggregate updates after homomorphic encryption.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from functools import reduce
|
6
|
+
from plato.servers import fedavg
|
7
|
+
from plato.utils import homo_enc
|
8
|
+
|
9
|
+
|
10
|
+
class Server(fedavg.Server):
|
11
|
+
"""
|
12
|
+
Federated learning server using federated averaging to aggregate updates after homomorphic
|
13
|
+
encryption.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
18
|
+
):
|
19
|
+
super().__init__(
|
20
|
+
model=model,
|
21
|
+
datasource=datasource,
|
22
|
+
algorithm=algorithm,
|
23
|
+
trainer=trainer,
|
24
|
+
callbacks=callbacks,
|
25
|
+
)
|
26
|
+
|
27
|
+
self.context = homo_enc.get_ckks_context()
|
28
|
+
self.encrypted_model = None
|
29
|
+
self.weight_shapes = {}
|
30
|
+
self.para_nums = {}
|
31
|
+
|
32
|
+
def configure(self) -> None:
|
33
|
+
"""Configure the model information like weight shapes and parameter numbers."""
|
34
|
+
super().configure()
|
35
|
+
|
36
|
+
extract_model = self.trainer.model.cpu().state_dict()
|
37
|
+
|
38
|
+
for key in extract_model.keys():
|
39
|
+
self.weight_shapes[key] = extract_model[key].size()
|
40
|
+
self.para_nums[key] = extract_model[key].numel()
|
41
|
+
|
42
|
+
self.encrypted_model = homo_enc.encrypt_weights(
|
43
|
+
extract_model, True, self.context, []
|
44
|
+
)
|
45
|
+
|
46
|
+
def customize_server_payload(self, payload):
|
47
|
+
"""Server can only send the encrypted aggreagtion result to clients."""
|
48
|
+
return self.encrypted_model
|
49
|
+
|
50
|
+
# pylint: disable=unused-argument
|
51
|
+
async def aggregate_weights(self, updates, baseline_weights, weights_received):
|
52
|
+
"""Aggregate the model updates and decrypt the result for evaluation purpose."""
|
53
|
+
self.encrypted_model = self._fedavg_hybrid(updates)
|
54
|
+
|
55
|
+
# Decrypt model weights for test accuracy
|
56
|
+
decrypted_weights = homo_enc.decrypt_weights(
|
57
|
+
self.encrypted_model, self.weight_shapes, self.para_nums
|
58
|
+
)
|
59
|
+
# Serialize the encrypted weights after decryption
|
60
|
+
self.encrypted_model["encrypted_weights"] = self.encrypted_model[
|
61
|
+
"encrypted_weights"
|
62
|
+
].serialize()
|
63
|
+
|
64
|
+
return decrypted_weights
|
65
|
+
|
66
|
+
def _fedavg_hybrid(self, updates):
|
67
|
+
"""Aggregate the model updates in the hybrid form of encrypted and unencrypted weights."""
|
68
|
+
weights_received = [
|
69
|
+
homo_enc.deserialize_weights(update.payload, self.context)
|
70
|
+
for update in updates
|
71
|
+
]
|
72
|
+
unencrypted_weights = [
|
73
|
+
homo_enc.extract_encrypted_model(x)[0] for x in weights_received
|
74
|
+
]
|
75
|
+
encrypted_weights = [
|
76
|
+
homo_enc.extract_encrypted_model(x)[1] for x in weights_received
|
77
|
+
]
|
78
|
+
# Assert the encrypted weights from all clients are aligned
|
79
|
+
indices = [homo_enc.extract_encrypted_model(x)[2] for x in weights_received]
|
80
|
+
for i in range(1, len(indices)):
|
81
|
+
assert indices[i] == indices[0]
|
82
|
+
encrypt_indices = indices[0]
|
83
|
+
|
84
|
+
# Extract the total number of samples
|
85
|
+
self.total_samples = sum(update.report.num_samples for update in updates)
|
86
|
+
|
87
|
+
# Perform weighted averaging on unencrypted weights
|
88
|
+
unencrypted_avg_update = self.trainer.zeros(unencrypted_weights[0].size)
|
89
|
+
encrypted_avg_update = self.trainer.zeros(encrypted_weights[0].size())
|
90
|
+
|
91
|
+
for i, (unenc_w, enc_w) in enumerate(
|
92
|
+
zip(unencrypted_weights, encrypted_weights)
|
93
|
+
):
|
94
|
+
report = updates[i].report
|
95
|
+
num_samples = report.num_samples
|
96
|
+
|
97
|
+
unencrypted_avg_update += unenc_w * (num_samples / self.total_samples)
|
98
|
+
encrypted_avg_update += enc_w * (num_samples / self.total_samples)
|
99
|
+
|
100
|
+
if len(encrypt_indices) == 0:
|
101
|
+
# No weights are encrypted, set to None
|
102
|
+
encrypted_avg_update = None
|
103
|
+
|
104
|
+
return homo_enc.wrap_encrypted_model(
|
105
|
+
unencrypted_avg_update, encrypted_avg_update, encrypt_indices
|
106
|
+
)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""
|
2
|
+
A personalized federated learning server that starts from a number of regular
|
3
|
+
rounds of federated learning. In these regular rounds, only a subset of the
|
4
|
+
total clients can be selected to perform the local update (the ratio of which is
|
5
|
+
a configuration setting). After all regular rounds are completed, it starts a
|
6
|
+
final round of personalization, where a selected subset of clients perform local
|
7
|
+
training using their local dataset.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from plato.servers import fedavg
|
11
|
+
from plato.config import Config
|
12
|
+
|
13
|
+
|
14
|
+
class Server(fedavg.Server):
|
15
|
+
"""
|
16
|
+
A personalzed FL server that controls how many clients will participate in
|
17
|
+
the training process, and that adds a final personalization round with all
|
18
|
+
clients sampled.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
23
|
+
):
|
24
|
+
super().__init__(
|
25
|
+
model=model,
|
26
|
+
datasource=datasource,
|
27
|
+
algorithm=algorithm,
|
28
|
+
trainer=trainer,
|
29
|
+
callbacks=callbacks,
|
30
|
+
)
|
31
|
+
# Personalization starts after the final regular round of training
|
32
|
+
self.personalization_started = False
|
33
|
+
|
34
|
+
def choose_clients(self, clients_pool, clients_count):
|
35
|
+
"""Choose a subset of the clients to participate in each round."""
|
36
|
+
if self.current_round > Config().trainer.rounds:
|
37
|
+
# In the final personalization round, choose from all clients
|
38
|
+
return super().choose_clients(clients_pool, clients_count)
|
39
|
+
else:
|
40
|
+
ratio = Config().algorithm.personalization.participating_client_ratio
|
41
|
+
|
42
|
+
return super().choose_clients(
|
43
|
+
clients_pool[: int(self.total_clients * ratio)],
|
44
|
+
clients_count,
|
45
|
+
)
|
46
|
+
|
47
|
+
async def wrap_up(self) -> None:
|
48
|
+
"""Wraps up when each round of training is done."""
|
49
|
+
if self.personalization_started:
|
50
|
+
await super().wrap_up()
|
51
|
+
else:
|
52
|
+
# If the target number of training rounds has been reached, start
|
53
|
+
# the final round of training for personalization on the clients
|
54
|
+
self.save_to_checkpoint()
|
55
|
+
|
56
|
+
if self.current_round >= Config().trainer.rounds:
|
57
|
+
self.personalization_started = True
|
plato/servers/mistnet.py
ADDED
@@ -0,0 +1,67 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning server for MistNet.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
|
6
|
+
Differential Privacy," found in docs/papers.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.datasources import feature
|
14
|
+
from plato.samplers import all_inclusive
|
15
|
+
from plato.servers import fedavg
|
16
|
+
|
17
|
+
|
18
|
+
class Server(fedavg.Server):
|
19
|
+
"""The MistNet server for federated learning."""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
23
|
+
):
|
24
|
+
super().__init__(
|
25
|
+
model=model,
|
26
|
+
datasource=datasource,
|
27
|
+
algorithm=algorithm,
|
28
|
+
trainer=trainer,
|
29
|
+
callbacks=callbacks,
|
30
|
+
)
|
31
|
+
|
32
|
+
# MistNet requires one round of client-server communication
|
33
|
+
assert Config().trainer.rounds == 1
|
34
|
+
|
35
|
+
def init_trainer(self) -> None:
|
36
|
+
"""Setting up a pre-trained model to be loaded on the server."""
|
37
|
+
super().init_trainer()
|
38
|
+
|
39
|
+
model_path = Config().params["model_path"]
|
40
|
+
model_file_name = (
|
41
|
+
Config().trainer.pretrained_model
|
42
|
+
if hasattr(Config().trainer, "pretrained_model")
|
43
|
+
else f"{Config().trainer.model_name}.pth"
|
44
|
+
)
|
45
|
+
pretrained_model_path = f"{model_path}/{model_file_name}"
|
46
|
+
|
47
|
+
if os.path.exists(pretrained_model_path):
|
48
|
+
logging.info("[Server #%d] Loading a pre-trained model.", os.getpid())
|
49
|
+
self.trainer.load_model(filename=model_file_name)
|
50
|
+
|
51
|
+
async def _process_reports(self):
|
52
|
+
"""Process the features extracted by the client and perform server-side training."""
|
53
|
+
features = [update.payload for update in self.updates]
|
54
|
+
feature_dataset = feature.DataSource(features)
|
55
|
+
|
56
|
+
# Training the model using all the features received from the client
|
57
|
+
sampler = all_inclusive.Sampler(feature_dataset)
|
58
|
+
self.algorithm.train(feature_dataset, sampler)
|
59
|
+
|
60
|
+
# Test the updated model
|
61
|
+
if not hasattr(Config().server, "do_test") or Config().server.do_test:
|
62
|
+
self.accuracy = self.trainer.test(self.testset)
|
63
|
+
logging.info(
|
64
|
+
"[%s] Global model accuracy: %.2f%%\n", self, 100 * self.accuracy
|
65
|
+
)
|
66
|
+
|
67
|
+
self.clients_processed()
|
@@ -0,0 +1,52 @@
|
|
1
|
+
"""
|
2
|
+
The registry for servers that contains framework-agnostic implementations on a
|
3
|
+
federated learning server.
|
4
|
+
|
5
|
+
Having a registry of all available classes is convenient for retrieving an
|
6
|
+
instance based on a configuration at run-time.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
|
11
|
+
from plato.config import Config
|
12
|
+
|
13
|
+
from plato.servers import (
|
14
|
+
fedavg,
|
15
|
+
fedavg_cs,
|
16
|
+
mistnet,
|
17
|
+
fedavg_gan,
|
18
|
+
fedavg_personalized,
|
19
|
+
split_learning,
|
20
|
+
)
|
21
|
+
|
22
|
+
if hasattr(Config().server, "type") and Config().server.type == "fedavg_he":
|
23
|
+
# FedAvg server with homomorphic encryption supports PyTorch only
|
24
|
+
from plato.servers import fedavg_he
|
25
|
+
|
26
|
+
registered_servers = {"fedavg_he": fedavg_he.Server}
|
27
|
+
|
28
|
+
else:
|
29
|
+
registered_servers = {
|
30
|
+
"fedavg": fedavg.Server,
|
31
|
+
"fedavg_cross_silo": fedavg_cs.Server,
|
32
|
+
"mistnet": mistnet.Server,
|
33
|
+
"fedavg_gan": fedavg_gan.Server,
|
34
|
+
"fedavg_personalized": fedavg_personalized.Server,
|
35
|
+
"split_learning": split_learning.Server,
|
36
|
+
}
|
37
|
+
|
38
|
+
|
39
|
+
def get(model=None, algorithm=None, trainer=None):
|
40
|
+
"""Get an instance of the server."""
|
41
|
+
if hasattr(Config().server, "type"):
|
42
|
+
server_type = Config().server.type
|
43
|
+
else:
|
44
|
+
server_type = Config().algorithm.type
|
45
|
+
|
46
|
+
if server_type in registered_servers:
|
47
|
+
logging.info("Server: %s", server_type)
|
48
|
+
return registered_servers[server_type](
|
49
|
+
model=model, algorithm=algorithm, trainer=trainer
|
50
|
+
)
|
51
|
+
else:
|
52
|
+
raise ValueError(f"No such server: {server_type}")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning server using split learning.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
|
7
|
+
Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
|
8
|
+
|
9
|
+
https://arxiv.org/pdf/1812.00564.pdf
|
10
|
+
|
11
|
+
Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
|
12
|
+
Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
|
13
|
+
|
14
|
+
https://arxiv.org/pdf/2112.01637.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
|
19
|
+
from plato.config import Config
|
20
|
+
from plato.datasources import feature
|
21
|
+
from plato.samplers import all_inclusive
|
22
|
+
from plato.servers import fedavg
|
23
|
+
from plato.utils import fonts
|
24
|
+
from plato.datasources import registry as datasources_registry
|
25
|
+
|
26
|
+
|
27
|
+
# pylint:disable=too-many-instance-attributes
|
28
|
+
class Server(fedavg.Server):
|
29
|
+
"""The split learning server."""
|
30
|
+
|
31
|
+
# pylint:disable=too-many-arguments
|
32
|
+
def __init__(
|
33
|
+
self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
|
34
|
+
):
|
35
|
+
super().__init__(model, datasource, algorithm, trainer, callbacks)
|
36
|
+
# Split learning clients interact with server sequentially
|
37
|
+
assert Config().clients.per_round == 1
|
38
|
+
self.phase = "prompt"
|
39
|
+
self.clients_list = []
|
40
|
+
self.client_last = None
|
41
|
+
self.next_client = True
|
42
|
+
self.test_accuracy = 0.0
|
43
|
+
|
44
|
+
# Manually set up the testset since do_test is turned off in config
|
45
|
+
if self.datasource is None and self.custom_datasource is None:
|
46
|
+
self.datasource = datasources_registry.get(client_id=0)
|
47
|
+
elif self.datasource is None and self.custom_datasource is not None:
|
48
|
+
self.datasource = self.custom_datasource()
|
49
|
+
self.testset = self.datasource.get_test_set()
|
50
|
+
self.testset_sampler = all_inclusive.Sampler(self.datasource, testing=True)
|
51
|
+
|
52
|
+
def choose_clients(self, clients_pool, clients_count):
|
53
|
+
"""Shuffle the clients and sequentially select them when the previous one is done."""
|
54
|
+
if len(self.clients_list) == 0 and self.next_client:
|
55
|
+
# Shuffle the client list
|
56
|
+
self.clients_list = super().choose_clients(clients_pool, len(clients_pool))
|
57
|
+
logging.warning("Client order: %s", str(self.clients_list))
|
58
|
+
|
59
|
+
if self.next_client:
|
60
|
+
# Sequentially select clients
|
61
|
+
self.client_last = [self.clients_list.pop(0)]
|
62
|
+
self.next_client = False
|
63
|
+
return self.client_last
|
64
|
+
|
65
|
+
def customize_server_payload(self, payload):
|
66
|
+
"""Wrap up generating the server payload with any additional information."""
|
67
|
+
if self.phase == "prompt":
|
68
|
+
# Split learning server doesn't send weights to client
|
69
|
+
return (None, "prompt")
|
70
|
+
return (self.trainer.get_gradients(), "gradients")
|
71
|
+
|
72
|
+
# pylint: disable=unused-argument
|
73
|
+
async def aggregate_weights(self, updates, baseline_weights, weights_received):
|
74
|
+
"""Aggregate weight updates from the clients or train the model."""
|
75
|
+
update = updates[0]
|
76
|
+
report = update.report
|
77
|
+
if report.type == "features":
|
78
|
+
logging.warning("[%s] Features received, compute gradients.", self)
|
79
|
+
feature_dataset = feature.DataSource([update.payload])
|
80
|
+
|
81
|
+
# Training the model using all the features received from the client
|
82
|
+
sampler = all_inclusive.Sampler(feature_dataset)
|
83
|
+
self.algorithm.train(feature_dataset, sampler)
|
84
|
+
|
85
|
+
self.phase = "gradient"
|
86
|
+
elif report.type == "weights":
|
87
|
+
logging.warning("[%s] Weights received, start testing accuracy.", self)
|
88
|
+
weights = update.payload
|
89
|
+
|
90
|
+
# The weights after cut layer are not trained by clients
|
91
|
+
self.algorithm.update_weights_before_cut(weights)
|
92
|
+
|
93
|
+
self.test_accuracy = self.trainer.test(self.testset, self.testset_sampler)
|
94
|
+
|
95
|
+
logging.warning(
|
96
|
+
fonts.colourize(
|
97
|
+
f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n"
|
98
|
+
)
|
99
|
+
)
|
100
|
+
self.phase = "prompt"
|
101
|
+
# Change client in next round
|
102
|
+
self.next_client = True
|
103
|
+
|
104
|
+
updated_weights = self.algorithm.extract_weights()
|
105
|
+
return updated_weights
|
106
|
+
|
107
|
+
def clients_processed(self):
|
108
|
+
# Replace the default accuracy by manually tested accuracy
|
109
|
+
self.accuracy = self.test_accuracy
|
File without changes
|
plato/trainers/base.py
ADDED
@@ -0,0 +1,99 @@
|
|
1
|
+
"""
|
2
|
+
Base class for trainers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
import os
|
7
|
+
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
|
11
|
+
class Trainer(ABC):
|
12
|
+
"""Base class for all the trainers."""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self.device = Config().device()
|
16
|
+
self.client_id = 0
|
17
|
+
|
18
|
+
def set_client_id(self, client_id):
|
19
|
+
"""Setting the client ID."""
|
20
|
+
self.client_id = client_id
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def save_model(self, filename=None, location=None):
|
24
|
+
"""Saving the model to a file."""
|
25
|
+
raise TypeError("save_model() not implemented.")
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def load_model(self, filename=None, location=None):
|
29
|
+
"""Loading pre-trained model weights from a file."""
|
30
|
+
raise TypeError("load_model() not implemented.")
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def save_accuracy(accuracy, filename=None):
|
34
|
+
"""Saving the test accuracy to a file."""
|
35
|
+
model_path = Config().params["model_path"]
|
36
|
+
model_name = Config().trainer.model_name
|
37
|
+
|
38
|
+
if not os.path.exists(model_path):
|
39
|
+
os.makedirs(model_path)
|
40
|
+
|
41
|
+
if filename is not None:
|
42
|
+
accuracy_path = f"{model_path}/{filename}"
|
43
|
+
else:
|
44
|
+
accuracy_path = f"{model_path}/{model_name}.acc"
|
45
|
+
|
46
|
+
with open(accuracy_path, "w", encoding="utf-8") as file:
|
47
|
+
file.write(str(accuracy))
|
48
|
+
|
49
|
+
@staticmethod
|
50
|
+
def load_accuracy(filename=None):
|
51
|
+
"""Loading the test accuracy from a file."""
|
52
|
+
model_path = Config().params["model_path"]
|
53
|
+
model_name = Config().trainer.model_name
|
54
|
+
|
55
|
+
if filename is not None:
|
56
|
+
accuracy_path = f"{model_path}/{filename}"
|
57
|
+
else:
|
58
|
+
accuracy_path = f"{model_path}/{model_name}.acc"
|
59
|
+
|
60
|
+
with open(accuracy_path, "r", encoding="utf-8") as file:
|
61
|
+
accuracy = float(file.read())
|
62
|
+
|
63
|
+
return accuracy
|
64
|
+
|
65
|
+
def pause_training(self):
|
66
|
+
"""Remove files of running trainers."""
|
67
|
+
if hasattr(Config().trainer, "max_concurrency"):
|
68
|
+
model_name = Config().trainer.model_name
|
69
|
+
model_path = Config().params["model_path"]
|
70
|
+
model_file = f"{model_path}/{model_name}_{self.client_id}_{Config().params['run_id']}.pth"
|
71
|
+
accuracy_file = f"{model_path}/{model_name}_{self.client_id}_{Config().params['run_id']}.acc"
|
72
|
+
|
73
|
+
if os.path.exists(model_file):
|
74
|
+
os.remove(model_file)
|
75
|
+
os.remove(model_file + ".pkl")
|
76
|
+
|
77
|
+
if os.path.exists(accuracy_file):
|
78
|
+
os.remove(accuracy_file)
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def train(self, trainset, sampler, **kwargs) -> float:
|
82
|
+
"""The main training loop in a federated learning workload.
|
83
|
+
|
84
|
+
Arguments:
|
85
|
+
trainset: The training dataset.
|
86
|
+
sampler: the sampler that extracts a partition for this client.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
float: The training time.
|
90
|
+
"""
|
91
|
+
|
92
|
+
@abstractmethod
|
93
|
+
def test(self, testset, sampler=None, **kwargs) -> float:
|
94
|
+
"""Testing the model using the provided test dataset.
|
95
|
+
|
96
|
+
Arguments:
|
97
|
+
testset: The test dataset.
|
98
|
+
sampler: The sampler that extracts a partition of the test dataset.
|
99
|
+
"""
|