plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across labels according to the Dirichlet
|
3
|
+
distribution and biased across data size according to Dirichlet distribution.
|
4
|
+
|
5
|
+
This sampler can introduce the hardest non-IID data scenarios because it contains:
|
6
|
+
|
7
|
+
- Label skewness - equals to the sampler called dirichlet, i.e., dirichlet.py
|
8
|
+
The number of classes contained in clients follows the Dirichlet distribution
|
9
|
+
that is parameterized by the "label_concentration".
|
10
|
+
- Quantity skewness - equals to the sampler called "sample_quantity_noniid.py".
|
11
|
+
The local dataset sizes of clients follow the Dirichlet distribution that is
|
12
|
+
parameterized by the "client_quantity_concentration".
|
13
|
+
|
14
|
+
For example,
|
15
|
+
1. Setting label_concentration = 0.1 will induce extreme label unbalance between clients.
|
16
|
+
When there are ten classes, each client only contains sufficient samples from one class.
|
17
|
+
classes 1 2 3 ... 8 9
|
18
|
+
client1 100 8 9 3 7
|
19
|
+
client2 4 108 7 9 6
|
20
|
+
...
|
21
|
+
clientN 3 10 11 99 2
|
22
|
+
2. Setting client_quantity_concentration = 0.1 will induce extreme data scale
|
23
|
+
unbalance between clients.
|
24
|
+
The sample sizes of clients follow the Dirichlet distribution.
|
25
|
+
classes 1 2 3 ... 8 9
|
26
|
+
client1 5 6 7 5 8
|
27
|
+
client2 50 45 67 49 56
|
28
|
+
...
|
29
|
+
clientN 6 7 11 10 7
|
30
|
+
3. Then, this sampler introduces the above two unbalance conditions simultaneously.
|
31
|
+
classes 1 2 3 ... 8 9
|
32
|
+
client1 60 66 380 45 38
|
33
|
+
client2 90 5 3 6 8
|
34
|
+
...
|
35
|
+
clientN 1 50 1 1 1
|
36
|
+
"""
|
37
|
+
|
38
|
+
import numpy as np
|
39
|
+
import torch
|
40
|
+
|
41
|
+
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
|
42
|
+
|
43
|
+
from plato.config import Config
|
44
|
+
from plato.samplers import base
|
45
|
+
from plato.samplers import sampler_utils
|
46
|
+
|
47
|
+
|
48
|
+
class Sampler(base.Sampler):
|
49
|
+
"""Create a data sampler for each client to use a divided partition of the
|
50
|
+
dataset, biased across labels according to the Dirichlet distribution
|
51
|
+
and biased partition size."""
|
52
|
+
|
53
|
+
def __init__(self, datasource, client_id, testing):
|
54
|
+
super().__init__()
|
55
|
+
self.client_id = client_id
|
56
|
+
|
57
|
+
# set the random seed based on client id
|
58
|
+
np.random.seed(self.random_seed * int(client_id))
|
59
|
+
|
60
|
+
# obtain the dataset information
|
61
|
+
if testing:
|
62
|
+
target_list = datasource.get_test_set().targets
|
63
|
+
else:
|
64
|
+
# the list of labels (targets) for all the examples
|
65
|
+
target_list = datasource.get_train_set().targets
|
66
|
+
|
67
|
+
class_list = datasource.classes()
|
68
|
+
total_data_size = len(target_list)
|
69
|
+
|
70
|
+
# obtain the configuration
|
71
|
+
min_partition_size = (
|
72
|
+
Config().data.min_partition_size
|
73
|
+
if hasattr(Config().data, "min_partition_size")
|
74
|
+
else 100
|
75
|
+
)
|
76
|
+
total_clients = Config().clients.total_clients
|
77
|
+
|
78
|
+
client_quantity_concentration = (
|
79
|
+
Config().data.client_quantity_concentration
|
80
|
+
if hasattr(Config().data, "client_quantity_concentration")
|
81
|
+
else 1.0
|
82
|
+
)
|
83
|
+
|
84
|
+
label_concentration = (
|
85
|
+
Config().data.label_concentration
|
86
|
+
if hasattr(Config().data, "label_concentration")
|
87
|
+
else 1.0
|
88
|
+
)
|
89
|
+
|
90
|
+
self.client_partition = sampler_utils.create_dirichlet_skew(
|
91
|
+
total_size=total_data_size,
|
92
|
+
concentration=client_quantity_concentration,
|
93
|
+
min_partition_size=None,
|
94
|
+
number_partitions=total_clients,
|
95
|
+
)[client_id - 1]
|
96
|
+
|
97
|
+
self.client_partition_size = int(total_data_size * self.client_partition)
|
98
|
+
self.client_partition_size = max(self.client_partition_size, min_partition_size)
|
99
|
+
|
100
|
+
self.client_label_proportions = sampler_utils.create_dirichlet_skew(
|
101
|
+
total_size=len(class_list),
|
102
|
+
concentration=label_concentration,
|
103
|
+
min_partition_size=None,
|
104
|
+
number_partitions=len(class_list),
|
105
|
+
)
|
106
|
+
|
107
|
+
self.sample_weights = self.client_label_proportions[target_list]
|
108
|
+
|
109
|
+
def get(self):
|
110
|
+
"""Obtains an instance of the sampler."""
|
111
|
+
gen = torch.Generator()
|
112
|
+
gen.manual_seed(self.random_seed)
|
113
|
+
|
114
|
+
# Samples without replacement using the sample weights
|
115
|
+
subset_indices = list(
|
116
|
+
WeightedRandomSampler(
|
117
|
+
weights=self.sample_weights,
|
118
|
+
num_samples=self.client_partition_size,
|
119
|
+
replacement=False,
|
120
|
+
generator=gen,
|
121
|
+
)
|
122
|
+
)
|
123
|
+
|
124
|
+
return SubsetRandomSampler(subset_indices, generator=gen)
|
125
|
+
|
126
|
+
def num_samples(self):
|
127
|
+
"""Returns the length of the dataset after sampling."""
|
128
|
+
return self.client_partition_size
|
129
|
+
|
130
|
+
def get_sampler_condition(self):
|
131
|
+
"""Obtain the label ratio and the sampler configuration"""
|
132
|
+
return self.client_partition, self.client_label_proportions
|
plato/samplers/iid.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset in an independent and identically distributed fashion.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import SubsetRandomSampler
|
8
|
+
|
9
|
+
from plato.config import Config
|
10
|
+
from plato.samplers import base
|
11
|
+
|
12
|
+
|
13
|
+
class Sampler(base.Sampler):
|
14
|
+
"""Create a data sampler for each client to use a randomly divided partition of the
|
15
|
+
dataset."""
|
16
|
+
|
17
|
+
def __init__(self, datasource, client_id, testing):
|
18
|
+
super().__init__()
|
19
|
+
|
20
|
+
if testing:
|
21
|
+
dataset = datasource.get_test_set()
|
22
|
+
else:
|
23
|
+
dataset = datasource.get_train_set()
|
24
|
+
|
25
|
+
self.dataset_size = len(dataset)
|
26
|
+
indices = list(range(self.dataset_size))
|
27
|
+
np.random.seed(self.random_seed)
|
28
|
+
np.random.shuffle(indices)
|
29
|
+
|
30
|
+
partition_size = Config().data.partition_size
|
31
|
+
total_clients = Config().clients.total_clients
|
32
|
+
total_size = partition_size * total_clients
|
33
|
+
|
34
|
+
# add extra samples to make it evenly divisible, if needed
|
35
|
+
if len(indices) < total_size:
|
36
|
+
while len(indices) < total_size:
|
37
|
+
indices += indices[: (total_size - len(indices))]
|
38
|
+
else:
|
39
|
+
indices = indices[:total_size]
|
40
|
+
assert len(indices) == total_size
|
41
|
+
|
42
|
+
# Compute the indices of data in the subset for this client
|
43
|
+
self.subset_indices = indices[(int(client_id) - 1) : total_size : total_clients]
|
44
|
+
|
45
|
+
def get(self):
|
46
|
+
"""Obtains an instance of the sampler."""
|
47
|
+
gen = torch.Generator()
|
48
|
+
gen.manual_seed(self.random_seed)
|
49
|
+
return SubsetRandomSampler(self.subset_indices, generator=gen)
|
50
|
+
|
51
|
+
def num_samples(self):
|
52
|
+
"""Returns the length of the dataset after sampling."""
|
53
|
+
return len(self.subset_indices)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across labels, and the number of labels
|
3
|
+
(corresponding classes) in different clients is the same.
|
4
|
+
|
5
|
+
This sampler implements one type of label distribution skew called:
|
6
|
+
|
7
|
+
Quantity-based label imbalance:
|
8
|
+
Each client contains a fixed number of classes parameterized by the
|
9
|
+
"per_client_classes_size", while the number of samples in each class
|
10
|
+
is almost the same. Besides, the classes assigned to each client are
|
11
|
+
randomly selected from all classes.
|
12
|
+
|
13
|
+
The samples of one class are equally divided and assigned to clients
|
14
|
+
who contain this class. Thus, the samples of different clients
|
15
|
+
are mutual-exclusive.
|
16
|
+
|
17
|
+
For Example:
|
18
|
+
Setting per_client_classes_size = 2 will induce the condition that each client
|
19
|
+
only contains two classes.
|
20
|
+
classes 1 2 3 ... 8 9
|
21
|
+
client1 100 0 100 0 0
|
22
|
+
client2 0 108 100 0 0
|
23
|
+
...
|
24
|
+
clientN 0 0 0 100 100
|
25
|
+
|
26
|
+
We have N clients while K clients contain class c. As class c contains D_c samples,
|
27
|
+
each client in K will contain D_c / K samples of this class.
|
28
|
+
"""
|
29
|
+
|
30
|
+
import numpy as np
|
31
|
+
import torch
|
32
|
+
from torch.utils.data import SubsetRandomSampler
|
33
|
+
|
34
|
+
from plato.config import Config
|
35
|
+
from plato.samplers import base
|
36
|
+
|
37
|
+
from plato.samplers import sampler_utils
|
38
|
+
|
39
|
+
|
40
|
+
class Sampler(base.Sampler):
|
41
|
+
"""Create a data sampler for each client to use a divided partition of the
|
42
|
+
dataset, biased across classes according to the parameter per_client_classes_size."""
|
43
|
+
|
44
|
+
def __init__(self, datasource, client_id, testing):
|
45
|
+
super().__init__()
|
46
|
+
self.client_id = client_id
|
47
|
+
|
48
|
+
# Different clients should share the randomness
|
49
|
+
# as the assignment of classes is completed in each
|
50
|
+
# sampling process.
|
51
|
+
# Thus, they share the clients_dataidx_map
|
52
|
+
np.random.seed(self.random_seed)
|
53
|
+
|
54
|
+
per_client_classes_size = Config().data.per_client_classes_size
|
55
|
+
total_clients = Config().clients.total_clients
|
56
|
+
|
57
|
+
# obtain the dataset information
|
58
|
+
if testing:
|
59
|
+
target_list = datasource.get_test_set().targets
|
60
|
+
else:
|
61
|
+
# the list of labels (targets) for all the examples
|
62
|
+
target_list = datasource.get_train_set().targets
|
63
|
+
|
64
|
+
self.targets_list = target_list
|
65
|
+
classes_text_list = datasource.classes()
|
66
|
+
classes_id_list = list(range(len(classes_text_list)))
|
67
|
+
|
68
|
+
self.clients_dataidx_map = {
|
69
|
+
client_id: np.ndarray(0, dtype=np.int64)
|
70
|
+
for client_id in range(total_clients)
|
71
|
+
}
|
72
|
+
# construct the quantity label skewness
|
73
|
+
self.quantity_label_skew(
|
74
|
+
dataset_labels=self.targets_list,
|
75
|
+
dataset_classes=classes_id_list,
|
76
|
+
num_clients=total_clients,
|
77
|
+
per_client_classes_size=per_client_classes_size,
|
78
|
+
)
|
79
|
+
|
80
|
+
self.subset_indices = self.clients_dataidx_map[client_id - 1]
|
81
|
+
|
82
|
+
def quantity_label_skew(
|
83
|
+
self, dataset_labels, dataset_classes, num_clients, per_client_classes_size
|
84
|
+
):
|
85
|
+
"""Achieve the quantity-based lable skewness"""
|
86
|
+
client_id = self.client_id
|
87
|
+
# each client contains the full classes
|
88
|
+
if per_client_classes_size == len(dataset_classes):
|
89
|
+
self.clients_dataidx_map = sampler_utils.assign_fully_classes(
|
90
|
+
dataset_labels, dataset_classes, num_clients, client_id
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
self.clients_dataidx_map = sampler_utils.assign_sub_classes(
|
94
|
+
dataset_labels,
|
95
|
+
dataset_classes,
|
96
|
+
num_clients,
|
97
|
+
per_client_classes_size,
|
98
|
+
anchor_classes=None,
|
99
|
+
consistent_clients=None,
|
100
|
+
keep_anchor_classes_size=None,
|
101
|
+
)
|
102
|
+
|
103
|
+
def get(self):
|
104
|
+
"""Obtains an instance of the sampler."""
|
105
|
+
gen = torch.Generator()
|
106
|
+
gen.manual_seed(self.random_seed)
|
107
|
+
|
108
|
+
return SubsetRandomSampler(self.subset_indices, generator=gen)
|
109
|
+
|
110
|
+
def num_samples(self):
|
111
|
+
"""Returns the length of the dataset after sampling."""
|
112
|
+
return len(self.subset_indices)
|
113
|
+
|
114
|
+
def get_trainset_condition(self):
|
115
|
+
"""Obtain the detailed information in the trainser"""
|
116
|
+
targets_array = np.array(self.targets_list)
|
117
|
+
client_sampled_subset_labels = targets_array[self.subset_indices]
|
118
|
+
unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
|
119
|
+
return np.asarray((unique, counts)).T
|
plato/samplers/mixed.py
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset as clients' local datasets.
|
3
|
+
Some are biased across labels according to the Dirichlet distribution,
|
4
|
+
while some are in an independent and identically distributed fashion.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from plato.config import Config
|
9
|
+
from plato.samplers import dirichlet
|
10
|
+
|
11
|
+
|
12
|
+
class Sampler(dirichlet.Sampler):
|
13
|
+
"""Create a data sampler for each client to use a divided partition of the dataset,
|
14
|
+
either biased across labels according to the Dirichlet distribution, or in an iid fashion."""
|
15
|
+
|
16
|
+
def __init__(self, datasource, client_id, testing):
|
17
|
+
super().__init__(datasource, client_id, testing)
|
18
|
+
|
19
|
+
assert hasattr(Config().data, "non_iid_clients")
|
20
|
+
non_iid_clients = Config().data.non_iid_clients
|
21
|
+
|
22
|
+
if isinstance(non_iid_clients, int):
|
23
|
+
# Gived the number of non-iid clients
|
24
|
+
self.non_iid_clients_list = [
|
25
|
+
x + 1 for x in range(int(non_iid_clients))
|
26
|
+
] # [int(non_iid_clients)]
|
27
|
+
else:
|
28
|
+
# Gived the list of non-iid clients
|
29
|
+
self.non_iid_clients_list = [
|
30
|
+
int(x.strip()) for x in non_iid_clients.split(",")
|
31
|
+
]
|
32
|
+
|
33
|
+
if int(client_id) not in self.non_iid_clients_list:
|
34
|
+
if testing:
|
35
|
+
target_list = datasource.get_test_set().targets
|
36
|
+
else:
|
37
|
+
target_list = datasource.targets()
|
38
|
+
class_list = datasource.classes()
|
39
|
+
self.sample_weights = np.array(
|
40
|
+
[1 / len(class_list) for _ in range(len(class_list))]
|
41
|
+
)[target_list]
|
42
|
+
|
43
|
+
# Different iid clients should have a different random seed for Generator
|
44
|
+
self.random_seed = self.random_seed * int(client_id)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, 1). biased across labels, and 2). the number of labels
|
3
|
+
(corresponding classes) in different clients is the same, and 3). part of clients
|
4
|
+
share same classes 4) the classes shared by these clients are partly used by other
|
5
|
+
clients.
|
6
|
+
|
7
|
+
This sampler implements the basic label quantity noniid as that in "label_quantity_noniid.py".
|
8
|
+
However, part of clients "consistent_clients_size" contain same classes "anchor_classes". Then,
|
9
|
+
the "keep_anchor_classes_size" classes of "consistent_clients" are also used in classes pool
|
10
|
+
to complete the class assignment.
|
11
|
+
|
12
|
+
For Example:
|
13
|
+
Setting per_client_classes_size = 3, anchor_classes=[2, 3, 9], consistent_clients=[0,1,N],
|
14
|
+
keep_anchor_classes_size=1 will induce the condition:
|
15
|
+
|
16
|
+
classes 1 2 3 ... 7 8 9
|
17
|
+
client1 0 350 350 0 0 350
|
18
|
+
client2 0 350 350 0 0 350
|
19
|
+
client3 100 20 0 0 100 0
|
20
|
+
client4 100 0 0 100 100 0
|
21
|
+
...
|
22
|
+
clientN 0 350 350 0 0 350
|
23
|
+
|
24
|
+
We have N clients while K clients contain class c. As class c contains D_c samples,
|
25
|
+
each client in K will contain D_c / K samples of this class.
|
26
|
+
"""
|
27
|
+
|
28
|
+
import numpy as np
|
29
|
+
import torch
|
30
|
+
from torch.utils.data import SubsetRandomSampler
|
31
|
+
|
32
|
+
from plato.config import Config
|
33
|
+
from plato.samplers import base
|
34
|
+
|
35
|
+
from plato.samplers import sampler_utils
|
36
|
+
|
37
|
+
|
38
|
+
class Sampler(base.Sampler):
|
39
|
+
"""Create a data sampler for each client to use a divided partition of the
|
40
|
+
dataset, biased across classes according to the parameter per_client_classes_size."""
|
41
|
+
|
42
|
+
def __init__(self, datasource, client_id, testing):
|
43
|
+
super().__init__()
|
44
|
+
self.client_id = client_id
|
45
|
+
|
46
|
+
# Different clients should share the randomness
|
47
|
+
# as the assignment of classes is completed in each
|
48
|
+
# sampling process.
|
49
|
+
# Thus, they share the clients_dataidx_map
|
50
|
+
np.random.seed(self.random_seed)
|
51
|
+
|
52
|
+
per_client_classes_size = Config().data.per_client_classes_size
|
53
|
+
anchor_classes = Config().data.anchor_classes
|
54
|
+
consistent_clients_size = Config().data.consistent_clients_size
|
55
|
+
keep_anchor_classes_size = Config().data.keep_anchor_classes_size
|
56
|
+
total_clients = Config().clients.total_clients
|
57
|
+
|
58
|
+
assert per_client_classes_size == len(anchor_classes)
|
59
|
+
|
60
|
+
self.consistent_clients = np.random.choice(
|
61
|
+
list(range(total_clients)), size=consistent_clients_size, replace=False
|
62
|
+
)
|
63
|
+
self.anchor_classes = anchor_classes
|
64
|
+
self.keep_anchor_classes_size = keep_anchor_classes_size
|
65
|
+
|
66
|
+
# obtain the dataset information
|
67
|
+
if testing:
|
68
|
+
target_list = datasource.get_test_set().targets
|
69
|
+
else:
|
70
|
+
# the list of labels (targets) for all the examples
|
71
|
+
target_list = datasource.get_train_set().targets
|
72
|
+
|
73
|
+
self.targets_list = target_list
|
74
|
+
classes_text_list = datasource.classes()
|
75
|
+
classes_id_list = list(range(len(classes_text_list)))
|
76
|
+
|
77
|
+
self.clients_dataidx_map = {
|
78
|
+
client_id: np.ndarray(0, dtype=np.int64)
|
79
|
+
for client_id in range(total_clients)
|
80
|
+
}
|
81
|
+
# construct the quantity label skewness
|
82
|
+
self.quantity_label_skew(
|
83
|
+
dataset_labels=self.targets_list,
|
84
|
+
dataset_classes=classes_id_list,
|
85
|
+
num_clients=total_clients,
|
86
|
+
per_client_classes_size=per_client_classes_size,
|
87
|
+
)
|
88
|
+
|
89
|
+
self.subset_indices = self.clients_dataidx_map[client_id - 1]
|
90
|
+
|
91
|
+
def quantity_label_skew(
|
92
|
+
self, dataset_labels, dataset_classes, num_clients, per_client_classes_size
|
93
|
+
):
|
94
|
+
"""Achieve the quantity-based lable skewness"""
|
95
|
+
client_id = self.client_id
|
96
|
+
# each client contains the full classes
|
97
|
+
if per_client_classes_size == len(dataset_classes):
|
98
|
+
self.clients_dataidx_map = sampler_utils.assign_fully_classes(
|
99
|
+
dataset_labels, dataset_classes, num_clients, client_id
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
self.clients_dataidx_map = sampler_utils.assign_sub_classes(
|
103
|
+
dataset_labels,
|
104
|
+
dataset_classes,
|
105
|
+
num_clients,
|
106
|
+
per_client_classes_size,
|
107
|
+
anchor_classes=self.anchor_classes,
|
108
|
+
consistent_clients=self.consistent_clients,
|
109
|
+
keep_anchor_classes_size=self.keep_anchor_classes_size,
|
110
|
+
)
|
111
|
+
|
112
|
+
def get(self):
|
113
|
+
"""Obtains an instance of the sampler."""
|
114
|
+
gen = torch.Generator()
|
115
|
+
gen.manual_seed(self.random_seed)
|
116
|
+
|
117
|
+
return SubsetRandomSampler(self.subset_indices, generator=gen)
|
118
|
+
|
119
|
+
def num_samples(self):
|
120
|
+
"""Returns the length of the dataset after sampling."""
|
121
|
+
return len(self.subset_indices)
|
122
|
+
|
123
|
+
def get_trainset_condition(self):
|
124
|
+
"""Obtain the detailed information in the trainser"""
|
125
|
+
targets_array = np.array(self.targets_list)
|
126
|
+
client_sampled_subset_labels = targets_array[self.subset_indices]
|
127
|
+
unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
|
128
|
+
return np.asarray((unique, counts)).T
|
@@ -0,0 +1,42 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across modalities in an
|
3
|
+
independent and identically distributed (IID) fashion.
|
4
|
+
|
5
|
+
Thus, all modalities of one sample are utilized as the input.
|
6
|
+
|
7
|
+
There is no difference between the train sampler and test sampler.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from plato.samplers import base
|
13
|
+
|
14
|
+
|
15
|
+
class Sampler(base.Sampler):
|
16
|
+
"""Create a data sampler for each client to use a randomly divided partition of the
|
17
|
+
dataset."""
|
18
|
+
|
19
|
+
def __init__(self, datasource, client_id):
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
self.client_id = client_id
|
23
|
+
if hasattr(datasource, "get_modality_name"):
|
24
|
+
modalities_name = datasource.get_modality_name()
|
25
|
+
else: # default: it only contains image data
|
26
|
+
modalities_name = ["rgb"]
|
27
|
+
|
28
|
+
np.random.seed(self.random_seed)
|
29
|
+
|
30
|
+
# obtain the modalities that hold for this data
|
31
|
+
self.subset_modalities = modalities_name
|
32
|
+
|
33
|
+
def get(self):
|
34
|
+
"""Obtains the modality sampler.
|
35
|
+
Note: the sampler here is utilized as the mask to
|
36
|
+
remove modalities.
|
37
|
+
"""
|
38
|
+
return self.subset_modalities
|
39
|
+
|
40
|
+
def modality_size(self):
|
41
|
+
"""Obtain the utilized modality size"""
|
42
|
+
return len(self.subset_modalities)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
Samples data from a dataset, biased across modalities in an
|
3
|
+
quantity-based nonIID fashion.
|
4
|
+
|
5
|
+
Thus, the quantity-based modality non-IID can be achieved by just keeping
|
6
|
+
one subset of modalities in each sample.
|
7
|
+
|
8
|
+
"""
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.samplers import base
|
14
|
+
|
15
|
+
|
16
|
+
class Sampler(base.Sampler):
|
17
|
+
"""Create a data sampler for each client to use a randomly divided partition of the
|
18
|
+
dataset."""
|
19
|
+
|
20
|
+
def __init__(self, datasource, client_id):
|
21
|
+
super().__init__()
|
22
|
+
self.client_id = client_id
|
23
|
+
if hasattr(datasource, "get_modality_name"):
|
24
|
+
modalities_name = datasource.get_modality_name()
|
25
|
+
else: # default: it only contains image data
|
26
|
+
modalities_name = ["rgb"]
|
27
|
+
|
28
|
+
# Different clients should have a different bias across modalities
|
29
|
+
np.random.seed(self.random_seed * int(client_id))
|
30
|
+
|
31
|
+
# default, one sample holds only one modality
|
32
|
+
per_client_modalties_size = (
|
33
|
+
Config().data.per_client_modalties_size
|
34
|
+
if hasattr(Config().data, "per_client_modalties_size")
|
35
|
+
else 1
|
36
|
+
)
|
37
|
+
|
38
|
+
assert per_client_modalties_size < len(modalities_name)
|
39
|
+
|
40
|
+
# obtain the modalities that hold for this data
|
41
|
+
self.subset_modalities = np.random.choice(
|
42
|
+
modalities_name,
|
43
|
+
per_client_modalties_size,
|
44
|
+
replace=False,
|
45
|
+
)
|
46
|
+
|
47
|
+
def get(self):
|
48
|
+
"""Obtains the modality sampler.
|
49
|
+
Note: the sampler here is utilized as the mask to
|
50
|
+
remove modalities.
|
51
|
+
"""
|
52
|
+
return self.subset_modalities
|
53
|
+
|
54
|
+
def modality_size(self):
|
55
|
+
"""Obtain the utilized modality size"""
|
56
|
+
return len(self.subset_modalities)
|
@@ -0,0 +1,99 @@
|
|
1
|
+
"""
|
2
|
+
A sampler for orthogonal cross-silo federated learning.
|
3
|
+
Each insitution's clients have data of different classes.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
|
9
|
+
from plato.config import Config
|
10
|
+
|
11
|
+
from plato.samplers import base
|
12
|
+
|
13
|
+
|
14
|
+
class Sampler(base.Sampler):
|
15
|
+
"""Create a data sampler for each client to use a divided partition of the dataset.
|
16
|
+
A client only has data of certain classes."""
|
17
|
+
|
18
|
+
def __init__(self, datasource, client_id, testing):
|
19
|
+
super().__init__()
|
20
|
+
|
21
|
+
# Different clients should have a different bias across the labels
|
22
|
+
np.random.seed(self.random_seed * int(client_id))
|
23
|
+
|
24
|
+
self.partition_size = Config().data.partition_size
|
25
|
+
|
26
|
+
if testing:
|
27
|
+
target_list = datasource.get_test_set().targets
|
28
|
+
else:
|
29
|
+
# The list of labels (targets) for all the examples
|
30
|
+
target_list = datasource.targets()
|
31
|
+
class_list = datasource.classes()
|
32
|
+
|
33
|
+
max_client_id = int(Config().clients.total_clients)
|
34
|
+
|
35
|
+
if client_id > max_client_id:
|
36
|
+
# This client is an edge server
|
37
|
+
institution_id = client_id - 1 - max_client_id
|
38
|
+
else:
|
39
|
+
institution_id = (client_id - 1) % int(Config().algorithm.total_silos)
|
40
|
+
|
41
|
+
if hasattr(Config().data, "institution_class_ids"):
|
42
|
+
institution_class_ids = Config().data.institution_class_ids
|
43
|
+
class_ids = [x.strip() for x in institution_class_ids.split(";")][
|
44
|
+
institution_id
|
45
|
+
]
|
46
|
+
class_id_list = [int(x.strip()) for x in class_ids.split(",")]
|
47
|
+
else:
|
48
|
+
class_ids = np.array_split(
|
49
|
+
[i for i in range(len(class_list))], Config().algorithm.total_silos
|
50
|
+
)[institution_id]
|
51
|
+
class_id_list = class_ids.tolist()
|
52
|
+
|
53
|
+
if (
|
54
|
+
hasattr(Config().data, "label_distribution")
|
55
|
+
and Config().data.label_distribution == "noniid"
|
56
|
+
):
|
57
|
+
# Concentration parameter to be used in the Dirichlet distribution
|
58
|
+
concentration = (
|
59
|
+
Config().data.concentration
|
60
|
+
if hasattr(Config().data, "concentration")
|
61
|
+
else 1.0
|
62
|
+
)
|
63
|
+
|
64
|
+
class_proportions = np.random.dirichlet(
|
65
|
+
np.repeat(concentration, len(class_id_list))
|
66
|
+
)
|
67
|
+
|
68
|
+
else:
|
69
|
+
class_proportions = [
|
70
|
+
1.0 / len(class_id_list) for i in range(len(class_id_list))
|
71
|
+
]
|
72
|
+
|
73
|
+
target_proportions = [0 for i in range(len(class_list))]
|
74
|
+
for index, class_id in enumerate(class_id_list):
|
75
|
+
target_proportions[class_id] = class_proportions[index]
|
76
|
+
target_proportions = np.asarray(target_proportions)
|
77
|
+
|
78
|
+
self.sample_weights = target_proportions[target_list]
|
79
|
+
|
80
|
+
def get(self):
|
81
|
+
"""Obtains an instance of the sampler."""
|
82
|
+
gen = torch.Generator()
|
83
|
+
gen.manual_seed(self.random_seed)
|
84
|
+
|
85
|
+
# Samples without replacement using the sample weights
|
86
|
+
subset_indices = list(
|
87
|
+
WeightedRandomSampler(
|
88
|
+
weights=self.sample_weights,
|
89
|
+
num_samples=self.partition_size,
|
90
|
+
replacement=False,
|
91
|
+
generator=gen,
|
92
|
+
)
|
93
|
+
)
|
94
|
+
|
95
|
+
return SubsetRandomSampler(subset_indices, generator=gen)
|
96
|
+
|
97
|
+
def num_samples(self):
|
98
|
+
"""Returns the length of the dataset after sampling."""
|
99
|
+
return self.partition_size
|