plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,302 @@
1
+ """
2
+ A self-supervised learning (SSL) trainer for SSL training and testing.
3
+
4
+ Federated learning with SSL trains the global model based on the data loader and
5
+ objective function of SSL algorithms. For this unsupervised learning process, we
6
+ cannot test the model directly as the model only extracts features from the
7
+ data. Therefore, we use KNN as a classifier to get the accuracy of the global
8
+ model during the regular federated training process.
9
+
10
+ In the personalization process, each client trains a linear layer locally, based
11
+ on the features extracted by the trained global model.
12
+
13
+ The accuracy obtained by KNN during the regular federated training rounds may
14
+ not be used to compare with the accuracy in supervised learning methods.
15
+ """
16
+
17
+ import logging
18
+ from collections import UserList
19
+
20
+ import torch
21
+ from lightly.data.multi_view_collate import MultiViewCollate
22
+
23
+ from plato.config import Config
24
+ from plato.trainers import basic
25
+ from plato.models import registry as models_registry
26
+ from plato.trainers import optimizers, lr_schedulers, loss_criterion
27
+
28
+
29
+ class SSLSamples(UserList):
30
+ """A container for SSL sample, which contains multiple views as a list."""
31
+
32
+ def to(self, device):
33
+ """Assign a list of views into the specific device."""
34
+ for view_idx, view in enumerate(self.data):
35
+ if isinstance(view, torch.Tensor):
36
+ view = view.to(device)
37
+
38
+ self[view_idx] = view
39
+
40
+ return self.data
41
+
42
+
43
+ class MultiViewCollateWrapper(MultiViewCollate):
44
+ """
45
+ An interface to connect collate from lightly with Plato's data loading mechanism.
46
+ """
47
+
48
+ def __call__(self, batch):
49
+ """Turn a batch of tuples into a single tuple."""
50
+ # Add a fname to each sample to make the batch compatible with lightly
51
+ batch = [batch[i] + (" ",) for i in range(len(batch))]
52
+
53
+ # Process first two parts with the lightly collate
54
+ views, labels, _ = super().__call__(batch)
55
+
56
+ # Assign views, which is a list of tensors, into SSLSamples
57
+ samples = SSLSamples(views)
58
+ return samples, labels
59
+
60
+
61
+ class Trainer(basic.Trainer):
62
+ """A federated SSL trainer."""
63
+
64
+ def __init__(self, model=None, callbacks=None):
65
+ """Initialize the trainer."""
66
+ super().__init__(model=model, callbacks=callbacks)
67
+
68
+ # Datasets for personalization.
69
+ self.personalized_trainset = None
70
+ self.personalized_testset = None
71
+
72
+ # Define the personalized model
73
+ model_params = Config().parameters.personalization.model._asdict()
74
+ model_params["input_dim"] = self.model.encoder.encoding_dim
75
+ model_params["output_dim"] = model_params["num_classes"]
76
+ self.local_layers = models_registry.get(
77
+ model_name=Config().algorithm.personalization.model_name,
78
+ model_type=Config().algorithm.personalization.model_type,
79
+ model_params=model_params,
80
+ )
81
+
82
+ def set_personalized_datasets(self, trainset, testset):
83
+ """Set the personalized trainset."""
84
+ self.personalized_trainset = trainset
85
+ self.personalized_testset = testset
86
+
87
+ def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
88
+ """Obtain the training loader based on the learning mode."""
89
+ # Get the trainloader for personalization
90
+ if self.current_round > Config().trainer.rounds:
91
+ return torch.utils.data.DataLoader(
92
+ dataset=self.personalized_trainset,
93
+ shuffle=False,
94
+ batch_size=batch_size,
95
+ sampler=sampler,
96
+ )
97
+ else:
98
+ collate_fn = MultiViewCollateWrapper()
99
+ return torch.utils.data.DataLoader(
100
+ dataset=trainset,
101
+ shuffle=False,
102
+ batch_size=batch_size,
103
+ sampler=sampler,
104
+ collate_fn=collate_fn,
105
+ )
106
+
107
+ def get_optimizer(self, model):
108
+ """Return the optimizer for SSL and personalization."""
109
+ if self.current_round <= Config().trainer.rounds:
110
+ return super().get_optimizer(model)
111
+ # Define the optimizer for the personalized model
112
+ optimizer_name = Config().algorithm.personalization.optimizer
113
+ optimizer_params = Config().parameters.personalization.optimizer._asdict()
114
+ return optimizers.get(
115
+ self.local_layers,
116
+ optimizer_name=optimizer_name,
117
+ optimizer_params=optimizer_params,
118
+ )
119
+
120
+ def get_ssl_criterion(self):
121
+ """
122
+ Get the loss criterion for SSL. Some SSL algorithms, for example,
123
+ BYOL, will overwrite this function for specific loss functions.
124
+ """
125
+
126
+ # Get loss criterion for the SSL
127
+ ssl_loss_function = loss_criterion.get()
128
+
129
+ # We need to wrap the loss function to make it compatible
130
+ # with different types of outputs
131
+ # The types of the outputs can vary from Tensor to a list of Tensors
132
+ def compute_loss(outputs, __):
133
+ if isinstance(outputs, (list, tuple)):
134
+ return ssl_loss_function(*outputs)
135
+
136
+ return ssl_loss_function(outputs)
137
+
138
+ return compute_loss
139
+
140
+ def get_loss_criterion(self):
141
+ """Return the loss criterion for SSL."""
142
+ # Get loss criterion for the subsequent training process
143
+ if self.current_round > Config().trainer.rounds:
144
+ loss_criterion_type = Config().algorithm.personalization.loss_criterion
145
+ loss_criterion_params = {}
146
+ if hasattr(Config().parameters.personalization, "loss_criterion"):
147
+ loss_criterion_params = (
148
+ Config().parameters.personalization.loss_criterion._asdict()
149
+ )
150
+ return loss_criterion.get(
151
+ loss_criterion=loss_criterion_type,
152
+ loss_criterion_params=loss_criterion_params,
153
+ )
154
+
155
+ return self.get_ssl_criterion()
156
+
157
+ def get_lr_scheduler(self, config, optimizer):
158
+ # Get the lr scheduler for personalization
159
+ if self.current_round > Config().trainer.rounds:
160
+ lr_scheduler = Config().algorithm.personalization.lr_scheduler
161
+ lr_params = Config().parameters.personalization.learning_rate._asdict()
162
+
163
+ return lr_schedulers.get(
164
+ optimizer,
165
+ len(self.train_loader),
166
+ lr_scheduler=lr_scheduler,
167
+ lr_params=lr_params,
168
+ )
169
+ # Get the lr scheduler for SSL
170
+ return super().get_lr_scheduler(config, optimizer)
171
+
172
+ def train_run_start(self, config):
173
+ """Set the config before training."""
174
+ if self.current_round > Config().trainer.rounds:
175
+ # Set the config for the personalization
176
+ config["batch_size"] = Config().algorithm.personalization.batch_size
177
+ config["epochs"] = Config().algorithm.personalization.epochs
178
+
179
+ # Move the local layers to the device and set it to train mode
180
+ self.local_layers.to(self.device)
181
+ self.local_layers.train()
182
+
183
+ def perform_forward_and_backward_passes(self, config, examples, labels):
184
+ """Perform forward and backward passes in the training loop.
185
+ This function needs to reuse the optimization code of Plato as
186
+ during personalization, the encoder of the self.model will be used to
187
+ extract features into the local layers.
188
+ """
189
+
190
+ # Perform SSL training in the first `Config().trainer.rounds`` rounds
191
+ if not self.current_round > Config().trainer.rounds:
192
+ return super().perform_forward_and_backward_passes(config, examples, labels)
193
+
194
+ # Perform personalization after the final round
195
+ # Perform the local update on self.local_layers
196
+ self.optimizer.zero_grad()
197
+
198
+ # Use the trained encoder to output features.
199
+ # No optimizer for this basic encoder
200
+ features = self.model.encoder(examples)
201
+ outputs = self.local_layers(features)
202
+
203
+ loss = self._loss_criterion(outputs, labels)
204
+ self._loss_tracker.update(loss, labels.size(0))
205
+
206
+ if "create_graph" in config:
207
+ loss.backward(create_graph=config["create_graph"])
208
+ else:
209
+ loss.backward()
210
+
211
+ self.optimizer.step()
212
+
213
+ return loss
214
+
215
+ def collect_encodings(self, data_loader):
216
+ """Collect encodings of the data by using self.model."""
217
+ samples_encoding = None
218
+ samples_label = None
219
+ self.model.eval()
220
+ self.model.to(self.device)
221
+ for examples, labels in data_loader:
222
+ examples, labels = examples.to(self.device), labels.to(self.device)
223
+ with torch.no_grad():
224
+ features = self.model.encoder(examples)
225
+ if samples_encoding is None:
226
+ samples_encoding = features
227
+ else:
228
+ samples_encoding = torch.cat([samples_encoding, features], dim=0)
229
+ if samples_label is None:
230
+ samples_label = labels
231
+ else:
232
+ samples_label = torch.cat([samples_label, labels], dim=0)
233
+
234
+ return samples_encoding, samples_label
235
+
236
+ def test_model(self, config, testset, sampler=None, **kwargs):
237
+ """Test the model to report the accuracy in each round."""
238
+ batch_size = config["batch_size"]
239
+ if self.current_round > Config().trainer.rounds:
240
+ # Test the personalized model after the final round.
241
+ self.local_layers.eval()
242
+ self.local_layers.to(self.device)
243
+
244
+ self.model.eval()
245
+ self.model.to(self.device)
246
+
247
+ test_loader = torch.utils.data.DataLoader(
248
+ testset, batch_size=batch_size, shuffle=False, sampler=sampler
249
+ )
250
+
251
+ correct = 0
252
+ total = 0
253
+ accuracy = 0
254
+ with torch.no_grad():
255
+ for examples, labels in test_loader:
256
+ examples, labels = examples.to(self.device), labels.to(self.device)
257
+
258
+ features = self.model.encoder(examples)
259
+ outputs = self.local_layers(features)
260
+
261
+ _, predicted = torch.max(outputs.data, 1)
262
+ total += labels.size(0)
263
+ correct += (predicted == labels).sum().item()
264
+
265
+ accuracy = correct / total
266
+
267
+ return accuracy
268
+ else:
269
+ # Test the personalized model in each round.
270
+
271
+ # For SSL, the way to test the trained model before personalization is
272
+ # to use the KNN as a classifier to evaluate the extracted features.
273
+
274
+ logging.info("[Client #%d] Testing the model with KNN.", self.client_id)
275
+
276
+ # Get the training loader and test loader
277
+ train_loader = torch.utils.data.DataLoader(
278
+ dataset=self.personalized_trainset,
279
+ shuffle=False,
280
+ batch_size=batch_size,
281
+ sampler=sampler,
282
+ )
283
+ test_loader = torch.utils.data.DataLoader(
284
+ testset, batch_size=batch_size, shuffle=False, sampler=sampler
285
+ )
286
+ # For evaluating self-supervised performance, we need to calculate
287
+ # distance between training samples and testing samples.
288
+ train_encodings, train_labels = self.collect_encodings(train_loader)
289
+ test_encodings, test_labels = self.collect_encodings(test_loader)
290
+
291
+ # Build KNN and perform the prediction
292
+ distances = torch.cdist(test_encodings, train_encodings, p=2)
293
+ knn = distances.topk(1, largest=False)
294
+ nearest_idx = knn.indices
295
+ predicted_labels = train_labels[nearest_idx].view(-1)
296
+ test_labels = test_labels.view(-1)
297
+
298
+ # Compute the accuracy
299
+ num_correct = torch.sum(predicted_labels == test_labels).item()
300
+ accuracy = num_correct / len(test_labels)
301
+
302
+ return accuracy
@@ -0,0 +1,305 @@
1
+ """
2
+ A federated learning trainer using split learning.
3
+
4
+ Reference:
5
+
6
+ Vepakomma, et al., "Split Learning for Health: Distributed Deep Learning without Sharing
7
+ Raw Patient Data," in Proc. AI for Social Good Workshop, affiliated with ICLR 2018.
8
+
9
+ https://arxiv.org/pdf/1812.00564.pdf
10
+
11
+ Chopra, Ayush, et al. "AdaSplit: Adaptive Trade-offs for Resource-constrained Distributed
12
+ Deep Learning." arXiv preprint arXiv:2112.01637 (2021).
13
+
14
+ https://arxiv.org/pdf/2112.01637.pdf
15
+ """
16
+
17
+ import logging
18
+ import os
19
+
20
+ import torch
21
+ from plato.config import Config
22
+
23
+ from plato.trainers import basic
24
+ from plato.datasources import feature
25
+ from plato.samplers import all_inclusive
26
+
27
+
28
+ # pylint:disable=too-many-instance-attributes
29
+ class Trainer(basic.Trainer):
30
+ """The split learning trainer."""
31
+
32
+ def __init__(self, model=None, callbacks=None):
33
+ """Initializing the trainer with the provided model.
34
+
35
+ Arguments:
36
+ model: The model to train.
37
+ callbacks: The callbacks that this trainer uses.
38
+ """
39
+ super().__init__(model=model, callbacks=callbacks)
40
+ self.last_client_id = None
41
+ self.last_optimizer = None
42
+
43
+ # Client side variables
44
+ self.training_samples = None
45
+ self.gradients = None
46
+ self.data_loader = None
47
+
48
+ # Server side variables
49
+ self.cut_layer_grad = []
50
+
51
+ def get_train_loader(self, batch_size, trainset, sampler, **kwargs):
52
+ """
53
+ Creates an instance of the trainloader.
54
+
55
+ Arguments:
56
+ batch_size: the batch size.
57
+ trainset: the training dataset.
58
+ sampler: the sampler for the trainloader to use.
59
+ """
60
+ return trainset
61
+
62
+ def perform_forward_and_backward_passes(self, config, examples, labels):
63
+ """Perform forward and backward passes in the training loop.
64
+
65
+ Arguments:
66
+ config: the configuration.
67
+ examples: data samples in the current batch.
68
+ labels: labels in the current batch.
69
+
70
+ Returns: loss values after the current batch has been processed.
71
+ """
72
+ if self.client_id == 0:
73
+ return self._server_train_loop(config, examples, labels)
74
+
75
+ return self._client_train_loop(examples)
76
+
77
+ def train_run_end(self, config):
78
+ """Additional tasks after training."""
79
+ if self.client_id == 0:
80
+ # Server needs to save gradients, clients not
81
+ self.save_gradients(config)
82
+
83
+ def get_optimizer(self, model):
84
+ """Return the optimizer used in the last round to avoid reconfiguration."""
85
+ if self.last_optimizer is None or self.last_client_id != self.client_id:
86
+ self.last_optimizer = super().get_optimizer(model)
87
+ self.last_client_id = self.client_id
88
+
89
+ return self.last_optimizer
90
+
91
+ def get_train_samples(self, batch_size, trainset, sampler):
92
+ """
93
+ Get a batch of training samples to extract feature, the trainer has to save these
94
+ samples to complete training later.
95
+ """
96
+ data_loader = torch.utils.data.DataLoader(
97
+ dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler
98
+ )
99
+ data_loader = iter(data_loader)
100
+ self.training_samples = next(data_loader)
101
+ # Wrap the training samples with datasource and sampler to be fed into Plato trainer
102
+ self.training_samples = self.process_training_samples_before_retrieving(
103
+ self.training_samples
104
+ )
105
+ return self.training_samples
106
+
107
+ def retrieve_train_samples(self):
108
+ """Retrieve the training samples to complete client training."""
109
+ samples = feature.DataSource([[self.training_samples]])
110
+ sampler = all_inclusive.Sampler(samples)
111
+
112
+ return samples, sampler
113
+
114
+ def load_gradients(self, gradients):
115
+ """Load the gradients which will be used to complete client training."""
116
+ self.gradients = gradients
117
+
118
+ def _client_train_loop(self, examples):
119
+ """Complete the client side training with gradients from server."""
120
+ self.optimizer.zero_grad()
121
+ examples, batch_size = self.process_samples_before_client_forwarding(examples)
122
+ outputs = self.model.forward_to(examples)
123
+
124
+ # Backpropagate with gradients from the server
125
+ gradients = self.gradients[0]
126
+ if gradients is None:
127
+ logging.warning("[Client #%d] Gradients from server is None.", os.getpid())
128
+ else:
129
+ gradients = gradients.to(self.device)
130
+ outputs.backward(gradients)
131
+ self.optimizer.step()
132
+
133
+ # No loss value on the client side
134
+ loss = torch.zeros(1)
135
+ self._loss_tracker.update(loss, batch_size)
136
+ return loss
137
+
138
+ def _server_train_loop(self, config, examples, labels):
139
+ """The training loop on the server."""
140
+ self.optimizer.zero_grad()
141
+ loss, grad, batch_size = self.server_forward_from((examples, labels), config)
142
+ loss = loss.cpu().detach()
143
+ self._loss_tracker.update(loss, batch_size)
144
+
145
+ # Record gradients within the cut layer
146
+ if grad is not None:
147
+ grad = grad.cpu().clone().detach()
148
+ self.cut_layer_grad = [grad]
149
+ self.optimizer.step()
150
+
151
+ logging.warning(
152
+ "[Server #%d] Gradients computed with training loss: %.4f",
153
+ os.getpid(),
154
+ loss,
155
+ )
156
+
157
+ return loss
158
+
159
+ def save_gradients(self, config):
160
+ """Server saves recorded gradients to a file."""
161
+ model_name = config["model_name"]
162
+ model_path = Config().params["model_path"]
163
+
164
+ if not os.path.exists(model_path):
165
+ os.makedirs(model_path)
166
+
167
+ if "/" in model_name:
168
+ model_name = model_name.replace("/", "_")
169
+
170
+ model_gradients_path = f"{model_path}/{model_name}_gradients.pth"
171
+ torch.save(self.cut_layer_grad, model_gradients_path)
172
+
173
+ logging.info(
174
+ "[Server #%d] Gradients saved to %s.", os.getpid(), model_gradients_path
175
+ )
176
+
177
+ def get_gradients(self):
178
+ """Read gradients from a file."""
179
+ model_path = Config().params["model_path"]
180
+ model_name = Config().trainer.model_name
181
+
182
+ if "/" in model_name:
183
+ model_name = model_name.replace("/", "_")
184
+
185
+ model_gradients_path = f"{model_path}/{model_name}_gradients.pth"
186
+ logging.info(
187
+ "[Server #%d] Loading gradients from %s.", os.getpid(), model_gradients_path
188
+ )
189
+
190
+ return torch.load(model_gradients_path)
191
+
192
+ def test_model(self, config, testset, sampler=None, **kwargs):
193
+ """
194
+ Evaluates the model with the provided test dataset and test sampler.
195
+
196
+ Arguments:
197
+ testset: the test dataset.
198
+ sampler: the test sampler. The default is None.
199
+ kwargs (optional): Additional keyword arguments.
200
+ """
201
+ batch_size = config["batch_size"]
202
+ accuracy = self.test_model_split_learning(batch_size, testset, sampler)
203
+ return accuracy
204
+
205
+ # API functions for split learning
206
+ def process_training_samples_before_retrieving(self, training_samples) -> ...:
207
+ """Process training samples before completing retrieving samples."""
208
+ return training_samples
209
+
210
+ def process_samples_before_client_forwarding(self, examples) -> ...:
211
+ """Process the examples before client conducting forwarding."""
212
+ return examples, examples.size(0)
213
+
214
+ # pylint:disable=unused-argument
215
+ def server_forward_from(self, batch, config) -> (..., ..., int):
216
+ """
217
+ The event for server completing training by forwarding from intermediate features.
218
+ Users may override this function for training different models with split learning.
219
+
220
+ Inputs:
221
+ batch: the batch of inputs for forwarding.
222
+ config: training configuration.
223
+ Returns:
224
+ loss: the calculated loss.
225
+ grad: the gradients over the intermediate feature.
226
+ batch_size: the batch size of the current sample.
227
+ """
228
+
229
+ inputs, target = batch
230
+ batch_size = inputs.size(0)
231
+ inputs = inputs.detach().requires_grad_(True)
232
+ outputs = self.model.forward_from(inputs)
233
+ loss = self._loss_criterion(outputs, target)
234
+ loss.backward()
235
+ grad = inputs.grad
236
+ return loss, grad, batch_size
237
+
238
+ def update_weights_before_cut(self, current_weights, weights) -> ...:
239
+ """
240
+ Update the weights before cut layer, called when testing accuracy in trainer.
241
+ Inputs:
242
+ current_weights: the current weights extracted by the algorithm.
243
+ weights: the weights to load.
244
+ Output:
245
+ current_weights: the updated current weights of the model.
246
+ """
247
+ cut_layer_idx = self.model.layers.index(self.model.cut_layer)
248
+
249
+ for i in range(0, cut_layer_idx):
250
+ weight_name = f"{self.model.layers[i]}.weight"
251
+ bias_name = f"{self.model.layers[i]}.bias"
252
+
253
+ if weight_name in current_weights:
254
+ current_weights[weight_name] = weights[weight_name]
255
+
256
+ if bias_name in current_weights:
257
+ current_weights[bias_name] = weights[bias_name]
258
+
259
+ return current_weights
260
+
261
+ def forward_to_intermediate_feature(self, inputs, targets) -> (..., ...):
262
+ """
263
+ The process to forward to get intermediate feature on the client.
264
+ Arguments:
265
+ inputs: the inputs for the model on the clients.
266
+ targets: the targets to get of the whole model.
267
+
268
+ Return:
269
+ outputs: the intermediate feature.
270
+ targets: the targets to get of the whole model.
271
+ """
272
+ with torch.no_grad():
273
+ logits = self.model.forward_to(inputs)
274
+
275
+ outputs = logits.detach().cpu()
276
+ targets = targets.detach().cpu()
277
+ return outputs, targets
278
+
279
+ def test_model_split_learning(self, batch_size, testset, sampler=None) -> ...:
280
+ """
281
+ The test model process for split learning.
282
+
283
+ Returns:
284
+ accuracy: the metrics for evaluating the model.
285
+ """
286
+ test_loader = torch.utils.data.DataLoader(
287
+ testset, batch_size=batch_size, shuffle=False, sampler=sampler
288
+ )
289
+ correct = 0
290
+ total = 0
291
+
292
+ self.model.to(self.device)
293
+ with torch.no_grad():
294
+ for examples, labels in test_loader:
295
+ examples, labels = examples.to(self.device), labels.to(self.device)
296
+
297
+ outputs = self.model(examples)
298
+
299
+ outputs = self.process_outputs(outputs)
300
+
301
+ _, predicted = torch.max(outputs.data, 1)
302
+ total += labels.size(0)
303
+ correct += (predicted == labels).sum().item()
304
+
305
+ return correct / total
@@ -0,0 +1,96 @@
1
+ """
2
+ Keeping a history of metrics during the training run.
3
+ """
4
+
5
+ from collections import defaultdict
6
+ from typing import Iterable
7
+
8
+
9
+ class RunHistory:
10
+ """
11
+ A simple history of metrics during a training or evaluation run.
12
+ """
13
+
14
+ def __init__(self):
15
+ self._metrics = defaultdict(list)
16
+
17
+ def get_metric_names(self) -> Iterable:
18
+ """
19
+ Returns an iterable set containing of all unique metric names which are
20
+ being tracked.
21
+
22
+ :return: an iterable of the unique metric names.
23
+ """
24
+ return set(self._metrics.keys())
25
+
26
+ def get_metric_values(self, metric_name) -> Iterable:
27
+ """
28
+ Returns an ordered iterable list of values that has been stored since
29
+ the last reset corresponding to the provided metric name.
30
+
31
+ :param metric_name: the name of the metric being tracked.
32
+ :return: an ordered iterable of values that have been recorded for that metric.
33
+ """
34
+ return self._metrics[metric_name]
35
+
36
+ def get_latest_metric(self, metric_name):
37
+ """
38
+ Returns the most recent value that has been recorded for the given metric.
39
+
40
+ :param metric_name: the name of the metric being tracked.
41
+ :return: the last recorded value.
42
+ """
43
+ if len(self._metrics[metric_name]) > 0:
44
+ return self._metrics[metric_name][-1]
45
+ else:
46
+ raise ValueError(
47
+ f"No values have been recorded for the metric {metric_name}"
48
+ )
49
+
50
+ def update_metric(self, metric_name, metric_value):
51
+ """
52
+ Records a new value for the given metric.
53
+
54
+ :param metric_name: the name of the metric being tracked.
55
+ :param metric_value: the value to record.
56
+ """
57
+ self._metrics[metric_name].append(metric_value)
58
+
59
+ def reset(self):
60
+ """
61
+ Resets the state of the :class:`RunHistory`.
62
+
63
+ """
64
+ self._metrics = defaultdict(list)
65
+
66
+
67
+ class LossTracker:
68
+ """A simple tracker for computing the average loss."""
69
+
70
+ def __init__(self):
71
+ self.loss_value = 0
72
+ self._average = 0
73
+ self.total_loss = 0
74
+ self.running_count = 0
75
+
76
+ def reset(self):
77
+ """Resets this loss tracker."""
78
+
79
+ self.loss_value = 0
80
+ self._average = 0
81
+ self.total_loss = 0
82
+ self.running_count = 0
83
+
84
+ def update(self, loss_batch_value, batch_size=1):
85
+ """Updates the loss tracker with another loss value from a batch."""
86
+
87
+ self.loss_value = loss_batch_value
88
+ self.total_loss += loss_batch_value * batch_size
89
+ self.running_count += batch_size
90
+ self._average = self.total_loss / self.running_count
91
+
92
+ @property
93
+ def average(self):
94
+ """Returns the computed average of loss values tracked."""
95
+
96
+ return self._average.cpu().detach().mean().item()