plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,178 @@
1
+ """
2
+ The training and testing loops for PyTorch.
3
+ """
4
+
5
+ import logging
6
+ import time
7
+
8
+ from opacus import GradSampleModule
9
+ from opacus.privacy_engine import PrivacyEngine
10
+ from opacus.utils.batch_memory_manager import BatchMemoryManager
11
+ from opacus.validators import ModuleValidator
12
+ from torch.utils.data import Subset
13
+
14
+ from plato.config import Config
15
+ from plato.trainers import basic
16
+
17
+
18
+ class Trainer(basic.Trainer):
19
+ """A differentially private federated learning trainer, used by the client."""
20
+
21
+ def __init__(self, model=None, **kwargs):
22
+ """Initializing the trainer with the provided model."""
23
+ super().__init__(model=model)
24
+
25
+ self.max_physical_batch_size = (
26
+ Config().trainer.max_physical_batch_size
27
+ if hasattr(Config().trainer, "max_physical_batch_size")
28
+ else 128
29
+ )
30
+
31
+ self.make_model_private()
32
+
33
+ def make_model_private(self):
34
+ """Make the model private for use with the differential privacy engine."""
35
+ errors = ModuleValidator.validate(self.model, strict=False)
36
+ if len(errors) > 0:
37
+ self.model = ModuleValidator.fix(self.model)
38
+ errors = ModuleValidator.validate(self.model, strict=False)
39
+ assert len(errors) == 0
40
+
41
+ # pylint: disable=unused-argument
42
+ def train_model(self, config, trainset, sampler, **kwargs):
43
+ """The default training loop that supports differential privacy."""
44
+ batch_size = config["batch_size"]
45
+ self.sampler = sampler
46
+ tic = time.perf_counter()
47
+
48
+ self.train_run_start(config)
49
+ self.callback_handler.call_event("on_train_run_start", self, config)
50
+
51
+ # We have to use poisson sampling to sample the data, rather than the provided sampler.
52
+ # Replacing the poisson sampler with the provided sampler is problematic since it may
53
+ # violate the basic theory of DP-SGD. Therefore, we need to first obtain the train subset
54
+ # based on the provided sampler, and then create a simple dataloader on the train subset
55
+ # without the sampler. We will finally use Opacus to recreate the dataloader from the
56
+ # simple dataloader (with poisson sampling).
57
+ trainset = Subset(trainset, list(sampler))
58
+ self.train_loader = self.get_train_loader(batch_size, trainset, sampler=None)
59
+
60
+ # Initializing the loss criterion
61
+ _loss_criterion = self.get_loss_criterion()
62
+
63
+ # Initializing the optimizer
64
+ optimizer = self.get_optimizer(self.model)
65
+ self.lr_scheduler = self.get_lr_scheduler(config, optimizer)
66
+ optimizer = self._adjust_lr(config, self.lr_scheduler, optimizer)
67
+
68
+ self.model.to(self.device)
69
+ total_epochs = config["epochs"]
70
+
71
+ logging.info(
72
+ "[Client #%s] Using differential privacy during training.",
73
+ self.client_id,
74
+ )
75
+
76
+ privacy_engine = PrivacyEngine(accountant="rdp", secure_mode=False)
77
+
78
+ self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
79
+ module=self.model,
80
+ optimizer=optimizer,
81
+ data_loader=self.train_loader,
82
+ target_epsilon=config["dp_epsilon"] if "dp_epsilon" in config else 10.0,
83
+ target_delta=config["dp_delta"] if "dp_delta" in config else 1e-5,
84
+ epochs=total_epochs,
85
+ max_grad_norm=config["dp_max_grad_norm"]
86
+ if "max_grad_norm" in config
87
+ else 1.0,
88
+ )
89
+
90
+ self.model.train()
91
+
92
+ for self.current_epoch in range(1, total_epochs + 1):
93
+ with BatchMemoryManager(
94
+ data_loader=train_loader,
95
+ max_physical_batch_size=self.max_physical_batch_size,
96
+ optimizer=optimizer,
97
+ ) as memory_safe_train_loader:
98
+ self._loss_tracker.reset()
99
+ self.train_epoch_start(config)
100
+ self.callback_handler.call_event("on_train_epoch_start", self, config)
101
+
102
+ for batch_id, (examples, labels) in enumerate(memory_safe_train_loader):
103
+ examples, labels = (
104
+ examples.to(self.device),
105
+ labels.to(self.device),
106
+ )
107
+ optimizer.zero_grad(set_to_none=True)
108
+
109
+ outputs = self.model(examples)
110
+
111
+ loss = _loss_criterion(outputs, labels)
112
+ self._loss_tracker.update(loss, labels.size(0))
113
+
114
+ if "create_graph" in config:
115
+ loss.backward(create_graph=config["create_graph"])
116
+ else:
117
+ loss.backward()
118
+
119
+ optimizer.step()
120
+
121
+ self.train_step_end(config, batch=batch_id, loss=loss)
122
+ self.callback_handler.call_event(
123
+ "on_train_step_end",
124
+ self,
125
+ config,
126
+ batch=batch_id,
127
+ loss=loss,
128
+ )
129
+
130
+ self.lr_scheduler_step()
131
+
132
+ if hasattr(optimizer, "params_state_update"):
133
+ optimizer.params_state_update()
134
+
135
+ # Simulate client's speed
136
+ if (
137
+ self.client_id != 0
138
+ and hasattr(Config().clients, "speed_simulation")
139
+ and Config().clients.speed_simulation
140
+ ):
141
+ self.simulate_sleep_time()
142
+
143
+ # Saving the model at the end of this epoch to a file so that
144
+ # it can later be retrieved to respond to server requests
145
+ # in asynchronous mode when the wall clock time is simulated
146
+ if (
147
+ hasattr(Config().server, "request_update")
148
+ and Config().server.request_update
149
+ ):
150
+ self.model.cpu()
151
+ training_time = time.perf_counter() - tic
152
+ filename = f"{self.client_id}_{self.current_epoch}_{training_time}.pth"
153
+ self.save_model(filename)
154
+ self.model.to(self.device)
155
+
156
+ self.run_history.update_metric("train_loss", self._loss_tracker.average)
157
+ self.train_epoch_end(config)
158
+ self.callback_handler.call_event("on_train_epoch_end", self, config)
159
+
160
+ self.train_run_end(config)
161
+ self.callback_handler.call_event("on_train_run_end", self, config)
162
+
163
+ def train_run_start(self, config):
164
+ """
165
+ Method called at the start of training run.
166
+ """
167
+ self.model = GradSampleModule(self.model)
168
+
169
+ def train_run_end(self, config):
170
+ """
171
+ Method called at the end of a training run.
172
+ """
173
+ # After GradSampleModule() conversion, the state_dict names have a `_module` prefix
174
+ # We will need to save the weights with the original layer names without the prefix
175
+ self.model_state_dict = {
176
+ k[8:] if "_module." in k else k: v
177
+ for k, v in self.model.state_dict().items()
178
+ }
plato/trainers/gan.py ADDED
@@ -0,0 +1,330 @@
1
+ """
2
+ The training and testing loops for GAN models.
3
+
4
+ Reference:
5
+ https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
6
+ """
7
+
8
+ import logging
9
+ import math
10
+ import os
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision
15
+ import numpy as np
16
+ import scipy
17
+
18
+ from plato.config import Config
19
+ from plato.models import registry as models_registry
20
+ from plato.trainers import basic
21
+ from plato.trainers import optimizers
22
+
23
+
24
+ class Trainer(basic.Trainer):
25
+ """A federated learning trainer for GAN models."""
26
+
27
+ def __init__(self, model=None, **kwargs):
28
+ super().__init__()
29
+
30
+ if model is None:
31
+ model = models_registry.get()
32
+ gan_model = model
33
+ self.generator = gan_model.generator
34
+ self.discriminator = gan_model.discriminator
35
+ self.loss_criterion = gan_model.loss_criterion
36
+ self.model = gan_model
37
+
38
+ # Use the pre-trained InceptionV3 model as a feature extractor
39
+ # for testing
40
+ self.inception_model = torchvision.models.inception_v3(
41
+ pretrained=True, aux_logits=False
42
+ )
43
+ # Remove the last output layer of inception
44
+ self.inception_model.fc = nn.Identity()
45
+ self.inception_model.eval()
46
+
47
+ self.training_start_time = 0
48
+
49
+ def save_model(self, filename=None, location=None):
50
+ """Saving the model to a file."""
51
+ model_path = Config().params["model_path"] if location is None else location
52
+ model_name = Config().trainer.model_name
53
+
54
+ try:
55
+ if not os.path.exists(model_path):
56
+ os.makedirs(model_path)
57
+ except FileExistsError:
58
+ pass
59
+
60
+ if filename is not None:
61
+ net_gen_path = f"{model_path}/Generator_{filename}"
62
+ net_disc_path = f"{model_path}/Discriminator_{filename}"
63
+ else:
64
+ net_gen_path = f"{model_path}/Generator_{model_name}.pth"
65
+ net_disc_path = f"{model_path}/Discriminator_{model_name}.pth"
66
+
67
+ torch.save(self.generator.state_dict(), net_gen_path)
68
+ torch.save(self.discriminator.state_dict(), net_disc_path)
69
+
70
+ if self.client_id == 0:
71
+ logging.info(
72
+ "[Server #%d] Generator Model saved to %s.", os.getpid(), net_gen_path
73
+ )
74
+ logging.info(
75
+ "[Server #%d] Discriminator Model saved to %s.",
76
+ os.getpid(),
77
+ net_disc_path,
78
+ )
79
+ else:
80
+ logging.info(
81
+ "[Client #%d] Generator Model saved to %s.",
82
+ self.client_id,
83
+ net_gen_path,
84
+ )
85
+ logging.info(
86
+ "[Client #%d] Discriminator Model saved to %s.",
87
+ self.client_id,
88
+ net_disc_path,
89
+ )
90
+
91
+ def load_model(self, filename=None, location=None):
92
+ """Loading pre-trained model weights from a file."""
93
+ model_path = Config().params["model_path"] if location is None else location
94
+ model_name = Config().trainer.model_name
95
+
96
+ if filename is not None:
97
+ net_gen_path = f"{model_path}/Generator_{filename}"
98
+ net_disc_path = f"{model_path}/Discriminator_{filename}"
99
+ else:
100
+ net_gen_path = f"{model_path}/Generator_{model_name}.pth"
101
+ net_disc_path = f"{model_path}/Discriminator_{model_name}.pth"
102
+
103
+ if self.client_id == 0:
104
+ logging.info(
105
+ "[Server #%d] Loading a Generator model from %s.",
106
+ os.getpid(),
107
+ net_gen_path,
108
+ )
109
+ logging.info(
110
+ "[Server #%d] Loading a Discriminator model from %s.",
111
+ os.getpid(),
112
+ net_disc_path,
113
+ )
114
+ else:
115
+ logging.info(
116
+ "[Client #%d] Loading a Generator model from %s.",
117
+ self.client_id,
118
+ net_gen_path,
119
+ )
120
+ logging.info(
121
+ "[Client #%d] Loading a Discriminator model from %s.",
122
+ self.client_id,
123
+ net_disc_path,
124
+ )
125
+
126
+ self.generator.load_state_dict(torch.load(net_gen_path))
127
+ self.discriminator.load_state_dict(torch.load(net_disc_path))
128
+
129
+ # pylint: disable=unused-argument
130
+ def train_model(self, config, trainset, sampler, **kwargs):
131
+ """The main training loop in a federated learning workload.
132
+
133
+ Arguments:
134
+ trainset: The training dataset.
135
+ sampler: the sampler that extracts a partition for this client.
136
+
137
+ Returns:
138
+ float: The training time.
139
+ """
140
+ batch_size = config["batch_size"]
141
+ log_interval = 10
142
+
143
+ logging.info("[Client #%d] Loading the dataset.", self.client_id)
144
+
145
+ train_loader = torch.utils.data.DataLoader(
146
+ dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler
147
+ )
148
+
149
+ self.model.to(self.device)
150
+ self.model.train()
151
+
152
+ # self.generator.apply(self.model.weights_init)
153
+ # self.discriminator.apply(self.model.weights_init)
154
+
155
+ optimizer_gen = optimizers.get(self.generator)
156
+ optimizer_disc = optimizers.get(self.discriminator)
157
+
158
+ real_label = 1.0
159
+ fake_label = 0.0
160
+
161
+ epochs = config["epochs"]
162
+ for epoch in range(1, epochs + 1):
163
+ # Here we assume the data samples still have labels attached to them,
164
+ # but GAN training does not need labels, so we'll just discard them
165
+ for batch_id, (examples, _) in enumerate(train_loader):
166
+ cur_batch_size = len(examples)
167
+ examples = examples.to(self.device)
168
+ label = torch.full((cur_batch_size,), real_label, dtype=torch.float)
169
+ label = label.to(self.device)
170
+ ############################
171
+ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
172
+ ###########################
173
+ ## Train with all-real batch
174
+ optimizer_disc.zero_grad()
175
+ # Forward pass real batch through D
176
+ output = self.discriminator(examples).view(-1)
177
+ # Calculate loss on all-real batch
178
+ err_disc_real = self.loss_criterion(output, label)
179
+ # Calculate gradients for D in backward pass
180
+ err_disc_real.backward()
181
+
182
+ ## Train with all-fake batch
183
+ # Generate batch of latent vectors
184
+ noise = torch.randn(
185
+ cur_batch_size, self.model.nz, 1, 1, device=self.device
186
+ )
187
+ # Generate fake image batch with G
188
+ fake = self.generator(noise)
189
+ label.fill_(fake_label)
190
+ # Classify all fake batch with D
191
+ output = self.discriminator(fake.detach()).view(-1)
192
+ # Calculate D's loss on the all-fake batch
193
+ err_disc_fake = self.loss_criterion(output, label)
194
+ # Calculate the gradients for this batch, accumulated (summed)
195
+ # with previous gradients
196
+ err_disc_fake.backward()
197
+ # Compute error of D as sum over the fake and the real batches
198
+ err_disc_total = err_disc_real + err_disc_fake
199
+ # Update D
200
+ optimizer_disc.step()
201
+
202
+ ############################
203
+ # (2) Update G network: maximize log(D(G(z)))
204
+ ###########################
205
+ optimizer_gen.zero_grad()
206
+ label.fill_(real_label) # fake labels are real for generator cost
207
+ # Since we just updated D, perform another forward pass of all-fake batch through D
208
+ output = self.discriminator(fake).view(-1)
209
+ # Calculate G's loss based on this output
210
+ err_gen = self.loss_criterion(output, label)
211
+ # Calculate gradients for G
212
+ err_gen.backward()
213
+ # Update G
214
+ optimizer_gen.step()
215
+
216
+ if batch_id % log_interval == 0:
217
+ if self.client_id == 0:
218
+ logging.info(
219
+ "[Server #%d] Epoch: [%d/%d][%d/%d]\tGenerator Loss: %.6f\t"
220
+ "Discriminator Loss: %.6f",
221
+ os.getpid(),
222
+ epoch,
223
+ epochs,
224
+ batch_id,
225
+ len(train_loader),
226
+ err_gen.data.item(),
227
+ err_disc_total.data.item(),
228
+ )
229
+ else:
230
+ logging.info(
231
+ "[Client #%d] Epoch: [%d/%d][%d/%d]\tGenerator Loss: %.6f\t"
232
+ "Discriminator Loss: %.6f",
233
+ self.client_id,
234
+ epoch,
235
+ epochs,
236
+ batch_id,
237
+ len(train_loader),
238
+ err_gen.data.item(),
239
+ err_disc_total.data.item(),
240
+ )
241
+
242
+ def test_model(self, config, testset, sampler=None, **kwargs):
243
+ """Test the Generator model with the Frechet Inception Distance metric."""
244
+
245
+ self.model.to(self.device)
246
+ self.model.eval()
247
+
248
+ perplexity = -1
249
+
250
+ test_loader = torch.utils.data.DataLoader(
251
+ testset, batch_size=config["batch_size"], shuffle=True
252
+ )
253
+
254
+ real_features, fake_features = [], []
255
+ with torch.no_grad():
256
+ for real_examples, _ in test_loader:
257
+ real_examples = real_examples.to(self.device)
258
+
259
+ noise = torch.randn(
260
+ config["batch_size"], self.model.nz, 1, 1, device=self.device
261
+ )
262
+ fake_examples = self.generator(noise)
263
+
264
+ # Extract the feature of real and synthetic data with
265
+ # InceptionV3 model pre-trained on ImageNet
266
+ self.inception_model.to(self.device)
267
+ feature_real = self.feature_extractor(real_examples)
268
+ feature_fake = self.feature_extractor(fake_examples)
269
+
270
+ # Store the feature of every real and synthetic data
271
+ real_features.extend(list(feature_real))
272
+ fake_features.extend(list(feature_fake))
273
+
274
+ real_features, fake_features = (
275
+ np.stack(real_features),
276
+ np.stack(fake_features),
277
+ )
278
+ # Calculate the Frechet Distance between the feature distribution
279
+ # of real data from testset and the feature distribution of data
280
+ # generated by the generator.
281
+ perplexity = self.calculate_fid(real_features, fake_features)
282
+
283
+ return perplexity
284
+
285
+ def feature_extractor(self, inputs):
286
+ """Extract the feature of input data with InceptionV3.
287
+
288
+ The feature extracted from each input is a NumPy array
289
+ of length 2048.
290
+ """
291
+ # Since the input to InceptionV3 needs to be at least 75x75,
292
+ # we will pad the input image if needed.
293
+ hpad = math.ceil((75 - inputs.size(dim=-2)) / 2)
294
+ vpad = math.ceil((75 - inputs.size(dim=-1)) / 2)
295
+ hpad, vpad = max(0, hpad), max(0, vpad)
296
+ pad = nn.ZeroPad2d((hpad, hpad, vpad, vpad))
297
+ inputs = pad(inputs)
298
+
299
+ # Extract feature with InceptionV3
300
+ features = None
301
+ with torch.no_grad():
302
+ features = self.inception_model(inputs)
303
+ features = features.cpu()
304
+ features = np.array(features)
305
+
306
+ return features
307
+
308
+ def calculate_fid(self, real_features, fake_features):
309
+ """Calculate the Frechet Inception Distance (FID) between the
310
+ given real data feature and the synthetic data feature.
311
+
312
+ A lower FID indicates a better Generator model.
313
+
314
+ The implementation is borrowed from the following link:
315
+ https://wandb.ai/ayush-thakur/gan-evaluation/reports/How-to-Evaluate-GANs-using-Frechet-Inception-Distance-FID---Vmlldzo0MTAxOTI
316
+ """
317
+ # calculate mean and covariance statistics
318
+ mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
319
+ mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
320
+ # calculate sum squared difference between means
321
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
322
+ # calculate sqrt of product between cov
323
+ covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
324
+ # check and correct imaginary numbers from sqrt
325
+ if np.iscomplexobj(covmean):
326
+ covmean = covmean.real
327
+ # calculate score
328
+ fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
329
+
330
+ return fid
@@ -0,0 +1,173 @@
1
+ """
2
+ Training and testing loops for HuggingFace's transformer models for natural
3
+ language processing.
4
+ """
5
+
6
+ import math
7
+ from typing import Optional
8
+
9
+ from torch.utils.data import RandomSampler, Sampler
10
+
11
+ from transformers import (
12
+ AutoConfig,
13
+ AutoTokenizer,
14
+ HfArgumentParser,
15
+ TrainerCallback,
16
+ LlamaTokenizer,
17
+ )
18
+ from transformers import Trainer as HuggingFaceTrainer
19
+ from transformers import TrainingArguments, default_data_collator
20
+
21
+ from plato.config import Config
22
+ from plato.trainers import basic
23
+
24
+
25
+ class SampledHuggingFaceTrainer(HuggingFaceTrainer):
26
+ """
27
+ Training and testing loops for HuggingFace's transformer models for natural
28
+ language processing.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ model,
34
+ args,
35
+ train_dataset,
36
+ eval_dataset,
37
+ tokenizer,
38
+ data_collator,
39
+ sampler,
40
+ callbacks,
41
+ ):
42
+ super().__init__(
43
+ model=model,
44
+ args=args,
45
+ train_dataset=train_dataset,
46
+ eval_dataset=eval_dataset,
47
+ tokenizer=tokenizer,
48
+ data_collator=data_collator,
49
+ callbacks=callbacks,
50
+ )
51
+ self.sampler = sampler
52
+
53
+ def _get_train_sampler(self) -> Optional[Sampler]:
54
+ if self.sampler is None:
55
+ return RandomSampler(self.train_dataset)
56
+
57
+ return self.sampler
58
+
59
+ def _get_eval_sampler(self, eval_dataset) -> Optional[Sampler]:
60
+ if self.sampler is None:
61
+ return super()._get_eval_sampler(eval_dataset)
62
+
63
+ return self.sampler
64
+
65
+
66
+ class Trainer(basic.Trainer):
67
+ """The trainer for HuggingFace transformer models for natural language processing."""
68
+
69
+ def __init__(self, model=None, callbacks=None):
70
+ super().__init__(model)
71
+
72
+ self.trainer = None
73
+ self.trainer_callbacks = []
74
+ if callbacks:
75
+ # Huggingface needs to check callback types
76
+ self.add_callbacks(callbacks)
77
+
78
+ self.model.train()
79
+
80
+ parser = HfArgumentParser(TrainingArguments)
81
+ (self.training_args,) = parser.parse_args_into_dataclasses(
82
+ args=[
83
+ "--output_dir=" + Config.params["checkpoint_path"],
84
+ "--report_to=none",
85
+ ]
86
+ )
87
+
88
+ model_name = Config().trainer.model_name
89
+ config_kwargs = {
90
+ "cache_dir": None,
91
+ "revision": "main",
92
+ "use_auth_token": None,
93
+ }
94
+ self.config = AutoConfig.from_pretrained(model_name, **config_kwargs)
95
+
96
+ tokenizer_kwargs = {
97
+ "cache_dir": None,
98
+ "use_fast": True,
99
+ "revision": "main",
100
+ "use_auth_token": None,
101
+ }
102
+ if "llama" in model_name:
103
+ self.tokenizer = LlamaTokenizer.from_pretrained(
104
+ model_name, config=self.config, **tokenizer_kwargs
105
+ )
106
+ else:
107
+ self.tokenizer = AutoTokenizer.from_pretrained(
108
+ model_name, config=self.config, **tokenizer_kwargs
109
+ )
110
+
111
+ # pylint: disable=unused-argument
112
+ def train_model(self, config, trainset, sampler, **kwargs):
113
+ """The training loop for HuggingFace models.
114
+
115
+ Arguments:
116
+ config: A dictionary of configuration parameters.
117
+ trainset: The training dataset.
118
+ sampler: the sampler that extracts a partition for this client.
119
+ """
120
+
121
+ self.training_args.num_train_epochs = config["epochs"]
122
+ self.training_args.per_device_train_batch_size = config["batch_size"]
123
+
124
+ self.trainer = SampledHuggingFaceTrainer(
125
+ model=self.model,
126
+ args=self.training_args,
127
+ train_dataset=trainset,
128
+ eval_dataset=None,
129
+ tokenizer=self.tokenizer,
130
+ data_collator=default_data_collator,
131
+ sampler=sampler,
132
+ callbacks=self.trainer_callbacks,
133
+ )
134
+
135
+ self.trainer.train()
136
+
137
+ def test_model(self, config, testset, sampler=None, **kwargs): # pylint: disable=unused-argument
138
+ """The testing loop for HuggingFace models.
139
+
140
+ Arguments:
141
+ config: Configuration parameters as a dictionary.
142
+ testset: The test dataset.
143
+ """
144
+ self.training_args.per_device_eval_batch_size = config["batch_size"]
145
+
146
+ self.trainer = SampledHuggingFaceTrainer(
147
+ model=self.model,
148
+ args=self.training_args,
149
+ train_dataset=None,
150
+ eval_dataset=testset,
151
+ tokenizer=self.tokenizer,
152
+ data_collator=default_data_collator,
153
+ sampler=sampler,
154
+ callbacks=None,
155
+ )
156
+
157
+ metrics = self.trainer.evaluate()
158
+
159
+ try:
160
+ perplexity = math.exp(metrics["eval_loss"])
161
+ except OverflowError:
162
+ perplexity = float("inf")
163
+
164
+ return perplexity
165
+
166
+ def add_callbacks(self, callbacks):
167
+ """Callbacks will be handled by Huggingface instead of Plato."""
168
+ for callback in callbacks:
169
+ if not issubclass(callback, TrainerCallback):
170
+ raise ValueError(
171
+ f"Huggingface trainer expects subclass of {TrainerCallback}, got {callback} instead."
172
+ )
173
+ self.trainer_callbacks.extend(callbacks)