plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,124 @@
1
+ """
2
+ Defines the TrainerCallback class, which is the abstract base class to be subclassed
3
+ when creating new trainer callbacks.
4
+
5
+ Defines a default callback to print training progress.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ from abc import ABC
11
+
12
+ from plato.utils import fonts
13
+
14
+
15
+ class TrainerCallback(ABC):
16
+ """
17
+ The abstract base class to be subclassed when creating new trainer callbacks.
18
+ """
19
+
20
+ def on_train_run_start(self, trainer, config, **kwargs):
21
+ """
22
+ Event called at the start of training run.
23
+ """
24
+
25
+ def on_train_run_end(self, trainer, config, **kwargs):
26
+ """
27
+ Event called at the end of training run.
28
+ """
29
+
30
+ def on_train_epoch_start(self, trainer, config, **kwargs):
31
+ """
32
+ Event called at the beginning of a training epoch.
33
+ """
34
+
35
+ def on_train_step_start(self, trainer, config, batch, **kwargs):
36
+ """
37
+ Event called at the beginning of a training step.
38
+
39
+ :param batch: the current batch of training data.
40
+ """
41
+
42
+ def on_train_step_end(self, trainer, config, batch, loss, **kwargs):
43
+ """
44
+ Event called at the end of a training step.
45
+
46
+ :param batch: the current batch of training data.
47
+ :param loss: the loss computed in the current batch.
48
+ """
49
+
50
+ def on_train_epoch_end(self, trainer, config, **kwargs):
51
+ """
52
+ Event called at the end of a training epoch.
53
+ """
54
+
55
+
56
+ class LogProgressCallback(TrainerCallback):
57
+ """
58
+ A callback which prints a message at the start of each epoch, and at the end of each step.
59
+ """
60
+
61
+ def on_train_run_start(self, trainer, config, **kwargs):
62
+ """
63
+ Event called at the start of training run.
64
+ """
65
+ if trainer.client_id == 0:
66
+ logging.info(
67
+ "[Server #%s] Loading the dataset with size %d.",
68
+ os.getpid(),
69
+ len(list(trainer.sampler)),
70
+ )
71
+ else:
72
+ logging.info(
73
+ "[Client #%d] Loading the dataset with size %d.",
74
+ trainer.client_id,
75
+ len(list(trainer.sampler)),
76
+ )
77
+
78
+ def on_train_epoch_start(self, trainer, config, **kwargs):
79
+ """
80
+ Event called at the beginning of a training epoch.
81
+ """
82
+ if trainer.client_id == 0:
83
+ logging.info(
84
+ fonts.colourize(
85
+ f"[Server #{os.getpid()}] Started training epoch {trainer.current_epoch}."
86
+ )
87
+ )
88
+ else:
89
+ logging.info(
90
+ fonts.colourize(
91
+ f"[Client #{trainer.client_id}] Started training epoch {trainer.current_epoch}."
92
+ )
93
+ )
94
+
95
+ def on_train_step_end(self, trainer, config, batch=None, loss=None, **kwargs):
96
+ """
97
+ Event called at the end of a training step.
98
+
99
+ :param batch: the current batch of training data.
100
+ :param loss: the loss computed in the current batch.
101
+ """
102
+ log_interval = 10
103
+
104
+ if batch % log_interval == 0:
105
+ if trainer.client_id == 0:
106
+ logging.info(
107
+ "[Server #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
108
+ os.getpid(),
109
+ trainer.current_epoch,
110
+ config["epochs"],
111
+ batch,
112
+ len(trainer.train_loader),
113
+ loss.data.item(),
114
+ )
115
+ else:
116
+ logging.info(
117
+ "[Client #%d] Epoch: [%d/%d][%d/%d]\tLoss: %.6f",
118
+ trainer.client_id,
119
+ trainer.current_epoch,
120
+ config["epochs"],
121
+ batch,
122
+ len(trainer.train_loader),
123
+ loss.data.item(),
124
+ )
plato/client.py ADDED
@@ -0,0 +1,67 @@
1
+ """
2
+ Starting point for a Plato federated learning client.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+
9
+ from plato.clients import registry as client_registry
10
+ from plato.config import Config
11
+
12
+
13
+ def run(client_id, port, client=None, edge_server=None, edge_client=None, trainer=None):
14
+ """Starting a client to connect to the server."""
15
+ Config().args.id = client_id
16
+ if port is not None:
17
+ Config().args.port = port
18
+
19
+ # If a server needs to be running concurrently
20
+ if Config().is_edge_server():
21
+ Config().trainer = Config().trainer._replace(
22
+ rounds=Config().algorithm.local_rounds
23
+ )
24
+
25
+ if edge_server is None:
26
+ from plato.clients import edge
27
+ from plato.servers import fedavg_cs
28
+
29
+ server = fedavg_cs.Server()
30
+ client = edge.Client(server)
31
+ else:
32
+ # A customized edge server
33
+ if trainer is not None:
34
+ server = edge_server(trainer=trainer())
35
+ else:
36
+ server = edge_server()
37
+ client = edge_client(server)
38
+
39
+ server.configure()
40
+ client.configure()
41
+
42
+ logging.info("Starting an edge server as client #%d", Config().args.id)
43
+ asyncio.ensure_future(client.start_client())
44
+
45
+ logging.info(
46
+ "Starting an edge server as server #%d on port %d",
47
+ os.getpid(),
48
+ Config().args.port,
49
+ )
50
+ server.start(port=Config().args.port)
51
+
52
+ else:
53
+ if client is None:
54
+ client = client_registry.get()
55
+ logging.info("Starting a %s client #%d.", Config().clients.type, client_id)
56
+ else:
57
+ client.client_id = client_id
58
+ logging.info("Starting a custom client #%d", client_id)
59
+
60
+ client.configure()
61
+
62
+ loop = asyncio.get_event_loop()
63
+ loop.run_until_complete(client.start_client())
64
+
65
+
66
+ if __name__ == "__main__":
67
+ run(Config().args.id, Config().args.port)
File without changes
plato/clients/base.py ADDED
@@ -0,0 +1,467 @@
1
+ """
2
+ The base class for all federated learning clients on edge devices or edge servers.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ import pickle
9
+ import re
10
+ import sys
11
+ import time
12
+ import uuid
13
+ from abc import abstractmethod
14
+
15
+ import numpy as np
16
+ import socketio
17
+
18
+ from plato.callbacks.client import LogProgressCallback
19
+ from plato.callbacks.handler import CallbackHandler
20
+ from plato.config import Config
21
+ from plato.utils import s3
22
+
23
+
24
+ # pylint: disable=unused-argument, protected-access
25
+ class ClientEvents(socketio.AsyncClientNamespace):
26
+ """A custom namespace for socketio.AsyncServer."""
27
+
28
+ def __init__(self, namespace, plato_client):
29
+ super().__init__(namespace)
30
+ self.plato_client = plato_client
31
+ self.client_id = plato_client.client_id
32
+
33
+ async def on_connect(self):
34
+ """Upon a new connection to the server."""
35
+ logging.info("[Client #%d] Connected to the server.", self.client_id)
36
+
37
+ async def on_disconnect(self):
38
+ """Upon a disconnection event."""
39
+ logging.info(
40
+ "[Client #%d] The server disconnected the connection.", self.client_id
41
+ )
42
+ self.plato_client._clear_checkpoint_files()
43
+ os._exit(0)
44
+
45
+ async def on_connect_error(self, data):
46
+ """Upon a failed connection attempt to the server."""
47
+ logging.info(
48
+ "[Client #%d] A connection attempt to the server failed.", self.client_id
49
+ )
50
+
51
+ async def on_payload_to_arrive(self, data):
52
+ """New payload is about to arrive from the server."""
53
+ await self.plato_client._payload_to_arrive(data["response"])
54
+
55
+ async def on_request_update(self, data):
56
+ """The server is requesting an urgent model update."""
57
+ await self.plato_client._request_update(data)
58
+
59
+ async def on_chunk(self, data):
60
+ """A chunk of data from the server arrived."""
61
+ await self.plato_client._chunk_arrived(data["data"])
62
+
63
+ async def on_payload(self, data):
64
+ """A portion of the new payload from the server arrived."""
65
+ await self.plato_client._payload_arrived(data["id"])
66
+
67
+ async def on_payload_done(self, data):
68
+ """All of the new payload sent from the server arrived."""
69
+ if "s3_key" in data:
70
+ await self.plato_client._payload_done(data["id"], s3_key=data["s3_key"])
71
+ else:
72
+ await self.plato_client._payload_done(data["id"])
73
+
74
+
75
+ class Client:
76
+ """A basic federated learning client."""
77
+
78
+ def __init__(self, callbacks=None) -> None:
79
+ self.client_id = Config().args.id
80
+ self.current_round = 0
81
+ self.sio = None
82
+ self.chunks = []
83
+ self.server_payload = None
84
+ self.s3_client = None
85
+ self.outbound_processor = None
86
+ self.inbound_processor = None
87
+ self.payload = None
88
+ self.report = None
89
+
90
+ self.processing_time = 0
91
+
92
+ self.comm_simulation = (
93
+ Config().clients.comm_simulation
94
+ if hasattr(Config().clients, "comm_simulation")
95
+ else True
96
+ )
97
+
98
+ if hasattr(Config().algorithm, "cross_silo") and not Config().is_edge_server():
99
+ self.edge_server_id = None
100
+
101
+ assert hasattr(Config().algorithm, "total_silos")
102
+
103
+ # Starting from the default client callback class, add all supplied server callbacks
104
+ self.callbacks = [LogProgressCallback]
105
+ if callbacks is not None:
106
+ self.callbacks.extend(callbacks)
107
+ self.callback_handler = CallbackHandler(self.callbacks)
108
+
109
+ def __repr__(self):
110
+ return f"Client #{self.client_id}"
111
+
112
+ async def start_client(self) -> None:
113
+ """Startup function for a client."""
114
+ if hasattr(Config().algorithm, "cross_silo") and not Config().is_edge_server():
115
+ # Contact one of the edge servers
116
+ self.edge_server_id = self.get_edge_server_id()
117
+
118
+ logging.info(
119
+ "[Client #%d] Contacting Edge Server #%d.",
120
+ self.client_id,
121
+ self.edge_server_id,
122
+ )
123
+ else:
124
+ await asyncio.sleep(5)
125
+ logging.info("[Client #%d] Contacting the server.", self.client_id)
126
+
127
+ self.sio = socketio.AsyncClient(reconnection=True)
128
+ self.sio.register_namespace(ClientEvents(namespace="/", plato_client=self))
129
+
130
+ if hasattr(Config().server, "s3_endpoint_url"):
131
+ self.s3_client = s3.S3()
132
+
133
+ if hasattr(Config().server, "use_https"):
134
+ uri = f"https://{Config().server.address}"
135
+ else:
136
+ uri = f"http://{Config().server.address}"
137
+
138
+ if hasattr(Config().server, "port"):
139
+ # If we are not using a production server deployed in the cloud
140
+ if (
141
+ hasattr(Config().algorithm, "cross_silo")
142
+ and not Config().is_edge_server()
143
+ ):
144
+ uri = f"{uri}:{int(Config().server.port) + int(self.edge_server_id)}"
145
+ else:
146
+ uri = f"{uri}:{Config().server.port}"
147
+
148
+ logging.info("[%s] Connecting to the server at %s.", self, uri)
149
+ await self.sio.connect(uri, wait_timeout=600)
150
+ await self.sio.emit("client_alive", {"pid": os.getpid(), "id": self.client_id})
151
+
152
+ logging.info("[Client #%d] Waiting to be selected.", self.client_id)
153
+ await self.sio.wait()
154
+
155
+ def get_edge_server_id(self):
156
+ """Returns the edge server id of the client in cross-silo FL."""
157
+ launched_client_num = (
158
+ min(
159
+ Config().trainer.max_concurrency
160
+ * max(1, Config().gpu_count())
161
+ * Config().algorithm.total_silos,
162
+ Config().clients.per_round,
163
+ )
164
+ if hasattr(Config().trainer, "max_concurrency")
165
+ else Config().clients.per_round
166
+ )
167
+
168
+ edges_launched_clients = [
169
+ len(i)
170
+ for i in np.array_split(
171
+ np.arange(launched_client_num), Config().algorithm.total_silos
172
+ )
173
+ ]
174
+
175
+ total = 0
176
+ for i, count in enumerate(edges_launched_clients):
177
+ total += count
178
+ if self.client_id <= total:
179
+ return i + 1 + Config().clients.total_clients
180
+
181
+ async def _payload_to_arrive(self, response) -> None:
182
+ """Upon receiving a response from the server."""
183
+ self.current_round = response["current_round"]
184
+
185
+ # Update (virtual) client id for client, trainer and algorithm
186
+ self.client_id = response["id"]
187
+
188
+ logging.info("[Client #%d] Selected by the server.", self.client_id)
189
+
190
+ self.process_server_response(response)
191
+
192
+ self._load_data()
193
+ self.configure()
194
+ self._allocate_data()
195
+
196
+ self.server_payload = None
197
+
198
+ if self.comm_simulation:
199
+ payload_filename = response["payload_filename"]
200
+ with open(payload_filename, "rb") as payload_file:
201
+ self.server_payload = pickle.load(payload_file)
202
+
203
+ payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
204
+
205
+ logging.info(
206
+ "[%s] Received %.2f MB of payload data from the server (simulated).",
207
+ self,
208
+ payload_size / 1024**2,
209
+ )
210
+
211
+ await self._handle_payload(self.server_payload)
212
+
213
+ async def _handle_payload(self, inbound_payload):
214
+ """Handles the inbound payload upon receiving it from the server."""
215
+ self.inbound_received(self.inbound_processor)
216
+ self.callback_handler.call_event(
217
+ "on_inbound_received", self, self.inbound_processor
218
+ )
219
+
220
+ tic = time.perf_counter()
221
+ processed_inbound_payload = self.inbound_processor.process(inbound_payload)
222
+ self.processing_time = time.perf_counter() - tic
223
+
224
+ # Inbound data is processed, computing outbound response
225
+ report, outbound_payload = await self.inbound_processed(
226
+ processed_inbound_payload
227
+ )
228
+ self.callback_handler.call_event(
229
+ "on_inbound_processed", self, processed_inbound_payload
230
+ )
231
+
232
+ # Outbound data is ready to be processed
233
+ tic = time.perf_counter()
234
+ self.outbound_ready(report, self.outbound_processor)
235
+ self.callback_handler.call_event(
236
+ "on_outbound_ready", self, report, self.outbound_processor
237
+ )
238
+ processed_outbound_payload = self.outbound_processor.process(outbound_payload)
239
+ self.processing_time += time.perf_counter() - tic
240
+ report.processing_time = self.processing_time
241
+
242
+ # Sending the client report as metadata to the server (payload to follow)
243
+ await self.sio.emit(
244
+ "client_report", {"id": self.client_id, "report": pickle.dumps(report)}
245
+ )
246
+
247
+ # Sending the client training payload to the server
248
+ await self._send(processed_outbound_payload)
249
+
250
+ def inbound_received(self, inbound_processor):
251
+ """
252
+ Override this method to complete additional tasks before the inbound processors start to
253
+ process the data received from the server.
254
+ """
255
+
256
+ async def inbound_processed(self, processed_inbound_payload):
257
+ """
258
+ Override this method to conduct customized operations to generate a client's response to
259
+ the server when inbound payload from the server has been processed.
260
+ """
261
+ report, outbound_payload = await self._start_training(processed_inbound_payload)
262
+ return report, outbound_payload
263
+
264
+ def outbound_ready(self, report, outbound_processor):
265
+ """
266
+ Override this method to complete additional tasks before the outbound processors start
267
+ to process the data to be sent to the server.
268
+ """
269
+
270
+ async def _chunk_arrived(self, data) -> None:
271
+ """Upon receiving a chunk of data from the server."""
272
+ self.chunks.append(data)
273
+
274
+ async def _request_update(self, data) -> None:
275
+ """Upon receiving a request for an urgent model update."""
276
+ logging.info(
277
+ "[Client #%s] Urgent request received for model update at time %s.",
278
+ data["client_id"],
279
+ data["time"],
280
+ )
281
+
282
+ report, payload = await self._obtain_model_update(
283
+ client_id=data["client_id"],
284
+ requested_time=data["time"],
285
+ )
286
+
287
+ # Process outbound data when necessary
288
+ self.callback_handler.call_event(
289
+ "on_outbound_ready", self, report, self.outbound_processor
290
+ )
291
+ self.outbound_ready(report, self.outbound_processor)
292
+ payload = self.outbound_processor.process(payload)
293
+
294
+ # Sending the client report as metadata to the server (payload to follow)
295
+ await self.sio.emit(
296
+ "client_report", {"id": self.client_id, "report": pickle.dumps(report)}
297
+ )
298
+
299
+ # Sending the client training payload to the server
300
+ await self._send(payload)
301
+
302
+ async def _payload_arrived(self, client_id) -> None:
303
+ """Upon receiving a portion of the new payload from the server."""
304
+ assert client_id == self.client_id
305
+
306
+ payload = b"".join(self.chunks)
307
+ _data = pickle.loads(payload)
308
+ self.chunks = []
309
+
310
+ if self.server_payload is None:
311
+ self.server_payload = _data
312
+ elif isinstance(self.server_payload, list):
313
+ self.server_payload.append(_data)
314
+ else:
315
+ self.server_payload = [self.server_payload]
316
+ self.server_payload.append(_data)
317
+
318
+ async def _payload_done(self, client_id, s3_key=None) -> None:
319
+ """Upon receiving all the new payload from the server."""
320
+ payload_size = 0
321
+
322
+ if s3_key is None:
323
+ if isinstance(self.server_payload, list):
324
+ for _data in self.server_payload:
325
+ payload_size += sys.getsizeof(pickle.dumps(_data))
326
+ elif isinstance(self.server_payload, dict):
327
+ for key, value in self.server_payload.items():
328
+ payload_size += sys.getsizeof(pickle.dumps({key: value}))
329
+ else:
330
+ payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
331
+ else:
332
+ self.server_payload = self.s3_client.receive_from_s3(s3_key)
333
+ payload_size = sys.getsizeof(pickle.dumps(self.server_payload))
334
+
335
+ assert client_id == self.client_id
336
+
337
+ logging.info(
338
+ "[Client #%d] Received %.2f MB of payload data from the server.",
339
+ client_id,
340
+ payload_size / 1024**2,
341
+ )
342
+
343
+ await self._handle_payload(self.server_payload)
344
+
345
+ async def _start_training(self, inbound_payload):
346
+ """Complete one round of training on this client."""
347
+ self._load_payload(inbound_payload)
348
+
349
+ report, outbound_payload = await self._train()
350
+
351
+ if Config().is_edge_server():
352
+ logging.info(
353
+ "[Server #%d] Model aggregated on edge server (%s).", os.getpid(), self
354
+ )
355
+ else:
356
+ logging.info("[%s] Model trained.", self)
357
+
358
+ return report, outbound_payload
359
+
360
+ async def _send_in_chunks(self, data) -> None:
361
+ """Sending a bytes object in fixed-sized chunks to the client."""
362
+ step = 1024**2
363
+ chunks = [data[i : i + step] for i in range(0, len(data), step)]
364
+
365
+ for chunk in chunks:
366
+ await self.sio.emit("chunk", {"data": chunk})
367
+
368
+ await self.sio.emit("client_payload", {"id": self.client_id})
369
+
370
+ async def _send(self, payload) -> None:
371
+ """Sending the client payload to the server using simulation, S3 or socket.io."""
372
+ if self.comm_simulation:
373
+ # If we are using the filesystem to simulate communication over a network
374
+ model_name = (
375
+ Config().trainer.model_name
376
+ if hasattr(Config().trainer, "model_name")
377
+ else "custom"
378
+ )
379
+ if "/" in model_name:
380
+ model_name = model_name.replace("/", "_")
381
+ checkpoint_path = Config().params["checkpoint_path"]
382
+ payload_filename = (
383
+ f"{checkpoint_path}/{model_name}_client_{self.client_id}.pth"
384
+ )
385
+ with open(payload_filename, "wb") as payload_file:
386
+ pickle.dump(payload, payload_file)
387
+
388
+ data_size = sys.getsizeof(pickle.dumps(payload))
389
+
390
+ logging.info(
391
+ "[%s] Sent %.2f MB of payload data to the server (simulated).",
392
+ self,
393
+ data_size / 1024**2,
394
+ )
395
+
396
+ else:
397
+ metadata = {"id": self.client_id}
398
+
399
+ if self.s3_client is not None:
400
+ unique_key = uuid.uuid4().hex[:6].upper()
401
+ s3_key = f"client_payload_{self.client_id}_{unique_key}"
402
+ self.s3_client.send_to_s3(s3_key, payload)
403
+ data_size = sys.getsizeof(pickle.dumps(payload))
404
+ metadata["s3_key"] = s3_key
405
+ else:
406
+ if isinstance(payload, list):
407
+ data_size: int = 0
408
+
409
+ for data in payload:
410
+ _data = pickle.dumps(data)
411
+ await self._send_in_chunks(_data)
412
+ data_size += sys.getsizeof(_data)
413
+ else:
414
+ _data = pickle.dumps(payload)
415
+ await self._send_in_chunks(_data)
416
+ data_size = sys.getsizeof(_data)
417
+
418
+ await self.sio.emit("client_payload_done", metadata)
419
+
420
+ logging.info(
421
+ "[%s] Sent %.2f MB of payload data to the server.",
422
+ self,
423
+ data_size / 1024**2,
424
+ )
425
+
426
+ def _clear_checkpoint_files(self):
427
+ """Delete all the temporary checkpoint files created by the client."""
428
+ model_path = Config().params["model_path"]
429
+ for filename in os.listdir(model_path):
430
+ split = re.match(
431
+ r"(?P<client_id>\d+)_(?P<epoch>\d+)_(?P<training_time>\d+.\d+).pth",
432
+ filename,
433
+ )
434
+ if split is not None:
435
+ file_path = f"{model_path}/{filename}"
436
+ os.remove(file_path)
437
+
438
+ def add_callbacks(self, callbacks):
439
+ """Adds a list of callbacks to the client callback handler."""
440
+ self.callback_handler.add_callbacks(callbacks)
441
+
442
+ @abstractmethod
443
+ async def _train(self):
444
+ """The machine learning training workload on a client."""
445
+
446
+ @abstractmethod
447
+ def configure(self) -> None:
448
+ """Prepare this client for training."""
449
+
450
+ @abstractmethod
451
+ def _load_data(self) -> None:
452
+ """Generating data and loading them onto this client."""
453
+
454
+ @abstractmethod
455
+ def _allocate_data(self) -> None:
456
+ """Allocate training or testing dataset of this client."""
457
+
458
+ @abstractmethod
459
+ def _load_payload(self, server_payload) -> None:
460
+ """Loading the payload onto this client."""
461
+
462
+ def process_server_response(self, server_response) -> None:
463
+ """Additional client-specific processing on the server response."""
464
+
465
+ @abstractmethod
466
+ async def _obtain_model_update(self, client_id, requested_time):
467
+ """Retrieving a model update corrsponding to a particular wall clock time."""