plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,82 @@
1
+ """
2
+ Implements a Processor to quantize and compress upload models.
3
+
4
+ In more detail, this processor first quantizes each upload parameter under
5
+ the given quantization level. Next, compress and store each quantized value.
6
+ Hence, these 32-bit parameters can be converted into 8-bit parameters.
7
+
8
+ Reference:
9
+
10
+ Alistarh, D., Grubic, D., Li, J., Tomioka, R., & Vojnovic, M. (2017).
11
+ "QSGD: Communication-efficient SGD via gradient quantization and encoding."
12
+ Advances in neural information processing systems.
13
+
14
+ https://proceedings.neurips.cc/paper/2017/file/6c340f25839e6acdc73414517203f5f0-Paper.pdf
15
+ """
16
+
17
+ import random
18
+ from struct import pack, unpack
19
+ from typing import Any
20
+
21
+ import torch
22
+
23
+ from plato.processors import model
24
+
25
+
26
+ class Processor(model.Processor):
27
+ """
28
+ Implements a Processor to quantize model parameters with QSGD.
29
+ """
30
+
31
+ def __init__(self, quantization_level=64, **kwargs) -> None:
32
+ super().__init__(**kwargs)
33
+
34
+ self.quantization_level = quantization_level # must <= 128!
35
+
36
+ def _process_layer(self, layer: Any) -> Any:
37
+ """Quantizes each individual layer of the model with QSGD."""
38
+
39
+ def add_prob(prob: Any) -> Any:
40
+ """Adds 1 to the corresponding positions with given probability."""
41
+ size = prob.size()
42
+ prob = prob.reshape(-1)
43
+ random.seed()
44
+ for count, value in enumerate(prob):
45
+ if random.random() <= value:
46
+ prob[count] = 1
47
+ else:
48
+ prob[count] = 0
49
+ return torch.reshape(prob, size)
50
+
51
+ def handler(tensor: Any) -> Any:
52
+ """Handler function for the compression of quantized values."""
53
+ content = b""
54
+ tensor = tensor.reshape(-1)
55
+ for _, value in enumerate(tensor):
56
+ num = value.item()
57
+ if num < 0:
58
+ num = abs(num) ^ unpack("!i", b"\x00\x00\x00\x80")[0]
59
+ content += pack("!I", num)[3:4] # present each parameter in 1 byte
60
+ return content
61
+
62
+ # Step 1: quantization
63
+ tuning_param = self.quantization_level - 1 # tuning parameter
64
+ max_v = torch.max(abs(layer)) # max absolute value
65
+ neg = (-1) * layer.lt(0) + 1 * layer.ge(0)
66
+ ratio = abs(layer) / max_v # |v_i| / ||v||
67
+ level = (ratio * tuning_param - 1).ceil()
68
+ zeta = level + add_prob(ratio * tuning_param - level)
69
+ zeta = zeta.mul(neg).to(int)
70
+
71
+ # Step 2: handle the header
72
+ output = pack("!f", max_v.item()) # ! represents for big-endian
73
+ output += pack("!I", zeta.numel())
74
+ dimensions = len(zeta.size())
75
+ output += pack("!h", dimensions)
76
+ for i in range(dimensions):
77
+ output += pack("!h", zeta.size(i))
78
+
79
+ # Step 3: handle the content, each consists of 1 sign bit followed by 7 bits
80
+ output += handler(zeta)
81
+
82
+ return output
@@ -0,0 +1,34 @@
1
+ """
2
+ Implements a Processor for applying local differential privacy using randomized response.
3
+ """
4
+
5
+ import torch
6
+
7
+ from plato.config import Config
8
+ from plato.processors import model
9
+ from plato.utils import unary_encoding
10
+
11
+
12
+ class Processor(model.Processor):
13
+ """
14
+ Implements a Processor for applying local differential privacy using randomized response.
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ super().__init__(**kwargs)
19
+
20
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
21
+ if Config().algorithm.epsilon is None:
22
+ return layer
23
+
24
+ epsilon = Config().algorithm.epsilon
25
+
26
+ # Apply randomized response as the local differential privacy mechanism
27
+ layer = layer.detach().cpu().numpy()
28
+
29
+ layer = unary_encoding.encode(layer)
30
+ layer = unary_encoding.randomize(layer, epsilon)
31
+
32
+ layer = torch.tensor(layer, dtype=torch.float32)
33
+
34
+ return layer
@@ -0,0 +1,38 @@
1
+ """
2
+ Implements a Processor for converting MistNet features from PyTorch tensors to numpy ndarrays.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from plato.processors import base
9
+
10
+
11
+ class Processor(base.Processor):
12
+ """
13
+ Implements a Processor for converting MistNet features from PyTorch tensors to numpy ndarrays.
14
+ This is used only by MistNet clients at this time.
15
+ """
16
+
17
+ def __init__(self, client_id=None, **kwargs) -> None:
18
+ super().__init__(**kwargs)
19
+
20
+ self.client_id = client_id
21
+
22
+ def process(self, data: Any) -> Any:
23
+ """
24
+ Converts MistNet features from PyTorch tensors to numpy ndarrays.
25
+ """
26
+ feature_dataset = []
27
+
28
+ for logit, target in data:
29
+ feature_dataset.append(
30
+ (logit.detach().cpu().numpy(), target.detach().cpu().numpy())
31
+ )
32
+
33
+ logging.info(
34
+ "[Client #%d] Features converted from PyTorch tensors to ndarrays.",
35
+ self.client_id,
36
+ )
37
+
38
+ return feature_dataset
@@ -0,0 +1,26 @@
1
+ """
2
+ Implements a pipeline of processors for data payloads to pass through.
3
+ """
4
+
5
+ from typing import Any, List
6
+
7
+ from plato.processors import base
8
+
9
+
10
+ class Processor(base.Processor):
11
+ """
12
+ Pipelining a list of Processors from the configuration file.
13
+ """
14
+
15
+ def __init__(self, processors: List[base.Processor], *args, **kwargs) -> None:
16
+ super().__init__(*args, **kwargs)
17
+ self.processors = processors
18
+
19
+ def process(self, data: Any) -> Any:
20
+ """
21
+ Implementing a pipeline of Processors for data payloads.
22
+ """
23
+ for processor in self.processors:
24
+ data = processor.process(data)
25
+
26
+ return data
@@ -0,0 +1,124 @@
1
+ """
2
+ This registry for Processors contains framework-specific implementations of
3
+ Processors for data payloads.
4
+
5
+ Having a registry of all available classes is convenient for retrieving an instance
6
+ based on a configuration at run-time.
7
+ """
8
+
9
+ import logging
10
+ from typing import Tuple
11
+
12
+ from plato.config import Config
13
+ from plato.processors import (
14
+ base,
15
+ compress,
16
+ decompress,
17
+ feature_dequantize,
18
+ feature_gaussian,
19
+ feature_laplace,
20
+ feature_quantize,
21
+ feature_randomized_response,
22
+ feature_unbatch,
23
+ inbound_feature_tensors,
24
+ model_compress,
25
+ model_decompress,
26
+ model_deepcopy,
27
+ model_dequantize,
28
+ model_dequantize_qsgd,
29
+ model_quantize,
30
+ model_quantize_qsgd,
31
+ model_randomized_response,
32
+ outbound_feature_ndarrays,
33
+ pipeline,
34
+ structured_pruning,
35
+ unstructured_pruning,
36
+ )
37
+
38
+ registered_processors = {
39
+ "base": base.Processor,
40
+ "compress": compress.Processor,
41
+ "decompress": decompress.Processor,
42
+ "feature_randomized_response": feature_randomized_response.Processor,
43
+ "feature_gaussian": feature_gaussian.Processor,
44
+ "feature_laplace": feature_laplace.Processor,
45
+ "feature_quantize": feature_quantize.Processor,
46
+ "feature_dequantize": feature_dequantize.Processor,
47
+ "feature_unbatch": feature_unbatch.Processor,
48
+ "inbound_feature_tensors": inbound_feature_tensors.Processor,
49
+ "outbound_feature_ndarrays": outbound_feature_ndarrays.Processor,
50
+ "model_deepcopy": model_deepcopy.Processor,
51
+ "model_quantize": model_quantize.Processor,
52
+ "model_dequantize": model_dequantize.Processor,
53
+ "model_compress": model_compress.Processor,
54
+ "model_quantize_qsgd": model_quantize_qsgd.Processor,
55
+ "model_decompress": model_decompress.Processor,
56
+ "model_dequantize_qsgd": model_dequantize_qsgd.Processor,
57
+ "model_randomized_response": model_randomized_response.Processor,
58
+ "structured_pruning": structured_pruning.Processor,
59
+ "unstructured_pruning": unstructured_pruning.Processor,
60
+ }
61
+
62
+
63
+ def register_he_processors():
64
+ """Register homomorphic encryption processors if needed."""
65
+
66
+
67
+ def get(
68
+ user: str, processor_kwargs=None, **kwargs
69
+ ) -> Tuple[pipeline.Processor, pipeline.Processor]:
70
+ """Get an instance of the processor."""
71
+ outbound_processors = []
72
+ inbound_processors = []
73
+
74
+ assert user in ("Server", "Client")
75
+
76
+ if user == "Server":
77
+ config = Config().server
78
+ else:
79
+ config = Config().clients
80
+
81
+ if hasattr(config, "outbound_processors") and isinstance(
82
+ config.outbound_processors, list
83
+ ):
84
+ outbound_processors = config.outbound_processors
85
+
86
+ if hasattr(config, "inbound_processors") and isinstance(
87
+ config.inbound_processors, list
88
+ ):
89
+ inbound_processors = config.inbound_processors
90
+
91
+ for processor in outbound_processors:
92
+ logging.info("%s: Using Processor for sending payload: %s", user, processor)
93
+ for processor in inbound_processors:
94
+ logging.info("%s: Using Processor for receiving payload: %s", user, processor)
95
+
96
+ # Check if HE processors are needed based on server configuration
97
+ if hasattr(config, "type") and config.type == "fedavg_he":
98
+ # FedAvg server with homomorphic encryption needs to import tenseal,
99
+ # which is not available on all platforms such as macOS
100
+ from plato.processors import model_decrypt, model_encrypt
101
+
102
+ registered_processors.update(
103
+ {
104
+ "model_encrypt": model_encrypt.Processor,
105
+ "model_decrypt": model_decrypt.Processor,
106
+ }
107
+ )
108
+
109
+ logging.info("%s: Using homomorphic encryption processors.", user)
110
+
111
+ def map_f(name):
112
+ if processor_kwargs is not None and name in processor_kwargs:
113
+ this_kwargs = {**kwargs, **(processor_kwargs[name])}
114
+ else:
115
+ this_kwargs = kwargs
116
+
117
+ return registered_processors[name](name=name, **this_kwargs)
118
+
119
+ outbound_processors = list(map(map_f, outbound_processors))
120
+ inbound_processors = list(map(map_f, inbound_processors))
121
+
122
+ return pipeline.Processor(outbound_processors), pipeline.Processor(
123
+ inbound_processors
124
+ )
@@ -0,0 +1,57 @@
1
+ """
2
+ Processor for structured pruning of model weights.
3
+ """
4
+
5
+ import logging
6
+
7
+ import torch
8
+ import torch.nn.utils.prune as prune
9
+
10
+ from plato.processors import model
11
+
12
+
13
+ class Processor(model.Processor):
14
+ """
15
+ A processor for the structured pruning of model weights.
16
+ """
17
+
18
+ def __init__(
19
+ self, pruning_method="ln", amount=0.2, norm=1, dim=-1, **kwargs
20
+ ) -> None:
21
+ super().__init__(**kwargs)
22
+
23
+ self.pruning_method = pruning_method
24
+ self.amount = amount
25
+ self.norm = norm
26
+ self.dim = dim
27
+ self.model = None
28
+
29
+ def process(self, data):
30
+ """
31
+ Processes structured pruning of model weights layer by layer.
32
+ """
33
+ self.model = self.trainer.model
34
+
35
+ for _, module in self.model.named_modules():
36
+ if isinstance(module, torch.nn.Conv2d) or isinstance(
37
+ module, torch.nn.Linear
38
+ ):
39
+ if self.pruning_method == "ln":
40
+ prune.ln_structured(
41
+ module, "weight", self.amount, n=self.norm, dim=self.dim
42
+ )
43
+ elif self.pruning_method == "random":
44
+ prune.random_structured(module, "weight", self.amount, dim=self.dim)
45
+ prune.remove(module, "weight")
46
+
47
+ output = self.model.cpu().state_dict()
48
+
49
+ if self.client_id is None:
50
+ logging.info("[Server #%d] Structured pruning applied.", self.server_id)
51
+ else:
52
+ logging.info("[Client #%d] Structured pruning applied.", self.client_id)
53
+
54
+ return output
55
+
56
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
57
+ """No need to process individual layer of the model"""
@@ -0,0 +1,73 @@
1
+ """
2
+ Implements a Processor for global unstructured pruning of model weights.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn.utils.prune as prune
10
+
11
+ from plato.processors import model
12
+
13
+
14
+ class Processor(model.Processor):
15
+ """
16
+ Implements a Processor for global unstructured pruning of model weights.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ parameters_to_prune=None,
22
+ pruning_method=prune.L1Unstructured,
23
+ amount=0.2,
24
+ **kwargs,
25
+ ) -> None:
26
+ super().__init__(**kwargs)
27
+
28
+ self.parameters_to_prune = parameters_to_prune
29
+ self.pruning_method = pruning_method
30
+ self.amount = amount
31
+ self.model = None
32
+
33
+ def process(self, data: Any) -> Any:
34
+ """
35
+ Proceesses global unstructured pruning on model weights.
36
+ """
37
+
38
+ self.model = self.trainer.model
39
+
40
+ if self.parameters_to_prune is None:
41
+ self.parameters_to_prune = []
42
+ for _, module in self.model.named_modules():
43
+ if isinstance(module, torch.nn.Conv2d) or isinstance(
44
+ module, torch.nn.Linear
45
+ ):
46
+ self.parameters_to_prune.append((module, "weight"))
47
+
48
+ prune.global_unstructured(
49
+ self.parameters_to_prune,
50
+ pruning_method=self.pruning_method,
51
+ amount=self.amount,
52
+ )
53
+
54
+ for module, name in self.parameters_to_prune:
55
+ prune.remove(module, name)
56
+
57
+ output = self.model.cpu().state_dict()
58
+
59
+ if self.client_id is None:
60
+ logging.info(
61
+ "[Server #%d] Global unstructured pruning applied.",
62
+ self.server_id,
63
+ )
64
+ else:
65
+ logging.info(
66
+ "[Client #%d] Global unstructured pruning applied.",
67
+ self.client_id,
68
+ )
69
+
70
+ return output
71
+
72
+ def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
73
+ return layer
File without changes
@@ -0,0 +1,41 @@
1
+ """
2
+ Samples all the data from a dataset. Applicable in cases where the dataset comes from
3
+ local sources only. Used by the Federated EMNIST dataset and the MistNet server.
4
+ """
5
+
6
+ import random
7
+
8
+ from plato.samplers import base
9
+ from plato.config import Config
10
+
11
+
12
+ class Sampler(base.Sampler):
13
+ """Create a data sampler that samples all the data in the dataset.
14
+ Used by the MistNet server.
15
+ """
16
+
17
+ def __init__(self, datasource, client_id=0, testing=False):
18
+ super().__init__()
19
+ self.client_id = client_id
20
+
21
+ if testing:
22
+ all_inclusive = range(len(datasource.get_test_set()))
23
+ if hasattr(Config().data, "testset_size"):
24
+ self.data_samples = random.sample(
25
+ all_inclusive, Config().data.testset_size
26
+ )
27
+ else:
28
+ self.data_samples = all_inclusive
29
+ else:
30
+ self.data_samples = range(len(datasource.get_train_set()))
31
+
32
+ def get(self):
33
+ import torch
34
+
35
+ gen = torch.Generator()
36
+ gen.manual_seed(self.random_seed)
37
+ return torch.utils.data.SubsetRandomSampler(self.data_samples, generator=gen)
38
+
39
+ def num_samples(self):
40
+ """Returns the length of the dataset after sampling."""
41
+ return len(self.data_samples)
plato/samplers/base.py ADDED
@@ -0,0 +1,31 @@
1
+ """
2
+ Base class for sampling data so that a dataset can be divided across the clients.
3
+ """
4
+
5
+ import os
6
+ from abc import abstractmethod
7
+
8
+ from plato.config import Config
9
+
10
+
11
+ class Sampler:
12
+ """Base class for data samplers so that the dataset is divided into
13
+ partitions across the clients."""
14
+
15
+ def __init__(self):
16
+ if hasattr(Config().data, "random_seed"):
17
+ # Keeping random seed the same across the clients
18
+ # so that the experiments are reproducible
19
+ self.random_seed = Config().data.random_seed
20
+ else:
21
+ # The random seed will be different across different
22
+ # runs if it is not provided.
23
+ self.random_seed = os.getpid()
24
+
25
+ @abstractmethod
26
+ def get(self):
27
+ """Obtains an instance of the sampler."""
28
+
29
+ @abstractmethod
30
+ def num_samples(self):
31
+ """Returns the length of the dataset after sampling."""
@@ -0,0 +1,81 @@
1
+ """
2
+ Samples data from a dataset, biased across labels according to the Dirichlet distribution.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
8
+ from plato.config import Config
9
+
10
+ from plato.samplers import base
11
+
12
+
13
+ class Sampler(base.Sampler):
14
+ """Create a data sampler for each client to use a divided partition of the
15
+ dataset, biased across labels according to the Dirichlet distribution."""
16
+
17
+ def __init__(self, datasource, client_id, testing):
18
+ super().__init__()
19
+
20
+ # Different clients should have a different bias across the labels & partition size
21
+ np.random.seed(self.random_seed * int(client_id))
22
+
23
+ # Concentration parameter to be used in the Dirichlet distribution
24
+ concentration = (
25
+ Config().data.concentration
26
+ if hasattr(Config().data, "concentration")
27
+ else 1.0
28
+ )
29
+
30
+ if testing:
31
+ target_list = datasource.get_test_set().targets
32
+ else:
33
+ # The list of labels (targets) for all the examples
34
+ target_list = datasource.targets()
35
+
36
+ class_list = datasource.classes()
37
+
38
+ target_proportions = np.random.dirichlet(
39
+ np.repeat(concentration, len(class_list))
40
+ )
41
+
42
+ if np.isnan(np.sum(target_proportions)):
43
+ target_proportions = np.repeat(0, len(class_list))
44
+ target_proportions[np.random.randint(0, len(class_list))] = 1
45
+
46
+ self.sample_weights = target_proportions[target_list]
47
+
48
+ def num_samples(self) -> int:
49
+ """Returns the length of the dataset after sampling."""
50
+ sampled_size = Config().data.partition_size
51
+
52
+ # Variable partition size across clients
53
+ if hasattr(Config().data, "partition_distribution"):
54
+ dist = Config().data.partition_distribution
55
+
56
+ if dist.distribution.lower() == "uniform":
57
+ sampled_size *= np.random.uniform(dist.low, dist.high)
58
+
59
+ if dist.distribution.lower() == "normal":
60
+ sampled_size *= np.random.normal(dist.mean, dist.high)
61
+
62
+ sampled_size = int(sampled_size)
63
+
64
+ return sampled_size
65
+
66
+ def get(self):
67
+ """Obtains an instance of the sampler."""
68
+ gen = torch.Generator()
69
+ gen.manual_seed(self.random_seed)
70
+
71
+ # Samples without replacement using the sample weights
72
+ subset_indices = list(
73
+ WeightedRandomSampler(
74
+ weights=self.sample_weights,
75
+ num_samples=self.num_samples(),
76
+ replacement=False,
77
+ generator=gen,
78
+ )
79
+ )
80
+
81
+ return SubsetRandomSampler(subset_indices, generator=gen)