plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,237 @@
1
+ import sys
2
+ import os.path as osp
3
+ import json
4
+ import logging
5
+ import pickle
6
+ import time
7
+ import itertools
8
+ import skimage.io as io
9
+
10
+ from pprint import pprint
11
+ import numpy as np
12
+
13
+
14
+ class REFER:
15
+ def __init__(
16
+ self,
17
+ data_root, # the root where the data is stored
18
+ image_dataroot, # the path of the source images
19
+ dataset="refcoco",
20
+ splitBy="unc",
21
+ ):
22
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
23
+ # also provide dataset name and splitBy information
24
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
25
+ print("loading dataset %s into memory..." % dataset)
26
+
27
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
28
+ self.DATA_DIR = osp.join(data_root, dataset)
29
+ if dataset in ["refcoco", "refcoco+", "refcocog"]:
30
+ self.IMAGE_DIR = image_dataroot
31
+ else:
32
+ print("No refer dataset is called [%s]" % dataset)
33
+ sys.exit()
34
+
35
+ # load refs from data/datasetrefs(dataset).json
36
+ tic = time.perf_counter()
37
+ ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
38
+ self.data = {}
39
+ self.data["dataset"] = dataset
40
+ self.data["refs"] = pickle.load(open(ref_file, "rb"))
41
+
42
+ # load annotations from data/dataset/instances.json
43
+ instances_file = osp.join(self.DATA_DIR, "instances.json")
44
+ instances = json.load(open(instances_file, "r"))
45
+ self.data["images"] = instances["images"]
46
+ self.data["annotations"] = instances["annotations"]
47
+ self.data["categories"] = instances["categories"]
48
+
49
+ # create index
50
+ self.createIndex()
51
+ print("DONE (t=%.2fs)" % (time.perf_counter() - tic))
52
+
53
+ def createIndex(self):
54
+ # create sets of mapping
55
+ # 1) Refs: {ref_id: ref}
56
+ # 2) Anns: {ann_id: ann}
57
+ # 3) Imgs: {image_id: image}
58
+ # 4) Cats: {category_id: category_name}
59
+ # 5) Sents: {sent_id: sent}
60
+ # 6) imgToRefs: {image_id: refs}
61
+ # 7) imgToAnns: {image_id: anns}
62
+ # 8) refToAnn: {ref_id: ann}
63
+ # 9) annToRef: {ann_id: ref}
64
+ # 10) catToRefs: {category_id: refs}
65
+ # 11) sentToRef: {sent_id: ref}
66
+ # 12) sentToTokens: {sent_id: tokens}
67
+ print("creating index...")
68
+ # fetch info from instances
69
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
70
+ for ann in self.data["annotations"]:
71
+ Anns[ann["id"]] = ann
72
+ imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
73
+ for img in self.data["images"]:
74
+ img["file_name"] = img["file_name"].strip().split("_")[-1]
75
+ Imgs[img["id"]] = img
76
+ for cat in self.data["categories"]:
77
+ Cats[cat["id"]] = cat["name"]
78
+
79
+ # fetch info from refs
80
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
81
+ Sents, sentToRef, sentToTokens = {}, {}, {}
82
+ for ref in self.data["refs"]:
83
+ # ids
84
+ ref_id = ref["ref_id"]
85
+ ann_id = ref["ann_id"]
86
+ category_id = ref["category_id"]
87
+ image_id = ref["image_id"]
88
+
89
+ # add mapping related to ref
90
+ Refs[ref_id] = ref
91
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
92
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
93
+ refToAnn[ref_id] = Anns[ann_id]
94
+ annToRef[ann_id] = ref
95
+
96
+ # add mapping of sent
97
+ for sent in ref["sentences"]:
98
+ Sents[sent["sent_id"]] = sent
99
+ sentToRef[sent["sent_id"]] = ref
100
+ sentToTokens[sent["sent_id"]] = sent["tokens"]
101
+
102
+ # create class members
103
+ self.Refs = Refs
104
+ self.Anns = Anns
105
+ self.Imgs = Imgs
106
+ self.Cats = Cats
107
+ self.Sents = Sents
108
+ self.imgToRefs = imgToRefs
109
+ self.imgToAnns = imgToAnns
110
+ self.refToAnn = refToAnn
111
+ self.annToRef = annToRef
112
+ self.catToRefs = catToRefs
113
+ self.sentToRef = sentToRef
114
+ self.sentToTokens = sentToTokens
115
+ print("index created.")
116
+
117
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
118
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
119
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
120
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
121
+
122
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
123
+ refs = self.data["refs"]
124
+ else:
125
+ if not len(image_ids) == 0:
126
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
127
+ else:
128
+ refs = self.data["refs"]
129
+ if not len(cat_ids) == 0:
130
+ refs = [ref for ref in refs if ref["category_id"] in cat_ids]
131
+ if not len(ref_ids) == 0:
132
+ refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
133
+ if not len(split) == 0:
134
+ if split in ["testA", "testB", "testC"]:
135
+ refs = [
136
+ ref for ref in refs if split[-1] in ref["split"]
137
+ ] # we also consider testAB, testBC, ...
138
+ elif split in ["testAB", "testBC", "testAC"]:
139
+ refs = [
140
+ ref for ref in refs if ref["split"] == split
141
+ ] # rarely used I guess...
142
+ elif split == "test":
143
+ refs = [ref for ref in refs if "test" in ref["split"]]
144
+ elif split == "train" or split == "val":
145
+ refs = [ref for ref in refs if ref["split"] == split]
146
+ else:
147
+ print("No such split [%s]" % split)
148
+ sys.exit()
149
+ ref_ids = [ref["ref_id"] for ref in refs]
150
+ return ref_ids
151
+
152
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
153
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
154
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
155
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
156
+
157
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
158
+ ann_ids = [ann["id"] for ann in self.data["annotations"]]
159
+ else:
160
+ if not len(image_ids) == 0:
161
+ lists = [
162
+ self.imgToAnns[image_id]
163
+ for image_id in image_ids
164
+ if image_id in self.imgToAnns
165
+ ] # list of [anns]
166
+ anns = list(itertools.chain.from_iterable(lists))
167
+ else:
168
+ anns = self.data["annotations"]
169
+ if not len(cat_ids) == 0:
170
+ anns = [ann for ann in anns if ann["category_id"] in cat_ids]
171
+ ann_ids = [ann["id"] for ann in anns]
172
+ if not len(ref_ids) == 0:
173
+ ids = set(ann_ids).intersection(
174
+ set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
175
+ )
176
+ return ann_ids
177
+
178
+ def getImgIds(self, ref_ids=[]):
179
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
180
+
181
+ if not len(ref_ids) == 0:
182
+ image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
183
+ else:
184
+ image_ids = self.Imgs.keys()
185
+ return image_ids
186
+
187
+ def getCatIds(self):
188
+ return self.Cats.keys()
189
+
190
+ def loadRefs(self, ref_ids=[]):
191
+ if type(ref_ids) == list:
192
+ return [self.Refs[ref_id] for ref_id in ref_ids]
193
+ elif type(ref_ids) == int:
194
+ return [self.Refs[ref_ids]]
195
+
196
+ def loadAnns(self, ann_ids=[]):
197
+ if type(ann_ids) == list:
198
+ return [self.Anns[ann_id] for ann_id in ann_ids]
199
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
200
+ return [self.Anns[ann_ids]]
201
+
202
+ def loadImgs(self, image_ids=[]):
203
+ if type(image_ids) == list:
204
+ return [self.Imgs[image_id] for image_id in image_ids]
205
+ elif type(image_ids) == int:
206
+ return [self.Imgs[image_ids]]
207
+
208
+ def loadImgspath(self, image_ids=[]):
209
+ if type(image_ids) == list:
210
+ return [
211
+ osp.join(self.IMAGE_DIR, self.Imgs[image_id]["file_name"])
212
+ for image_id in image_ids
213
+ ]
214
+ elif type(image_ids) == int:
215
+ return [osp.join(self.IMAGE_DIR, self.Imgs[image_ids]["file_name"])]
216
+
217
+ def loadImgsData(self, image_ids=[]): # the image_ids is obtained from refs
218
+ if type(image_ids) == list:
219
+ return [
220
+ io.imread(osp.join(self.IMAGE_DIR, self.Imgs[image_id]["file_name"]))
221
+ for image_id in image_ids
222
+ ]
223
+ elif type(image_ids) == int:
224
+ return [
225
+ io.imread(osp.join(self.IMAGE_DIR, self.Imgs[image_ids]["file_name"]))
226
+ ]
227
+
228
+ def loadCats(self, cat_ids=[]):
229
+ if type(cat_ids) == list:
230
+ return [self.Cats[cat_id] for cat_id in cat_ids]
231
+ elif type(cat_ids) == int:
232
+ return [self.Cats[cat_ids]]
233
+
234
+ def getRefBox(self, ref_id):
235
+ ref = self.Refs[ref_id]
236
+ ann = self.refToAnn[ref_id]
237
+ return ann["bbox"] # [x, y, w, h]
@@ -0,0 +1,81 @@
1
+ """
2
+ This includes tools for creating the processing tiny dataset.
3
+
4
+ The tiny dataset is sampled from the whole dataset.
5
+
6
+ """
7
+
8
+ import os
9
+ import collections
10
+ import logging
11
+
12
+ import pandas as pd
13
+ import numpy as np
14
+
15
+
16
+ def create_tiny_kinetics_anno(kinetics_annotation_files_info, num_samples, random_seed):
17
+ """Creating the annotation files for tiny kinetics.
18
+
19
+ Args:
20
+ kinetics_annotation_files_info (dict): a dict contains the annotation
21
+ files for different splits. e.g., {"train": xxx, "val": xxx}.
22
+ num_fist_videos (int): the number of samples utilized to create
23
+ tiny dataset.
24
+ random_seed (int): the random seed for sampling samples.
25
+ """
26
+ train_anno_file_path = kinetics_annotation_files_info["train"]
27
+ test_anno_file_path = kinetics_annotation_files_info["test"]
28
+ val_anno_file_path = kinetics_annotation_files_info["val"]
29
+ np.random.seed(random_seed)
30
+
31
+ train_anno_df = pd.read_csv(train_anno_file_path)
32
+ train_selected_samples_df = train_anno_df.iloc[:num_samples]
33
+
34
+ train_selected_classes = train_selected_samples_df["label"].tolist()
35
+
36
+ # select from test/val anno files based on the train classes
37
+ def select_anchored_samples(src_anno_df, anchor_classes):
38
+ selected_df = None
39
+ counted_classes = collections.Counter(anchor_classes)
40
+ for cls in list(counted_classes.keys()):
41
+ num_cls = counted_classes[cls]
42
+ cls_anno_df = src_anno_df[src_anno_df["label"] == cls]
43
+ len_cls_df = len(cls_anno_df)
44
+ selected_samples_idx = np.random.choice(
45
+ list(range(len_cls_df)), size=num_cls
46
+ )
47
+ cls_selected_df = cls_anno_df.iloc[selected_samples_idx]
48
+ if selected_df is None:
49
+ selected_df = cls_selected_df
50
+ else:
51
+ selected_df = pd.concat([selected_df, cls_selected_df])
52
+ return selected_df
53
+
54
+ def save_anno_df(anno_file_path, selected_anno_df):
55
+ anno_file_base_dir = os.path.dirname(anno_file_path)
56
+ anno_file_base_name, ext_type = os.path.basename(anno_file_path).split(".")
57
+ tiny_anno_file_path = os.path.join(
58
+ anno_file_base_dir, anno_file_base_name + "_tiny." + ext_type
59
+ )
60
+
61
+ if os.path.exists(tiny_anno_file_path):
62
+ logging.info(
63
+ "Annotation file for tiny data exists: %s, Using it directly",
64
+ tiny_anno_file_path,
65
+ )
66
+ else:
67
+ selected_anno_df.to_csv(path_or_buf=tiny_anno_file_path, index=False)
68
+
69
+ test_anno_df = pd.read_csv(test_anno_file_path)
70
+ test_selected_df = test_anno_df.iloc[:num_samples]
71
+
72
+ val_anno_df = pd.read_csv(val_anno_file_path)
73
+ val_selected_df = select_anchored_samples(
74
+ src_anno_df=val_anno_df, anchor_classes=train_selected_classes
75
+ )
76
+
77
+ save_anno_df(
78
+ anno_file_path=train_anno_file_path, selected_anno_df=train_selected_samples_df
79
+ )
80
+ save_anno_df(anno_file_path=test_anno_file_path, selected_anno_df=test_selected_df)
81
+ save_anno_df(anno_file_path=val_anno_file_path, selected_anno_df=val_selected_df)
@@ -0,0 +1,79 @@
1
+ """
2
+ The transformers for the video classification task
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.transforms import transforms
8
+
9
+
10
+ class ConvertBHWCtoBCHW(nn.Module):
11
+ """Convert tensor from (B, H, W, C) to (B, C, H, W)"""
12
+
13
+ def forward(self, vid: torch.Tensor) -> torch.Tensor:
14
+ """Change the order of tensor dims"""
15
+ return vid.permute(0, 3, 1, 2)
16
+
17
+
18
+ class ConvertBCHWtoCBHW(nn.Module):
19
+ """Convert tensor from (B, C, H, W) to (C, B, H, W)"""
20
+
21
+ def forward(self, vid: torch.Tensor) -> torch.Tensor:
22
+ """Change the order of tensor dims"""
23
+ return vid.permute(1, 0, 2, 3)
24
+
25
+
26
+ class VideoClassificationTrainTransformer:
27
+ """The transformer to process the data for video classification model training"""
28
+
29
+ def __init__(
30
+ self,
31
+ resize_size,
32
+ crop_size,
33
+ mean=(0.43216, 0.394666, 0.37645),
34
+ std=(0.22803, 0.22145, 0.216989),
35
+ hflip_prob=0.5,
36
+ ):
37
+ trans = [
38
+ ConvertBHWCtoBCHW(),
39
+ transforms.ConvertImageDtype(torch.float32),
40
+ transforms.Resize(resize_size),
41
+ ]
42
+ if hflip_prob > 0:
43
+ trans.append(transforms.RandomHorizontalFlip(hflip_prob))
44
+ trans.extend(
45
+ [
46
+ transforms.Normalize(mean=mean, std=std),
47
+ transforms.RandomCrop(crop_size),
48
+ ConvertBCHWtoCBHW(),
49
+ ]
50
+ )
51
+ self.transforms = transforms.Compose(trans)
52
+
53
+ def __call__(self, video_data):
54
+ return self.transforms(video_data)
55
+
56
+
57
+ class VideoClassificationEvalTransformer:
58
+ """The transformer to process the data for video classification model evaluation"""
59
+
60
+ def __init__(
61
+ self,
62
+ resize_size,
63
+ crop_size,
64
+ mean=(0.43216, 0.394666, 0.37645),
65
+ std=(0.22803, 0.22145, 0.216989),
66
+ ):
67
+ self.transforms = transforms.Compose(
68
+ [
69
+ ConvertBHWCtoBCHW(),
70
+ transforms.ConvertImageDtype(torch.float32),
71
+ transforms.Resize(resize_size),
72
+ transforms.Normalize(mean=mean, std=std),
73
+ transforms.CenterCrop(crop_size),
74
+ ConvertBCHWtoCBHW(),
75
+ ]
76
+ )
77
+
78
+ def __call__(self, video_data):
79
+ return self.transforms(video_data)
@@ -0,0 +1,64 @@
1
+ """
2
+ The Extended MNIST (EMNIST) dataset from the torchvision package.
3
+ """
4
+
5
+ from torchvision import datasets, transforms
6
+
7
+ from plato.config import Config
8
+ from plato.datasources import base
9
+
10
+
11
+ class DataSource(base.DataSource):
12
+ """The EMNIST dataset."""
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__()
16
+ _path = Config().params["data_path"]
17
+
18
+ train_transform = (
19
+ kwargs["train_transform"]
20
+ if "train_transform" in kwargs
21
+ else (
22
+ transforms.Compose(
23
+ [
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.RandomAffine(
26
+ degrees=10, translate=(0.2, 0.2), scale=(0.8, 1.2)
27
+ ),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.5], std=[0.5]),
30
+ ]
31
+ )
32
+ )
33
+ )
34
+
35
+ test_transform = (
36
+ kwargs["test_transform"]
37
+ if "test_transform" in kwargs
38
+ else (
39
+ transforms.Compose(
40
+ [transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]
41
+ )
42
+ )
43
+ )
44
+
45
+ self.trainset = datasets.EMNIST(
46
+ root=_path,
47
+ split="balanced",
48
+ train=True,
49
+ download=True,
50
+ transform=train_transform,
51
+ )
52
+ self.testset = datasets.EMNIST(
53
+ root=_path,
54
+ split="balanced",
55
+ train=False,
56
+ download=True,
57
+ transform=test_transform,
58
+ )
59
+
60
+ def num_train_examples(self):
61
+ return 112800
62
+
63
+ def num_test_examples(self):
64
+ return 18800
@@ -0,0 +1,41 @@
1
+ """
2
+ The FashionMNIST dataset from the torchvision package.
3
+ """
4
+
5
+ from torchvision import datasets, transforms
6
+
7
+ from plato.config import Config
8
+ from plato.datasources import base
9
+
10
+
11
+ class DataSource(base.DataSource):
12
+ """The FashionMNIST dataset."""
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__()
16
+ _path = Config().params["data_path"]
17
+
18
+ train_transform = (
19
+ kwargs["train_transform"]
20
+ if "train_transform" in kwargs
21
+ else (
22
+ transforms.Compose(
23
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
24
+ )
25
+ )
26
+ )
27
+ test_transform = train_transform
28
+
29
+ self.trainset = datasets.FashionMNIST(
30
+ root=_path, train=True, download=True, transform=train_transform
31
+ )
32
+
33
+ self.testset = datasets.FashionMNIST(
34
+ root=_path, train=False, download=True, transform=test_transform
35
+ )
36
+
37
+ def num_train_examples(self):
38
+ return 60000
39
+
40
+ def num_test_examples(self):
41
+ return 10000
@@ -0,0 +1,24 @@
1
+ """
2
+ The feature dataset server received from clients.
3
+ """
4
+
5
+ from itertools import chain
6
+ from plato.datasources import base
7
+
8
+
9
+ class DataSource(base.DataSource):
10
+ """The feature dataset."""
11
+
12
+ def __init__(self, features, **kwargs):
13
+ super().__init__()
14
+
15
+ # Faster way to deep flatten a list of lists compared to list comprehension
16
+ self.feature_dataset = list(chain.from_iterable(features))
17
+ self.trainset = self.feature_dataset
18
+ self.testset = []
19
+
20
+ def __len__(self):
21
+ return len(self.trainset)
22
+
23
+ def __getitem__(self, item):
24
+ return self.trainset[item]
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ class FeatureDataset(torch.utils.data.Dataset):
5
+ """Used to prepare a feature dataset for a DataLoader in PyTorch."""
6
+
7
+ def __init__(self, dataset):
8
+ self.dataset = dataset
9
+
10
+ def __len__(self):
11
+ return len(self.dataset)
12
+
13
+ def __getitem__(self, item):
14
+ image, label = self.dataset[item]
15
+ return image, label
@@ -0,0 +1,141 @@
1
+ """
2
+ The Federated EMNIST dataset.
3
+
4
+ The Federated EMNIST dataset originates from the EMNIST dataset, which contains
5
+ 817851 images, each of which is a 28x28 greyscale image in 1 out of 62 classes.
6
+ The difference between the Federated EMNIST dataset and its original counterpart
7
+ is that this dataset is already partitioned by the client ID, using the data
8
+ provider IDs included in the original EMNIST dataset. As a result of this
9
+ partitioning, there are 3597 clients in total, each of which has 227.37 images
10
+ on average (std is 88.84). For each client, 90% data samples are used for
11
+ training, while the remaining samples are used for testing.
12
+
13
+ Reference:
14
+
15
+ G. Cohen, S. Afshar, J. Tapson, and A. Van Schaik, "EMNIST: Extending MNIST to
16
+ handwritten letters," in the 2017 International Joint Conference on Neural
17
+ Networks (IJCNN).
18
+
19
+ """
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+
25
+ import numpy as np
26
+ from torch.utils.data import Dataset
27
+ from torchvision import transforms
28
+
29
+ from plato.config import Config
30
+ from plato.datasources import base
31
+
32
+
33
+ class CustomDictDataset(Dataset):
34
+ """Custom dataset from a dictionary with support of transforms."""
35
+
36
+ def __init__(self, loaded_data, transform=None):
37
+ """Initializing the custom dataset."""
38
+ super().__init__()
39
+ self.loaded_data = loaded_data
40
+ self.transform = transform
41
+
42
+ def __getitem__(self, index):
43
+ sample = self.loaded_data["x"][index]
44
+ target = self.loaded_data["y"][index]
45
+
46
+ if self.transform:
47
+ sample = self.transform(sample)
48
+
49
+ return sample, target
50
+
51
+ def __len__(self):
52
+ return len(self.loaded_data["y"])
53
+
54
+
55
+ class ReshapeListTransform:
56
+ """The transform that reshapes an image."""
57
+
58
+ def __init__(self, new_shape):
59
+ self.new_shape = new_shape
60
+
61
+ def __call__(self, img):
62
+ return np.array(img, dtype=np.float32).reshape(self.new_shape)
63
+
64
+
65
+ class DataSource(base.DataSource):
66
+ """The FEMNIST dataset."""
67
+
68
+ def __init__(self, client_id=0, **kwargs):
69
+ super().__init__()
70
+ self.trainset = None
71
+ self.testset = None
72
+
73
+ root_path = os.path.join(
74
+ Config().params["data_path"], "FEMNIST", "packaged_data"
75
+ )
76
+ if client_id == 0:
77
+ # If we are on the federated learning server
78
+ data_dir = os.path.join(root_path, "test")
79
+ data_url = (
80
+ "http://iqua.ece.toronto.edu/baochun/FEMNIST/test/"
81
+ + str(client_id)
82
+ + ".zip"
83
+ )
84
+ else:
85
+ data_dir = os.path.join(root_path, "train")
86
+ data_url = (
87
+ "http://iqua.ece.toronto.edu/baochun/FEMNIST/train/"
88
+ + str(client_id)
89
+ + ".zip"
90
+ )
91
+
92
+ if not os.path.exists(os.path.join(data_dir, str(client_id))):
93
+ logging.info(
94
+ "Downloading the Federated EMNIST dataset "
95
+ "with the client datasets pre-partitioned. This may take a while.",
96
+ )
97
+ self.download(url=data_url, data_path=data_dir)
98
+
99
+ loaded_data = DataSource.read_data(
100
+ file_path=os.path.join(data_dir, str(client_id), "data.json")
101
+ )
102
+
103
+ train_transform = (
104
+ kwargs["train_transform"]
105
+ if "train_transform" in kwargs
106
+ else (
107
+ transforms.Compose(
108
+ [
109
+ ReshapeListTransform((28, 28, 1)),
110
+ transforms.ToPILImage(),
111
+ transforms.RandomCrop(
112
+ 28, padding=2, padding_mode="constant", fill=1.0
113
+ ),
114
+ transforms.RandomResizedCrop(
115
+ 28, scale=(0.8, 1.2), ratio=(4.0 / 5.0, 5.0 / 4.0)
116
+ ),
117
+ transforms.RandomRotation(5, fill=1.0),
118
+ transforms.ToTensor(),
119
+ transforms.Normalize(0.9637, 0.1597),
120
+ ]
121
+ )
122
+ )
123
+ )
124
+
125
+ dataset = CustomDictDataset(loaded_data=loaded_data, transform=train_transform)
126
+
127
+ self.testset = dataset
128
+ self.trainset = dataset
129
+
130
+ @staticmethod
131
+ def read_data(file_path):
132
+ """Reading the dataset specific to a client_id."""
133
+ with open(file_path, "r", encoding="utf-8") as fin:
134
+ loaded_data = json.load(fin)
135
+ return loaded_data
136
+
137
+ def num_train_examples(self):
138
+ return len(self.trainset)
139
+
140
+ def num_test_examples(self):
141
+ return len(self.testset)