plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/trainers/yolov8.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
"""The YOLOV8 model for PyTorch."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
|
5
|
+
from plato.config import Config
|
6
|
+
from plato.trainers import basic
|
7
|
+
|
8
|
+
|
9
|
+
class Trainer(basic.Trainer):
|
10
|
+
"""The YOLOV8 trainer."""
|
11
|
+
|
12
|
+
# pylint: disable=unused-argument
|
13
|
+
def train_model(self, config, trainset, sampler, **kwargs):
|
14
|
+
"""The training loop for YOLOv8.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
config: A dictionary of configuration parameters.
|
18
|
+
trainset: The training dataset.
|
19
|
+
"""
|
20
|
+
self.model.train(
|
21
|
+
data=Config().data.data_params,
|
22
|
+
epochs=Config().trainer.epochs,
|
23
|
+
)
|
24
|
+
|
25
|
+
self.train_run_end(config)
|
26
|
+
self.callback_handler.call_event("on_train_run_end", self, config)
|
27
|
+
|
28
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
29
|
+
"""The test loop for YOLOv8.
|
30
|
+
|
31
|
+
Arguments:
|
32
|
+
config: A dictionary of configuration parameters.
|
33
|
+
testset: The test dataset.
|
34
|
+
"""
|
35
|
+
|
36
|
+
logging.info("[%s] Started model testing.", self)
|
37
|
+
metrics = self.model.val(
|
38
|
+
data=Config().data.data_params,
|
39
|
+
)
|
40
|
+
|
41
|
+
return metrics.box.map50
|
plato/utils/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from prettytable import PrettyTable
|
2
|
+
import torch
|
3
|
+
|
4
|
+
|
5
|
+
def count_parameters(model):
|
6
|
+
table = PrettyTable(["Modules", "Parameters"])
|
7
|
+
total_params = 0
|
8
|
+
for name, parameter in model.named_parameters():
|
9
|
+
if not parameter.requires_grad:
|
10
|
+
continue
|
11
|
+
params = parameter.numel()
|
12
|
+
table.add_row([name, params])
|
13
|
+
total_params += params
|
14
|
+
print(table)
|
15
|
+
print(f"Total Trainable Params: {total_params}")
|
16
|
+
return total_params
|
17
|
+
|
18
|
+
|
19
|
+
resnet18 = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=True)
|
20
|
+
mobilenet = torch.hub.load("pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=True)
|
21
|
+
alexnet = torch.hub.load("pytorch/vision:v0.10.0", "alexnet", pretrained=True)
|
22
|
+
|
23
|
+
print("The size of ResNet-18:")
|
24
|
+
count_parameters(resnet18)
|
25
|
+
|
26
|
+
print("\nThe size of MobileNet:")
|
27
|
+
count_parameters(mobilenet)
|
28
|
+
|
29
|
+
print("\nThe size of AlexNet:")
|
30
|
+
count_parameters(alexnet)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions that write results into a CSV file.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import csv
|
6
|
+
import os
|
7
|
+
from typing import List
|
8
|
+
|
9
|
+
|
10
|
+
def initialize_csv(result_csv_file: str, logged_items: List, result_path: str) -> None:
|
11
|
+
"""Create a CSV file and writer the first row."""
|
12
|
+
# Create a new directory if it does not exist
|
13
|
+
if not os.path.exists(result_path):
|
14
|
+
os.makedirs(result_path)
|
15
|
+
|
16
|
+
with open(result_csv_file, "w", encoding="utf-8") as result_file:
|
17
|
+
result_writer = csv.writer(result_file)
|
18
|
+
header_row = logged_items
|
19
|
+
result_writer.writerow(header_row)
|
20
|
+
|
21
|
+
|
22
|
+
def write_csv(result_csv_file: str, new_row: List) -> None:
|
23
|
+
"""Write the results of current round."""
|
24
|
+
with open(result_csv_file, "a", encoding="utf-8") as result_file:
|
25
|
+
result_writer = csv.writer(result_file)
|
26
|
+
result_writer.writerow(new_row)
|
@@ -0,0 +1,148 @@
|
|
1
|
+
"""
|
2
|
+
The implementation of various wrappers to support flexible combinations of data loaders.
|
3
|
+
|
4
|
+
Two types of data loader wrappers are supported:
|
5
|
+
|
6
|
+
- ParallelDataLoader
|
7
|
+
- SequentialDataLoader
|
8
|
+
|
9
|
+
One specific utilization condition is self-supervised learning, where datasets,
|
10
|
+
such as STL10, contains trainsets with and without labels, and the desired data
|
11
|
+
loader first loads the trainset with labels and then the one without labels.
|
12
|
+
We can use SequentialDataLoader for this purpose.
|
13
|
+
|
14
|
+
"""
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
|
18
|
+
|
19
|
+
class ParallelIterator:
|
20
|
+
"""An iterator to support iterating along each data loader simultaneously to generate
|
21
|
+
one batch."""
|
22
|
+
|
23
|
+
def __init__(self, defined_compound_loader):
|
24
|
+
self.defined_compound_loader = defined_compound_loader
|
25
|
+
self.compound_loaders = self.defined_compound_loader.loaders
|
26
|
+
self.loader_iters = [iter(loader) for loader in self.compound_loaders]
|
27
|
+
|
28
|
+
def __iter__(self):
|
29
|
+
return self
|
30
|
+
|
31
|
+
def __next__(self):
|
32
|
+
# When the shortest loader (the one with minimum number of batches)
|
33
|
+
# terminates, this iterator will terminates.
|
34
|
+
# The `StopIteration` raised inside that shortest loader's `__next__`
|
35
|
+
# method will in turn gets out of this `__next__` method.
|
36
|
+
batches = [next(loader_iter) for loader_iter in self.loader_iters]
|
37
|
+
return self.defined_compound_loader.combine_batch(batches)
|
38
|
+
|
39
|
+
def __len__(self):
|
40
|
+
return len(self.defined_compound_loader)
|
41
|
+
|
42
|
+
|
43
|
+
class ParallelDataLoader:
|
44
|
+
"""This class wraps several pytorch DataLoader objects, allowing each time
|
45
|
+
taking a batch from each of them and then combining these several batches
|
46
|
+
into one. This class mimics the `for batch in loader:` interface of
|
47
|
+
pytorch `DataLoader`.
|
48
|
+
|
49
|
+
:param defined_loaders: a list or tuple of pytorch DataLoader objects
|
50
|
+
|
51
|
+
[For example]
|
52
|
+
There are two dataloaders A and B.
|
53
|
+
With ParallelDataLoader, one iter will extract one batch of samples 'A_b'
|
54
|
+
from A and one batch of samples 'B_b' from B. Thus, the loaded term is a
|
55
|
+
list containing [A_b, B_b].
|
56
|
+
|
57
|
+
The size of this dataloader equals to the minimum length of the dataloader
|
58
|
+
within input defined loaders.
|
59
|
+
"""
|
60
|
+
|
61
|
+
def __init__(self, defined_loaders):
|
62
|
+
self.loaders = [loader for loader in defined_loaders if loader is not None]
|
63
|
+
|
64
|
+
def __iter__(self):
|
65
|
+
return ParallelIterator(self)
|
66
|
+
|
67
|
+
def __len__(self):
|
68
|
+
return min(len(loader) for loader in self.loaders)
|
69
|
+
|
70
|
+
def combine_batch(self, batches):
|
71
|
+
"""Customize the behavior of combining batches here."""
|
72
|
+
return batches
|
73
|
+
|
74
|
+
|
75
|
+
class SequentialIterator:
|
76
|
+
"""An iterator to support iterating through each data loader sequentially.
|
77
|
+
|
78
|
+
For example, there are three data loaders, A, B, and, C:
|
79
|
+
the iteration will start from A, once A finished, B will start; then C will start.
|
80
|
+
|
81
|
+
Thus, the length of this iterator is:
|
82
|
+
len(A) + len(B) + len(C)
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(self, defined_compound_loader):
|
86
|
+
# only utilize the vaild loaders
|
87
|
+
|
88
|
+
self.defined_compound_loader = defined_compound_loader
|
89
|
+
self.compound_loaders = self.defined_compound_loader.loaders
|
90
|
+
self.loader_iters = [iter(loader) for loader in self.compound_loaders]
|
91
|
+
|
92
|
+
self.loaders_len = [len(loader) for loader in self.compound_loaders]
|
93
|
+
self.loaders_batch_bound = np.cumsum(self.loaders_len, axis=0)
|
94
|
+
|
95
|
+
self.num_loaders = len(self.loaders_len)
|
96
|
+
self.batch_idx = 0
|
97
|
+
|
98
|
+
def __iter__(self):
|
99
|
+
return self
|
100
|
+
|
101
|
+
def __next__(self):
|
102
|
+
# When the final loader (the last loader in the input loaders)
|
103
|
+
# terminates, this iterator will terminates.
|
104
|
+
# The `StopIteration` raised inside that shortest loader's `__next__`
|
105
|
+
# method will in turn gets out of this `__next__` method.
|
106
|
+
cur_loader_idx = np.digitize(self.batch_idx, self.loaders_batch_bound)
|
107
|
+
|
108
|
+
# if completed the final loader, we just recycle to the final loader
|
109
|
+
# then, this loader will be terminated because:
|
110
|
+
# The `StopIteration` raised inside that final loader's `__next__`
|
111
|
+
if cur_loader_idx == self.num_loaders:
|
112
|
+
cur_loader_idx -= 1
|
113
|
+
|
114
|
+
loader_iter = self.loader_iters[cur_loader_idx]
|
115
|
+
batch = next(loader_iter)
|
116
|
+
|
117
|
+
self.batch_idx += 1
|
118
|
+
|
119
|
+
return self.defined_compound_loader.process_batch(batch)
|
120
|
+
|
121
|
+
def __len__(self):
|
122
|
+
return len(self.target_loader)
|
123
|
+
|
124
|
+
|
125
|
+
class SequentialDataLoader:
|
126
|
+
"""This class wraps several pytorch DataLoader objects, allowing each time
|
127
|
+
taking a batch from each of them and then combining these several batches
|
128
|
+
into one. This class mimics the `for batch in loader:` interface of
|
129
|
+
pytorch `DataLoader`.
|
130
|
+
|
131
|
+
:param defined_loaders: A list or tuple containing pytorch DataLoader objects
|
132
|
+
|
133
|
+
The size of this dataloader equals to the minimum length of the dataloader
|
134
|
+
within input defined loaders.
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(self, defined_loaders):
|
138
|
+
self.loaders = [loader for loader in defined_loaders if loader is not None]
|
139
|
+
|
140
|
+
def __iter__(self):
|
141
|
+
return SequentialIterator(self)
|
142
|
+
|
143
|
+
def __len__(self):
|
144
|
+
return sum(len(loader) for loader in self.loaders)
|
145
|
+
|
146
|
+
def process_batch(self, batch):
|
147
|
+
"""Customize the behavior of combining batches here."""
|
148
|
+
return batch
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""Useful decorators."""
|
2
|
+
|
3
|
+
import time
|
4
|
+
from functools import wraps
|
5
|
+
|
6
|
+
|
7
|
+
def timeit(func_timed):
|
8
|
+
"""Measures the time elapsed for a particular function 'func_timed'."""
|
9
|
+
|
10
|
+
@wraps(func_timed)
|
11
|
+
def timed(*args, **kwargs):
|
12
|
+
started = time.perf_counter()
|
13
|
+
output = func_timed(*args, **kwargs)
|
14
|
+
ended = time.perf_counter()
|
15
|
+
elapsed = ended - started
|
16
|
+
print(
|
17
|
+
'"{}" took {:.2f} seconds to execute.'.format(func_timed.__name__, elapsed)
|
18
|
+
)
|
19
|
+
if output is None:
|
20
|
+
return elapsed
|
21
|
+
else:
|
22
|
+
return output, elapsed
|
23
|
+
|
24
|
+
return timed
|
plato/utils/fonts.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Colours and fonts for logging messages
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
def colourize(message, colour="yellow", style="bold"):
|
7
|
+
"""Returns the message in input colour and style"""
|
8
|
+
reset = "\033[0m"
|
9
|
+
colours = {
|
10
|
+
"green": "\033[92m",
|
11
|
+
"blue": "\033[94m",
|
12
|
+
"yellow": "\033[93m",
|
13
|
+
"red": "\033[91m",
|
14
|
+
}
|
15
|
+
styles = {"standard": "", "bold": "\033[1m", "underline": "\033[4m"}
|
16
|
+
|
17
|
+
if not (colour in colours and style in styles):
|
18
|
+
raise ValueError(
|
19
|
+
f"Your colour '{colour}' or your style '{style}' is not supported."
|
20
|
+
f"\nThe supported colours are: {', '.join(colours)}. \nThe supported styles are: {', '.join(styles)}."
|
21
|
+
)
|
22
|
+
|
23
|
+
return colours[colour] + styles[style] + message + reset
|
plato/utils/homo_enc.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for homomorphric encryption with TenSEAL.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import zlib
|
8
|
+
from typing import OrderedDict
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import tenseal as ts
|
12
|
+
import torch
|
13
|
+
|
14
|
+
|
15
|
+
def get_ckks_context():
|
16
|
+
"""Obtain a TenSEAL context for encryption and decryption."""
|
17
|
+
context_dir = ".ckks_context/"
|
18
|
+
context_name = "context"
|
19
|
+
try:
|
20
|
+
with open(os.path.join(context_dir, context_name), "rb") as f:
|
21
|
+
return ts.context_from(f.read())
|
22
|
+
except FileNotFoundError:
|
23
|
+
# Create a new context if it does not exist
|
24
|
+
if not os.path.exists(context_dir):
|
25
|
+
os.mkdir(context_dir)
|
26
|
+
|
27
|
+
context = ts.context(
|
28
|
+
ts.SCHEME_TYPE.CKKS,
|
29
|
+
poly_modulus_degree=8192,
|
30
|
+
coeff_mod_bit_sizes=[60, 40, 40, 60],
|
31
|
+
)
|
32
|
+
context.global_scale = 2**40
|
33
|
+
|
34
|
+
with open(os.path.join(context_dir, context_name), "wb") as f:
|
35
|
+
f.write(context.serialize(save_secret_key=True))
|
36
|
+
f.close()
|
37
|
+
|
38
|
+
return context
|
39
|
+
|
40
|
+
|
41
|
+
def encrypt_weights(
|
42
|
+
plain_weights,
|
43
|
+
serialize=True,
|
44
|
+
context=None,
|
45
|
+
indices=None,
|
46
|
+
):
|
47
|
+
"""Flatten the model weights and encrypt the selected ones."""
|
48
|
+
assert context is not None
|
49
|
+
|
50
|
+
# Step 1: flatten all weight tensors to a vector
|
51
|
+
weights_vector = np.array([])
|
52
|
+
for weight in plain_weights.values():
|
53
|
+
weights_vector = np.append(weights_vector, weight)
|
54
|
+
|
55
|
+
# Step 2: set up the indices for encrypted weights
|
56
|
+
encrypt_indices = None
|
57
|
+
if indices is None:
|
58
|
+
encrypt_indices = np.arange(len(weights_vector)).tolist()
|
59
|
+
else:
|
60
|
+
encrypt_indices = indices
|
61
|
+
encrypt_indices.sort()
|
62
|
+
|
63
|
+
# Step 3: separate weights into encrypted and unencrypted ones
|
64
|
+
unencrypted_weights = np.delete(weights_vector, encrypt_indices)
|
65
|
+
weights_to_enc = weights_vector[encrypt_indices]
|
66
|
+
|
67
|
+
if len(weights_to_enc) == 0:
|
68
|
+
encrypted_weights = None
|
69
|
+
else:
|
70
|
+
encrypted_weights = _encrypt(weights_to_enc, context, serialize)
|
71
|
+
|
72
|
+
# Finish by wrapping up the information
|
73
|
+
output = wrap_encrypted_model(
|
74
|
+
unencrypted_weights, encrypted_weights, encrypt_indices
|
75
|
+
)
|
76
|
+
|
77
|
+
return output
|
78
|
+
|
79
|
+
|
80
|
+
def _encrypt(data_vector, context, serialize=True):
|
81
|
+
if serialize:
|
82
|
+
return ts.ckks_vector(context, data_vector).serialize()
|
83
|
+
else:
|
84
|
+
return ts.ckks_vector(context, data_vector)
|
85
|
+
|
86
|
+
|
87
|
+
def deserialize_weights(serialized_weights, context):
|
88
|
+
"""Deserialize the encrypted weights (not decrypted yet)."""
|
89
|
+
deserialized_weights = OrderedDict()
|
90
|
+
for name, weight in serialized_weights.items():
|
91
|
+
if name == "encrypted_weights" and weight is not None:
|
92
|
+
deser_weights_vector = ts.lazy_ckks_vector_from(weight)
|
93
|
+
deser_weights_vector.link_context(context)
|
94
|
+
deserialized_weights[name] = deser_weights_vector
|
95
|
+
else:
|
96
|
+
deserialized_weights[name] = weight
|
97
|
+
|
98
|
+
return deserialized_weights
|
99
|
+
|
100
|
+
|
101
|
+
def decrypt_weights(data, weight_shapes=None, para_nums=None):
|
102
|
+
"""Decrypt the vector and restore model weights according to the shapes."""
|
103
|
+
vector_length = []
|
104
|
+
for para_num in para_nums.values():
|
105
|
+
vector_length.append(para_num)
|
106
|
+
|
107
|
+
# Step 1: decrypt the encrypted weights
|
108
|
+
plaintext_weights_vector = None
|
109
|
+
unencrypted_weights, encrypted_weights, indices = extract_encrypted_model(data)
|
110
|
+
|
111
|
+
if len(indices) != 0:
|
112
|
+
decrypted_vector = np.array(encrypted_weights.decrypt())
|
113
|
+
|
114
|
+
vector_size = len(unencrypted_weights) + len(indices)
|
115
|
+
plaintext_weights_vector = np.zeros(vector_size)
|
116
|
+
plaintext_weights_vector[indices] = decrypted_vector
|
117
|
+
|
118
|
+
unencrypted_indices = np.delete(range(vector_size), indices)
|
119
|
+
plaintext_weights_vector[unencrypted_indices] = unencrypted_weights
|
120
|
+
else:
|
121
|
+
plaintext_weights_vector = unencrypted_weights
|
122
|
+
|
123
|
+
# Step 2: rebuild the original weight vector
|
124
|
+
decrypted_weights = OrderedDict()
|
125
|
+
plaintext_weights_vector = np.split(
|
126
|
+
plaintext_weights_vector, np.cumsum(vector_length)
|
127
|
+
)[:-1]
|
128
|
+
weight_index = 0
|
129
|
+
|
130
|
+
for name, shape in weight_shapes.items():
|
131
|
+
decrypted_weights[name] = plaintext_weights_vector[weight_index].reshape(shape)
|
132
|
+
try:
|
133
|
+
decrypted_weights[name] = torch.from_numpy(decrypted_weights[name])
|
134
|
+
except Exception:
|
135
|
+
# PyTorch does not exist, just return numpy array and handle it somewhere else.
|
136
|
+
decrypted_weights[name] = decrypted_weights[name]
|
137
|
+
weight_index = weight_index + 1
|
138
|
+
|
139
|
+
return decrypted_weights
|
140
|
+
|
141
|
+
|
142
|
+
def wrap_encrypted_model(unencrypted_weights, encrypted_weights, indices):
|
143
|
+
"""Wrap up the encrypted model in a dict as the message between server and client."""
|
144
|
+
message = {
|
145
|
+
"unencrypted_weights": unencrypted_weights,
|
146
|
+
"encrypted_weights": encrypted_weights,
|
147
|
+
"indices": indices,
|
148
|
+
}
|
149
|
+
|
150
|
+
return message
|
151
|
+
|
152
|
+
|
153
|
+
def extract_encrypted_model(data):
|
154
|
+
"""Extract infromation from the message of encrytped model"""
|
155
|
+
unencrypted_weights = data["unencrypted_weights"]
|
156
|
+
encrypted_weights = data["encrypted_weights"]
|
157
|
+
indices = data["indices"]
|
158
|
+
|
159
|
+
return unencrypted_weights, encrypted_weights, indices
|
160
|
+
|
161
|
+
|
162
|
+
def indices_to_bitmap(indices):
|
163
|
+
"""Turn a list of indices into a bitmap."""
|
164
|
+
if indices == []:
|
165
|
+
# In case of empty list
|
166
|
+
return indices
|
167
|
+
bit_array = np.zeros(np.max(indices) + 1, dtype=np.int8)
|
168
|
+
bit_array[indices] = 1
|
169
|
+
bitmap = np.packbits(bit_array)
|
170
|
+
|
171
|
+
# Compress the bitmap before sending it out
|
172
|
+
compressed_bitmap = zlib.compress(pickle.dumps(bitmap))
|
173
|
+
|
174
|
+
return compressed_bitmap
|
175
|
+
|
176
|
+
|
177
|
+
def bitmap_to_indices(bitmap):
|
178
|
+
"""Translate a bitmap back to a list of indices."""
|
179
|
+
if bitmap == []:
|
180
|
+
# In case of empty list
|
181
|
+
return bitmap
|
182
|
+
|
183
|
+
decompressed_bitmap = pickle.loads(zlib.decompress(bitmap))
|
184
|
+
bit_array = np.unpackbits(decompressed_bitmap)
|
185
|
+
indices = np.where(bit_array == 1)[0].tolist()
|
186
|
+
|
187
|
+
return indices
|
File without changes
|
File without changes
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import random
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
import torch.nn.functional as F
|
10
|
+
from plato.config import Config
|
11
|
+
from torch import nn
|
12
|
+
|
13
|
+
|
14
|
+
class ReplayMemory:
|
15
|
+
"""A simple example of replay memory buffer."""
|
16
|
+
|
17
|
+
def __init__(self, state_dim, action_dim, capacity, seed):
|
18
|
+
random.seed(seed)
|
19
|
+
self.device = Config().device()
|
20
|
+
self.capacity = int(capacity)
|
21
|
+
self.ptr = 0
|
22
|
+
self.size = 0
|
23
|
+
|
24
|
+
self.state = np.zeros((self.capacity, state_dim))
|
25
|
+
self.action = np.zeros((self.capacity, action_dim))
|
26
|
+
self.reward = np.zeros((self.capacity, 1))
|
27
|
+
self.next_state = np.zeros((self.capacity, state_dim))
|
28
|
+
self.done = np.zeros((self.capacity, 1))
|
29
|
+
|
30
|
+
def push(self, data):
|
31
|
+
self.state[self.ptr] = data[0]
|
32
|
+
self.action[self.ptr] = data[1]
|
33
|
+
self.reward[self.ptr] = data[2]
|
34
|
+
self.next_state[self.ptr] = data[3]
|
35
|
+
self.done[self.ptr] = data[4]
|
36
|
+
|
37
|
+
self.ptr = (self.ptr + 1) % self.capacity
|
38
|
+
self.size = min(self.size + 1, self.capacity)
|
39
|
+
|
40
|
+
def sample(self):
|
41
|
+
ind = np.random.randint(0, self.size, size=int(Config().algorithm.batch_size))
|
42
|
+
|
43
|
+
state = self.state[ind]
|
44
|
+
action = self.action[ind]
|
45
|
+
reward = self.reward[ind]
|
46
|
+
next_state = self.next_state[ind]
|
47
|
+
done = self.done[ind]
|
48
|
+
|
49
|
+
return state, action, reward, next_state, done
|
50
|
+
|
51
|
+
def __len__(self):
|
52
|
+
return self.size
|
53
|
+
|
54
|
+
|
55
|
+
class Actor(nn.Module):
|
56
|
+
def __init__(self, state_dim, action_dim, max_action):
|
57
|
+
super(Actor, self).__init__()
|
58
|
+
|
59
|
+
self.l1 = nn.Linear(state_dim, 400)
|
60
|
+
self.l2 = nn.Linear(400, 300)
|
61
|
+
self.l3 = nn.Linear(300, action_dim)
|
62
|
+
|
63
|
+
self.max_action = max_action
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
x = F.relu(self.l1(x))
|
67
|
+
x = F.relu(self.l2(x))
|
68
|
+
x = self.max_action * torch.tanh(self.l3(x))
|
69
|
+
return x
|
70
|
+
|
71
|
+
|
72
|
+
class Critic(nn.Module):
|
73
|
+
def __init__(self, state_dim, action_dim):
|
74
|
+
super(Critic, self).__init__()
|
75
|
+
|
76
|
+
self.l1 = nn.Linear(state_dim + action_dim, 400)
|
77
|
+
self.l2 = nn.Linear(400, 300)
|
78
|
+
self.l3 = nn.Linear(300, 1)
|
79
|
+
|
80
|
+
def forward(self, x, u):
|
81
|
+
x = F.relu(self.l1(torch.cat([x, u], 1)))
|
82
|
+
x = F.relu(self.l2(x))
|
83
|
+
x = self.l3(x)
|
84
|
+
return x
|
85
|
+
|
86
|
+
|
87
|
+
class Policy(ABC):
|
88
|
+
"""A simple example of DRL policy."""
|
89
|
+
|
90
|
+
def __init__(self, state_dim, action_dim):
|
91
|
+
self.max_action = Config().algorithm.max_action
|
92
|
+
self.device = Config().device()
|
93
|
+
self.total_it = 0
|
94
|
+
|
95
|
+
# Initialize NNs
|
96
|
+
self.actor = Actor(state_dim, action_dim, self.max_action).to(self.device)
|
97
|
+
self.actor_target = copy.deepcopy(self.actor)
|
98
|
+
self.actor_optimizer = torch.optim.Adam(
|
99
|
+
self.actor.parameters(), lr=Config().algorithm.learning_rate
|
100
|
+
)
|
101
|
+
|
102
|
+
self.critic = Critic(state_dim, action_dim).to(self.device)
|
103
|
+
self.critic_target = copy.deepcopy(self.critic)
|
104
|
+
self.critic_optimizer = torch.optim.Adam(
|
105
|
+
self.critic.parameters(), lr=Config().algorithm.learning_rate
|
106
|
+
)
|
107
|
+
# Initialize replay memory
|
108
|
+
self.replay_buffer = ReplayMemory(
|
109
|
+
state_dim,
|
110
|
+
action_dim,
|
111
|
+
Config().algorithm.replay_size,
|
112
|
+
Config().algorithm.replay_seed,
|
113
|
+
)
|
114
|
+
|
115
|
+
def save_model(self, ep=None):
|
116
|
+
"""Saving the model to a file."""
|
117
|
+
model_name = Config().algorithm.model_name
|
118
|
+
model_path = f"./models/{model_name}/"
|
119
|
+
if not os.path.exists(model_path):
|
120
|
+
os.makedirs(model_path)
|
121
|
+
if ep is not None:
|
122
|
+
model_path += "iter" + str(ep) + "_"
|
123
|
+
|
124
|
+
torch.save(self.actor.state_dict(), model_path + "actor.pth")
|
125
|
+
torch.save(
|
126
|
+
self.actor_optimizer.state_dict(), model_path + "actor_optimizer.pth"
|
127
|
+
)
|
128
|
+
torch.save(self.critic.state_dict(), model_path + "critic.pth")
|
129
|
+
torch.save(
|
130
|
+
self.critic_optimizer.state_dict(), model_path + "critic_optimizer.pth"
|
131
|
+
)
|
132
|
+
|
133
|
+
logging.info("[RL Agent] Model saved to %s.", model_path)
|
134
|
+
|
135
|
+
def load_model(self, ep=None):
|
136
|
+
"""Loading pre-trained model weights from a file."""
|
137
|
+
model_name = Config().algorithm.model_name
|
138
|
+
model_path = f"./models/{model_name}/"
|
139
|
+
if ep is not None:
|
140
|
+
model_path += "iter" + str(ep) + "_"
|
141
|
+
|
142
|
+
logging.info("[RL Agent] Loading a model from %s.", model_path)
|
143
|
+
|
144
|
+
self.actor.load_state_dict(torch.load(model_path + "actor.pth"))
|
145
|
+
self.actor_optimizer.load_state_dict(
|
146
|
+
torch.load(model_path + "actor_optimizer.pth")
|
147
|
+
)
|
148
|
+
self.critic.load_state_dict(torch.load(model_path + "critic.pth"))
|
149
|
+
self.critic_optimizer.load_state_dict(
|
150
|
+
torch.load(model_path + "critic_optimizer.pth")
|
151
|
+
)
|
152
|
+
|
153
|
+
@abstractmethod
|
154
|
+
def select_action(self, state, hidden=None, test=False):
|
155
|
+
"""Select action from policy."""
|
156
|
+
raise NotImplementedError("Please Implement this method")
|
157
|
+
|
158
|
+
@abstractmethod
|
159
|
+
def update(self):
|
160
|
+
"""Update policy."""
|
161
|
+
raise NotImplementedError("Please Implement this method")
|