plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,41 @@
1
+ """The YOLOV8 model for PyTorch."""
2
+
3
+ import logging
4
+
5
+ from plato.config import Config
6
+ from plato.trainers import basic
7
+
8
+
9
+ class Trainer(basic.Trainer):
10
+ """The YOLOV8 trainer."""
11
+
12
+ # pylint: disable=unused-argument
13
+ def train_model(self, config, trainset, sampler, **kwargs):
14
+ """The training loop for YOLOv8.
15
+
16
+ Arguments:
17
+ config: A dictionary of configuration parameters.
18
+ trainset: The training dataset.
19
+ """
20
+ self.model.train(
21
+ data=Config().data.data_params,
22
+ epochs=Config().trainer.epochs,
23
+ )
24
+
25
+ self.train_run_end(config)
26
+ self.callback_handler.call_event("on_train_run_end", self, config)
27
+
28
+ def test_model(self, config, testset, sampler=None, **kwargs):
29
+ """The test loop for YOLOv8.
30
+
31
+ Arguments:
32
+ config: A dictionary of configuration parameters.
33
+ testset: The test dataset.
34
+ """
35
+
36
+ logging.info("[%s] Started model testing.", self)
37
+ metrics = self.model.val(
38
+ data=Config().data.data_params,
39
+ )
40
+
41
+ return metrics.box.map50
File without changes
@@ -0,0 +1,30 @@
1
+ from prettytable import PrettyTable
2
+ import torch
3
+
4
+
5
+ def count_parameters(model):
6
+ table = PrettyTable(["Modules", "Parameters"])
7
+ total_params = 0
8
+ for name, parameter in model.named_parameters():
9
+ if not parameter.requires_grad:
10
+ continue
11
+ params = parameter.numel()
12
+ table.add_row([name, params])
13
+ total_params += params
14
+ print(table)
15
+ print(f"Total Trainable Params: {total_params}")
16
+ return total_params
17
+
18
+
19
+ resnet18 = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=True)
20
+ mobilenet = torch.hub.load("pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=True)
21
+ alexnet = torch.hub.load("pytorch/vision:v0.10.0", "alexnet", pretrained=True)
22
+
23
+ print("The size of ResNet-18:")
24
+ count_parameters(resnet18)
25
+
26
+ print("\nThe size of MobileNet:")
27
+ count_parameters(mobilenet)
28
+
29
+ print("\nThe size of AlexNet:")
30
+ count_parameters(alexnet)
@@ -0,0 +1,26 @@
1
+ """
2
+ Utility functions that write results into a CSV file.
3
+ """
4
+
5
+ import csv
6
+ import os
7
+ from typing import List
8
+
9
+
10
+ def initialize_csv(result_csv_file: str, logged_items: List, result_path: str) -> None:
11
+ """Create a CSV file and writer the first row."""
12
+ # Create a new directory if it does not exist
13
+ if not os.path.exists(result_path):
14
+ os.makedirs(result_path)
15
+
16
+ with open(result_csv_file, "w", encoding="utf-8") as result_file:
17
+ result_writer = csv.writer(result_file)
18
+ header_row = logged_items
19
+ result_writer.writerow(header_row)
20
+
21
+
22
+ def write_csv(result_csv_file: str, new_row: List) -> None:
23
+ """Write the results of current round."""
24
+ with open(result_csv_file, "a", encoding="utf-8") as result_file:
25
+ result_writer = csv.writer(result_file)
26
+ result_writer.writerow(new_row)
@@ -0,0 +1,148 @@
1
+ """
2
+ The implementation of various wrappers to support flexible combinations of data loaders.
3
+
4
+ Two types of data loader wrappers are supported:
5
+
6
+ - ParallelDataLoader
7
+ - SequentialDataLoader
8
+
9
+ One specific utilization condition is self-supervised learning, where datasets,
10
+ such as STL10, contains trainsets with and without labels, and the desired data
11
+ loader first loads the trainset with labels and then the one without labels.
12
+ We can use SequentialDataLoader for this purpose.
13
+
14
+ """
15
+
16
+ import numpy as np
17
+
18
+
19
+ class ParallelIterator:
20
+ """An iterator to support iterating along each data loader simultaneously to generate
21
+ one batch."""
22
+
23
+ def __init__(self, defined_compound_loader):
24
+ self.defined_compound_loader = defined_compound_loader
25
+ self.compound_loaders = self.defined_compound_loader.loaders
26
+ self.loader_iters = [iter(loader) for loader in self.compound_loaders]
27
+
28
+ def __iter__(self):
29
+ return self
30
+
31
+ def __next__(self):
32
+ # When the shortest loader (the one with minimum number of batches)
33
+ # terminates, this iterator will terminates.
34
+ # The `StopIteration` raised inside that shortest loader's `__next__`
35
+ # method will in turn gets out of this `__next__` method.
36
+ batches = [next(loader_iter) for loader_iter in self.loader_iters]
37
+ return self.defined_compound_loader.combine_batch(batches)
38
+
39
+ def __len__(self):
40
+ return len(self.defined_compound_loader)
41
+
42
+
43
+ class ParallelDataLoader:
44
+ """This class wraps several pytorch DataLoader objects, allowing each time
45
+ taking a batch from each of them and then combining these several batches
46
+ into one. This class mimics the `for batch in loader:` interface of
47
+ pytorch `DataLoader`.
48
+
49
+ :param defined_loaders: a list or tuple of pytorch DataLoader objects
50
+
51
+ [For example]
52
+ There are two dataloaders A and B.
53
+ With ParallelDataLoader, one iter will extract one batch of samples 'A_b'
54
+ from A and one batch of samples 'B_b' from B. Thus, the loaded term is a
55
+ list containing [A_b, B_b].
56
+
57
+ The size of this dataloader equals to the minimum length of the dataloader
58
+ within input defined loaders.
59
+ """
60
+
61
+ def __init__(self, defined_loaders):
62
+ self.loaders = [loader for loader in defined_loaders if loader is not None]
63
+
64
+ def __iter__(self):
65
+ return ParallelIterator(self)
66
+
67
+ def __len__(self):
68
+ return min(len(loader) for loader in self.loaders)
69
+
70
+ def combine_batch(self, batches):
71
+ """Customize the behavior of combining batches here."""
72
+ return batches
73
+
74
+
75
+ class SequentialIterator:
76
+ """An iterator to support iterating through each data loader sequentially.
77
+
78
+ For example, there are three data loaders, A, B, and, C:
79
+ the iteration will start from A, once A finished, B will start; then C will start.
80
+
81
+ Thus, the length of this iterator is:
82
+ len(A) + len(B) + len(C)
83
+ """
84
+
85
+ def __init__(self, defined_compound_loader):
86
+ # only utilize the vaild loaders
87
+
88
+ self.defined_compound_loader = defined_compound_loader
89
+ self.compound_loaders = self.defined_compound_loader.loaders
90
+ self.loader_iters = [iter(loader) for loader in self.compound_loaders]
91
+
92
+ self.loaders_len = [len(loader) for loader in self.compound_loaders]
93
+ self.loaders_batch_bound = np.cumsum(self.loaders_len, axis=0)
94
+
95
+ self.num_loaders = len(self.loaders_len)
96
+ self.batch_idx = 0
97
+
98
+ def __iter__(self):
99
+ return self
100
+
101
+ def __next__(self):
102
+ # When the final loader (the last loader in the input loaders)
103
+ # terminates, this iterator will terminates.
104
+ # The `StopIteration` raised inside that shortest loader's `__next__`
105
+ # method will in turn gets out of this `__next__` method.
106
+ cur_loader_idx = np.digitize(self.batch_idx, self.loaders_batch_bound)
107
+
108
+ # if completed the final loader, we just recycle to the final loader
109
+ # then, this loader will be terminated because:
110
+ # The `StopIteration` raised inside that final loader's `__next__`
111
+ if cur_loader_idx == self.num_loaders:
112
+ cur_loader_idx -= 1
113
+
114
+ loader_iter = self.loader_iters[cur_loader_idx]
115
+ batch = next(loader_iter)
116
+
117
+ self.batch_idx += 1
118
+
119
+ return self.defined_compound_loader.process_batch(batch)
120
+
121
+ def __len__(self):
122
+ return len(self.target_loader)
123
+
124
+
125
+ class SequentialDataLoader:
126
+ """This class wraps several pytorch DataLoader objects, allowing each time
127
+ taking a batch from each of them and then combining these several batches
128
+ into one. This class mimics the `for batch in loader:` interface of
129
+ pytorch `DataLoader`.
130
+
131
+ :param defined_loaders: A list or tuple containing pytorch DataLoader objects
132
+
133
+ The size of this dataloader equals to the minimum length of the dataloader
134
+ within input defined loaders.
135
+ """
136
+
137
+ def __init__(self, defined_loaders):
138
+ self.loaders = [loader for loader in defined_loaders if loader is not None]
139
+
140
+ def __iter__(self):
141
+ return SequentialIterator(self)
142
+
143
+ def __len__(self):
144
+ return sum(len(loader) for loader in self.loaders)
145
+
146
+ def process_batch(self, batch):
147
+ """Customize the behavior of combining batches here."""
148
+ return batch
@@ -0,0 +1,24 @@
1
+ """Useful decorators."""
2
+
3
+ import time
4
+ from functools import wraps
5
+
6
+
7
+ def timeit(func_timed):
8
+ """Measures the time elapsed for a particular function 'func_timed'."""
9
+
10
+ @wraps(func_timed)
11
+ def timed(*args, **kwargs):
12
+ started = time.perf_counter()
13
+ output = func_timed(*args, **kwargs)
14
+ ended = time.perf_counter()
15
+ elapsed = ended - started
16
+ print(
17
+ '"{}" took {:.2f} seconds to execute.'.format(func_timed.__name__, elapsed)
18
+ )
19
+ if output is None:
20
+ return elapsed
21
+ else:
22
+ return output, elapsed
23
+
24
+ return timed
plato/utils/fonts.py ADDED
@@ -0,0 +1,23 @@
1
+ """
2
+ Colours and fonts for logging messages
3
+ """
4
+
5
+
6
+ def colourize(message, colour="yellow", style="bold"):
7
+ """Returns the message in input colour and style"""
8
+ reset = "\033[0m"
9
+ colours = {
10
+ "green": "\033[92m",
11
+ "blue": "\033[94m",
12
+ "yellow": "\033[93m",
13
+ "red": "\033[91m",
14
+ }
15
+ styles = {"standard": "", "bold": "\033[1m", "underline": "\033[4m"}
16
+
17
+ if not (colour in colours and style in styles):
18
+ raise ValueError(
19
+ f"Your colour '{colour}' or your style '{style}' is not supported."
20
+ f"\nThe supported colours are: {', '.join(colours)}. \nThe supported styles are: {', '.join(styles)}."
21
+ )
22
+
23
+ return colours[colour] + styles[style] + message + reset
@@ -0,0 +1,187 @@
1
+ """
2
+ Utility functions for homomorphric encryption with TenSEAL.
3
+ """
4
+
5
+ import os
6
+ import pickle
7
+ import zlib
8
+ from typing import OrderedDict
9
+
10
+ import numpy as np
11
+ import tenseal as ts
12
+ import torch
13
+
14
+
15
+ def get_ckks_context():
16
+ """Obtain a TenSEAL context for encryption and decryption."""
17
+ context_dir = ".ckks_context/"
18
+ context_name = "context"
19
+ try:
20
+ with open(os.path.join(context_dir, context_name), "rb") as f:
21
+ return ts.context_from(f.read())
22
+ except FileNotFoundError:
23
+ # Create a new context if it does not exist
24
+ if not os.path.exists(context_dir):
25
+ os.mkdir(context_dir)
26
+
27
+ context = ts.context(
28
+ ts.SCHEME_TYPE.CKKS,
29
+ poly_modulus_degree=8192,
30
+ coeff_mod_bit_sizes=[60, 40, 40, 60],
31
+ )
32
+ context.global_scale = 2**40
33
+
34
+ with open(os.path.join(context_dir, context_name), "wb") as f:
35
+ f.write(context.serialize(save_secret_key=True))
36
+ f.close()
37
+
38
+ return context
39
+
40
+
41
+ def encrypt_weights(
42
+ plain_weights,
43
+ serialize=True,
44
+ context=None,
45
+ indices=None,
46
+ ):
47
+ """Flatten the model weights and encrypt the selected ones."""
48
+ assert context is not None
49
+
50
+ # Step 1: flatten all weight tensors to a vector
51
+ weights_vector = np.array([])
52
+ for weight in plain_weights.values():
53
+ weights_vector = np.append(weights_vector, weight)
54
+
55
+ # Step 2: set up the indices for encrypted weights
56
+ encrypt_indices = None
57
+ if indices is None:
58
+ encrypt_indices = np.arange(len(weights_vector)).tolist()
59
+ else:
60
+ encrypt_indices = indices
61
+ encrypt_indices.sort()
62
+
63
+ # Step 3: separate weights into encrypted and unencrypted ones
64
+ unencrypted_weights = np.delete(weights_vector, encrypt_indices)
65
+ weights_to_enc = weights_vector[encrypt_indices]
66
+
67
+ if len(weights_to_enc) == 0:
68
+ encrypted_weights = None
69
+ else:
70
+ encrypted_weights = _encrypt(weights_to_enc, context, serialize)
71
+
72
+ # Finish by wrapping up the information
73
+ output = wrap_encrypted_model(
74
+ unencrypted_weights, encrypted_weights, encrypt_indices
75
+ )
76
+
77
+ return output
78
+
79
+
80
+ def _encrypt(data_vector, context, serialize=True):
81
+ if serialize:
82
+ return ts.ckks_vector(context, data_vector).serialize()
83
+ else:
84
+ return ts.ckks_vector(context, data_vector)
85
+
86
+
87
+ def deserialize_weights(serialized_weights, context):
88
+ """Deserialize the encrypted weights (not decrypted yet)."""
89
+ deserialized_weights = OrderedDict()
90
+ for name, weight in serialized_weights.items():
91
+ if name == "encrypted_weights" and weight is not None:
92
+ deser_weights_vector = ts.lazy_ckks_vector_from(weight)
93
+ deser_weights_vector.link_context(context)
94
+ deserialized_weights[name] = deser_weights_vector
95
+ else:
96
+ deserialized_weights[name] = weight
97
+
98
+ return deserialized_weights
99
+
100
+
101
+ def decrypt_weights(data, weight_shapes=None, para_nums=None):
102
+ """Decrypt the vector and restore model weights according to the shapes."""
103
+ vector_length = []
104
+ for para_num in para_nums.values():
105
+ vector_length.append(para_num)
106
+
107
+ # Step 1: decrypt the encrypted weights
108
+ plaintext_weights_vector = None
109
+ unencrypted_weights, encrypted_weights, indices = extract_encrypted_model(data)
110
+
111
+ if len(indices) != 0:
112
+ decrypted_vector = np.array(encrypted_weights.decrypt())
113
+
114
+ vector_size = len(unencrypted_weights) + len(indices)
115
+ plaintext_weights_vector = np.zeros(vector_size)
116
+ plaintext_weights_vector[indices] = decrypted_vector
117
+
118
+ unencrypted_indices = np.delete(range(vector_size), indices)
119
+ plaintext_weights_vector[unencrypted_indices] = unencrypted_weights
120
+ else:
121
+ plaintext_weights_vector = unencrypted_weights
122
+
123
+ # Step 2: rebuild the original weight vector
124
+ decrypted_weights = OrderedDict()
125
+ plaintext_weights_vector = np.split(
126
+ plaintext_weights_vector, np.cumsum(vector_length)
127
+ )[:-1]
128
+ weight_index = 0
129
+
130
+ for name, shape in weight_shapes.items():
131
+ decrypted_weights[name] = plaintext_weights_vector[weight_index].reshape(shape)
132
+ try:
133
+ decrypted_weights[name] = torch.from_numpy(decrypted_weights[name])
134
+ except Exception:
135
+ # PyTorch does not exist, just return numpy array and handle it somewhere else.
136
+ decrypted_weights[name] = decrypted_weights[name]
137
+ weight_index = weight_index + 1
138
+
139
+ return decrypted_weights
140
+
141
+
142
+ def wrap_encrypted_model(unencrypted_weights, encrypted_weights, indices):
143
+ """Wrap up the encrypted model in a dict as the message between server and client."""
144
+ message = {
145
+ "unencrypted_weights": unencrypted_weights,
146
+ "encrypted_weights": encrypted_weights,
147
+ "indices": indices,
148
+ }
149
+
150
+ return message
151
+
152
+
153
+ def extract_encrypted_model(data):
154
+ """Extract infromation from the message of encrytped model"""
155
+ unencrypted_weights = data["unencrypted_weights"]
156
+ encrypted_weights = data["encrypted_weights"]
157
+ indices = data["indices"]
158
+
159
+ return unencrypted_weights, encrypted_weights, indices
160
+
161
+
162
+ def indices_to_bitmap(indices):
163
+ """Turn a list of indices into a bitmap."""
164
+ if indices == []:
165
+ # In case of empty list
166
+ return indices
167
+ bit_array = np.zeros(np.max(indices) + 1, dtype=np.int8)
168
+ bit_array[indices] = 1
169
+ bitmap = np.packbits(bit_array)
170
+
171
+ # Compress the bitmap before sending it out
172
+ compressed_bitmap = zlib.compress(pickle.dumps(bitmap))
173
+
174
+ return compressed_bitmap
175
+
176
+
177
+ def bitmap_to_indices(bitmap):
178
+ """Translate a bitmap back to a list of indices."""
179
+ if bitmap == []:
180
+ # In case of empty list
181
+ return bitmap
182
+
183
+ decompressed_bitmap = pickle.loads(zlib.decompress(bitmap))
184
+ bit_array = np.unpackbits(decompressed_bitmap)
185
+ indices = np.where(bit_array == 1)[0].tolist()
186
+
187
+ return indices
File without changes
@@ -0,0 +1,161 @@
1
+ import copy
2
+ import logging
3
+ import os
4
+ import random
5
+ from abc import ABC, abstractmethod
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from plato.config import Config
11
+ from torch import nn
12
+
13
+
14
+ class ReplayMemory:
15
+ """A simple example of replay memory buffer."""
16
+
17
+ def __init__(self, state_dim, action_dim, capacity, seed):
18
+ random.seed(seed)
19
+ self.device = Config().device()
20
+ self.capacity = int(capacity)
21
+ self.ptr = 0
22
+ self.size = 0
23
+
24
+ self.state = np.zeros((self.capacity, state_dim))
25
+ self.action = np.zeros((self.capacity, action_dim))
26
+ self.reward = np.zeros((self.capacity, 1))
27
+ self.next_state = np.zeros((self.capacity, state_dim))
28
+ self.done = np.zeros((self.capacity, 1))
29
+
30
+ def push(self, data):
31
+ self.state[self.ptr] = data[0]
32
+ self.action[self.ptr] = data[1]
33
+ self.reward[self.ptr] = data[2]
34
+ self.next_state[self.ptr] = data[3]
35
+ self.done[self.ptr] = data[4]
36
+
37
+ self.ptr = (self.ptr + 1) % self.capacity
38
+ self.size = min(self.size + 1, self.capacity)
39
+
40
+ def sample(self):
41
+ ind = np.random.randint(0, self.size, size=int(Config().algorithm.batch_size))
42
+
43
+ state = self.state[ind]
44
+ action = self.action[ind]
45
+ reward = self.reward[ind]
46
+ next_state = self.next_state[ind]
47
+ done = self.done[ind]
48
+
49
+ return state, action, reward, next_state, done
50
+
51
+ def __len__(self):
52
+ return self.size
53
+
54
+
55
+ class Actor(nn.Module):
56
+ def __init__(self, state_dim, action_dim, max_action):
57
+ super(Actor, self).__init__()
58
+
59
+ self.l1 = nn.Linear(state_dim, 400)
60
+ self.l2 = nn.Linear(400, 300)
61
+ self.l3 = nn.Linear(300, action_dim)
62
+
63
+ self.max_action = max_action
64
+
65
+ def forward(self, x):
66
+ x = F.relu(self.l1(x))
67
+ x = F.relu(self.l2(x))
68
+ x = self.max_action * torch.tanh(self.l3(x))
69
+ return x
70
+
71
+
72
+ class Critic(nn.Module):
73
+ def __init__(self, state_dim, action_dim):
74
+ super(Critic, self).__init__()
75
+
76
+ self.l1 = nn.Linear(state_dim + action_dim, 400)
77
+ self.l2 = nn.Linear(400, 300)
78
+ self.l3 = nn.Linear(300, 1)
79
+
80
+ def forward(self, x, u):
81
+ x = F.relu(self.l1(torch.cat([x, u], 1)))
82
+ x = F.relu(self.l2(x))
83
+ x = self.l3(x)
84
+ return x
85
+
86
+
87
+ class Policy(ABC):
88
+ """A simple example of DRL policy."""
89
+
90
+ def __init__(self, state_dim, action_dim):
91
+ self.max_action = Config().algorithm.max_action
92
+ self.device = Config().device()
93
+ self.total_it = 0
94
+
95
+ # Initialize NNs
96
+ self.actor = Actor(state_dim, action_dim, self.max_action).to(self.device)
97
+ self.actor_target = copy.deepcopy(self.actor)
98
+ self.actor_optimizer = torch.optim.Adam(
99
+ self.actor.parameters(), lr=Config().algorithm.learning_rate
100
+ )
101
+
102
+ self.critic = Critic(state_dim, action_dim).to(self.device)
103
+ self.critic_target = copy.deepcopy(self.critic)
104
+ self.critic_optimizer = torch.optim.Adam(
105
+ self.critic.parameters(), lr=Config().algorithm.learning_rate
106
+ )
107
+ # Initialize replay memory
108
+ self.replay_buffer = ReplayMemory(
109
+ state_dim,
110
+ action_dim,
111
+ Config().algorithm.replay_size,
112
+ Config().algorithm.replay_seed,
113
+ )
114
+
115
+ def save_model(self, ep=None):
116
+ """Saving the model to a file."""
117
+ model_name = Config().algorithm.model_name
118
+ model_path = f"./models/{model_name}/"
119
+ if not os.path.exists(model_path):
120
+ os.makedirs(model_path)
121
+ if ep is not None:
122
+ model_path += "iter" + str(ep) + "_"
123
+
124
+ torch.save(self.actor.state_dict(), model_path + "actor.pth")
125
+ torch.save(
126
+ self.actor_optimizer.state_dict(), model_path + "actor_optimizer.pth"
127
+ )
128
+ torch.save(self.critic.state_dict(), model_path + "critic.pth")
129
+ torch.save(
130
+ self.critic_optimizer.state_dict(), model_path + "critic_optimizer.pth"
131
+ )
132
+
133
+ logging.info("[RL Agent] Model saved to %s.", model_path)
134
+
135
+ def load_model(self, ep=None):
136
+ """Loading pre-trained model weights from a file."""
137
+ model_name = Config().algorithm.model_name
138
+ model_path = f"./models/{model_name}/"
139
+ if ep is not None:
140
+ model_path += "iter" + str(ep) + "_"
141
+
142
+ logging.info("[RL Agent] Loading a model from %s.", model_path)
143
+
144
+ self.actor.load_state_dict(torch.load(model_path + "actor.pth"))
145
+ self.actor_optimizer.load_state_dict(
146
+ torch.load(model_path + "actor_optimizer.pth")
147
+ )
148
+ self.critic.load_state_dict(torch.load(model_path + "critic.pth"))
149
+ self.critic_optimizer.load_state_dict(
150
+ torch.load(model_path + "critic_optimizer.pth")
151
+ )
152
+
153
+ @abstractmethod
154
+ def select_action(self, state, hidden=None, test=False):
155
+ """Select action from policy."""
156
+ raise NotImplementedError("Please Implement this method")
157
+
158
+ @abstractmethod
159
+ def update(self):
160
+ """Update policy."""
161
+ raise NotImplementedError("Please Implement this method")