plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,336 @@
1
+ """
2
+ Necessary functions for the Flickr30K Entities dataset
3
+
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import xml.etree.ElementTree as ET
9
+ import logging
10
+
11
+ from plato.datasources.datalib import data_utils
12
+
13
+
14
+ def phrase_boxes_alignment(flatten_boxes, ori_phrases_boxes):
15
+ """align the bounding boxes with corresponding phrases."""
16
+ phrases_boxes = []
17
+
18
+ ori_pb_boxes_count = []
19
+ for ph_boxes in ori_phrases_boxes:
20
+ ori_pb_boxes_count.append(len(ph_boxes))
21
+
22
+ strat_point = 0
23
+ for pb_boxes_num in ori_pb_boxes_count:
24
+ sub_boxes = []
25
+ for i in range(strat_point, strat_point + pb_boxes_num):
26
+ sub_boxes.append(flatten_boxes[i])
27
+
28
+ strat_point += pb_boxes_num
29
+ phrases_boxes.append(sub_boxes)
30
+
31
+ pb_boxes_count = []
32
+ for ph_boxes in phrases_boxes:
33
+ pb_boxes_count.append(len(ph_boxes))
34
+
35
+ assert pb_boxes_count == ori_pb_boxes_count
36
+
37
+ return phrases_boxes
38
+
39
+
40
+ def filter_bad_boxes(boxes_coor):
41
+ """Filter the boxes with wrong coordinates"""
42
+ filted_boxes = []
43
+ for box_coor in boxes_coor:
44
+ [xmin, ymin, xmax, ymax] = box_coor
45
+ if xmin < xmax and ymin < ymax:
46
+ filted_boxes.append(box_coor)
47
+
48
+ return filted_boxes
49
+
50
+
51
+ def get_sentence_data(parse_file_path):
52
+ """Parses a sentence file from the Flickr30K Entities dataset
53
+
54
+ Args:
55
+ parse_file_path - full file path to the sentence file to parse
56
+ Return:
57
+ a list of dictionaries for each sentence with the following fields:
58
+ sentence - the original sentence
59
+ phrases - a list of dictionaries for each phrase with the
60
+ following fields:
61
+ phrase - the text of the annotated phrase
62
+ first_word_index - the position of the first word of
63
+ the phrase in the sentence
64
+ phrase_id - an identifier for this phrase
65
+ phrase_type - a list of the coarse categories this phrase belongs to
66
+ """
67
+ with open(parse_file_path, "r") as opened_file:
68
+ sentences = opened_file.read().split("\n")
69
+
70
+ annotations = []
71
+ for sentence in sentences:
72
+ if not sentence:
73
+ continue
74
+
75
+ first_word = []
76
+ phrases = []
77
+ phrase_id = []
78
+ phrase_type = []
79
+ words = []
80
+ current_phrase = []
81
+ add_to_phrase = False
82
+ for token in sentence.split():
83
+ if add_to_phrase:
84
+ if token[-1] == "]":
85
+ add_to_phrase = False
86
+ token = token[:-1]
87
+ current_phrase.append(token)
88
+ phrases.append(" ".join(current_phrase))
89
+ current_phrase = []
90
+ else:
91
+ current_phrase.append(token)
92
+
93
+ words.append(token)
94
+ else:
95
+ if token[0] == "[":
96
+ add_to_phrase = True
97
+ first_word.append(len(words))
98
+ parts = token.split("/")
99
+ phrase_id.append(parts[1][3:])
100
+ phrase_type.append(parts[2:])
101
+ else:
102
+ words.append(token)
103
+
104
+ sentence_data = {"sentence": " ".join(words), "phrases": []}
105
+ for index, phrase, p_id, p_type in zip(
106
+ first_word, phrases, phrase_id, phrase_type
107
+ ):
108
+ sentence_data["phrases"].append(
109
+ {
110
+ "first_word_index": index,
111
+ "phrase": phrase,
112
+ "phrase_id": p_id,
113
+ "phrase_type": p_type,
114
+ }
115
+ )
116
+
117
+ annotations.append(sentence_data)
118
+
119
+ return annotations
120
+
121
+
122
+ def get_annotations(parse_file_path):
123
+ """Parses the xml files in the Flickr30K Entities dataset.
124
+ Args:
125
+ parse_file_path - full file path to the annotations file to parse
126
+ Return:
127
+ dictionary with the following fields:
128
+ scene - list of identifiers which were annotated as
129
+ pertaining to the whole scene
130
+ nobox - list of identifiers which were annotated as
131
+ not being visible in the image
132
+ boxes - a dictionary where the fields are identifiers
133
+ and the values are its list of boxes in the [xmin ymin xmax ymax] format
134
+ """
135
+ tree = ET.parse(parse_file_path)
136
+ root = tree.getroot()
137
+ size_container = root.findall("size")[0]
138
+ anno_info = {"boxes": {}, "scene": [], "nobox": []}
139
+ for size_element in size_container:
140
+ anno_info[size_element.tag] = int(size_element.text)
141
+
142
+ for object_container in root.findall("object"):
143
+ for names in object_container.findall("name"):
144
+ box_id = names.text
145
+ box_container = object_container.findall("bndbox")
146
+ if len(box_container) > 0:
147
+ if box_id not in anno_info["boxes"]:
148
+ anno_info["boxes"][box_id] = []
149
+ xmin = int(box_container[0].findall("xmin")[0].text) - 1
150
+ ymin = int(box_container[0].findall("ymin")[0].text) - 1
151
+ xmax = int(box_container[0].findall("xmax")[0].text) - 1
152
+ ymax = int(box_container[0].findall("ymax")[0].text) - 1
153
+ anno_info["boxes"][box_id].append([xmin, ymin, xmax, ymax])
154
+ else:
155
+ nobndbox = int(object_container.findall("nobndbox")[0].text)
156
+ if nobndbox > 0:
157
+ anno_info["nobox"].append(box_id)
158
+
159
+ scene = int(object_container.findall("scene")[0].text)
160
+ if scene > 0:
161
+ anno_info["scene"].append(box_id)
162
+
163
+ return anno_info
164
+
165
+
166
+ def align_anno_sent(image_sents, image_annos):
167
+ """Align the items in annotations and sentences.
168
+
169
+ Args:
170
+ image_sents ([list]): [each itme is a dict that contains 'sentence', 'phrases']
171
+ image_annos ([dict]): [contain 'boxes' - a dict presents the phrase_id: box]
172
+
173
+ Return:
174
+ aligned_items ([list]): [each itme is a dict that contains the sentence with
175
+ corresponding phrases information, there should have several
176
+ items because for one image, there are 5 sentences. Sometimes,
177
+ some sentences are useless, making the number of items less than 5]
178
+ """
179
+ aligned_items = [] # each item is a dict
180
+ for sent_info in image_sents:
181
+ img_sent = sent_info["sentence"]
182
+ img_sent_phrases = []
183
+ img_sent_phrases_type = []
184
+ img_sent_phrases_id = []
185
+ img_sent_phrases_boxes = []
186
+ for phrase_info_idx in range(len(sent_info["phrases"])):
187
+ phrase_info = sent_info["phrases"][phrase_info_idx]
188
+
189
+ phrase = phrase_info["phrase"]
190
+ phrase_type = phrase_info["phrase_type"]
191
+ phrase_id = phrase_info["phrase_id"]
192
+ if phrase_id not in image_annos["boxes"].keys():
193
+ continue
194
+
195
+ phrase_boxes = image_annos["boxes"][phrase_id] # a nested list
196
+ filted_boxes = filter_bad_boxes(phrase_boxes)
197
+ if not filted_boxes:
198
+ continue
199
+
200
+ img_sent_phrases.append(phrase)
201
+ img_sent_phrases_type.append(phrase_type)
202
+ img_sent_phrases_id.append(phrase_id)
203
+ img_sent_phrases_boxes.append(filted_boxes)
204
+
205
+ if not img_sent_phrases:
206
+ continue
207
+
208
+ items = dict()
209
+ # a string shows the sentence
210
+ items["sentence"] = img_sent
211
+ # a list that contains the phrases
212
+ items["sentence_phrases"] = img_sent_phrases
213
+ # a nested list that contains phrases type
214
+ items["sentence_phrases_type"] = img_sent_phrases_type
215
+ # a list that contains the phrases id
216
+ items["sentence_phrases_id"] = img_sent_phrases_id
217
+ # a nested list that contains boxes for each phrase
218
+ items["sentence_phrases_boxes"] = img_sent_phrases_boxes
219
+
220
+ aligned_items.append(items)
221
+
222
+ return aligned_items
223
+
224
+
225
+ def integrate_data_to_json(
226
+ splits_info, mm_data_info, data_types, split_wise=True, globally=True
227
+ ):
228
+ """Integrate the data into one json file that contains aligned
229
+ annotation-sentence for each image.
230
+
231
+ The integrated data info is presented as a dict type.
232
+
233
+ Each item in dict contains image and one of its annotation.
234
+
235
+ For example, one randomly item:
236
+ {
237
+ ...,
238
+ "./data/Flickr30KEntities/test/test_Images/1011572216.jpg0"
239
+ {"sentence": "bride and groom",
240
+ "sentence_phrases": ["bride", "groom"],
241
+ "sentence_phrases_type": [["people"], ["people"]],
242
+ "sentence_phrases_id": ["370", "372"],
243
+ "sentence_phrases_boxes": [[[161, 21, 330, 357]],
244
+ [[195, 82, 327, 241]]],
245
+ },
246
+ ....
247
+ }
248
+ """
249
+
250
+ def operate_integration(
251
+ images_name, images_annotations_path, images_sentences_path
252
+ ):
253
+ """Obtain the integrated for images."""
254
+ integrated_data = dict()
255
+ for image_name_idx, image_name in enumerate(images_name):
256
+ image_sent_path = images_sentences_path[image_name_idx]
257
+ image_anno_path = images_annotations_path[image_name_idx]
258
+
259
+ image_sents = get_sentence_data(image_sent_path)
260
+
261
+ image_annos = get_annotations(image_anno_path)
262
+
263
+ aligned_items = align_anno_sent(image_sents, image_annos)
264
+ if not aligned_items:
265
+ continue
266
+ for item_idx, item in enumerate(aligned_items):
267
+ integrated_data[image_name + str(item_idx)] = item
268
+
269
+ return integrated_data
270
+
271
+ if split_wise:
272
+ for split_type in list(splits_info.keys()):
273
+ path = splits_info[split_type]["path"]
274
+ save_path = os.path.join(path, split_type + "_integrated_data.json")
275
+ if os.path.exists(save_path):
276
+ logging.info("Integrating %s: the file already exists.", split_type)
277
+ continue
278
+
279
+ split_data_types_samples_path = []
280
+ for _, data_type in enumerate(data_types):
281
+ data_type_format = splits_info[split_type][data_type]["format"]
282
+ split_data_type_path = splits_info[split_type][data_type]["path"]
283
+
284
+ split_data_type_samples = data_utils.list_inorder(
285
+ os.listdir(split_data_type_path), flag_str=data_type_format
286
+ )
287
+
288
+ split_data_type_samples_path = [
289
+ os.path.join(split_data_type_path, sample)
290
+ for sample in split_data_type_samples
291
+ ]
292
+
293
+ split_data_types_samples_path.append(split_data_type_samples_path)
294
+
295
+ split_integrated_data = operate_integration(
296
+ images_name=split_data_types_samples_path[0],
297
+ images_annotations_path=split_data_types_samples_path[1],
298
+ images_sentences_path=split_data_types_samples_path[2],
299
+ )
300
+ with open(save_path, "w", encoding="utf-8") as outfile:
301
+ json.dump(split_integrated_data, outfile)
302
+
303
+ logging.info("The integration process for %s is done.", split_type)
304
+
305
+ if globally:
306
+ save_path = os.path.join(
307
+ mm_data_info["data_path"], "total_integrated_data.json"
308
+ )
309
+ if os.path.exists(save_path):
310
+ logging.info("Gloablly integrated file already exists.")
311
+ return
312
+
313
+ raw_data_types_samples_path = []
314
+ for _, data_type in enumerate(data_types):
315
+ data_type_format = mm_data_info[data_type]["format"]
316
+ raw_data_type_path = mm_data_info[data_type]["path"]
317
+
318
+ global_raw_type_samples = data_utils.list_inorder(
319
+ os.listdir(raw_data_type_path), flag_str=data_type_format
320
+ )
321
+
322
+ global_raw_type_samples_path = [
323
+ os.path.join(raw_data_type_path, sample)
324
+ for sample in global_raw_type_samples
325
+ ]
326
+ raw_data_types_samples_path.append(global_raw_type_samples_path)
327
+
328
+ global_integrated_data = operate_integration(
329
+ images_name=raw_data_types_samples_path[0],
330
+ images_annotations_path=raw_data_types_samples_path[1],
331
+ images_sentences_path=raw_data_types_samples_path[2],
332
+ )
333
+ with open(save_path, "w", encoding="utf-8") as outfile:
334
+ json.dump(global_integrated_data, outfile)
335
+
336
+ logging.info("Integration for the whole dataset, Done.")
@@ -0,0 +1,254 @@
1
+ """
2
+ Tools for extracting and processing the frames
3
+ The classes in this tool aim to extract different modalities,
4
+ including rgb, optical flow, and audio
5
+ from the raw video dataset.
6
+ """
7
+
8
+ import glob
9
+ import os
10
+ from multiprocessing import Pool
11
+
12
+ from mmaction.tools.misc.flow_extraction import extract_dense_flow
13
+
14
+ from plato.datasources.datalib import modality_extraction_base
15
+
16
+
17
+ def obtain_video_dest_dir(out_dir, video_path, is_classname_contained=True):
18
+ """Get the destination path for the video"""
19
+
20
+ class_name = os.path.basename(os.path.dirname(video_path))
21
+ _, tail = os.path.split(video_path)
22
+ video_name = tail.split(".")[0]
23
+ if is_classname_contained:
24
+ out_full_path = os.path.join(out_dir, class_name, video_name)
25
+ else:
26
+ out_full_path = os.path.join(out_dir, video_name)
27
+
28
+ return out_full_path
29
+
30
+
31
+ def extract_dense_flow_wrapper(items):
32
+ """This function can extract the frame based on the cpu hardware"""
33
+ (
34
+ input_video_path,
35
+ dest_dir,
36
+ bound,
37
+ save_rgb,
38
+ start_idx,
39
+ rgb_tmpl,
40
+ flow_tmpl,
41
+ method,
42
+ is_classname_contained,
43
+ ) = items
44
+
45
+ out_full_path = obtain_video_dest_dir(
46
+ dest_dir, input_video_path, is_classname_contained=is_classname_contained
47
+ )
48
+
49
+ extract_dense_flow(
50
+ input_video_path,
51
+ out_full_path,
52
+ bound,
53
+ save_rgb,
54
+ start_idx,
55
+ rgb_tmpl,
56
+ flow_tmpl,
57
+ method,
58
+ )
59
+
60
+
61
+ def extract_rgb_frame(videos_extraction_items):
62
+ """Generate optical flow using dense flow.
63
+
64
+ Args:
65
+ videos_items (list): Video item containing video full path,
66
+ video (short) path, video id.
67
+
68
+ Returns:
69
+ bool: Whether generate optical flow successfully.
70
+ """
71
+ (
72
+ full_path,
73
+ vid_path,
74
+ _,
75
+ out_dir,
76
+ new_width,
77
+ new_height,
78
+ new_short,
79
+ is_classname_contained,
80
+ ) = videos_extraction_items
81
+ out_full_path = obtain_video_dest_dir(
82
+ out_dir=out_dir,
83
+ video_path=vid_path,
84
+ is_classname_contained=is_classname_contained,
85
+ )
86
+
87
+ if new_short == 0:
88
+ cmd = os.path.join(
89
+ f"denseflow '{full_path}' -b=20 -s=0 -o='{out_full_path}'"
90
+ f" -nw={new_width} -nh={new_height} -v"
91
+ )
92
+ else:
93
+ cmd = os.path.join(
94
+ f"denseflow '{full_path}' -b=20 -s=0 -o='{out_full_path}'"
95
+ f" -ns={new_short} -v"
96
+ )
97
+ os.system(cmd)
98
+
99
+
100
+ def extract_optical_flow(videos_items):
101
+ """Extract optical flow from the video"""
102
+ (
103
+ full_path,
104
+ vid_path,
105
+ _,
106
+ method,
107
+ out_dir,
108
+ new_short,
109
+ new_width,
110
+ new_height,
111
+ is_classname_contained,
112
+ ) = videos_items
113
+ out_full_path = obtain_video_dest_dir(
114
+ out_dir=out_dir,
115
+ video_path=vid_path,
116
+ is_classname_contained=is_classname_contained,
117
+ )
118
+
119
+ if new_short == 0:
120
+ cmd = os.path.join(
121
+ f"denseflow '{full_path}' -a={method} -b=20 -s=1 -o='{out_full_path}'" # noqa: E501
122
+ f" -nw={new_width} --nh={new_height} -v"
123
+ )
124
+ else:
125
+ cmd = os.path.join(
126
+ f"denseflow '{full_path}' -a={method} -b=20 -s=1 -o='{out_full_path}'" # noqa: E501
127
+ f" -ns={new_short} -v"
128
+ )
129
+
130
+ os.system(cmd)
131
+
132
+
133
+ class VideoFramesExtractor(modality_extraction_base.VideoExtractorBase):
134
+ """The class for extracting the frame the video"""
135
+
136
+ def __init__(
137
+ self, video_src_dir, dir_level=2, num_worker=8, video_ext="mp4", mixed_ext=False
138
+ ):
139
+ super().__init__(video_src_dir, dir_level, num_worker, video_ext, mixed_ext)
140
+ self.is_classname_contained = False
141
+ # the videos are categorized by the classes
142
+ if dir_level == 2:
143
+ self.is_classname_contained = True
144
+
145
+ def build_rgb_frames(self, to_dir, new_short=0, new_width=0, new_height=0):
146
+ """Obtain the RGB frame"""
147
+ sourc_video_dir = self.video_src_dir
148
+ if self.dir_level == 2:
149
+ self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
150
+ _ = glob.glob(to_dir + "/*" * self.dir_level)
151
+
152
+ pool = Pool(self.num_worker)
153
+ pool.map(
154
+ extract_rgb_frame,
155
+ zip(
156
+ self.fullpath_list,
157
+ self.videos_path_list,
158
+ range(len(self.videos_path_list)),
159
+ len(self.videos_path_list) * [to_dir],
160
+ len(self.videos_path_list) * [new_short],
161
+ len(self.videos_path_list) * [new_width],
162
+ len(self.videos_path_list) * [new_height],
163
+ len(self.videos_path_list) * [self.is_classname_contained],
164
+ ),
165
+ )
166
+
167
+ def build_optical_flow_frames(
168
+ self,
169
+ to_dir,
170
+ flow_type=None, # None, 'tvl1', 'warp_tvl1', 'farn', 'brox',
171
+ new_short=0, # resize image short side length keeping ratio
172
+ new_width=0,
173
+ new_height=0,
174
+ ):
175
+ """Get the optical flow frame based on the CPU"""
176
+ sourc_video_dir = self.video_src_dir
177
+ if self.dir_level == 2:
178
+ self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
179
+ _ = glob.glob(to_dir + "/*" * self.dir_level)
180
+
181
+ pool = Pool(self.num_worker)
182
+ pool.map(
183
+ extract_optical_flow,
184
+ zip(
185
+ self.fullpath_list,
186
+ self.videos_path_list,
187
+ range(len(self.videos_path_list)),
188
+ len(self.videos_path_list) * [flow_type],
189
+ len(self.videos_path_list) * [to_dir],
190
+ len(self.videos_path_list) * [new_short],
191
+ len(self.videos_path_list) * [new_width],
192
+ len(self.videos_path_list) * [new_height],
193
+ len(self.videos_path_list) * [self.is_classname_contained],
194
+ ),
195
+ )
196
+
197
+ def build_frames_gpu(
198
+ self, rgb_out__path, flow_our__path, new_short=1, new_width=0, new_height=0
199
+ ):
200
+ """Get the optical flow frame based on the GPU"""
201
+ self.build_rgb_frames(
202
+ rgb_out__path,
203
+ new_short=new_short,
204
+ new_width=new_width,
205
+ new_height=new_height,
206
+ )
207
+ self.build_optical_flow_frames(
208
+ flow_our__path,
209
+ new_short=new_short,
210
+ new_width=new_width,
211
+ new_height=new_height,
212
+ )
213
+
214
+ def build_full_frames_gpu(self, to__path, new_short=1, new_width=0, new_height=0):
215
+ """The interface for extracting all frames based on the GPU"""
216
+ self.build_frames_gpu(
217
+ rgb_out__path=to__path,
218
+ flow_our__path=to__path,
219
+ new_short=new_short,
220
+ new_width=new_width,
221
+ new_height=new_height,
222
+ )
223
+
224
+ def build_frames_cpu(
225
+ self,
226
+ to_dir,
227
+ bound=20, # maximum of optical flow
228
+ save_rgb=True, # also save rgb frames
229
+ start_idx=1, # index of extracted frames
230
+ rgb_tmpl="img_{:05d}.jpg", # template filename of rgb frames
231
+ flow_tmpl="{}_{:05d}.jpg", # template filename of flow frames
232
+ method="tvl1",
233
+ ): # use which method to generate the flow
234
+ """Get the full frames, including RGB and optical flow based on the GPU"""
235
+ sourc_video_dir = self.video_src_dir
236
+ if self.dir_level == 2:
237
+ self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
238
+ _ = glob.glob(to_dir + "/*" * self.dir_level)
239
+
240
+ pool = Pool(self.num_worker)
241
+ pool.map(
242
+ extract_dense_flow_wrapper,
243
+ zip(
244
+ self.fullpath_list,
245
+ len(self.videos_path_list) * [to_dir],
246
+ len(self.videos_path_list) * [bound],
247
+ len(self.videos_path_list) * [save_rgb],
248
+ len(self.videos_path_list) * [start_idx],
249
+ len(self.videos_path_list) * [rgb_tmpl],
250
+ len(self.videos_path_list) * [flow_tmpl],
251
+ len(self.videos_path_list) * [method],
252
+ len(self.videos_path_list) * [self.is_classname_contained],
253
+ ),
254
+ )
File without changes