plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,649 @@
1
+ """
2
+ The training and testing loops for PyTorch.
3
+ """
4
+
5
+ import copy
6
+ import logging
7
+ import multiprocessing as mp
8
+ import os
9
+ import pickle
10
+ import re
11
+ import time
12
+
13
+ import torch
14
+
15
+ from plato.callbacks.handler import CallbackHandler
16
+ from plato.callbacks.trainer import LogProgressCallback
17
+ from plato.config import Config
18
+ from plato.models import registry as models_registry
19
+ from plato.trainers import (
20
+ base,
21
+ loss_criterion,
22
+ lr_schedulers,
23
+ optimizers,
24
+ tracking,
25
+ )
26
+
27
+
28
+ class Trainer(base.Trainer):
29
+ """A basic federated learning trainer, used by both the client and the server."""
30
+
31
+ def __init__(self, model=None, callbacks=None):
32
+ """Initializing the trainer with the provided model.
33
+
34
+ Arguments:
35
+ model: The model to train.
36
+ callbacks: The callbacks that this trainer uses.
37
+ """
38
+ super().__init__()
39
+
40
+ self.training_start_time = time.time()
41
+ self.model_state_dict = None
42
+ self.current_round = 0
43
+
44
+ # Starting from the default trainer callback class, add all supplied trainer callbacks
45
+ self.callbacks = [LogProgressCallback]
46
+ if callbacks is not None:
47
+ self.callbacks.extend(callbacks)
48
+ self.callback_handler = CallbackHandler(self.callbacks)
49
+
50
+ # The run history of performance metrics
51
+ self.run_history = tracking.RunHistory()
52
+ self._loss_tracker = tracking.LossTracker()
53
+
54
+ if model is None:
55
+ self.model = models_registry.get()
56
+ else:
57
+ self.model = model()
58
+
59
+ self.train_loader = None
60
+ self.sampler = None
61
+ self._loss_criterion = None
62
+ self.optimizer = None
63
+ self.lr_scheduler = None
64
+ self.current_epoch = 0
65
+
66
+ def zeros(self, shape):
67
+ """Returns a PyTorch zero tensor with the given shape."""
68
+ # This should only be called from a server
69
+ assert self.client_id == 0
70
+ return torch.zeros(shape)
71
+
72
+ def save_model(self, filename=None, location=None):
73
+ """Saving the model to a file."""
74
+ model_path = Config().params["model_path"] if location is None else location
75
+ model_name = Config().trainer.model_name
76
+
77
+ try:
78
+ if not os.path.exists(model_path):
79
+ os.makedirs(model_path)
80
+ except FileExistsError:
81
+ pass
82
+
83
+ if filename is not None:
84
+ model_path = f"{model_path}/{filename}"
85
+ else:
86
+ model_path = f"{model_path}/{model_name}.pth"
87
+
88
+ if self.model_state_dict is None:
89
+ torch.save(self.model.state_dict(), model_path)
90
+ else:
91
+ torch.save(self.model_state_dict, model_path)
92
+
93
+ with open(model_path + ".pkl", "wb") as history_file:
94
+ pickle.dump(self.run_history, history_file)
95
+
96
+ if self.client_id == 0:
97
+ logging.info("[Server #%d] Model saved to %s.", os.getpid(), model_path)
98
+ else:
99
+ logging.info("[Client #%d] Model saved to %s.", self.client_id, model_path)
100
+
101
+ def load_model(self, filename=None, location=None):
102
+ """Loading pre-trained model weights from a file."""
103
+ model_path = Config().params["model_path"] if location is None else location
104
+ model_name = Config().trainer.model_name
105
+
106
+ if filename is not None:
107
+ model_path = f"{model_path}/{filename}"
108
+ else:
109
+ model_path = f"{model_path}/{model_name}.pth"
110
+
111
+ if self.client_id == 0:
112
+ logging.info(
113
+ "[Server #%d] Loading a model from %s.", os.getpid(), model_path
114
+ )
115
+ else:
116
+ logging.info(
117
+ "[Client #%d] Loading a model from %s.",
118
+ self.client_id,
119
+ model_path,
120
+ )
121
+
122
+ pretrained = None
123
+ if torch.cuda.is_available():
124
+ pretrained = torch.load(model_path)
125
+ else:
126
+ pretrained = torch.load(model_path, map_location=torch.device("cpu"))
127
+ self.model.load_state_dict(pretrained, strict=True)
128
+
129
+ with open(model_path + ".pkl", "rb") as history_file:
130
+ self.run_history = pickle.load(history_file)
131
+
132
+ def simulate_sleep_time(self):
133
+ """Simulate client's speed by putting it to sleep."""
134
+ if not (
135
+ hasattr(Config().clients, "sleep_simulation")
136
+ and Config().clients.sleep_simulation
137
+ ):
138
+ sleep_seconds = Config().client_sleep_times[self.client_id - 1]
139
+
140
+ # Put this client to sleep
141
+ logging.info(
142
+ "[Client #%d] Going to sleep for %.2f seconds.",
143
+ self.client_id,
144
+ sleep_seconds,
145
+ )
146
+ time.sleep(sleep_seconds)
147
+ logging.info("[Client #%d] Woke up.", self.client_id)
148
+
149
+ def train_process(self, config, trainset, sampler, **kwargs):
150
+ """
151
+ The main training loop in a federated learning workload, run in a
152
+ separate process with a new CUDA context, so that CUDA memory can be
153
+ released after the training completes.
154
+
155
+ Arguments:
156
+ self: the trainer itself.
157
+ config: a dictionary of configuration parameters.
158
+ trainset: The training dataset.
159
+ sampler: The sampler that extracts a partition for this client.
160
+ kwargs (optional): Additional keyword arguments.
161
+ """
162
+ try:
163
+ self.train_model(config, trainset, sampler.get(), **kwargs)
164
+ except Exception as training_exception:
165
+ logging.info("Training on client #%d failed.", self.client_id)
166
+ raise training_exception
167
+
168
+ if "max_concurrency" in config:
169
+ self.model.cpu()
170
+ model_name = config["model_name"]
171
+ filename = f"{model_name}_{self.client_id}_{config['run_id']}.pth"
172
+ self.save_model(filename)
173
+
174
+ def perform_forward_and_backward_passes(self, config, examples, labels):
175
+ """Perform forward and backward passes in the training loop.
176
+
177
+ Arguments:
178
+ config: the configuration.
179
+ examples: data samples in the current batch.
180
+ labels: labels in the current batch.
181
+
182
+ Returns: loss values after the current batch has been processed.
183
+ """
184
+ self.optimizer.zero_grad()
185
+
186
+ outputs = self.model(examples)
187
+
188
+ loss = self._loss_criterion(outputs, labels)
189
+ self._loss_tracker.update(loss, labels.size(0))
190
+
191
+ if "create_graph" in config:
192
+ loss.backward(create_graph=config["create_graph"])
193
+ else:
194
+ loss.backward()
195
+
196
+ self.optimizer.step()
197
+
198
+ return loss
199
+
200
+ # pylint: disable=unused-argument
201
+ def train_model(self, config, trainset, sampler, **kwargs):
202
+ """The default training loop when a custom training loop is not supplied."""
203
+ batch_size = config["batch_size"]
204
+ self.sampler = sampler
205
+ tic = time.perf_counter()
206
+
207
+ self.run_history.reset()
208
+
209
+ self.train_run_start(config)
210
+ self.callback_handler.call_event("on_train_run_start", self, config)
211
+
212
+ self.train_loader = self.get_train_loader(batch_size, trainset, sampler)
213
+
214
+ # Initializing the loss criterion
215
+ self._loss_criterion = self.get_loss_criterion()
216
+
217
+ # Initializing the optimizer
218
+ self.optimizer = self.get_optimizer(self.model)
219
+ self.lr_scheduler = self.get_lr_scheduler(config, self.optimizer)
220
+ self.optimizer = self._adjust_lr(config, self.lr_scheduler, self.optimizer)
221
+
222
+ self.model.to(self.device)
223
+ self.model.train()
224
+
225
+ total_epochs = config["epochs"]
226
+
227
+ for self.current_epoch in range(1, total_epochs + 1):
228
+ self._loss_tracker.reset()
229
+ self.train_epoch_start(config)
230
+ self.callback_handler.call_event("on_train_epoch_start", self, config)
231
+
232
+ for batch_id, (examples, labels) in enumerate(self.train_loader):
233
+ self.train_step_start(config, batch=batch_id)
234
+ self.callback_handler.call_event(
235
+ "on_train_step_start", self, config, batch=batch_id
236
+ )
237
+
238
+ examples, labels = (
239
+ examples.to(self.device),
240
+ labels.to(self.device),
241
+ )
242
+
243
+ loss = self.perform_forward_and_backward_passes(
244
+ config, examples, labels
245
+ )
246
+
247
+ self.train_step_end(config, batch=batch_id, loss=loss)
248
+ self.callback_handler.call_event(
249
+ "on_train_step_end", self, config, batch=batch_id, loss=loss
250
+ )
251
+
252
+ self.lr_scheduler_step()
253
+
254
+ if hasattr(self.optimizer, "params_state_update"):
255
+ self.optimizer.params_state_update()
256
+
257
+ # Simulate client's speed
258
+ if (
259
+ self.client_id != 0
260
+ and hasattr(Config().clients, "speed_simulation")
261
+ and Config().clients.speed_simulation
262
+ ):
263
+ self.simulate_sleep_time()
264
+
265
+ # Saving the model at the end of this epoch to a file so that
266
+ # it can later be retrieved to respond to server requests
267
+ # in asynchronous mode when the wall clock time is simulated
268
+ if (
269
+ hasattr(Config().server, "request_update")
270
+ and Config().server.request_update
271
+ ):
272
+ self.model.cpu()
273
+ training_time = time.perf_counter() - tic
274
+ filename = f"{self.client_id}_{self.current_epoch}_{training_time}.pth"
275
+ self.save_model(filename)
276
+ self.model.to(self.device)
277
+
278
+ self.run_history.update_metric("train_loss", self._loss_tracker.average)
279
+ self.train_epoch_end(config)
280
+ self.callback_handler.call_event("on_train_epoch_end", self, config)
281
+
282
+ self.train_run_end(config)
283
+ self.callback_handler.call_event("on_train_run_end", self, config)
284
+
285
+ def train(self, trainset, sampler, **kwargs) -> float:
286
+ """The main training loop in a federated learning workload.
287
+
288
+ Arguments:
289
+ trainset: The training dataset.
290
+ sampler: the sampler that extracts a partition for this client.
291
+ kwargs (optional): Additional keyword arguments.
292
+
293
+ Returns:
294
+ float: Elapsed time during training.
295
+ """
296
+ config = Config().trainer._asdict()
297
+ config["run_id"] = Config().params["run_id"]
298
+
299
+ # Set the start time of training in absolute time
300
+ self.training_start_time = time.time()
301
+
302
+ if "max_concurrency" in config:
303
+ tic = time.perf_counter()
304
+
305
+ if mp.get_start_method(allow_none=True) != "spawn":
306
+ mp.set_start_method("spawn", force=True)
307
+
308
+ train_proc = mp.Process(
309
+ target=self.train_process,
310
+ args=(config, trainset, sampler),
311
+ kwargs=kwargs,
312
+ )
313
+ train_proc.start()
314
+ train_proc.join()
315
+
316
+ model_name = Config().trainer.model_name
317
+ filename = f"{model_name}_{self.client_id}_{Config().params['run_id']}.pth"
318
+
319
+ try:
320
+ self.load_model(filename)
321
+ except OSError as error: # the model file is not found, training failed
322
+ raise ValueError(
323
+ f"Training on client {self.client_id} failed."
324
+ ) from error
325
+
326
+ toc = time.perf_counter()
327
+ self.pause_training()
328
+ else:
329
+ tic = time.perf_counter()
330
+ self.train_process(config, trainset, sampler, **kwargs)
331
+ toc = time.perf_counter()
332
+
333
+ training_time = toc - tic
334
+
335
+ return training_time
336
+
337
+ def test_process(self, config, testset, sampler=None, **kwargs):
338
+ """The testing loop, run in a separate process with a new CUDA context,
339
+ so that CUDA memory can be released after the training completes.
340
+
341
+ Arguments:
342
+ config: a dictionary of configuration parameters.
343
+ testset: The test dataset.
344
+ sampler: The sampler that extracts a partition of the test dataset.
345
+ kwargs (optional): Additional keyword arguments.
346
+ """
347
+ self.model.to(self.device)
348
+ self.model.eval()
349
+
350
+ # Initialize accuracy to be returned to -1, so that the client can disconnect
351
+ # from the server when testing fails
352
+ accuracy = -1
353
+
354
+ try:
355
+ if sampler is None:
356
+ accuracy = self.test_model(config, testset, **kwargs)
357
+ else:
358
+ accuracy = self.test_model(config, testset, sampler.get(), **kwargs)
359
+
360
+ except Exception as testing_exception:
361
+ logging.info("Testing on client #%d failed.", self.client_id)
362
+ raise testing_exception
363
+
364
+ self.model.cpu()
365
+
366
+ if "max_concurrency" in config:
367
+ model_name = config["model_name"]
368
+ filename = f"{model_name}_{self.client_id}_{config['run_id']}.acc"
369
+ self.save_accuracy(accuracy, filename)
370
+ else:
371
+ return accuracy
372
+
373
+ def test(self, testset, sampler=None, **kwargs) -> float:
374
+ """Testing the model using the provided test dataset.
375
+
376
+ Arguments:
377
+ testset: The test dataset.
378
+ sampler: The sampler that extracts a partition of the test dataset.
379
+ kwargs (optional): Additional keyword arguments.
380
+ """
381
+ config = Config().trainer._asdict()
382
+ config["run_id"] = Config().params["run_id"]
383
+
384
+ if hasattr(Config().trainer, "max_concurrency"):
385
+ if mp.get_start_method(allow_none=True) != "spawn":
386
+ mp.set_start_method("spawn", force=True)
387
+
388
+ proc = mp.Process(
389
+ target=self.test_process,
390
+ args=(config, testset, sampler),
391
+ kwargs=kwargs,
392
+ )
393
+ proc.start()
394
+ proc.join()
395
+
396
+ accuracy = -1
397
+ try:
398
+ model_name = Config().trainer.model_name
399
+ filename = (
400
+ f"{model_name}_{self.client_id}_{Config().params['run_id']}.acc"
401
+ )
402
+ accuracy = self.load_accuracy(filename)
403
+ except OSError as error: # the model file is not found, training failed
404
+ raise ValueError(
405
+ f"Testing on client #{self.client_id} failed."
406
+ ) from error
407
+
408
+ self.pause_training()
409
+ else:
410
+ accuracy = self.test_process(config, testset, **kwargs)
411
+
412
+ return accuracy
413
+
414
+ def obtain_model_update(self, client_id, requested_time):
415
+ """
416
+ Obtain a saved model for a particular epoch that finishes just after the provided
417
+ wall clock time is reached.
418
+ """
419
+ # Constructing a list of epochs and training times
420
+ models_per_epoch = {}
421
+
422
+ for filename in os.listdir(Config().params["model_path"]):
423
+ split = re.match(
424
+ r"(?P<client_id>\d+)_(?P<epoch>\d+)_(?P<training_time>\d+.\d+).pth$",
425
+ filename,
426
+ )
427
+
428
+ if split is not None:
429
+ epoch = split.group("epoch")
430
+ training_time = split.group("training_time")
431
+ if client_id == int(split.group("client_id")):
432
+ models_per_epoch[epoch] = {
433
+ "training_time": float(training_time),
434
+ "model_checkpoint": filename,
435
+ }
436
+ # Locate the model at a specific wall clock time
437
+ for epoch in sorted(models_per_epoch, reverse=True):
438
+ model_training_time = models_per_epoch[epoch]["training_time"]
439
+ model_checkpoint = models_per_epoch[epoch]["model_checkpoint"]
440
+
441
+ if model_training_time < requested_time:
442
+ model_path = f"{Config().params['model_path']}/{model_checkpoint}"
443
+
444
+ pretrained = None
445
+ if torch.cuda.is_available():
446
+ pretrained = torch.load(model_path)
447
+ else:
448
+ pretrained = torch.load(
449
+ model_path, map_location=torch.device("cpu")
450
+ )
451
+
452
+ model = models_registry.get()
453
+ model.load_state_dict(pretrained, strict=True)
454
+
455
+ logging.info(
456
+ "[Client #%s] Responding to the server with the model after "
457
+ "epoch %s finished, at time %s.",
458
+ client_id,
459
+ epoch,
460
+ model_training_time,
461
+ )
462
+
463
+ return model
464
+
465
+ raise ValueError(
466
+ f"[Client #{client_id}] Cannot find an epoch that matches the wall-clock time provided."
467
+ )
468
+
469
+ # pylint: disable=unused-argument
470
+ def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
471
+ """
472
+ Creates an instance of the trainloader.
473
+
474
+ Arguments:
475
+ batch_size: the batch size.
476
+ trainset: the training dataset.
477
+ sampler: the sampler for the trainloader to use.
478
+ """
479
+ return torch.utils.data.DataLoader(
480
+ dataset=trainset,
481
+ shuffle=False,
482
+ batch_size=batch_size,
483
+ sampler=sampler,
484
+ )
485
+
486
+ # pylint: disable=unused-argument
487
+ def test_model(self, config, testset, sampler=None, **kwargs):
488
+ """
489
+ Evaluates the model with the provided test dataset and test sampler.
490
+
491
+ Auguments:
492
+ testset: the test dataset.
493
+ sampler: the test sampler. The default is None.
494
+ kwargs (optional): Additional keyword arguments.
495
+ """
496
+ batch_size = config["batch_size"]
497
+
498
+ test_loader = torch.utils.data.DataLoader(
499
+ testset, batch_size=batch_size, shuffle=False, sampler=sampler
500
+ )
501
+
502
+ correct = 0
503
+ total = 0
504
+
505
+ self.model.to(self.device)
506
+ with torch.no_grad():
507
+ for examples, labels in test_loader:
508
+ examples, labels = (
509
+ examples.to(self.device),
510
+ labels.to(self.device),
511
+ )
512
+
513
+ outputs = self.model(examples)
514
+
515
+ outputs = self.process_outputs(outputs)
516
+
517
+ _, predicted = torch.max(outputs.data, 1)
518
+ total += labels.size(0)
519
+ correct += (predicted == labels).sum().item()
520
+
521
+ return correct / total
522
+
523
+ def add_callbacks(self, callbacks):
524
+ """Adds a list of callbacks to the trainer callback handler."""
525
+ self.callback_handler.add_callbacks(callbacks)
526
+
527
+ def get_optimizer(self, model):
528
+ """Returns the optimizer."""
529
+ return optimizers.get(model)
530
+
531
+ def get_lr_scheduler(self, config, optimizer):
532
+ """Returns the learning rate scheduler, if needed."""
533
+ if "lr_scheduler" not in config:
534
+ return None
535
+
536
+ return lr_schedulers.get(optimizer, len(self.train_loader))
537
+
538
+ def lr_scheduler_step(self):
539
+ """
540
+ Performs a single scheduler step if ``self.lr_scheduler`` has been assigned.
541
+ """
542
+ if self.lr_scheduler is not None:
543
+ self.lr_scheduler.step()
544
+
545
+ def _adjust_lr(self, config, lr_scheduler, optimizer) -> torch.optim.Optimizer:
546
+ """Returns an optimizer with an initial learning rate that has been
547
+ adjusted according to the current round, so that learning rate
548
+ schedulers can be effective throughout the communication rounds."""
549
+
550
+ if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
551
+ global_lr_scheduler = copy.deepcopy(lr_scheduler)
552
+
553
+ for __ in range(self.current_round - 1):
554
+ for __ in range(Config().trainer.epochs):
555
+ global_lr_scheduler.step()
556
+
557
+ initial_lr = global_lr_scheduler.get_last_lr()
558
+ optimizer.param_groups[0]["lr"] = initial_lr[0]
559
+
560
+ return optimizer
561
+
562
+ def get_loss_criterion(self):
563
+ """Returns the loss criterion."""
564
+ return loss_criterion.get()
565
+
566
+ def backward(self, config, loss):
567
+ """Perform the backpropagation pass."""
568
+
569
+ def train_run_start(self, config):
570
+ """Method called at the start of training run."""
571
+
572
+ def train_run_end(self, config):
573
+ """Method called at the end of a training run."""
574
+
575
+ def train_epoch_start(self, config):
576
+ """Method called at the beginning of a training epoch."""
577
+
578
+ def train_epoch_end(self, config):
579
+ """Method called at the end of a training epoch."""
580
+
581
+ def train_step_start(self, config, batch=None):
582
+ """Method called at the beginning of a training step."""
583
+
584
+ def train_step_end(self, config, batch=None, loss=None):
585
+ """
586
+ Method called at the end of a training step.
587
+
588
+ :param batch: the current batch of training data.
589
+ :param loss: the loss computed in the current batch.
590
+ """
591
+
592
+ @staticmethod
593
+ def process_outputs(outputs):
594
+ """
595
+ Method called after the model updates have been generated.
596
+ """
597
+ return outputs
598
+
599
+
600
+ class TrainerWithTimmScheduler(Trainer):
601
+ """
602
+ Subclass of the :class:`Trainer` that works with `timm schedulers
603
+ <https://fastai.github.io/timmdocs/schedulers>` instead of standard PyTorch
604
+ learning rate schedulers.
605
+ """
606
+
607
+ def __init__(self, *args, **kwargs):
608
+ super().__init__(*args, **kwargs)
609
+ self.num_updates = None
610
+ self.past_epochs = None
611
+
612
+ def train_epoch_start(self, config):
613
+ """Method called at the beginning of a training epoch."""
614
+ super().train_epoch_start(config)
615
+
616
+ self.num_updates = self.current_epoch * len(self.train_loader)
617
+
618
+ if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
619
+ self.num_updates += self.past_epochs * len(self.train_loader)
620
+
621
+ def lr_scheduler_step(self):
622
+ self.num_updates += 1
623
+
624
+ if self.lr_scheduler is not None:
625
+ self.lr_scheduler.step_update(num_updates=self.num_updates)
626
+
627
+ def train_epoch_end(self, config):
628
+ """Method called at the end of a training epoch."""
629
+ super().train_epoch_end(config)
630
+
631
+ if self.lr_scheduler is not None:
632
+ if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
633
+ self.lr_scheduler.step(self.past_epochs + self.current_epoch + 1)
634
+ else:
635
+ self.lr_scheduler.step(self.current_epoch + 1)
636
+
637
+ def _adjust_lr(self, config, lr_scheduler, optimizer) -> torch.optim.Optimizer:
638
+ """Returns an optimizer with an initial learning rate that has been
639
+ adjusted according to the current round, so that learning rate
640
+ schedulers can be effective throughout the communication rounds."""
641
+
642
+ if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
643
+ past_epochs = (self.current_round - 1) * Config().trainer.epochs
644
+ self.past_epochs = past_epochs
645
+
646
+ lr_scheduler.step(past_epochs)
647
+ lr_scheduler.step_update(past_epochs * len(self.train_loader))
648
+
649
+ return optimizer