plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "1.0"
File without changes
@@ -0,0 +1,45 @@
1
+ """
2
+ Base class for algorithms.
3
+ """
4
+
5
+ import os
6
+ from abc import ABC, abstractmethod
7
+
8
+ from plato.trainers.base import Trainer
9
+
10
+
11
+ class Algorithm(ABC):
12
+ """Base class for all the algorithms."""
13
+
14
+ def __init__(self, trainer: Trainer):
15
+ """Initializes the algorithm with the provided model and trainer.
16
+
17
+ Arguments:
18
+ trainer: The trainer for the model, which is a trainers.base.Trainer class.
19
+ model: The model to train.
20
+ """
21
+ super().__init__()
22
+ self.trainer = trainer
23
+ self.model = trainer.model
24
+ self.client_id = 0
25
+
26
+ def __repr__(self):
27
+ if self.client_id == 0:
28
+ return f"Server #{os.getpid()}"
29
+ else:
30
+ return f"Client #{self.client_id}"
31
+
32
+ def set_client_id(self, client_id):
33
+ """Sets the client ID."""
34
+ self.client_id = client_id
35
+
36
+ @abstractmethod
37
+ def extract_weights(self, model=None):
38
+ """Extracts weights from a model passed in as a parameter."""
39
+
40
+ @abstractmethod
41
+ def load_weights(self, weights):
42
+ """Loads the model weights passed in as a parameter."""
43
+
44
+ async def aggregate_weights(self, baseline_weights, weights_received, **kwargs):
45
+ """Aggregates the weights received into baseline weights (optional)."""
@@ -0,0 +1,48 @@
1
+ """
2
+ The federated averaging algorithm for PyTorch.
3
+ """
4
+
5
+ from collections import OrderedDict
6
+
7
+ from plato.algorithms import base
8
+
9
+
10
+ class Algorithm(base.Algorithm):
11
+ """PyTorch-based federated averaging algorithm, used by both the client and the server."""
12
+
13
+ def compute_weight_deltas(self, baseline_weights, weights_received):
14
+ """Compute the deltas between baseline weights and weights received."""
15
+ # Calculate updates from the received weights
16
+ deltas = []
17
+ for weight in weights_received:
18
+ delta = OrderedDict()
19
+ for name, current_weight in weight.items():
20
+ baseline = baseline_weights[name]
21
+
22
+ # Calculate update
23
+ _delta = current_weight - baseline
24
+ delta[name] = _delta
25
+ deltas.append(delta)
26
+
27
+ return deltas
28
+
29
+ def update_weights(self, deltas):
30
+ """Updates the existing model weights from the provided deltas."""
31
+ baseline_weights = self.extract_weights()
32
+
33
+ updated_weights = OrderedDict()
34
+ for name, weight in baseline_weights.items():
35
+ updated_weights[name] = weight + deltas[name]
36
+
37
+ return updated_weights
38
+
39
+ def extract_weights(self, model=None):
40
+ """Extracts weights from the model."""
41
+ if model is None:
42
+ return self.model.cpu().state_dict()
43
+ else:
44
+ return model.cpu().state_dict()
45
+
46
+ def load_weights(self, weights):
47
+ """Loads the model weights passed in as a parameter."""
48
+ self.model.load_state_dict(weights, strict=True)
@@ -0,0 +1,79 @@
1
+ """
2
+ The federated averaging algorithm for GAN model.
3
+ """
4
+
5
+ from collections import OrderedDict
6
+
7
+ from plato.algorithms import fedavg
8
+ from plato.trainers.base import Trainer
9
+
10
+
11
+ class Algorithm(fedavg.Algorithm):
12
+ """Federated averaging algorithm for GAN models, used by both the client and the server."""
13
+
14
+ def __init__(self, trainer: Trainer):
15
+ super().__init__(trainer=trainer)
16
+ self.generator = self.model.generator
17
+ self.discriminator = self.model.discriminator
18
+
19
+ def compute_weight_deltas(self, weights_received):
20
+ """Extract the weights received from a client and compute the updates."""
21
+ baseline_weights_gen, baseline_weights_disc = self.extract_weights()
22
+
23
+ deltas = []
24
+ for weight_gen, weight_disc in weights_received:
25
+ delta_gen = OrderedDict()
26
+ for name, current_weight in weight_gen.items():
27
+ baseline = baseline_weights_gen[name]
28
+
29
+ delta = current_weight - baseline
30
+ delta_gen[name] = delta
31
+
32
+ delta_disc = OrderedDict()
33
+ for name, current_weight in weight_disc.items():
34
+ baseline = baseline_weights_disc[name]
35
+
36
+ delta = current_weight - baseline
37
+ delta_disc[name] = delta
38
+
39
+ deltas.append((delta_gen, delta_disc))
40
+
41
+ return deltas
42
+
43
+ def update_weights(self, deltas):
44
+ """Update the existing model weights."""
45
+ baseline_weights_gen, baseline_weights_disc = self.extract_weights()
46
+ update_gen, update_disc = deltas
47
+
48
+ updated_weights_gen = OrderedDict()
49
+ for name, weight in baseline_weights_gen.items():
50
+ updated_weights_gen[name] = weight + update_gen[name]
51
+
52
+ updated_weights_disc = OrderedDict()
53
+ for name, weight in baseline_weights_disc.items():
54
+ updated_weights_disc[name] = weight + update_disc[name]
55
+
56
+ return updated_weights_gen, updated_weights_disc
57
+
58
+ def extract_weights(self, model=None):
59
+ """Extract weights from the model."""
60
+ generator = self.generator
61
+ discriminator = self.discriminator
62
+ if model is not None:
63
+ generator = model.generator
64
+ discriminator = model.discriminator
65
+
66
+ gen_weight = generator.cpu().state_dict()
67
+ disc_weight = discriminator.cpu().state_dict()
68
+
69
+ return gen_weight, disc_weight
70
+
71
+ def load_weights(self, weights):
72
+ """Load the model weights passed in as a parameter."""
73
+ weights_gen, weights_disc = weights
74
+ # The client might only receive one or none of the Generator
75
+ # and Discriminator model weight.
76
+ if weights_gen is not None:
77
+ self.generator.load_state_dict(weights_gen, strict=True)
78
+ if weights_disc is not None:
79
+ self.discriminator.load_state_dict(weights_disc, strict=True)
@@ -0,0 +1,48 @@
1
+ """
2
+ A personalized federate learning algorithm that loads and saves local layers of a model.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+
8
+ import torch
9
+ from plato.algorithms import fedavg
10
+ from plato.config import Config
11
+
12
+
13
+ class Algorithm(fedavg.Algorithm):
14
+ """
15
+ A personalized federate learning algorithm that loads and saves local layers
16
+ of a model.
17
+ """
18
+
19
+ def load_weights(self, weights):
20
+ """
21
+ Loads local layers included in `local_layer_names` to the received weights which
22
+ will be loaded to the model
23
+ """
24
+ if hasattr(Config().algorithm, "local_layer_names"):
25
+ # Get the filename of the previous saved local layer
26
+ model_path = Config().params["model_path"]
27
+ model_name = Config().trainer.model_name
28
+ filename = f"{model_path}/{model_name}_{self.client_id}_local_layers.pth"
29
+
30
+ # Load local layers to the weights when the file exists
31
+ if os.path.exists(filename):
32
+ local_layers = torch.load(filename, map_location=torch.device("cpu"))
33
+
34
+ # Update the received weights with the loaded local layers
35
+ weights.update(local_layers)
36
+
37
+ logging.info(
38
+ "[Client #%d] Replaced portions of the global model with local layers.",
39
+ self.trainer.client_id,
40
+ )
41
+
42
+ self.model.load_state_dict(weights, strict=True)
43
+
44
+ def save_local_layers(self, local_layers, filename):
45
+ """
46
+ Save local layers to a file with the filename provided.
47
+ """
48
+ torch.save(local_layers, filename)
@@ -0,0 +1,52 @@
1
+ """
2
+ The PyTorch-based MistNet algorithm, used by both the client and the server.
3
+
4
+ Reference:
5
+
6
+ P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
7
+ Differential Privacy," found in docs/papers.
8
+ """
9
+
10
+ import logging
11
+ import time
12
+
13
+ import torch
14
+ from plato.algorithms import fedavg
15
+ from plato.datasources import feature_dataset
16
+
17
+
18
+ class Algorithm(fedavg.Algorithm):
19
+ """The PyTorch-based MistNet algorithm, used by both the client and the
20
+ server.
21
+ """
22
+
23
+ def extract_features(self, dataset, sampler):
24
+ """Extracting features using layers before the cut_layer.
25
+
26
+ dataset: The training or testing dataset.
27
+ """
28
+ self.model.eval()
29
+
30
+ data_loader = self.trainer.get_train_loader(
31
+ batch_size=1, trainset=dataset, sampler=sampler.get(), extract_features=True
32
+ )
33
+
34
+ tic = time.perf_counter()
35
+
36
+ features_dataset = []
37
+
38
+ for inputs, targets, *__ in data_loader:
39
+ with torch.no_grad():
40
+ logits = self.model.forward_to(inputs)
41
+ features_dataset.append((logits, targets))
42
+
43
+ toc = time.perf_counter()
44
+ logging.info("[Client #%s] Time used: %.2f seconds.", self.client_id, toc - tic)
45
+
46
+ return features_dataset
47
+
48
+ def train(self, trainset, sampler):
49
+ """Train the neural network model after the cut layer."""
50
+ self.trainer.train(
51
+ feature_dataset.FeatureDataset(trainset.feature_dataset), sampler
52
+ )
@@ -0,0 +1,39 @@
1
+ """
2
+ The registry for algorithms that contains framework-specific implementations.
3
+
4
+ Having a registry of all available classes is convenient for retrieving an instance
5
+ based on a configuration at run-time.
6
+ """
7
+
8
+ import logging
9
+
10
+ from plato.config import Config
11
+
12
+
13
+ from plato.algorithms import (
14
+ fedavg,
15
+ mistnet,
16
+ fedavg_gan,
17
+ fedavg_personalized,
18
+ split_learning,
19
+ )
20
+
21
+ registered_algorithms = {
22
+ "fedavg": fedavg.Algorithm,
23
+ "mistnet": mistnet.Algorithm,
24
+ "fedavg_gan": fedavg_gan.Algorithm,
25
+ "fedavg_personalized": fedavg_personalized.Algorithm,
26
+ "split_learning": split_learning.Algorithm,
27
+ }
28
+
29
+
30
+ def get(trainer=None):
31
+ """Get the algorithm with the provided type."""
32
+ algorithm_type = Config().algorithm.type
33
+
34
+ if algorithm_type in registered_algorithms:
35
+ logging.info("Algorithm: %s", algorithm_type)
36
+ registered_alg = registered_algorithms[algorithm_type](trainer)
37
+ return registered_alg
38
+ else:
39
+ raise ValueError(f"No such algorithm: {algorithm_type}")
@@ -0,0 +1,89 @@
1
+ """
2
+ A federated learning algorithm using split learning.
3
+
4
+ Reference:
5
+
6
+ Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
7
+ Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
8
+
9
+ https://arxiv.org/pdf/1812.00564.pdf
10
+
11
+ Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
12
+ Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
13
+
14
+ https://arxiv.org/pdf/2112.01637.pdf
15
+ """
16
+
17
+ import logging
18
+ import time
19
+
20
+ from plato.algorithms import fedavg
21
+ from plato.config import Config
22
+ from plato.datasources import feature_dataset
23
+
24
+
25
+ class Algorithm(fedavg.Algorithm):
26
+ """The PyTorch-based split learning algorithm, used by both the client and the
27
+ server.
28
+ """
29
+
30
+ def extract_features(self, dataset, sampler):
31
+ """Extracting features using layers before the cut_layer."""
32
+ self.model.to(self.trainer.device)
33
+ self.model.eval()
34
+
35
+ tic = time.perf_counter()
36
+
37
+ features_dataset = []
38
+
39
+ inputs, targets = self.trainer.get_train_samples(
40
+ Config().trainer.batch_size, dataset, sampler
41
+ )
42
+ inputs = inputs.to(self.trainer.device)
43
+ targets = targets.to(self.trainer.device)
44
+ outputs, targets = self.trainer.forward_to_intermediate_feature(inputs, targets)
45
+ features_dataset.append((outputs, targets))
46
+
47
+ toc = time.perf_counter()
48
+ logging.warning(
49
+ "[Client #%d] Features extracted from %s examples in %.2f seconds.",
50
+ self.client_id,
51
+ Config().trainer.batch_size,
52
+ toc - tic,
53
+ )
54
+
55
+ return features_dataset, toc - tic
56
+
57
+ def complete_train(self, gradients):
58
+ """Update the model on the client/device with the gradients received
59
+ from the server.
60
+ """
61
+ tic = time.perf_counter()
62
+
63
+ # Retrieve the training samples and let trainer do the training
64
+ samples, sampler = self.trainer.retrieve_train_samples()
65
+ self.trainer.load_gradients(gradients)
66
+ self.train(samples, sampler)
67
+
68
+ toc = time.perf_counter()
69
+ logging.warning(
70
+ "[Client #%d] Training completed in %.2f seconds.",
71
+ self.client_id,
72
+ toc - tic,
73
+ )
74
+
75
+ return toc - tic
76
+
77
+ def train(self, trainset, sampler):
78
+ """General training method that trains model with provided trainset and sampler."""
79
+ self.trainer.train(
80
+ feature_dataset.FeatureDataset(trainset.feature_dataset), sampler
81
+ )
82
+
83
+ def update_weights_before_cut(self, weights):
84
+ """Update the weights before cut layer, called when testing accuracy."""
85
+ current_weights = self.extract_weights()
86
+ current_weights = self.trainer.update_weights_before_cut(
87
+ current_weights, weights
88
+ )
89
+ self.load_weights(current_weights)
File without changes
@@ -0,0 +1,56 @@
1
+ """
2
+ Defines the ClientCallback class, which is the abstract base class to be subclassed
3
+ when creating new client callbacks.
4
+
5
+ Defines a default callback to print local training progress.
6
+ """
7
+
8
+ from abc import ABC
9
+ import logging
10
+
11
+
12
+ class ClientCallback(ABC):
13
+ """
14
+ The abstract base class to be subclassed when creating new client callbacks.
15
+ """
16
+
17
+ def on_inbound_received(self, client, inbound_processor):
18
+ """
19
+ Event called before inbound processors start to process data.
20
+ """
21
+
22
+ def on_inbound_processed(self, client, data):
23
+ """
24
+ Event called when payload was processed by inbound processors.
25
+ """
26
+
27
+ def on_outbound_ready(self, client, report, outbound_processor):
28
+ """
29
+ Event called before outbound processors start to process data.
30
+ """
31
+
32
+
33
+ class LogProgressCallback(ClientCallback):
34
+ """
35
+ A callback which prints a message when needed.
36
+ """
37
+
38
+ def on_inbound_received(self, client, inbound_processor):
39
+ """
40
+ Event called before inbound processors start to process data.
41
+ """
42
+ logging.info("[%s] Start to process inbound data.", client)
43
+
44
+ def on_inbound_processed(self, client, data):
45
+ """
46
+ Event called when payload was processed by inbound processors.
47
+ """
48
+ logging.info("[%s] Inbound data has been processed.", client)
49
+
50
+ def on_outbound_ready(self, client, report, outbound_processor):
51
+ """
52
+ Event called before outbound processors start to process data.
53
+ """
54
+ logging.info(
55
+ "[%s] Outbound data is ready to be sent after being processed.", client
56
+ )
@@ -0,0 +1,78 @@
1
+ """
2
+ Defines the :class:`CallbackHandler`, which is responsible for calling a list of callbacks.
3
+ """
4
+
5
+
6
+ class CallbackHandler:
7
+ """
8
+ The :class:`CallbackHandler` is responsible for calling a list of callbacks.
9
+ This class calls the callbacks in the order that they are given.
10
+ """
11
+
12
+ def __init__(self, callbacks):
13
+ self.callbacks = []
14
+ self.add_callbacks(callbacks)
15
+
16
+ def add_callbacks(self, callbacks):
17
+ """
18
+ Adds a list of callbacks to the callback handler.
19
+
20
+ :param callbacks: a list of instances of a subclass of :class:`TrainerCallback`.
21
+ """
22
+ for callback in callbacks:
23
+ self.add_callback(callback)
24
+
25
+ def add_callback(self, callback):
26
+ """
27
+ Adds a callback to the callback handler.
28
+
29
+ :param callback: an instance of a subclass of :class:`TrainerCallback`.
30
+ """
31
+ _callback = callback() if isinstance(callback, type) else callback
32
+ _callback_class = callback if isinstance(callback, type) else callback.__class__
33
+
34
+ if _callback_class in {c.__class__ for c in self.callbacks}:
35
+ existing_callbacks = "\n".join(cb for cb in self.callback_list)
36
+
37
+ raise ValueError(
38
+ f"You attempted to add multiple instances of the callback "
39
+ f"{_callback_class}.\n"
40
+ f"The list of callbacks already present is: {existing_callbacks}"
41
+ )
42
+ self.callbacks.append(_callback)
43
+
44
+ def __iter__(self):
45
+ return self.callbacks
46
+
47
+ def clear_callbacks(self):
48
+ """
49
+ Clears all the callbacks in the current list.
50
+ """
51
+ self.callbacks = []
52
+
53
+ @property
54
+ def callback_list(self):
55
+ """
56
+ Retruns the names for the current list of callbacks.
57
+ """
58
+ return [cb.__class__.__name__ for cb in self.callbacks]
59
+
60
+ def call_event(self, event, *args, **kwargs):
61
+ """
62
+ For each callback which has been registered, sequentially call the method corresponding
63
+ to the given event.
64
+
65
+ :param event: The event corresponding to the method to call on each callback.
66
+ :param args: a list of arguments to be passed to each callback.
67
+ :param kwargs: a list of keyword arguments to be passed to each callback.
68
+ """
69
+ for callback in self.callbacks:
70
+ try:
71
+ getattr(callback, event)(
72
+ *args,
73
+ **kwargs,
74
+ )
75
+ except AttributeError as exc:
76
+ raise ValueError(
77
+ "The callback method has not been implemented"
78
+ ) from exc
@@ -0,0 +1,139 @@
1
+ """
2
+ Defines the ServerCallback class, which is the abstract base class to be subclassed
3
+ when creating new server callbacks.
4
+
5
+ Defines a default callback to print training progress.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ from abc import ABC
11
+ from plato.config import Config
12
+ from plato.utils import csv_processor, fonts
13
+
14
+
15
+ class ServerCallback(ABC):
16
+ """
17
+ The abstract base class to be subclassed when creating new server callbacks.
18
+ """
19
+
20
+ def __init__(self):
21
+ """
22
+ Initializer.
23
+ """
24
+
25
+ def on_weights_received(self, server, weights_received):
26
+ """
27
+ Event called after the updated weights have been received.
28
+ """
29
+
30
+ def on_weights_aggregated(self, server, updates):
31
+ """
32
+ Event called after the updated weights have been aggregated.
33
+ """
34
+
35
+ def on_clients_selected(self, server, selected_clients, **kwargs):
36
+ """
37
+ Event called after a new client arrived.
38
+ """
39
+
40
+ def on_clients_processed(self, server, **kwargs):
41
+ """Additional work to be performed after client reports have been processed."""
42
+
43
+ def on_training_will_start(self, server, **kwargs):
44
+ """
45
+ Event called before selecting clients for the first round of training.
46
+ """
47
+
48
+ def on_server_will_close(self, server, **kwargs):
49
+ """
50
+ Event called at the start of closing the server.
51
+ """
52
+
53
+
54
+ class LogProgressCallback(ServerCallback):
55
+ """
56
+ A callback which prints a message when needed.
57
+ """
58
+
59
+ def __init__(self):
60
+ super().__init__()
61
+
62
+ recorded_items = Config().params["result_types"]
63
+ self.recorded_items = [x.strip() for x in recorded_items.split(",")]
64
+
65
+ # Initialize the .csv file for logging runtime results
66
+ result_csv_file = f"{Config().params['result_path']}/{os.getpid()}.csv"
67
+ csv_processor.initialize_csv(
68
+ result_csv_file, self.recorded_items, Config().params["result_path"]
69
+ )
70
+
71
+ logging.info(
72
+ fonts.colourize(
73
+ f"[{os.getpid()}] Logging runtime results to: {result_csv_file}."
74
+ )
75
+ )
76
+
77
+ def on_weights_received(self, server, weights_received):
78
+ """
79
+ Event called after the updated weights have been received.
80
+ """
81
+ logging.info("[%s] Updated weights have been received.", server)
82
+
83
+ def on_weights_aggregated(self, server, updates):
84
+ """
85
+ Event called after the updated weights have been aggregated.
86
+ """
87
+ logging.info("[%s] Finished aggregating updated weights.", server)
88
+
89
+ def on_clients_selected(self, server, selected_clients):
90
+ """
91
+ Event called after clients have been selected in each round.
92
+ """
93
+
94
+ def on_clients_processed(self, server, **kwargs):
95
+ """Additional work to be performed after client reports have been processed."""
96
+ # Record results into a .csv file
97
+ new_row = []
98
+ for item in self.recorded_items:
99
+ item_value = server.get_logged_items()[item]
100
+ new_row.append(item_value)
101
+
102
+ result_csv_file = f"{Config().params['result_path']}/{os.getpid()}.csv"
103
+ csv_processor.write_csv(result_csv_file, new_row)
104
+
105
+ if (
106
+ hasattr(Config().clients, "do_test")
107
+ and Config().clients.do_test
108
+ and (
109
+ hasattr(Config(), "results")
110
+ and hasattr(Config().results, "record_clients_accuracy")
111
+ and Config().results.record_clients_accuracy
112
+ )
113
+ ):
114
+ # Updates the log for client test accuracies
115
+ accuracy_csv_file = (
116
+ f"{Config().params['result_path']}/{os.getpid()}_accuracy.csv"
117
+ )
118
+
119
+ for update in server.updates:
120
+ accuracy_row = [
121
+ server.current_round,
122
+ update.client_id,
123
+ update.report.accuracy,
124
+ ]
125
+ csv_processor.write_csv(accuracy_csv_file, accuracy_row)
126
+
127
+ logging.info("[%s] All client reports have been processed.", server)
128
+
129
+ def on_training_will_start(self, server, **kwargs):
130
+ """
131
+ Event called before selecting clients for the first round of training.
132
+ """
133
+ logging.info("[%s] Starting training.", server)
134
+
135
+ def on_server_will_close(self, server, **kwargs):
136
+ """
137
+ Event called at the start of closing the server.
138
+ """
139
+ logging.info("[%s] Closing the server.", server)