plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,66 @@
1
+ """
2
+ The registry for samplers designed to partition the dataset across the clients.
3
+
4
+ Having a registry of all available classes is convenient for retrieving an instance based
5
+ on a configuration at run-time.
6
+ """
7
+
8
+ import logging
9
+ from collections import OrderedDict
10
+
11
+ from plato.config import Config
12
+
13
+ from plato.samplers import (
14
+ iid,
15
+ dirichlet,
16
+ mixed,
17
+ orthogonal,
18
+ all_inclusive,
19
+ distribution_noniid,
20
+ label_quantity_noniid,
21
+ mixed_label_quantity_noniid,
22
+ sample_quantity_noniid,
23
+ modality_iid,
24
+ modality_quantity_noniid,
25
+ )
26
+
27
+ registered_samplers = OrderedDict(
28
+ [
29
+ ("iid", iid.Sampler),
30
+ ("noniid", dirichlet.Sampler),
31
+ ("mixed", mixed.Sampler),
32
+ ("orthogonal", orthogonal.Sampler),
33
+ ("all_inclusive", all_inclusive.Sampler),
34
+ ("distribution_noniid", distribution_noniid.Sampler),
35
+ ("label_quantity_noniid", label_quantity_noniid.Sampler),
36
+ ("mixed_label_quantity_noniid", mixed_label_quantity_noniid.Sampler),
37
+ ("sample_quantity_noniid", sample_quantity_noniid.Sampler),
38
+ ("modality_iid", modality_iid.Sampler),
39
+ ("modality_quantity_noniid", modality_quantity_noniid.Sampler),
40
+ ]
41
+ )
42
+
43
+
44
+ def get(datasource, client_id, testing=False, **kwargs):
45
+ """Get an instance of the sampler."""
46
+
47
+ sampler_type = (
48
+ kwargs["sampler_type"]
49
+ if "sampler_type" in kwargs
50
+ else Config().data.testset_sampler
51
+ if testing and hasattr(Config().data, "testset_sampler")
52
+ else Config().data.sampler
53
+ )
54
+ if testing:
55
+ logging.info("[Client #%d] Test set sampler: %s", client_id, sampler_type)
56
+ else:
57
+ logging.info("[Client #%d] Sampler: %s", client_id, sampler_type)
58
+
59
+ if sampler_type in registered_samplers:
60
+ registered_sampler = registered_samplers[sampler_type](
61
+ datasource, client_id, testing=testing
62
+ )
63
+ else:
64
+ raise ValueError(f"No such sampler: {sampler_type}")
65
+
66
+ return registered_sampler
@@ -0,0 +1,123 @@
1
+ """
2
+ Samples data from a dataset, biased across quantity size of clients.
3
+
4
+ This sampler implements one type of sample distribution skew called:
5
+
6
+ Quantity skewness:
7
+ The local dataset sizes of clients follow the Dirichlet distribution that is
8
+ parameterized by the "client_quantity_concentration".
9
+
10
+ Within each client, sample sizes of different classes are the same.
11
+
12
+ For Example:
13
+ Setting client_quantity_concentration = 0.1 will induce extreme data scale
14
+ unbalance between clients.
15
+ The sample sizes of clients follow the Dirichlet distribution.
16
+ classes 1 2 3 ... 8 9
17
+ client1 5 6 7 5 8
18
+ client2 50 45 67 49 56
19
+ ...
20
+ clientN 6 7 11 10 7
21
+
22
+ """
23
+
24
+ import numpy as np
25
+ import torch
26
+ from torch.utils.data import SubsetRandomSampler
27
+
28
+ from plato.config import Config
29
+ from plato.samplers import base
30
+ from plato.samplers import sampler_utils
31
+
32
+
33
+ class Sampler(base.Sampler):
34
+ """Create a data sampler for each client to use a divided partition of the
35
+ dataset, biased across partition size."""
36
+
37
+ def __init__(self, datasource, client_id, testing):
38
+ super().__init__()
39
+
40
+ self.client_id = client_id
41
+
42
+ np.random.seed(self.random_seed)
43
+
44
+ # obtain the dataset information
45
+ if testing:
46
+ dataset = datasource.get_test_set()
47
+ else:
48
+ dataset = datasource.get_train_set()
49
+
50
+ # The list of labels (targets) for all the examples
51
+ self.targets_list = datasource.targets
52
+
53
+ self.dataset_size = len(dataset)
54
+
55
+ indices = list(range(self.dataset_size))
56
+
57
+ np.random.shuffle(indices)
58
+
59
+ # Concentration parameter to be used in the Dirichlet distribution
60
+ concentration = (
61
+ Config().data.client_quantity_concentration
62
+ if hasattr(Config().data, "client_quantity_concentration")
63
+ else 1.0
64
+ )
65
+
66
+ min_partition_size = Config().data.min_partition_size
67
+ total_clients = Config().clients.total_clients
68
+
69
+ self.subset_indices = self.sample_quantity_skew(
70
+ dataset_indices=indices,
71
+ dataset_size=self.dataset_size,
72
+ min_partition_size=min_partition_size,
73
+ concentration=concentration,
74
+ num_clients=total_clients,
75
+ )[client_id - 1]
76
+
77
+ def sample_quantity_skew(
78
+ self,
79
+ dataset_indices,
80
+ dataset_size,
81
+ min_partition_size,
82
+ concentration,
83
+ num_clients,
84
+ ):
85
+ """Create the quantity-based sample skewness"""
86
+ proportions = sampler_utils.create_dirichlet_skew(
87
+ total_size=dataset_size,
88
+ concentration=concentration,
89
+ min_partition_size=min_partition_size,
90
+ number_partitions=num_clients + 1,
91
+ is_extend_total_size=True,
92
+ )
93
+
94
+ proportions_range = (np.cumsum(proportions) * dataset_size).astype(int)[:-1]
95
+
96
+ required_total_size = proportions_range[-1]
97
+ extended_dataset_indices = sampler_utils.extend_indices(
98
+ indices=dataset_indices, required_total_size=required_total_size
99
+ )
100
+
101
+ # obtain the assigned subdataset indices for current client
102
+ clients_assigned_idxs = np.split(extended_dataset_indices, proportions_range)[
103
+ :-1
104
+ ]
105
+
106
+ return clients_assigned_idxs
107
+
108
+ def get(self):
109
+ """Obtains an instance of the sampler."""
110
+ gen = torch.Generator()
111
+ gen.manual_seed(self.random_seed)
112
+ return SubsetRandomSampler(self.subset_indices, generator=gen)
113
+
114
+ def num_samples(self):
115
+ """Returns the length of the dataset after sampling."""
116
+ return len(self.subset_indices)
117
+
118
+ def get_sampled_data_condition(self):
119
+ """Get the detailed info of the trainset"""
120
+ targets_array = np.array(self.targets_list)
121
+ client_sampled_subset_labels = targets_array[self.subset_indices]
122
+ unique, counts = np.unique(client_sampled_subset_labels, return_counts=True)
123
+ return np.asarray((unique, counts)).T
@@ -0,0 +1,190 @@
1
+ """
2
+ Utility functions that samples local datasets
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ def extend_indices(indices, required_total_size):
9
+ """Extend the indices to obtain the required total size
10
+ by duplicating the indices"""
11
+ # add extra samples to make it evenly divisible, if needed
12
+ if len(indices) < required_total_size:
13
+ while len(indices) < required_total_size:
14
+ indices += indices[: (required_total_size - len(indices))]
15
+ else:
16
+ indices = indices[:required_total_size]
17
+ assert len(indices) == required_total_size
18
+
19
+ return indices
20
+
21
+
22
+ def generate_left_classes_pool(anchor_classes, all_classes, keep_anchor_size=1):
23
+ """Generate classes pool by 1. removng anchor classes from the all classes
24
+ 2. randomly select 'keep_anchor_size' from anchor classes to the left
25
+ class pool."""
26
+
27
+ if anchor_classes is None:
28
+ return all_classes
29
+
30
+ # obtain subset classes from the anchor class
31
+ left_anchor_classes = np.random.choice(
32
+ anchor_classes, size=keep_anchor_size, replace=False
33
+ )
34
+ # remove the anchor classes from the whole classes
35
+ left_classes_id_list = [
36
+ class_id for class_id in all_classes if class_id not in anchor_classes
37
+ ]
38
+
39
+ # combine the left anchor classes and the left whole classes to
40
+ # obtain the left classes pool for global classes assignmenr
41
+ left_classes_id_list = left_anchor_classes.tolist() + left_classes_id_list
42
+
43
+ return left_classes_id_list
44
+
45
+
46
+ def assign_fully_classes(dataset_labels, dataset_classes, num_clients, client_id):
47
+ """Assign full classes to each client"""
48
+
49
+ # define the client_id to sample index mapper
50
+ clients_dataidx_map = {
51
+ client_id: np.ndarray(0, dtype=np.int64) for client_id in range(num_clients)
52
+ }
53
+
54
+ dataset_labels = np.array(dataset_labels)
55
+
56
+ for class_id in dataset_classes:
57
+ idx_k = np.where(dataset_labels == class_id)[0]
58
+
59
+ # the samples of each class is evenly assigned to this client
60
+ split = np.array_split(idx_k, num_clients)
61
+ clients_dataidx_map[client_id] = np.append(
62
+ clients_dataidx_map[client_id], split[client_id]
63
+ )
64
+ return clients_dataidx_map
65
+
66
+
67
+ def assign_sub_classes(
68
+ dataset_labels,
69
+ dataset_classes,
70
+ num_clients,
71
+ per_client_classes_size,
72
+ anchor_classes=None,
73
+ consistent_clients=None,
74
+ keep_anchor_classes_size=None,
75
+ ):
76
+ """Assign subset of classes to each client and assign corresponding samples of classes
77
+
78
+ Args:
79
+ dataset_labels (list): a list of lables of global samples
80
+ dataset_classes (list): a list containing classes of the dataset
81
+ num_clients (int): total number of clients for classes assignment
82
+ per_client_classes_size (int): the number of classes assigned to each client
83
+ anchor_classes (list, default []): subset of classes assigned to "consistent_clients"
84
+ consistent_clients (list, default []): subset of classes containing same classes
85
+ keep_anchor_classes_size (list, default None): how many classes in anchor are utilized
86
+ in the class pool for global classes
87
+ assignment.
88
+ """
89
+ # define the client_id to sample index mapper
90
+ clients_dataidx_map = {
91
+ client_id: np.ndarray(0, dtype=np.int64) for client_id in range(num_clients)
92
+ }
93
+ dataset_labels = np.array(dataset_labels)
94
+
95
+ classes_assigned_count = {cls_i: 0 for cls_i in dataset_classes}
96
+ clients_contain_classes = {cli_i: [] for cli_i in range(num_clients)}
97
+
98
+ for client_id in range(num_clients):
99
+ if consistent_clients is not None and client_id in consistent_clients:
100
+ current_assigned_cls = anchor_classes
101
+ for assigned_cls in current_assigned_cls:
102
+ classes_assigned_count[assigned_cls] += 1
103
+ else:
104
+ left_classes_id_list = generate_left_classes_pool(
105
+ anchor_classes=anchor_classes,
106
+ all_classes=dataset_classes,
107
+ keep_anchor_size=keep_anchor_classes_size,
108
+ )
109
+
110
+ num_classes = len(left_classes_id_list)
111
+ current_assigned_cls_idx = client_id % num_classes
112
+ assigned_cls = dataset_classes[current_assigned_cls_idx]
113
+ current_assigned_cls = [assigned_cls]
114
+ classes_assigned_count[assigned_cls] += 1
115
+ j = 1
116
+ while j < per_client_classes_size:
117
+ # ind = np.random.randint(0, max_class_id)
118
+ ind = np.random.choice(left_classes_id_list, size=1)[0]
119
+ if ind not in current_assigned_cls:
120
+ j = j + 1
121
+ current_assigned_cls.append(ind)
122
+ classes_assigned_count[ind] += 1
123
+ clients_contain_classes[client_id] = current_assigned_cls
124
+
125
+ for class_id in dataset_classes:
126
+ # skip if this class is never assinged to any clients
127
+ if classes_assigned_count[class_id] == 0:
128
+ continue
129
+
130
+ idx_k = np.where(dataset_labels == class_id)[0]
131
+
132
+ # the samples of current class are evenly assigned to the corresponding clients
133
+ split = np.array_split(idx_k, classes_assigned_count[class_id])
134
+ ids = 0
135
+ for client_id in range(num_clients):
136
+ if class_id in clients_contain_classes[client_id]:
137
+ clients_dataidx_map[client_id] = np.append(
138
+ clients_dataidx_map[client_id], split[ids]
139
+ )
140
+ ids += 1
141
+ return clients_dataidx_map
142
+
143
+
144
+ def create_dirichlet_skew(
145
+ total_size, # the totoal size to generate partitions
146
+ concentration, # the beta of the dirichlet dictribution
147
+ number_partitions, # number of partitions
148
+ min_partition_size=None, # minimum required size for partitions
149
+ is_extend_total_size=False,
150
+ ):
151
+ """Create the distribution skewness based on the dirichlet distribution
152
+
153
+ Note:
154
+ is_extend_total_size (boolean) determines whether to generate the
155
+ partitions satisfying min_partition_size by directly extending
156
+ the total data size.
157
+ """
158
+ if min_partition_size is not None:
159
+ if not is_extend_total_size:
160
+ min_size = 0
161
+ while min_size < min_partition_size:
162
+ proportions = np.random.dirichlet(
163
+ np.repeat(concentration, number_partitions)
164
+ )
165
+
166
+ proportions = proportions / proportions.sum()
167
+ min_size = np.min(proportions * total_size)
168
+
169
+ else: # extend the total size to satisfy the minimum requirement
170
+ minimum_proportion_bound = float(min_partition_size / total_size)
171
+
172
+ proportions = np.random.dirichlet(
173
+ np.repeat(concentration, number_partitions)
174
+ )
175
+
176
+ proportions = proportions / proportions.sum()
177
+
178
+ # set the proportion to satisfy the minimum size
179
+ def set_min_bound(proportion):
180
+ if proportion > minimum_proportion_bound:
181
+ return proportion
182
+ else:
183
+ return minimum_proportion_bound
184
+
185
+ proportions = list(map(set_min_bound, proportions))
186
+
187
+ else:
188
+ proportions = np.random.dirichlet(np.repeat(concentration, number_partitions))
189
+
190
+ return proportions
File without changes