plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,132 @@
1
+ """
2
+ Samples data from a dataset, biased across labels according to the Dirichlet
3
+ distribution and biased across data size according to Dirichlet distribution.
4
+
5
+ This sampler can introduce the hardest non-IID data scenarios because it contains:
6
+
7
+ - Label skewness - equals to the sampler called dirichlet, i.e., dirichlet.py
8
+ The number of classes contained in clients follows the Dirichlet distribution
9
+ that is parameterized by the "label_concentration".
10
+ - Quantity skewness - equals to the sampler called "sample_quantity_noniid.py".
11
+ The local dataset sizes of clients follow the Dirichlet distribution that is
12
+ parameterized by the "client_quantity_concentration".
13
+
14
+ For example,
15
+ 1. Setting label_concentration = 0.1 will induce extreme label unbalance between clients.
16
+ When there are ten classes, each client only contains sufficient samples from one class.
17
+ classes 1 2 3 ... 8 9
18
+ client1 100 8 9 3 7
19
+ client2 4 108 7 9 6
20
+ ...
21
+ clientN 3 10 11 99 2
22
+ 2. Setting client_quantity_concentration = 0.1 will induce extreme data scale
23
+ unbalance between clients.
24
+ The sample sizes of clients follow the Dirichlet distribution.
25
+ classes 1 2 3 ... 8 9
26
+ client1 5 6 7 5 8
27
+ client2 50 45 67 49 56
28
+ ...
29
+ clientN 6 7 11 10 7
30
+ 3. Then, this sampler introduces the above two unbalance conditions simultaneously.
31
+ classes 1 2 3 ... 8 9
32
+ client1 60 66 380 45 38
33
+ client2 90 5 3 6 8
34
+ ...
35
+ clientN 1 50 1 1 1
36
+ """
37
+
38
+ import numpy as np
39
+ import torch
40
+
41
+ from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
42
+
43
+ from plato.config import Config
44
+ from plato.samplers import base
45
+ from plato.samplers import sampler_utils
46
+
47
+
48
+ class Sampler(base.Sampler):
49
+ """Create a data sampler for each client to use a divided partition of the
50
+ dataset, biased across labels according to the Dirichlet distribution
51
+ and biased partition size."""
52
+
53
+ def __init__(self, datasource, client_id, testing):
54
+ super().__init__()
55
+ self.client_id = client_id
56
+
57
+ # set the random seed based on client id
58
+ np.random.seed(self.random_seed * int(client_id))
59
+
60
+ # obtain the dataset information
61
+ if testing:
62
+ target_list = datasource.get_test_set().targets
63
+ else:
64
+ # the list of labels (targets) for all the examples
65
+ target_list = datasource.get_train_set().targets
66
+
67
+ class_list = datasource.classes()
68
+ total_data_size = len(target_list)
69
+
70
+ # obtain the configuration
71
+ min_partition_size = (
72
+ Config().data.min_partition_size
73
+ if hasattr(Config().data, "min_partition_size")
74
+ else 100
75
+ )
76
+ total_clients = Config().clients.total_clients
77
+
78
+ client_quantity_concentration = (
79
+ Config().data.client_quantity_concentration
80
+ if hasattr(Config().data, "client_quantity_concentration")
81
+ else 1.0
82
+ )
83
+
84
+ label_concentration = (
85
+ Config().data.label_concentration
86
+ if hasattr(Config().data, "label_concentration")
87
+ else 1.0
88
+ )
89
+
90
+ self.client_partition = sampler_utils.create_dirichlet_skew(
91
+ total_size=total_data_size,
92
+ concentration=client_quantity_concentration,
93
+ min_partition_size=None,
94
+ number_partitions=total_clients,
95
+ )[client_id - 1]
96
+
97
+ self.client_partition_size = int(total_data_size * self.client_partition)
98
+ self.client_partition_size = max(self.client_partition_size, min_partition_size)
99
+
100
+ self.client_label_proportions = sampler_utils.create_dirichlet_skew(
101
+ total_size=len(class_list),
102
+ concentration=label_concentration,
103
+ min_partition_size=None,
104
+ number_partitions=len(class_list),
105
+ )
106
+
107
+ self.sample_weights = self.client_label_proportions[target_list]
108
+
109
+ def get(self):
110
+ """Obtains an instance of the sampler."""
111
+ gen = torch.Generator()
112
+ gen.manual_seed(self.random_seed)
113
+
114
+ # Samples without replacement using the sample weights
115
+ subset_indices = list(
116
+ WeightedRandomSampler(
117
+ weights=self.sample_weights,
118
+ num_samples=self.client_partition_size,
119
+ replacement=False,
120
+ generator=gen,
121
+ )
122
+ )
123
+
124
+ return SubsetRandomSampler(subset_indices, generator=gen)
125
+
126
+ def num_samples(self):
127
+ """Returns the length of the dataset after sampling."""
128
+ return self.client_partition_size
129
+
130
+ def get_sampler_condition(self):
131
+ """Obtain the label ratio and the sampler configuration"""
132
+ return self.client_partition, self.client_label_proportions
plato/samplers/iid.py ADDED
@@ -0,0 +1,53 @@
1
+ """
2
+ Samples data from a dataset in an independent and identically distributed fashion.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import SubsetRandomSampler
8
+
9
+ from plato.config import Config
10
+ from plato.samplers import base
11
+
12
+
13
+ class Sampler(base.Sampler):
14
+ """Create a data sampler for each client to use a randomly divided partition of the
15
+ dataset."""
16
+
17
+ def __init__(self, datasource, client_id, testing):
18
+ super().__init__()
19
+
20
+ if testing:
21
+ dataset = datasource.get_test_set()
22
+ else:
23
+ dataset = datasource.get_train_set()
24
+
25
+ self.dataset_size = len(dataset)
26
+ indices = list(range(self.dataset_size))
27
+ np.random.seed(self.random_seed)
28
+ np.random.shuffle(indices)
29
+
30
+ partition_size = Config().data.partition_size
31
+ total_clients = Config().clients.total_clients
32
+ total_size = partition_size * total_clients
33
+
34
+ # add extra samples to make it evenly divisible, if needed
35
+ if len(indices) < total_size:
36
+ while len(indices) < total_size:
37
+ indices += indices[: (total_size - len(indices))]
38
+ else:
39
+ indices = indices[:total_size]
40
+ assert len(indices) == total_size
41
+
42
+ # Compute the indices of data in the subset for this client
43
+ self.subset_indices = indices[(int(client_id) - 1) : total_size : total_clients]
44
+
45
+ def get(self):
46
+ """Obtains an instance of the sampler."""
47
+ gen = torch.Generator()
48
+ gen.manual_seed(self.random_seed)
49
+ return SubsetRandomSampler(self.subset_indices, generator=gen)
50
+
51
+ def num_samples(self):
52
+ """Returns the length of the dataset after sampling."""
53
+ return len(self.subset_indices)
@@ -0,0 +1,119 @@
1
+ """
2
+ Samples data from a dataset, biased across labels, and the number of labels
3
+ (corresponding classes) in different clients is the same.
4
+
5
+ This sampler implements one type of label distribution skew called:
6
+
7
+ Quantity-based label imbalance:
8
+ Each client contains a fixed number of classes parameterized by the
9
+ "per_client_classes_size", while the number of samples in each class
10
+ is almost the same. Besides, the classes assigned to each client are
11
+ randomly selected from all classes.
12
+
13
+ The samples of one class are equally divided and assigned to clients
14
+ who contain this class. Thus, the samples of different clients
15
+ are mutual-exclusive.
16
+
17
+ For Example:
18
+ Setting per_client_classes_size = 2 will induce the condition that each client
19
+ only contains two classes.
20
+ classes 1 2 3 ... 8 9
21
+ client1 100 0 100 0 0
22
+ client2 0 108 100 0 0
23
+ ...
24
+ clientN 0 0 0 100 100
25
+
26
+ We have N clients while K clients contain class c. As class c contains D_c samples,
27
+ each client in K will contain D_c / K samples of this class.
28
+ """
29
+
30
+ import numpy as np
31
+ import torch
32
+ from torch.utils.data import SubsetRandomSampler
33
+
34
+ from plato.config import Config
35
+ from plato.samplers import base
36
+
37
+ from plato.samplers import sampler_utils
38
+
39
+
40
+ class Sampler(base.Sampler):
41
+ """Create a data sampler for each client to use a divided partition of the
42
+ dataset, biased across classes according to the parameter per_client_classes_size."""
43
+
44
+ def __init__(self, datasource, client_id, testing):
45
+ super().__init__()
46
+ self.client_id = client_id
47
+
48
+ # Different clients should share the randomness
49
+ # as the assignment of classes is completed in each
50
+ # sampling process.
51
+ # Thus, they share the clients_dataidx_map
52
+ np.random.seed(self.random_seed)
53
+
54
+ per_client_classes_size = Config().data.per_client_classes_size
55
+ total_clients = Config().clients.total_clients
56
+
57
+ # obtain the dataset information
58
+ if testing:
59
+ target_list = datasource.get_test_set().targets
60
+ else:
61
+ # the list of labels (targets) for all the examples
62
+ target_list = datasource.get_train_set().targets
63
+
64
+ self.targets_list = target_list
65
+ classes_text_list = datasource.classes()
66
+ classes_id_list = list(range(len(classes_text_list)))
67
+
68
+ self.clients_dataidx_map = {
69
+ client_id: np.ndarray(0, dtype=np.int64)
70
+ for client_id in range(total_clients)
71
+ }
72
+ # construct the quantity label skewness
73
+ self.quantity_label_skew(
74
+ dataset_labels=self.targets_list,
75
+ dataset_classes=classes_id_list,
76
+ num_clients=total_clients,
77
+ per_client_classes_size=per_client_classes_size,
78
+ )
79
+
80
+ self.subset_indices = self.clients_dataidx_map[client_id - 1]
81
+
82
+ def quantity_label_skew(
83
+ self, dataset_labels, dataset_classes, num_clients, per_client_classes_size
84
+ ):
85
+ """Achieve the quantity-based lable skewness"""
86
+ client_id = self.client_id
87
+ # each client contains the full classes
88
+ if per_client_classes_size == len(dataset_classes):
89
+ self.clients_dataidx_map = sampler_utils.assign_fully_classes(
90
+ dataset_labels, dataset_classes, num_clients, client_id
91
+ )
92
+ else:
93
+ self.clients_dataidx_map = sampler_utils.assign_sub_classes(
94
+ dataset_labels,
95
+ dataset_classes,
96
+ num_clients,
97
+ per_client_classes_size,
98
+ anchor_classes=None,
99
+ consistent_clients=None,
100
+ keep_anchor_classes_size=None,
101
+ )
102
+
103
+ def get(self):
104
+ """Obtains an instance of the sampler."""
105
+ gen = torch.Generator()
106
+ gen.manual_seed(self.random_seed)
107
+
108
+ return SubsetRandomSampler(self.subset_indices, generator=gen)
109
+
110
+ def num_samples(self):
111
+ """Returns the length of the dataset after sampling."""
112
+ return len(self.subset_indices)
113
+
114
+ def get_trainset_condition(self):
115
+ """Obtain the detailed information in the trainser"""
116
+ targets_array = np.array(self.targets_list)
117
+ client_sampled_subset_labels = targets_array[self.subset_indices]
118
+ unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
119
+ return np.asarray((unique, counts)).T
@@ -0,0 +1,44 @@
1
+ """
2
+ Samples data from a dataset as clients' local datasets.
3
+ Some are biased across labels according to the Dirichlet distribution,
4
+ while some are in an independent and identically distributed fashion.
5
+ """
6
+
7
+ import numpy as np
8
+ from plato.config import Config
9
+ from plato.samplers import dirichlet
10
+
11
+
12
+ class Sampler(dirichlet.Sampler):
13
+ """Create a data sampler for each client to use a divided partition of the dataset,
14
+ either biased across labels according to the Dirichlet distribution, or in an iid fashion."""
15
+
16
+ def __init__(self, datasource, client_id, testing):
17
+ super().__init__(datasource, client_id, testing)
18
+
19
+ assert hasattr(Config().data, "non_iid_clients")
20
+ non_iid_clients = Config().data.non_iid_clients
21
+
22
+ if isinstance(non_iid_clients, int):
23
+ # Gived the number of non-iid clients
24
+ self.non_iid_clients_list = [
25
+ x + 1 for x in range(int(non_iid_clients))
26
+ ] # [int(non_iid_clients)]
27
+ else:
28
+ # Gived the list of non-iid clients
29
+ self.non_iid_clients_list = [
30
+ int(x.strip()) for x in non_iid_clients.split(",")
31
+ ]
32
+
33
+ if int(client_id) not in self.non_iid_clients_list:
34
+ if testing:
35
+ target_list = datasource.get_test_set().targets
36
+ else:
37
+ target_list = datasource.targets()
38
+ class_list = datasource.classes()
39
+ self.sample_weights = np.array(
40
+ [1 / len(class_list) for _ in range(len(class_list))]
41
+ )[target_list]
42
+
43
+ # Different iid clients should have a different random seed for Generator
44
+ self.random_seed = self.random_seed * int(client_id)
@@ -0,0 +1,128 @@
1
+ """
2
+ Samples data from a dataset, 1). biased across labels, and 2). the number of labels
3
+ (corresponding classes) in different clients is the same, and 3). part of clients
4
+ share same classes 4) the classes shared by these clients are partly used by other
5
+ clients.
6
+
7
+ This sampler implements the basic label quantity noniid as that in "label_quantity_noniid.py".
8
+ However, part of clients "consistent_clients_size" contain same classes "anchor_classes". Then,
9
+ the "keep_anchor_classes_size" classes of "consistent_clients" are also used in classes pool
10
+ to complete the class assignment.
11
+
12
+ For Example:
13
+ Setting per_client_classes_size = 3, anchor_classes=[2, 3, 9], consistent_clients=[0,1,N],
14
+ keep_anchor_classes_size=1 will induce the condition:
15
+
16
+ classes 1 2 3 ... 7 8 9
17
+ client1 0 350 350 0 0 350
18
+ client2 0 350 350 0 0 350
19
+ client3 100 20 0 0 100 0
20
+ client4 100 0 0 100 100 0
21
+ ...
22
+ clientN 0 350 350 0 0 350
23
+
24
+ We have N clients while K clients contain class c. As class c contains D_c samples,
25
+ each client in K will contain D_c / K samples of this class.
26
+ """
27
+
28
+ import numpy as np
29
+ import torch
30
+ from torch.utils.data import SubsetRandomSampler
31
+
32
+ from plato.config import Config
33
+ from plato.samplers import base
34
+
35
+ from plato.samplers import sampler_utils
36
+
37
+
38
+ class Sampler(base.Sampler):
39
+ """Create a data sampler for each client to use a divided partition of the
40
+ dataset, biased across classes according to the parameter per_client_classes_size."""
41
+
42
+ def __init__(self, datasource, client_id, testing):
43
+ super().__init__()
44
+ self.client_id = client_id
45
+
46
+ # Different clients should share the randomness
47
+ # as the assignment of classes is completed in each
48
+ # sampling process.
49
+ # Thus, they share the clients_dataidx_map
50
+ np.random.seed(self.random_seed)
51
+
52
+ per_client_classes_size = Config().data.per_client_classes_size
53
+ anchor_classes = Config().data.anchor_classes
54
+ consistent_clients_size = Config().data.consistent_clients_size
55
+ keep_anchor_classes_size = Config().data.keep_anchor_classes_size
56
+ total_clients = Config().clients.total_clients
57
+
58
+ assert per_client_classes_size == len(anchor_classes)
59
+
60
+ self.consistent_clients = np.random.choice(
61
+ list(range(total_clients)), size=consistent_clients_size, replace=False
62
+ )
63
+ self.anchor_classes = anchor_classes
64
+ self.keep_anchor_classes_size = keep_anchor_classes_size
65
+
66
+ # obtain the dataset information
67
+ if testing:
68
+ target_list = datasource.get_test_set().targets
69
+ else:
70
+ # the list of labels (targets) for all the examples
71
+ target_list = datasource.get_train_set().targets
72
+
73
+ self.targets_list = target_list
74
+ classes_text_list = datasource.classes()
75
+ classes_id_list = list(range(len(classes_text_list)))
76
+
77
+ self.clients_dataidx_map = {
78
+ client_id: np.ndarray(0, dtype=np.int64)
79
+ for client_id in range(total_clients)
80
+ }
81
+ # construct the quantity label skewness
82
+ self.quantity_label_skew(
83
+ dataset_labels=self.targets_list,
84
+ dataset_classes=classes_id_list,
85
+ num_clients=total_clients,
86
+ per_client_classes_size=per_client_classes_size,
87
+ )
88
+
89
+ self.subset_indices = self.clients_dataidx_map[client_id - 1]
90
+
91
+ def quantity_label_skew(
92
+ self, dataset_labels, dataset_classes, num_clients, per_client_classes_size
93
+ ):
94
+ """Achieve the quantity-based lable skewness"""
95
+ client_id = self.client_id
96
+ # each client contains the full classes
97
+ if per_client_classes_size == len(dataset_classes):
98
+ self.clients_dataidx_map = sampler_utils.assign_fully_classes(
99
+ dataset_labels, dataset_classes, num_clients, client_id
100
+ )
101
+ else:
102
+ self.clients_dataidx_map = sampler_utils.assign_sub_classes(
103
+ dataset_labels,
104
+ dataset_classes,
105
+ num_clients,
106
+ per_client_classes_size,
107
+ anchor_classes=self.anchor_classes,
108
+ consistent_clients=self.consistent_clients,
109
+ keep_anchor_classes_size=self.keep_anchor_classes_size,
110
+ )
111
+
112
+ def get(self):
113
+ """Obtains an instance of the sampler."""
114
+ gen = torch.Generator()
115
+ gen.manual_seed(self.random_seed)
116
+
117
+ return SubsetRandomSampler(self.subset_indices, generator=gen)
118
+
119
+ def num_samples(self):
120
+ """Returns the length of the dataset after sampling."""
121
+ return len(self.subset_indices)
122
+
123
+ def get_trainset_condition(self):
124
+ """Obtain the detailed information in the trainser"""
125
+ targets_array = np.array(self.targets_list)
126
+ client_sampled_subset_labels = targets_array[self.subset_indices]
127
+ unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
128
+ return np.asarray((unique, counts)).T
@@ -0,0 +1,42 @@
1
+ """
2
+ Samples data from a dataset, biased across modalities in an
3
+ independent and identically distributed (IID) fashion.
4
+
5
+ Thus, all modalities of one sample are utilized as the input.
6
+
7
+ There is no difference between the train sampler and test sampler.
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ from plato.samplers import base
13
+
14
+
15
+ class Sampler(base.Sampler):
16
+ """Create a data sampler for each client to use a randomly divided partition of the
17
+ dataset."""
18
+
19
+ def __init__(self, datasource, client_id):
20
+ super().__init__()
21
+
22
+ self.client_id = client_id
23
+ if hasattr(datasource, "get_modality_name"):
24
+ modalities_name = datasource.get_modality_name()
25
+ else: # default: it only contains image data
26
+ modalities_name = ["rgb"]
27
+
28
+ np.random.seed(self.random_seed)
29
+
30
+ # obtain the modalities that hold for this data
31
+ self.subset_modalities = modalities_name
32
+
33
+ def get(self):
34
+ """Obtains the modality sampler.
35
+ Note: the sampler here is utilized as the mask to
36
+ remove modalities.
37
+ """
38
+ return self.subset_modalities
39
+
40
+ def modality_size(self):
41
+ """Obtain the utilized modality size"""
42
+ return len(self.subset_modalities)
@@ -0,0 +1,56 @@
1
+ """
2
+ Samples data from a dataset, biased across modalities in an
3
+ quantity-based nonIID fashion.
4
+
5
+ Thus, the quantity-based modality non-IID can be achieved by just keeping
6
+ one subset of modalities in each sample.
7
+
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ from plato.config import Config
13
+ from plato.samplers import base
14
+
15
+
16
+ class Sampler(base.Sampler):
17
+ """Create a data sampler for each client to use a randomly divided partition of the
18
+ dataset."""
19
+
20
+ def __init__(self, datasource, client_id):
21
+ super().__init__()
22
+ self.client_id = client_id
23
+ if hasattr(datasource, "get_modality_name"):
24
+ modalities_name = datasource.get_modality_name()
25
+ else: # default: it only contains image data
26
+ modalities_name = ["rgb"]
27
+
28
+ # Different clients should have a different bias across modalities
29
+ np.random.seed(self.random_seed * int(client_id))
30
+
31
+ # default, one sample holds only one modality
32
+ per_client_modalties_size = (
33
+ Config().data.per_client_modalties_size
34
+ if hasattr(Config().data, "per_client_modalties_size")
35
+ else 1
36
+ )
37
+
38
+ assert per_client_modalties_size < len(modalities_name)
39
+
40
+ # obtain the modalities that hold for this data
41
+ self.subset_modalities = np.random.choice(
42
+ modalities_name,
43
+ per_client_modalties_size,
44
+ replace=False,
45
+ )
46
+
47
+ def get(self):
48
+ """Obtains the modality sampler.
49
+ Note: the sampler here is utilized as the mask to
50
+ remove modalities.
51
+ """
52
+ return self.subset_modalities
53
+
54
+ def modality_size(self):
55
+ """Obtain the utilized modality size"""
56
+ return len(self.subset_modalities)
@@ -0,0 +1,99 @@
1
+ """
2
+ A sampler for orthogonal cross-silo federated learning.
3
+ Each insitution's clients have data of different classes.
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
9
+ from plato.config import Config
10
+
11
+ from plato.samplers import base
12
+
13
+
14
+ class Sampler(base.Sampler):
15
+ """Create a data sampler for each client to use a divided partition of the dataset.
16
+ A client only has data of certain classes."""
17
+
18
+ def __init__(self, datasource, client_id, testing):
19
+ super().__init__()
20
+
21
+ # Different clients should have a different bias across the labels
22
+ np.random.seed(self.random_seed * int(client_id))
23
+
24
+ self.partition_size = Config().data.partition_size
25
+
26
+ if testing:
27
+ target_list = datasource.get_test_set().targets
28
+ else:
29
+ # The list of labels (targets) for all the examples
30
+ target_list = datasource.targets()
31
+ class_list = datasource.classes()
32
+
33
+ max_client_id = int(Config().clients.total_clients)
34
+
35
+ if client_id > max_client_id:
36
+ # This client is an edge server
37
+ institution_id = client_id - 1 - max_client_id
38
+ else:
39
+ institution_id = (client_id - 1) % int(Config().algorithm.total_silos)
40
+
41
+ if hasattr(Config().data, "institution_class_ids"):
42
+ institution_class_ids = Config().data.institution_class_ids
43
+ class_ids = [x.strip() for x in institution_class_ids.split(";")][
44
+ institution_id
45
+ ]
46
+ class_id_list = [int(x.strip()) for x in class_ids.split(",")]
47
+ else:
48
+ class_ids = np.array_split(
49
+ [i for i in range(len(class_list))], Config().algorithm.total_silos
50
+ )[institution_id]
51
+ class_id_list = class_ids.tolist()
52
+
53
+ if (
54
+ hasattr(Config().data, "label_distribution")
55
+ and Config().data.label_distribution == "noniid"
56
+ ):
57
+ # Concentration parameter to be used in the Dirichlet distribution
58
+ concentration = (
59
+ Config().data.concentration
60
+ if hasattr(Config().data, "concentration")
61
+ else 1.0
62
+ )
63
+
64
+ class_proportions = np.random.dirichlet(
65
+ np.repeat(concentration, len(class_id_list))
66
+ )
67
+
68
+ else:
69
+ class_proportions = [
70
+ 1.0 / len(class_id_list) for i in range(len(class_id_list))
71
+ ]
72
+
73
+ target_proportions = [0 for i in range(len(class_list))]
74
+ for index, class_id in enumerate(class_id_list):
75
+ target_proportions[class_id] = class_proportions[index]
76
+ target_proportions = np.asarray(target_proportions)
77
+
78
+ self.sample_weights = target_proportions[target_list]
79
+
80
+ def get(self):
81
+ """Obtains an instance of the sampler."""
82
+ gen = torch.Generator()
83
+ gen.manual_seed(self.random_seed)
84
+
85
+ # Samples without replacement using the sample weights
86
+ subset_indices = list(
87
+ WeightedRandomSampler(
88
+ weights=self.sample_weights,
89
+ num_samples=self.partition_size,
90
+ replacement=False,
91
+ generator=gen,
92
+ )
93
+ )
94
+
95
+ return SubsetRandomSampler(subset_indices, generator=gen)
96
+
97
+ def num_samples(self):
98
+ """Returns the length of the dataset after sampling."""
99
+ return self.partition_size