plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/clients/edge.py ADDED
@@ -0,0 +1,103 @@
1
+ """
2
+ A federated learning client at the edge server in a cross-silo training workload.
3
+ """
4
+
5
+ import time
6
+ from types import SimpleNamespace
7
+
8
+ from plato.clients import simple
9
+ from plato.config import Config
10
+ from plato.processors import registry as processor_registry
11
+
12
+
13
+ class Client(simple.Client):
14
+ """A federated learning client at the edge server in a cross-silo training workload."""
15
+
16
+ def __init__(
17
+ self,
18
+ server,
19
+ model=None,
20
+ datasource=None,
21
+ algorithm=None,
22
+ trainer=None,
23
+ callbacks=None,
24
+ ):
25
+ super().__init__(
26
+ model=model,
27
+ datasource=datasource,
28
+ algorithm=algorithm,
29
+ trainer=trainer,
30
+ callbacks=callbacks,
31
+ )
32
+ self.server = server
33
+
34
+ def configure(self) -> None:
35
+ """Prepare this edge client for training."""
36
+ super().configure()
37
+
38
+ # Pass inbound and outbound data payloads through processors for
39
+ # additional data processing
40
+ self.outbound_processor, self.inbound_processor = processor_registry.get(
41
+ "Client", client_id=self.client_id, trainer=self.server.trainer
42
+ )
43
+
44
+ def load_data(self) -> None:
45
+ """The edge client does not need to train models using local data."""
46
+
47
+ def _load_payload(self, server_payload) -> None:
48
+ """The edge client loads the model from the central server."""
49
+ self.server.algorithm.load_weights(server_payload)
50
+
51
+ def process_server_response(self, server_response):
52
+ """Additional client-specific processing on the server response."""
53
+ if "current_global_round" in server_response:
54
+ self.server.current_global_round = server_response["current_global_round"]
55
+
56
+ async def _train(self):
57
+ """The aggregation workload on an edge client."""
58
+ training_start_time = time.perf_counter()
59
+ # Signal edge server to select clients to start a new round of local aggregation
60
+ self.server.new_global_round_begins.set()
61
+
62
+ # Wait for the edge server to finish model aggregation
63
+ await self.server.model_aggregated.wait()
64
+ self.server.model_aggregated.clear()
65
+
66
+ # Extract model weights and biases
67
+ weights = self.server.algorithm.extract_weights()
68
+
69
+ average_accuracy = self.server.average_accuracy
70
+ accuracy = self.server.accuracy
71
+
72
+ if (
73
+ hasattr(Config().clients, "sleep_simulation")
74
+ and Config().clients.sleep_simulation
75
+ ):
76
+ training_time = self.server.edge_training_time
77
+ self.server.edge_training_time = 0
78
+ else:
79
+ training_time = time.perf_counter() - training_start_time
80
+
81
+ comm_time = time.time()
82
+
83
+ edge_server_comm_time = self.server.edge_comm_time
84
+ self.server.edge_comm_time = 0
85
+
86
+ # Generate a report for the central server
87
+ report = SimpleNamespace(
88
+ client_id=self.client_id,
89
+ num_samples=self.server.total_samples,
90
+ accuracy=accuracy,
91
+ training_time=training_time,
92
+ comm_time=comm_time,
93
+ update_response=False,
94
+ average_accuracy=average_accuracy,
95
+ edge_server_comm_overhead=self.server.comm_overhead,
96
+ edge_server_comm_time=edge_server_comm_time,
97
+ )
98
+
99
+ self._report = self.customize_report(report)
100
+
101
+ self.server.comm_overhead = 0
102
+
103
+ return self._report, weights
@@ -0,0 +1,40 @@
1
+ """
2
+ A personalized federated learning client that saves its local layers before
3
+ sending the shared global model to the server after local training.
4
+ """
5
+
6
+ from collections import OrderedDict
7
+
8
+ from plato.clients import simple
9
+ from plato.config import Config
10
+
11
+
12
+ class Client(simple.Client):
13
+ """
14
+ A personalized federated learning client that saves its local layers before sending the
15
+ shared global model to the server after local training.
16
+ """
17
+
18
+ def outbound_ready(self, report, outbound_processor):
19
+ super().outbound_ready(report, outbound_processor)
20
+ weights = self.algorithm.extract_weights()
21
+
22
+ # Save local layers before giving them to the outbound processor
23
+ if hasattr(Config().algorithm, "local_layer_names"):
24
+ # Extract weights of desired local layers
25
+ local_layers = OrderedDict(
26
+ [
27
+ (name, param)
28
+ for name, param in weights.items()
29
+ if any(
30
+ param_name in name.strip().split(".")
31
+ for param_name in Config().algorithm.local_layer_names
32
+ )
33
+ ]
34
+ )
35
+
36
+ model_path = Config().params["model_path"]
37
+ model_name = Config().trainer.model_name
38
+ filename = f"{model_path}/{model_name}_{self.client_id}_local_layers.pth"
39
+
40
+ self.algorithm.save_local_layers(local_layers, filename)
@@ -0,0 +1,49 @@
1
+ """
2
+ A federated learning client for MistNet.
3
+
4
+ Reference:
5
+
6
+ P. Wang, et al. "MistNet: Towards Private Neural Network Training with Local
7
+ Differential Privacy," found in docs/papers.
8
+ """
9
+
10
+ import logging
11
+ import time
12
+ from types import SimpleNamespace
13
+
14
+ from plato.config import Config
15
+ from plato.clients import simple
16
+
17
+
18
+ class Client(simple.Client):
19
+ """A federated learning client for MistNet."""
20
+
21
+ async def _train(self):
22
+ """A MistNet client only uses the first several layers in a forward pass."""
23
+ logging.info("Training on MistNet client #%d", self.client_id)
24
+
25
+ # Since training is performed on the server, the client should not be doing
26
+ # its own testing for the model accuracy
27
+ assert not Config().clients.do_test
28
+
29
+ tic = time.perf_counter()
30
+
31
+ # Perform a forward pass till the cut layer in the model
32
+ features = self.algorithm.extract_features(self.trainset, self.sampler)
33
+
34
+ training_time = time.perf_counter() - tic
35
+
36
+ # Generate a report for the server, performing model testing if applicable
37
+ comm_time = time.time()
38
+ return (
39
+ SimpleNamespace(
40
+ client_id=self.client_id,
41
+ num_samples=self.sampler.num_samples(),
42
+ accuracy=0,
43
+ training_time=training_time,
44
+ comm_time=comm_time,
45
+ update_response=False,
46
+ payload_length=len(features),
47
+ ),
48
+ features,
49
+ )
@@ -0,0 +1,43 @@
1
+ """
2
+ The registry that contains all available federated learning clients.
3
+
4
+ Having a registry of all available classes is convenient for retrieving an instance based
5
+ on a configuration at run-time.
6
+ """
7
+
8
+ import logging
9
+
10
+ from plato.config import Config
11
+ from plato.clients import (
12
+ self_supervised_learning,
13
+ simple,
14
+ mistnet,
15
+ fedavg_personalized,
16
+ split_learning,
17
+ )
18
+
19
+ registered_clients = {
20
+ "simple": simple.Client,
21
+ "mistnet": mistnet.Client,
22
+ "fedavg_personalized": fedavg_personalized.Client,
23
+ "self_supervised_learning": self_supervised_learning.Client,
24
+ "split_learning": split_learning.Client,
25
+ }
26
+
27
+
28
+ def get(model=None, datasource=None, algorithm=None, trainer=None):
29
+ """Get an instance of the server."""
30
+ if hasattr(Config().clients, "type"):
31
+ client_type = Config().clients.type
32
+ else:
33
+ client_type = Config().algorithm.type
34
+
35
+ if client_type in registered_clients:
36
+ logging.info("Client: %s", client_type)
37
+ registered_client = registered_clients[client_type](
38
+ model=model, datasource=datasource, algorithm=algorithm, trainer=trainer
39
+ )
40
+ else:
41
+ raise ValueError(f"No such client: {client_type}")
42
+
43
+ return registered_client
@@ -0,0 +1,51 @@
1
+ """
2
+ A self-supervised learning (SSL) client prepares a personalized datasource for
3
+ the personalization process, which will be performed after finishing the FL
4
+ training process with SSL.
5
+
6
+ Specifically, the conventional FL training process with SSL will train the model
7
+ with the datasource and objective function of SSL. Yet, the datasource used in
8
+ personalization should be one of supervised learning. Therefore, a client needs
9
+ to prepare the personalized datasource.
10
+ """
11
+
12
+ from plato.datasources import registry as datasources_registry
13
+ from plato.clients import simple
14
+
15
+
16
+ class Client(simple.Client):
17
+ """An SSL client to prepare the datasource for personalization."""
18
+
19
+ def __init__(
20
+ self,
21
+ model=None,
22
+ datasource=None,
23
+ algorithm=None,
24
+ trainer=None,
25
+ callbacks=None,
26
+ trainer_callbacks=None,
27
+ ):
28
+ super().__init__(
29
+ model=model,
30
+ datasource=datasource,
31
+ algorithm=algorithm,
32
+ trainer=trainer,
33
+ callbacks=callbacks,
34
+ trainer_callbacks=trainer_callbacks,
35
+ )
36
+ # The datasource used in personalization
37
+ self.personalized_datasource = None
38
+
39
+ def configure(self) -> None:
40
+ """Prepare this client for training."""
41
+ super().configure()
42
+
43
+ # Get the personalized datasource
44
+ if self.personalized_datasource is None:
45
+ personalized_datasource = datasources_registry.get()
46
+
47
+ # Set the train and the test set for the trainer
48
+ self.trainer.set_personalized_datasets(
49
+ personalized_datasource.get_train_set(),
50
+ personalized_datasource.get_test_set(),
51
+ )
@@ -0,0 +1,218 @@
1
+ """
2
+ A basic federated learning client who sends weight updates to the server.
3
+ """
4
+
5
+ import logging
6
+ import time
7
+ from types import SimpleNamespace
8
+
9
+ from plato.algorithms import registry as algorithms_registry
10
+ from plato.clients import base
11
+ from plato.config import Config
12
+ from plato.datasources import registry as datasources_registry
13
+ from plato.processors import registry as processor_registry
14
+ from plato.samplers import registry as samplers_registry
15
+ from plato.trainers import registry as trainers_registry
16
+ from plato.utils import fonts
17
+
18
+
19
+ class Client(base.Client):
20
+ """A basic federated learning client who sends simple weight updates."""
21
+
22
+ def __init__(
23
+ self,
24
+ model=None,
25
+ datasource=None,
26
+ algorithm=None,
27
+ trainer=None,
28
+ callbacks=None,
29
+ trainer_callbacks=None,
30
+ ):
31
+ super().__init__(callbacks=callbacks)
32
+ # Save the callbacks that will be passed to trainer later
33
+ self.trainer_callbacks = trainer_callbacks
34
+
35
+ self.custom_model = model
36
+ self.model = None
37
+
38
+ self.custom_datasource = datasource
39
+ self.datasource = None
40
+
41
+ self.custom_algorithm = algorithm
42
+ self.algorithm = None
43
+
44
+ self.custom_trainer = trainer
45
+ self.trainer = None
46
+
47
+ self.trainset = None # Training dataset
48
+ self.testset = None # Testing dataset
49
+ self.sampler = None
50
+ self.testset_sampler = None # Sampler for the test set
51
+
52
+ self._report = None
53
+
54
+ def configure(self) -> None:
55
+ """Prepares this client for training."""
56
+ super().configure()
57
+
58
+ if self.model is None and self.custom_model is not None:
59
+ self.model = self.custom_model
60
+
61
+ if self.trainer is None and self.custom_trainer is None:
62
+ self.trainer = trainers_registry.get(
63
+ model=self.model, callbacks=self.trainer_callbacks
64
+ )
65
+ elif self.trainer is None and self.custom_trainer is not None:
66
+ self.trainer = self.custom_trainer(
67
+ model=self.model, callbacks=self.trainer_callbacks
68
+ )
69
+
70
+ self.trainer.set_client_id(self.client_id)
71
+
72
+ if self.algorithm is None and self.custom_algorithm is None:
73
+ self.algorithm = algorithms_registry.get(trainer=self.trainer)
74
+ elif self.algorithm is None and self.custom_algorithm is not None:
75
+ self.algorithm = self.custom_algorithm(trainer=self.trainer)
76
+
77
+ self.algorithm.set_client_id(self.client_id)
78
+
79
+ # Pass inbound and outbound data payloads through processors for
80
+ # additional data processing
81
+ self.outbound_processor, self.inbound_processor = processor_registry.get(
82
+ "Client", client_id=self.client_id, trainer=self.trainer
83
+ )
84
+
85
+ # Setting up the data sampler
86
+ if self.datasource:
87
+ self.sampler = samplers_registry.get(self.datasource, self.client_id)
88
+
89
+ if (
90
+ hasattr(Config().clients, "do_test")
91
+ and Config().clients.do_test
92
+ and hasattr(Config().data, "testset_sampler")
93
+ ):
94
+ # Set the sampler for test set
95
+ self.testset_sampler = samplers_registry.get(
96
+ self.datasource, self.client_id, testing=True
97
+ )
98
+
99
+ def _load_data(self) -> None:
100
+ """Generates data and loads them onto this client."""
101
+ # The only case where Config().data.reload_data is set to true is
102
+ # when clients with different client IDs need to load from different datasets,
103
+ # such as in the pre-partitioned Federated EMNIST dataset. We do not support
104
+ # reloading data from a custom datasource at this time.
105
+ if (
106
+ self.datasource is None
107
+ or hasattr(Config().data, "reload_data")
108
+ and Config().data.reload_data
109
+ ):
110
+ logging.info("[%s] Loading its data source...", self)
111
+
112
+ if self.custom_datasource is None:
113
+ self.datasource = datasources_registry.get(client_id=self.client_id)
114
+ elif self.custom_datasource is not None:
115
+ self.datasource = self.custom_datasource()
116
+
117
+ logging.info(
118
+ "[%s] Dataset size: %s", self, self.datasource.num_train_examples()
119
+ )
120
+
121
+ def _allocate_data(self) -> None:
122
+ """Allocate training or testing dataset of this client."""
123
+ # PyTorch uses samplers when loading data with a data loader
124
+ self.trainset = self.datasource.get_train_set()
125
+
126
+ if hasattr(Config().clients, "do_test") and Config().clients.do_test:
127
+ # Set the testset if local testing is needed
128
+ self.testset = self.datasource.get_test_set()
129
+
130
+ def _load_payload(self, server_payload) -> None:
131
+ """Loads the server model onto this client."""
132
+ self.algorithm.load_weights(server_payload)
133
+
134
+ async def _train(self):
135
+ """The machine learning training workload on a client."""
136
+ logging.info(
137
+ fonts.colourize(
138
+ f"[{self}] Started training in communication round #{self.current_round}."
139
+ )
140
+ )
141
+
142
+ # Perform model training
143
+ try:
144
+ if hasattr(self.trainer, "current_round"):
145
+ self.trainer.current_round = self.current_round
146
+ training_time = self.trainer.train(self.trainset, self.sampler)
147
+
148
+ except ValueError as exc:
149
+ logging.info(
150
+ fonts.colourize(f"[{self}] Error occurred during training: {exc}")
151
+ )
152
+ await self.sio.disconnect()
153
+
154
+ # Extract model weights and biases
155
+ weights = self.algorithm.extract_weights()
156
+
157
+ # Generate a report for the server, performing model testing if applicable
158
+ if (hasattr(Config().clients, "do_test") and Config().clients.do_test) and (
159
+ not hasattr(Config().clients, "test_interval")
160
+ or self.current_round % Config().clients.test_interval == 0
161
+ ):
162
+ accuracy = self.trainer.test(self.testset, self.testset_sampler)
163
+
164
+ if accuracy == -1:
165
+ # The testing process failed, disconnect from the server
166
+ logging.info(
167
+ fonts.colourize(
168
+ f"[{self}] Accuracy is -1 when testing. Disconnecting from the server."
169
+ )
170
+ )
171
+ await self.sio.disconnect()
172
+
173
+ if hasattr(Config().trainer, "target_perplexity"):
174
+ logging.info("[%s] Test perplexity: %.2f", self, accuracy)
175
+ else:
176
+ logging.info("[%s] Test accuracy: %.2f%%", self, 100 * accuracy)
177
+ else:
178
+ accuracy = 0
179
+
180
+ comm_time = time.time()
181
+
182
+ if (
183
+ hasattr(Config().clients, "sleep_simulation")
184
+ and Config().clients.sleep_simulation
185
+ ):
186
+ sleep_seconds = Config().client_sleep_times[self.client_id - 1]
187
+ avg_training_time = Config().clients.avg_training_time
188
+
189
+ training_time = (
190
+ avg_training_time + sleep_seconds
191
+ ) * Config().trainer.epochs
192
+
193
+ report = SimpleNamespace(
194
+ client_id=self.client_id,
195
+ num_samples=self.sampler.num_samples(),
196
+ accuracy=accuracy,
197
+ training_time=training_time,
198
+ comm_time=comm_time,
199
+ update_response=False,
200
+ )
201
+
202
+ self._report = self.customize_report(report)
203
+
204
+ return self._report, weights
205
+
206
+ async def _obtain_model_update(self, client_id, requested_time):
207
+ """Retrieves a model update corresponding to a particular wall clock time."""
208
+ model = self.trainer.obtain_model_update(client_id, requested_time)
209
+ weights = self.algorithm.extract_weights(model)
210
+ self._report.comm_time = time.time()
211
+ self._report.client_id = client_id
212
+ self._report.update_response = True
213
+
214
+ return self._report, weights
215
+
216
+ def customize_report(self, report: SimpleNamespace) -> SimpleNamespace:
217
+ """Customizes the report with any additional information."""
218
+ return report
@@ -0,0 +1,150 @@
1
+ """
2
+ A federated learning client using split learning.
3
+
4
+ Reference:
5
+
6
+ Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
7
+ Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
8
+
9
+ https://arxiv.org/pdf/1812.00564.pdf
10
+
11
+ Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
12
+ Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
13
+
14
+ https://arxiv.org/pdf/2112.01637.pdf
15
+ """
16
+
17
+ import logging
18
+ import time
19
+ from types import SimpleNamespace
20
+
21
+ from plato.clients import simple
22
+ from plato.config import Config
23
+ from plato.utils import fonts
24
+
25
+
26
+ class Client(simple.Client):
27
+ """A split learning client."""
28
+
29
+ # pylint:disable=too-many-arguments
30
+ def __init__(
31
+ self,
32
+ model=None,
33
+ datasource=None,
34
+ algorithm=None,
35
+ trainer=None,
36
+ callbacks=None,
37
+ ):
38
+ super().__init__(
39
+ model=model,
40
+ datasource=datasource,
41
+ algorithm=algorithm,
42
+ trainer=trainer,
43
+ callbacks=callbacks,
44
+ )
45
+ assert not Config().clients.do_test
46
+
47
+ self.model_received = False
48
+ self.gradient_received = False
49
+ self.contexts = {}
50
+ self.original_weights = None
51
+
52
+ # Iteration control
53
+ self.iterations = Config().clients.iteration
54
+ self.iter_left = Config().clients.iteration
55
+
56
+ # Sampler cannot be reconfigured otherwise same training samples
57
+ # will be selected every round
58
+ self.static_sampler = None
59
+
60
+ async def inbound_processed(self, processed_inbound_payload):
61
+ """Extract features or complete the training using split learning."""
62
+ server_payload, info = processed_inbound_payload
63
+
64
+ # Preparing the client response
65
+ report, payload = None, None
66
+
67
+ if info == "prompt":
68
+ # Server prompts a new client to conduct split learning
69
+ self._load_context(self.client_id)
70
+ report, payload = self._extract_features()
71
+ elif info == "gradients":
72
+ # server sends the gradients of the features, i.e., complete training
73
+ logging.warning("[%s] Gradients received, complete training.", self)
74
+ training_time, weights = self._complete_training(server_payload)
75
+ self.iter_left -= 1
76
+
77
+ if self.iter_left == 0:
78
+ logging.warning(
79
+ "[%s] Finished training, sending weights to the server.", self
80
+ )
81
+ # Send weights to server for evaluation
82
+ report = SimpleNamespace(
83
+ client_id=self.client_id,
84
+ num_samples=self.sampler.num_samples(),
85
+ accuracy=0,
86
+ training_time=training_time,
87
+ comm_time=time.time(),
88
+ update_response=False,
89
+ type="weights",
90
+ )
91
+ payload = weights
92
+ self.iter_left = self.iterations
93
+ else:
94
+ # Continue feature extraction
95
+ report, payload = self._extract_features()
96
+ report.training_time += training_time
97
+
98
+ # Save the state of current client
99
+ self._save_context(self.client_id)
100
+ return report, payload
101
+
102
+ def _save_context(self, client_id):
103
+ """Saving the extracted weights and the data sampler for a given client."""
104
+ # Sampler needs to be saved otherwise same data samples will be selected every round
105
+ self.contexts[client_id] = (
106
+ self.algorithm.extract_weights(),
107
+ self.static_sampler,
108
+ )
109
+
110
+ def _load_context(self, client_id):
111
+ """Load client's model weights and the sampler from last selected round."""
112
+ if not client_id in self.contexts:
113
+ if self.original_weights is None:
114
+ self.original_weights = self.algorithm.extract_weights()
115
+ self.algorithm.load_weights(self.original_weights)
116
+ self.static_sampler = self.sampler.get()
117
+ else:
118
+ weights, sampler = self.contexts.pop(client_id)
119
+ self.algorithm.load_weights(weights)
120
+ self.static_sampler = sampler
121
+
122
+ def _extract_features(self):
123
+ """Extract the feature till the cut layer."""
124
+ round_number = self.iterations - self.iter_left + 1
125
+ logging.warning(
126
+ fonts.colourize(
127
+ f"[{self}] Started split learning in round #{round_number}/{self.iterations}"
128
+ + f" (Global round {self.current_round})."
129
+ )
130
+ )
131
+
132
+ features, training_time = self.algorithm.extract_features(
133
+ self.trainset, self.static_sampler
134
+ )
135
+ report = SimpleNamespace(
136
+ client_id=self.client_id,
137
+ num_samples=self.sampler.num_samples(),
138
+ accuracy=0,
139
+ training_time=training_time,
140
+ comm_time=time.time(),
141
+ update_response=False,
142
+ type="features",
143
+ )
144
+ return report, features
145
+
146
+ def _complete_training(self, payload):
147
+ """Complete the training based on the gradients from server."""
148
+ training_time = self.algorithm.complete_train(payload)
149
+ weights = self.algorithm.extract_weights()
150
+ return training_time, weights