plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,82 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor to quantize and compress upload models.
|
3
|
+
|
4
|
+
In more detail, this processor first quantizes each upload parameter under
|
5
|
+
the given quantization level. Next, compress and store each quantized value.
|
6
|
+
Hence, these 32-bit parameters can be converted into 8-bit parameters.
|
7
|
+
|
8
|
+
Reference:
|
9
|
+
|
10
|
+
Alistarh, D., Grubic, D., Li, J., Tomioka, R., & Vojnovic, M. (2017).
|
11
|
+
"QSGD: Communication-efficient SGD via gradient quantization and encoding."
|
12
|
+
Advances in neural information processing systems.
|
13
|
+
|
14
|
+
https://proceedings.neurips.cc/paper/2017/file/6c340f25839e6acdc73414517203f5f0-Paper.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
import random
|
18
|
+
from struct import pack, unpack
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
import torch
|
22
|
+
|
23
|
+
from plato.processors import model
|
24
|
+
|
25
|
+
|
26
|
+
class Processor(model.Processor):
|
27
|
+
"""
|
28
|
+
Implements a Processor to quantize model parameters with QSGD.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, quantization_level=64, **kwargs) -> None:
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
|
34
|
+
self.quantization_level = quantization_level # must <= 128!
|
35
|
+
|
36
|
+
def _process_layer(self, layer: Any) -> Any:
|
37
|
+
"""Quantizes each individual layer of the model with QSGD."""
|
38
|
+
|
39
|
+
def add_prob(prob: Any) -> Any:
|
40
|
+
"""Adds 1 to the corresponding positions with given probability."""
|
41
|
+
size = prob.size()
|
42
|
+
prob = prob.reshape(-1)
|
43
|
+
random.seed()
|
44
|
+
for count, value in enumerate(prob):
|
45
|
+
if random.random() <= value:
|
46
|
+
prob[count] = 1
|
47
|
+
else:
|
48
|
+
prob[count] = 0
|
49
|
+
return torch.reshape(prob, size)
|
50
|
+
|
51
|
+
def handler(tensor: Any) -> Any:
|
52
|
+
"""Handler function for the compression of quantized values."""
|
53
|
+
content = b""
|
54
|
+
tensor = tensor.reshape(-1)
|
55
|
+
for _, value in enumerate(tensor):
|
56
|
+
num = value.item()
|
57
|
+
if num < 0:
|
58
|
+
num = abs(num) ^ unpack("!i", b"\x00\x00\x00\x80")[0]
|
59
|
+
content += pack("!I", num)[3:4] # present each parameter in 1 byte
|
60
|
+
return content
|
61
|
+
|
62
|
+
# Step 1: quantization
|
63
|
+
tuning_param = self.quantization_level - 1 # tuning parameter
|
64
|
+
max_v = torch.max(abs(layer)) # max absolute value
|
65
|
+
neg = (-1) * layer.lt(0) + 1 * layer.ge(0)
|
66
|
+
ratio = abs(layer) / max_v # |v_i| / ||v||
|
67
|
+
level = (ratio * tuning_param - 1).ceil()
|
68
|
+
zeta = level + add_prob(ratio * tuning_param - level)
|
69
|
+
zeta = zeta.mul(neg).to(int)
|
70
|
+
|
71
|
+
# Step 2: handle the header
|
72
|
+
output = pack("!f", max_v.item()) # ! represents for big-endian
|
73
|
+
output += pack("!I", zeta.numel())
|
74
|
+
dimensions = len(zeta.size())
|
75
|
+
output += pack("!h", dimensions)
|
76
|
+
for i in range(dimensions):
|
77
|
+
output += pack("!h", zeta.size(i))
|
78
|
+
|
79
|
+
# Step 3: handle the content, each consists of 1 sign bit followed by 7 bits
|
80
|
+
output += handler(zeta)
|
81
|
+
|
82
|
+
return output
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying local differential privacy using randomized response.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
from plato.processors import model
|
9
|
+
from plato.utils import unary_encoding
|
10
|
+
|
11
|
+
|
12
|
+
class Processor(model.Processor):
|
13
|
+
"""
|
14
|
+
Implements a Processor for applying local differential privacy using randomized response.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, **kwargs) -> None:
|
18
|
+
super().__init__(**kwargs)
|
19
|
+
|
20
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
21
|
+
if Config().algorithm.epsilon is None:
|
22
|
+
return layer
|
23
|
+
|
24
|
+
epsilon = Config().algorithm.epsilon
|
25
|
+
|
26
|
+
# Apply randomized response as the local differential privacy mechanism
|
27
|
+
layer = layer.detach().cpu().numpy()
|
28
|
+
|
29
|
+
layer = unary_encoding.encode(layer)
|
30
|
+
layer = unary_encoding.randomize(layer, epsilon)
|
31
|
+
|
32
|
+
layer = torch.tensor(layer, dtype=torch.float32)
|
33
|
+
|
34
|
+
return layer
|
@@ -0,0 +1,38 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for converting MistNet features from PyTorch tensors to numpy ndarrays.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from plato.processors import base
|
9
|
+
|
10
|
+
|
11
|
+
class Processor(base.Processor):
|
12
|
+
"""
|
13
|
+
Implements a Processor for converting MistNet features from PyTorch tensors to numpy ndarrays.
|
14
|
+
This is used only by MistNet clients at this time.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, client_id=None, **kwargs) -> None:
|
18
|
+
super().__init__(**kwargs)
|
19
|
+
|
20
|
+
self.client_id = client_id
|
21
|
+
|
22
|
+
def process(self, data: Any) -> Any:
|
23
|
+
"""
|
24
|
+
Converts MistNet features from PyTorch tensors to numpy ndarrays.
|
25
|
+
"""
|
26
|
+
feature_dataset = []
|
27
|
+
|
28
|
+
for logit, target in data:
|
29
|
+
feature_dataset.append(
|
30
|
+
(logit.detach().cpu().numpy(), target.detach().cpu().numpy())
|
31
|
+
)
|
32
|
+
|
33
|
+
logging.info(
|
34
|
+
"[Client #%d] Features converted from PyTorch tensors to ndarrays.",
|
35
|
+
self.client_id,
|
36
|
+
)
|
37
|
+
|
38
|
+
return feature_dataset
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""
|
2
|
+
Implements a pipeline of processors for data payloads to pass through.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any, List
|
6
|
+
|
7
|
+
from plato.processors import base
|
8
|
+
|
9
|
+
|
10
|
+
class Processor(base.Processor):
|
11
|
+
"""
|
12
|
+
Pipelining a list of Processors from the configuration file.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, processors: List[base.Processor], *args, **kwargs) -> None:
|
16
|
+
super().__init__(*args, **kwargs)
|
17
|
+
self.processors = processors
|
18
|
+
|
19
|
+
def process(self, data: Any) -> Any:
|
20
|
+
"""
|
21
|
+
Implementing a pipeline of Processors for data payloads.
|
22
|
+
"""
|
23
|
+
for processor in self.processors:
|
24
|
+
data = processor.process(data)
|
25
|
+
|
26
|
+
return data
|
@@ -0,0 +1,124 @@
|
|
1
|
+
"""
|
2
|
+
This registry for Processors contains framework-specific implementations of
|
3
|
+
Processors for data payloads.
|
4
|
+
|
5
|
+
Having a registry of all available classes is convenient for retrieving an instance
|
6
|
+
based on a configuration at run-time.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
from typing import Tuple
|
11
|
+
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.processors import (
|
14
|
+
base,
|
15
|
+
compress,
|
16
|
+
decompress,
|
17
|
+
feature_dequantize,
|
18
|
+
feature_gaussian,
|
19
|
+
feature_laplace,
|
20
|
+
feature_quantize,
|
21
|
+
feature_randomized_response,
|
22
|
+
feature_unbatch,
|
23
|
+
inbound_feature_tensors,
|
24
|
+
model_compress,
|
25
|
+
model_decompress,
|
26
|
+
model_deepcopy,
|
27
|
+
model_dequantize,
|
28
|
+
model_dequantize_qsgd,
|
29
|
+
model_quantize,
|
30
|
+
model_quantize_qsgd,
|
31
|
+
model_randomized_response,
|
32
|
+
outbound_feature_ndarrays,
|
33
|
+
pipeline,
|
34
|
+
structured_pruning,
|
35
|
+
unstructured_pruning,
|
36
|
+
)
|
37
|
+
|
38
|
+
registered_processors = {
|
39
|
+
"base": base.Processor,
|
40
|
+
"compress": compress.Processor,
|
41
|
+
"decompress": decompress.Processor,
|
42
|
+
"feature_randomized_response": feature_randomized_response.Processor,
|
43
|
+
"feature_gaussian": feature_gaussian.Processor,
|
44
|
+
"feature_laplace": feature_laplace.Processor,
|
45
|
+
"feature_quantize": feature_quantize.Processor,
|
46
|
+
"feature_dequantize": feature_dequantize.Processor,
|
47
|
+
"feature_unbatch": feature_unbatch.Processor,
|
48
|
+
"inbound_feature_tensors": inbound_feature_tensors.Processor,
|
49
|
+
"outbound_feature_ndarrays": outbound_feature_ndarrays.Processor,
|
50
|
+
"model_deepcopy": model_deepcopy.Processor,
|
51
|
+
"model_quantize": model_quantize.Processor,
|
52
|
+
"model_dequantize": model_dequantize.Processor,
|
53
|
+
"model_compress": model_compress.Processor,
|
54
|
+
"model_quantize_qsgd": model_quantize_qsgd.Processor,
|
55
|
+
"model_decompress": model_decompress.Processor,
|
56
|
+
"model_dequantize_qsgd": model_dequantize_qsgd.Processor,
|
57
|
+
"model_randomized_response": model_randomized_response.Processor,
|
58
|
+
"structured_pruning": structured_pruning.Processor,
|
59
|
+
"unstructured_pruning": unstructured_pruning.Processor,
|
60
|
+
}
|
61
|
+
|
62
|
+
|
63
|
+
def register_he_processors():
|
64
|
+
"""Register homomorphic encryption processors if needed."""
|
65
|
+
|
66
|
+
|
67
|
+
def get(
|
68
|
+
user: str, processor_kwargs=None, **kwargs
|
69
|
+
) -> Tuple[pipeline.Processor, pipeline.Processor]:
|
70
|
+
"""Get an instance of the processor."""
|
71
|
+
outbound_processors = []
|
72
|
+
inbound_processors = []
|
73
|
+
|
74
|
+
assert user in ("Server", "Client")
|
75
|
+
|
76
|
+
if user == "Server":
|
77
|
+
config = Config().server
|
78
|
+
else:
|
79
|
+
config = Config().clients
|
80
|
+
|
81
|
+
if hasattr(config, "outbound_processors") and isinstance(
|
82
|
+
config.outbound_processors, list
|
83
|
+
):
|
84
|
+
outbound_processors = config.outbound_processors
|
85
|
+
|
86
|
+
if hasattr(config, "inbound_processors") and isinstance(
|
87
|
+
config.inbound_processors, list
|
88
|
+
):
|
89
|
+
inbound_processors = config.inbound_processors
|
90
|
+
|
91
|
+
for processor in outbound_processors:
|
92
|
+
logging.info("%s: Using Processor for sending payload: %s", user, processor)
|
93
|
+
for processor in inbound_processors:
|
94
|
+
logging.info("%s: Using Processor for receiving payload: %s", user, processor)
|
95
|
+
|
96
|
+
# Check if HE processors are needed based on server configuration
|
97
|
+
if hasattr(config, "type") and config.type == "fedavg_he":
|
98
|
+
# FedAvg server with homomorphic encryption needs to import tenseal,
|
99
|
+
# which is not available on all platforms such as macOS
|
100
|
+
from plato.processors import model_decrypt, model_encrypt
|
101
|
+
|
102
|
+
registered_processors.update(
|
103
|
+
{
|
104
|
+
"model_encrypt": model_encrypt.Processor,
|
105
|
+
"model_decrypt": model_decrypt.Processor,
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
109
|
+
logging.info("%s: Using homomorphic encryption processors.", user)
|
110
|
+
|
111
|
+
def map_f(name):
|
112
|
+
if processor_kwargs is not None and name in processor_kwargs:
|
113
|
+
this_kwargs = {**kwargs, **(processor_kwargs[name])}
|
114
|
+
else:
|
115
|
+
this_kwargs = kwargs
|
116
|
+
|
117
|
+
return registered_processors[name](name=name, **this_kwargs)
|
118
|
+
|
119
|
+
outbound_processors = list(map(map_f, outbound_processors))
|
120
|
+
inbound_processors = list(map(map_f, inbound_processors))
|
121
|
+
|
122
|
+
return pipeline.Processor(outbound_processors), pipeline.Processor(
|
123
|
+
inbound_processors
|
124
|
+
)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""
|
2
|
+
Processor for structured pruning of model weights.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.utils.prune as prune
|
9
|
+
|
10
|
+
from plato.processors import model
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(model.Processor):
|
14
|
+
"""
|
15
|
+
A processor for the structured pruning of model weights.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(
|
19
|
+
self, pruning_method="ln", amount=0.2, norm=1, dim=-1, **kwargs
|
20
|
+
) -> None:
|
21
|
+
super().__init__(**kwargs)
|
22
|
+
|
23
|
+
self.pruning_method = pruning_method
|
24
|
+
self.amount = amount
|
25
|
+
self.norm = norm
|
26
|
+
self.dim = dim
|
27
|
+
self.model = None
|
28
|
+
|
29
|
+
def process(self, data):
|
30
|
+
"""
|
31
|
+
Processes structured pruning of model weights layer by layer.
|
32
|
+
"""
|
33
|
+
self.model = self.trainer.model
|
34
|
+
|
35
|
+
for _, module in self.model.named_modules():
|
36
|
+
if isinstance(module, torch.nn.Conv2d) or isinstance(
|
37
|
+
module, torch.nn.Linear
|
38
|
+
):
|
39
|
+
if self.pruning_method == "ln":
|
40
|
+
prune.ln_structured(
|
41
|
+
module, "weight", self.amount, n=self.norm, dim=self.dim
|
42
|
+
)
|
43
|
+
elif self.pruning_method == "random":
|
44
|
+
prune.random_structured(module, "weight", self.amount, dim=self.dim)
|
45
|
+
prune.remove(module, "weight")
|
46
|
+
|
47
|
+
output = self.model.cpu().state_dict()
|
48
|
+
|
49
|
+
if self.client_id is None:
|
50
|
+
logging.info("[Server #%d] Structured pruning applied.", self.server_id)
|
51
|
+
else:
|
52
|
+
logging.info("[Client #%d] Structured pruning applied.", self.client_id)
|
53
|
+
|
54
|
+
return output
|
55
|
+
|
56
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
57
|
+
"""No need to process individual layer of the model"""
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for global unstructured pruning of model weights.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.nn.utils.prune as prune
|
10
|
+
|
11
|
+
from plato.processors import model
|
12
|
+
|
13
|
+
|
14
|
+
class Processor(model.Processor):
|
15
|
+
"""
|
16
|
+
Implements a Processor for global unstructured pruning of model weights.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
parameters_to_prune=None,
|
22
|
+
pruning_method=prune.L1Unstructured,
|
23
|
+
amount=0.2,
|
24
|
+
**kwargs,
|
25
|
+
) -> None:
|
26
|
+
super().__init__(**kwargs)
|
27
|
+
|
28
|
+
self.parameters_to_prune = parameters_to_prune
|
29
|
+
self.pruning_method = pruning_method
|
30
|
+
self.amount = amount
|
31
|
+
self.model = None
|
32
|
+
|
33
|
+
def process(self, data: Any) -> Any:
|
34
|
+
"""
|
35
|
+
Proceesses global unstructured pruning on model weights.
|
36
|
+
"""
|
37
|
+
|
38
|
+
self.model = self.trainer.model
|
39
|
+
|
40
|
+
if self.parameters_to_prune is None:
|
41
|
+
self.parameters_to_prune = []
|
42
|
+
for _, module in self.model.named_modules():
|
43
|
+
if isinstance(module, torch.nn.Conv2d) or isinstance(
|
44
|
+
module, torch.nn.Linear
|
45
|
+
):
|
46
|
+
self.parameters_to_prune.append((module, "weight"))
|
47
|
+
|
48
|
+
prune.global_unstructured(
|
49
|
+
self.parameters_to_prune,
|
50
|
+
pruning_method=self.pruning_method,
|
51
|
+
amount=self.amount,
|
52
|
+
)
|
53
|
+
|
54
|
+
for module, name in self.parameters_to_prune:
|
55
|
+
prune.remove(module, name)
|
56
|
+
|
57
|
+
output = self.model.cpu().state_dict()
|
58
|
+
|
59
|
+
if self.client_id is None:
|
60
|
+
logging.info(
|
61
|
+
"[Server #%d] Global unstructured pruning applied.",
|
62
|
+
self.server_id,
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
logging.info(
|
66
|
+
"[Client #%d] Global unstructured pruning applied.",
|
67
|
+
self.client_id,
|
68
|
+
)
|
69
|
+
|
70
|
+
return output
|
71
|
+
|
72
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
73
|
+
return layer
|
File without changes
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""
|
2
|
+
Samples all the data from a dataset. Applicable in cases where the dataset comes from
|
3
|
+
local sources only. Used by the Federated EMNIST dataset and the MistNet server.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import random
|
7
|
+
|
8
|
+
from plato.samplers import base
|
9
|
+
from plato.config import Config
|
10
|
+
|
11
|
+
|
12
|
+
class Sampler(base.Sampler):
|
13
|
+
"""Create a data sampler that samples all the data in the dataset.
|
14
|
+
Used by the MistNet server.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, datasource, client_id=0, testing=False):
|
18
|
+
super().__init__()
|
19
|
+
self.client_id = client_id
|
20
|
+
|
21
|
+
if testing:
|
22
|
+
all_inclusive = range(len(datasource.get_test_set()))
|
23
|
+
if hasattr(Config().data, "testset_size"):
|
24
|
+
self.data_samples = random.sample(
|
25
|
+
all_inclusive, Config().data.testset_size
|
26
|
+
)
|
27
|
+
else:
|
28
|
+
self.data_samples = all_inclusive
|
29
|
+
else:
|
30
|
+
self.data_samples = range(len(datasource.get_train_set()))
|
31
|
+
|
32
|
+
def get(self):
|
33
|
+
import torch
|
34
|
+
|
35
|
+
gen = torch.Generator()
|
36
|
+
gen.manual_seed(self.random_seed)
|
37
|
+
return torch.utils.data.SubsetRandomSampler(self.data_samples, generator=gen)
|
38
|
+
|
39
|
+
def num_samples(self):
|
40
|
+
"""Returns the length of the dataset after sampling."""
|
41
|
+
return len(self.data_samples)
|
plato/samplers/base.py
ADDED
@@ -0,0 +1,31 @@
|
|
1
|
+
"""
|
2
|
+
Base class for sampling data so that a dataset can be divided across the clients.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
from abc import abstractmethod
|
7
|
+
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
|
11
|
+
class Sampler:
|
12
|
+
"""Base class for data samplers so that the dataset is divided into
|
13
|
+
partitions across the clients."""
|
14
|
+
|
15
|
+
def __init__(self):
|
16
|
+
if hasattr(Config().data, "random_seed"):
|
17
|
+
# Keeping random seed the same across the clients
|
18
|
+
# so that the experiments are reproducible
|
19
|
+
self.random_seed = Config().data.random_seed
|
20
|
+
else:
|
21
|
+
# The random seed will be different across different
|
22
|
+
# runs if it is not provided.
|
23
|
+
self.random_seed = os.getpid()
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def get(self):
|
27
|
+
"""Obtains an instance of the sampler."""
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def num_samples(self):
|
31
|
+
"""Returns the length of the dataset after sampling."""
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across labels according to the Dirichlet distribution.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
from plato.samplers import base
|
11
|
+
|
12
|
+
|
13
|
+
class Sampler(base.Sampler):
|
14
|
+
"""Create a data sampler for each client to use a divided partition of the
|
15
|
+
dataset, biased across labels according to the Dirichlet distribution."""
|
16
|
+
|
17
|
+
def __init__(self, datasource, client_id, testing):
|
18
|
+
super().__init__()
|
19
|
+
|
20
|
+
# Different clients should have a different bias across the labels & partition size
|
21
|
+
np.random.seed(self.random_seed * int(client_id))
|
22
|
+
|
23
|
+
# Concentration parameter to be used in the Dirichlet distribution
|
24
|
+
concentration = (
|
25
|
+
Config().data.concentration
|
26
|
+
if hasattr(Config().data, "concentration")
|
27
|
+
else 1.0
|
28
|
+
)
|
29
|
+
|
30
|
+
if testing:
|
31
|
+
target_list = datasource.get_test_set().targets
|
32
|
+
else:
|
33
|
+
# The list of labels (targets) for all the examples
|
34
|
+
target_list = datasource.targets()
|
35
|
+
|
36
|
+
class_list = datasource.classes()
|
37
|
+
|
38
|
+
target_proportions = np.random.dirichlet(
|
39
|
+
np.repeat(concentration, len(class_list))
|
40
|
+
)
|
41
|
+
|
42
|
+
if np.isnan(np.sum(target_proportions)):
|
43
|
+
target_proportions = np.repeat(0, len(class_list))
|
44
|
+
target_proportions[np.random.randint(0, len(class_list))] = 1
|
45
|
+
|
46
|
+
self.sample_weights = target_proportions[target_list]
|
47
|
+
|
48
|
+
def num_samples(self) -> int:
|
49
|
+
"""Returns the length of the dataset after sampling."""
|
50
|
+
sampled_size = Config().data.partition_size
|
51
|
+
|
52
|
+
# Variable partition size across clients
|
53
|
+
if hasattr(Config().data, "partition_distribution"):
|
54
|
+
dist = Config().data.partition_distribution
|
55
|
+
|
56
|
+
if dist.distribution.lower() == "uniform":
|
57
|
+
sampled_size *= np.random.uniform(dist.low, dist.high)
|
58
|
+
|
59
|
+
if dist.distribution.lower() == "normal":
|
60
|
+
sampled_size *= np.random.normal(dist.mean, dist.high)
|
61
|
+
|
62
|
+
sampled_size = int(sampled_size)
|
63
|
+
|
64
|
+
return sampled_size
|
65
|
+
|
66
|
+
def get(self):
|
67
|
+
"""Obtains an instance of the sampler."""
|
68
|
+
gen = torch.Generator()
|
69
|
+
gen.manual_seed(self.random_seed)
|
70
|
+
|
71
|
+
# Samples without replacement using the sample weights
|
72
|
+
subset_indices = list(
|
73
|
+
WeightedRandomSampler(
|
74
|
+
weights=self.sample_weights,
|
75
|
+
num_samples=self.num_samples(),
|
76
|
+
replacement=False,
|
77
|
+
generator=gen,
|
78
|
+
)
|
79
|
+
)
|
80
|
+
|
81
|
+
return SubsetRandomSampler(subset_indices, generator=gen)
|