plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
"""
|
2
|
+
The registry for samplers designed to partition the dataset across the clients.
|
3
|
+
|
4
|
+
Having a registry of all available classes is convenient for retrieving an instance based
|
5
|
+
on a configuration at run-time.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from collections import OrderedDict
|
10
|
+
|
11
|
+
from plato.config import Config
|
12
|
+
|
13
|
+
from plato.samplers import (
|
14
|
+
iid,
|
15
|
+
dirichlet,
|
16
|
+
mixed,
|
17
|
+
orthogonal,
|
18
|
+
all_inclusive,
|
19
|
+
distribution_noniid,
|
20
|
+
label_quantity_noniid,
|
21
|
+
mixed_label_quantity_noniid,
|
22
|
+
sample_quantity_noniid,
|
23
|
+
modality_iid,
|
24
|
+
modality_quantity_noniid,
|
25
|
+
)
|
26
|
+
|
27
|
+
registered_samplers = OrderedDict(
|
28
|
+
[
|
29
|
+
("iid", iid.Sampler),
|
30
|
+
("noniid", dirichlet.Sampler),
|
31
|
+
("mixed", mixed.Sampler),
|
32
|
+
("orthogonal", orthogonal.Sampler),
|
33
|
+
("all_inclusive", all_inclusive.Sampler),
|
34
|
+
("distribution_noniid", distribution_noniid.Sampler),
|
35
|
+
("label_quantity_noniid", label_quantity_noniid.Sampler),
|
36
|
+
("mixed_label_quantity_noniid", mixed_label_quantity_noniid.Sampler),
|
37
|
+
("sample_quantity_noniid", sample_quantity_noniid.Sampler),
|
38
|
+
("modality_iid", modality_iid.Sampler),
|
39
|
+
("modality_quantity_noniid", modality_quantity_noniid.Sampler),
|
40
|
+
]
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
def get(datasource, client_id, testing=False, **kwargs):
|
45
|
+
"""Get an instance of the sampler."""
|
46
|
+
|
47
|
+
sampler_type = (
|
48
|
+
kwargs["sampler_type"]
|
49
|
+
if "sampler_type" in kwargs
|
50
|
+
else Config().data.testset_sampler
|
51
|
+
if testing and hasattr(Config().data, "testset_sampler")
|
52
|
+
else Config().data.sampler
|
53
|
+
)
|
54
|
+
if testing:
|
55
|
+
logging.info("[Client #%d] Test set sampler: %s", client_id, sampler_type)
|
56
|
+
else:
|
57
|
+
logging.info("[Client #%d] Sampler: %s", client_id, sampler_type)
|
58
|
+
|
59
|
+
if sampler_type in registered_samplers:
|
60
|
+
registered_sampler = registered_samplers[sampler_type](
|
61
|
+
datasource, client_id, testing=testing
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
raise ValueError(f"No such sampler: {sampler_type}")
|
65
|
+
|
66
|
+
return registered_sampler
|
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across quantity size of clients.
|
3
|
+
|
4
|
+
This sampler implements one type of sample distribution skew called:
|
5
|
+
|
6
|
+
Quantity skewness:
|
7
|
+
The local dataset sizes of clients follow the Dirichlet distribution that is
|
8
|
+
parameterized by the "client_quantity_concentration".
|
9
|
+
|
10
|
+
Within each client, sample sizes of different classes are the same.
|
11
|
+
|
12
|
+
For Example:
|
13
|
+
Setting client_quantity_concentration = 0.1 will induce extreme data scale
|
14
|
+
unbalance between clients.
|
15
|
+
The sample sizes of clients follow the Dirichlet distribution.
|
16
|
+
classes 1 2 3 ... 8 9
|
17
|
+
client1 5 6 7 5 8
|
18
|
+
client2 50 45 67 49 56
|
19
|
+
...
|
20
|
+
clientN 6 7 11 10 7
|
21
|
+
|
22
|
+
"""
|
23
|
+
|
24
|
+
import numpy as np
|
25
|
+
import torch
|
26
|
+
from torch.utils.data import SubsetRandomSampler
|
27
|
+
|
28
|
+
from plato.config import Config
|
29
|
+
from plato.samplers import base
|
30
|
+
from plato.samplers import sampler_utils
|
31
|
+
|
32
|
+
|
33
|
+
class Sampler(base.Sampler):
|
34
|
+
"""Create a data sampler for each client to use a divided partition of the
|
35
|
+
dataset, biased across partition size."""
|
36
|
+
|
37
|
+
def __init__(self, datasource, client_id, testing):
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
self.client_id = client_id
|
41
|
+
|
42
|
+
np.random.seed(self.random_seed)
|
43
|
+
|
44
|
+
# obtain the dataset information
|
45
|
+
if testing:
|
46
|
+
dataset = datasource.get_test_set()
|
47
|
+
else:
|
48
|
+
dataset = datasource.get_train_set()
|
49
|
+
|
50
|
+
# The list of labels (targets) for all the examples
|
51
|
+
self.targets_list = datasource.targets
|
52
|
+
|
53
|
+
self.dataset_size = len(dataset)
|
54
|
+
|
55
|
+
indices = list(range(self.dataset_size))
|
56
|
+
|
57
|
+
np.random.shuffle(indices)
|
58
|
+
|
59
|
+
# Concentration parameter to be used in the Dirichlet distribution
|
60
|
+
concentration = (
|
61
|
+
Config().data.client_quantity_concentration
|
62
|
+
if hasattr(Config().data, "client_quantity_concentration")
|
63
|
+
else 1.0
|
64
|
+
)
|
65
|
+
|
66
|
+
min_partition_size = Config().data.min_partition_size
|
67
|
+
total_clients = Config().clients.total_clients
|
68
|
+
|
69
|
+
self.subset_indices = self.sample_quantity_skew(
|
70
|
+
dataset_indices=indices,
|
71
|
+
dataset_size=self.dataset_size,
|
72
|
+
min_partition_size=min_partition_size,
|
73
|
+
concentration=concentration,
|
74
|
+
num_clients=total_clients,
|
75
|
+
)[client_id - 1]
|
76
|
+
|
77
|
+
def sample_quantity_skew(
|
78
|
+
self,
|
79
|
+
dataset_indices,
|
80
|
+
dataset_size,
|
81
|
+
min_partition_size,
|
82
|
+
concentration,
|
83
|
+
num_clients,
|
84
|
+
):
|
85
|
+
"""Create the quantity-based sample skewness"""
|
86
|
+
proportions = sampler_utils.create_dirichlet_skew(
|
87
|
+
total_size=dataset_size,
|
88
|
+
concentration=concentration,
|
89
|
+
min_partition_size=min_partition_size,
|
90
|
+
number_partitions=num_clients + 1,
|
91
|
+
is_extend_total_size=True,
|
92
|
+
)
|
93
|
+
|
94
|
+
proportions_range = (np.cumsum(proportions) * dataset_size).astype(int)[:-1]
|
95
|
+
|
96
|
+
required_total_size = proportions_range[-1]
|
97
|
+
extended_dataset_indices = sampler_utils.extend_indices(
|
98
|
+
indices=dataset_indices, required_total_size=required_total_size
|
99
|
+
)
|
100
|
+
|
101
|
+
# obtain the assigned subdataset indices for current client
|
102
|
+
clients_assigned_idxs = np.split(extended_dataset_indices, proportions_range)[
|
103
|
+
:-1
|
104
|
+
]
|
105
|
+
|
106
|
+
return clients_assigned_idxs
|
107
|
+
|
108
|
+
def get(self):
|
109
|
+
"""Obtains an instance of the sampler."""
|
110
|
+
gen = torch.Generator()
|
111
|
+
gen.manual_seed(self.random_seed)
|
112
|
+
return SubsetRandomSampler(self.subset_indices, generator=gen)
|
113
|
+
|
114
|
+
def num_samples(self):
|
115
|
+
"""Returns the length of the dataset after sampling."""
|
116
|
+
return len(self.subset_indices)
|
117
|
+
|
118
|
+
def get_sampled_data_condition(self):
|
119
|
+
"""Get the detailed info of the trainset"""
|
120
|
+
targets_array = np.array(self.targets_list)
|
121
|
+
client_sampled_subset_labels = targets_array[self.subset_indices]
|
122
|
+
unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
|
123
|
+
return np.asarray((unique, counts)).T
|
@@ -0,0 +1,190 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions that samples local datasets
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
def extend_indices(indices, required_total_size):
|
9
|
+
"""Extend the indices to obtain the required total size
|
10
|
+
by duplicating the indices"""
|
11
|
+
# add extra samples to make it evenly divisible, if needed
|
12
|
+
if len(indices) < required_total_size:
|
13
|
+
while len(indices) < required_total_size:
|
14
|
+
indices += indices[: (required_total_size - len(indices))]
|
15
|
+
else:
|
16
|
+
indices = indices[:required_total_size]
|
17
|
+
assert len(indices) == required_total_size
|
18
|
+
|
19
|
+
return indices
|
20
|
+
|
21
|
+
|
22
|
+
def generate_left_classes_pool(anchor_classes, all_classes, keep_anchor_size=1):
|
23
|
+
"""Generate classes pool by 1. removng anchor classes from the all classes
|
24
|
+
2. randomly select 'keep_anchor_size' from anchor classes to the left
|
25
|
+
class pool."""
|
26
|
+
|
27
|
+
if anchor_classes is None:
|
28
|
+
return all_classes
|
29
|
+
|
30
|
+
# obtain subset classes from the anchor class
|
31
|
+
left_anchor_classes = np.random.choice(
|
32
|
+
anchor_classes, size=keep_anchor_size, replace=False
|
33
|
+
)
|
34
|
+
# remove the anchor classes from the whole classes
|
35
|
+
left_classes_id_list = [
|
36
|
+
class_id for class_id in all_classes if class_id not in anchor_classes
|
37
|
+
]
|
38
|
+
|
39
|
+
# combine the left anchor classes and the left whole classes to
|
40
|
+
# obtain the left classes pool for global classes assignmenr
|
41
|
+
left_classes_id_list = left_anchor_classes.tolist() + left_classes_id_list
|
42
|
+
|
43
|
+
return left_classes_id_list
|
44
|
+
|
45
|
+
|
46
|
+
def assign_fully_classes(dataset_labels, dataset_classes, num_clients, client_id):
|
47
|
+
"""Assign full classes to each client"""
|
48
|
+
|
49
|
+
# define the client_id to sample index mapper
|
50
|
+
clients_dataidx_map = {
|
51
|
+
client_id: np.ndarray(0, dtype=np.int64) for client_id in range(num_clients)
|
52
|
+
}
|
53
|
+
|
54
|
+
dataset_labels = np.array(dataset_labels)
|
55
|
+
|
56
|
+
for class_id in dataset_classes:
|
57
|
+
idx_k = np.where(dataset_labels == class_id)[0]
|
58
|
+
|
59
|
+
# the samples of each class is evenly assigned to this client
|
60
|
+
split = np.array_split(idx_k, num_clients)
|
61
|
+
clients_dataidx_map[client_id] = np.append(
|
62
|
+
clients_dataidx_map[client_id], split[client_id]
|
63
|
+
)
|
64
|
+
return clients_dataidx_map
|
65
|
+
|
66
|
+
|
67
|
+
def assign_sub_classes(
|
68
|
+
dataset_labels,
|
69
|
+
dataset_classes,
|
70
|
+
num_clients,
|
71
|
+
per_client_classes_size,
|
72
|
+
anchor_classes=None,
|
73
|
+
consistent_clients=None,
|
74
|
+
keep_anchor_classes_size=None,
|
75
|
+
):
|
76
|
+
"""Assign subset of classes to each client and assign corresponding samples of classes
|
77
|
+
|
78
|
+
Args:
|
79
|
+
dataset_labels (list): a list of lables of global samples
|
80
|
+
dataset_classes (list): a list containing classes of the dataset
|
81
|
+
num_clients (int): total number of clients for classes assignment
|
82
|
+
per_client_classes_size (int): the number of classes assigned to each client
|
83
|
+
anchor_classes (list, default []): subset of classes assigned to "consistent_clients"
|
84
|
+
consistent_clients (list, default []): subset of classes containing same classes
|
85
|
+
keep_anchor_classes_size (list, default None): how many classes in anchor are utilized
|
86
|
+
in the class pool for global classes
|
87
|
+
assignment.
|
88
|
+
"""
|
89
|
+
# define the client_id to sample index mapper
|
90
|
+
clients_dataidx_map = {
|
91
|
+
client_id: np.ndarray(0, dtype=np.int64) for client_id in range(num_clients)
|
92
|
+
}
|
93
|
+
dataset_labels = np.array(dataset_labels)
|
94
|
+
|
95
|
+
classes_assigned_count = {cls_i: 0 for cls_i in dataset_classes}
|
96
|
+
clients_contain_classes = {cli_i: [] for cli_i in range(num_clients)}
|
97
|
+
|
98
|
+
for client_id in range(num_clients):
|
99
|
+
if consistent_clients is not None and client_id in consistent_clients:
|
100
|
+
current_assigned_cls = anchor_classes
|
101
|
+
for assigned_cls in current_assigned_cls:
|
102
|
+
classes_assigned_count[assigned_cls] += 1
|
103
|
+
else:
|
104
|
+
left_classes_id_list = generate_left_classes_pool(
|
105
|
+
anchor_classes=anchor_classes,
|
106
|
+
all_classes=dataset_classes,
|
107
|
+
keep_anchor_size=keep_anchor_classes_size,
|
108
|
+
)
|
109
|
+
|
110
|
+
num_classes = len(left_classes_id_list)
|
111
|
+
current_assigned_cls_idx = client_id % num_classes
|
112
|
+
assigned_cls = dataset_classes[current_assigned_cls_idx]
|
113
|
+
current_assigned_cls = [assigned_cls]
|
114
|
+
classes_assigned_count[assigned_cls] += 1
|
115
|
+
j = 1
|
116
|
+
while j < per_client_classes_size:
|
117
|
+
# ind = np.random.randint(0, max_class_id)
|
118
|
+
ind = np.random.choice(left_classes_id_list, size=1)[0]
|
119
|
+
if ind not in current_assigned_cls:
|
120
|
+
j = j + 1
|
121
|
+
current_assigned_cls.append(ind)
|
122
|
+
classes_assigned_count[ind] += 1
|
123
|
+
clients_contain_classes[client_id] = current_assigned_cls
|
124
|
+
|
125
|
+
for class_id in dataset_classes:
|
126
|
+
# skip if this class is never assinged to any clients
|
127
|
+
if classes_assigned_count[class_id] == 0:
|
128
|
+
continue
|
129
|
+
|
130
|
+
idx_k = np.where(dataset_labels == class_id)[0]
|
131
|
+
|
132
|
+
# the samples of current class are evenly assigned to the corresponding clients
|
133
|
+
split = np.array_split(idx_k, classes_assigned_count[class_id])
|
134
|
+
ids = 0
|
135
|
+
for client_id in range(num_clients):
|
136
|
+
if class_id in clients_contain_classes[client_id]:
|
137
|
+
clients_dataidx_map[client_id] = np.append(
|
138
|
+
clients_dataidx_map[client_id], split[ids]
|
139
|
+
)
|
140
|
+
ids += 1
|
141
|
+
return clients_dataidx_map
|
142
|
+
|
143
|
+
|
144
|
+
def create_dirichlet_skew(
|
145
|
+
total_size, # the totoal size to generate partitions
|
146
|
+
concentration, # the beta of the dirichlet dictribution
|
147
|
+
number_partitions, # number of partitions
|
148
|
+
min_partition_size=None, # minimum required size for partitions
|
149
|
+
is_extend_total_size=False,
|
150
|
+
):
|
151
|
+
"""Create the distribution skewness based on the dirichlet distribution
|
152
|
+
|
153
|
+
Note:
|
154
|
+
is_extend_total_size (boolean) determines whether to generate the
|
155
|
+
partitions satisfying min_partition_size by directly extending
|
156
|
+
the total data size.
|
157
|
+
"""
|
158
|
+
if min_partition_size is not None:
|
159
|
+
if not is_extend_total_size:
|
160
|
+
min_size = 0
|
161
|
+
while min_size < min_partition_size:
|
162
|
+
proportions = np.random.dirichlet(
|
163
|
+
np.repeat(concentration, number_partitions)
|
164
|
+
)
|
165
|
+
|
166
|
+
proportions = proportions / proportions.sum()
|
167
|
+
min_size = np.min(proportions * total_size)
|
168
|
+
|
169
|
+
else: # extend the total size to satisfy the minimum requirement
|
170
|
+
minimum_proportion_bound = float(min_partition_size / total_size)
|
171
|
+
|
172
|
+
proportions = np.random.dirichlet(
|
173
|
+
np.repeat(concentration, number_partitions)
|
174
|
+
)
|
175
|
+
|
176
|
+
proportions = proportions / proportions.sum()
|
177
|
+
|
178
|
+
# set the proportion to satisfy the minimum size
|
179
|
+
def set_min_bound(proportion):
|
180
|
+
if proportion > minimum_proportion_bound:
|
181
|
+
return proportion
|
182
|
+
else:
|
183
|
+
return minimum_proportion_bound
|
184
|
+
|
185
|
+
proportions = list(map(set_min_bound, proportions))
|
186
|
+
|
187
|
+
else:
|
188
|
+
proportions = np.random.dirichlet(np.repeat(concentration, number_partitions))
|
189
|
+
|
190
|
+
return proportions
|
File without changes
|