plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,189 @@
1
+ """
2
+
3
+ Cut the whole video based on the requirements
4
+
5
+ data_root = '../../../data/gym'
6
+ video_root = f'{data_root}/videos'
7
+ anno_root = f'{data_root}/annotations'
8
+ anno_file = f'{anno_root}/annotation.json'
9
+
10
+ event_anno_file = f'{anno_root}/event_annotation.json'
11
+ event_root = f'{data_root}/events'
12
+
13
+ """
14
+
15
+ import os
16
+ import os.path as osp
17
+ import subprocess
18
+
19
+ import mmcv
20
+
21
+
22
+ def trim_event(video_root, anno_file, event_anno_file, event_root):
23
+ """Trim the videos into many events"""
24
+ videos = os.listdir(video_root)
25
+ videos = set(videos)
26
+ annotation = mmcv.load(anno_file)
27
+ event_annotation = {}
28
+
29
+ mmcv.mkdir_or_exist(event_root)
30
+
31
+ for anno_key, anno_value in annotation.items():
32
+ if anno_key + ".mp4" not in videos:
33
+ print(f"video {anno_key} has not been downloaded")
34
+ continue
35
+
36
+ video_path = osp.join(video_root, anno_key + ".mp4")
37
+
38
+ for event_id, event_anno in anno_value.items():
39
+ timestamps = event_anno["timestamps"][0]
40
+ start_time, end_time = timestamps
41
+ event_name = anno_key + "_" + event_id
42
+
43
+ output_filename = event_name + ".mp4"
44
+
45
+ command = [
46
+ "ffmpeg",
47
+ "-i",
48
+ '"%s"' % video_path,
49
+ "-ss",
50
+ str(start_time),
51
+ "-t",
52
+ str(end_time - start_time),
53
+ "-c:v",
54
+ "libx264",
55
+ "-c:a",
56
+ "copy",
57
+ "-threads",
58
+ "8",
59
+ "-loglevel",
60
+ "panic",
61
+ '"%s"' % osp.join(event_root, output_filename),
62
+ ]
63
+ command = " ".join(command)
64
+ try:
65
+ subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
66
+ except subprocess.CalledProcessError:
67
+ print(
68
+ f"Trimming of the Event {event_name} of Video {anno_key} Failed",
69
+ flush=True,
70
+ )
71
+
72
+ segments = event_anno["segments"]
73
+ if segments is not None:
74
+ event_annotation[event_name] = segments
75
+
76
+ mmcv.dump(event_annotation, event_anno_file)
77
+
78
+
79
+ # data_root = '../../../data/gym'
80
+ # anno_root = f'{data_root}/annotations'
81
+
82
+ # event_anno_file = f'{anno_root}/event_annotation.json'
83
+ # event_root = f'{data_root}/events'
84
+ # subaction_root = f'{data_root}/subactions'
85
+
86
+
87
+ def trim_subsection(event_anno_file, event_root, subaction_root):
88
+ """Further trim the event into several subsections"""
89
+ events = os.listdir(event_root)
90
+ events = set(events)
91
+ annotation = mmcv.load(event_anno_file)
92
+
93
+ mmcv.mkdir_or_exist(subaction_root)
94
+
95
+ for anno_key, anno_value in annotation.items():
96
+ if anno_key + ".mp4" not in events:
97
+ print(
98
+ f"video {anno_key[:11]} has not been downloaded "
99
+ f"or the event clip {anno_key} not generated"
100
+ )
101
+ continue
102
+
103
+ video_path = osp.join(event_root, anno_key + ".mp4")
104
+
105
+ for subaction_id, subaction_anno in anno_value.items():
106
+ timestamps = subaction_anno["timestamps"]
107
+ start_time, end_time = timestamps[0][0], timestamps[-1][1]
108
+ subaction_name = anno_key + "_" + subaction_id
109
+
110
+ output_filename = subaction_name + ".mp4"
111
+
112
+ command = [
113
+ "ffmpeg",
114
+ "-i",
115
+ '"%s"' % video_path,
116
+ "-ss",
117
+ str(start_time),
118
+ "-t",
119
+ str(end_time - start_time),
120
+ "-c:v",
121
+ "libx264",
122
+ "-c:a",
123
+ "copy",
124
+ "-threads",
125
+ "8",
126
+ "-loglevel",
127
+ "panic",
128
+ '"%s"' % osp.join(subaction_root, output_filename),
129
+ ]
130
+ command = " ".join(command)
131
+ try:
132
+ subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
133
+ except subprocess.CalledProcessError:
134
+ print(
135
+ f"Trimming of the Subaction {subaction_name} of Event "
136
+ f"{anno_key} Failed",
137
+ flush=True,
138
+ )
139
+
140
+
141
+ def generate_splits_list(data_root, annotation_root, frame_data_root):
142
+ """Generate the split information based on the predefined files"""
143
+ videos = os.listdir(data_root)
144
+ videos = set(videos)
145
+ train_file_org = osp.join(annotation_root, "gym99_train_org.txt")
146
+ val_file_org = osp.join(annotation_root, "gym99_val_org.txt")
147
+ train_file = osp.join(annotation_root, "gym99_train.txt")
148
+ val_file = osp.join(annotation_root, "gym99_val.txt")
149
+ train_frame_file = osp.join(annotation_root, "gym99_train_rawframes.txt")
150
+ val_frame_file = osp.join(annotation_root, "gym99_val_rawframes.txt")
151
+
152
+ train_org = open(train_file_org).readlines()
153
+
154
+ train_org = [x.strip().split() for x in train_org]
155
+
156
+ train = [x for x in train_org if x[0] + ".mp4" in videos]
157
+
158
+ if osp.exists(frame_data_root):
159
+ train_frames = []
160
+ for line in train:
161
+ length = len(os.listdir(osp.join(frame_data_root, line[0])))
162
+ train_frames.append([line[0], str(length // 3), line[1]])
163
+ train_frames = [" ".join(x) for x in train_frames]
164
+ with open(train_frame_file, "w") as fout:
165
+ fout.write("\n".join(train_frames))
166
+
167
+ train = [x[0] + ".mp4 " + x[1] for x in train]
168
+
169
+ with open(train_file, "w") as fout:
170
+ fout.write("\n".join(train))
171
+
172
+ val_org = open(val_file_org).readlines()
173
+ val_org = [x.strip().split() for x in val_org]
174
+ val = [x for x in val_org if x[0] + ".mp4" in videos]
175
+
176
+ if osp.exists(frame_data_root):
177
+ val_frames = []
178
+ for line in val:
179
+ if not os.path.exists(osp.join(frame_data_root, line[0])):
180
+ continue
181
+ length = len(os.listdir(osp.join(frame_data_root, line[0])))
182
+ val_frames.append([line[0], str(length // 3), line[1]])
183
+ val_frames = [" ".join(x) for x in val_frames]
184
+ with open(val_frame_file, "w") as fout:
185
+ fout.write("\n".join(val_frames))
186
+
187
+ val = [x[0] + ".mp4 " + x[1] for x in val]
188
+ with open(val_file, "w") as fout:
189
+ fout.write("\n".join(val))
@@ -0,0 +1,163 @@
1
+ """
2
+ The class in this file is supported by the mmaction/tools/data/build_file_list
3
+
4
+
5
+ """
6
+
7
+ import os
8
+ import glob
9
+ import json
10
+
11
+ from mmaction.tools.data.anno_txt2json import lines2dictlist
12
+ from mmaction.tools.data.parse_file_list import parse_directory
13
+
14
+ from plato.datasources.datalib.parse_datasets import build_list, obtain_data_splits_info
15
+
16
+
17
+ class GenerateMDataAnnotation(object):
18
+ """Generate the annotation file for the existing data modality"""
19
+
20
+ def __init__(
21
+ self,
22
+ data_src_dir,
23
+ data_annos_files_info, # a dict that contains the data splits' file path
24
+ data_format, # 'rawframes', 'videos'
25
+ out_path,
26
+ dataset_name,
27
+ data_dir_level=2,
28
+ rgb_prefix="img_'", # prefix of rgb frames
29
+ flow_x_prefix="flow_x_", # prefix of flow x frames [flow_x_ or x_]
30
+ flow_y_prefix="flow_y_", # prefix of flow y frames [flow_y_ or y_]
31
+ # shuffle=False, # whether to shuffle the file list
32
+ output_format="json",
33
+ ): # txt or json
34
+ self.data_src_dir = data_src_dir
35
+ self.data_annos_files_info = data_annos_files_info
36
+ self.dataset_name = dataset_name
37
+ self.data_format = data_format
38
+ self.annotations_out_path = out_path
39
+ self.data_dir_level = data_dir_level
40
+ self.rgb_prefix = rgb_prefix
41
+ self.flow_x_prefix = flow_x_prefix
42
+ self.flow_y_prefix = flow_y_prefix
43
+
44
+ self.output_format = output_format
45
+
46
+ self.data_splits_info = None
47
+ self.frame_info = None
48
+
49
+ def read_data_splits_csv_info(self):
50
+ """Get the data splits information from the csv annotation files"""
51
+ self.data_splits_info = obtain_data_splits_info(
52
+ data_annos_files_info=self.data_annos_files_info,
53
+ data_fir_level=2,
54
+ data_name=self.dataset_name,
55
+ )
56
+
57
+ def parse_levels_dir(self, data_src_dir):
58
+ data_dir_info = {}
59
+ """ Parse the dir with several levels. """
60
+ if self.data_dir_level == 1:
61
+ # search for one-level directory
62
+ files_list = glob.glob(os.path.join(data_src_dir, "*"))
63
+ elif self.data_dir_level == 2:
64
+ # search for two-level directory
65
+ files_list = glob.glob(os.path.join(data_src_dir, "*", "*"))
66
+ else:
67
+ raise ValueError(f"level must be 1 or 2, but got {self.data_dir_level}")
68
+ for file in files_list:
69
+ file_path = os.path.relpath(file, data_src_dir)
70
+ # for video: video_id: (video_relative_path, -1, -1)
71
+ # for audio: audio_id: (audio_relative_path, -1, -1)
72
+ data_dir_info[os.path.splitext(file_path)[0]] = (file_path, -1, -1)
73
+
74
+ return data_dir_info
75
+
76
+ def parse_dir_files(self, split):
77
+ """Parse the dir to summary the data information"""
78
+ # The annotations for audio spectrogram features are identical to those of rawframes.
79
+
80
+ # data_format = "rawframes" if self.data_format == "audio_features" else self.data_format
81
+ # split_format_data_src_dir = os.path.join(self.data_src_dir, split,
82
+ # data_format)
83
+
84
+ split_format_data_src_dir = os.path.join(
85
+ self.data_src_dir, split, self.data_format
86
+ )
87
+ frame_info = None
88
+ if self.data_format == "rawframes":
89
+ frame_info = parse_directory(
90
+ split_format_data_src_dir,
91
+ rgb_prefix=self.rgb_prefix,
92
+ flow_x_prefix=self.flow_x_prefix,
93
+ flow_y_prefix=self.flow_y_prefix,
94
+ level=self.data_dir_level,
95
+ )
96
+ elif self.data_format == "videos":
97
+ frame_info = self.parse_levels_dir(split_format_data_src_dir)
98
+ elif self.data_format in ["audio_features", "audios"]:
99
+ # the audio anno list should be consistent with that of rawframes
100
+ rawframes_src_path = os.path.join(self.data_src_dir, split, "rawframes")
101
+ frame_info = parse_directory(
102
+ rawframes_src_path,
103
+ rgb_prefix=self.rgb_prefix,
104
+ flow_x_prefix=self.flow_x_prefix,
105
+ flow_y_prefix=self.flow_y_prefix,
106
+ level=self.data_dir_level,
107
+ )
108
+ else:
109
+ raise NotImplementedError("only rawframes and videos are supported")
110
+ self.frame_info = frame_info
111
+
112
+ def get_anno_file_path(self, split_name):
113
+ """Get the annotation file path"""
114
+ filename = f"{self.dataset_name}_{split_name}_list_{self.data_format}.txt"
115
+
116
+ if self.output_format == "json":
117
+ filename = filename.replace(".txt", ".json")
118
+
119
+ output_anno_file_path = os.path.join(self.annotations_out_path, filename)
120
+
121
+ return output_anno_file_path
122
+
123
+ def generate_data_splits_info_file(self, split_name):
124
+ """Generate the data split information and write the info to file"""
125
+ self.parse_dir_files(split_name)
126
+
127
+ split_info = self.data_splits_info[split_name]
128
+
129
+ # (rgb_list, flow_list)
130
+ split_built_list = build_list(
131
+ split=split_info, frame_info=self.frame_info, shuffle=False
132
+ )
133
+
134
+ output_file_path = self.get_anno_file_path(split_name=split_name)
135
+
136
+ data_format = (
137
+ "rawframes"
138
+ if self.data_format in ["audio_features", "audios"]
139
+ else self.data_format
140
+ )
141
+
142
+ if self.output_format == "txt":
143
+ with open(output_file_path, "w") as anno_file:
144
+ anno_file.writelines(split_built_list[0])
145
+ elif self.output_format == "json":
146
+ data_list = lines2dictlist(split_built_list[0], data_format)
147
+ if self.data_format in ["audios", "audio_features"]:
148
+
149
+ def change_title_func(elem):
150
+ """Using this function to"""
151
+ # added the filename key with value presenting the
152
+ # path of the corresponding video
153
+ if self.data_format == "audio_features":
154
+ elem["audio_path"] = elem["frame_dir"] + ".npy"
155
+ else:
156
+ elem["audio_path"] = elem["frame_dir"] + ".wav"
157
+
158
+ return elem
159
+
160
+ data_list = [change_title_func(elem) for elem in data_list]
161
+
162
+ with open(output_file_path, "w") as anno_file:
163
+ json.dump(data_list, anno_file)
@@ -0,0 +1,59 @@
1
+ """
2
+ Classes for parsing the structured files for the data
3
+
4
+ """
5
+
6
+ import glob
7
+ import logging
8
+ import os
9
+
10
+
11
+ class VideoExtractorBase:
12
+ """The base class for the following video extractor classes"""
13
+
14
+ def __init__(
15
+ self, video_src_dir, dir_level=2, num_worker=8, video_ext="mp4", mixed_ext=False
16
+ ):
17
+ self.video_src_dir = video_src_dir
18
+ self.dir_level = dir_level
19
+ self.num_worker = num_worker
20
+ self.video_ext = video_ext # support 'avi', 'mp4', 'webm'
21
+ self.mixed_ext = mixed_ext
22
+
23
+ # assert self.dir_level == 2 # we insist two-level data directory setting
24
+
25
+ logging.info("Reading videos from folder: %s", self.video_src_dir)
26
+ if self.mixed_ext:
27
+ logging.info("Using the mixture extensions of videos")
28
+ fullpath_list = glob.glob(self.video_src_dir + "/*" * self.dir_level)
29
+ else:
30
+ logging.info("Using the mixture extensions of videos: %s", self.video_ext)
31
+ fullpath_list = glob.glob(
32
+ self.video_src_dir + "/*" * self.dir_level + "." + self.video_ext
33
+ )
34
+
35
+ logging.info("Total number of videos found: %s", len(fullpath_list))
36
+
37
+ # the full path list is the full path of the video,
38
+ # for example: ./data/Kinetics/Kinetics700/train/video/clay_pottery_making/RE6YNPccYK4.mp4',
39
+ self.fullpath_list = fullpath_list
40
+
41
+ # Video item containing video full path,
42
+ # for example: clay_pottery_making/RE6YNPccYK4.mp4
43
+ self.videos_path_list = list(
44
+ map(
45
+ lambda p: os.path.join(
46
+ os.path.basename(os.path.dirname(p)), os.path.basename(p)
47
+ ),
48
+ self.fullpath_list,
49
+ )
50
+ )
51
+
52
+ def organize_modality_dir(self, src_dir, to_dir):
53
+ """Organize the data dir of the modality into two level - calssname/data_id"""
54
+
55
+ classes = os.listdir(src_dir)
56
+ for classname in classes:
57
+ new_dir = os.path.join(to_dir, classname)
58
+ if not os.path.isdir(new_dir):
59
+ os.makedirs(new_dir)
@@ -0,0 +1,212 @@
1
+ """This part of the code heavily depends on the
2
+ tools/data/build_file_lists.py provided by the mmaction
3
+
4
+ """
5
+
6
+ import csv
7
+ import random
8
+
9
+ from mmaction.tools.data.parse_file_list import (
10
+ parse_diving48_splits,
11
+ parse_hmdb51_split,
12
+ parse_jester_splits,
13
+ parse_mit_splits,
14
+ parse_mmit_splits,
15
+ parse_sthv1_splits,
16
+ parse_sthv2_splits,
17
+ parse_ucf101_splits,
18
+ )
19
+
20
+
21
+ def build_list(split, frame_info, shuffle=False):
22
+ """Build RGB and Flow file list with a given split.
23
+
24
+ Args:
25
+ split (list): Split to be generate file list.
26
+
27
+ Returns:
28
+ tuple[list, list]: (rgb_list, flow_list), rgb_list is the
29
+ generated file list for rgb, flow_list is the generated
30
+ file list for flow.
31
+ """
32
+ rgb_list, flow_list = list(), list()
33
+ for item in split:
34
+ if item[0] not in frame_info:
35
+ continue
36
+ elif frame_info[item[0]][1] > 0:
37
+ # rawframes
38
+ rgb_cnt = frame_info[item[0]][1]
39
+ flow_cnt = frame_info[item[0]][2]
40
+ if isinstance(item[1], int):
41
+ rgb_list.append(f"{item[0]} {rgb_cnt} {item[1]}\n")
42
+ flow_list.append(f"{item[0]} {flow_cnt} {item[1]}\n")
43
+ elif isinstance(item[1], list):
44
+ # only for multi-label datasets like mmit
45
+ rgb_list.append(
46
+ f"{item[0]} {rgb_cnt} "
47
+ + " ".join([str(digit) for digit in item[1]])
48
+ + "\n"
49
+ )
50
+ rgb_list.append(
51
+ f"{item[0]} {flow_cnt} "
52
+ + " ".join([str(digit) for digit in item[1]])
53
+ + "\n"
54
+ )
55
+ else:
56
+ raise ValueError(
57
+ "frame_info should be "
58
+ + "[`video`(str), `label`(int)|`labels(list[int])`"
59
+ )
60
+ else:
61
+ # videos
62
+ if isinstance(item[1], int):
63
+ rgb_list.append(f"{frame_info[item[0]][0]} {item[1]}\n")
64
+ flow_list.append(f"{frame_info[item[0]][0]} {item[1]}\n")
65
+ elif isinstance(item[1], list):
66
+ # only for multi-label datasets like mmit
67
+ rgb_list.append(
68
+ f"{frame_info[item[0]][0]} "
69
+ + " ".join([str(digit) for digit in item[1]])
70
+ + "\n"
71
+ )
72
+ flow_list.append(
73
+ f"{frame_info[item[0]][0]} "
74
+ + " ".join([str(digit) for digit in item[1]])
75
+ + "\n"
76
+ )
77
+ else:
78
+ raise ValueError(
79
+ "frame_info should be "
80
+ + "[`video`(str), `label`(int)|`labels(list[int])`"
81
+ )
82
+ if shuffle:
83
+ random.shuffle(rgb_list)
84
+ random.shuffle(flow_list)
85
+ return (rgb_list, flow_list)
86
+
87
+
88
+ def parse_kinetics_splits(kinetics_anntation_files_info, level, dataset_name):
89
+ """Parse Kinetics dataset into "train", "val", "test" splits.
90
+
91
+ Args:
92
+ kinetics_anntation_files_info (dict): The file path of the original annotation file.
93
+ The file should be the "*.csv" provided in the
94
+ official website.
95
+ For example:
96
+ {"train": ""}
97
+ level (int): Directory level of data. 1 for the single-level directory,
98
+ 2 for the two-level directory.
99
+ dataset (str): Denotes the version of Kinetics that needs to be parsed,
100
+ choices are "kinetics400", "kinetics600" and "kinetics700".
101
+
102
+ Returns:
103
+ list: "train", "val", "test" splits of Kinetics.
104
+ """
105
+
106
+ def convert_label(label_str, keep_whitespaces=False):
107
+ """Convert label name to a formal string.
108
+
109
+ Remove redundant '"' and convert whitespace to '_'.
110
+
111
+ Args:
112
+ label_str (str): String to be converted.
113
+ keep_whitespaces(bool): Whether to keep whitespace. Default: False.
114
+
115
+ Returns:
116
+ str: Converted string.
117
+ """
118
+ if not keep_whitespaces:
119
+ return label_str.replace('"', "").replace(" ", "_")
120
+ else:
121
+ return label_str.replace('"', "")
122
+
123
+ def line_to_map(line_str, test=False):
124
+ """A function to map line string to video and label.
125
+
126
+ Args:
127
+ line_str (str): A single line from Kinetics csv file.
128
+ test (bool): Indicate whether the line comes from test
129
+ annotation file.
130
+
131
+ Returns:
132
+ tuple[str, str]: (video, label), video is the video id,
133
+ label is the video label.
134
+ """
135
+ if test: # x: ['---v8pgm1eQ', '0', '10', 'test']
136
+ # video = f'{x[0]}_{int(x[1]):06d}_{int(x[2]):06d}'
137
+ # video = f'{x[1]}_{int(float(x[2])):06d}_{int(float(x[3])):06d}'
138
+ video = f"{line_str[0]}_{int(float(line_str[1])):06d}_{int(float(line_str[2])):06d}"
139
+ label = -1 # label unknown
140
+ return video, label
141
+ else: # ['clay pottery making', '---0dWlqevI', '19', '29', 'train']
142
+ video = f"{line_str[1]}_{int(float(line_str[2])):06d}_{int(float(line_str[3])):06d}"
143
+ if level == 2:
144
+ video = f"{convert_label(line_str[0])}/{video}"
145
+ else:
146
+ assert level == 1
147
+ label = class_mapping[convert_label(line_str[0])]
148
+ return video, label
149
+
150
+ train_file = kinetics_anntation_files_info["train"]
151
+ test_file = kinetics_anntation_files_info["test"]
152
+ val_file = kinetics_anntation_files_info["val"]
153
+
154
+ csv_reader = csv.reader(open(train_file))
155
+ # skip the first line
156
+ next(csv_reader)
157
+
158
+ labels_sorted = sorted({convert_label(row[0]) for row in csv_reader})
159
+ class_mapping = {label: i for i, label in enumerate(labels_sorted)}
160
+
161
+ csv_reader = csv.reader(open(train_file))
162
+ next(csv_reader)
163
+ train_list = [line_to_map(x) for x in csv_reader]
164
+
165
+ csv_reader = csv.reader(open(val_file))
166
+ next(csv_reader)
167
+ val_list = [line_to_map(x) for x in csv_reader]
168
+
169
+ csv_reader = csv.reader(open(test_file))
170
+ next(csv_reader)
171
+ test_list = [line_to_map(x, test=True) for x in csv_reader]
172
+
173
+ splits = ((train_list, val_list, test_list),)
174
+ splits = {"train": train_list, "test": test_list, "val": val_list}
175
+ return splits
176
+
177
+
178
+ def obtain_data_splits_info(
179
+ data_annos_files_info, # a dict containing the data original splits' file path
180
+ data_fir_level=2,
181
+ data_name="kinetics700",
182
+ ):
183
+ """Parse the raw data file to obtain different splits info"""
184
+ if data_name == "ucf101":
185
+ splits = parse_ucf101_splits(data_fir_level)
186
+ elif data_name == "sthv1":
187
+ splits = parse_sthv1_splits(data_fir_level)
188
+ elif data_name == "sthv2":
189
+ splits = parse_sthv2_splits(data_fir_level)
190
+ elif data_name == "mit":
191
+ splits = parse_mit_splits()
192
+ elif data_name == "mmit":
193
+ splits = parse_mmit_splits()
194
+ elif data_name in ["kinetics400", "kinetics600", "kinetics700"]:
195
+ kinetics_anntation_files_info = data_annos_files_info
196
+ splits = parse_kinetics_splits(
197
+ kinetics_anntation_files_info, data_fir_level, data_name
198
+ )
199
+ elif data_name == "hmdb51":
200
+ splits = parse_hmdb51_split(data_fir_level)
201
+ elif data_name == "jester":
202
+ splits = parse_jester_splits(data_fir_level)
203
+ elif data_name == "diving48":
204
+ splits = parse_diving48_splits()
205
+ else:
206
+ raise ValueError(
207
+ f"Supported datasets are 'ucf101, sthv1, sthv2', 'jester', "
208
+ f"'mmit', 'mit', 'kinetics400', 'kinetics600', 'kinetics700', but "
209
+ f"got {data_name}"
210
+ )
211
+
212
+ return splits
File without changes