plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,106 @@
1
+ """
2
+ A federated learning server using federated averaging to aggregate updates after homomorphic encryption.
3
+ """
4
+
5
+ from functools import reduce
6
+ from plato.servers import fedavg
7
+ from plato.utils import homo_enc
8
+
9
+
10
+ class Server(fedavg.Server):
11
+ """
12
+ Federated learning server using federated averaging to aggregate updates after homomorphic
13
+ encryption.
14
+ """
15
+
16
+ def __init__(
17
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
18
+ ):
19
+ super().__init__(
20
+ model=model,
21
+ datasource=datasource,
22
+ algorithm=algorithm,
23
+ trainer=trainer,
24
+ callbacks=callbacks,
25
+ )
26
+
27
+ self.context = homo_enc.get_ckks_context()
28
+ self.encrypted_model = None
29
+ self.weight_shapes = {}
30
+ self.para_nums = {}
31
+
32
+ def configure(self) -> None:
33
+ """Configure the model information like weight shapes and parameter numbers."""
34
+ super().configure()
35
+
36
+ extract_model = self.trainer.model.cpu().state_dict()
37
+
38
+ for key in extract_model.keys():
39
+ self.weight_shapes[key] = extract_model[key].size()
40
+ self.para_nums[key] = extract_model[key].numel()
41
+
42
+ self.encrypted_model = homo_enc.encrypt_weights(
43
+ extract_model, True, self.context, []
44
+ )
45
+
46
+ def customize_server_payload(self, payload):
47
+ """Server can only send the encrypted aggreagtion result to clients."""
48
+ return self.encrypted_model
49
+
50
+ # pylint: disable=unused-argument
51
+ async def aggregate_weights(self, updates, baseline_weights, weights_received):
52
+ """Aggregate the model updates and decrypt the result for evaluation purpose."""
53
+ self.encrypted_model = self._fedavg_hybrid(updates)
54
+
55
+ # Decrypt model weights for test accuracy
56
+ decrypted_weights = homo_enc.decrypt_weights(
57
+ self.encrypted_model, self.weight_shapes, self.para_nums
58
+ )
59
+ # Serialize the encrypted weights after decryption
60
+ self.encrypted_model["encrypted_weights"] = self.encrypted_model[
61
+ "encrypted_weights"
62
+ ].serialize()
63
+
64
+ return decrypted_weights
65
+
66
+ def _fedavg_hybrid(self, updates):
67
+ """Aggregate the model updates in the hybrid form of encrypted and unencrypted weights."""
68
+ weights_received = [
69
+ homo_enc.deserialize_weights(update.payload, self.context)
70
+ for update in updates
71
+ ]
72
+ unencrypted_weights = [
73
+ homo_enc.extract_encrypted_model(x)[0] for x in weights_received
74
+ ]
75
+ encrypted_weights = [
76
+ homo_enc.extract_encrypted_model(x)[1] for x in weights_received
77
+ ]
78
+ # Assert the encrypted weights from all clients are aligned
79
+ indices = [homo_enc.extract_encrypted_model(x)[2] for x in weights_received]
80
+ for i in range(1, len(indices)):
81
+ assert indices[i] == indices[0]
82
+ encrypt_indices = indices[0]
83
+
84
+ # Extract the total number of samples
85
+ self.total_samples = sum(update.report.num_samples for update in updates)
86
+
87
+ # Perform weighted averaging on unencrypted weights
88
+ unencrypted_avg_update = self.trainer.zeros(unencrypted_weights[0].size)
89
+ encrypted_avg_update = self.trainer.zeros(encrypted_weights[0].size())
90
+
91
+ for i, (unenc_w, enc_w) in enumerate(
92
+ zip(unencrypted_weights, encrypted_weights)
93
+ ):
94
+ report = updates[i].report
95
+ num_samples = report.num_samples
96
+
97
+ unencrypted_avg_update += unenc_w * (num_samples / self.total_samples)
98
+ encrypted_avg_update += enc_w * (num_samples / self.total_samples)
99
+
100
+ if len(encrypt_indices) == 0:
101
+ # No weights are encrypted, set to None
102
+ encrypted_avg_update = None
103
+
104
+ return homo_enc.wrap_encrypted_model(
105
+ unencrypted_avg_update, encrypted_avg_update, encrypt_indices
106
+ )
@@ -0,0 +1,57 @@
1
+ """
2
+ A personalized federated learning server that starts from a number of regular
3
+ rounds of federated learning. In these regular rounds, only a subset of the
4
+ total clients can be selected to perform the local update (the ratio of which is
5
+ a configuration setting). After all regular rounds are completed, it starts a
6
+ final round of personalization, where a selected subset of clients perform local
7
+ training using their local dataset.
8
+ """
9
+
10
+ from plato.servers import fedavg
11
+ from plato.config import Config
12
+
13
+
14
+ class Server(fedavg.Server):
15
+ """
16
+ A personalzed FL server that controls how many clients will participate in
17
+ the training process, and that adds a final personalization round with all
18
+ clients sampled.
19
+ """
20
+
21
+ def __init__(
22
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
23
+ ):
24
+ super().__init__(
25
+ model=model,
26
+ datasource=datasource,
27
+ algorithm=algorithm,
28
+ trainer=trainer,
29
+ callbacks=callbacks,
30
+ )
31
+ # Personalization starts after the final regular round of training
32
+ self.personalization_started = False
33
+
34
+ def choose_clients(self, clients_pool, clients_count):
35
+ """Choose a subset of the clients to participate in each round."""
36
+ if self.current_round > Config().trainer.rounds:
37
+ # In the final personalization round, choose from all clients
38
+ return super().choose_clients(clients_pool, clients_count)
39
+ else:
40
+ ratio = Config().algorithm.personalization.participating_client_ratio
41
+
42
+ return super().choose_clients(
43
+ clients_pool[: int(self.total_clients * ratio)],
44
+ clients_count,
45
+ )
46
+
47
+ async def wrap_up(self) -> None:
48
+ """Wraps up when each round of training is done."""
49
+ if self.personalization_started:
50
+ await super().wrap_up()
51
+ else:
52
+ # If the target number of training rounds has been reached, start
53
+ # the final round of training for personalization on the clients
54
+ self.save_to_checkpoint()
55
+
56
+ if self.current_round >= Config().trainer.rounds:
57
+ self.personalization_started = True
@@ -0,0 +1,67 @@
1
+ """
2
+ A federated learning server for MistNet.
3
+
4
+ Reference:
5
+ P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
6
+ Differential Privacy," found in docs/papers.
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ from plato.config import Config
13
+ from plato.datasources import feature
14
+ from plato.samplers import all_inclusive
15
+ from plato.servers import fedavg
16
+
17
+
18
+ class Server(fedavg.Server):
19
+ """The MistNet server for federated learning."""
20
+
21
+ def __init__(
22
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
23
+ ):
24
+ super().__init__(
25
+ model=model,
26
+ datasource=datasource,
27
+ algorithm=algorithm,
28
+ trainer=trainer,
29
+ callbacks=callbacks,
30
+ )
31
+
32
+ # MistNet requires one round of client-server communication
33
+ assert Config().trainer.rounds == 1
34
+
35
+ def init_trainer(self) -> None:
36
+ """Setting up a pre-trained model to be loaded on the server."""
37
+ super().init_trainer()
38
+
39
+ model_path = Config().params["model_path"]
40
+ model_file_name = (
41
+ Config().trainer.pretrained_model
42
+ if hasattr(Config().trainer, "pretrained_model")
43
+ else f"{Config().trainer.model_name}.pth"
44
+ )
45
+ pretrained_model_path = f"{model_path}/{model_file_name}"
46
+
47
+ if os.path.exists(pretrained_model_path):
48
+ logging.info("[Server #%d] Loading a pre-trained model.", os.getpid())
49
+ self.trainer.load_model(filename=model_file_name)
50
+
51
+ async def _process_reports(self):
52
+ """Process the features extracted by the client and perform server-side training."""
53
+ features = [update.payload for update in self.updates]
54
+ feature_dataset = feature.DataSource(features)
55
+
56
+ # Training the model using all the features received from the client
57
+ sampler = all_inclusive.Sampler(feature_dataset)
58
+ self.algorithm.train(feature_dataset, sampler)
59
+
60
+ # Test the updated model
61
+ if not hasattr(Config().server, "do_test") or Config().server.do_test:
62
+ self.accuracy = self.trainer.test(self.testset)
63
+ logging.info(
64
+ "[%s] Global model accuracy: %.2f%%\n", self, 100 * self.accuracy
65
+ )
66
+
67
+ self.clients_processed()
@@ -0,0 +1,52 @@
1
+ """
2
+ The registry for servers that contains framework-agnostic implementations on a
3
+ federated learning server.
4
+
5
+ Having a registry of all available classes is convenient for retrieving an
6
+ instance based on a configuration at run-time.
7
+ """
8
+
9
+ import logging
10
+
11
+ from plato.config import Config
12
+
13
+ from plato.servers import (
14
+ fedavg,
15
+ fedavg_cs,
16
+ mistnet,
17
+ fedavg_gan,
18
+ fedavg_personalized,
19
+ split_learning,
20
+ )
21
+
22
+ if hasattr(Config().server, "type") and Config().server.type == "fedavg_he":
23
+ # FedAvg server with homomorphic encryption supports PyTorch only
24
+ from plato.servers import fedavg_he
25
+
26
+ registered_servers = {"fedavg_he": fedavg_he.Server}
27
+
28
+ else:
29
+ registered_servers = {
30
+ "fedavg": fedavg.Server,
31
+ "fedavg_cross_silo": fedavg_cs.Server,
32
+ "mistnet": mistnet.Server,
33
+ "fedavg_gan": fedavg_gan.Server,
34
+ "fedavg_personalized": fedavg_personalized.Server,
35
+ "split_learning": split_learning.Server,
36
+ }
37
+
38
+
39
+ def get(model=None, algorithm=None, trainer=None):
40
+ """Get an instance of the server."""
41
+ if hasattr(Config().server, "type"):
42
+ server_type = Config().server.type
43
+ else:
44
+ server_type = Config().algorithm.type
45
+
46
+ if server_type in registered_servers:
47
+ logging.info("Server: %s", server_type)
48
+ return registered_servers[server_type](
49
+ model=model, algorithm=algorithm, trainer=trainer
50
+ )
51
+ else:
52
+ raise ValueError(f"No such server: {server_type}")
@@ -0,0 +1,109 @@
1
+ """
2
+ A federated learning server using split learning.
3
+
4
+ Reference:
5
+
6
+ Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
7
+ Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
8
+
9
+ https://arxiv.org/pdf/1812.00564.pdf
10
+
11
+ Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
12
+ Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
13
+
14
+ https://arxiv.org/pdf/2112.01637.pdf
15
+ """
16
+
17
+ import logging
18
+
19
+ from plato.config import Config
20
+ from plato.datasources import feature
21
+ from plato.samplers import all_inclusive
22
+ from plato.servers import fedavg
23
+ from plato.utils import fonts
24
+ from plato.datasources import registry as datasources_registry
25
+
26
+
27
+ # pylint:disable=too-many-instance-attributes
28
+ class Server(fedavg.Server):
29
+ """The split learning server."""
30
+
31
+ # pylint:disable=too-many-arguments
32
+ def __init__(
33
+ self, model=None, datasource=None, algorithm=None, trainer=None, callbacks=None
34
+ ):
35
+ super().__init__(model, datasource, algorithm, trainer, callbacks)
36
+ # Split learning clients interact with server sequentially
37
+ assert Config().clients.per_round == 1
38
+ self.phase = "prompt"
39
+ self.clients_list = []
40
+ self.client_last = None
41
+ self.next_client = True
42
+ self.test_accuracy = 0.0
43
+
44
+ # Manually set up the testset since do_test is turned off in config
45
+ if self.datasource is None and self.custom_datasource is None:
46
+ self.datasource = datasources_registry.get(client_id=0)
47
+ elif self.datasource is None and self.custom_datasource is not None:
48
+ self.datasource = self.custom_datasource()
49
+ self.testset = self.datasource.get_test_set()
50
+ self.testset_sampler = all_inclusive.Sampler(self.datasource, testing=True)
51
+
52
+ def choose_clients(self, clients_pool, clients_count):
53
+ """Shuffle the clients and sequentially select them when the previous one is done."""
54
+ if len(self.clients_list) == 0 and self.next_client:
55
+ # Shuffle the client list
56
+ self.clients_list = super().choose_clients(clients_pool, len(clients_pool))
57
+ logging.warning("Client order: %s", str(self.clients_list))
58
+
59
+ if self.next_client:
60
+ # Sequentially select clients
61
+ self.client_last = [self.clients_list.pop(0)]
62
+ self.next_client = False
63
+ return self.client_last
64
+
65
+ def customize_server_payload(self, payload):
66
+ """Wrap up generating the server payload with any additional information."""
67
+ if self.phase == "prompt":
68
+ # Split learning server doesn't send weights to client
69
+ return (None, "prompt")
70
+ return (self.trainer.get_gradients(), "gradients")
71
+
72
+ # pylint: disable=unused-argument
73
+ async def aggregate_weights(self, updates, baseline_weights, weights_received):
74
+ """Aggregate weight updates from the clients or train the model."""
75
+ update = updates[0]
76
+ report = update.report
77
+ if report.type == "features":
78
+ logging.warning("[%s] Features received, compute gradients.", self)
79
+ feature_dataset = feature.DataSource([update.payload])
80
+
81
+ # Training the model using all the features received from the client
82
+ sampler = all_inclusive.Sampler(feature_dataset)
83
+ self.algorithm.train(feature_dataset, sampler)
84
+
85
+ self.phase = "gradient"
86
+ elif report.type == "weights":
87
+ logging.warning("[%s] Weights received, start testing accuracy.", self)
88
+ weights = update.payload
89
+
90
+ # The weights after cut layer are not trained by clients
91
+ self.algorithm.update_weights_before_cut(weights)
92
+
93
+ self.test_accuracy = self.trainer.test(self.testset, self.testset_sampler)
94
+
95
+ logging.warning(
96
+ fonts.colourize(
97
+ f"[{self}] Global model accuracy: {100 * self.test_accuracy:.2f}%\n"
98
+ )
99
+ )
100
+ self.phase = "prompt"
101
+ # Change client in next round
102
+ self.next_client = True
103
+
104
+ updated_weights = self.algorithm.extract_weights()
105
+ return updated_weights
106
+
107
+ def clients_processed(self):
108
+ # Replace the default accuracy by manually tested accuracy
109
+ self.accuracy = self.test_accuracy
File without changes
plato/trainers/base.py ADDED
@@ -0,0 +1,99 @@
1
+ """
2
+ Base class for trainers.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ import os
7
+
8
+ from plato.config import Config
9
+
10
+
11
+ class Trainer(ABC):
12
+ """Base class for all the trainers."""
13
+
14
+ def __init__(self):
15
+ self.device = Config().device()
16
+ self.client_id = 0
17
+
18
+ def set_client_id(self, client_id):
19
+ """Setting the client ID."""
20
+ self.client_id = client_id
21
+
22
+ @abstractmethod
23
+ def save_model(self, filename=None, location=None):
24
+ """Saving the model to a file."""
25
+ raise TypeError("save_model() not implemented.")
26
+
27
+ @abstractmethod
28
+ def load_model(self, filename=None, location=None):
29
+ """Loading pre-trained model weights from a file."""
30
+ raise TypeError("load_model() not implemented.")
31
+
32
+ @staticmethod
33
+ def save_accuracy(accuracy, filename=None):
34
+ """Saving the test accuracy to a file."""
35
+ model_path = Config().params["model_path"]
36
+ model_name = Config().trainer.model_name
37
+
38
+ if not os.path.exists(model_path):
39
+ os.makedirs(model_path)
40
+
41
+ if filename is not None:
42
+ accuracy_path = f"{model_path}/{filename}"
43
+ else:
44
+ accuracy_path = f"{model_path}/{model_name}.acc"
45
+
46
+ with open(accuracy_path, "w", encoding="utf-8") as file:
47
+ file.write(str(accuracy))
48
+
49
+ @staticmethod
50
+ def load_accuracy(filename=None):
51
+ """Loading the test accuracy from a file."""
52
+ model_path = Config().params["model_path"]
53
+ model_name = Config().trainer.model_name
54
+
55
+ if filename is not None:
56
+ accuracy_path = f"{model_path}/{filename}"
57
+ else:
58
+ accuracy_path = f"{model_path}/{model_name}.acc"
59
+
60
+ with open(accuracy_path, "r", encoding="utf-8") as file:
61
+ accuracy = float(file.read())
62
+
63
+ return accuracy
64
+
65
+ def pause_training(self):
66
+ """Remove files of running trainers."""
67
+ if hasattr(Config().trainer, "max_concurrency"):
68
+ model_name = Config().trainer.model_name
69
+ model_path = Config().params["model_path"]
70
+ model_file = f"{model_path}/{model_name}_{self.client_id}_{Config().params['run_id']}.pth"
71
+ accuracy_file = f"{model_path}/{model_name}_{self.client_id}_{Config().params['run_id']}.acc"
72
+
73
+ if os.path.exists(model_file):
74
+ os.remove(model_file)
75
+ os.remove(model_file + ".pkl")
76
+
77
+ if os.path.exists(accuracy_file):
78
+ os.remove(accuracy_file)
79
+
80
+ @abstractmethod
81
+ def train(self, trainset, sampler, **kwargs) -> float:
82
+ """The main training loop in a federated learning workload.
83
+
84
+ Arguments:
85
+ trainset: The training dataset.
86
+ sampler: the sampler that extracts a partition for this client.
87
+
88
+ Returns:
89
+ float: The training time.
90
+ """
91
+
92
+ @abstractmethod
93
+ def test(self, testset, sampler=None, **kwargs) -> float:
94
+ """Testing the model using the provided test dataset.
95
+
96
+ Arguments:
97
+ testset: The test dataset.
98
+ sampler: The sampler that extracts a partition of the test dataset.
99
+ """