plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "1.0"
|
File without changes
|
plato/algorithms/base.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Base class for algorithms.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
|
8
|
+
from plato.trainers.base import Trainer
|
9
|
+
|
10
|
+
|
11
|
+
class Algorithm(ABC):
|
12
|
+
"""Base class for all the algorithms."""
|
13
|
+
|
14
|
+
def __init__(self, trainer: Trainer):
|
15
|
+
"""Initializes the algorithm with the provided model and trainer.
|
16
|
+
|
17
|
+
Arguments:
|
18
|
+
trainer: The trainer for the model, which is a trainers.base.Trainer class.
|
19
|
+
model: The model to train.
|
20
|
+
"""
|
21
|
+
super().__init__()
|
22
|
+
self.trainer = trainer
|
23
|
+
self.model = trainer.model
|
24
|
+
self.client_id = 0
|
25
|
+
|
26
|
+
def __repr__(self):
|
27
|
+
if self.client_id == 0:
|
28
|
+
return f"Server #{os.getpid()}"
|
29
|
+
else:
|
30
|
+
return f"Client #{self.client_id}"
|
31
|
+
|
32
|
+
def set_client_id(self, client_id):
|
33
|
+
"""Sets the client ID."""
|
34
|
+
self.client_id = client_id
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
def extract_weights(self, model=None):
|
38
|
+
"""Extracts weights from a model passed in as a parameter."""
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def load_weights(self, weights):
|
42
|
+
"""Loads the model weights passed in as a parameter."""
|
43
|
+
|
44
|
+
async def aggregate_weights(self, baseline_weights, weights_received, **kwargs):
|
45
|
+
"""Aggregates the weights received into baseline weights (optional)."""
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
The federated averaging algorithm for PyTorch.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections import OrderedDict
|
6
|
+
|
7
|
+
from plato.algorithms import base
|
8
|
+
|
9
|
+
|
10
|
+
class Algorithm(base.Algorithm):
|
11
|
+
"""PyTorch-based federated averaging algorithm, used by both the client and the server."""
|
12
|
+
|
13
|
+
def compute_weight_deltas(self, baseline_weights, weights_received):
|
14
|
+
"""Compute the deltas between baseline weights and weights received."""
|
15
|
+
# Calculate updates from the received weights
|
16
|
+
deltas = []
|
17
|
+
for weight in weights_received:
|
18
|
+
delta = OrderedDict()
|
19
|
+
for name, current_weight in weight.items():
|
20
|
+
baseline = baseline_weights[name]
|
21
|
+
|
22
|
+
# Calculate update
|
23
|
+
_delta = current_weight - baseline
|
24
|
+
delta[name] = _delta
|
25
|
+
deltas.append(delta)
|
26
|
+
|
27
|
+
return deltas
|
28
|
+
|
29
|
+
def update_weights(self, deltas):
|
30
|
+
"""Updates the existing model weights from the provided deltas."""
|
31
|
+
baseline_weights = self.extract_weights()
|
32
|
+
|
33
|
+
updated_weights = OrderedDict()
|
34
|
+
for name, weight in baseline_weights.items():
|
35
|
+
updated_weights[name] = weight + deltas[name]
|
36
|
+
|
37
|
+
return updated_weights
|
38
|
+
|
39
|
+
def extract_weights(self, model=None):
|
40
|
+
"""Extracts weights from the model."""
|
41
|
+
if model is None:
|
42
|
+
return self.model.cpu().state_dict()
|
43
|
+
else:
|
44
|
+
return model.cpu().state_dict()
|
45
|
+
|
46
|
+
def load_weights(self, weights):
|
47
|
+
"""Loads the model weights passed in as a parameter."""
|
48
|
+
self.model.load_state_dict(weights, strict=True)
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
The federated averaging algorithm for GAN model.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections import OrderedDict
|
6
|
+
|
7
|
+
from plato.algorithms import fedavg
|
8
|
+
from plato.trainers.base import Trainer
|
9
|
+
|
10
|
+
|
11
|
+
class Algorithm(fedavg.Algorithm):
|
12
|
+
"""Federated averaging algorithm for GAN models, used by both the client and the server."""
|
13
|
+
|
14
|
+
def __init__(self, trainer: Trainer):
|
15
|
+
super().__init__(trainer=trainer)
|
16
|
+
self.generator = self.model.generator
|
17
|
+
self.discriminator = self.model.discriminator
|
18
|
+
|
19
|
+
def compute_weight_deltas(self, weights_received):
|
20
|
+
"""Extract the weights received from a client and compute the updates."""
|
21
|
+
baseline_weights_gen, baseline_weights_disc = self.extract_weights()
|
22
|
+
|
23
|
+
deltas = []
|
24
|
+
for weight_gen, weight_disc in weights_received:
|
25
|
+
delta_gen = OrderedDict()
|
26
|
+
for name, current_weight in weight_gen.items():
|
27
|
+
baseline = baseline_weights_gen[name]
|
28
|
+
|
29
|
+
delta = current_weight - baseline
|
30
|
+
delta_gen[name] = delta
|
31
|
+
|
32
|
+
delta_disc = OrderedDict()
|
33
|
+
for name, current_weight in weight_disc.items():
|
34
|
+
baseline = baseline_weights_disc[name]
|
35
|
+
|
36
|
+
delta = current_weight - baseline
|
37
|
+
delta_disc[name] = delta
|
38
|
+
|
39
|
+
deltas.append((delta_gen, delta_disc))
|
40
|
+
|
41
|
+
return deltas
|
42
|
+
|
43
|
+
def update_weights(self, deltas):
|
44
|
+
"""Update the existing model weights."""
|
45
|
+
baseline_weights_gen, baseline_weights_disc = self.extract_weights()
|
46
|
+
update_gen, update_disc = deltas
|
47
|
+
|
48
|
+
updated_weights_gen = OrderedDict()
|
49
|
+
for name, weight in baseline_weights_gen.items():
|
50
|
+
updated_weights_gen[name] = weight + update_gen[name]
|
51
|
+
|
52
|
+
updated_weights_disc = OrderedDict()
|
53
|
+
for name, weight in baseline_weights_disc.items():
|
54
|
+
updated_weights_disc[name] = weight + update_disc[name]
|
55
|
+
|
56
|
+
return updated_weights_gen, updated_weights_disc
|
57
|
+
|
58
|
+
def extract_weights(self, model=None):
|
59
|
+
"""Extract weights from the model."""
|
60
|
+
generator = self.generator
|
61
|
+
discriminator = self.discriminator
|
62
|
+
if model is not None:
|
63
|
+
generator = model.generator
|
64
|
+
discriminator = model.discriminator
|
65
|
+
|
66
|
+
gen_weight = generator.cpu().state_dict()
|
67
|
+
disc_weight = discriminator.cpu().state_dict()
|
68
|
+
|
69
|
+
return gen_weight, disc_weight
|
70
|
+
|
71
|
+
def load_weights(self, weights):
|
72
|
+
"""Load the model weights passed in as a parameter."""
|
73
|
+
weights_gen, weights_disc = weights
|
74
|
+
# The client might only receive one or none of the Generator
|
75
|
+
# and Discriminator model weight.
|
76
|
+
if weights_gen is not None:
|
77
|
+
self.generator.load_state_dict(weights_gen, strict=True)
|
78
|
+
if weights_disc is not None:
|
79
|
+
self.discriminator.load_state_dict(weights_disc, strict=True)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
A personalized federate learning algorithm that loads and saves local layers of a model.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import logging
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from plato.algorithms import fedavg
|
10
|
+
from plato.config import Config
|
11
|
+
|
12
|
+
|
13
|
+
class Algorithm(fedavg.Algorithm):
|
14
|
+
"""
|
15
|
+
A personalized federate learning algorithm that loads and saves local layers
|
16
|
+
of a model.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def load_weights(self, weights):
|
20
|
+
"""
|
21
|
+
Loads local layers included in `local_layer_names` to the received weights which
|
22
|
+
will be loaded to the model
|
23
|
+
"""
|
24
|
+
if hasattr(Config().algorithm, "local_layer_names"):
|
25
|
+
# Get the filename of the previous saved local layer
|
26
|
+
model_path = Config().params["model_path"]
|
27
|
+
model_name = Config().trainer.model_name
|
28
|
+
filename = f"{model_path}/{model_name}_{self.client_id}_local_layers.pth"
|
29
|
+
|
30
|
+
# Load local layers to the weights when the file exists
|
31
|
+
if os.path.exists(filename):
|
32
|
+
local_layers = torch.load(filename, map_location=torch.device("cpu"))
|
33
|
+
|
34
|
+
# Update the received weights with the loaded local layers
|
35
|
+
weights.update(local_layers)
|
36
|
+
|
37
|
+
logging.info(
|
38
|
+
"[Client #%d] Replaced portions of the global model with local layers.",
|
39
|
+
self.trainer.client_id,
|
40
|
+
)
|
41
|
+
|
42
|
+
self.model.load_state_dict(weights, strict=True)
|
43
|
+
|
44
|
+
def save_local_layers(self, local_layers, filename):
|
45
|
+
"""
|
46
|
+
Save local layers to a file with the filename provided.
|
47
|
+
"""
|
48
|
+
torch.save(local_layers, filename)
|
@@ -0,0 +1,52 @@
|
|
1
|
+
"""
|
2
|
+
The PyTorch-based MistNet algorithm, used by both the client and the server.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
|
7
|
+
Differential Privacy," found in docs/papers.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import logging
|
11
|
+
import time
|
12
|
+
|
13
|
+
import torch
|
14
|
+
from plato.algorithms import fedavg
|
15
|
+
from plato.datasources import feature_dataset
|
16
|
+
|
17
|
+
|
18
|
+
class Algorithm(fedavg.Algorithm):
|
19
|
+
"""The PyTorch-based MistNet algorithm, used by both the client and the
|
20
|
+
server.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def extract_features(self, dataset, sampler):
|
24
|
+
"""Extracting features using layers before the cut_layer.
|
25
|
+
|
26
|
+
dataset: The training or testing dataset.
|
27
|
+
"""
|
28
|
+
self.model.eval()
|
29
|
+
|
30
|
+
data_loader = self.trainer.get_train_loader(
|
31
|
+
batch_size=1, trainset=dataset, sampler=sampler.get(), extract_features=True
|
32
|
+
)
|
33
|
+
|
34
|
+
tic = time.perf_counter()
|
35
|
+
|
36
|
+
features_dataset = []
|
37
|
+
|
38
|
+
for inputs, targets, *__ in data_loader:
|
39
|
+
with torch.no_grad():
|
40
|
+
logits = self.model.forward_to(inputs)
|
41
|
+
features_dataset.append((logits, targets))
|
42
|
+
|
43
|
+
toc = time.perf_counter()
|
44
|
+
logging.info("[Client #%s] Time used: %.2f seconds.", self.client_id, toc - tic)
|
45
|
+
|
46
|
+
return features_dataset
|
47
|
+
|
48
|
+
def train(self, trainset, sampler):
|
49
|
+
"""Train the neural network model after the cut layer."""
|
50
|
+
self.trainer.train(
|
51
|
+
feature_dataset.FeatureDataset(trainset.feature_dataset), sampler
|
52
|
+
)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
The registry for algorithms that contains framework-specific implementations.
|
3
|
+
|
4
|
+
Having a registry of all available classes is convenient for retrieving an instance
|
5
|
+
based on a configuration at run-time.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
|
10
|
+
from plato.config import Config
|
11
|
+
|
12
|
+
|
13
|
+
from plato.algorithms import (
|
14
|
+
fedavg,
|
15
|
+
mistnet,
|
16
|
+
fedavg_gan,
|
17
|
+
fedavg_personalized,
|
18
|
+
split_learning,
|
19
|
+
)
|
20
|
+
|
21
|
+
registered_algorithms = {
|
22
|
+
"fedavg": fedavg.Algorithm,
|
23
|
+
"mistnet": mistnet.Algorithm,
|
24
|
+
"fedavg_gan": fedavg_gan.Algorithm,
|
25
|
+
"fedavg_personalized": fedavg_personalized.Algorithm,
|
26
|
+
"split_learning": split_learning.Algorithm,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def get(trainer=None):
|
31
|
+
"""Get the algorithm with the provided type."""
|
32
|
+
algorithm_type = Config().algorithm.type
|
33
|
+
|
34
|
+
if algorithm_type in registered_algorithms:
|
35
|
+
logging.info("Algorithm: %s", algorithm_type)
|
36
|
+
registered_alg = registered_algorithms[algorithm_type](trainer)
|
37
|
+
return registered_alg
|
38
|
+
else:
|
39
|
+
raise ValueError(f"No such algorithm: {algorithm_type}")
|
@@ -0,0 +1,89 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning algorithm using split learning.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
|
7
|
+
Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
|
8
|
+
|
9
|
+
https://arxiv.org/pdf/1812.00564.pdf
|
10
|
+
|
11
|
+
Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
|
12
|
+
Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
|
13
|
+
|
14
|
+
https://arxiv.org/pdf/2112.01637.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
import time
|
19
|
+
|
20
|
+
from plato.algorithms import fedavg
|
21
|
+
from plato.config import Config
|
22
|
+
from plato.datasources import feature_dataset
|
23
|
+
|
24
|
+
|
25
|
+
class Algorithm(fedavg.Algorithm):
|
26
|
+
"""The PyTorch-based split learning algorithm, used by both the client and the
|
27
|
+
server.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def extract_features(self, dataset, sampler):
|
31
|
+
"""Extracting features using layers before the cut_layer."""
|
32
|
+
self.model.to(self.trainer.device)
|
33
|
+
self.model.eval()
|
34
|
+
|
35
|
+
tic = time.perf_counter()
|
36
|
+
|
37
|
+
features_dataset = []
|
38
|
+
|
39
|
+
inputs, targets = self.trainer.get_train_samples(
|
40
|
+
Config().trainer.batch_size, dataset, sampler
|
41
|
+
)
|
42
|
+
inputs = inputs.to(self.trainer.device)
|
43
|
+
targets = targets.to(self.trainer.device)
|
44
|
+
outputs, targets = self.trainer.forward_to_intermediate_feature(inputs, targets)
|
45
|
+
features_dataset.append((outputs, targets))
|
46
|
+
|
47
|
+
toc = time.perf_counter()
|
48
|
+
logging.warning(
|
49
|
+
"[Client #%d] Features extracted from %s examples in %.2f seconds.",
|
50
|
+
self.client_id,
|
51
|
+
Config().trainer.batch_size,
|
52
|
+
toc - tic,
|
53
|
+
)
|
54
|
+
|
55
|
+
return features_dataset, toc - tic
|
56
|
+
|
57
|
+
def complete_train(self, gradients):
|
58
|
+
"""Update the model on the client/device with the gradients received
|
59
|
+
from the server.
|
60
|
+
"""
|
61
|
+
tic = time.perf_counter()
|
62
|
+
|
63
|
+
# Retrieve the training samples and let trainer do the training
|
64
|
+
samples, sampler = self.trainer.retrieve_train_samples()
|
65
|
+
self.trainer.load_gradients(gradients)
|
66
|
+
self.train(samples, sampler)
|
67
|
+
|
68
|
+
toc = time.perf_counter()
|
69
|
+
logging.warning(
|
70
|
+
"[Client #%d] Training completed in %.2f seconds.",
|
71
|
+
self.client_id,
|
72
|
+
toc - tic,
|
73
|
+
)
|
74
|
+
|
75
|
+
return toc - tic
|
76
|
+
|
77
|
+
def train(self, trainset, sampler):
|
78
|
+
"""General training method that trains model with provided trainset and sampler."""
|
79
|
+
self.trainer.train(
|
80
|
+
feature_dataset.FeatureDataset(trainset.feature_dataset), sampler
|
81
|
+
)
|
82
|
+
|
83
|
+
def update_weights_before_cut(self, weights):
|
84
|
+
"""Update the weights before cut layer, called when testing accuracy."""
|
85
|
+
current_weights = self.extract_weights()
|
86
|
+
current_weights = self.trainer.update_weights_before_cut(
|
87
|
+
current_weights, weights
|
88
|
+
)
|
89
|
+
self.load_weights(current_weights)
|
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
Defines the ClientCallback class, which is the abstract base class to be subclassed
|
3
|
+
when creating new client callbacks.
|
4
|
+
|
5
|
+
Defines a default callback to print local training progress.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from abc import ABC
|
9
|
+
import logging
|
10
|
+
|
11
|
+
|
12
|
+
class ClientCallback(ABC):
|
13
|
+
"""
|
14
|
+
The abstract base class to be subclassed when creating new client callbacks.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def on_inbound_received(self, client, inbound_processor):
|
18
|
+
"""
|
19
|
+
Event called before inbound processors start to process data.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def on_inbound_processed(self, client, data):
|
23
|
+
"""
|
24
|
+
Event called when payload was processed by inbound processors.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def on_outbound_ready(self, client, report, outbound_processor):
|
28
|
+
"""
|
29
|
+
Event called before outbound processors start to process data.
|
30
|
+
"""
|
31
|
+
|
32
|
+
|
33
|
+
class LogProgressCallback(ClientCallback):
|
34
|
+
"""
|
35
|
+
A callback which prints a message when needed.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def on_inbound_received(self, client, inbound_processor):
|
39
|
+
"""
|
40
|
+
Event called before inbound processors start to process data.
|
41
|
+
"""
|
42
|
+
logging.info("[%s] Start to process inbound data.", client)
|
43
|
+
|
44
|
+
def on_inbound_processed(self, client, data):
|
45
|
+
"""
|
46
|
+
Event called when payload was processed by inbound processors.
|
47
|
+
"""
|
48
|
+
logging.info("[%s] Inbound data has been processed.", client)
|
49
|
+
|
50
|
+
def on_outbound_ready(self, client, report, outbound_processor):
|
51
|
+
"""
|
52
|
+
Event called before outbound processors start to process data.
|
53
|
+
"""
|
54
|
+
logging.info(
|
55
|
+
"[%s] Outbound data is ready to be sent after being processed.", client
|
56
|
+
)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
"""
|
2
|
+
Defines the :class:`CallbackHandler`, which is responsible for calling a list of callbacks.
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class CallbackHandler:
|
7
|
+
"""
|
8
|
+
The :class:`CallbackHandler` is responsible for calling a list of callbacks.
|
9
|
+
This class calls the callbacks in the order that they are given.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self, callbacks):
|
13
|
+
self.callbacks = []
|
14
|
+
self.add_callbacks(callbacks)
|
15
|
+
|
16
|
+
def add_callbacks(self, callbacks):
|
17
|
+
"""
|
18
|
+
Adds a list of callbacks to the callback handler.
|
19
|
+
|
20
|
+
:param callbacks: a list of instances of a subclass of :class:`TrainerCallback`.
|
21
|
+
"""
|
22
|
+
for callback in callbacks:
|
23
|
+
self.add_callback(callback)
|
24
|
+
|
25
|
+
def add_callback(self, callback):
|
26
|
+
"""
|
27
|
+
Adds a callback to the callback handler.
|
28
|
+
|
29
|
+
:param callback: an instance of a subclass of :class:`TrainerCallback`.
|
30
|
+
"""
|
31
|
+
_callback = callback() if isinstance(callback, type) else callback
|
32
|
+
_callback_class = callback if isinstance(callback, type) else callback.__class__
|
33
|
+
|
34
|
+
if _callback_class in {c.__class__ for c in self.callbacks}:
|
35
|
+
existing_callbacks = "\n".join(cb for cb in self.callback_list)
|
36
|
+
|
37
|
+
raise ValueError(
|
38
|
+
f"You attempted to add multiple instances of the callback "
|
39
|
+
f"{_callback_class}.\n"
|
40
|
+
f"The list of callbacks already present is: {existing_callbacks}"
|
41
|
+
)
|
42
|
+
self.callbacks.append(_callback)
|
43
|
+
|
44
|
+
def __iter__(self):
|
45
|
+
return self.callbacks
|
46
|
+
|
47
|
+
def clear_callbacks(self):
|
48
|
+
"""
|
49
|
+
Clears all the callbacks in the current list.
|
50
|
+
"""
|
51
|
+
self.callbacks = []
|
52
|
+
|
53
|
+
@property
|
54
|
+
def callback_list(self):
|
55
|
+
"""
|
56
|
+
Retruns the names for the current list of callbacks.
|
57
|
+
"""
|
58
|
+
return [cb.__class__.__name__ for cb in self.callbacks]
|
59
|
+
|
60
|
+
def call_event(self, event, *args, **kwargs):
|
61
|
+
"""
|
62
|
+
For each callback which has been registered, sequentially call the method corresponding
|
63
|
+
to the given event.
|
64
|
+
|
65
|
+
:param event: The event corresponding to the method to call on each callback.
|
66
|
+
:param args: a list of arguments to be passed to each callback.
|
67
|
+
:param kwargs: a list of keyword arguments to be passed to each callback.
|
68
|
+
"""
|
69
|
+
for callback in self.callbacks:
|
70
|
+
try:
|
71
|
+
getattr(callback, event)(
|
72
|
+
*args,
|
73
|
+
**kwargs,
|
74
|
+
)
|
75
|
+
except AttributeError as exc:
|
76
|
+
raise ValueError(
|
77
|
+
"The callback method has not been implemented"
|
78
|
+
) from exc
|
@@ -0,0 +1,139 @@
|
|
1
|
+
"""
|
2
|
+
Defines the ServerCallback class, which is the abstract base class to be subclassed
|
3
|
+
when creating new server callbacks.
|
4
|
+
|
5
|
+
Defines a default callback to print training progress.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
from abc import ABC
|
11
|
+
from plato.config import Config
|
12
|
+
from plato.utils import csv_processor, fonts
|
13
|
+
|
14
|
+
|
15
|
+
class ServerCallback(ABC):
|
16
|
+
"""
|
17
|
+
The abstract base class to be subclassed when creating new server callbacks.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self):
|
21
|
+
"""
|
22
|
+
Initializer.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def on_weights_received(self, server, weights_received):
|
26
|
+
"""
|
27
|
+
Event called after the updated weights have been received.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def on_weights_aggregated(self, server, updates):
|
31
|
+
"""
|
32
|
+
Event called after the updated weights have been aggregated.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def on_clients_selected(self, server, selected_clients, **kwargs):
|
36
|
+
"""
|
37
|
+
Event called after a new client arrived.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def on_clients_processed(self, server, **kwargs):
|
41
|
+
"""Additional work to be performed after client reports have been processed."""
|
42
|
+
|
43
|
+
def on_training_will_start(self, server, **kwargs):
|
44
|
+
"""
|
45
|
+
Event called before selecting clients for the first round of training.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def on_server_will_close(self, server, **kwargs):
|
49
|
+
"""
|
50
|
+
Event called at the start of closing the server.
|
51
|
+
"""
|
52
|
+
|
53
|
+
|
54
|
+
class LogProgressCallback(ServerCallback):
|
55
|
+
"""
|
56
|
+
A callback which prints a message when needed.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(self):
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
recorded_items = Config().params["result_types"]
|
63
|
+
self.recorded_items = [x.strip() for x in recorded_items.split(",")]
|
64
|
+
|
65
|
+
# Initialize the .csv file for logging runtime results
|
66
|
+
result_csv_file = f"{Config().params['result_path']}/{os.getpid()}.csv"
|
67
|
+
csv_processor.initialize_csv(
|
68
|
+
result_csv_file, self.recorded_items, Config().params["result_path"]
|
69
|
+
)
|
70
|
+
|
71
|
+
logging.info(
|
72
|
+
fonts.colourize(
|
73
|
+
f"[{os.getpid()}] Logging runtime results to: {result_csv_file}."
|
74
|
+
)
|
75
|
+
)
|
76
|
+
|
77
|
+
def on_weights_received(self, server, weights_received):
|
78
|
+
"""
|
79
|
+
Event called after the updated weights have been received.
|
80
|
+
"""
|
81
|
+
logging.info("[%s] Updated weights have been received.", server)
|
82
|
+
|
83
|
+
def on_weights_aggregated(self, server, updates):
|
84
|
+
"""
|
85
|
+
Event called after the updated weights have been aggregated.
|
86
|
+
"""
|
87
|
+
logging.info("[%s] Finished aggregating updated weights.", server)
|
88
|
+
|
89
|
+
def on_clients_selected(self, server, selected_clients):
|
90
|
+
"""
|
91
|
+
Event called after clients have been selected in each round.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def on_clients_processed(self, server, **kwargs):
|
95
|
+
"""Additional work to be performed after client reports have been processed."""
|
96
|
+
# Record results into a .csv file
|
97
|
+
new_row = []
|
98
|
+
for item in self.recorded_items:
|
99
|
+
item_value = server.get_logged_items()[item]
|
100
|
+
new_row.append(item_value)
|
101
|
+
|
102
|
+
result_csv_file = f"{Config().params['result_path']}/{os.getpid()}.csv"
|
103
|
+
csv_processor.write_csv(result_csv_file, new_row)
|
104
|
+
|
105
|
+
if (
|
106
|
+
hasattr(Config().clients, "do_test")
|
107
|
+
and Config().clients.do_test
|
108
|
+
and (
|
109
|
+
hasattr(Config(), "results")
|
110
|
+
and hasattr(Config().results, "record_clients_accuracy")
|
111
|
+
and Config().results.record_clients_accuracy
|
112
|
+
)
|
113
|
+
):
|
114
|
+
# Updates the log for client test accuracies
|
115
|
+
accuracy_csv_file = (
|
116
|
+
f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
|
117
|
+
)
|
118
|
+
|
119
|
+
for update in server.updates:
|
120
|
+
accuracy_row = [
|
121
|
+
server.current_round,
|
122
|
+
update.client_id,
|
123
|
+
update.report.accuracy,
|
124
|
+
]
|
125
|
+
csv_processor.write_csv(accuracy_csv_file, accuracy_row)
|
126
|
+
|
127
|
+
logging.info("[%s] All client reports have been processed.", server)
|
128
|
+
|
129
|
+
def on_training_will_start(self, server, **kwargs):
|
130
|
+
"""
|
131
|
+
Event called before selecting clients for the first round of training.
|
132
|
+
"""
|
133
|
+
logging.info("[%s] Starting training.", server)
|
134
|
+
|
135
|
+
def on_server_will_close(self, server, **kwargs):
|
136
|
+
"""
|
137
|
+
Event called at the start of closing the server.
|
138
|
+
"""
|
139
|
+
logging.info("[%s] Closing the server.", server)
|