plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,178 @@
|
|
1
|
+
"""
|
2
|
+
The training and testing loops for PyTorch.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import time
|
7
|
+
|
8
|
+
from opacus import GradSampleModule
|
9
|
+
from opacus.privacy_engine import PrivacyEngine
|
10
|
+
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
11
|
+
from opacus.validators import ModuleValidator
|
12
|
+
from torch.utils.data import Subset
|
13
|
+
|
14
|
+
from plato.config import Config
|
15
|
+
from plato.trainers import basic
|
16
|
+
|
17
|
+
|
18
|
+
class Trainer(basic.Trainer):
|
19
|
+
"""A differentially private federated learning trainer, used by the client."""
|
20
|
+
|
21
|
+
def __init__(self, model=None, **kwargs):
|
22
|
+
"""Initializing the trainer with the provided model."""
|
23
|
+
super().__init__(model=model)
|
24
|
+
|
25
|
+
self.max_physical_batch_size = (
|
26
|
+
Config().trainer.max_physical_batch_size
|
27
|
+
if hasattr(Config().trainer, "max_physical_batch_size")
|
28
|
+
else 128
|
29
|
+
)
|
30
|
+
|
31
|
+
self.make_model_private()
|
32
|
+
|
33
|
+
def make_model_private(self):
|
34
|
+
"""Make the model private for use with the differential privacy engine."""
|
35
|
+
errors = ModuleValidator.validate(self.model, strict=False)
|
36
|
+
if len(errors) > 0:
|
37
|
+
self.model = ModuleValidator.fix(self.model)
|
38
|
+
errors = ModuleValidator.validate(self.model, strict=False)
|
39
|
+
assert len(errors) == 0
|
40
|
+
|
41
|
+
# pylint: disable=unused-argument
|
42
|
+
def train_model(self, config, trainset, sampler, **kwargs):
|
43
|
+
"""The default training loop that supports differential privacy."""
|
44
|
+
batch_size = config["batch_size"]
|
45
|
+
self.sampler = sampler
|
46
|
+
tic = time.perf_counter()
|
47
|
+
|
48
|
+
self.train_run_start(config)
|
49
|
+
self.callback_handler.call_event("on_train_run_start", self, config)
|
50
|
+
|
51
|
+
# We have to use poisson sampling to sample the data, rather than the provided sampler.
|
52
|
+
# Replacing the poisson sampler with the provided sampler is problematic since it may
|
53
|
+
# violate the basic theory of DP-SGD. Therefore, we need to first obtain the train subset
|
54
|
+
# based on the provided sampler, and then create a simple dataloader on the train subset
|
55
|
+
# without the sampler. We will finally use Opacus to recreate the dataloader from the
|
56
|
+
# simple dataloader (with poisson sampling).
|
57
|
+
trainset = Subset(trainset, list(sampler))
|
58
|
+
self.train_loader = self.get_train_loader(batch_size, trainset, sampler=None)
|
59
|
+
|
60
|
+
# Initializing the loss criterion
|
61
|
+
_loss_criterion = self.get_loss_criterion()
|
62
|
+
|
63
|
+
# Initializing the optimizer
|
64
|
+
optimizer = self.get_optimizer(self.model)
|
65
|
+
self.lr_scheduler = self.get_lr_scheduler(config, optimizer)
|
66
|
+
optimizer = self._adjust_lr(config, self.lr_scheduler, optimizer)
|
67
|
+
|
68
|
+
self.model.to(self.device)
|
69
|
+
total_epochs = config["epochs"]
|
70
|
+
|
71
|
+
logging.info(
|
72
|
+
"[Client #%s] Using differential privacy during training.",
|
73
|
+
self.client_id,
|
74
|
+
)
|
75
|
+
|
76
|
+
privacy_engine = PrivacyEngine(accountant="rdp", secure_mode=False)
|
77
|
+
|
78
|
+
self.model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
|
79
|
+
module=self.model,
|
80
|
+
optimizer=optimizer,
|
81
|
+
data_loader=self.train_loader,
|
82
|
+
target_epsilon=config["dp_epsilon"] if "dp_epsilon" in config else 10.0,
|
83
|
+
target_delta=config["dp_delta"] if "dp_delta" in config else 1e-5,
|
84
|
+
epochs=total_epochs,
|
85
|
+
max_grad_norm=config["dp_max_grad_norm"]
|
86
|
+
if "max_grad_norm" in config
|
87
|
+
else 1.0,
|
88
|
+
)
|
89
|
+
|
90
|
+
self.model.train()
|
91
|
+
|
92
|
+
for self.current_epoch in range(1, total_epochs + 1):
|
93
|
+
with BatchMemoryManager(
|
94
|
+
data_loader=train_loader,
|
95
|
+
max_physical_batch_size=self.max_physical_batch_size,
|
96
|
+
optimizer=optimizer,
|
97
|
+
) as memory_safe_train_loader:
|
98
|
+
self._loss_tracker.reset()
|
99
|
+
self.train_epoch_start(config)
|
100
|
+
self.callback_handler.call_event("on_train_epoch_start", self, config)
|
101
|
+
|
102
|
+
for batch_id, (examples, labels) in enumerate(memory_safe_train_loader):
|
103
|
+
examples, labels = (
|
104
|
+
examples.to(self.device),
|
105
|
+
labels.to(self.device),
|
106
|
+
)
|
107
|
+
optimizer.zero_grad(set_to_none=True)
|
108
|
+
|
109
|
+
outputs = self.model(examples)
|
110
|
+
|
111
|
+
loss = _loss_criterion(outputs, labels)
|
112
|
+
self._loss_tracker.update(loss, labels.size(0))
|
113
|
+
|
114
|
+
if "create_graph" in config:
|
115
|
+
loss.backward(create_graph=config["create_graph"])
|
116
|
+
else:
|
117
|
+
loss.backward()
|
118
|
+
|
119
|
+
optimizer.step()
|
120
|
+
|
121
|
+
self.train_step_end(config, batch=batch_id, loss=loss)
|
122
|
+
self.callback_handler.call_event(
|
123
|
+
"on_train_step_end",
|
124
|
+
self,
|
125
|
+
config,
|
126
|
+
batch=batch_id,
|
127
|
+
loss=loss,
|
128
|
+
)
|
129
|
+
|
130
|
+
self.lr_scheduler_step()
|
131
|
+
|
132
|
+
if hasattr(optimizer, "params_state_update"):
|
133
|
+
optimizer.params_state_update()
|
134
|
+
|
135
|
+
# Simulate client's speed
|
136
|
+
if (
|
137
|
+
self.client_id != 0
|
138
|
+
and hasattr(Config().clients, "speed_simulation")
|
139
|
+
and Config().clients.speed_simulation
|
140
|
+
):
|
141
|
+
self.simulate_sleep_time()
|
142
|
+
|
143
|
+
# Saving the model at the end of this epoch to a file so that
|
144
|
+
# it can later be retrieved to respond to server requests
|
145
|
+
# in asynchronous mode when the wall clock time is simulated
|
146
|
+
if (
|
147
|
+
hasattr(Config().server, "request_update")
|
148
|
+
and Config().server.request_update
|
149
|
+
):
|
150
|
+
self.model.cpu()
|
151
|
+
training_time = time.perf_counter() - tic
|
152
|
+
filename = f"{self.client_id}_{self.current_epoch}_{training_time}.pth"
|
153
|
+
self.save_model(filename)
|
154
|
+
self.model.to(self.device)
|
155
|
+
|
156
|
+
self.run_history.update_metric("train_loss", self._loss_tracker.average)
|
157
|
+
self.train_epoch_end(config)
|
158
|
+
self.callback_handler.call_event("on_train_epoch_end", self, config)
|
159
|
+
|
160
|
+
self.train_run_end(config)
|
161
|
+
self.callback_handler.call_event("on_train_run_end", self, config)
|
162
|
+
|
163
|
+
def train_run_start(self, config):
|
164
|
+
"""
|
165
|
+
Method called at the start of training run.
|
166
|
+
"""
|
167
|
+
self.model = GradSampleModule(self.model)
|
168
|
+
|
169
|
+
def train_run_end(self, config):
|
170
|
+
"""
|
171
|
+
Method called at the end of a training run.
|
172
|
+
"""
|
173
|
+
# After GradSampleModule() conversion, the state_dict names have a `_module` prefix
|
174
|
+
# We will need to save the weights with the original layer names without the prefix
|
175
|
+
self.model_state_dict = {
|
176
|
+
k[8:] if "_module." in k else k: v
|
177
|
+
for k, v in self.model.state_dict().items()
|
178
|
+
}
|
plato/trainers/gan.py
ADDED
@@ -0,0 +1,330 @@
|
|
1
|
+
"""
|
2
|
+
The training and testing loops for GAN models.
|
3
|
+
|
4
|
+
Reference:
|
5
|
+
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import math
|
10
|
+
import os
|
11
|
+
|
12
|
+
import torch
|
13
|
+
import torch.nn as nn
|
14
|
+
import torchvision
|
15
|
+
import numpy as np
|
16
|
+
import scipy
|
17
|
+
|
18
|
+
from plato.config import Config
|
19
|
+
from plato.models import registry as models_registry
|
20
|
+
from plato.trainers import basic
|
21
|
+
from plato.trainers import optimizers
|
22
|
+
|
23
|
+
|
24
|
+
class Trainer(basic.Trainer):
|
25
|
+
"""A federated learning trainer for GAN models."""
|
26
|
+
|
27
|
+
def __init__(self, model=None, **kwargs):
|
28
|
+
super().__init__()
|
29
|
+
|
30
|
+
if model is None:
|
31
|
+
model = models_registry.get()
|
32
|
+
gan_model = model
|
33
|
+
self.generator = gan_model.generator
|
34
|
+
self.discriminator = gan_model.discriminator
|
35
|
+
self.loss_criterion = gan_model.loss_criterion
|
36
|
+
self.model = gan_model
|
37
|
+
|
38
|
+
# Use the pre-trained InceptionV3 model as a feature extractor
|
39
|
+
# for testing
|
40
|
+
self.inception_model = torchvision.models.inception_v3(
|
41
|
+
pretrained=True, aux_logits=False
|
42
|
+
)
|
43
|
+
# Remove the last output layer of inception
|
44
|
+
self.inception_model.fc = nn.Identity()
|
45
|
+
self.inception_model.eval()
|
46
|
+
|
47
|
+
self.training_start_time = 0
|
48
|
+
|
49
|
+
def save_model(self, filename=None, location=None):
|
50
|
+
"""Saving the model to a file."""
|
51
|
+
model_path = Config().params["model_path"] if location is None else location
|
52
|
+
model_name = Config().trainer.model_name
|
53
|
+
|
54
|
+
try:
|
55
|
+
if not os.path.exists(model_path):
|
56
|
+
os.makedirs(model_path)
|
57
|
+
except FileExistsError:
|
58
|
+
pass
|
59
|
+
|
60
|
+
if filename is not None:
|
61
|
+
net_gen_path = f"{model_path}/Generator_{filename}"
|
62
|
+
net_disc_path = f"{model_path}/Discriminator_{filename}"
|
63
|
+
else:
|
64
|
+
net_gen_path = f"{model_path}/Generator_{model_name}.pth"
|
65
|
+
net_disc_path = f"{model_path}/Discriminator_{model_name}.pth"
|
66
|
+
|
67
|
+
torch.save(self.generator.state_dict(), net_gen_path)
|
68
|
+
torch.save(self.discriminator.state_dict(), net_disc_path)
|
69
|
+
|
70
|
+
if self.client_id == 0:
|
71
|
+
logging.info(
|
72
|
+
"[Server #%d] Generator Model saved to %s.", os.getpid(), net_gen_path
|
73
|
+
)
|
74
|
+
logging.info(
|
75
|
+
"[Server #%d] Discriminator Model saved to %s.",
|
76
|
+
os.getpid(),
|
77
|
+
net_disc_path,
|
78
|
+
)
|
79
|
+
else:
|
80
|
+
logging.info(
|
81
|
+
"[Client #%d] Generator Model saved to %s.",
|
82
|
+
self.client_id,
|
83
|
+
net_gen_path,
|
84
|
+
)
|
85
|
+
logging.info(
|
86
|
+
"[Client #%d] Discriminator Model saved to %s.",
|
87
|
+
self.client_id,
|
88
|
+
net_disc_path,
|
89
|
+
)
|
90
|
+
|
91
|
+
def load_model(self, filename=None, location=None):
|
92
|
+
"""Loading pre-trained model weights from a file."""
|
93
|
+
model_path = Config().params["model_path"] if location is None else location
|
94
|
+
model_name = Config().trainer.model_name
|
95
|
+
|
96
|
+
if filename is not None:
|
97
|
+
net_gen_path = f"{model_path}/Generator_{filename}"
|
98
|
+
net_disc_path = f"{model_path}/Discriminator_{filename}"
|
99
|
+
else:
|
100
|
+
net_gen_path = f"{model_path}/Generator_{model_name}.pth"
|
101
|
+
net_disc_path = f"{model_path}/Discriminator_{model_name}.pth"
|
102
|
+
|
103
|
+
if self.client_id == 0:
|
104
|
+
logging.info(
|
105
|
+
"[Server #%d] Loading a Generator model from %s.",
|
106
|
+
os.getpid(),
|
107
|
+
net_gen_path,
|
108
|
+
)
|
109
|
+
logging.info(
|
110
|
+
"[Server #%d] Loading a Discriminator model from %s.",
|
111
|
+
os.getpid(),
|
112
|
+
net_disc_path,
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
logging.info(
|
116
|
+
"[Client #%d] Loading a Generator model from %s.",
|
117
|
+
self.client_id,
|
118
|
+
net_gen_path,
|
119
|
+
)
|
120
|
+
logging.info(
|
121
|
+
"[Client #%d] Loading a Discriminator model from %s.",
|
122
|
+
self.client_id,
|
123
|
+
net_disc_path,
|
124
|
+
)
|
125
|
+
|
126
|
+
self.generator.load_state_dict(torch.load(net_gen_path))
|
127
|
+
self.discriminator.load_state_dict(torch.load(net_disc_path))
|
128
|
+
|
129
|
+
# pylint: disable=unused-argument
|
130
|
+
def train_model(self, config, trainset, sampler, **kwargs):
|
131
|
+
"""The main training loop in a federated learning workload.
|
132
|
+
|
133
|
+
Arguments:
|
134
|
+
trainset: The training dataset.
|
135
|
+
sampler: the sampler that extracts a partition for this client.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
float: The training time.
|
139
|
+
"""
|
140
|
+
batch_size = config["batch_size"]
|
141
|
+
log_interval = 10
|
142
|
+
|
143
|
+
logging.info("[Client #%d] Loading the dataset.", self.client_id)
|
144
|
+
|
145
|
+
train_loader = torch.utils.data.DataLoader(
|
146
|
+
dataset=trainset, shuffle=False, batch_size=batch_size, sampler=sampler
|
147
|
+
)
|
148
|
+
|
149
|
+
self.model.to(self.device)
|
150
|
+
self.model.train()
|
151
|
+
|
152
|
+
# self.generator.apply(self.model.weights_init)
|
153
|
+
# self.discriminator.apply(self.model.weights_init)
|
154
|
+
|
155
|
+
optimizer_gen = optimizers.get(self.generator)
|
156
|
+
optimizer_disc = optimizers.get(self.discriminator)
|
157
|
+
|
158
|
+
real_label = 1.0
|
159
|
+
fake_label = 0.0
|
160
|
+
|
161
|
+
epochs = config["epochs"]
|
162
|
+
for epoch in range(1, epochs + 1):
|
163
|
+
# Here we assume the data samples still have labels attached to them,
|
164
|
+
# but GAN training does not need labels, so we'll just discard them
|
165
|
+
for batch_id, (examples, _) in enumerate(train_loader):
|
166
|
+
cur_batch_size = len(examples)
|
167
|
+
examples = examples.to(self.device)
|
168
|
+
label = torch.full((cur_batch_size,), real_label, dtype=torch.float)
|
169
|
+
label = label.to(self.device)
|
170
|
+
############################
|
171
|
+
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
|
172
|
+
###########################
|
173
|
+
## Train with all-real batch
|
174
|
+
optimizer_disc.zero_grad()
|
175
|
+
# Forward pass real batch through D
|
176
|
+
output = self.discriminator(examples).view(-1)
|
177
|
+
# Calculate loss on all-real batch
|
178
|
+
err_disc_real = self.loss_criterion(output, label)
|
179
|
+
# Calculate gradients for D in backward pass
|
180
|
+
err_disc_real.backward()
|
181
|
+
|
182
|
+
## Train with all-fake batch
|
183
|
+
# Generate batch of latent vectors
|
184
|
+
noise = torch.randn(
|
185
|
+
cur_batch_size, self.model.nz, 1, 1, device=self.device
|
186
|
+
)
|
187
|
+
# Generate fake image batch with G
|
188
|
+
fake = self.generator(noise)
|
189
|
+
label.fill_(fake_label)
|
190
|
+
# Classify all fake batch with D
|
191
|
+
output = self.discriminator(fake.detach()).view(-1)
|
192
|
+
# Calculate D's loss on the all-fake batch
|
193
|
+
err_disc_fake = self.loss_criterion(output, label)
|
194
|
+
# Calculate the gradients for this batch, accumulated (summed)
|
195
|
+
# with previous gradients
|
196
|
+
err_disc_fake.backward()
|
197
|
+
# Compute error of D as sum over the fake and the real batches
|
198
|
+
err_disc_total = err_disc_real + err_disc_fake
|
199
|
+
# Update D
|
200
|
+
optimizer_disc.step()
|
201
|
+
|
202
|
+
############################
|
203
|
+
# (2) Update G network: maximize log(D(G(z)))
|
204
|
+
###########################
|
205
|
+
optimizer_gen.zero_grad()
|
206
|
+
label.fill_(real_label) # fake labels are real for generator cost
|
207
|
+
# Since we just updated D, perform another forward pass of all-fake batch through D
|
208
|
+
output = self.discriminator(fake).view(-1)
|
209
|
+
# Calculate G's loss based on this output
|
210
|
+
err_gen = self.loss_criterion(output, label)
|
211
|
+
# Calculate gradients for G
|
212
|
+
err_gen.backward()
|
213
|
+
# Update G
|
214
|
+
optimizer_gen.step()
|
215
|
+
|
216
|
+
if batch_id % log_interval == 0:
|
217
|
+
if self.client_id == 0:
|
218
|
+
logging.info(
|
219
|
+
"[Server #%d] Epoch: [%d/%d][%d/%d]\tGenerator Loss: %.6f\t"
|
220
|
+
"Discriminator Loss: %.6f",
|
221
|
+
os.getpid(),
|
222
|
+
epoch,
|
223
|
+
epochs,
|
224
|
+
batch_id,
|
225
|
+
len(train_loader),
|
226
|
+
err_gen.data.item(),
|
227
|
+
err_disc_total.data.item(),
|
228
|
+
)
|
229
|
+
else:
|
230
|
+
logging.info(
|
231
|
+
"[Client #%d] Epoch: [%d/%d][%d/%d]\tGenerator Loss: %.6f\t"
|
232
|
+
"Discriminator Loss: %.6f",
|
233
|
+
self.client_id,
|
234
|
+
epoch,
|
235
|
+
epochs,
|
236
|
+
batch_id,
|
237
|
+
len(train_loader),
|
238
|
+
err_gen.data.item(),
|
239
|
+
err_disc_total.data.item(),
|
240
|
+
)
|
241
|
+
|
242
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
243
|
+
"""Test the Generator model with the Frechet Inception Distance metric."""
|
244
|
+
|
245
|
+
self.model.to(self.device)
|
246
|
+
self.model.eval()
|
247
|
+
|
248
|
+
perplexity = -1
|
249
|
+
|
250
|
+
test_loader = torch.utils.data.DataLoader(
|
251
|
+
testset, batch_size=config["batch_size"], shuffle=True
|
252
|
+
)
|
253
|
+
|
254
|
+
real_features, fake_features = [], []
|
255
|
+
with torch.no_grad():
|
256
|
+
for real_examples, _ in test_loader:
|
257
|
+
real_examples = real_examples.to(self.device)
|
258
|
+
|
259
|
+
noise = torch.randn(
|
260
|
+
config["batch_size"], self.model.nz, 1, 1, device=self.device
|
261
|
+
)
|
262
|
+
fake_examples = self.generator(noise)
|
263
|
+
|
264
|
+
# Extract the feature of real and synthetic data with
|
265
|
+
# InceptionV3 model pre-trained on ImageNet
|
266
|
+
self.inception_model.to(self.device)
|
267
|
+
feature_real = self.feature_extractor(real_examples)
|
268
|
+
feature_fake = self.feature_extractor(fake_examples)
|
269
|
+
|
270
|
+
# Store the feature of every real and synthetic data
|
271
|
+
real_features.extend(list(feature_real))
|
272
|
+
fake_features.extend(list(feature_fake))
|
273
|
+
|
274
|
+
real_features, fake_features = (
|
275
|
+
np.stack(real_features),
|
276
|
+
np.stack(fake_features),
|
277
|
+
)
|
278
|
+
# Calculate the Frechet Distance between the feature distribution
|
279
|
+
# of real data from testset and the feature distribution of data
|
280
|
+
# generated by the generator.
|
281
|
+
perplexity = self.calculate_fid(real_features, fake_features)
|
282
|
+
|
283
|
+
return perplexity
|
284
|
+
|
285
|
+
def feature_extractor(self, inputs):
|
286
|
+
"""Extract the feature of input data with InceptionV3.
|
287
|
+
|
288
|
+
The feature extracted from each input is a NumPy array
|
289
|
+
of length 2048.
|
290
|
+
"""
|
291
|
+
# Since the input to InceptionV3 needs to be at least 75x75,
|
292
|
+
# we will pad the input image if needed.
|
293
|
+
hpad = math.ceil((75 - inputs.size(dim=-2)) / 2)
|
294
|
+
vpad = math.ceil((75 - inputs.size(dim=-1)) / 2)
|
295
|
+
hpad, vpad = max(0, hpad), max(0, vpad)
|
296
|
+
pad = nn.ZeroPad2d((hpad, hpad, vpad, vpad))
|
297
|
+
inputs = pad(inputs)
|
298
|
+
|
299
|
+
# Extract feature with InceptionV3
|
300
|
+
features = None
|
301
|
+
with torch.no_grad():
|
302
|
+
features = self.inception_model(inputs)
|
303
|
+
features = features.cpu()
|
304
|
+
features = np.array(features)
|
305
|
+
|
306
|
+
return features
|
307
|
+
|
308
|
+
def calculate_fid(self, real_features, fake_features):
|
309
|
+
"""Calculate the Frechet Inception Distance (FID) between the
|
310
|
+
given real data feature and the synthetic data feature.
|
311
|
+
|
312
|
+
A lower FID indicates a better Generator model.
|
313
|
+
|
314
|
+
The implementation is borrowed from the following link:
|
315
|
+
https://wandb.ai/ayush-thakur/gan-evaluation/reports/How-to-Evaluate-GANs-using-Frechet-Inception-Distance-FID---Vmlldzo0MTAxOTI
|
316
|
+
"""
|
317
|
+
# calculate mean and covariance statistics
|
318
|
+
mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
|
319
|
+
mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
|
320
|
+
# calculate sum squared difference between means
|
321
|
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
322
|
+
# calculate sqrt of product between cov
|
323
|
+
covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
|
324
|
+
# check and correct imaginary numbers from sqrt
|
325
|
+
if np.iscomplexobj(covmean):
|
326
|
+
covmean = covmean.real
|
327
|
+
# calculate score
|
328
|
+
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
329
|
+
|
330
|
+
return fid
|
@@ -0,0 +1,173 @@
|
|
1
|
+
"""
|
2
|
+
Training and testing loops for HuggingFace's transformer models for natural
|
3
|
+
language processing.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import math
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
from torch.utils.data import RandomSampler, Sampler
|
10
|
+
|
11
|
+
from transformers import (
|
12
|
+
AutoConfig,
|
13
|
+
AutoTokenizer,
|
14
|
+
HfArgumentParser,
|
15
|
+
TrainerCallback,
|
16
|
+
LlamaTokenizer,
|
17
|
+
)
|
18
|
+
from transformers import Trainer as HuggingFaceTrainer
|
19
|
+
from transformers import TrainingArguments, default_data_collator
|
20
|
+
|
21
|
+
from plato.config import Config
|
22
|
+
from plato.trainers import basic
|
23
|
+
|
24
|
+
|
25
|
+
class SampledHuggingFaceTrainer(HuggingFaceTrainer):
|
26
|
+
"""
|
27
|
+
Training and testing loops for HuggingFace's transformer models for natural
|
28
|
+
language processing.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
model,
|
34
|
+
args,
|
35
|
+
train_dataset,
|
36
|
+
eval_dataset,
|
37
|
+
tokenizer,
|
38
|
+
data_collator,
|
39
|
+
sampler,
|
40
|
+
callbacks,
|
41
|
+
):
|
42
|
+
super().__init__(
|
43
|
+
model=model,
|
44
|
+
args=args,
|
45
|
+
train_dataset=train_dataset,
|
46
|
+
eval_dataset=eval_dataset,
|
47
|
+
tokenizer=tokenizer,
|
48
|
+
data_collator=data_collator,
|
49
|
+
callbacks=callbacks,
|
50
|
+
)
|
51
|
+
self.sampler = sampler
|
52
|
+
|
53
|
+
def _get_train_sampler(self) -> Optional[Sampler]:
|
54
|
+
if self.sampler is None:
|
55
|
+
return RandomSampler(self.train_dataset)
|
56
|
+
|
57
|
+
return self.sampler
|
58
|
+
|
59
|
+
def _get_eval_sampler(self, eval_dataset) -> Optional[Sampler]:
|
60
|
+
if self.sampler is None:
|
61
|
+
return super()._get_eval_sampler(eval_dataset)
|
62
|
+
|
63
|
+
return self.sampler
|
64
|
+
|
65
|
+
|
66
|
+
class Trainer(basic.Trainer):
|
67
|
+
"""The trainer for HuggingFace transformer models for natural language processing."""
|
68
|
+
|
69
|
+
def __init__(self, model=None, callbacks=None):
|
70
|
+
super().__init__(model)
|
71
|
+
|
72
|
+
self.trainer = None
|
73
|
+
self.trainer_callbacks = []
|
74
|
+
if callbacks:
|
75
|
+
# Huggingface needs to check callback types
|
76
|
+
self.add_callbacks(callbacks)
|
77
|
+
|
78
|
+
self.model.train()
|
79
|
+
|
80
|
+
parser = HfArgumentParser(TrainingArguments)
|
81
|
+
(self.training_args,) = parser.parse_args_into_dataclasses(
|
82
|
+
args=[
|
83
|
+
"--output_dir=" + Config.params["checkpoint_path"],
|
84
|
+
"--report_to=none",
|
85
|
+
]
|
86
|
+
)
|
87
|
+
|
88
|
+
model_name = Config().trainer.model_name
|
89
|
+
config_kwargs = {
|
90
|
+
"cache_dir": None,
|
91
|
+
"revision": "main",
|
92
|
+
"use_auth_token": None,
|
93
|
+
}
|
94
|
+
self.config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
95
|
+
|
96
|
+
tokenizer_kwargs = {
|
97
|
+
"cache_dir": None,
|
98
|
+
"use_fast": True,
|
99
|
+
"revision": "main",
|
100
|
+
"use_auth_token": None,
|
101
|
+
}
|
102
|
+
if "llama" in model_name:
|
103
|
+
self.tokenizer = LlamaTokenizer.from_pretrained(
|
104
|
+
model_name, config=self.config, **tokenizer_kwargs
|
105
|
+
)
|
106
|
+
else:
|
107
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
108
|
+
model_name, config=self.config, **tokenizer_kwargs
|
109
|
+
)
|
110
|
+
|
111
|
+
# pylint: disable=unused-argument
|
112
|
+
def train_model(self, config, trainset, sampler, **kwargs):
|
113
|
+
"""The training loop for HuggingFace models.
|
114
|
+
|
115
|
+
Arguments:
|
116
|
+
config: A dictionary of configuration parameters.
|
117
|
+
trainset: The training dataset.
|
118
|
+
sampler: the sampler that extracts a partition for this client.
|
119
|
+
"""
|
120
|
+
|
121
|
+
self.training_args.num_train_epochs = config["epochs"]
|
122
|
+
self.training_args.per_device_train_batch_size = config["batch_size"]
|
123
|
+
|
124
|
+
self.trainer = SampledHuggingFaceTrainer(
|
125
|
+
model=self.model,
|
126
|
+
args=self.training_args,
|
127
|
+
train_dataset=trainset,
|
128
|
+
eval_dataset=None,
|
129
|
+
tokenizer=self.tokenizer,
|
130
|
+
data_collator=default_data_collator,
|
131
|
+
sampler=sampler,
|
132
|
+
callbacks=self.trainer_callbacks,
|
133
|
+
)
|
134
|
+
|
135
|
+
self.trainer.train()
|
136
|
+
|
137
|
+
def test_model(self, config, testset, sampler=None, **kwargs): # pylint: disable=unused-argument
|
138
|
+
"""The testing loop for HuggingFace models.
|
139
|
+
|
140
|
+
Arguments:
|
141
|
+
config: Configuration parameters as a dictionary.
|
142
|
+
testset: The test dataset.
|
143
|
+
"""
|
144
|
+
self.training_args.per_device_eval_batch_size = config["batch_size"]
|
145
|
+
|
146
|
+
self.trainer = SampledHuggingFaceTrainer(
|
147
|
+
model=self.model,
|
148
|
+
args=self.training_args,
|
149
|
+
train_dataset=None,
|
150
|
+
eval_dataset=testset,
|
151
|
+
tokenizer=self.tokenizer,
|
152
|
+
data_collator=default_data_collator,
|
153
|
+
sampler=sampler,
|
154
|
+
callbacks=None,
|
155
|
+
)
|
156
|
+
|
157
|
+
metrics = self.trainer.evaluate()
|
158
|
+
|
159
|
+
try:
|
160
|
+
perplexity = math.exp(metrics["eval_loss"])
|
161
|
+
except OverflowError:
|
162
|
+
perplexity = float("inf")
|
163
|
+
|
164
|
+
return perplexity
|
165
|
+
|
166
|
+
def add_callbacks(self, callbacks):
|
167
|
+
"""Callbacks will be handled by Huggingface instead of Plato."""
|
168
|
+
for callback in callbacks:
|
169
|
+
if not issubclass(callback, TrainerCallback):
|
170
|
+
raise ValueError(
|
171
|
+
f"Huggingface trainer expects subclass of {TrainerCallback}, got {callback} instead."
|
172
|
+
)
|
173
|
+
self.trainer_callbacks.extend(callbacks)
|