plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/trainers/basic.py
ADDED
@@ -0,0 +1,649 @@
|
|
1
|
+
"""
|
2
|
+
The training and testing loops for PyTorch.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import copy
|
6
|
+
import logging
|
7
|
+
import multiprocessing as mp
|
8
|
+
import os
|
9
|
+
import pickle
|
10
|
+
import re
|
11
|
+
import time
|
12
|
+
|
13
|
+
import torch
|
14
|
+
|
15
|
+
from plato.callbacks.handler import CallbackHandler
|
16
|
+
from plato.callbacks.trainer import LogProgressCallback
|
17
|
+
from plato.config import Config
|
18
|
+
from plato.models import registry as models_registry
|
19
|
+
from plato.trainers import (
|
20
|
+
base,
|
21
|
+
loss_criterion,
|
22
|
+
lr_schedulers,
|
23
|
+
optimizers,
|
24
|
+
tracking,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class Trainer(base.Trainer):
|
29
|
+
"""A basic federated learning trainer, used by both the client and the server."""
|
30
|
+
|
31
|
+
def __init__(self, model=None, callbacks=None):
|
32
|
+
"""Initializing the trainer with the provided model.
|
33
|
+
|
34
|
+
Arguments:
|
35
|
+
model: The model to train.
|
36
|
+
callbacks: The callbacks that this trainer uses.
|
37
|
+
"""
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
self.training_start_time = time.time()
|
41
|
+
self.model_state_dict = None
|
42
|
+
self.current_round = 0
|
43
|
+
|
44
|
+
# Starting from the default trainer callback class, add all supplied trainer callbacks
|
45
|
+
self.callbacks = [LogProgressCallback]
|
46
|
+
if callbacks is not None:
|
47
|
+
self.callbacks.extend(callbacks)
|
48
|
+
self.callback_handler = CallbackHandler(self.callbacks)
|
49
|
+
|
50
|
+
# The run history of performance metrics
|
51
|
+
self.run_history = tracking.RunHistory()
|
52
|
+
self._loss_tracker = tracking.LossTracker()
|
53
|
+
|
54
|
+
if model is None:
|
55
|
+
self.model = models_registry.get()
|
56
|
+
else:
|
57
|
+
self.model = model()
|
58
|
+
|
59
|
+
self.train_loader = None
|
60
|
+
self.sampler = None
|
61
|
+
self._loss_criterion = None
|
62
|
+
self.optimizer = None
|
63
|
+
self.lr_scheduler = None
|
64
|
+
self.current_epoch = 0
|
65
|
+
|
66
|
+
def zeros(self, shape):
|
67
|
+
"""Returns a PyTorch zero tensor with the given shape."""
|
68
|
+
# This should only be called from a server
|
69
|
+
assert self.client_id == 0
|
70
|
+
return torch.zeros(shape)
|
71
|
+
|
72
|
+
def save_model(self, filename=None, location=None):
|
73
|
+
"""Saving the model to a file."""
|
74
|
+
model_path = Config().params["model_path"] if location is None else location
|
75
|
+
model_name = Config().trainer.model_name
|
76
|
+
|
77
|
+
try:
|
78
|
+
if not os.path.exists(model_path):
|
79
|
+
os.makedirs(model_path)
|
80
|
+
except FileExistsError:
|
81
|
+
pass
|
82
|
+
|
83
|
+
if filename is not None:
|
84
|
+
model_path = f"{model_path}/{filename}"
|
85
|
+
else:
|
86
|
+
model_path = f"{model_path}/{model_name}.pth"
|
87
|
+
|
88
|
+
if self.model_state_dict is None:
|
89
|
+
torch.save(self.model.state_dict(), model_path)
|
90
|
+
else:
|
91
|
+
torch.save(self.model_state_dict, model_path)
|
92
|
+
|
93
|
+
with open(model_path + ".pkl", "wb") as history_file:
|
94
|
+
pickle.dump(self.run_history, history_file)
|
95
|
+
|
96
|
+
if self.client_id == 0:
|
97
|
+
logging.info("[Server #%d] Model saved to %s.", os.getpid(), model_path)
|
98
|
+
else:
|
99
|
+
logging.info("[Client #%d] Model saved to %s.", self.client_id, model_path)
|
100
|
+
|
101
|
+
def load_model(self, filename=None, location=None):
|
102
|
+
"""Loading pre-trained model weights from a file."""
|
103
|
+
model_path = Config().params["model_path"] if location is None else location
|
104
|
+
model_name = Config().trainer.model_name
|
105
|
+
|
106
|
+
if filename is not None:
|
107
|
+
model_path = f"{model_path}/{filename}"
|
108
|
+
else:
|
109
|
+
model_path = f"{model_path}/{model_name}.pth"
|
110
|
+
|
111
|
+
if self.client_id == 0:
|
112
|
+
logging.info(
|
113
|
+
"[Server #%d] Loading a model from %s.", os.getpid(), model_path
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
logging.info(
|
117
|
+
"[Client #%d] Loading a model from %s.",
|
118
|
+
self.client_id,
|
119
|
+
model_path,
|
120
|
+
)
|
121
|
+
|
122
|
+
pretrained = None
|
123
|
+
if torch.cuda.is_available():
|
124
|
+
pretrained = torch.load(model_path)
|
125
|
+
else:
|
126
|
+
pretrained = torch.load(model_path, map_location=torch.device("cpu"))
|
127
|
+
self.model.load_state_dict(pretrained, strict=True)
|
128
|
+
|
129
|
+
with open(model_path + ".pkl", "rb") as history_file:
|
130
|
+
self.run_history = pickle.load(history_file)
|
131
|
+
|
132
|
+
def simulate_sleep_time(self):
|
133
|
+
"""Simulate client's speed by putting it to sleep."""
|
134
|
+
if not (
|
135
|
+
hasattr(Config().clients, "sleep_simulation")
|
136
|
+
and Config().clients.sleep_simulation
|
137
|
+
):
|
138
|
+
sleep_seconds = Config().client_sleep_times[self.client_id - 1]
|
139
|
+
|
140
|
+
# Put this client to sleep
|
141
|
+
logging.info(
|
142
|
+
"[Client #%d] Going to sleep for %.2f seconds.",
|
143
|
+
self.client_id,
|
144
|
+
sleep_seconds,
|
145
|
+
)
|
146
|
+
time.sleep(sleep_seconds)
|
147
|
+
logging.info("[Client #%d] Woke up.", self.client_id)
|
148
|
+
|
149
|
+
def train_process(self, config, trainset, sampler, **kwargs):
|
150
|
+
"""
|
151
|
+
The main training loop in a federated learning workload, run in a
|
152
|
+
separate process with a new CUDA context, so that CUDA memory can be
|
153
|
+
released after the training completes.
|
154
|
+
|
155
|
+
Arguments:
|
156
|
+
self: the trainer itself.
|
157
|
+
config: a dictionary of configuration parameters.
|
158
|
+
trainset: The training dataset.
|
159
|
+
sampler: The sampler that extracts a partition for this client.
|
160
|
+
kwargs (optional): Additional keyword arguments.
|
161
|
+
"""
|
162
|
+
try:
|
163
|
+
self.train_model(config, trainset, sampler.get(), **kwargs)
|
164
|
+
except Exception as training_exception:
|
165
|
+
logging.info("Training on client #%d failed.", self.client_id)
|
166
|
+
raise training_exception
|
167
|
+
|
168
|
+
if "max_concurrency" in config:
|
169
|
+
self.model.cpu()
|
170
|
+
model_name = config["model_name"]
|
171
|
+
filename = f"{model_name}_{self.client_id}_{config['run_id']}.pth"
|
172
|
+
self.save_model(filename)
|
173
|
+
|
174
|
+
def perform_forward_and_backward_passes(self, config, examples, labels):
|
175
|
+
"""Perform forward and backward passes in the training loop.
|
176
|
+
|
177
|
+
Arguments:
|
178
|
+
config: the configuration.
|
179
|
+
examples: data samples in the current batch.
|
180
|
+
labels: labels in the current batch.
|
181
|
+
|
182
|
+
Returns: loss values after the current batch has been processed.
|
183
|
+
"""
|
184
|
+
self.optimizer.zero_grad()
|
185
|
+
|
186
|
+
outputs = self.model(examples)
|
187
|
+
|
188
|
+
loss = self._loss_criterion(outputs, labels)
|
189
|
+
self._loss_tracker.update(loss, labels.size(0))
|
190
|
+
|
191
|
+
if "create_graph" in config:
|
192
|
+
loss.backward(create_graph=config["create_graph"])
|
193
|
+
else:
|
194
|
+
loss.backward()
|
195
|
+
|
196
|
+
self.optimizer.step()
|
197
|
+
|
198
|
+
return loss
|
199
|
+
|
200
|
+
# pylint: disable=unused-argument
|
201
|
+
def train_model(self, config, trainset, sampler, **kwargs):
|
202
|
+
"""The default training loop when a custom training loop is not supplied."""
|
203
|
+
batch_size = config["batch_size"]
|
204
|
+
self.sampler = sampler
|
205
|
+
tic = time.perf_counter()
|
206
|
+
|
207
|
+
self.run_history.reset()
|
208
|
+
|
209
|
+
self.train_run_start(config)
|
210
|
+
self.callback_handler.call_event("on_train_run_start", self, config)
|
211
|
+
|
212
|
+
self.train_loader = self.get_train_loader(batch_size, trainset, sampler)
|
213
|
+
|
214
|
+
# Initializing the loss criterion
|
215
|
+
self._loss_criterion = self.get_loss_criterion()
|
216
|
+
|
217
|
+
# Initializing the optimizer
|
218
|
+
self.optimizer = self.get_optimizer(self.model)
|
219
|
+
self.lr_scheduler = self.get_lr_scheduler(config, self.optimizer)
|
220
|
+
self.optimizer = self._adjust_lr(config, self.lr_scheduler, self.optimizer)
|
221
|
+
|
222
|
+
self.model.to(self.device)
|
223
|
+
self.model.train()
|
224
|
+
|
225
|
+
total_epochs = config["epochs"]
|
226
|
+
|
227
|
+
for self.current_epoch in range(1, total_epochs + 1):
|
228
|
+
self._loss_tracker.reset()
|
229
|
+
self.train_epoch_start(config)
|
230
|
+
self.callback_handler.call_event("on_train_epoch_start", self, config)
|
231
|
+
|
232
|
+
for batch_id, (examples, labels) in enumerate(self.train_loader):
|
233
|
+
self.train_step_start(config, batch=batch_id)
|
234
|
+
self.callback_handler.call_event(
|
235
|
+
"on_train_step_start", self, config, batch=batch_id
|
236
|
+
)
|
237
|
+
|
238
|
+
examples, labels = (
|
239
|
+
examples.to(self.device),
|
240
|
+
labels.to(self.device),
|
241
|
+
)
|
242
|
+
|
243
|
+
loss = self.perform_forward_and_backward_passes(
|
244
|
+
config, examples, labels
|
245
|
+
)
|
246
|
+
|
247
|
+
self.train_step_end(config, batch=batch_id, loss=loss)
|
248
|
+
self.callback_handler.call_event(
|
249
|
+
"on_train_step_end", self, config, batch=batch_id, loss=loss
|
250
|
+
)
|
251
|
+
|
252
|
+
self.lr_scheduler_step()
|
253
|
+
|
254
|
+
if hasattr(self.optimizer, "params_state_update"):
|
255
|
+
self.optimizer.params_state_update()
|
256
|
+
|
257
|
+
# Simulate client's speed
|
258
|
+
if (
|
259
|
+
self.client_id != 0
|
260
|
+
and hasattr(Config().clients, "speed_simulation")
|
261
|
+
and Config().clients.speed_simulation
|
262
|
+
):
|
263
|
+
self.simulate_sleep_time()
|
264
|
+
|
265
|
+
# Saving the model at the end of this epoch to a file so that
|
266
|
+
# it can later be retrieved to respond to server requests
|
267
|
+
# in asynchronous mode when the wall clock time is simulated
|
268
|
+
if (
|
269
|
+
hasattr(Config().server, "request_update")
|
270
|
+
and Config().server.request_update
|
271
|
+
):
|
272
|
+
self.model.cpu()
|
273
|
+
training_time = time.perf_counter() - tic
|
274
|
+
filename = f"{self.client_id}_{self.current_epoch}_{training_time}.pth"
|
275
|
+
self.save_model(filename)
|
276
|
+
self.model.to(self.device)
|
277
|
+
|
278
|
+
self.run_history.update_metric("train_loss", self._loss_tracker.average)
|
279
|
+
self.train_epoch_end(config)
|
280
|
+
self.callback_handler.call_event("on_train_epoch_end", self, config)
|
281
|
+
|
282
|
+
self.train_run_end(config)
|
283
|
+
self.callback_handler.call_event("on_train_run_end", self, config)
|
284
|
+
|
285
|
+
def train(self, trainset, sampler, **kwargs) -> float:
|
286
|
+
"""The main training loop in a federated learning workload.
|
287
|
+
|
288
|
+
Arguments:
|
289
|
+
trainset: The training dataset.
|
290
|
+
sampler: the sampler that extracts a partition for this client.
|
291
|
+
kwargs (optional): Additional keyword arguments.
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
float: Elapsed time during training.
|
295
|
+
"""
|
296
|
+
config = Config().trainer._asdict()
|
297
|
+
config["run_id"] = Config().params["run_id"]
|
298
|
+
|
299
|
+
# Set the start time of training in absolute time
|
300
|
+
self.training_start_time = time.time()
|
301
|
+
|
302
|
+
if "max_concurrency" in config:
|
303
|
+
tic = time.perf_counter()
|
304
|
+
|
305
|
+
if mp.get_start_method(allow_none=True) != "spawn":
|
306
|
+
mp.set_start_method("spawn", force=True)
|
307
|
+
|
308
|
+
train_proc = mp.Process(
|
309
|
+
target=self.train_process,
|
310
|
+
args=(config, trainset, sampler),
|
311
|
+
kwargs=kwargs,
|
312
|
+
)
|
313
|
+
train_proc.start()
|
314
|
+
train_proc.join()
|
315
|
+
|
316
|
+
model_name = Config().trainer.model_name
|
317
|
+
filename = f"{model_name}_{self.client_id}_{Config().params['run_id']}.pth"
|
318
|
+
|
319
|
+
try:
|
320
|
+
self.load_model(filename)
|
321
|
+
except OSError as error: # the model file is not found, training failed
|
322
|
+
raise ValueError(
|
323
|
+
f"Training on client {self.client_id} failed."
|
324
|
+
) from error
|
325
|
+
|
326
|
+
toc = time.perf_counter()
|
327
|
+
self.pause_training()
|
328
|
+
else:
|
329
|
+
tic = time.perf_counter()
|
330
|
+
self.train_process(config, trainset, sampler, **kwargs)
|
331
|
+
toc = time.perf_counter()
|
332
|
+
|
333
|
+
training_time = toc - tic
|
334
|
+
|
335
|
+
return training_time
|
336
|
+
|
337
|
+
def test_process(self, config, testset, sampler=None, **kwargs):
|
338
|
+
"""The testing loop, run in a separate process with a new CUDA context,
|
339
|
+
so that CUDA memory can be released after the training completes.
|
340
|
+
|
341
|
+
Arguments:
|
342
|
+
config: a dictionary of configuration parameters.
|
343
|
+
testset: The test dataset.
|
344
|
+
sampler: The sampler that extracts a partition of the test dataset.
|
345
|
+
kwargs (optional): Additional keyword arguments.
|
346
|
+
"""
|
347
|
+
self.model.to(self.device)
|
348
|
+
self.model.eval()
|
349
|
+
|
350
|
+
# Initialize accuracy to be returned to -1, so that the client can disconnect
|
351
|
+
# from the server when testing fails
|
352
|
+
accuracy = -1
|
353
|
+
|
354
|
+
try:
|
355
|
+
if sampler is None:
|
356
|
+
accuracy = self.test_model(config, testset, **kwargs)
|
357
|
+
else:
|
358
|
+
accuracy = self.test_model(config, testset, sampler.get(), **kwargs)
|
359
|
+
|
360
|
+
except Exception as testing_exception:
|
361
|
+
logging.info("Testing on client #%d failed.", self.client_id)
|
362
|
+
raise testing_exception
|
363
|
+
|
364
|
+
self.model.cpu()
|
365
|
+
|
366
|
+
if "max_concurrency" in config:
|
367
|
+
model_name = config["model_name"]
|
368
|
+
filename = f"{model_name}_{self.client_id}_{config['run_id']}.acc"
|
369
|
+
self.save_accuracy(accuracy, filename)
|
370
|
+
else:
|
371
|
+
return accuracy
|
372
|
+
|
373
|
+
def test(self, testset, sampler=None, **kwargs) -> float:
|
374
|
+
"""Testing the model using the provided test dataset.
|
375
|
+
|
376
|
+
Arguments:
|
377
|
+
testset: The test dataset.
|
378
|
+
sampler: The sampler that extracts a partition of the test dataset.
|
379
|
+
kwargs (optional): Additional keyword arguments.
|
380
|
+
"""
|
381
|
+
config = Config().trainer._asdict()
|
382
|
+
config["run_id"] = Config().params["run_id"]
|
383
|
+
|
384
|
+
if hasattr(Config().trainer, "max_concurrency"):
|
385
|
+
if mp.get_start_method(allow_none=True) != "spawn":
|
386
|
+
mp.set_start_method("spawn", force=True)
|
387
|
+
|
388
|
+
proc = mp.Process(
|
389
|
+
target=self.test_process,
|
390
|
+
args=(config, testset, sampler),
|
391
|
+
kwargs=kwargs,
|
392
|
+
)
|
393
|
+
proc.start()
|
394
|
+
proc.join()
|
395
|
+
|
396
|
+
accuracy = -1
|
397
|
+
try:
|
398
|
+
model_name = Config().trainer.model_name
|
399
|
+
filename = (
|
400
|
+
f"{model_name}_{self.client_id}_{Config().params['run_id']}.acc"
|
401
|
+
)
|
402
|
+
accuracy = self.load_accuracy(filename)
|
403
|
+
except OSError as error: # the model file is not found, training failed
|
404
|
+
raise ValueError(
|
405
|
+
f"Testing on client #{self.client_id} failed."
|
406
|
+
) from error
|
407
|
+
|
408
|
+
self.pause_training()
|
409
|
+
else:
|
410
|
+
accuracy = self.test_process(config, testset, **kwargs)
|
411
|
+
|
412
|
+
return accuracy
|
413
|
+
|
414
|
+
def obtain_model_update(self, client_id, requested_time):
|
415
|
+
"""
|
416
|
+
Obtain a saved model for a particular epoch that finishes just after the provided
|
417
|
+
wall clock time is reached.
|
418
|
+
"""
|
419
|
+
# Constructing a list of epochs and training times
|
420
|
+
models_per_epoch = {}
|
421
|
+
|
422
|
+
for filename in os.listdir(Config().params["model_path"]):
|
423
|
+
split = re.match(
|
424
|
+
r"(?P<client_id>\d+)_(?P<epoch>\d+)_(?P<training_time>\d+.\d+).pth$",
|
425
|
+
filename,
|
426
|
+
)
|
427
|
+
|
428
|
+
if split is not None:
|
429
|
+
epoch = split.group("epoch")
|
430
|
+
training_time = split.group("training_time")
|
431
|
+
if client_id == int(split.group("client_id")):
|
432
|
+
models_per_epoch[epoch] = {
|
433
|
+
"training_time": float(training_time),
|
434
|
+
"model_checkpoint": filename,
|
435
|
+
}
|
436
|
+
# Locate the model at a specific wall clock time
|
437
|
+
for epoch in sorted(models_per_epoch, reverse=True):
|
438
|
+
model_training_time = models_per_epoch[epoch]["training_time"]
|
439
|
+
model_checkpoint = models_per_epoch[epoch]["model_checkpoint"]
|
440
|
+
|
441
|
+
if model_training_time < requested_time:
|
442
|
+
model_path = f"{Config().params['model_path']}/{model_checkpoint}"
|
443
|
+
|
444
|
+
pretrained = None
|
445
|
+
if torch.cuda.is_available():
|
446
|
+
pretrained = torch.load(model_path)
|
447
|
+
else:
|
448
|
+
pretrained = torch.load(
|
449
|
+
model_path, map_location=torch.device("cpu")
|
450
|
+
)
|
451
|
+
|
452
|
+
model = models_registry.get()
|
453
|
+
model.load_state_dict(pretrained, strict=True)
|
454
|
+
|
455
|
+
logging.info(
|
456
|
+
"[Client #%s] Responding to the server with the model after "
|
457
|
+
"epoch %s finished, at time %s.",
|
458
|
+
client_id,
|
459
|
+
epoch,
|
460
|
+
model_training_time,
|
461
|
+
)
|
462
|
+
|
463
|
+
return model
|
464
|
+
|
465
|
+
raise ValueError(
|
466
|
+
f"[Client #{client_id}] Cannot find an epoch that matches the wall-clock time provided."
|
467
|
+
)
|
468
|
+
|
469
|
+
# pylint: disable=unused-argument
|
470
|
+
def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
|
471
|
+
"""
|
472
|
+
Creates an instance of the trainloader.
|
473
|
+
|
474
|
+
Arguments:
|
475
|
+
batch_size: the batch size.
|
476
|
+
trainset: the training dataset.
|
477
|
+
sampler: the sampler for the trainloader to use.
|
478
|
+
"""
|
479
|
+
return torch.utils.data.DataLoader(
|
480
|
+
dataset=trainset,
|
481
|
+
shuffle=False,
|
482
|
+
batch_size=batch_size,
|
483
|
+
sampler=sampler,
|
484
|
+
)
|
485
|
+
|
486
|
+
# pylint: disable=unused-argument
|
487
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
488
|
+
"""
|
489
|
+
Evaluates the model with the provided test dataset and test sampler.
|
490
|
+
|
491
|
+
Auguments:
|
492
|
+
testset: the test dataset.
|
493
|
+
sampler: the test sampler. The default is None.
|
494
|
+
kwargs (optional): Additional keyword arguments.
|
495
|
+
"""
|
496
|
+
batch_size = config["batch_size"]
|
497
|
+
|
498
|
+
test_loader = torch.utils.data.DataLoader(
|
499
|
+
testset, batch_size=batch_size, shuffle=False, sampler=sampler
|
500
|
+
)
|
501
|
+
|
502
|
+
correct = 0
|
503
|
+
total = 0
|
504
|
+
|
505
|
+
self.model.to(self.device)
|
506
|
+
with torch.no_grad():
|
507
|
+
for examples, labels in test_loader:
|
508
|
+
examples, labels = (
|
509
|
+
examples.to(self.device),
|
510
|
+
labels.to(self.device),
|
511
|
+
)
|
512
|
+
|
513
|
+
outputs = self.model(examples)
|
514
|
+
|
515
|
+
outputs = self.process_outputs(outputs)
|
516
|
+
|
517
|
+
_, predicted = torch.max(outputs.data, 1)
|
518
|
+
total += labels.size(0)
|
519
|
+
correct += (predicted == labels).sum().item()
|
520
|
+
|
521
|
+
return correct / total
|
522
|
+
|
523
|
+
def add_callbacks(self, callbacks):
|
524
|
+
"""Adds a list of callbacks to the trainer callback handler."""
|
525
|
+
self.callback_handler.add_callbacks(callbacks)
|
526
|
+
|
527
|
+
def get_optimizer(self, model):
|
528
|
+
"""Returns the optimizer."""
|
529
|
+
return optimizers.get(model)
|
530
|
+
|
531
|
+
def get_lr_scheduler(self, config, optimizer):
|
532
|
+
"""Returns the learning rate scheduler, if needed."""
|
533
|
+
if "lr_scheduler" not in config:
|
534
|
+
return None
|
535
|
+
|
536
|
+
return lr_schedulers.get(optimizer, len(self.train_loader))
|
537
|
+
|
538
|
+
def lr_scheduler_step(self):
|
539
|
+
"""
|
540
|
+
Performs a single scheduler step if ``self.lr_scheduler`` has been assigned.
|
541
|
+
"""
|
542
|
+
if self.lr_scheduler is not None:
|
543
|
+
self.lr_scheduler.step()
|
544
|
+
|
545
|
+
def _adjust_lr(self, config, lr_scheduler, optimizer) -> torch.optim.Optimizer:
|
546
|
+
"""Returns an optimizer with an initial learning rate that has been
|
547
|
+
adjusted according to the current round, so that learning rate
|
548
|
+
schedulers can be effective throughout the communication rounds."""
|
549
|
+
|
550
|
+
if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
|
551
|
+
global_lr_scheduler = copy.deepcopy(lr_scheduler)
|
552
|
+
|
553
|
+
for __ in range(self.current_round - 1):
|
554
|
+
for __ in range(Config().trainer.epochs):
|
555
|
+
global_lr_scheduler.step()
|
556
|
+
|
557
|
+
initial_lr = global_lr_scheduler.get_last_lr()
|
558
|
+
optimizer.param_groups[0]["lr"] = initial_lr[0]
|
559
|
+
|
560
|
+
return optimizer
|
561
|
+
|
562
|
+
def get_loss_criterion(self):
|
563
|
+
"""Returns the loss criterion."""
|
564
|
+
return loss_criterion.get()
|
565
|
+
|
566
|
+
def backward(self, config, loss):
|
567
|
+
"""Perform the backpropagation pass."""
|
568
|
+
|
569
|
+
def train_run_start(self, config):
|
570
|
+
"""Method called at the start of training run."""
|
571
|
+
|
572
|
+
def train_run_end(self, config):
|
573
|
+
"""Method called at the end of a training run."""
|
574
|
+
|
575
|
+
def train_epoch_start(self, config):
|
576
|
+
"""Method called at the beginning of a training epoch."""
|
577
|
+
|
578
|
+
def train_epoch_end(self, config):
|
579
|
+
"""Method called at the end of a training epoch."""
|
580
|
+
|
581
|
+
def train_step_start(self, config, batch=None):
|
582
|
+
"""Method called at the beginning of a training step."""
|
583
|
+
|
584
|
+
def train_step_end(self, config, batch=None, loss=None):
|
585
|
+
"""
|
586
|
+
Method called at the end of a training step.
|
587
|
+
|
588
|
+
:param batch: the current batch of training data.
|
589
|
+
:param loss: the loss computed in the current batch.
|
590
|
+
"""
|
591
|
+
|
592
|
+
@staticmethod
|
593
|
+
def process_outputs(outputs):
|
594
|
+
"""
|
595
|
+
Method called after the model updates have been generated.
|
596
|
+
"""
|
597
|
+
return outputs
|
598
|
+
|
599
|
+
|
600
|
+
class TrainerWithTimmScheduler(Trainer):
|
601
|
+
"""
|
602
|
+
Subclass of the :class:`Trainer` that works with `timm schedulers
|
603
|
+
<https://fastai.github.io/timmdocs/schedulers>` instead of standard PyTorch
|
604
|
+
learning rate schedulers.
|
605
|
+
"""
|
606
|
+
|
607
|
+
def __init__(self, *args, **kwargs):
|
608
|
+
super().__init__(*args, **kwargs)
|
609
|
+
self.num_updates = None
|
610
|
+
self.past_epochs = None
|
611
|
+
|
612
|
+
def train_epoch_start(self, config):
|
613
|
+
"""Method called at the beginning of a training epoch."""
|
614
|
+
super().train_epoch_start(config)
|
615
|
+
|
616
|
+
self.num_updates = self.current_epoch * len(self.train_loader)
|
617
|
+
|
618
|
+
if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
|
619
|
+
self.num_updates += self.past_epochs * len(self.train_loader)
|
620
|
+
|
621
|
+
def lr_scheduler_step(self):
|
622
|
+
self.num_updates += 1
|
623
|
+
|
624
|
+
if self.lr_scheduler is not None:
|
625
|
+
self.lr_scheduler.step_update(num_updates=self.num_updates)
|
626
|
+
|
627
|
+
def train_epoch_end(self, config):
|
628
|
+
"""Method called at the end of a training epoch."""
|
629
|
+
super().train_epoch_end(config)
|
630
|
+
|
631
|
+
if self.lr_scheduler is not None:
|
632
|
+
if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
|
633
|
+
self.lr_scheduler.step(self.past_epochs + self.current_epoch + 1)
|
634
|
+
else:
|
635
|
+
self.lr_scheduler.step(self.current_epoch + 1)
|
636
|
+
|
637
|
+
def _adjust_lr(self, config, lr_scheduler, optimizer) -> torch.optim.Optimizer:
|
638
|
+
"""Returns an optimizer with an initial learning rate that has been
|
639
|
+
adjusted according to the current round, so that learning rate
|
640
|
+
schedulers can be effective throughout the communication rounds."""
|
641
|
+
|
642
|
+
if "global_lr_scheduler" in config and config["global_lr_scheduler"]:
|
643
|
+
past_epochs = (self.current_round - 1) * Config().trainer.epochs
|
644
|
+
self.past_epochs = past_epochs
|
645
|
+
|
646
|
+
lr_scheduler.step(past_epochs)
|
647
|
+
lr_scheduler.step_update(past_epochs * len(self.train_loader))
|
648
|
+
|
649
|
+
return optimizer
|