plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
Cut the whole video based on the requirements
|
4
|
+
|
5
|
+
data_root = '../../../data/gym'
|
6
|
+
video_root = f'{data_root}/videos'
|
7
|
+
anno_root = f'{data_root}/annotations'
|
8
|
+
anno_file = f'{anno_root}/annotation.json'
|
9
|
+
|
10
|
+
event_anno_file = f'{anno_root}/event_annotation.json'
|
11
|
+
event_root = f'{data_root}/events'
|
12
|
+
|
13
|
+
"""
|
14
|
+
|
15
|
+
import os
|
16
|
+
import os.path as osp
|
17
|
+
import subprocess
|
18
|
+
|
19
|
+
import mmcv
|
20
|
+
|
21
|
+
|
22
|
+
def trim_event(video_root, anno_file, event_anno_file, event_root):
|
23
|
+
"""Trim the videos into many events"""
|
24
|
+
videos = os.listdir(video_root)
|
25
|
+
videos = set(videos)
|
26
|
+
annotation = mmcv.load(anno_file)
|
27
|
+
event_annotation = {}
|
28
|
+
|
29
|
+
mmcv.mkdir_or_exist(event_root)
|
30
|
+
|
31
|
+
for anno_key, anno_value in annotation.items():
|
32
|
+
if anno_key + ".mp4" not in videos:
|
33
|
+
print(f"video {anno_key} has not been downloaded")
|
34
|
+
continue
|
35
|
+
|
36
|
+
video_path = osp.join(video_root, anno_key + ".mp4")
|
37
|
+
|
38
|
+
for event_id, event_anno in anno_value.items():
|
39
|
+
timestamps = event_anno["timestamps"][0]
|
40
|
+
start_time, end_time = timestamps
|
41
|
+
event_name = anno_key + "_" + event_id
|
42
|
+
|
43
|
+
output_filename = event_name + ".mp4"
|
44
|
+
|
45
|
+
command = [
|
46
|
+
"ffmpeg",
|
47
|
+
"-i",
|
48
|
+
'"%s"' % video_path,
|
49
|
+
"-ss",
|
50
|
+
str(start_time),
|
51
|
+
"-t",
|
52
|
+
str(end_time - start_time),
|
53
|
+
"-c:v",
|
54
|
+
"libx264",
|
55
|
+
"-c:a",
|
56
|
+
"copy",
|
57
|
+
"-threads",
|
58
|
+
"8",
|
59
|
+
"-loglevel",
|
60
|
+
"panic",
|
61
|
+
'"%s"' % osp.join(event_root, output_filename),
|
62
|
+
]
|
63
|
+
command = " ".join(command)
|
64
|
+
try:
|
65
|
+
subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
|
66
|
+
except subprocess.CalledProcessError:
|
67
|
+
print(
|
68
|
+
f"Trimming of the Event {event_name} of Video {anno_key} Failed",
|
69
|
+
flush=True,
|
70
|
+
)
|
71
|
+
|
72
|
+
segments = event_anno["segments"]
|
73
|
+
if segments is not None:
|
74
|
+
event_annotation[event_name] = segments
|
75
|
+
|
76
|
+
mmcv.dump(event_annotation, event_anno_file)
|
77
|
+
|
78
|
+
|
79
|
+
# data_root = '../../../data/gym'
|
80
|
+
# anno_root = f'{data_root}/annotations'
|
81
|
+
|
82
|
+
# event_anno_file = f'{anno_root}/event_annotation.json'
|
83
|
+
# event_root = f'{data_root}/events'
|
84
|
+
# subaction_root = f'{data_root}/subactions'
|
85
|
+
|
86
|
+
|
87
|
+
def trim_subsection(event_anno_file, event_root, subaction_root):
|
88
|
+
"""Further trim the event into several subsections"""
|
89
|
+
events = os.listdir(event_root)
|
90
|
+
events = set(events)
|
91
|
+
annotation = mmcv.load(event_anno_file)
|
92
|
+
|
93
|
+
mmcv.mkdir_or_exist(subaction_root)
|
94
|
+
|
95
|
+
for anno_key, anno_value in annotation.items():
|
96
|
+
if anno_key + ".mp4" not in events:
|
97
|
+
print(
|
98
|
+
f"video {anno_key[:11]} has not been downloaded "
|
99
|
+
f"or the event clip {anno_key} not generated"
|
100
|
+
)
|
101
|
+
continue
|
102
|
+
|
103
|
+
video_path = osp.join(event_root, anno_key + ".mp4")
|
104
|
+
|
105
|
+
for subaction_id, subaction_anno in anno_value.items():
|
106
|
+
timestamps = subaction_anno["timestamps"]
|
107
|
+
start_time, end_time = timestamps[0][0], timestamps[-1][1]
|
108
|
+
subaction_name = anno_key + "_" + subaction_id
|
109
|
+
|
110
|
+
output_filename = subaction_name + ".mp4"
|
111
|
+
|
112
|
+
command = [
|
113
|
+
"ffmpeg",
|
114
|
+
"-i",
|
115
|
+
'"%s"' % video_path,
|
116
|
+
"-ss",
|
117
|
+
str(start_time),
|
118
|
+
"-t",
|
119
|
+
str(end_time - start_time),
|
120
|
+
"-c:v",
|
121
|
+
"libx264",
|
122
|
+
"-c:a",
|
123
|
+
"copy",
|
124
|
+
"-threads",
|
125
|
+
"8",
|
126
|
+
"-loglevel",
|
127
|
+
"panic",
|
128
|
+
'"%s"' % osp.join(subaction_root, output_filename),
|
129
|
+
]
|
130
|
+
command = " ".join(command)
|
131
|
+
try:
|
132
|
+
subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
|
133
|
+
except subprocess.CalledProcessError:
|
134
|
+
print(
|
135
|
+
f"Trimming of the Subaction {subaction_name} of Event "
|
136
|
+
f"{anno_key} Failed",
|
137
|
+
flush=True,
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
def generate_splits_list(data_root, annotation_root, frame_data_root):
|
142
|
+
"""Generate the split information based on the predefined files"""
|
143
|
+
videos = os.listdir(data_root)
|
144
|
+
videos = set(videos)
|
145
|
+
train_file_org = osp.join(annotation_root, "gym99_train_org.txt")
|
146
|
+
val_file_org = osp.join(annotation_root, "gym99_val_org.txt")
|
147
|
+
train_file = osp.join(annotation_root, "gym99_train.txt")
|
148
|
+
val_file = osp.join(annotation_root, "gym99_val.txt")
|
149
|
+
train_frame_file = osp.join(annotation_root, "gym99_train_rawframes.txt")
|
150
|
+
val_frame_file = osp.join(annotation_root, "gym99_val_rawframes.txt")
|
151
|
+
|
152
|
+
train_org = open(train_file_org).readlines()
|
153
|
+
|
154
|
+
train_org = [x.strip().split() for x in train_org]
|
155
|
+
|
156
|
+
train = [x for x in train_org if x[0] + ".mp4" in videos]
|
157
|
+
|
158
|
+
if osp.exists(frame_data_root):
|
159
|
+
train_frames = []
|
160
|
+
for line in train:
|
161
|
+
length = len(os.listdir(osp.join(frame_data_root, line[0])))
|
162
|
+
train_frames.append([line[0], str(length // 3), line[1]])
|
163
|
+
train_frames = [" ".join(x) for x in train_frames]
|
164
|
+
with open(train_frame_file, "w") as fout:
|
165
|
+
fout.write("\n".join(train_frames))
|
166
|
+
|
167
|
+
train = [x[0] + ".mp4 " + x[1] for x in train]
|
168
|
+
|
169
|
+
with open(train_file, "w") as fout:
|
170
|
+
fout.write("\n".join(train))
|
171
|
+
|
172
|
+
val_org = open(val_file_org).readlines()
|
173
|
+
val_org = [x.strip().split() for x in val_org]
|
174
|
+
val = [x for x in val_org if x[0] + ".mp4" in videos]
|
175
|
+
|
176
|
+
if osp.exists(frame_data_root):
|
177
|
+
val_frames = []
|
178
|
+
for line in val:
|
179
|
+
if not os.path.exists(osp.join(frame_data_root, line[0])):
|
180
|
+
continue
|
181
|
+
length = len(os.listdir(osp.join(frame_data_root, line[0])))
|
182
|
+
val_frames.append([line[0], str(length // 3), line[1]])
|
183
|
+
val_frames = [" ".join(x) for x in val_frames]
|
184
|
+
with open(val_frame_file, "w") as fout:
|
185
|
+
fout.write("\n".join(val_frames))
|
186
|
+
|
187
|
+
val = [x[0] + ".mp4 " + x[1] for x in val]
|
188
|
+
with open(val_file, "w") as fout:
|
189
|
+
fout.write("\n".join(val))
|
@@ -0,0 +1,163 @@
|
|
1
|
+
"""
|
2
|
+
The class in this file is supported by the mmaction/tools/data/build_file_list
|
3
|
+
|
4
|
+
|
5
|
+
"""
|
6
|
+
|
7
|
+
import os
|
8
|
+
import glob
|
9
|
+
import json
|
10
|
+
|
11
|
+
from mmaction.tools.data.anno_txt2json import lines2dictlist
|
12
|
+
from mmaction.tools.data.parse_file_list import parse_directory
|
13
|
+
|
14
|
+
from plato.datasources.datalib.parse_datasets import build_list, obtain_data_splits_info
|
15
|
+
|
16
|
+
|
17
|
+
class GenerateMDataAnnotation(object):
|
18
|
+
"""Generate the annotation file for the existing data modality"""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
data_src_dir,
|
23
|
+
data_annos_files_info, # a dict that contains the data splits' file path
|
24
|
+
data_format, # 'rawframes', 'videos'
|
25
|
+
out_path,
|
26
|
+
dataset_name,
|
27
|
+
data_dir_level=2,
|
28
|
+
rgb_prefix="img_'", # prefix of rgb frames
|
29
|
+
flow_x_prefix="flow_x_", # prefix of flow x frames [flow_x_ or x_]
|
30
|
+
flow_y_prefix="flow_y_", # prefix of flow y frames [flow_y_ or y_]
|
31
|
+
# shuffle=False, # whether to shuffle the file list
|
32
|
+
output_format="json",
|
33
|
+
): # txt or json
|
34
|
+
self.data_src_dir = data_src_dir
|
35
|
+
self.data_annos_files_info = data_annos_files_info
|
36
|
+
self.dataset_name = dataset_name
|
37
|
+
self.data_format = data_format
|
38
|
+
self.annotations_out_path = out_path
|
39
|
+
self.data_dir_level = data_dir_level
|
40
|
+
self.rgb_prefix = rgb_prefix
|
41
|
+
self.flow_x_prefix = flow_x_prefix
|
42
|
+
self.flow_y_prefix = flow_y_prefix
|
43
|
+
|
44
|
+
self.output_format = output_format
|
45
|
+
|
46
|
+
self.data_splits_info = None
|
47
|
+
self.frame_info = None
|
48
|
+
|
49
|
+
def read_data_splits_csv_info(self):
|
50
|
+
"""Get the data splits information from the csv annotation files"""
|
51
|
+
self.data_splits_info = obtain_data_splits_info(
|
52
|
+
data_annos_files_info=self.data_annos_files_info,
|
53
|
+
data_fir_level=2,
|
54
|
+
data_name=self.dataset_name,
|
55
|
+
)
|
56
|
+
|
57
|
+
def parse_levels_dir(self, data_src_dir):
|
58
|
+
data_dir_info = {}
|
59
|
+
""" Parse the dir with several levels. """
|
60
|
+
if self.data_dir_level == 1:
|
61
|
+
# search for one-level directory
|
62
|
+
files_list = glob.glob(os.path.join(data_src_dir, "*"))
|
63
|
+
elif self.data_dir_level == 2:
|
64
|
+
# search for two-level directory
|
65
|
+
files_list = glob.glob(os.path.join(data_src_dir, "*", "*"))
|
66
|
+
else:
|
67
|
+
raise ValueError(f"level must be 1 or 2, but got {self.data_dir_level}")
|
68
|
+
for file in files_list:
|
69
|
+
file_path = os.path.relpath(file, data_src_dir)
|
70
|
+
# for video: video_id: (video_relative_path, -1, -1)
|
71
|
+
# for audio: audio_id: (audio_relative_path, -1, -1)
|
72
|
+
data_dir_info[os.path.splitext(file_path)[0]] = (file_path, -1, -1)
|
73
|
+
|
74
|
+
return data_dir_info
|
75
|
+
|
76
|
+
def parse_dir_files(self, split):
|
77
|
+
"""Parse the dir to summary the data information"""
|
78
|
+
# The annotations for audio spectrogram features are identical to those of rawframes.
|
79
|
+
|
80
|
+
# data_format = "rawframes" if self.data_format == "audio_features" else self.data_format
|
81
|
+
# split_format_data_src_dir = os.path.join(self.data_src_dir, split,
|
82
|
+
# data_format)
|
83
|
+
|
84
|
+
split_format_data_src_dir = os.path.join(
|
85
|
+
self.data_src_dir, split, self.data_format
|
86
|
+
)
|
87
|
+
frame_info = None
|
88
|
+
if self.data_format == "rawframes":
|
89
|
+
frame_info = parse_directory(
|
90
|
+
split_format_data_src_dir,
|
91
|
+
rgb_prefix=self.rgb_prefix,
|
92
|
+
flow_x_prefix=self.flow_x_prefix,
|
93
|
+
flow_y_prefix=self.flow_y_prefix,
|
94
|
+
level=self.data_dir_level,
|
95
|
+
)
|
96
|
+
elif self.data_format == "videos":
|
97
|
+
frame_info = self.parse_levels_dir(split_format_data_src_dir)
|
98
|
+
elif self.data_format in ["audio_features", "audios"]:
|
99
|
+
# the audio anno list should be consistent with that of rawframes
|
100
|
+
rawframes_src_path = os.path.join(self.data_src_dir, split, "rawframes")
|
101
|
+
frame_info = parse_directory(
|
102
|
+
rawframes_src_path,
|
103
|
+
rgb_prefix=self.rgb_prefix,
|
104
|
+
flow_x_prefix=self.flow_x_prefix,
|
105
|
+
flow_y_prefix=self.flow_y_prefix,
|
106
|
+
level=self.data_dir_level,
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
raise NotImplementedError("only rawframes and videos are supported")
|
110
|
+
self.frame_info = frame_info
|
111
|
+
|
112
|
+
def get_anno_file_path(self, split_name):
|
113
|
+
"""Get the annotation file path"""
|
114
|
+
filename = f"{self.dataset_name}_{split_name}_list_{self.data_format}.txt"
|
115
|
+
|
116
|
+
if self.output_format == "json":
|
117
|
+
filename = filename.replace(".txt", ".json")
|
118
|
+
|
119
|
+
output_anno_file_path = os.path.join(self.annotations_out_path, filename)
|
120
|
+
|
121
|
+
return output_anno_file_path
|
122
|
+
|
123
|
+
def generate_data_splits_info_file(self, split_name):
|
124
|
+
"""Generate the data split information and write the info to file"""
|
125
|
+
self.parse_dir_files(split_name)
|
126
|
+
|
127
|
+
split_info = self.data_splits_info[split_name]
|
128
|
+
|
129
|
+
# (rgb_list, flow_list)
|
130
|
+
split_built_list = build_list(
|
131
|
+
split=split_info, frame_info=self.frame_info, shuffle=False
|
132
|
+
)
|
133
|
+
|
134
|
+
output_file_path = self.get_anno_file_path(split_name=split_name)
|
135
|
+
|
136
|
+
data_format = (
|
137
|
+
"rawframes"
|
138
|
+
if self.data_format in ["audio_features", "audios"]
|
139
|
+
else self.data_format
|
140
|
+
)
|
141
|
+
|
142
|
+
if self.output_format == "txt":
|
143
|
+
with open(output_file_path, "w") as anno_file:
|
144
|
+
anno_file.writelines(split_built_list[0])
|
145
|
+
elif self.output_format == "json":
|
146
|
+
data_list = lines2dictlist(split_built_list[0], data_format)
|
147
|
+
if self.data_format in ["audios", "audio_features"]:
|
148
|
+
|
149
|
+
def change_title_func(elem):
|
150
|
+
"""Using this function to"""
|
151
|
+
# added the filename key with value presenting the
|
152
|
+
# path of the corresponding video
|
153
|
+
if self.data_format == "audio_features":
|
154
|
+
elem["audio_path"] = elem["frame_dir"] + ".npy"
|
155
|
+
else:
|
156
|
+
elem["audio_path"] = elem["frame_dir"] + ".wav"
|
157
|
+
|
158
|
+
return elem
|
159
|
+
|
160
|
+
data_list = [change_title_func(elem) for elem in data_list]
|
161
|
+
|
162
|
+
with open(output_file_path, "w") as anno_file:
|
163
|
+
json.dump(data_list, anno_file)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
"""
|
2
|
+
Classes for parsing the structured files for the data
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import glob
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
|
10
|
+
|
11
|
+
class VideoExtractorBase:
|
12
|
+
"""The base class for the following video extractor classes"""
|
13
|
+
|
14
|
+
def __init__(
|
15
|
+
self, video_src_dir, dir_level=2, num_worker=8, video_ext="mp4", mixed_ext=False
|
16
|
+
):
|
17
|
+
self.video_src_dir = video_src_dir
|
18
|
+
self.dir_level = dir_level
|
19
|
+
self.num_worker = num_worker
|
20
|
+
self.video_ext = video_ext # support 'avi', 'mp4', 'webm'
|
21
|
+
self.mixed_ext = mixed_ext
|
22
|
+
|
23
|
+
# assert self.dir_level == 2 # we insist two-level data directory setting
|
24
|
+
|
25
|
+
logging.info("Reading videos from folder: %s", self.video_src_dir)
|
26
|
+
if self.mixed_ext:
|
27
|
+
logging.info("Using the mixture extensions of videos")
|
28
|
+
fullpath_list = glob.glob(self.video_src_dir + "/*" * self.dir_level)
|
29
|
+
else:
|
30
|
+
logging.info("Using the mixture extensions of videos: %s", self.video_ext)
|
31
|
+
fullpath_list = glob.glob(
|
32
|
+
self.video_src_dir + "/*" * self.dir_level + "." + self.video_ext
|
33
|
+
)
|
34
|
+
|
35
|
+
logging.info("Total number of videos found: %s", len(fullpath_list))
|
36
|
+
|
37
|
+
# the full path list is the full path of the video,
|
38
|
+
# for example: ./data/Kinetics/Kinetics700/train/video/clay_pottery_making/RE6YNPccYK4.mp4',
|
39
|
+
self.fullpath_list = fullpath_list
|
40
|
+
|
41
|
+
# Video item containing video full path,
|
42
|
+
# for example: clay_pottery_making/RE6YNPccYK4.mp4
|
43
|
+
self.videos_path_list = list(
|
44
|
+
map(
|
45
|
+
lambda p: os.path.join(
|
46
|
+
os.path.basename(os.path.dirname(p)), os.path.basename(p)
|
47
|
+
),
|
48
|
+
self.fullpath_list,
|
49
|
+
)
|
50
|
+
)
|
51
|
+
|
52
|
+
def organize_modality_dir(self, src_dir, to_dir):
|
53
|
+
"""Organize the data dir of the modality into two level - calssname/data_id"""
|
54
|
+
|
55
|
+
classes = os.listdir(src_dir)
|
56
|
+
for classname in classes:
|
57
|
+
new_dir = os.path.join(to_dir, classname)
|
58
|
+
if not os.path.isdir(new_dir):
|
59
|
+
os.makedirs(new_dir)
|
@@ -0,0 +1,212 @@
|
|
1
|
+
"""This part of the code heavily depends on the
|
2
|
+
tools/data/build_file_lists.py provided by the mmaction
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import csv
|
7
|
+
import random
|
8
|
+
|
9
|
+
from mmaction.tools.data.parse_file_list import (
|
10
|
+
parse_diving48_splits,
|
11
|
+
parse_hmdb51_split,
|
12
|
+
parse_jester_splits,
|
13
|
+
parse_mit_splits,
|
14
|
+
parse_mmit_splits,
|
15
|
+
parse_sthv1_splits,
|
16
|
+
parse_sthv2_splits,
|
17
|
+
parse_ucf101_splits,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
def build_list(split, frame_info, shuffle=False):
|
22
|
+
"""Build RGB and Flow file list with a given split.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
split (list): Split to be generate file list.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
tuple[list, list]: (rgb_list, flow_list), rgb_list is the
|
29
|
+
generated file list for rgb, flow_list is the generated
|
30
|
+
file list for flow.
|
31
|
+
"""
|
32
|
+
rgb_list, flow_list = list(), list()
|
33
|
+
for item in split:
|
34
|
+
if item[0] not in frame_info:
|
35
|
+
continue
|
36
|
+
elif frame_info[item[0]][1] > 0:
|
37
|
+
# rawframes
|
38
|
+
rgb_cnt = frame_info[item[0]][1]
|
39
|
+
flow_cnt = frame_info[item[0]][2]
|
40
|
+
if isinstance(item[1], int):
|
41
|
+
rgb_list.append(f"{item[0]} {rgb_cnt} {item[1]}\n")
|
42
|
+
flow_list.append(f"{item[0]} {flow_cnt} {item[1]}\n")
|
43
|
+
elif isinstance(item[1], list):
|
44
|
+
# only for multi-label datasets like mmit
|
45
|
+
rgb_list.append(
|
46
|
+
f"{item[0]} {rgb_cnt} "
|
47
|
+
+ " ".join([str(digit) for digit in item[1]])
|
48
|
+
+ "\n"
|
49
|
+
)
|
50
|
+
rgb_list.append(
|
51
|
+
f"{item[0]} {flow_cnt} "
|
52
|
+
+ " ".join([str(digit) for digit in item[1]])
|
53
|
+
+ "\n"
|
54
|
+
)
|
55
|
+
else:
|
56
|
+
raise ValueError(
|
57
|
+
"frame_info should be "
|
58
|
+
+ "[`video`(str), `label`(int)|`labels(list[int])`"
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
# videos
|
62
|
+
if isinstance(item[1], int):
|
63
|
+
rgb_list.append(f"{frame_info[item[0]][0]} {item[1]}\n")
|
64
|
+
flow_list.append(f"{frame_info[item[0]][0]} {item[1]}\n")
|
65
|
+
elif isinstance(item[1], list):
|
66
|
+
# only for multi-label datasets like mmit
|
67
|
+
rgb_list.append(
|
68
|
+
f"{frame_info[item[0]][0]} "
|
69
|
+
+ " ".join([str(digit) for digit in item[1]])
|
70
|
+
+ "\n"
|
71
|
+
)
|
72
|
+
flow_list.append(
|
73
|
+
f"{frame_info[item[0]][0]} "
|
74
|
+
+ " ".join([str(digit) for digit in item[1]])
|
75
|
+
+ "\n"
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
raise ValueError(
|
79
|
+
"frame_info should be "
|
80
|
+
+ "[`video`(str), `label`(int)|`labels(list[int])`"
|
81
|
+
)
|
82
|
+
if shuffle:
|
83
|
+
random.shuffle(rgb_list)
|
84
|
+
random.shuffle(flow_list)
|
85
|
+
return (rgb_list, flow_list)
|
86
|
+
|
87
|
+
|
88
|
+
def parse_kinetics_splits(kinetics_anntation_files_info, level, dataset_name):
|
89
|
+
"""Parse Kinetics dataset into "train", "val", "test" splits.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
kinetics_anntation_files_info (dict): The file path of the original annotation file.
|
93
|
+
The file should be the "*.csv" provided in the
|
94
|
+
official website.
|
95
|
+
For example:
|
96
|
+
{"train": ""}
|
97
|
+
level (int): Directory level of data. 1 for the single-level directory,
|
98
|
+
2 for the two-level directory.
|
99
|
+
dataset (str): Denotes the version of Kinetics that needs to be parsed,
|
100
|
+
choices are "kinetics400", "kinetics600" and "kinetics700".
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
list: "train", "val", "test" splits of Kinetics.
|
104
|
+
"""
|
105
|
+
|
106
|
+
def convert_label(label_str, keep_whitespaces=False):
|
107
|
+
"""Convert label name to a formal string.
|
108
|
+
|
109
|
+
Remove redundant '"' and convert whitespace to '_'.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
label_str (str): String to be converted.
|
113
|
+
keep_whitespaces(bool): Whether to keep whitespace. Default: False.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
str: Converted string.
|
117
|
+
"""
|
118
|
+
if not keep_whitespaces:
|
119
|
+
return label_str.replace('"', "").replace(" ", "_")
|
120
|
+
else:
|
121
|
+
return label_str.replace('"', "")
|
122
|
+
|
123
|
+
def line_to_map(line_str, test=False):
|
124
|
+
"""A function to map line string to video and label.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
line_str (str): A single line from Kinetics csv file.
|
128
|
+
test (bool): Indicate whether the line comes from test
|
129
|
+
annotation file.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
tuple[str, str]: (video, label), video is the video id,
|
133
|
+
label is the video label.
|
134
|
+
"""
|
135
|
+
if test: # x: ['---v8pgm1eQ', '0', '10', 'test']
|
136
|
+
# video = f'{x[0]}_{int(x[1]):06d}_{int(x[2]):06d}'
|
137
|
+
# video = f'{x[1]}_{int(float(x[2])):06d}_{int(float(x[3])):06d}'
|
138
|
+
video = f"{line_str[0]}_{int(float(line_str[1])):06d}_{int(float(line_str[2])):06d}"
|
139
|
+
label = -1 # label unknown
|
140
|
+
return video, label
|
141
|
+
else: # ['clay pottery making', '---0dWlqevI', '19', '29', 'train']
|
142
|
+
video = f"{line_str[1]}_{int(float(line_str[2])):06d}_{int(float(line_str[3])):06d}"
|
143
|
+
if level == 2:
|
144
|
+
video = f"{convert_label(line_str[0])}/{video}"
|
145
|
+
else:
|
146
|
+
assert level == 1
|
147
|
+
label = class_mapping[convert_label(line_str[0])]
|
148
|
+
return video, label
|
149
|
+
|
150
|
+
train_file = kinetics_anntation_files_info["train"]
|
151
|
+
test_file = kinetics_anntation_files_info["test"]
|
152
|
+
val_file = kinetics_anntation_files_info["val"]
|
153
|
+
|
154
|
+
csv_reader = csv.reader(open(train_file))
|
155
|
+
# skip the first line
|
156
|
+
next(csv_reader)
|
157
|
+
|
158
|
+
labels_sorted = sorted({convert_label(row[0]) for row in csv_reader})
|
159
|
+
class_mapping = {label: i for i, label in enumerate(labels_sorted)}
|
160
|
+
|
161
|
+
csv_reader = csv.reader(open(train_file))
|
162
|
+
next(csv_reader)
|
163
|
+
train_list = [line_to_map(x) for x in csv_reader]
|
164
|
+
|
165
|
+
csv_reader = csv.reader(open(val_file))
|
166
|
+
next(csv_reader)
|
167
|
+
val_list = [line_to_map(x) for x in csv_reader]
|
168
|
+
|
169
|
+
csv_reader = csv.reader(open(test_file))
|
170
|
+
next(csv_reader)
|
171
|
+
test_list = [line_to_map(x, test=True) for x in csv_reader]
|
172
|
+
|
173
|
+
splits = ((train_list, val_list, test_list),)
|
174
|
+
splits = {"train": train_list, "test": test_list, "val": val_list}
|
175
|
+
return splits
|
176
|
+
|
177
|
+
|
178
|
+
def obtain_data_splits_info(
|
179
|
+
data_annos_files_info, # a dict containing the data original splits' file path
|
180
|
+
data_fir_level=2,
|
181
|
+
data_name="kinetics700",
|
182
|
+
):
|
183
|
+
"""Parse the raw data file to obtain different splits info"""
|
184
|
+
if data_name == "ucf101":
|
185
|
+
splits = parse_ucf101_splits(data_fir_level)
|
186
|
+
elif data_name == "sthv1":
|
187
|
+
splits = parse_sthv1_splits(data_fir_level)
|
188
|
+
elif data_name == "sthv2":
|
189
|
+
splits = parse_sthv2_splits(data_fir_level)
|
190
|
+
elif data_name == "mit":
|
191
|
+
splits = parse_mit_splits()
|
192
|
+
elif data_name == "mmit":
|
193
|
+
splits = parse_mmit_splits()
|
194
|
+
elif data_name in ["kinetics400", "kinetics600", "kinetics700"]:
|
195
|
+
kinetics_anntation_files_info = data_annos_files_info
|
196
|
+
splits = parse_kinetics_splits(
|
197
|
+
kinetics_anntation_files_info, data_fir_level, data_name
|
198
|
+
)
|
199
|
+
elif data_name == "hmdb51":
|
200
|
+
splits = parse_hmdb51_split(data_fir_level)
|
201
|
+
elif data_name == "jester":
|
202
|
+
splits = parse_jester_splits(data_fir_level)
|
203
|
+
elif data_name == "diving48":
|
204
|
+
splits = parse_diving48_splits()
|
205
|
+
else:
|
206
|
+
raise ValueError(
|
207
|
+
f"Supported datasets are 'ucf101, sthv1, sthv2', 'jester', "
|
208
|
+
f"'mmit', 'mit', 'kinetics400', 'kinetics600', 'kinetics700', but "
|
209
|
+
f"got {data_name}"
|
210
|
+
)
|
211
|
+
|
212
|
+
return splits
|
File without changes
|