plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,302 @@
|
|
1
|
+
"""
|
2
|
+
A self-supervised learning (SSL) trainer for SSL training and testing.
|
3
|
+
|
4
|
+
Federated learning with SSL trains the global model based on the data loader and
|
5
|
+
objective function of SSL algorithms. For this unsupervised learning process, we
|
6
|
+
cannot test the model directly as the model only extracts features from the
|
7
|
+
data. Therefore, we use KNN as a classifier to get the accuracy of the global
|
8
|
+
model during the regular federated training process.
|
9
|
+
|
10
|
+
In the personalization process, each client trains a linear layer locally, based
|
11
|
+
on the features extracted by the trained global model.
|
12
|
+
|
13
|
+
The accuracy obtained by KNN during the regular federated training rounds may
|
14
|
+
not be used to compare with the accuracy in supervised learning methods.
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
from collections import UserList
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from lightly.data.multi_view_collate import MultiViewCollate
|
22
|
+
|
23
|
+
from plato.config import Config
|
24
|
+
from plato.trainers import basic
|
25
|
+
from plato.models import registry as models_registry
|
26
|
+
from plato.trainers import optimizers, lr_schedulers, loss_criterion
|
27
|
+
|
28
|
+
|
29
|
+
class SSLSamples(UserList):
|
30
|
+
"""A container for SSL sample, which contains multiple views as a list."""
|
31
|
+
|
32
|
+
def to(self, device):
|
33
|
+
"""Assign a list of views into the specific device."""
|
34
|
+
for view_idx, view in enumerate(self.data):
|
35
|
+
if isinstance(view, torch.Tensor):
|
36
|
+
view = view.to(device)
|
37
|
+
|
38
|
+
self[view_idx] = view
|
39
|
+
|
40
|
+
return self.data
|
41
|
+
|
42
|
+
|
43
|
+
class MultiViewCollateWrapper(MultiViewCollate):
|
44
|
+
"""
|
45
|
+
An interface to connect collate from lightly with Plato's data loading mechanism.
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __call__(self, batch):
|
49
|
+
"""Turn a batch of tuples into a single tuple."""
|
50
|
+
# Add a fname to each sample to make the batch compatible with lightly
|
51
|
+
batch = [batch[i] + (" ",) for i in range(len(batch))]
|
52
|
+
|
53
|
+
# Process first two parts with the lightly collate
|
54
|
+
views, labels, _ = super().__call__(batch)
|
55
|
+
|
56
|
+
# Assign views, which is a list of tensors, into SSLSamples
|
57
|
+
samples = SSLSamples(views)
|
58
|
+
return samples, labels
|
59
|
+
|
60
|
+
|
61
|
+
class Trainer(basic.Trainer):
|
62
|
+
"""A federated SSL trainer."""
|
63
|
+
|
64
|
+
def __init__(self, model=None, callbacks=None):
|
65
|
+
"""Initialize the trainer."""
|
66
|
+
super().__init__(model=model, callbacks=callbacks)
|
67
|
+
|
68
|
+
# Datasets for personalization.
|
69
|
+
self.personalized_trainset = None
|
70
|
+
self.personalized_testset = None
|
71
|
+
|
72
|
+
# Define the personalized model
|
73
|
+
model_params = Config().parameters.personalization.model._asdict()
|
74
|
+
model_params["input_dim"] = self.model.encoder.encoding_dim
|
75
|
+
model_params["output_dim"] = model_params["num_classes"]
|
76
|
+
self.local_layers = models_registry.get(
|
77
|
+
model_name=Config().algorithm.personalization.model_name,
|
78
|
+
model_type=Config().algorithm.personalization.model_type,
|
79
|
+
model_params=model_params,
|
80
|
+
)
|
81
|
+
|
82
|
+
def set_personalized_datasets(self, trainset, testset):
|
83
|
+
"""Set the personalized trainset."""
|
84
|
+
self.personalized_trainset = trainset
|
85
|
+
self.personalized_testset = testset
|
86
|
+
|
87
|
+
def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
|
88
|
+
"""Obtain the training loader based on the learning mode."""
|
89
|
+
# Get the trainloader for personalization
|
90
|
+
if self.current_round > Config().trainer.rounds:
|
91
|
+
return torch.utils.data.DataLoader(
|
92
|
+
dataset=self.personalized_trainset,
|
93
|
+
shuffle=False,
|
94
|
+
batch_size=batch_size,
|
95
|
+
sampler=sampler,
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
collate_fn = MultiViewCollateWrapper()
|
99
|
+
return torch.utils.data.DataLoader(
|
100
|
+
dataset=trainset,
|
101
|
+
shuffle=False,
|
102
|
+
batch_size=batch_size,
|
103
|
+
sampler=sampler,
|
104
|
+
collate_fn=collate_fn,
|
105
|
+
)
|
106
|
+
|
107
|
+
def get_optimizer(self, model):
|
108
|
+
"""Return the optimizer for SSL and personalization."""
|
109
|
+
if self.current_round <= Config().trainer.rounds:
|
110
|
+
return super().get_optimizer(model)
|
111
|
+
# Define the optimizer for the personalized model
|
112
|
+
optimizer_name = Config().algorithm.personalization.optimizer
|
113
|
+
optimizer_params = Config().parameters.personalization.optimizer._asdict()
|
114
|
+
return optimizers.get(
|
115
|
+
self.local_layers,
|
116
|
+
optimizer_name=optimizer_name,
|
117
|
+
optimizer_params=optimizer_params,
|
118
|
+
)
|
119
|
+
|
120
|
+
def get_ssl_criterion(self):
|
121
|
+
"""
|
122
|
+
Get the loss criterion for SSL. Some SSL algorithms, for example,
|
123
|
+
BYOL, will overwrite this function for specific loss functions.
|
124
|
+
"""
|
125
|
+
|
126
|
+
# Get loss criterion for the SSL
|
127
|
+
ssl_loss_function = loss_criterion.get()
|
128
|
+
|
129
|
+
# We need to wrap the loss function to make it compatible
|
130
|
+
# with different types of outputs
|
131
|
+
# The types of the outputs can vary from Tensor to a list of Tensors
|
132
|
+
def compute_loss(outputs, __):
|
133
|
+
if isinstance(outputs, (list, tuple)):
|
134
|
+
return ssl_loss_function(*outputs)
|
135
|
+
|
136
|
+
return ssl_loss_function(outputs)
|
137
|
+
|
138
|
+
return compute_loss
|
139
|
+
|
140
|
+
def get_loss_criterion(self):
|
141
|
+
"""Return the loss criterion for SSL."""
|
142
|
+
# Get loss criterion for the subsequent training process
|
143
|
+
if self.current_round > Config().trainer.rounds:
|
144
|
+
loss_criterion_type = Config().algorithm.personalization.loss_criterion
|
145
|
+
loss_criterion_params = {}
|
146
|
+
if hasattr(Config().parameters.personalization, "loss_criterion"):
|
147
|
+
loss_criterion_params = (
|
148
|
+
Config().parameters.personalization.loss_criterion._asdict()
|
149
|
+
)
|
150
|
+
return loss_criterion.get(
|
151
|
+
loss_criterion=loss_criterion_type,
|
152
|
+
loss_criterion_params=loss_criterion_params,
|
153
|
+
)
|
154
|
+
|
155
|
+
return self.get_ssl_criterion()
|
156
|
+
|
157
|
+
def get_lr_scheduler(self, config, optimizer):
|
158
|
+
# Get the lr scheduler for personalization
|
159
|
+
if self.current_round > Config().trainer.rounds:
|
160
|
+
lr_scheduler = Config().algorithm.personalization.lr_scheduler
|
161
|
+
lr_params = Config().parameters.personalization.learning_rate._asdict()
|
162
|
+
|
163
|
+
return lr_schedulers.get(
|
164
|
+
optimizer,
|
165
|
+
len(self.train_loader),
|
166
|
+
lr_scheduler=lr_scheduler,
|
167
|
+
lr_params=lr_params,
|
168
|
+
)
|
169
|
+
# Get the lr scheduler for SSL
|
170
|
+
return super().get_lr_scheduler(config, optimizer)
|
171
|
+
|
172
|
+
def train_run_start(self, config):
|
173
|
+
"""Set the config before training."""
|
174
|
+
if self.current_round > Config().trainer.rounds:
|
175
|
+
# Set the config for the personalization
|
176
|
+
config["batch_size"] = Config().algorithm.personalization.batch_size
|
177
|
+
config["epochs"] = Config().algorithm.personalization.epochs
|
178
|
+
|
179
|
+
# Move the local layers to the device and set it to train mode
|
180
|
+
self.local_layers.to(self.device)
|
181
|
+
self.local_layers.train()
|
182
|
+
|
183
|
+
def perform_forward_and_backward_passes(self, config, examples, labels):
|
184
|
+
"""Perform forward and backward passes in the training loop.
|
185
|
+
This function needs to reuse the optimization code of Plato as
|
186
|
+
during personalization, the encoder of the self.model will be used to
|
187
|
+
extract features into the local layers.
|
188
|
+
"""
|
189
|
+
|
190
|
+
# Perform SSL training in the first `Config().trainer.rounds`` rounds
|
191
|
+
if not self.current_round > Config().trainer.rounds:
|
192
|
+
return super().perform_forward_and_backward_passes(config, examples, labels)
|
193
|
+
|
194
|
+
# Perform personalization after the final round
|
195
|
+
# Perform the local update on self.local_layers
|
196
|
+
self.optimizer.zero_grad()
|
197
|
+
|
198
|
+
# Use the trained encoder to output features.
|
199
|
+
# No optimizer for this basic encoder
|
200
|
+
features = self.model.encoder(examples)
|
201
|
+
outputs = self.local_layers(features)
|
202
|
+
|
203
|
+
loss = self._loss_criterion(outputs, labels)
|
204
|
+
self._loss_tracker.update(loss, labels.size(0))
|
205
|
+
|
206
|
+
if "create_graph" in config:
|
207
|
+
loss.backward(create_graph=config["create_graph"])
|
208
|
+
else:
|
209
|
+
loss.backward()
|
210
|
+
|
211
|
+
self.optimizer.step()
|
212
|
+
|
213
|
+
return loss
|
214
|
+
|
215
|
+
def collect_encodings(self, data_loader):
|
216
|
+
"""Collect encodings of the data by using self.model."""
|
217
|
+
samples_encoding = None
|
218
|
+
samples_label = None
|
219
|
+
self.model.eval()
|
220
|
+
self.model.to(self.device)
|
221
|
+
for examples, labels in data_loader:
|
222
|
+
examples, labels = examples.to(self.device), labels.to(self.device)
|
223
|
+
with torch.no_grad():
|
224
|
+
features = self.model.encoder(examples)
|
225
|
+
if samples_encoding is None:
|
226
|
+
samples_encoding = features
|
227
|
+
else:
|
228
|
+
samples_encoding = torch.cat([samples_encoding, features], dim=0)
|
229
|
+
if samples_label is None:
|
230
|
+
samples_label = labels
|
231
|
+
else:
|
232
|
+
samples_label = torch.cat([samples_label, labels], dim=0)
|
233
|
+
|
234
|
+
return samples_encoding, samples_label
|
235
|
+
|
236
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
237
|
+
"""Test the model to report the accuracy in each round."""
|
238
|
+
batch_size = config["batch_size"]
|
239
|
+
if self.current_round > Config().trainer.rounds:
|
240
|
+
# Test the personalized model after the final round.
|
241
|
+
self.local_layers.eval()
|
242
|
+
self.local_layers.to(self.device)
|
243
|
+
|
244
|
+
self.model.eval()
|
245
|
+
self.model.to(self.device)
|
246
|
+
|
247
|
+
test_loader = torch.utils.data.DataLoader(
|
248
|
+
testset, batch_size=batch_size, shuffle=False, sampler=sampler
|
249
|
+
)
|
250
|
+
|
251
|
+
correct = 0
|
252
|
+
total = 0
|
253
|
+
accuracy = 0
|
254
|
+
with torch.no_grad():
|
255
|
+
for examples, labels in test_loader:
|
256
|
+
examples, labels = examples.to(self.device), labels.to(self.device)
|
257
|
+
|
258
|
+
features = self.model.encoder(examples)
|
259
|
+
outputs = self.local_layers(features)
|
260
|
+
|
261
|
+
_, predicted = torch.max(outputs.data, 1)
|
262
|
+
total += labels.size(0)
|
263
|
+
correct += (predicted == labels).sum().item()
|
264
|
+
|
265
|
+
accuracy = correct / total
|
266
|
+
|
267
|
+
return accuracy
|
268
|
+
else:
|
269
|
+
# Test the personalized model in each round.
|
270
|
+
|
271
|
+
# For SSL, the way to test the trained model before personalization is
|
272
|
+
# to use the KNN as a classifier to evaluate the extracted features.
|
273
|
+
|
274
|
+
logging.info("[Client #%d] Testing the model with KNN.", self.client_id)
|
275
|
+
|
276
|
+
# Get the training loader and test loader
|
277
|
+
train_loader = torch.utils.data.DataLoader(
|
278
|
+
dataset=self.personalized_trainset,
|
279
|
+
shuffle=False,
|
280
|
+
batch_size=batch_size,
|
281
|
+
sampler=sampler,
|
282
|
+
)
|
283
|
+
test_loader = torch.utils.data.DataLoader(
|
284
|
+
testset, batch_size=batch_size, shuffle=False, sampler=sampler
|
285
|
+
)
|
286
|
+
# For evaluating self-supervised performance, we need to calculate
|
287
|
+
# distance between training samples and testing samples.
|
288
|
+
train_encodings, train_labels = self.collect_encodings(train_loader)
|
289
|
+
test_encodings, test_labels = self.collect_encodings(test_loader)
|
290
|
+
|
291
|
+
# Build KNN and perform the prediction
|
292
|
+
distances = torch.cdist(test_encodings, train_encodings, p=2)
|
293
|
+
knn = distances.topk(1, largest=False)
|
294
|
+
nearest_idx = knn.indices
|
295
|
+
predicted_labels = train_labels[nearest_idx].view(-1)
|
296
|
+
test_labels = test_labels.view(-1)
|
297
|
+
|
298
|
+
# Compute the accuracy
|
299
|
+
num_correct = torch.sum(predicted_labels == test_labels).item()
|
300
|
+
accuracy = num_correct / len(test_labels)
|
301
|
+
|
302
|
+
return accuracy
|
@@ -0,0 +1,305 @@
|
|
1
|
+
"""
|
2
|
+
A federated learning trainer using split learning.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
|
6
|
+
Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
|
7
|
+
Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
|
8
|
+
|
9
|
+
https://arxiv.org/pdf/1812.00564.pdf
|
10
|
+
|
11
|
+
Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
|
12
|
+
Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
|
13
|
+
|
14
|
+
https://arxiv.org/pdf/2112.01637.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
import logging
|
18
|
+
import os
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from plato.config import Config
|
22
|
+
|
23
|
+
from plato.trainers import basic
|
24
|
+
from plato.datasources import feature
|
25
|
+
from plato.samplers import all_inclusive
|
26
|
+
|
27
|
+
|
28
|
+
# pylint:disable=too-many-instance-attributes
|
29
|
+
class Trainer(basic.Trainer):
|
30
|
+
"""The split learning trainer."""
|
31
|
+
|
32
|
+
def __init__(self, model=None, callbacks=None):
|
33
|
+
"""Initializing the trainer with the provided model.
|
34
|
+
|
35
|
+
Arguments:
|
36
|
+
model: The model to train.
|
37
|
+
callbacks: The callbacks that this trainer uses.
|
38
|
+
"""
|
39
|
+
super().__init__(model=model, callbacks=callbacks)
|
40
|
+
self.last_client_id = None
|
41
|
+
self.last_optimizer = None
|
42
|
+
|
43
|
+
# Client side variables
|
44
|
+
self.training_samples = None
|
45
|
+
self.gradients = None
|
46
|
+
self.data_loader = None
|
47
|
+
|
48
|
+
# Server side variables
|
49
|
+
self.cut_layer_grad = []
|
50
|
+
|
51
|
+
def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
|
52
|
+
"""
|
53
|
+
Creates an instance of the trainloader.
|
54
|
+
|
55
|
+
Arguments:
|
56
|
+
batch_size: the batch size.
|
57
|
+
trainset: the training dataset.
|
58
|
+
sampler: the sampler for the trainloader to use.
|
59
|
+
"""
|
60
|
+
return trainset
|
61
|
+
|
62
|
+
def perform_forward_and_backward_passes(self, config, examples, labels):
|
63
|
+
"""Perform forward and backward passes in the training loop.
|
64
|
+
|
65
|
+
Arguments:
|
66
|
+
config: the configuration.
|
67
|
+
examples: data samples in the current batch.
|
68
|
+
labels: labels in the current batch.
|
69
|
+
|
70
|
+
Returns: loss values after the current batch has been processed.
|
71
|
+
"""
|
72
|
+
if self.client_id == 0:
|
73
|
+
return self._server_train_loop(config, examples, labels)
|
74
|
+
|
75
|
+
return self._client_train_loop(examples)
|
76
|
+
|
77
|
+
def train_run_end(self, config):
|
78
|
+
"""Additional tasks after training."""
|
79
|
+
if self.client_id == 0:
|
80
|
+
# Server needs to save gradients, clients not
|
81
|
+
self.save_gradients(config)
|
82
|
+
|
83
|
+
def get_optimizer(self, model):
|
84
|
+
"""Return the optimizer used in the last round to avoid reconfiguration."""
|
85
|
+
if self.last_optimizer is None or self.last_client_id != self.client_id:
|
86
|
+
self.last_optimizer = super().get_optimizer(model)
|
87
|
+
self.last_client_id = self.client_id
|
88
|
+
|
89
|
+
return self.last_optimizer
|
90
|
+
|
91
|
+
def get_train_samples(self, batch_size, trainset, sampler):
|
92
|
+
"""
|
93
|
+
Get a batch of training samples to extract feature, the trainer has to save these
|
94
|
+
samples to complete training later.
|
95
|
+
"""
|
96
|
+
data_loader = torch.utils.data.DataLoader(
|
97
|
+
dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler
|
98
|
+
)
|
99
|
+
data_loader = iter(data_loader)
|
100
|
+
self.training_samples = next(data_loader)
|
101
|
+
# Wrap the training samples with datasource and sampler to be fed into Plato trainer
|
102
|
+
self.training_samples = self.process_training_samples_before_retrieving(
|
103
|
+
self.training_samples
|
104
|
+
)
|
105
|
+
return self.training_samples
|
106
|
+
|
107
|
+
def retrieve_train_samples(self):
|
108
|
+
"""Retrieve the training samples to complete client training."""
|
109
|
+
samples = feature.DataSource([[self.training_samples]])
|
110
|
+
sampler = all_inclusive.Sampler(samples)
|
111
|
+
|
112
|
+
return samples, sampler
|
113
|
+
|
114
|
+
def load_gradients(self, gradients):
|
115
|
+
"""Load the gradients which will be used to complete client training."""
|
116
|
+
self.gradients = gradients
|
117
|
+
|
118
|
+
def _client_train_loop(self, examples):
|
119
|
+
"""Complete the client side training with gradients from server."""
|
120
|
+
self.optimizer.zero_grad()
|
121
|
+
examples, batch_size = self.process_samples_before_client_forwarding(examples)
|
122
|
+
outputs = self.model.forward_to(examples)
|
123
|
+
|
124
|
+
# Backpropagate with gradients from the server
|
125
|
+
gradients = self.gradients[0]
|
126
|
+
if gradients is None:
|
127
|
+
logging.warning("[Client #%d] Gradients from server is None.", os.getpid())
|
128
|
+
else:
|
129
|
+
gradients = gradients.to(self.device)
|
130
|
+
outputs.backward(gradients)
|
131
|
+
self.optimizer.step()
|
132
|
+
|
133
|
+
# No loss value on the client side
|
134
|
+
loss = torch.zeros(1)
|
135
|
+
self._loss_tracker.update(loss, batch_size)
|
136
|
+
return loss
|
137
|
+
|
138
|
+
def _server_train_loop(self, config, examples, labels):
|
139
|
+
"""The training loop on the server."""
|
140
|
+
self.optimizer.zero_grad()
|
141
|
+
loss, grad, batch_size = self.server_forward_from((examples, labels), config)
|
142
|
+
loss = loss.cpu().detach()
|
143
|
+
self._loss_tracker.update(loss, batch_size)
|
144
|
+
|
145
|
+
# Record gradients within the cut layer
|
146
|
+
if grad is not None:
|
147
|
+
grad = grad.cpu().clone().detach()
|
148
|
+
self.cut_layer_grad = [grad]
|
149
|
+
self.optimizer.step()
|
150
|
+
|
151
|
+
logging.warning(
|
152
|
+
"[Server #%d] Gradients computed with training loss: %.4f",
|
153
|
+
os.getpid(),
|
154
|
+
loss,
|
155
|
+
)
|
156
|
+
|
157
|
+
return loss
|
158
|
+
|
159
|
+
def save_gradients(self, config):
|
160
|
+
"""Server saves recorded gradients to a file."""
|
161
|
+
model_name = config["model_name"]
|
162
|
+
model_path = Config().params["model_path"]
|
163
|
+
|
164
|
+
if not os.path.exists(model_path):
|
165
|
+
os.makedirs(model_path)
|
166
|
+
|
167
|
+
if "/" in model_name:
|
168
|
+
model_name = model_name.replace("/", "_")
|
169
|
+
|
170
|
+
model_gradients_path = f"{model_path}/{model_name}_gradients.pth"
|
171
|
+
torch.save(self.cut_layer_grad, model_gradients_path)
|
172
|
+
|
173
|
+
logging.info(
|
174
|
+
"[Server #%d] Gradients saved to %s.", os.getpid(), model_gradients_path
|
175
|
+
)
|
176
|
+
|
177
|
+
def get_gradients(self):
|
178
|
+
"""Read gradients from a file."""
|
179
|
+
model_path = Config().params["model_path"]
|
180
|
+
model_name = Config().trainer.model_name
|
181
|
+
|
182
|
+
if "/" in model_name:
|
183
|
+
model_name = model_name.replace("/", "_")
|
184
|
+
|
185
|
+
model_gradients_path = f"{model_path}/{model_name}_gradients.pth"
|
186
|
+
logging.info(
|
187
|
+
"[Server #%d] Loading gradients from %s.", os.getpid(), model_gradients_path
|
188
|
+
)
|
189
|
+
|
190
|
+
return torch.load(model_gradients_path)
|
191
|
+
|
192
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
193
|
+
"""
|
194
|
+
Evaluates the model with the provided test dataset and test sampler.
|
195
|
+
|
196
|
+
Arguments:
|
197
|
+
testset: the test dataset.
|
198
|
+
sampler: the test sampler. The default is None.
|
199
|
+
kwargs (optional): Additional keyword arguments.
|
200
|
+
"""
|
201
|
+
batch_size = config["batch_size"]
|
202
|
+
accuracy = self.test_model_split_learning(batch_size, testset, sampler)
|
203
|
+
return accuracy
|
204
|
+
|
205
|
+
# API functions for split learning
|
206
|
+
def process_training_samples_before_retrieving(self, training_samples) -> ...:
|
207
|
+
"""Process training samples before completing retrieving samples."""
|
208
|
+
return training_samples
|
209
|
+
|
210
|
+
def process_samples_before_client_forwarding(self, examples) -> ...:
|
211
|
+
"""Process the examples before client conducting forwarding."""
|
212
|
+
return examples, examples.size(0)
|
213
|
+
|
214
|
+
# pylint:disable=unused-argument
|
215
|
+
def server_forward_from(self, batch, config) -> (..., ..., int):
|
216
|
+
"""
|
217
|
+
The event for server completing training by forwarding from intermediate features.
|
218
|
+
Users may override this function for training different models with split learning.
|
219
|
+
|
220
|
+
Inputs:
|
221
|
+
batch: the batch of inputs for forwarding.
|
222
|
+
config: training configuration.
|
223
|
+
Returns:
|
224
|
+
loss: the calculated loss.
|
225
|
+
grad: the gradients over the intermediate feature.
|
226
|
+
batch_size: the batch size of the current sample.
|
227
|
+
"""
|
228
|
+
|
229
|
+
inputs, target = batch
|
230
|
+
batch_size = inputs.size(0)
|
231
|
+
inputs = inputs.detach().requires_grad_(True)
|
232
|
+
outputs = self.model.forward_from(inputs)
|
233
|
+
loss = self._loss_criterion(outputs, target)
|
234
|
+
loss.backward()
|
235
|
+
grad = inputs.grad
|
236
|
+
return loss, grad, batch_size
|
237
|
+
|
238
|
+
def update_weights_before_cut(self, current_weights, weights) -> ...:
|
239
|
+
"""
|
240
|
+
Update the weights before cut layer, called when testing accuracy in trainer.
|
241
|
+
Inputs:
|
242
|
+
current_weights: the current weights extracted by the algorithm.
|
243
|
+
weights: the weights to load.
|
244
|
+
Output:
|
245
|
+
current_weights: the updated current weights of the model.
|
246
|
+
"""
|
247
|
+
cut_layer_idx = self.model.layers.index(self.model.cut_layer)
|
248
|
+
|
249
|
+
for i in range(0, cut_layer_idx):
|
250
|
+
weight_name = f"{self.model.layers[i]}.weight"
|
251
|
+
bias_name = f"{self.model.layers[i]}.bias"
|
252
|
+
|
253
|
+
if weight_name in current_weights:
|
254
|
+
current_weights[weight_name] = weights[weight_name]
|
255
|
+
|
256
|
+
if bias_name in current_weights:
|
257
|
+
current_weights[bias_name] = weights[bias_name]
|
258
|
+
|
259
|
+
return current_weights
|
260
|
+
|
261
|
+
def forward_to_intermediate_feature(self, inputs, targets) -> (..., ...):
|
262
|
+
"""
|
263
|
+
The process to forward to get intermediate feature on the client.
|
264
|
+
Arguments:
|
265
|
+
inputs: the inputs for the model on the clients.
|
266
|
+
targets: the targets to get of the whole model.
|
267
|
+
|
268
|
+
Return:
|
269
|
+
outputs: the intermediate feature.
|
270
|
+
targets: the targets to get of the whole model.
|
271
|
+
"""
|
272
|
+
with torch.no_grad():
|
273
|
+
logits = self.model.forward_to(inputs)
|
274
|
+
|
275
|
+
outputs = logits.detach().cpu()
|
276
|
+
targets = targets.detach().cpu()
|
277
|
+
return outputs, targets
|
278
|
+
|
279
|
+
def test_model_split_learning(self, batch_size, testset, sampler=None) -> ...:
|
280
|
+
"""
|
281
|
+
The test model process for split learning.
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
accuracy: the metrics for evaluating the model.
|
285
|
+
"""
|
286
|
+
test_loader = torch.utils.data.DataLoader(
|
287
|
+
testset, batch_size=batch_size, shuffle=False, sampler=sampler
|
288
|
+
)
|
289
|
+
correct = 0
|
290
|
+
total = 0
|
291
|
+
|
292
|
+
self.model.to(self.device)
|
293
|
+
with torch.no_grad():
|
294
|
+
for examples, labels in test_loader:
|
295
|
+
examples, labels = examples.to(self.device), labels.to(self.device)
|
296
|
+
|
297
|
+
outputs = self.model(examples)
|
298
|
+
|
299
|
+
outputs = self.process_outputs(outputs)
|
300
|
+
|
301
|
+
_, predicted = torch.max(outputs.data, 1)
|
302
|
+
total += labels.size(0)
|
303
|
+
correct += (predicted == labels).sum().item()
|
304
|
+
|
305
|
+
return correct / total
|
@@ -0,0 +1,96 @@
|
|
1
|
+
"""
|
2
|
+
Keeping a history of metrics during the training run.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import Iterable
|
7
|
+
|
8
|
+
|
9
|
+
class RunHistory:
|
10
|
+
"""
|
11
|
+
A simple history of metrics during a training or evaluation run.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self._metrics = defaultdict(list)
|
16
|
+
|
17
|
+
def get_metric_names(self) -> Iterable:
|
18
|
+
"""
|
19
|
+
Returns an iterable set containing of all unique metric names which are
|
20
|
+
being tracked.
|
21
|
+
|
22
|
+
:return: an iterable of the unique metric names.
|
23
|
+
"""
|
24
|
+
return set(self._metrics.keys())
|
25
|
+
|
26
|
+
def get_metric_values(self, metric_name) -> Iterable:
|
27
|
+
"""
|
28
|
+
Returns an ordered iterable list of values that has been stored since
|
29
|
+
the last reset corresponding to the provided metric name.
|
30
|
+
|
31
|
+
:param metric_name: the name of the metric being tracked.
|
32
|
+
:return: an ordered iterable of values that have been recorded for that metric.
|
33
|
+
"""
|
34
|
+
return self._metrics[metric_name]
|
35
|
+
|
36
|
+
def get_latest_metric(self, metric_name):
|
37
|
+
"""
|
38
|
+
Returns the most recent value that has been recorded for the given metric.
|
39
|
+
|
40
|
+
:param metric_name: the name of the metric being tracked.
|
41
|
+
:return: the last recorded value.
|
42
|
+
"""
|
43
|
+
if len(self._metrics[metric_name]) > 0:
|
44
|
+
return self._metrics[metric_name][-1]
|
45
|
+
else:
|
46
|
+
raise ValueError(
|
47
|
+
f"No values have been recorded for the metric {metric_name}"
|
48
|
+
)
|
49
|
+
|
50
|
+
def update_metric(self, metric_name, metric_value):
|
51
|
+
"""
|
52
|
+
Records a new value for the given metric.
|
53
|
+
|
54
|
+
:param metric_name: the name of the metric being tracked.
|
55
|
+
:param metric_value: the value to record.
|
56
|
+
"""
|
57
|
+
self._metrics[metric_name].append(metric_value)
|
58
|
+
|
59
|
+
def reset(self):
|
60
|
+
"""
|
61
|
+
Resets the state of the :class:`RunHistory`.
|
62
|
+
|
63
|
+
"""
|
64
|
+
self._metrics = defaultdict(list)
|
65
|
+
|
66
|
+
|
67
|
+
class LossTracker:
|
68
|
+
"""A simple tracker for computing the average loss."""
|
69
|
+
|
70
|
+
def __init__(self):
|
71
|
+
self.loss_value = 0
|
72
|
+
self._average = 0
|
73
|
+
self.total_loss = 0
|
74
|
+
self.running_count = 0
|
75
|
+
|
76
|
+
def reset(self):
|
77
|
+
"""Resets this loss tracker."""
|
78
|
+
|
79
|
+
self.loss_value = 0
|
80
|
+
self._average = 0
|
81
|
+
self.total_loss = 0
|
82
|
+
self.running_count = 0
|
83
|
+
|
84
|
+
def update(self, loss_batch_value, batch_size=1):
|
85
|
+
"""Updates the loss tracker with another loss value from a batch."""
|
86
|
+
|
87
|
+
self.loss_value = loss_batch_value
|
88
|
+
self.total_loss += loss_batch_value * batch_size
|
89
|
+
self.running_count += batch_size
|
90
|
+
self._average = self.total_loss / self.running_count
|
91
|
+
|
92
|
+
@property
|
93
|
+
def average(self):
|
94
|
+
"""Returns the computed average of loss values tracked."""
|
95
|
+
|
96
|
+
return self._average.cpu().detach().mean().item()
|