plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,568 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
The Kinetics700 dataset.
|
4
|
+
|
5
|
+
Note that the setting for the data loader is obtained from the github
|
6
|
+
repo provided by the official workers:
|
7
|
+
https://github.com/pytorch/vision/references/video_classification/train.py
|
8
|
+
|
9
|
+
We consider three modalities: RGB, optical flow and audio.
|
10
|
+
For RGB and flow, we use input clips of 16×224×224 as input.
|
11
|
+
We follow [1] for visual pre-processing and augmentation.
|
12
|
+
For audio, we use log-Mel with 100 temporal frames by 40 Mel filters.
|
13
|
+
|
14
|
+
Audio and visual are temporally aligned.
|
15
|
+
|
16
|
+
[1]. Video classification with channel-separated convolutional networks.
|
17
|
+
In ICCV, 2019. (CSN network)
|
18
|
+
This is actually the csn network in the mmaction packet.
|
19
|
+
|
20
|
+
Also, the implementation of our code is based on the mmaction2 of the
|
21
|
+
openmmlab https://openmmlab.com/.
|
22
|
+
|
23
|
+
The data structure is:
|
24
|
+
├── data
|
25
|
+
│ ├── ${DATASET}
|
26
|
+
│ │ ├── ${DATASET}_train_list_videos.txt
|
27
|
+
│ │ ├── ${DATASET}_val_list_videos.txt
|
28
|
+
│ │ ├── annotations
|
29
|
+
│ │ ├── videos_train
|
30
|
+
│ │ ├── videos_val
|
31
|
+
│ │ │ ├── abseiling
|
32
|
+
│ │ │ │ ├── 0wR5jVB-WPk_000417_000427.mp4
|
33
|
+
│ │ │ │ ├── ...
|
34
|
+
│ │ │ ├── ...
|
35
|
+
│ │ │ ├── wrapping_present
|
36
|
+
│ │ │ ├── ...
|
37
|
+
│ │ │ ├── zumba
|
38
|
+
│ │ ├── rawframes_train
|
39
|
+
│ │ ├── rawframes_val
|
40
|
+
|
41
|
+
|
42
|
+
For data_formates, we support "videos", "rawframes", "audios", "audio_features"
|
43
|
+
for modality, we support "video", "audio", "audio_feature", "rgb", "flow"
|
44
|
+
|
45
|
+
"""
|
46
|
+
|
47
|
+
import re
|
48
|
+
import logging
|
49
|
+
import os
|
50
|
+
import shutil
|
51
|
+
from collections import defaultdict
|
52
|
+
|
53
|
+
import torch
|
54
|
+
|
55
|
+
from mmaction.tools.data.kinetics import download as kinetics_downloader
|
56
|
+
from mmaction.datasets import build_dataset
|
57
|
+
|
58
|
+
from plato.config import Config
|
59
|
+
from plato.datasources import multimodal_base
|
60
|
+
from plato.datasources.datalib import frames_extraction_tools
|
61
|
+
from plato.datasources.datalib import audio_extraction_tools
|
62
|
+
from plato.datasources.datalib import modality_data_anntation_tools
|
63
|
+
from plato.datasources.datalib import data_utils
|
64
|
+
from plato.datasources.datalib import tiny_data_tools
|
65
|
+
|
66
|
+
|
67
|
+
def obtain_required_anno_files(splits_info):
|
68
|
+
"""Obtain the general full/tiny annotation files for splits"""
|
69
|
+
required_anno_files = {"train": "", "test": "", "val": ""}
|
70
|
+
for split in ["train", "test", "val"]:
|
71
|
+
split_info = splits_info[split]
|
72
|
+
# Obtain the annotation files for the whole dataset
|
73
|
+
if hasattr(Config().data, "tiny_data") and Config().data.tiny_data:
|
74
|
+
split_anno_path = split_info["split_tiny_anno_file"]
|
75
|
+
else: # Obtain the annotation files for the tiny dataset
|
76
|
+
split_anno_path = split_info["split_anno_file"]
|
77
|
+
|
78
|
+
required_anno_files[split] = split_anno_path
|
79
|
+
return required_anno_files
|
80
|
+
|
81
|
+
|
82
|
+
class KineticsDataset(multimodal_base.MultiModalDataset):
|
83
|
+
"""Prepare the Kinetics dataset."""
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self, multimodal_data_holder, phase, phase_info, modality_sampler=None
|
87
|
+
):
|
88
|
+
super().__init__()
|
89
|
+
self.phase = phase
|
90
|
+
# multimodal_data_holder is a dict:
|
91
|
+
# {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
|
92
|
+
self.phase_multimodal_data_record = multimodal_data_holder
|
93
|
+
|
94
|
+
# a dict presented as:
|
95
|
+
# "rgb": <rgb_annotation_file_path>
|
96
|
+
self.phase_info = phase_info
|
97
|
+
|
98
|
+
self.modalities_name = list(multimodal_data_holder.keys())
|
99
|
+
|
100
|
+
self.supported_modalities = ["rgb", "flow", "audio_feature"]
|
101
|
+
|
102
|
+
# default utilizing the full modalities
|
103
|
+
if modality_sampler is None:
|
104
|
+
self.modality_sampler = self.supported_modalities
|
105
|
+
else:
|
106
|
+
self.modality_sampler = modality_sampler
|
107
|
+
|
108
|
+
self.targets = self.get_targets()
|
109
|
+
|
110
|
+
def __len__(self):
|
111
|
+
return len(self.phase_multimodal_data_record)
|
112
|
+
|
113
|
+
def get_targets(self):
|
114
|
+
"""Obtain the labels of samples in current phase dataset."""
|
115
|
+
# the order of samples in rgb, flow, or audio annotation files
|
116
|
+
# is maintained the same, thus obtain either one is great.
|
117
|
+
# Normally, rgb and flow belong to the rawframes
|
118
|
+
rawframes_anno_list_file_path = self.phase_info["rgb"]
|
119
|
+
annos_list = data_utils.read_anno_file(rawframes_anno_list_file_path)
|
120
|
+
|
121
|
+
obtained_targets = [anno_item["label"][0] for anno_item in annos_list]
|
122
|
+
|
123
|
+
return obtained_targets
|
124
|
+
|
125
|
+
def get_one_multimodal_sample(self, sample_idx):
|
126
|
+
"""Obtain one sample from the Kinetics dataset."""
|
127
|
+
obtained_mm_sample = dict()
|
128
|
+
|
129
|
+
for modality_name in self.modalities_name:
|
130
|
+
modality_dataset = self.phase_multimodal_data_record[modality_name]
|
131
|
+
obtained_mm_sample[modality_name] = modality_dataset[sample_idx]
|
132
|
+
|
133
|
+
return obtained_mm_sample
|
134
|
+
|
135
|
+
|
136
|
+
class DataSource(multimodal_base.MultiModalDataSource):
|
137
|
+
"""The Kinetics datasource."""
|
138
|
+
|
139
|
+
def __init__(self, **kwargs):
|
140
|
+
super().__init__()
|
141
|
+
|
142
|
+
self.data_name = Config().data.datasource
|
143
|
+
base_data_name = re.findall(r"\D+", self.data_name)[0]
|
144
|
+
|
145
|
+
# The rawframes contains the "flow", "rgb"
|
146
|
+
# thus, the flow and rgb will be put in the same directory rawframes/
|
147
|
+
self.modality_names = ["video", "audio", "rgb", "flow", "audio_feature"]
|
148
|
+
|
149
|
+
_path = Config().params["data_path"]
|
150
|
+
# Generate the basic path for the dataset, it performs:
|
151
|
+
# 1.- Assign path to self.mm_data_info
|
152
|
+
# 2.- Assign splits path to self.splits_info
|
153
|
+
# where the root path for splits is the data_path
|
154
|
+
# in self.mm_data_info
|
155
|
+
self._data_path_process(data_path=_path, base_data_name=self.data_name)
|
156
|
+
# Generate the modalities path for all splits, it performs:
|
157
|
+
# 1.- Add modality path to each modality, the key style is:
|
158
|
+
# {modality}_path: ...
|
159
|
+
# Note: the rgb and flow modalities are merged into 'rawframes_path'
|
160
|
+
# as they belong to the same prototype "rawframes".
|
161
|
+
self._create_modalities_path(modality_names=self.modality_names)
|
162
|
+
|
163
|
+
# Set the annotation file path
|
164
|
+
base_data_path = self.mm_data_info["data_path"]
|
165
|
+
|
166
|
+
# Define all the dir here
|
167
|
+
kinetics_anno_dir_name = "annotations"
|
168
|
+
self.data_annotation_path = os.path.join(base_data_path, kinetics_anno_dir_name)
|
169
|
+
|
170
|
+
for split in ["train", "test", "validate"]:
|
171
|
+
split_anno_path = os.path.join(
|
172
|
+
self.data_annotation_path, base_data_name + "_" + split + ".csv"
|
173
|
+
)
|
174
|
+
split_tiny_anno_path = os.path.join(
|
175
|
+
self.data_annotation_path, base_data_name + "_" + split + "_tiny.csv"
|
176
|
+
)
|
177
|
+
split_name = split if split != "validate" else "val"
|
178
|
+
self.splits_info[split_name]["split_anno_file"] = split_anno_path
|
179
|
+
self.splits_info[split_name]["split_tiny_anno_file"] = split_tiny_anno_path
|
180
|
+
|
181
|
+
# Thus, after operating the above two functions,
|
182
|
+
# the self.splits_info can contain
|
183
|
+
# e.g. {'train':
|
184
|
+
# 'path': xxx,
|
185
|
+
# 'split_anno_file': xxx,
|
186
|
+
# 'split_tiny_anno_file': xxx,
|
187
|
+
# 'rawframes_path': xxx,
|
188
|
+
# 'video_path': xxx}
|
189
|
+
# the self.mm_data_info can contain
|
190
|
+
# - source_data_path
|
191
|
+
# - data_path
|
192
|
+
|
193
|
+
anno_download_url = (
|
194
|
+
"https://storage.googleapis.com/deepmind-media/Datasets/{}.tar.gz"
|
195
|
+
).format(self.data_name)
|
196
|
+
|
197
|
+
extracted_anno_file_name = self._download_arrange_data(
|
198
|
+
download_url_address=anno_download_url,
|
199
|
+
data_path=self.data_annotation_path,
|
200
|
+
obtained_file_name=None,
|
201
|
+
)
|
202
|
+
download_anno_path = os.path.join(
|
203
|
+
self.data_annotation_path, extracted_anno_file_name
|
204
|
+
)
|
205
|
+
|
206
|
+
downloaded_files = os.listdir(download_anno_path)
|
207
|
+
for file_name in downloaded_files:
|
208
|
+
new_file_name = base_data_name + "_" + file_name
|
209
|
+
shutil.move(
|
210
|
+
os.path.join(download_anno_path, file_name),
|
211
|
+
os.path.join(self.data_annotation_path, new_file_name),
|
212
|
+
)
|
213
|
+
|
214
|
+
# Whether to create the tiny dataset
|
215
|
+
for split in ["train", "test", "validate"]:
|
216
|
+
split_name = split if split != "validate" else "val"
|
217
|
+
split_anno_path = self.splits_info[split_name]["split_anno_file"]
|
218
|
+
|
219
|
+
if hasattr(Config().data, "tiny_data") and Config().data.tiny_data:
|
220
|
+
anno_files_info = {
|
221
|
+
"train": self.splits_info["train"]["split_anno_file"],
|
222
|
+
"test": self.splits_info["test"]["split_anno_file"],
|
223
|
+
"val": self.splits_info["val"]["split_anno_file"],
|
224
|
+
}
|
225
|
+
tiny_data_tools.create_tiny_kinetics_anno(
|
226
|
+
kinetics_annotation_files_info=anno_files_info,
|
227
|
+
num_samples=Config().data.tiny_data_number,
|
228
|
+
random_seed=Config().data.random_seed,
|
229
|
+
)
|
230
|
+
|
231
|
+
# Download the raw datasets for splits
|
232
|
+
# There is no need to download data for test as the test dataset of kinetics does not
|
233
|
+
# contain labels.
|
234
|
+
required_anno_files = obtain_required_anno_files(self.splits_info)
|
235
|
+
for split in ["train", "val"]:
|
236
|
+
split_anno_path = required_anno_files[split]
|
237
|
+
video_path_format = self.set_modality_path_key_format(modality_name="video")
|
238
|
+
video_dir = self.splits_info[split][video_path_format]
|
239
|
+
if not self._exists(video_dir):
|
240
|
+
num_workers = Config().data.downloader.num_workers
|
241
|
+
# Set the tmp_dir to save the raw video
|
242
|
+
# Then, the raw video will be clipped to save to
|
243
|
+
# the target video_dir
|
244
|
+
|
245
|
+
tmp_dir = os.path.join(video_dir, "tmp")
|
246
|
+
logging.info(
|
247
|
+
"Downloading the raw videos for the %s %s dataset. This may take a long time.",
|
248
|
+
self.data_name,
|
249
|
+
split,
|
250
|
+
)
|
251
|
+
kinetics_downloader.main(
|
252
|
+
input_csv=split_anno_path,
|
253
|
+
output_dir=video_dir,
|
254
|
+
trim_format="%06d",
|
255
|
+
num_jobs=num_workers,
|
256
|
+
tmp_dir=tmp_dir,
|
257
|
+
)
|
258
|
+
# Rename of class name
|
259
|
+
for split in ["train", "val"]:
|
260
|
+
self.rename_classes(mode=split)
|
261
|
+
|
262
|
+
logging.info("Done.")
|
263
|
+
|
264
|
+
logging.info("The %s dataset has been prepared", self.data_name)
|
265
|
+
|
266
|
+
# Extract rgb, flow, audio, audio_feature from the video
|
267
|
+
for split in ["train", "val"]:
|
268
|
+
self.extract_videos_rgb_flow_audio(mode=split)
|
269
|
+
|
270
|
+
# Extract the splits information into the
|
271
|
+
# list corresponding files
|
272
|
+
self.audios_splits_list_files_into = self.extract_splits_list_files(
|
273
|
+
data_format="audio_features", splits=["train", "val"]
|
274
|
+
)
|
275
|
+
|
276
|
+
self.video_splits_list_files_into = self.extract_splits_list_files(
|
277
|
+
data_format="videos", splits=["train", "val"]
|
278
|
+
)
|
279
|
+
self.rawframes_splits_list_files_into = self.extract_splits_list_files(
|
280
|
+
data_format="rawframes", splits=["train", "val"]
|
281
|
+
)
|
282
|
+
|
283
|
+
def get_modality_name(self):
|
284
|
+
"""Get all supports modalities"""
|
285
|
+
return ["rgb", "flow", "audio"]
|
286
|
+
|
287
|
+
def rename_classes(self, mode):
|
288
|
+
"""Rename classes by replacing whitespace to 'Underscore'"""
|
289
|
+
video_format_path_key = self.set_modality_path_key_format(modality_name="video")
|
290
|
+
videos_root__path = self.splits_info[mode][video_format_path_key]
|
291
|
+
videos_dirs_name = [
|
292
|
+
dir_name
|
293
|
+
for dir_name in os.listdir(videos_root__path)
|
294
|
+
if os.path.isdir(os.path.join(videos_root__path, dir_name))
|
295
|
+
]
|
296
|
+
|
297
|
+
new_videos_dirs_name = [
|
298
|
+
dir_name.replace(" ", "_") for dir_name in videos_dirs_name
|
299
|
+
]
|
300
|
+
|
301
|
+
videos_dirs_path = [
|
302
|
+
os.path.join(videos_root__path, dir_name) for dir_name in videos_dirs_name
|
303
|
+
]
|
304
|
+
new_videos_dirs_path = [
|
305
|
+
os.path.join(videos_root__path, dir_name)
|
306
|
+
for dir_name in new_videos_dirs_name
|
307
|
+
]
|
308
|
+
for i, _ in enumerate(videos_dirs_path):
|
309
|
+
os.rename(videos_dirs_path[i], new_videos_dirs_path[i])
|
310
|
+
|
311
|
+
def get_modality_data_path(self, mode, modality_name):
|
312
|
+
"""Obtain the path for the modality data in specific mode"""
|
313
|
+
|
314
|
+
modality_key = self.set_modality_path_key_format(modality_name=modality_name)
|
315
|
+
|
316
|
+
return self.splits_info[mode][modality_key]
|
317
|
+
|
318
|
+
def extract_videos_rgb_flow_audio(self, mode="train"):
|
319
|
+
"""Extract rgb, optical flow, and audio from videos"""
|
320
|
+
video_data_path = self.get_modality_data_path(mode=mode, modality_name="video")
|
321
|
+
src_mode_videos_dir = video_data_path
|
322
|
+
|
323
|
+
rgb_out__path = self.get_modality_data_path(mode=mode, modality_name="rgb")
|
324
|
+
flow_our__path = self.get_modality_data_path(mode=mode, modality_name="flow")
|
325
|
+
audio_out__path = self.get_modality_data_path(mode=mode, modality_name="audio")
|
326
|
+
audio_feature__path = self.get_modality_data_path(
|
327
|
+
mode=mode, modality_name="audio_feature"
|
328
|
+
)
|
329
|
+
|
330
|
+
# define the modalities extractor
|
331
|
+
if not self._exists(rgb_out__path):
|
332
|
+
vdf_extractor = frames_extraction_tools.VideoFramesExtractor(
|
333
|
+
video_src_dir=src_mode_videos_dir,
|
334
|
+
dir_level=2,
|
335
|
+
num_worker=8,
|
336
|
+
video_ext="mp4",
|
337
|
+
mixed_ext=False,
|
338
|
+
)
|
339
|
+
if not self._exists(audio_out__path) or not self._exists(audio_feature__path):
|
340
|
+
vda_extractor = audio_extraction_tools.VideoAudioExtractor(
|
341
|
+
video_src_dir=src_mode_videos_dir,
|
342
|
+
dir_level=2,
|
343
|
+
num_worker=8,
|
344
|
+
video_ext="mp4",
|
345
|
+
mixed_ext=False,
|
346
|
+
)
|
347
|
+
|
348
|
+
if torch.cuda.is_available():
|
349
|
+
if not self._exists(rgb_out__path):
|
350
|
+
logging.info(
|
351
|
+
"Extracting frames by GPU from videos in %s to %s.",
|
352
|
+
src_mode_videos_dir,
|
353
|
+
rgb_out__path,
|
354
|
+
)
|
355
|
+
vdf_extractor.build_frames_gpu(
|
356
|
+
rgb_out__path,
|
357
|
+
flow_our__path,
|
358
|
+
new_short=1,
|
359
|
+
new_width=0,
|
360
|
+
new_height=0,
|
361
|
+
)
|
362
|
+
else:
|
363
|
+
if not self._exists(rgb_out__path):
|
364
|
+
logging.info(
|
365
|
+
"Extracting frames by CPU from videos in %s to %s.",
|
366
|
+
src_mode_videos_dir,
|
367
|
+
rgb_out__path,
|
368
|
+
)
|
369
|
+
vdf_extractor.build_frames_cpu(to_dir=rgb_out__path)
|
370
|
+
|
371
|
+
if not self._exists(audio_out__path):
|
372
|
+
logging.info(
|
373
|
+
"Extracting audios by CPU from videos in %s to %s.",
|
374
|
+
src_mode_videos_dir,
|
375
|
+
audio_out__path,
|
376
|
+
)
|
377
|
+
vda_extractor.build_audios(to_dir=audio_out__path)
|
378
|
+
|
379
|
+
if not self._exists(audio_feature__path):
|
380
|
+
logging.info(
|
381
|
+
"Extracting audios feature by CPU from audios in %s to %s.",
|
382
|
+
audio_out__path,
|
383
|
+
audio_feature__path,
|
384
|
+
)
|
385
|
+
# # window_size:32ms hop_size:16ms
|
386
|
+
vda_extractor.build_audios_features(
|
387
|
+
audio_src_path=audio_out__path,
|
388
|
+
to_dir=audio_feature__path,
|
389
|
+
fft_size=512, # fft_size / sample_rate is window size
|
390
|
+
hop_size=256,
|
391
|
+
)
|
392
|
+
|
393
|
+
def extract_splits_list_files(self, data_format, splits):
|
394
|
+
"""Extract and generate the split information of current mode/phase"""
|
395
|
+
output_format = "json"
|
396
|
+
out_path = self.mm_data_info["data_path"]
|
397
|
+
|
398
|
+
# obtained a dict that contains the required data splits' file path
|
399
|
+
# it can be full data or tiny data
|
400
|
+
required_anno_files = obtain_required_anno_files(self.splits_info)
|
401
|
+
data_splits_file_info = required_anno_files
|
402
|
+
gen_annots_op = modality_data_anntation_tools.GenerateMDataAnnotation(
|
403
|
+
data_src_dir=self.mm_data_info["data_path"],
|
404
|
+
data_annos_files_info=data_splits_file_info,
|
405
|
+
dataset_name=self.data_name,
|
406
|
+
data_format=data_format, # 'rawframes', 'videos', 'audio_features'
|
407
|
+
rgb_prefix="img_", # prefix of rgb frames
|
408
|
+
flow_x_prefix="x_", # prefix of flow x frames
|
409
|
+
flow_y_prefix="y_", # prefix of flow y frames
|
410
|
+
out_path=out_path,
|
411
|
+
output_format=output_format,
|
412
|
+
)
|
413
|
+
|
414
|
+
target_list_regu = f"_{data_format}.{output_format}"
|
415
|
+
if not self._file_exists(
|
416
|
+
tg_file_name=target_list_regu, search_path=out_path, is_partial_name=True
|
417
|
+
):
|
418
|
+
logging.info("Extracting annotation list for %s. ", data_format)
|
419
|
+
|
420
|
+
gen_annots_op.read_data_splits_csv_info()
|
421
|
+
|
422
|
+
for split_name in splits:
|
423
|
+
gen_annots_op.generate_data_splits_info_file(split_name=split_name)
|
424
|
+
|
425
|
+
# obtain the extracted files path
|
426
|
+
generated_list_files_info = {}
|
427
|
+
for split_name in splits:
|
428
|
+
generated_list_files_info[split_name] = gen_annots_op.get_anno_file_path(
|
429
|
+
split_name
|
430
|
+
)
|
431
|
+
|
432
|
+
return generated_list_files_info
|
433
|
+
|
434
|
+
def correct_current_config(self, loaded_plato_config, mode, modality_name):
|
435
|
+
"""Correct the loaded configuration settings based on on-hand data information."""
|
436
|
+
|
437
|
+
# 1.1. convert plato config to dict type
|
438
|
+
loaded_config = data_utils.config_to_dict(loaded_plato_config)
|
439
|
+
# 1.2. convert the list to tuple
|
440
|
+
loaded_config = data_utils.dict_list2tuple(loaded_config)
|
441
|
+
|
442
|
+
# 2. using the obtained annotation file replace the user set ones
|
443
|
+
# in the configuration file
|
444
|
+
# The main reason is that the obtained path here is the full path
|
445
|
+
cur_rawframes_anno_file_path = self.rawframes_splits_list_files_into[mode]
|
446
|
+
cur_rawframes_data_path = self.get_modality_data_path(
|
447
|
+
mode=mode, modality_name="rgb"
|
448
|
+
)
|
449
|
+
cur_videos_anno_file_path = self.video_splits_list_files_into[mode]
|
450
|
+
cur_video_data_path = self.get_modality_data_path(
|
451
|
+
mode=mode, modality_name="video"
|
452
|
+
)
|
453
|
+
cur_audio_feas_anno_file_path = self.audios_splits_list_files_into[mode]
|
454
|
+
cur_audio_feas_data_path = self.get_modality_data_path(
|
455
|
+
mode=mode, modality_name="audio_feature"
|
456
|
+
)
|
457
|
+
|
458
|
+
if modality_name == "rgb" or modality_name == "flow":
|
459
|
+
loaded_config["ann_file"] = cur_rawframes_anno_file_path
|
460
|
+
elif modality_name == "audio_feature":
|
461
|
+
loaded_config["ann_file"] = cur_audio_feas_anno_file_path
|
462
|
+
else:
|
463
|
+
loaded_config["ann_file"] = cur_videos_anno_file_path
|
464
|
+
|
465
|
+
# 3. reset the data_prefix by using the modality path
|
466
|
+
if modality_name == "rgb" or modality_name == "flow":
|
467
|
+
loaded_config["data_prefix"] = cur_rawframes_data_path
|
468
|
+
elif modality_name == "audio_feature":
|
469
|
+
loaded_config["data_prefix"] = cur_audio_feas_data_path
|
470
|
+
else:
|
471
|
+
loaded_config["data_prefix"] = cur_video_data_path
|
472
|
+
|
473
|
+
return loaded_config
|
474
|
+
|
475
|
+
def get_phase_dataset(self, phase, modality_sampler):
|
476
|
+
"""Get the dataset for the specific phase."""
|
477
|
+
rgb_mode_config = getattr(Config().data.multi_modal_configs.rgb, phase)
|
478
|
+
flow_mode_config = getattr(Config().data.multi_modal_configs.flow, phase)
|
479
|
+
audio_feature_mode_config = getattr(
|
480
|
+
Config().data.multi_modal_configs.audio_feature, phase
|
481
|
+
)
|
482
|
+
|
483
|
+
rgb_mode_config = self.correct_current_config(
|
484
|
+
loaded_plato_config=rgb_mode_config, mode=phase, modality_name="rgb"
|
485
|
+
)
|
486
|
+
flow_mode_config = self.correct_current_config(
|
487
|
+
loaded_plato_config=flow_mode_config, mode=phase, modality_name="flow"
|
488
|
+
)
|
489
|
+
audio_feature_mode_config = self.correct_current_config(
|
490
|
+
loaded_plato_config=audio_feature_mode_config,
|
491
|
+
mode=phase,
|
492
|
+
modality_name="audio_feature",
|
493
|
+
)
|
494
|
+
# build a RawframeDataset
|
495
|
+
rgb_mode_dataset = build_dataset(rgb_mode_config)
|
496
|
+
flow_mode_dataset = build_dataset(flow_mode_config)
|
497
|
+
audio_feature_mode_dataset = build_dataset(audio_feature_mode_config)
|
498
|
+
|
499
|
+
multi_modal_mode_data = {
|
500
|
+
"rgb": rgb_mode_dataset,
|
501
|
+
"flow": flow_mode_dataset,
|
502
|
+
"audio_feature": audio_feature_mode_dataset,
|
503
|
+
}
|
504
|
+
|
505
|
+
multi_modal_mode_info = {
|
506
|
+
"rgb": rgb_mode_config["ann_file"],
|
507
|
+
"flow": flow_mode_config["ann_file"],
|
508
|
+
"audio_feature": audio_feature_mode_config["ann_file"],
|
509
|
+
}
|
510
|
+
|
511
|
+
kinetics_mode_dataset = KineticsDataset(
|
512
|
+
multimodal_data_holder=multi_modal_mode_data,
|
513
|
+
phase="train",
|
514
|
+
phase_info=multi_modal_mode_info,
|
515
|
+
modality_sampler=modality_sampler,
|
516
|
+
)
|
517
|
+
|
518
|
+
return kinetics_mode_dataset
|
519
|
+
|
520
|
+
def get_train_set(self, modality_sampler=None):
|
521
|
+
"""Obtain the trainset for multimodal data."""
|
522
|
+
kinetics_train_dataset = self.get_phase_dataset(
|
523
|
+
phase="train", modality_sampler=modality_sampler
|
524
|
+
)
|
525
|
+
|
526
|
+
return kinetics_train_dataset
|
527
|
+
|
528
|
+
def get_test_set(self, modality_sampler=None):
|
529
|
+
"""Obtain the testset for multimodal data.
|
530
|
+
|
531
|
+
Note, in the kinetics dataset, there is no testset in which
|
532
|
+
samples contain the groundtruth label.
|
533
|
+
Thus, we utilize the validation set directly.
|
534
|
+
"""
|
535
|
+
kinetics_val_dataset = self.get_phase_dataset(
|
536
|
+
phase="val", modality_sampler=modality_sampler
|
537
|
+
)
|
538
|
+
|
539
|
+
return kinetics_val_dataset
|
540
|
+
|
541
|
+
def get_class_label_mapper(self):
|
542
|
+
"""Obtain the mapper used to map the text to integer."""
|
543
|
+
textclass_integer_mapper = defaultdict(list)
|
544
|
+
# obtain the classes from the trainset
|
545
|
+
train_anno_list_path = self.rawframes_splits_list_files_into["train"]
|
546
|
+
train_anno_list = data_utils.read_anno_file(train_anno_list_path)
|
547
|
+
# [{"frame_dir": "clay_pottery_making/---0dWlqevI_000019_000029",
|
548
|
+
# "total_frames": 300, "label": [0]}
|
549
|
+
for item in train_anno_list:
|
550
|
+
textclass = item["frame_dir"].split("/")[0]
|
551
|
+
integar_label = item["frame_dir"]["label"][0]
|
552
|
+
|
553
|
+
textclass_integer_mapper[textclass].append(integar_label)
|
554
|
+
|
555
|
+
return textclass_integer_mapper
|
556
|
+
|
557
|
+
def classes(self):
|
558
|
+
"""The classes of the dataset."""
|
559
|
+
|
560
|
+
# obtain the classes from the trainset
|
561
|
+
train_anno_list_path = self.rawframes_splits_list_files_into["train"]
|
562
|
+
train_anno_list = data_utils.read_anno_file(train_anno_list_path)
|
563
|
+
|
564
|
+
integer_labels = [anno_elem["label"][0] for anno_elem in train_anno_list]
|
565
|
+
integer_classes = list(set(integer_labels))
|
566
|
+
integer_classes.sort()
|
567
|
+
|
568
|
+
return integer_classes
|
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
The MNIST dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from torchvision import datasets, transforms
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
from plato.datasources import base
|
9
|
+
|
10
|
+
|
11
|
+
class DataSource(base.DataSource):
|
12
|
+
"""The MNIST dataset."""
|
13
|
+
|
14
|
+
def __init__(self, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
_path = Config().params["data_path"]
|
17
|
+
|
18
|
+
train_transform = (
|
19
|
+
kwargs["train_transform"]
|
20
|
+
if "train_transform" in kwargs
|
21
|
+
else transforms.Compose(
|
22
|
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
23
|
+
)
|
24
|
+
)
|
25
|
+
|
26
|
+
test_transform = (
|
27
|
+
kwargs["test_transform"]
|
28
|
+
if "test_transform" in kwargs
|
29
|
+
else transforms.Compose(
|
30
|
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
31
|
+
)
|
32
|
+
)
|
33
|
+
self.trainset = datasets.MNIST(
|
34
|
+
root=_path, train=True, download=True, transform=train_transform
|
35
|
+
)
|
36
|
+
self.testset = datasets.MNIST(
|
37
|
+
root=_path, train=False, download=True, transform=test_transform
|
38
|
+
)
|
39
|
+
|
40
|
+
def num_train_examples(self):
|
41
|
+
return 60000
|
42
|
+
|
43
|
+
def num_test_examples(self):
|
44
|
+
return 10000
|