plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,48 @@
1
+ """
2
+ Implements a Processor for applying local differential privacy using additive noise mechanism.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import numpy
9
+
10
+ from plato.processors import feature
11
+
12
+
13
+ class Processor(feature.Processor):
14
+ """
15
+ Implements a Processor for applying local differential privacy using additive noise mechanism.
16
+ """
17
+
18
+ methods = {
19
+ "gaussian": numpy.random.normal,
20
+ "laplace": numpy.random.laplace,
21
+ }
22
+
23
+ def __init__(self, method="", scale=None, **kwargs) -> None:
24
+ self._method = method
25
+
26
+ def func(logits, targets):
27
+ return (
28
+ Processor.methods[method](logits, scale),
29
+ targets,
30
+ )
31
+
32
+ super().__init__(method=func, **kwargs)
33
+
34
+ def process(self, data: Any) -> Any:
35
+ """
36
+ Implements a Processor for applying randomized response as the local differential privacy
37
+ mechanism.
38
+ """
39
+
40
+ output = super().process(data)
41
+
42
+ logging.info(
43
+ "[Client #%d] Local differential privacy (using the %s mechanism) applied.",
44
+ self.client_id,
45
+ self._method,
46
+ )
47
+
48
+ return output
@@ -0,0 +1,34 @@
1
+ """
2
+ Implements a Processor for applying dequantization to MistNet PyTorch features.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from plato.processors import base
11
+
12
+
13
+ class Processor(base.Processor):
14
+ """
15
+ Implements a Processor for applying dequantization to MistNet PyTorch features.
16
+ """
17
+
18
+ def __init__(self, server_id=None, **kwargs) -> None:
19
+ super().__init__(**kwargs)
20
+
21
+ self.server_id = server_id
22
+
23
+ def process(self, data: Any) -> Any:
24
+ """
25
+ Implements a Processor for applying dequantization to MistNet PyTorch features.
26
+ """
27
+ feature_dataset = []
28
+
29
+ for logit, target in data:
30
+ feature_dataset.append((torch.dequantize(logit), target))
31
+
32
+ logging.info("[Server #%d] Dequantized features.", self.server_id)
33
+
34
+ return feature_dataset
@@ -0,0 +1,17 @@
1
+ """
2
+ Implements a Processor for applying local differential privacy using gaussian mechanism.
3
+ """
4
+
5
+ import math
6
+
7
+ from plato.processors import feature_additive_noise
8
+
9
+
10
+ class Processor(feature_additive_noise.Processor):
11
+ """
12
+ Implements a Processor for applying local differential privacy using gaussian mechanism.
13
+ """
14
+
15
+ def __init__(self, epsilon=None, delta=None, sensitivity=None, **kwargs) -> None:
16
+ scale = 2 * math.log(1.25 / delta) * sensitivity**2 / epsilon**2
17
+ super().__init__(method="gaussian", scale=scale, **kwargs)
@@ -0,0 +1,15 @@
1
+ """
2
+ Implements a Processor for applying local differential privacy using laplace mechanism.
3
+ """
4
+
5
+ from plato.processors import feature_additive_noise
6
+
7
+
8
+ class Processor(feature_additive_noise.Processor):
9
+ """
10
+ Implements a Processor for applying local differential privacy using the Laplace mechanism.
11
+ """
12
+
13
+ def __init__(self, epsilon=None, sensitivity=None, **kwargs) -> None:
14
+ scale = sensitivity / epsilon
15
+ super().__init__(method="laplace", scale=scale, **kwargs)
@@ -0,0 +1,34 @@
1
+ """
2
+ Implements a Processor for applying quantization to MistNet PyTorch features.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from plato.processors import feature
11
+
12
+
13
+ class Processor(feature.Processor):
14
+ """
15
+ Implements a Processor for applying quantization to MistNet PyTorch features.
16
+ """
17
+
18
+ def __init__(self, scale=0.1, zero_point=10, dtype=torch.quint8, **kwargs) -> None:
19
+ def func(logits, targets):
20
+ logits = torch.quantize_per_tensor(logits, scale, zero_point, dtype)
21
+ return logits, targets
22
+
23
+ super().__init__(method=func, use_numpy=False, **kwargs)
24
+
25
+ def process(self, data: Any) -> Any:
26
+ """
27
+ Implements a Processor for applying quantization to MistNet PyTorch features.
28
+ """
29
+
30
+ output = super().process(data)
31
+
32
+ logging.info("[Client #%d] Quantization applied.", self.client_id)
33
+
34
+ return output
@@ -0,0 +1,50 @@
1
+ """
2
+ Implements a Processor for applying local differential privacy using randomized response.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from plato.config import Config
9
+ from plato.processors import feature
10
+ from plato.utils import unary_encoding
11
+
12
+
13
+ class Processor(feature.Processor):
14
+ """
15
+ Implements a Processor for applying local differential privacy using randomized response.
16
+ """
17
+
18
+ def __init__(self, **kwargs) -> None:
19
+ def func(logits, targets):
20
+ logits = unary_encoding.encode(logits)
21
+
22
+ if Config().algorithm.epsilon is None:
23
+ return logits, targets
24
+
25
+ _randomize = getattr(self.trainer, "randomize", None)
26
+ epsilon = Config().algorithm.epsilon
27
+
28
+ if callable(_randomize):
29
+ logits = self.trainer.randomize(logits, targets, epsilon)
30
+ else:
31
+ logits = unary_encoding.randomize(logits, epsilon)
32
+
33
+ return logits, targets
34
+
35
+ super().__init__(method=func, **kwargs)
36
+
37
+ def process(self, data: Any) -> Any:
38
+ """
39
+ Implements a Processor for applying randomized response as the
40
+ local differential privacy mechanism.
41
+ """
42
+
43
+ output = super().process(data)
44
+
45
+ logging.info(
46
+ "[Client #%d] Local differential privacy (using randomized response) applied.",
47
+ self.client_id,
48
+ )
49
+
50
+ return output
@@ -0,0 +1,39 @@
1
+ """
2
+ Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from plato.processors import base
11
+
12
+
13
+ class Processor(base.Processor):
14
+ """
15
+ Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
16
+ """
17
+
18
+ def __init__(self, client_id=None, **kwargs) -> None:
19
+ super().__init__(**kwargs)
20
+
21
+ self.client_id = client_id
22
+
23
+ def process(self, data: Any) -> Any:
24
+ """
25
+ Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
26
+ """
27
+ feature_dataset = []
28
+
29
+ for logits, targets in data:
30
+ for i in np.arange(logits.shape[0]): # each sample in the batch
31
+ feature_dataset.append((logits[i].clone(), targets[i].clone()))
32
+
33
+ logging.info(
34
+ "[Client #%d] Features extracted from %s examples.",
35
+ self.client_id,
36
+ len(feature_dataset),
37
+ )
38
+
39
+ return feature_dataset
@@ -0,0 +1,39 @@
1
+ """
2
+ Implements a Processor for converting MistNet features from numpy ndarrays to PyTorch tensors.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from plato.processors import base
11
+
12
+
13
+ class Processor(base.Processor):
14
+ """
15
+ Implements a Processor for converting MistNet features from numpy ndarrays to PyTorch tensors.
16
+ """
17
+
18
+ def __init__(self, server_id=None, **kwargs) -> None:
19
+ super().__init__(**kwargs)
20
+
21
+ self.server_id = server_id
22
+
23
+ def process(self, data: Any) -> Any:
24
+ """
25
+ Converts MistNet features from numpy ndarrays to PyTorch tensors.
26
+ """
27
+ feature_dataset = []
28
+
29
+ for logit, target in data:
30
+ # Uses torch.as_tensor() as opposed to torch.tensor() to avoid data copying
31
+ # according to https://pytorch.org/docs/stable/generated/torch.tensor.html
32
+ feature_dataset.append((torch.as_tensor(logit), torch.as_tensor(target)))
33
+
34
+ logging.info(
35
+ "[Server #%d] Features converted from ndarrays to PyTorch tensors.",
36
+ self.server_id,
37
+ )
38
+
39
+ return feature_dataset
@@ -0,0 +1,55 @@
1
+ """
2
+ Base processor for processing PyTorch models.
3
+ """
4
+
5
+ import logging
6
+ import pickle
7
+ import sys
8
+ from typing import OrderedDict
9
+
10
+ import torch
11
+
12
+ from plato.processors import base
13
+
14
+
15
+ class Processor(base.Processor):
16
+ """Base processor for processing PyTorch models."""
17
+
18
+ def __init__(self, client_id=None, server_id=None, **kwargs) -> None:
19
+ super().__init__(**kwargs)
20
+ self.client_id = client_id
21
+ self.server_id = server_id
22
+
23
+ def process(self, data: OrderedDict) -> OrderedDict:
24
+ """
25
+ Processes PyTorch model parameter.
26
+ The data is a state_dict of a PyTorch model.
27
+ """
28
+ old_data_size = sys.getsizeof(pickle.dumps(data))
29
+
30
+ new_data = OrderedDict()
31
+ for layer_name, layer_params in data.items():
32
+ new_data[layer_name] = self._process_layer(layer_params)
33
+
34
+ new_data_size = sys.getsizeof(pickle.dumps(new_data))
35
+
36
+ if self.client_id is None:
37
+ logging.info(
38
+ "[Server #%d] Processed the model and changed its size from %.2f MB to %.2f MB.",
39
+ self.server_id,
40
+ old_data_size / 1024**2,
41
+ new_data_size / 1024**2,
42
+ )
43
+ else:
44
+ logging.info(
45
+ "[Client #%d] Processed the model and changed its size from %.2f MB to %.2f MB.",
46
+ self.client_id,
47
+ old_data_size / 1024**2,
48
+ new_data_size / 1024**2,
49
+ )
50
+
51
+ return new_data
52
+
53
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
54
+ """Processes an individual layer of the model."""
55
+ return layer
@@ -0,0 +1,34 @@
1
+ """
2
+ Implements a Processor for compressing model weights.
3
+ """
4
+
5
+ import logging
6
+ import pickle
7
+ from typing import Any
8
+
9
+ import zstd
10
+
11
+ from plato.processors import model
12
+
13
+
14
+ class Processor(model.Processor):
15
+ """
16
+ Implements a Processor for compressing model parameters.
17
+ """
18
+
19
+ def __init__(self, compression_level=1, **kwargs) -> None:
20
+ super().__init__(**kwargs)
21
+
22
+ self.compression_level = compression_level
23
+
24
+ def process(self, data: Any) -> Any:
25
+ """Implements a Processor for compressing model parameters."""
26
+
27
+ output = zstd.compress(pickle.dumps(data), self.compression_level)
28
+
29
+ if self.client_id is None:
30
+ logging.info("[Server #%d] Compressed model parameters.", self.server_id)
31
+ else:
32
+ logging.info("[Client #%d] Compressed model parameters.", self.client_id)
33
+
34
+ return output
@@ -0,0 +1,37 @@
1
+ """
2
+ Implements a Processor for decompressing model weights.
3
+ """
4
+
5
+ import logging
6
+ import pickle
7
+ from typing import Any
8
+
9
+ import zstd
10
+
11
+ from plato.processors import model
12
+
13
+
14
+ class Processor(model.Processor):
15
+ """
16
+ Implements a Processor for decompressing model parameters.
17
+ """
18
+
19
+ def __init__(self, **kwargs) -> None:
20
+ super().__init__(**kwargs)
21
+
22
+ def process(self, data: Any) -> Any:
23
+ """Implements a Processor for decompressing model parameters."""
24
+
25
+ output = pickle.loads(zstd.decompress(data))
26
+
27
+ if self.client_id is None:
28
+ logging.info(
29
+ "[Server #%d] Decompressed received model parameters.",
30
+ self.server_id,
31
+ )
32
+ else:
33
+ logging.info(
34
+ "[Client #%d] Decompressed received model parameters.",
35
+ self.client_id,
36
+ )
37
+ return output
@@ -0,0 +1,41 @@
1
+ """
2
+ A processor that decrypts model weights of MaskCrypt.
3
+ """
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from plato.processors import model
10
+ from plato.utils import homo_enc
11
+
12
+
13
+ class Processor(model.Processor):
14
+ """
15
+ A processor that decrypts model tensors
16
+ """
17
+
18
+ def __init__(self, **kwargs) -> None:
19
+ super().__init__(**kwargs)
20
+
21
+ self.context = homo_enc.get_ckks_context()
22
+ weight_shapes = {}
23
+ para_nums = {}
24
+ extract_model = self.trainer.model.cpu().state_dict()
25
+
26
+ for key in extract_model.keys():
27
+ weight_shapes[key] = extract_model[key].size()
28
+ para_nums[key] = torch.numel(extract_model[key])
29
+
30
+ self.weight_shapes = weight_shapes
31
+ self.para_nums = para_nums
32
+
33
+ def process(self, data: Any) -> Any:
34
+ """Deserialize and decrypt the model weights."""
35
+ deserialized_weights = homo_enc.deserialize_weights(data, self.context)
36
+
37
+ output = homo_enc.decrypt_weights(
38
+ deserialized_weights, self.weight_shapes, self.para_nums
39
+ )
40
+
41
+ return output
@@ -0,0 +1,21 @@
1
+ """
2
+ Processor for creating a deep copy of the PyTorch model state_dict.
3
+ """
4
+
5
+ import copy
6
+
7
+ import torch
8
+
9
+ from plato.processors import model
10
+
11
+
12
+ class Processor(model.Processor):
13
+ """
14
+ Processor for creating a deep copy of the PyTorch model state_dict.
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ super().__init__(**kwargs)
19
+
20
+ def _process_layer(self, layer: torch.Tensor):
21
+ return copy.deepcopy(layer)
@@ -0,0 +1,18 @@
1
+ """
2
+ Implements a Processor for dequantizing model parameters.
3
+ """
4
+
5
+ import torch
6
+
7
+ from plato.processors import model
8
+
9
+
10
+ class Processor(model.Processor):
11
+ """
12
+ Implements a Processor for dequantizing model parameters.
13
+ """
14
+
15
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
16
+ """Quantizes each individual layer of the model."""
17
+
18
+ return layer.to(torch.float32)
@@ -0,0 +1,61 @@
1
+ """
2
+ Implements a Processor to decompress and dequantize upload models.
3
+
4
+ In more detail, this processor first decompresses each received parameter.
5
+ Next, dequantize each upload parameter under the given quantization level.
6
+ Hence, 8-bit received parameters can be converted into 32-bit parameters.
7
+
8
+ Reference:
9
+
10
+ Alistarh, D., Grubic, D., Li, J., Tomioka, R., & Vojnovic, M. (2017).
11
+ "QSGD: Communication-efficient SGD via gradient quantization and encoding."
12
+ Advances in neural information processing systems.
13
+
14
+ https://proceedings.neurips.cc/paper/2017/file/6c340f25839e6acdc73414517203f5f0-Paper.pdf
15
+ """
16
+
17
+ from struct import unpack
18
+ from typing import Any
19
+
20
+ import torch
21
+
22
+ from plato.processors import model
23
+
24
+
25
+ class Processor(model.Processor):
26
+ """
27
+ Implements a Processor to dequantize model parameters quantized with QSGD.
28
+ """
29
+
30
+ def __init__(self, quantization_level=64, **kwargs) -> None:
31
+ super().__init__(**kwargs)
32
+
33
+ self.quantization_level = quantization_level # must <= 128!
34
+
35
+ def _process_layer(self, layer: Any) -> Any:
36
+ """Dequantizes each individual layer of the model."""
37
+
38
+ # Step 1: decompress the header
39
+ tuning_param = self.quantization_level - 1
40
+ max_v = unpack("!f", layer[0:4])[0]
41
+ numel = unpack("!I", layer[4:8])[0]
42
+ dimensions = unpack("!h", layer[8:10])[0]
43
+ size = []
44
+ for i in range(dimensions):
45
+ size.append(unpack("!h", layer[10 + 2 * i : 12 + 2 * i])[0])
46
+
47
+ # Step 2: decompress the content
48
+ layer = layer[10 + 2 * dimensions :]
49
+ zeta = []
50
+ prefix = b"\x00\x00\x00"
51
+ for i in range(numel):
52
+ tmp = unpack("!I", prefix + layer[i : i + 1])[0]
53
+ if tmp >= 128:
54
+ tmp = -1 * (tmp - 128)
55
+ zeta.append(tmp)
56
+ zeta = torch.tensor(zeta).reshape(size)
57
+
58
+ # Step 3: dequantize the content
59
+ zeta = zeta * max_v / tuning_param
60
+
61
+ return zeta
@@ -0,0 +1,43 @@
1
+ """
2
+ A processor that encrypts model weights in MaskCrypt.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from plato.processors import model
11
+ from plato.utils import homo_enc
12
+
13
+
14
+ class Processor(model.Processor):
15
+ """
16
+ A processor that encrypts model weights with given encryption mask.
17
+ """
18
+
19
+ def __init__(self, mask=None, **kwargs) -> None:
20
+ super().__init__(**kwargs)
21
+ self.context = homo_enc.get_ckks_context()
22
+ self.mask = mask
23
+
24
+ para_nums = {}
25
+ extract_model = self.trainer.model.cpu().state_dict()
26
+ for key in extract_model.keys():
27
+ para_nums[key] = torch.numel(extract_model[key])
28
+ self.para_nums = para_nums
29
+
30
+ def process(self, data: Any) -> Any:
31
+ logging.info(
32
+ "[Client #%d] Encrypt the model weights with given encryption mask.",
33
+ self.client_id,
34
+ )
35
+
36
+ encrypted_weights = homo_enc.encrypt_weights(
37
+ data,
38
+ serialize=True,
39
+ context=self.context,
40
+ indices=self.mask,
41
+ )
42
+
43
+ return encrypted_weights
@@ -0,0 +1,18 @@
1
+ """
2
+ Implements a Processor for quantizing model parameters.
3
+ """
4
+
5
+ import torch
6
+
7
+ from plato.processors import model
8
+
9
+
10
+ class Processor(model.Processor):
11
+ """
12
+ Implements a Processor to quantize model parameters to 16-bit floating points.
13
+ """
14
+
15
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
16
+ """Dequantizes each individual layer of the model."""
17
+
18
+ return layer.to(torch.bfloat16)