plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying local differential privacy using additive noise mechanism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy
|
9
|
+
|
10
|
+
from plato.processors import feature
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(feature.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for applying local differential privacy using additive noise mechanism.
|
16
|
+
"""
|
17
|
+
|
18
|
+
methods = {
|
19
|
+
"gaussian": numpy.random.normal,
|
20
|
+
"laplace": numpy.random.laplace,
|
21
|
+
}
|
22
|
+
|
23
|
+
def __init__(self, method="", scale=None, **kwargs) -> None:
|
24
|
+
self._method = method
|
25
|
+
|
26
|
+
def func(logits, targets):
|
27
|
+
return (
|
28
|
+
Processor.methods[method](logits, scale),
|
29
|
+
targets,
|
30
|
+
)
|
31
|
+
|
32
|
+
super().__init__(method=func, **kwargs)
|
33
|
+
|
34
|
+
def process(self, data: Any) -> Any:
|
35
|
+
"""
|
36
|
+
Implements a Processor for applying randomized response as the local differential privacy
|
37
|
+
mechanism.
|
38
|
+
"""
|
39
|
+
|
40
|
+
output = super().process(data)
|
41
|
+
|
42
|
+
logging.info(
|
43
|
+
"[Client #%d] Local differential privacy (using the %s mechanism) applied.",
|
44
|
+
self.client_id,
|
45
|
+
self._method,
|
46
|
+
)
|
47
|
+
|
48
|
+
return output
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying dequantization to MistNet PyTorch features.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from plato.processors import base
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(base.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for applying dequantization to MistNet PyTorch features.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, server_id=None, **kwargs) -> None:
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
|
21
|
+
self.server_id = server_id
|
22
|
+
|
23
|
+
def process(self, data: Any) -> Any:
|
24
|
+
"""
|
25
|
+
Implements a Processor for applying dequantization to MistNet PyTorch features.
|
26
|
+
"""
|
27
|
+
feature_dataset = []
|
28
|
+
|
29
|
+
for logit, target in data:
|
30
|
+
feature_dataset.append((torch.dequantize(logit), target))
|
31
|
+
|
32
|
+
logging.info("[Server #%d] Dequantized features.", self.server_id)
|
33
|
+
|
34
|
+
return feature_dataset
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying local differential privacy using gaussian mechanism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import math
|
6
|
+
|
7
|
+
from plato.processors import feature_additive_noise
|
8
|
+
|
9
|
+
|
10
|
+
class Processor(feature_additive_noise.Processor):
|
11
|
+
"""
|
12
|
+
Implements a Processor for applying local differential privacy using gaussian mechanism.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, epsilon=None, delta=None, sensitivity=None, **kwargs) -> None:
|
16
|
+
scale = 2 * math.log(1.25 / delta) * sensitivity**2 / epsilon**2
|
17
|
+
super().__init__(method="gaussian", scale=scale, **kwargs)
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying local differential privacy using laplace mechanism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from plato.processors import feature_additive_noise
|
6
|
+
|
7
|
+
|
8
|
+
class Processor(feature_additive_noise.Processor):
|
9
|
+
"""
|
10
|
+
Implements a Processor for applying local differential privacy using the Laplace mechanism.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, epsilon=None, sensitivity=None, **kwargs) -> None:
|
14
|
+
scale = sensitivity / epsilon
|
15
|
+
super().__init__(method="laplace", scale=scale, **kwargs)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying quantization to MistNet PyTorch features.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from plato.processors import feature
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(feature.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for applying quantization to MistNet PyTorch features.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, scale=0.1, zero_point=10, dtype=torch.quint8, **kwargs) -> None:
|
19
|
+
def func(logits, targets):
|
20
|
+
logits = torch.quantize_per_tensor(logits, scale, zero_point, dtype)
|
21
|
+
return logits, targets
|
22
|
+
|
23
|
+
super().__init__(method=func, use_numpy=False, **kwargs)
|
24
|
+
|
25
|
+
def process(self, data: Any) -> Any:
|
26
|
+
"""
|
27
|
+
Implements a Processor for applying quantization to MistNet PyTorch features.
|
28
|
+
"""
|
29
|
+
|
30
|
+
output = super().process(data)
|
31
|
+
|
32
|
+
logging.info("[Client #%d] Quantization applied.", self.client_id)
|
33
|
+
|
34
|
+
return output
|
@@ -0,0 +1,50 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for applying local differential privacy using randomized response.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from plato.config import Config
|
9
|
+
from plato.processors import feature
|
10
|
+
from plato.utils import unary_encoding
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(feature.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for applying local differential privacy using randomized response.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, **kwargs) -> None:
|
19
|
+
def func(logits, targets):
|
20
|
+
logits = unary_encoding.encode(logits)
|
21
|
+
|
22
|
+
if Config().algorithm.epsilon is None:
|
23
|
+
return logits, targets
|
24
|
+
|
25
|
+
_randomize = getattr(self.trainer, "randomize", None)
|
26
|
+
epsilon = Config().algorithm.epsilon
|
27
|
+
|
28
|
+
if callable(_randomize):
|
29
|
+
logits = self.trainer.randomize(logits, targets, epsilon)
|
30
|
+
else:
|
31
|
+
logits = unary_encoding.randomize(logits, epsilon)
|
32
|
+
|
33
|
+
return logits, targets
|
34
|
+
|
35
|
+
super().__init__(method=func, **kwargs)
|
36
|
+
|
37
|
+
def process(self, data: Any) -> Any:
|
38
|
+
"""
|
39
|
+
Implements a Processor for applying randomized response as the
|
40
|
+
local differential privacy mechanism.
|
41
|
+
"""
|
42
|
+
|
43
|
+
output = super().process(data)
|
44
|
+
|
45
|
+
logging.info(
|
46
|
+
"[Client #%d] Local differential privacy (using randomized response) applied.",
|
47
|
+
self.client_id,
|
48
|
+
)
|
49
|
+
|
50
|
+
return output
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from plato.processors import base
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(base.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, client_id=None, **kwargs) -> None:
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
|
21
|
+
self.client_id = client_id
|
22
|
+
|
23
|
+
def process(self, data: Any) -> Any:
|
24
|
+
"""
|
25
|
+
Implements a Processor for unbatching MistNet PyTorch features into the dataset form.
|
26
|
+
"""
|
27
|
+
feature_dataset = []
|
28
|
+
|
29
|
+
for logits, targets in data:
|
30
|
+
for i in np.arange(logits.shape[0]): # each sample in the batch
|
31
|
+
feature_dataset.append((logits[i].clone(), targets[i].clone()))
|
32
|
+
|
33
|
+
logging.info(
|
34
|
+
"[Client #%d] Features extracted from %s examples.",
|
35
|
+
self.client_id,
|
36
|
+
len(feature_dataset),
|
37
|
+
)
|
38
|
+
|
39
|
+
return feature_dataset
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for converting MistNet features from numpy ndarrays to PyTorch tensors.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from plato.processors import base
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(base.Processor):
|
14
|
+
"""
|
15
|
+
Implements a Processor for converting MistNet features from numpy ndarrays to PyTorch tensors.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, server_id=None, **kwargs) -> None:
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
|
21
|
+
self.server_id = server_id
|
22
|
+
|
23
|
+
def process(self, data: Any) -> Any:
|
24
|
+
"""
|
25
|
+
Converts MistNet features from numpy ndarrays to PyTorch tensors.
|
26
|
+
"""
|
27
|
+
feature_dataset = []
|
28
|
+
|
29
|
+
for logit, target in data:
|
30
|
+
# Uses torch.as_tensor() as opposed to torch.tensor() to avoid data copying
|
31
|
+
# according to https://pytorch.org/docs/stable/generated/torch.tensor.html
|
32
|
+
feature_dataset.append((torch.as_tensor(logit), torch.as_tensor(target)))
|
33
|
+
|
34
|
+
logging.info(
|
35
|
+
"[Server #%d] Features converted from ndarrays to PyTorch tensors.",
|
36
|
+
self.server_id,
|
37
|
+
)
|
38
|
+
|
39
|
+
return feature_dataset
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""
|
2
|
+
Base processor for processing PyTorch models.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import pickle
|
7
|
+
import sys
|
8
|
+
from typing import OrderedDict
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
from plato.processors import base
|
13
|
+
|
14
|
+
|
15
|
+
class Processor(base.Processor):
|
16
|
+
"""Base processor for processing PyTorch models."""
|
17
|
+
|
18
|
+
def __init__(self, client_id=None, server_id=None, **kwargs) -> None:
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
self.client_id = client_id
|
21
|
+
self.server_id = server_id
|
22
|
+
|
23
|
+
def process(self, data: OrderedDict) -> OrderedDict:
|
24
|
+
"""
|
25
|
+
Processes PyTorch model parameter.
|
26
|
+
The data is a state_dict of a PyTorch model.
|
27
|
+
"""
|
28
|
+
old_data_size = sys.getsizeof(pickle.dumps(data))
|
29
|
+
|
30
|
+
new_data = OrderedDict()
|
31
|
+
for layer_name, layer_params in data.items():
|
32
|
+
new_data[layer_name] = self._process_layer(layer_params)
|
33
|
+
|
34
|
+
new_data_size = sys.getsizeof(pickle.dumps(new_data))
|
35
|
+
|
36
|
+
if self.client_id is None:
|
37
|
+
logging.info(
|
38
|
+
"[Server #%d] Processed the model and changed its size from %.2f MB to %.2f MB.",
|
39
|
+
self.server_id,
|
40
|
+
old_data_size / 1024**2,
|
41
|
+
new_data_size / 1024**2,
|
42
|
+
)
|
43
|
+
else:
|
44
|
+
logging.info(
|
45
|
+
"[Client #%d] Processed the model and changed its size from %.2f MB to %.2f MB.",
|
46
|
+
self.client_id,
|
47
|
+
old_data_size / 1024**2,
|
48
|
+
new_data_size / 1024**2,
|
49
|
+
)
|
50
|
+
|
51
|
+
return new_data
|
52
|
+
|
53
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
54
|
+
"""Processes an individual layer of the model."""
|
55
|
+
return layer
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for compressing model weights.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import pickle
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
import zstd
|
10
|
+
|
11
|
+
from plato.processors import model
|
12
|
+
|
13
|
+
|
14
|
+
class Processor(model.Processor):
|
15
|
+
"""
|
16
|
+
Implements a Processor for compressing model parameters.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, compression_level=1, **kwargs) -> None:
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
|
22
|
+
self.compression_level = compression_level
|
23
|
+
|
24
|
+
def process(self, data: Any) -> Any:
|
25
|
+
"""Implements a Processor for compressing model parameters."""
|
26
|
+
|
27
|
+
output = zstd.compress(pickle.dumps(data), self.compression_level)
|
28
|
+
|
29
|
+
if self.client_id is None:
|
30
|
+
logging.info("[Server #%d] Compressed model parameters.", self.server_id)
|
31
|
+
else:
|
32
|
+
logging.info("[Client #%d] Compressed model parameters.", self.client_id)
|
33
|
+
|
34
|
+
return output
|
@@ -0,0 +1,37 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for decompressing model weights.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import pickle
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
import zstd
|
10
|
+
|
11
|
+
from plato.processors import model
|
12
|
+
|
13
|
+
|
14
|
+
class Processor(model.Processor):
|
15
|
+
"""
|
16
|
+
Implements a Processor for decompressing model parameters.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, **kwargs) -> None:
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
|
22
|
+
def process(self, data: Any) -> Any:
|
23
|
+
"""Implements a Processor for decompressing model parameters."""
|
24
|
+
|
25
|
+
output = pickle.loads(zstd.decompress(data))
|
26
|
+
|
27
|
+
if self.client_id is None:
|
28
|
+
logging.info(
|
29
|
+
"[Server #%d] Decompressed received model parameters.",
|
30
|
+
self.server_id,
|
31
|
+
)
|
32
|
+
else:
|
33
|
+
logging.info(
|
34
|
+
"[Client #%d] Decompressed received model parameters.",
|
35
|
+
self.client_id,
|
36
|
+
)
|
37
|
+
return output
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""
|
2
|
+
A processor that decrypts model weights of MaskCrypt.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from plato.processors import model
|
10
|
+
from plato.utils import homo_enc
|
11
|
+
|
12
|
+
|
13
|
+
class Processor(model.Processor):
|
14
|
+
"""
|
15
|
+
A processor that decrypts model tensors
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, **kwargs) -> None:
|
19
|
+
super().__init__(**kwargs)
|
20
|
+
|
21
|
+
self.context = homo_enc.get_ckks_context()
|
22
|
+
weight_shapes = {}
|
23
|
+
para_nums = {}
|
24
|
+
extract_model = self.trainer.model.cpu().state_dict()
|
25
|
+
|
26
|
+
for key in extract_model.keys():
|
27
|
+
weight_shapes[key] = extract_model[key].size()
|
28
|
+
para_nums[key] = torch.numel(extract_model[key])
|
29
|
+
|
30
|
+
self.weight_shapes = weight_shapes
|
31
|
+
self.para_nums = para_nums
|
32
|
+
|
33
|
+
def process(self, data: Any) -> Any:
|
34
|
+
"""Deserialize and decrypt the model weights."""
|
35
|
+
deserialized_weights = homo_enc.deserialize_weights(data, self.context)
|
36
|
+
|
37
|
+
output = homo_enc.decrypt_weights(
|
38
|
+
deserialized_weights, self.weight_shapes, self.para_nums
|
39
|
+
)
|
40
|
+
|
41
|
+
return output
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""
|
2
|
+
Processor for creating a deep copy of the PyTorch model state_dict.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import copy
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from plato.processors import model
|
10
|
+
|
11
|
+
|
12
|
+
class Processor(model.Processor):
|
13
|
+
"""
|
14
|
+
Processor for creating a deep copy of the PyTorch model state_dict.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, **kwargs) -> None:
|
18
|
+
super().__init__(**kwargs)
|
19
|
+
|
20
|
+
def _process_layer(self, layer: torch.Tensor):
|
21
|
+
return copy.deepcopy(layer)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for dequantizing model parameters.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from plato.processors import model
|
8
|
+
|
9
|
+
|
10
|
+
class Processor(model.Processor):
|
11
|
+
"""
|
12
|
+
Implements a Processor for dequantizing model parameters.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
16
|
+
"""Quantizes each individual layer of the model."""
|
17
|
+
|
18
|
+
return layer.to(torch.float32)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor to decompress and dequantize upload models.
|
3
|
+
|
4
|
+
In more detail, this processor first decompresses each received parameter.
|
5
|
+
Next, dequantize each upload parameter under the given quantization level.
|
6
|
+
Hence, 8-bit received parameters can be converted into 32-bit parameters.
|
7
|
+
|
8
|
+
Reference:
|
9
|
+
|
10
|
+
Alistarh, D., Grubic, D., Li, J., Tomioka, R., & Vojnovic, M. (2017).
|
11
|
+
"QSGD: Communication-efficient SGD via gradient quantization and encoding."
|
12
|
+
Advances in neural information processing systems.
|
13
|
+
|
14
|
+
https://proceedings.neurips.cc/paper/2017/file/6c340f25839e6acdc73414517203f5f0-Paper.pdf
|
15
|
+
"""
|
16
|
+
|
17
|
+
from struct import unpack
|
18
|
+
from typing import Any
|
19
|
+
|
20
|
+
import torch
|
21
|
+
|
22
|
+
from plato.processors import model
|
23
|
+
|
24
|
+
|
25
|
+
class Processor(model.Processor):
|
26
|
+
"""
|
27
|
+
Implements a Processor to dequantize model parameters quantized with QSGD.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, quantization_level=64, **kwargs) -> None:
|
31
|
+
super().__init__(**kwargs)
|
32
|
+
|
33
|
+
self.quantization_level = quantization_level # must <= 128!
|
34
|
+
|
35
|
+
def _process_layer(self, layer: Any) -> Any:
|
36
|
+
"""Dequantizes each individual layer of the model."""
|
37
|
+
|
38
|
+
# Step 1: decompress the header
|
39
|
+
tuning_param = self.quantization_level - 1
|
40
|
+
max_v = unpack("!f", layer[0:4])[0]
|
41
|
+
numel = unpack("!I", layer[4:8])[0]
|
42
|
+
dimensions = unpack("!h", layer[8:10])[0]
|
43
|
+
size = []
|
44
|
+
for i in range(dimensions):
|
45
|
+
size.append(unpack("!h", layer[10 + 2 * i : 12 + 2 * i])[0])
|
46
|
+
|
47
|
+
# Step 2: decompress the content
|
48
|
+
layer = layer[10 + 2 * dimensions :]
|
49
|
+
zeta = []
|
50
|
+
prefix = b"\x00\x00\x00"
|
51
|
+
for i in range(numel):
|
52
|
+
tmp = unpack("!I", prefix + layer[i : i + 1])[0]
|
53
|
+
if tmp >= 128:
|
54
|
+
tmp = -1 * (tmp - 128)
|
55
|
+
zeta.append(tmp)
|
56
|
+
zeta = torch.tensor(zeta).reshape(size)
|
57
|
+
|
58
|
+
# Step 3: dequantize the content
|
59
|
+
zeta = zeta * max_v / tuning_param
|
60
|
+
|
61
|
+
return zeta
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""
|
2
|
+
A processor that encrypts model weights in MaskCrypt.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from plato.processors import model
|
11
|
+
from plato.utils import homo_enc
|
12
|
+
|
13
|
+
|
14
|
+
class Processor(model.Processor):
|
15
|
+
"""
|
16
|
+
A processor that encrypts model weights with given encryption mask.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, mask=None, **kwargs) -> None:
|
20
|
+
super().__init__(**kwargs)
|
21
|
+
self.context = homo_enc.get_ckks_context()
|
22
|
+
self.mask = mask
|
23
|
+
|
24
|
+
para_nums = {}
|
25
|
+
extract_model = self.trainer.model.cpu().state_dict()
|
26
|
+
for key in extract_model.keys():
|
27
|
+
para_nums[key] = torch.numel(extract_model[key])
|
28
|
+
self.para_nums = para_nums
|
29
|
+
|
30
|
+
def process(self, data: Any) -> Any:
|
31
|
+
logging.info(
|
32
|
+
"[Client #%d] Encrypt the model weights with given encryption mask.",
|
33
|
+
self.client_id,
|
34
|
+
)
|
35
|
+
|
36
|
+
encrypted_weights = homo_enc.encrypt_weights(
|
37
|
+
data,
|
38
|
+
serialize=True,
|
39
|
+
context=self.context,
|
40
|
+
indices=self.mask,
|
41
|
+
)
|
42
|
+
|
43
|
+
return encrypted_weights
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Implements a Processor for quantizing model parameters.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from plato.processors import model
|
8
|
+
|
9
|
+
|
10
|
+
class Processor(model.Processor):
|
11
|
+
"""
|
12
|
+
Implements a Processor to quantize model parameters to 16-bit floating points.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def _process_layer(self, layer: torch.Tensor) -> torch.Tensor:
|
16
|
+
"""Dequantizes each individual layer of the model."""
|
17
|
+
|
18
|
+
return layer.to(torch.bfloat16)
|