plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
import sys
|
2
|
+
import os.path as osp
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import pickle
|
6
|
+
import time
|
7
|
+
import itertools
|
8
|
+
import skimage.io as io
|
9
|
+
|
10
|
+
from pprint import pprint
|
11
|
+
import numpy as np
|
12
|
+
|
13
|
+
|
14
|
+
class REFER:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
data_root, # the root where the data is stored
|
18
|
+
image_dataroot, # the path of the source images
|
19
|
+
dataset="refcoco",
|
20
|
+
splitBy="unc",
|
21
|
+
):
|
22
|
+
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
|
23
|
+
# also provide dataset name and splitBy information
|
24
|
+
# e.g., dataset = 'refcoco', splitBy = 'unc'
|
25
|
+
print("loading dataset %s into memory..." % dataset)
|
26
|
+
|
27
|
+
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
|
28
|
+
self.DATA_DIR = osp.join(data_root, dataset)
|
29
|
+
if dataset in ["refcoco", "refcoco+", "refcocog"]:
|
30
|
+
self.IMAGE_DIR = image_dataroot
|
31
|
+
else:
|
32
|
+
print("No refer dataset is called [%s]" % dataset)
|
33
|
+
sys.exit()
|
34
|
+
|
35
|
+
# load refs from data/datasetrefs(dataset).json
|
36
|
+
tic = time.perf_counter()
|
37
|
+
ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
|
38
|
+
self.data = {}
|
39
|
+
self.data["dataset"] = dataset
|
40
|
+
self.data["refs"] = pickle.load(open(ref_file, "rb"))
|
41
|
+
|
42
|
+
# load annotations from data/dataset/instances.json
|
43
|
+
instances_file = osp.join(self.DATA_DIR, "instances.json")
|
44
|
+
instances = json.load(open(instances_file, "r"))
|
45
|
+
self.data["images"] = instances["images"]
|
46
|
+
self.data["annotations"] = instances["annotations"]
|
47
|
+
self.data["categories"] = instances["categories"]
|
48
|
+
|
49
|
+
# create index
|
50
|
+
self.createIndex()
|
51
|
+
print("DONE (t=%.2fs)" % (time.perf_counter() - tic))
|
52
|
+
|
53
|
+
def createIndex(self):
|
54
|
+
# create sets of mapping
|
55
|
+
# 1) Refs: {ref_id: ref}
|
56
|
+
# 2) Anns: {ann_id: ann}
|
57
|
+
# 3) Imgs: {image_id: image}
|
58
|
+
# 4) Cats: {category_id: category_name}
|
59
|
+
# 5) Sents: {sent_id: sent}
|
60
|
+
# 6) imgToRefs: {image_id: refs}
|
61
|
+
# 7) imgToAnns: {image_id: anns}
|
62
|
+
# 8) refToAnn: {ref_id: ann}
|
63
|
+
# 9) annToRef: {ann_id: ref}
|
64
|
+
# 10) catToRefs: {category_id: refs}
|
65
|
+
# 11) sentToRef: {sent_id: ref}
|
66
|
+
# 12) sentToTokens: {sent_id: tokens}
|
67
|
+
print("creating index...")
|
68
|
+
# fetch info from instances
|
69
|
+
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
|
70
|
+
for ann in self.data["annotations"]:
|
71
|
+
Anns[ann["id"]] = ann
|
72
|
+
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
|
73
|
+
for img in self.data["images"]:
|
74
|
+
img["file_name"] = img["file_name"].strip().split("_")[-1]
|
75
|
+
Imgs[img["id"]] = img
|
76
|
+
for cat in self.data["categories"]:
|
77
|
+
Cats[cat["id"]] = cat["name"]
|
78
|
+
|
79
|
+
# fetch info from refs
|
80
|
+
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
|
81
|
+
Sents, sentToRef, sentToTokens = {}, {}, {}
|
82
|
+
for ref in self.data["refs"]:
|
83
|
+
# ids
|
84
|
+
ref_id = ref["ref_id"]
|
85
|
+
ann_id = ref["ann_id"]
|
86
|
+
category_id = ref["category_id"]
|
87
|
+
image_id = ref["image_id"]
|
88
|
+
|
89
|
+
# add mapping related to ref
|
90
|
+
Refs[ref_id] = ref
|
91
|
+
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
|
92
|
+
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
|
93
|
+
refToAnn[ref_id] = Anns[ann_id]
|
94
|
+
annToRef[ann_id] = ref
|
95
|
+
|
96
|
+
# add mapping of sent
|
97
|
+
for sent in ref["sentences"]:
|
98
|
+
Sents[sent["sent_id"]] = sent
|
99
|
+
sentToRef[sent["sent_id"]] = ref
|
100
|
+
sentToTokens[sent["sent_id"]] = sent["tokens"]
|
101
|
+
|
102
|
+
# create class members
|
103
|
+
self.Refs = Refs
|
104
|
+
self.Anns = Anns
|
105
|
+
self.Imgs = Imgs
|
106
|
+
self.Cats = Cats
|
107
|
+
self.Sents = Sents
|
108
|
+
self.imgToRefs = imgToRefs
|
109
|
+
self.imgToAnns = imgToAnns
|
110
|
+
self.refToAnn = refToAnn
|
111
|
+
self.annToRef = annToRef
|
112
|
+
self.catToRefs = catToRefs
|
113
|
+
self.sentToRef = sentToRef
|
114
|
+
self.sentToTokens = sentToTokens
|
115
|
+
print("index created.")
|
116
|
+
|
117
|
+
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
|
118
|
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
119
|
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
120
|
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
121
|
+
|
122
|
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
|
123
|
+
refs = self.data["refs"]
|
124
|
+
else:
|
125
|
+
if not len(image_ids) == 0:
|
126
|
+
refs = [self.imgToRefs[image_id] for image_id in image_ids]
|
127
|
+
else:
|
128
|
+
refs = self.data["refs"]
|
129
|
+
if not len(cat_ids) == 0:
|
130
|
+
refs = [ref for ref in refs if ref["category_id"] in cat_ids]
|
131
|
+
if not len(ref_ids) == 0:
|
132
|
+
refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
|
133
|
+
if not len(split) == 0:
|
134
|
+
if split in ["testA", "testB", "testC"]:
|
135
|
+
refs = [
|
136
|
+
ref for ref in refs if split[-1] in ref["split"]
|
137
|
+
] # we also consider testAB, testBC, ...
|
138
|
+
elif split in ["testAB", "testBC", "testAC"]:
|
139
|
+
refs = [
|
140
|
+
ref for ref in refs if ref["split"] == split
|
141
|
+
] # rarely used I guess...
|
142
|
+
elif split == "test":
|
143
|
+
refs = [ref for ref in refs if "test" in ref["split"]]
|
144
|
+
elif split == "train" or split == "val":
|
145
|
+
refs = [ref for ref in refs if ref["split"] == split]
|
146
|
+
else:
|
147
|
+
print("No such split [%s]" % split)
|
148
|
+
sys.exit()
|
149
|
+
ref_ids = [ref["ref_id"] for ref in refs]
|
150
|
+
return ref_ids
|
151
|
+
|
152
|
+
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
|
153
|
+
image_ids = image_ids if type(image_ids) == list else [image_ids]
|
154
|
+
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
|
155
|
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
156
|
+
|
157
|
+
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
|
158
|
+
ann_ids = [ann["id"] for ann in self.data["annotations"]]
|
159
|
+
else:
|
160
|
+
if not len(image_ids) == 0:
|
161
|
+
lists = [
|
162
|
+
self.imgToAnns[image_id]
|
163
|
+
for image_id in image_ids
|
164
|
+
if image_id in self.imgToAnns
|
165
|
+
] # list of [anns]
|
166
|
+
anns = list(itertools.chain.from_iterable(lists))
|
167
|
+
else:
|
168
|
+
anns = self.data["annotations"]
|
169
|
+
if not len(cat_ids) == 0:
|
170
|
+
anns = [ann for ann in anns if ann["category_id"] in cat_ids]
|
171
|
+
ann_ids = [ann["id"] for ann in anns]
|
172
|
+
if not len(ref_ids) == 0:
|
173
|
+
ids = set(ann_ids).intersection(
|
174
|
+
set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
|
175
|
+
)
|
176
|
+
return ann_ids
|
177
|
+
|
178
|
+
def getImgIds(self, ref_ids=[]):
|
179
|
+
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
|
180
|
+
|
181
|
+
if not len(ref_ids) == 0:
|
182
|
+
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
|
183
|
+
else:
|
184
|
+
image_ids = self.Imgs.keys()
|
185
|
+
return image_ids
|
186
|
+
|
187
|
+
def getCatIds(self):
|
188
|
+
return self.Cats.keys()
|
189
|
+
|
190
|
+
def loadRefs(self, ref_ids=[]):
|
191
|
+
if type(ref_ids) == list:
|
192
|
+
return [self.Refs[ref_id] for ref_id in ref_ids]
|
193
|
+
elif type(ref_ids) == int:
|
194
|
+
return [self.Refs[ref_ids]]
|
195
|
+
|
196
|
+
def loadAnns(self, ann_ids=[]):
|
197
|
+
if type(ann_ids) == list:
|
198
|
+
return [self.Anns[ann_id] for ann_id in ann_ids]
|
199
|
+
elif type(ann_ids) == int or type(ann_ids) == unicode:
|
200
|
+
return [self.Anns[ann_ids]]
|
201
|
+
|
202
|
+
def loadImgs(self, image_ids=[]):
|
203
|
+
if type(image_ids) == list:
|
204
|
+
return [self.Imgs[image_id] for image_id in image_ids]
|
205
|
+
elif type(image_ids) == int:
|
206
|
+
return [self.Imgs[image_ids]]
|
207
|
+
|
208
|
+
def loadImgspath(self, image_ids=[]):
|
209
|
+
if type(image_ids) == list:
|
210
|
+
return [
|
211
|
+
osp.join(self.IMAGE_DIR, self.Imgs[image_id]["file_name"])
|
212
|
+
for image_id in image_ids
|
213
|
+
]
|
214
|
+
elif type(image_ids) == int:
|
215
|
+
return [osp.join(self.IMAGE_DIR, self.Imgs[image_ids]["file_name"])]
|
216
|
+
|
217
|
+
def loadImgsData(self, image_ids=[]): # the image_ids is obtained from refs
|
218
|
+
if type(image_ids) == list:
|
219
|
+
return [
|
220
|
+
io.imread(osp.join(self.IMAGE_DIR, self.Imgs[image_id]["file_name"]))
|
221
|
+
for image_id in image_ids
|
222
|
+
]
|
223
|
+
elif type(image_ids) == int:
|
224
|
+
return [
|
225
|
+
io.imread(osp.join(self.IMAGE_DIR, self.Imgs[image_ids]["file_name"]))
|
226
|
+
]
|
227
|
+
|
228
|
+
def loadCats(self, cat_ids=[]):
|
229
|
+
if type(cat_ids) == list:
|
230
|
+
return [self.Cats[cat_id] for cat_id in cat_ids]
|
231
|
+
elif type(cat_ids) == int:
|
232
|
+
return [self.Cats[cat_ids]]
|
233
|
+
|
234
|
+
def getRefBox(self, ref_id):
|
235
|
+
ref = self.Refs[ref_id]
|
236
|
+
ann = self.refToAnn[ref_id]
|
237
|
+
return ann["bbox"] # [x, y, w, h]
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""
|
2
|
+
This includes tools for creating the processing tiny dataset.
|
3
|
+
|
4
|
+
The tiny dataset is sampled from the whole dataset.
|
5
|
+
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import collections
|
10
|
+
import logging
|
11
|
+
|
12
|
+
import pandas as pd
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
|
16
|
+
def create_tiny_kinetics_anno(kinetics_annotation_files_info, num_samples, random_seed):
|
17
|
+
"""Creating the annotation files for tiny kinetics.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
kinetics_annotation_files_info (dict): a dict contains the annotation
|
21
|
+
files for different splits. e.g., {"train": xxx, "val": xxx}.
|
22
|
+
num_fist_videos (int): the number of samples utilized to create
|
23
|
+
tiny dataset.
|
24
|
+
random_seed (int): the random seed for sampling samples.
|
25
|
+
"""
|
26
|
+
train_anno_file_path = kinetics_annotation_files_info["train"]
|
27
|
+
test_anno_file_path = kinetics_annotation_files_info["test"]
|
28
|
+
val_anno_file_path = kinetics_annotation_files_info["val"]
|
29
|
+
np.random.seed(random_seed)
|
30
|
+
|
31
|
+
train_anno_df = pd.read_csv(train_anno_file_path)
|
32
|
+
train_selected_samples_df = train_anno_df.iloc[:num_samples]
|
33
|
+
|
34
|
+
train_selected_classes = train_selected_samples_df["label"].tolist()
|
35
|
+
|
36
|
+
# select from test/val anno files based on the train classes
|
37
|
+
def select_anchored_samples(src_anno_df, anchor_classes):
|
38
|
+
selected_df = None
|
39
|
+
counted_classes = collections.Counter(anchor_classes)
|
40
|
+
for cls in list(counted_classes.keys()):
|
41
|
+
num_cls = counted_classes[cls]
|
42
|
+
cls_anno_df = src_anno_df[src_anno_df["label"] == cls]
|
43
|
+
len_cls_df = len(cls_anno_df)
|
44
|
+
selected_samples_idx = np.random.choice(
|
45
|
+
list(range(len_cls_df)), size=num_cls
|
46
|
+
)
|
47
|
+
cls_selected_df = cls_anno_df.iloc[selected_samples_idx]
|
48
|
+
if selected_df is None:
|
49
|
+
selected_df = cls_selected_df
|
50
|
+
else:
|
51
|
+
selected_df = pd.concat([selected_df, cls_selected_df])
|
52
|
+
return selected_df
|
53
|
+
|
54
|
+
def save_anno_df(anno_file_path, selected_anno_df):
|
55
|
+
anno_file_base_dir = os.path.dirname(anno_file_path)
|
56
|
+
anno_file_base_name, ext_type = os.path.basename(anno_file_path).split(".")
|
57
|
+
tiny_anno_file_path = os.path.join(
|
58
|
+
anno_file_base_dir, anno_file_base_name + "_tiny." + ext_type
|
59
|
+
)
|
60
|
+
|
61
|
+
if os.path.exists(tiny_anno_file_path):
|
62
|
+
logging.info(
|
63
|
+
"Annotation file for tiny data exists: %s, Using it directly",
|
64
|
+
tiny_anno_file_path,
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
selected_anno_df.to_csv(path_or_buf=tiny_anno_file_path, index=False)
|
68
|
+
|
69
|
+
test_anno_df = pd.read_csv(test_anno_file_path)
|
70
|
+
test_selected_df = test_anno_df.iloc[:num_samples]
|
71
|
+
|
72
|
+
val_anno_df = pd.read_csv(val_anno_file_path)
|
73
|
+
val_selected_df = select_anchored_samples(
|
74
|
+
src_anno_df=val_anno_df, anchor_classes=train_selected_classes
|
75
|
+
)
|
76
|
+
|
77
|
+
save_anno_df(
|
78
|
+
anno_file_path=train_anno_file_path, selected_anno_df=train_selected_samples_df
|
79
|
+
)
|
80
|
+
save_anno_df(anno_file_path=test_anno_file_path, selected_anno_df=test_selected_df)
|
81
|
+
save_anno_df(anno_file_path=val_anno_file_path, selected_anno_df=val_selected_df)
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
The transformers for the video classification task
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from torchvision.transforms import transforms
|
8
|
+
|
9
|
+
|
10
|
+
class ConvertBHWCtoBCHW(nn.Module):
|
11
|
+
"""Convert tensor from (B, H, W, C) to (B, C, H, W)"""
|
12
|
+
|
13
|
+
def forward(self, vid: torch.Tensor) -> torch.Tensor:
|
14
|
+
"""Change the order of tensor dims"""
|
15
|
+
return vid.permute(0, 3, 1, 2)
|
16
|
+
|
17
|
+
|
18
|
+
class ConvertBCHWtoCBHW(nn.Module):
|
19
|
+
"""Convert tensor from (B, C, H, W) to (C, B, H, W)"""
|
20
|
+
|
21
|
+
def forward(self, vid: torch.Tensor) -> torch.Tensor:
|
22
|
+
"""Change the order of tensor dims"""
|
23
|
+
return vid.permute(1, 0, 2, 3)
|
24
|
+
|
25
|
+
|
26
|
+
class VideoClassificationTrainTransformer:
|
27
|
+
"""The transformer to process the data for video classification model training"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
resize_size,
|
32
|
+
crop_size,
|
33
|
+
mean=(0.43216, 0.394666, 0.37645),
|
34
|
+
std=(0.22803, 0.22145, 0.216989),
|
35
|
+
hflip_prob=0.5,
|
36
|
+
):
|
37
|
+
trans = [
|
38
|
+
ConvertBHWCtoBCHW(),
|
39
|
+
transforms.ConvertImageDtype(torch.float32),
|
40
|
+
transforms.Resize(resize_size),
|
41
|
+
]
|
42
|
+
if hflip_prob > 0:
|
43
|
+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
44
|
+
trans.extend(
|
45
|
+
[
|
46
|
+
transforms.Normalize(mean=mean, std=std),
|
47
|
+
transforms.RandomCrop(crop_size),
|
48
|
+
ConvertBCHWtoCBHW(),
|
49
|
+
]
|
50
|
+
)
|
51
|
+
self.transforms = transforms.Compose(trans)
|
52
|
+
|
53
|
+
def __call__(self, video_data):
|
54
|
+
return self.transforms(video_data)
|
55
|
+
|
56
|
+
|
57
|
+
class VideoClassificationEvalTransformer:
|
58
|
+
"""The transformer to process the data for video classification model evaluation"""
|
59
|
+
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
resize_size,
|
63
|
+
crop_size,
|
64
|
+
mean=(0.43216, 0.394666, 0.37645),
|
65
|
+
std=(0.22803, 0.22145, 0.216989),
|
66
|
+
):
|
67
|
+
self.transforms = transforms.Compose(
|
68
|
+
[
|
69
|
+
ConvertBHWCtoBCHW(),
|
70
|
+
transforms.ConvertImageDtype(torch.float32),
|
71
|
+
transforms.Resize(resize_size),
|
72
|
+
transforms.Normalize(mean=mean, std=std),
|
73
|
+
transforms.CenterCrop(crop_size),
|
74
|
+
ConvertBCHWtoCBHW(),
|
75
|
+
]
|
76
|
+
)
|
77
|
+
|
78
|
+
def __call__(self, video_data):
|
79
|
+
return self.transforms(video_data)
|
@@ -0,0 +1,64 @@
|
|
1
|
+
"""
|
2
|
+
The Extended MNIST (EMNIST) dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from torchvision import datasets, transforms
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
from plato.datasources import base
|
9
|
+
|
10
|
+
|
11
|
+
class DataSource(base.DataSource):
|
12
|
+
"""The EMNIST dataset."""
|
13
|
+
|
14
|
+
def __init__(self, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
_path = Config().params["data_path"]
|
17
|
+
|
18
|
+
train_transform = (
|
19
|
+
kwargs["train_transform"]
|
20
|
+
if "train_transform" in kwargs
|
21
|
+
else (
|
22
|
+
transforms.Compose(
|
23
|
+
[
|
24
|
+
transforms.RandomHorizontalFlip(),
|
25
|
+
transforms.RandomAffine(
|
26
|
+
degrees=10, translate=(0.2, 0.2), scale=(0.8, 1.2)
|
27
|
+
),
|
28
|
+
transforms.ToTensor(),
|
29
|
+
transforms.Normalize(mean=[0.5], std=[0.5]),
|
30
|
+
]
|
31
|
+
)
|
32
|
+
)
|
33
|
+
)
|
34
|
+
|
35
|
+
test_transform = (
|
36
|
+
kwargs["test_transform"]
|
37
|
+
if "test_transform" in kwargs
|
38
|
+
else (
|
39
|
+
transforms.Compose(
|
40
|
+
[transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]
|
41
|
+
)
|
42
|
+
)
|
43
|
+
)
|
44
|
+
|
45
|
+
self.trainset = datasets.EMNIST(
|
46
|
+
root=_path,
|
47
|
+
split="balanced",
|
48
|
+
train=True,
|
49
|
+
download=True,
|
50
|
+
transform=train_transform,
|
51
|
+
)
|
52
|
+
self.testset = datasets.EMNIST(
|
53
|
+
root=_path,
|
54
|
+
split="balanced",
|
55
|
+
train=False,
|
56
|
+
download=True,
|
57
|
+
transform=test_transform,
|
58
|
+
)
|
59
|
+
|
60
|
+
def num_train_examples(self):
|
61
|
+
return 112800
|
62
|
+
|
63
|
+
def num_test_examples(self):
|
64
|
+
return 18800
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"""
|
2
|
+
The FashionMNIST dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from torchvision import datasets, transforms
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
from plato.datasources import base
|
9
|
+
|
10
|
+
|
11
|
+
class DataSource(base.DataSource):
|
12
|
+
"""The FashionMNIST dataset."""
|
13
|
+
|
14
|
+
def __init__(self, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
_path = Config().params["data_path"]
|
17
|
+
|
18
|
+
train_transform = (
|
19
|
+
kwargs["train_transform"]
|
20
|
+
if "train_transform" in kwargs
|
21
|
+
else (
|
22
|
+
transforms.Compose(
|
23
|
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
24
|
+
)
|
25
|
+
)
|
26
|
+
)
|
27
|
+
test_transform = train_transform
|
28
|
+
|
29
|
+
self.trainset = datasets.FashionMNIST(
|
30
|
+
root=_path, train=True, download=True, transform=train_transform
|
31
|
+
)
|
32
|
+
|
33
|
+
self.testset = datasets.FashionMNIST(
|
34
|
+
root=_path, train=False, download=True, transform=test_transform
|
35
|
+
)
|
36
|
+
|
37
|
+
def num_train_examples(self):
|
38
|
+
return 60000
|
39
|
+
|
40
|
+
def num_test_examples(self):
|
41
|
+
return 10000
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""
|
2
|
+
The feature dataset server received from clients.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from itertools import chain
|
6
|
+
from plato.datasources import base
|
7
|
+
|
8
|
+
|
9
|
+
class DataSource(base.DataSource):
|
10
|
+
"""The feature dataset."""
|
11
|
+
|
12
|
+
def __init__(self, features, **kwargs):
|
13
|
+
super().__init__()
|
14
|
+
|
15
|
+
# Faster way to deep flatten a list of lists compared to list comprehension
|
16
|
+
self.feature_dataset = list(chain.from_iterable(features))
|
17
|
+
self.trainset = self.feature_dataset
|
18
|
+
self.testset = []
|
19
|
+
|
20
|
+
def __len__(self):
|
21
|
+
return len(self.trainset)
|
22
|
+
|
23
|
+
def __getitem__(self, item):
|
24
|
+
return self.trainset[item]
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
class FeatureDataset(torch.utils.data.Dataset):
|
5
|
+
"""Used to prepare a feature dataset for a DataLoader in PyTorch."""
|
6
|
+
|
7
|
+
def __init__(self, dataset):
|
8
|
+
self.dataset = dataset
|
9
|
+
|
10
|
+
def __len__(self):
|
11
|
+
return len(self.dataset)
|
12
|
+
|
13
|
+
def __getitem__(self, item):
|
14
|
+
image, label = self.dataset[item]
|
15
|
+
return image, label
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""
|
2
|
+
The Federated EMNIST dataset.
|
3
|
+
|
4
|
+
The Federated EMNIST dataset originates from the EMNIST dataset, which contains
|
5
|
+
817851 images, each of which is a 28x28 greyscale image in 1 out of 62 classes.
|
6
|
+
The difference between the Federated EMNIST dataset and its original counterpart
|
7
|
+
is that this dataset is already partitioned by the client ID, using the data
|
8
|
+
provider IDs included in the original EMNIST dataset. As a result of this
|
9
|
+
partitioning, there are 3597 clients in total, each of which has 227.37 images
|
10
|
+
on average (std is 88.84). For each client, 90% data samples are used for
|
11
|
+
training, while the remaining samples are used for testing.
|
12
|
+
|
13
|
+
Reference:
|
14
|
+
|
15
|
+
G. Cohen, S. Afshar, J. Tapson, and A. Van Schaik, "EMNIST: Extending MNIST to
|
16
|
+
handwritten letters," in the 2017 International Joint Conference on Neural
|
17
|
+
Networks (IJCNN).
|
18
|
+
|
19
|
+
"""
|
20
|
+
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
import os
|
24
|
+
|
25
|
+
import numpy as np
|
26
|
+
from torch.utils.data import Dataset
|
27
|
+
from torchvision import transforms
|
28
|
+
|
29
|
+
from plato.config import Config
|
30
|
+
from plato.datasources import base
|
31
|
+
|
32
|
+
|
33
|
+
class CustomDictDataset(Dataset):
|
34
|
+
"""Custom dataset from a dictionary with support of transforms."""
|
35
|
+
|
36
|
+
def __init__(self, loaded_data, transform=None):
|
37
|
+
"""Initializing the custom dataset."""
|
38
|
+
super().__init__()
|
39
|
+
self.loaded_data = loaded_data
|
40
|
+
self.transform = transform
|
41
|
+
|
42
|
+
def __getitem__(self, index):
|
43
|
+
sample = self.loaded_data["x"][index]
|
44
|
+
target = self.loaded_data["y"][index]
|
45
|
+
|
46
|
+
if self.transform:
|
47
|
+
sample = self.transform(sample)
|
48
|
+
|
49
|
+
return sample, target
|
50
|
+
|
51
|
+
def __len__(self):
|
52
|
+
return len(self.loaded_data["y"])
|
53
|
+
|
54
|
+
|
55
|
+
class ReshapeListTransform:
|
56
|
+
"""The transform that reshapes an image."""
|
57
|
+
|
58
|
+
def __init__(self, new_shape):
|
59
|
+
self.new_shape = new_shape
|
60
|
+
|
61
|
+
def __call__(self, img):
|
62
|
+
return np.array(img, dtype=np.float32).reshape(self.new_shape)
|
63
|
+
|
64
|
+
|
65
|
+
class DataSource(base.DataSource):
|
66
|
+
"""The FEMNIST dataset."""
|
67
|
+
|
68
|
+
def __init__(self, client_id=0, **kwargs):
|
69
|
+
super().__init__()
|
70
|
+
self.trainset = None
|
71
|
+
self.testset = None
|
72
|
+
|
73
|
+
root_path = os.path.join(
|
74
|
+
Config().params["data_path"], "FEMNIST", "packaged_data"
|
75
|
+
)
|
76
|
+
if client_id == 0:
|
77
|
+
# If we are on the federated learning server
|
78
|
+
data_dir = os.path.join(root_path, "test")
|
79
|
+
data_url = (
|
80
|
+
"http://iqua.ece.toronto.edu/baochun/FEMNIST/test/"
|
81
|
+
+ str(client_id)
|
82
|
+
+ ".zip"
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
data_dir = os.path.join(root_path, "train")
|
86
|
+
data_url = (
|
87
|
+
"http://iqua.ece.toronto.edu/baochun/FEMNIST/train/"
|
88
|
+
+ str(client_id)
|
89
|
+
+ ".zip"
|
90
|
+
)
|
91
|
+
|
92
|
+
if not os.path.exists(os.path.join(data_dir, str(client_id))):
|
93
|
+
logging.info(
|
94
|
+
"Downloading the Federated EMNIST dataset "
|
95
|
+
"with the client datasets pre-partitioned. This may take a while.",
|
96
|
+
)
|
97
|
+
self.download(url=data_url, data_path=data_dir)
|
98
|
+
|
99
|
+
loaded_data = DataSource.read_data(
|
100
|
+
file_path=os.path.join(data_dir, str(client_id), "data.json")
|
101
|
+
)
|
102
|
+
|
103
|
+
train_transform = (
|
104
|
+
kwargs["train_transform"]
|
105
|
+
if "train_transform" in kwargs
|
106
|
+
else (
|
107
|
+
transforms.Compose(
|
108
|
+
[
|
109
|
+
ReshapeListTransform((28, 28, 1)),
|
110
|
+
transforms.ToPILImage(),
|
111
|
+
transforms.RandomCrop(
|
112
|
+
28, padding=2, padding_mode="constant", fill=1.0
|
113
|
+
),
|
114
|
+
transforms.RandomResizedCrop(
|
115
|
+
28, scale=(0.8, 1.2), ratio=(4.0 / 5.0, 5.0 / 4.0)
|
116
|
+
),
|
117
|
+
transforms.RandomRotation(5, fill=1.0),
|
118
|
+
transforms.ToTensor(),
|
119
|
+
transforms.Normalize(0.9637, 0.1597),
|
120
|
+
]
|
121
|
+
)
|
122
|
+
)
|
123
|
+
)
|
124
|
+
|
125
|
+
dataset = CustomDictDataset(loaded_data=loaded_data, transform=train_transform)
|
126
|
+
|
127
|
+
self.testset = dataset
|
128
|
+
self.trainset = dataset
|
129
|
+
|
130
|
+
@staticmethod
|
131
|
+
def read_data(file_path):
|
132
|
+
"""Reading the dataset specific to a client_id."""
|
133
|
+
with open(file_path, "r", encoding="utf-8") as fin:
|
134
|
+
loaded_data = json.load(fin)
|
135
|
+
return loaded_data
|
136
|
+
|
137
|
+
def num_train_examples(self):
|
138
|
+
return len(self.trainset)
|
139
|
+
|
140
|
+
def num_test_examples(self):
|
141
|
+
return len(self.testset)
|