plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,336 @@
|
|
1
|
+
"""
|
2
|
+
Necessary functions for the Flickr30K Entities dataset
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import os
|
7
|
+
import json
|
8
|
+
import xml.etree.ElementTree as ET
|
9
|
+
import logging
|
10
|
+
|
11
|
+
from plato.datasources.datalib import data_utils
|
12
|
+
|
13
|
+
|
14
|
+
def phrase_boxes_alignment(flatten_boxes, ori_phrases_boxes):
|
15
|
+
"""align the bounding boxes with corresponding phrases."""
|
16
|
+
phrases_boxes = []
|
17
|
+
|
18
|
+
ori_pb_boxes_count = []
|
19
|
+
for ph_boxes in ori_phrases_boxes:
|
20
|
+
ori_pb_boxes_count.append(len(ph_boxes))
|
21
|
+
|
22
|
+
strat_point = 0
|
23
|
+
for pb_boxes_num in ori_pb_boxes_count:
|
24
|
+
sub_boxes = []
|
25
|
+
for i in range(strat_point, strat_point + pb_boxes_num):
|
26
|
+
sub_boxes.append(flatten_boxes[i])
|
27
|
+
|
28
|
+
strat_point += pb_boxes_num
|
29
|
+
phrases_boxes.append(sub_boxes)
|
30
|
+
|
31
|
+
pb_boxes_count = []
|
32
|
+
for ph_boxes in phrases_boxes:
|
33
|
+
pb_boxes_count.append(len(ph_boxes))
|
34
|
+
|
35
|
+
assert pb_boxes_count == ori_pb_boxes_count
|
36
|
+
|
37
|
+
return phrases_boxes
|
38
|
+
|
39
|
+
|
40
|
+
def filter_bad_boxes(boxes_coor):
|
41
|
+
"""Filter the boxes with wrong coordinates"""
|
42
|
+
filted_boxes = []
|
43
|
+
for box_coor in boxes_coor:
|
44
|
+
[xmin, ymin, xmax, ymax] = box_coor
|
45
|
+
if xmin < xmax and ymin < ymax:
|
46
|
+
filted_boxes.append(box_coor)
|
47
|
+
|
48
|
+
return filted_boxes
|
49
|
+
|
50
|
+
|
51
|
+
def get_sentence_data(parse_file_path):
|
52
|
+
"""Parses a sentence file from the Flickr30K Entities dataset
|
53
|
+
|
54
|
+
Args:
|
55
|
+
parse_file_path - full file path to the sentence file to parse
|
56
|
+
Return:
|
57
|
+
a list of dictionaries for each sentence with the following fields:
|
58
|
+
sentence - the original sentence
|
59
|
+
phrases - a list of dictionaries for each phrase with the
|
60
|
+
following fields:
|
61
|
+
phrase - the text of the annotated phrase
|
62
|
+
first_word_index - the position of the first word of
|
63
|
+
the phrase in the sentence
|
64
|
+
phrase_id - an identifier for this phrase
|
65
|
+
phrase_type - a list of the coarse categories this phrase belongs to
|
66
|
+
"""
|
67
|
+
with open(parse_file_path, "r") as opened_file:
|
68
|
+
sentences = opened_file.read().split("\n")
|
69
|
+
|
70
|
+
annotations = []
|
71
|
+
for sentence in sentences:
|
72
|
+
if not sentence:
|
73
|
+
continue
|
74
|
+
|
75
|
+
first_word = []
|
76
|
+
phrases = []
|
77
|
+
phrase_id = []
|
78
|
+
phrase_type = []
|
79
|
+
words = []
|
80
|
+
current_phrase = []
|
81
|
+
add_to_phrase = False
|
82
|
+
for token in sentence.split():
|
83
|
+
if add_to_phrase:
|
84
|
+
if token[-1] == "]":
|
85
|
+
add_to_phrase = False
|
86
|
+
token = token[:-1]
|
87
|
+
current_phrase.append(token)
|
88
|
+
phrases.append(" ".join(current_phrase))
|
89
|
+
current_phrase = []
|
90
|
+
else:
|
91
|
+
current_phrase.append(token)
|
92
|
+
|
93
|
+
words.append(token)
|
94
|
+
else:
|
95
|
+
if token[0] == "[":
|
96
|
+
add_to_phrase = True
|
97
|
+
first_word.append(len(words))
|
98
|
+
parts = token.split("/")
|
99
|
+
phrase_id.append(parts[1][3:])
|
100
|
+
phrase_type.append(parts[2:])
|
101
|
+
else:
|
102
|
+
words.append(token)
|
103
|
+
|
104
|
+
sentence_data = {"sentence": " ".join(words), "phrases": []}
|
105
|
+
for index, phrase, p_id, p_type in zip(
|
106
|
+
first_word, phrases, phrase_id, phrase_type
|
107
|
+
):
|
108
|
+
sentence_data["phrases"].append(
|
109
|
+
{
|
110
|
+
"first_word_index": index,
|
111
|
+
"phrase": phrase,
|
112
|
+
"phrase_id": p_id,
|
113
|
+
"phrase_type": p_type,
|
114
|
+
}
|
115
|
+
)
|
116
|
+
|
117
|
+
annotations.append(sentence_data)
|
118
|
+
|
119
|
+
return annotations
|
120
|
+
|
121
|
+
|
122
|
+
def get_annotations(parse_file_path):
|
123
|
+
"""Parses the xml files in the Flickr30K Entities dataset.
|
124
|
+
Args:
|
125
|
+
parse_file_path - full file path to the annotations file to parse
|
126
|
+
Return:
|
127
|
+
dictionary with the following fields:
|
128
|
+
scene - list of identifiers which were annotated as
|
129
|
+
pertaining to the whole scene
|
130
|
+
nobox - list of identifiers which were annotated as
|
131
|
+
not being visible in the image
|
132
|
+
boxes - a dictionary where the fields are identifiers
|
133
|
+
and the values are its list of boxes in the [xmin ymin xmax ymax] format
|
134
|
+
"""
|
135
|
+
tree = ET.parse(parse_file_path)
|
136
|
+
root = tree.getroot()
|
137
|
+
size_container = root.findall("size")[0]
|
138
|
+
anno_info = {"boxes": {}, "scene": [], "nobox": []}
|
139
|
+
for size_element in size_container:
|
140
|
+
anno_info[size_element.tag] = int(size_element.text)
|
141
|
+
|
142
|
+
for object_container in root.findall("object"):
|
143
|
+
for names in object_container.findall("name"):
|
144
|
+
box_id = names.text
|
145
|
+
box_container = object_container.findall("bndbox")
|
146
|
+
if len(box_container) > 0:
|
147
|
+
if box_id not in anno_info["boxes"]:
|
148
|
+
anno_info["boxes"][box_id] = []
|
149
|
+
xmin = int(box_container[0].findall("xmin")[0].text) - 1
|
150
|
+
ymin = int(box_container[0].findall("ymin")[0].text) - 1
|
151
|
+
xmax = int(box_container[0].findall("xmax")[0].text) - 1
|
152
|
+
ymax = int(box_container[0].findall("ymax")[0].text) - 1
|
153
|
+
anno_info["boxes"][box_id].append([xmin, ymin, xmax, ymax])
|
154
|
+
else:
|
155
|
+
nobndbox = int(object_container.findall("nobndbox")[0].text)
|
156
|
+
if nobndbox > 0:
|
157
|
+
anno_info["nobox"].append(box_id)
|
158
|
+
|
159
|
+
scene = int(object_container.findall("scene")[0].text)
|
160
|
+
if scene > 0:
|
161
|
+
anno_info["scene"].append(box_id)
|
162
|
+
|
163
|
+
return anno_info
|
164
|
+
|
165
|
+
|
166
|
+
def align_anno_sent(image_sents, image_annos):
|
167
|
+
"""Align the items in annotations and sentences.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
image_sents ([list]): [each itme is a dict that contains 'sentence', 'phrases']
|
171
|
+
image_annos ([dict]): [contain 'boxes' - a dict presents the phrase_id: box]
|
172
|
+
|
173
|
+
Return:
|
174
|
+
aligned_items ([list]): [each itme is a dict that contains the sentence with
|
175
|
+
corresponding phrases information, there should have several
|
176
|
+
items because for one image, there are 5 sentences. Sometimes,
|
177
|
+
some sentences are useless, making the number of items less than 5]
|
178
|
+
"""
|
179
|
+
aligned_items = [] # each item is a dict
|
180
|
+
for sent_info in image_sents:
|
181
|
+
img_sent = sent_info["sentence"]
|
182
|
+
img_sent_phrases = []
|
183
|
+
img_sent_phrases_type = []
|
184
|
+
img_sent_phrases_id = []
|
185
|
+
img_sent_phrases_boxes = []
|
186
|
+
for phrase_info_idx in range(len(sent_info["phrases"])):
|
187
|
+
phrase_info = sent_info["phrases"][phrase_info_idx]
|
188
|
+
|
189
|
+
phrase = phrase_info["phrase"]
|
190
|
+
phrase_type = phrase_info["phrase_type"]
|
191
|
+
phrase_id = phrase_info["phrase_id"]
|
192
|
+
if phrase_id not in image_annos["boxes"].keys():
|
193
|
+
continue
|
194
|
+
|
195
|
+
phrase_boxes = image_annos["boxes"][phrase_id] # a nested list
|
196
|
+
filted_boxes = filter_bad_boxes(phrase_boxes)
|
197
|
+
if not filted_boxes:
|
198
|
+
continue
|
199
|
+
|
200
|
+
img_sent_phrases.append(phrase)
|
201
|
+
img_sent_phrases_type.append(phrase_type)
|
202
|
+
img_sent_phrases_id.append(phrase_id)
|
203
|
+
img_sent_phrases_boxes.append(filted_boxes)
|
204
|
+
|
205
|
+
if not img_sent_phrases:
|
206
|
+
continue
|
207
|
+
|
208
|
+
items = dict()
|
209
|
+
# a string shows the sentence
|
210
|
+
items["sentence"] = img_sent
|
211
|
+
# a list that contains the phrases
|
212
|
+
items["sentence_phrases"] = img_sent_phrases
|
213
|
+
# a nested list that contains phrases type
|
214
|
+
items["sentence_phrases_type"] = img_sent_phrases_type
|
215
|
+
# a list that contains the phrases id
|
216
|
+
items["sentence_phrases_id"] = img_sent_phrases_id
|
217
|
+
# a nested list that contains boxes for each phrase
|
218
|
+
items["sentence_phrases_boxes"] = img_sent_phrases_boxes
|
219
|
+
|
220
|
+
aligned_items.append(items)
|
221
|
+
|
222
|
+
return aligned_items
|
223
|
+
|
224
|
+
|
225
|
+
def integrate_data_to_json(
|
226
|
+
splits_info, mm_data_info, data_types, split_wise=True, globally=True
|
227
|
+
):
|
228
|
+
"""Integrate the data into one json file that contains aligned
|
229
|
+
annotation-sentence for each image.
|
230
|
+
|
231
|
+
The integrated data info is presented as a dict type.
|
232
|
+
|
233
|
+
Each item in dict contains image and one of its annotation.
|
234
|
+
|
235
|
+
For example, one randomly item:
|
236
|
+
{
|
237
|
+
...,
|
238
|
+
"./data/Flickr30KEntities/test/test_Images/1011572216.jpg0"
|
239
|
+
{"sentence": "bride and groom",
|
240
|
+
"sentence_phrases": ["bride", "groom"],
|
241
|
+
"sentence_phrases_type": [["people"], ["people"]],
|
242
|
+
"sentence_phrases_id": ["370", "372"],
|
243
|
+
"sentence_phrases_boxes": [[[161, 21, 330, 357]],
|
244
|
+
[[195, 82, 327, 241]]],
|
245
|
+
},
|
246
|
+
....
|
247
|
+
}
|
248
|
+
"""
|
249
|
+
|
250
|
+
def operate_integration(
|
251
|
+
images_name, images_annotations_path, images_sentences_path
|
252
|
+
):
|
253
|
+
"""Obtain the integrated for images."""
|
254
|
+
integrated_data = dict()
|
255
|
+
for image_name_idx, image_name in enumerate(images_name):
|
256
|
+
image_sent_path = images_sentences_path[image_name_idx]
|
257
|
+
image_anno_path = images_annotations_path[image_name_idx]
|
258
|
+
|
259
|
+
image_sents = get_sentence_data(image_sent_path)
|
260
|
+
|
261
|
+
image_annos = get_annotations(image_anno_path)
|
262
|
+
|
263
|
+
aligned_items = align_anno_sent(image_sents, image_annos)
|
264
|
+
if not aligned_items:
|
265
|
+
continue
|
266
|
+
for item_idx, item in enumerate(aligned_items):
|
267
|
+
integrated_data[image_name + str(item_idx)] = item
|
268
|
+
|
269
|
+
return integrated_data
|
270
|
+
|
271
|
+
if split_wise:
|
272
|
+
for split_type in list(splits_info.keys()):
|
273
|
+
path = splits_info[split_type]["path"]
|
274
|
+
save_path = os.path.join(path, split_type + "_integrated_data.json")
|
275
|
+
if os.path.exists(save_path):
|
276
|
+
logging.info("Integrating %s: the file already exists.", split_type)
|
277
|
+
continue
|
278
|
+
|
279
|
+
split_data_types_samples_path = []
|
280
|
+
for _, data_type in enumerate(data_types):
|
281
|
+
data_type_format = splits_info[split_type][data_type]["format"]
|
282
|
+
split_data_type_path = splits_info[split_type][data_type]["path"]
|
283
|
+
|
284
|
+
split_data_type_samples = data_utils.list_inorder(
|
285
|
+
os.listdir(split_data_type_path), flag_str=data_type_format
|
286
|
+
)
|
287
|
+
|
288
|
+
split_data_type_samples_path = [
|
289
|
+
os.path.join(split_data_type_path, sample)
|
290
|
+
for sample in split_data_type_samples
|
291
|
+
]
|
292
|
+
|
293
|
+
split_data_types_samples_path.append(split_data_type_samples_path)
|
294
|
+
|
295
|
+
split_integrated_data = operate_integration(
|
296
|
+
images_name=split_data_types_samples_path[0],
|
297
|
+
images_annotations_path=split_data_types_samples_path[1],
|
298
|
+
images_sentences_path=split_data_types_samples_path[2],
|
299
|
+
)
|
300
|
+
with open(save_path, "w", encoding="utf-8") as outfile:
|
301
|
+
json.dump(split_integrated_data, outfile)
|
302
|
+
|
303
|
+
logging.info("The integration process for %s is done.", split_type)
|
304
|
+
|
305
|
+
if globally:
|
306
|
+
save_path = os.path.join(
|
307
|
+
mm_data_info["data_path"], "total_integrated_data.json"
|
308
|
+
)
|
309
|
+
if os.path.exists(save_path):
|
310
|
+
logging.info("Gloablly integrated file already exists.")
|
311
|
+
return
|
312
|
+
|
313
|
+
raw_data_types_samples_path = []
|
314
|
+
for _, data_type in enumerate(data_types):
|
315
|
+
data_type_format = mm_data_info[data_type]["format"]
|
316
|
+
raw_data_type_path = mm_data_info[data_type]["path"]
|
317
|
+
|
318
|
+
global_raw_type_samples = data_utils.list_inorder(
|
319
|
+
os.listdir(raw_data_type_path), flag_str=data_type_format
|
320
|
+
)
|
321
|
+
|
322
|
+
global_raw_type_samples_path = [
|
323
|
+
os.path.join(raw_data_type_path, sample)
|
324
|
+
for sample in global_raw_type_samples
|
325
|
+
]
|
326
|
+
raw_data_types_samples_path.append(global_raw_type_samples_path)
|
327
|
+
|
328
|
+
global_integrated_data = operate_integration(
|
329
|
+
images_name=raw_data_types_samples_path[0],
|
330
|
+
images_annotations_path=raw_data_types_samples_path[1],
|
331
|
+
images_sentences_path=raw_data_types_samples_path[2],
|
332
|
+
)
|
333
|
+
with open(save_path, "w", encoding="utf-8") as outfile:
|
334
|
+
json.dump(global_integrated_data, outfile)
|
335
|
+
|
336
|
+
logging.info("Integration for the whole dataset, Done.")
|
@@ -0,0 +1,254 @@
|
|
1
|
+
"""
|
2
|
+
Tools for extracting and processing the frames
|
3
|
+
The classes in this tool aim to extract different modalities,
|
4
|
+
including rgb, optical flow, and audio
|
5
|
+
from the raw video dataset.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import glob
|
9
|
+
import os
|
10
|
+
from multiprocessing import Pool
|
11
|
+
|
12
|
+
from mmaction.tools.misc.flow_extraction import extract_dense_flow
|
13
|
+
|
14
|
+
from plato.datasources.datalib import modality_extraction_base
|
15
|
+
|
16
|
+
|
17
|
+
def obtain_video_dest_dir(out_dir, video_path, is_classname_contained=True):
|
18
|
+
"""Get the destination path for the video"""
|
19
|
+
|
20
|
+
class_name = os.path.basename(os.path.dirname(video_path))
|
21
|
+
_, tail = os.path.split(video_path)
|
22
|
+
video_name = tail.split(".")[0]
|
23
|
+
if is_classname_contained:
|
24
|
+
out_full_path = os.path.join(out_dir, class_name, video_name)
|
25
|
+
else:
|
26
|
+
out_full_path = os.path.join(out_dir, video_name)
|
27
|
+
|
28
|
+
return out_full_path
|
29
|
+
|
30
|
+
|
31
|
+
def extract_dense_flow_wrapper(items):
|
32
|
+
"""This function can extract the frame based on the cpu hardware"""
|
33
|
+
(
|
34
|
+
input_video_path,
|
35
|
+
dest_dir,
|
36
|
+
bound,
|
37
|
+
save_rgb,
|
38
|
+
start_idx,
|
39
|
+
rgb_tmpl,
|
40
|
+
flow_tmpl,
|
41
|
+
method,
|
42
|
+
is_classname_contained,
|
43
|
+
) = items
|
44
|
+
|
45
|
+
out_full_path = obtain_video_dest_dir(
|
46
|
+
dest_dir, input_video_path, is_classname_contained=is_classname_contained
|
47
|
+
)
|
48
|
+
|
49
|
+
extract_dense_flow(
|
50
|
+
input_video_path,
|
51
|
+
out_full_path,
|
52
|
+
bound,
|
53
|
+
save_rgb,
|
54
|
+
start_idx,
|
55
|
+
rgb_tmpl,
|
56
|
+
flow_tmpl,
|
57
|
+
method,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def extract_rgb_frame(videos_extraction_items):
|
62
|
+
"""Generate optical flow using dense flow.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
videos_items (list): Video item containing video full path,
|
66
|
+
video (short) path, video id.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
bool: Whether generate optical flow successfully.
|
70
|
+
"""
|
71
|
+
(
|
72
|
+
full_path,
|
73
|
+
vid_path,
|
74
|
+
_,
|
75
|
+
out_dir,
|
76
|
+
new_width,
|
77
|
+
new_height,
|
78
|
+
new_short,
|
79
|
+
is_classname_contained,
|
80
|
+
) = videos_extraction_items
|
81
|
+
out_full_path = obtain_video_dest_dir(
|
82
|
+
out_dir=out_dir,
|
83
|
+
video_path=vid_path,
|
84
|
+
is_classname_contained=is_classname_contained,
|
85
|
+
)
|
86
|
+
|
87
|
+
if new_short == 0:
|
88
|
+
cmd = os.path.join(
|
89
|
+
f"denseflow '{full_path}' -b=20 -s=0 -o='{out_full_path}'"
|
90
|
+
f" -nw={new_width} -nh={new_height} -v"
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
cmd = os.path.join(
|
94
|
+
f"denseflow '{full_path}' -b=20 -s=0 -o='{out_full_path}'"
|
95
|
+
f" -ns={new_short} -v"
|
96
|
+
)
|
97
|
+
os.system(cmd)
|
98
|
+
|
99
|
+
|
100
|
+
def extract_optical_flow(videos_items):
|
101
|
+
"""Extract optical flow from the video"""
|
102
|
+
(
|
103
|
+
full_path,
|
104
|
+
vid_path,
|
105
|
+
_,
|
106
|
+
method,
|
107
|
+
out_dir,
|
108
|
+
new_short,
|
109
|
+
new_width,
|
110
|
+
new_height,
|
111
|
+
is_classname_contained,
|
112
|
+
) = videos_items
|
113
|
+
out_full_path = obtain_video_dest_dir(
|
114
|
+
out_dir=out_dir,
|
115
|
+
video_path=vid_path,
|
116
|
+
is_classname_contained=is_classname_contained,
|
117
|
+
)
|
118
|
+
|
119
|
+
if new_short == 0:
|
120
|
+
cmd = os.path.join(
|
121
|
+
f"denseflow '{full_path}' -a={method} -b=20 -s=1 -o='{out_full_path}'" # noqa: E501
|
122
|
+
f" -nw={new_width} --nh={new_height} -v"
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
cmd = os.path.join(
|
126
|
+
f"denseflow '{full_path}' -a={method} -b=20 -s=1 -o='{out_full_path}'" # noqa: E501
|
127
|
+
f" -ns={new_short} -v"
|
128
|
+
)
|
129
|
+
|
130
|
+
os.system(cmd)
|
131
|
+
|
132
|
+
|
133
|
+
class VideoFramesExtractor(modality_extraction_base.VideoExtractorBase):
|
134
|
+
"""The class for extracting the frame the video"""
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self, video_src_dir, dir_level=2, num_worker=8, video_ext="mp4", mixed_ext=False
|
138
|
+
):
|
139
|
+
super().__init__(video_src_dir, dir_level, num_worker, video_ext, mixed_ext)
|
140
|
+
self.is_classname_contained = False
|
141
|
+
# the videos are categorized by the classes
|
142
|
+
if dir_level == 2:
|
143
|
+
self.is_classname_contained = True
|
144
|
+
|
145
|
+
def build_rgb_frames(self, to_dir, new_short=0, new_width=0, new_height=0):
|
146
|
+
"""Obtain the RGB frame"""
|
147
|
+
sourc_video_dir = self.video_src_dir
|
148
|
+
if self.dir_level == 2:
|
149
|
+
self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
|
150
|
+
_ = glob.glob(to_dir + "/*" * self.dir_level)
|
151
|
+
|
152
|
+
pool = Pool(self.num_worker)
|
153
|
+
pool.map(
|
154
|
+
extract_rgb_frame,
|
155
|
+
zip(
|
156
|
+
self.fullpath_list,
|
157
|
+
self.videos_path_list,
|
158
|
+
range(len(self.videos_path_list)),
|
159
|
+
len(self.videos_path_list) * [to_dir],
|
160
|
+
len(self.videos_path_list) * [new_short],
|
161
|
+
len(self.videos_path_list) * [new_width],
|
162
|
+
len(self.videos_path_list) * [new_height],
|
163
|
+
len(self.videos_path_list) * [self.is_classname_contained],
|
164
|
+
),
|
165
|
+
)
|
166
|
+
|
167
|
+
def build_optical_flow_frames(
|
168
|
+
self,
|
169
|
+
to_dir,
|
170
|
+
flow_type=None, # None, 'tvl1', 'warp_tvl1', 'farn', 'brox',
|
171
|
+
new_short=0, # resize image short side length keeping ratio
|
172
|
+
new_width=0,
|
173
|
+
new_height=0,
|
174
|
+
):
|
175
|
+
"""Get the optical flow frame based on the CPU"""
|
176
|
+
sourc_video_dir = self.video_src_dir
|
177
|
+
if self.dir_level == 2:
|
178
|
+
self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
|
179
|
+
_ = glob.glob(to_dir + "/*" * self.dir_level)
|
180
|
+
|
181
|
+
pool = Pool(self.num_worker)
|
182
|
+
pool.map(
|
183
|
+
extract_optical_flow,
|
184
|
+
zip(
|
185
|
+
self.fullpath_list,
|
186
|
+
self.videos_path_list,
|
187
|
+
range(len(self.videos_path_list)),
|
188
|
+
len(self.videos_path_list) * [flow_type],
|
189
|
+
len(self.videos_path_list) * [to_dir],
|
190
|
+
len(self.videos_path_list) * [new_short],
|
191
|
+
len(self.videos_path_list) * [new_width],
|
192
|
+
len(self.videos_path_list) * [new_height],
|
193
|
+
len(self.videos_path_list) * [self.is_classname_contained],
|
194
|
+
),
|
195
|
+
)
|
196
|
+
|
197
|
+
def build_frames_gpu(
|
198
|
+
self, rgb_out__path, flow_our__path, new_short=1, new_width=0, new_height=0
|
199
|
+
):
|
200
|
+
"""Get the optical flow frame based on the GPU"""
|
201
|
+
self.build_rgb_frames(
|
202
|
+
rgb_out__path,
|
203
|
+
new_short=new_short,
|
204
|
+
new_width=new_width,
|
205
|
+
new_height=new_height,
|
206
|
+
)
|
207
|
+
self.build_optical_flow_frames(
|
208
|
+
flow_our__path,
|
209
|
+
new_short=new_short,
|
210
|
+
new_width=new_width,
|
211
|
+
new_height=new_height,
|
212
|
+
)
|
213
|
+
|
214
|
+
def build_full_frames_gpu(self, to__path, new_short=1, new_width=0, new_height=0):
|
215
|
+
"""The interface for extracting all frames based on the GPU"""
|
216
|
+
self.build_frames_gpu(
|
217
|
+
rgb_out__path=to__path,
|
218
|
+
flow_our__path=to__path,
|
219
|
+
new_short=new_short,
|
220
|
+
new_width=new_width,
|
221
|
+
new_height=new_height,
|
222
|
+
)
|
223
|
+
|
224
|
+
def build_frames_cpu(
|
225
|
+
self,
|
226
|
+
to_dir,
|
227
|
+
bound=20, # maximum of optical flow
|
228
|
+
save_rgb=True, # also save rgb frames
|
229
|
+
start_idx=1, # index of extracted frames
|
230
|
+
rgb_tmpl="img_{:05d}.jpg", # template filename of rgb frames
|
231
|
+
flow_tmpl="{}_{:05d}.jpg", # template filename of flow frames
|
232
|
+
method="tvl1",
|
233
|
+
): # use which method to generate the flow
|
234
|
+
"""Get the full frames, including RGB and optical flow based on the GPU"""
|
235
|
+
sourc_video_dir = self.video_src_dir
|
236
|
+
if self.dir_level == 2:
|
237
|
+
self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
|
238
|
+
_ = glob.glob(to_dir + "/*" * self.dir_level)
|
239
|
+
|
240
|
+
pool = Pool(self.num_worker)
|
241
|
+
pool.map(
|
242
|
+
extract_dense_flow_wrapper,
|
243
|
+
zip(
|
244
|
+
self.fullpath_list,
|
245
|
+
len(self.videos_path_list) * [to_dir],
|
246
|
+
len(self.videos_path_list) * [bound],
|
247
|
+
len(self.videos_path_list) * [save_rgb],
|
248
|
+
len(self.videos_path_list) * [start_idx],
|
249
|
+
len(self.videos_path_list) * [rgb_tmpl],
|
250
|
+
len(self.videos_path_list) * [flow_tmpl],
|
251
|
+
len(self.videos_path_list) * [method],
|
252
|
+
len(self.videos_path_list) * [self.is_classname_contained],
|
253
|
+
),
|
254
|
+
)
|
File without changes
|