plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
plato/datasources/gym.py
ADDED
@@ -0,0 +1,431 @@
|
|
1
|
+
"""
|
2
|
+
The Gym dataset.
|
3
|
+
|
4
|
+
Note that the setting for the data loader is obtained from the github repo provided
|
5
|
+
by the official workers:
|
6
|
+
Finegym: A hierarchical video dataset for fine-grained action understanding
|
7
|
+
|
8
|
+
The data structure should be:
|
9
|
+
|
10
|
+
├── data
|
11
|
+
│ ├── gym99
|
12
|
+
| | ├── annotations
|
13
|
+
| | | ├── gym99_train_org.txt
|
14
|
+
| | | ├── gym99_val_org.txt
|
15
|
+
| | | ├── gym99_train.txt
|
16
|
+
| | | ├── gym99_val.txt
|
17
|
+
| | | ├── annotation.json
|
18
|
+
| | | └── event_annotation.json
|
19
|
+
│ │ ├── videos
|
20
|
+
| | | ├── 0LtLS9wROrk.mp4
|
21
|
+
| | | ├── ...
|
22
|
+
| | | └── zfqS-wCJSsw.mp4
|
23
|
+
│ │ ├── events
|
24
|
+
| | | ├── 0LtLS9wROrk_E_002407_002435.mp4
|
25
|
+
| | | ├── ...
|
26
|
+
| | | └── zfqS-wCJSsw_E_006732_006824.mp4
|
27
|
+
│ │ ├── subactions
|
28
|
+
| | | ├── 0LtLS9wROrk_E_002407_002435_A_0003_0005.mp4
|
29
|
+
| | | ├── ...
|
30
|
+
| | | └── zfqS-wCJSsw_E_006244_006252_A_0000_0007.mp4
|
31
|
+
| | └── subaction_frames
|
32
|
+
| | |── subaction_audios
|
33
|
+
|
34
|
+
"""
|
35
|
+
|
36
|
+
import logging
|
37
|
+
import os
|
38
|
+
import shutil
|
39
|
+
|
40
|
+
import torch
|
41
|
+
|
42
|
+
from mmaction.tools.data.gym import download as gym_downloader
|
43
|
+
from mmaction.datasets import build_dataset
|
44
|
+
|
45
|
+
from plato.config import Config
|
46
|
+
from plato.datasources.datalib.gym_utils import gym_trim
|
47
|
+
from plato.datasources import multimodal_base
|
48
|
+
from plato.datasources.datalib import frames_extraction_tools
|
49
|
+
from plato.datasources.datalib import audio_extraction_tools
|
50
|
+
from plato.datasources.datalib import data_utils
|
51
|
+
|
52
|
+
|
53
|
+
class GymDataset(multimodal_base.MultiModalDataset):
|
54
|
+
"""Prepare the Gym dataset."""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self, multimodal_data_holder, phase, phase_info, modality_sampler=None
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
self.phase = phase
|
61
|
+
# multimodal_data_holder is a dict:
|
62
|
+
# {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
|
63
|
+
self.phase_multimodal_data_record = multimodal_data_holder
|
64
|
+
|
65
|
+
# a dict presented as:
|
66
|
+
# "rgb": <rgb_annotation_file_path>
|
67
|
+
self.phase_info = phase_info
|
68
|
+
|
69
|
+
self.modalities_name = list(multimodal_data_holder.keys())
|
70
|
+
|
71
|
+
self.supported_modalities = ["rgb", "flow", "audio_feature"]
|
72
|
+
|
73
|
+
# default utilizing the full modalities
|
74
|
+
if modality_sampler is None:
|
75
|
+
self.modality_sampler = self.supported_modalities
|
76
|
+
else:
|
77
|
+
self.modality_sampler = modality_sampler
|
78
|
+
|
79
|
+
self.targets = self.get_targets()
|
80
|
+
|
81
|
+
def __len__(self):
|
82
|
+
return len(self.phase_multimodal_data_record)
|
83
|
+
|
84
|
+
def get_targets(self):
|
85
|
+
"""Obtain the labels of samples in current phase dataset."""
|
86
|
+
# There is no label provided in the fine gym dataset currently
|
87
|
+
# This part will be added afterward
|
88
|
+
return [0]
|
89
|
+
|
90
|
+
def get_one_multimodal_sample(self, sample_idx):
|
91
|
+
"""Obtain one sample from the Kinetics dataset."""
|
92
|
+
obtained_mm_sample = dict()
|
93
|
+
|
94
|
+
for modality_name in self.modalities_name:
|
95
|
+
modality_dataset = self.phase_multimodal_data_record[modality_name]
|
96
|
+
obtained_mm_sample[modality_name] = modality_dataset[sample_idx]
|
97
|
+
|
98
|
+
return obtained_mm_sample
|
99
|
+
|
100
|
+
|
101
|
+
class DataSource(multimodal_base.MultiModalDataSource):
|
102
|
+
"""The Gym dataset."""
|
103
|
+
|
104
|
+
def __init__(self, **kwargs):
|
105
|
+
super().__init__()
|
106
|
+
|
107
|
+
self.data_name = Config().data.datasource
|
108
|
+
|
109
|
+
# the rawframes contains the "flow", "rgb"
|
110
|
+
# thus, the flow and rgb will be put in in same directory rawframes/
|
111
|
+
# self.modality_names = ["video", "audio", "rawframes", "audio_feature"]
|
112
|
+
self.modality_names = ["video", "audio", "rgb", "flow", "audio_feature"]
|
113
|
+
|
114
|
+
_path = Config().params["data_path"]
|
115
|
+
self._data_path_process(data_path=_path, base_data_name=self.data_name)
|
116
|
+
self._create_modalities_path(modality_names=self.modality_names)
|
117
|
+
|
118
|
+
base_data_path = self.mm_data_info["data_path"]
|
119
|
+
# define all the dir here
|
120
|
+
gym_anno_dir_name = "annotations"
|
121
|
+
self.data_annotation_path = os.path.join(base_data_path, gym_anno_dir_name)
|
122
|
+
|
123
|
+
self.data_anno_file_path = os.path.join(
|
124
|
+
self.data_annotation_path, "annotation.json"
|
125
|
+
)
|
126
|
+
self.categoty_anno_file_path = os.path.join(
|
127
|
+
self.data_annotation_path, "gym99_categories.txt"
|
128
|
+
)
|
129
|
+
|
130
|
+
self.raw_videos_path = os.path.join(base_data_path, "videos")
|
131
|
+
self.event__path = os.path.join(base_data_path, "event")
|
132
|
+
self.event_subsection__path = os.path.join(base_data_path, "subactions")
|
133
|
+
self.data_event_anno_file_path = os.path.join(
|
134
|
+
self.data_annotation_path, "event_annotation.json"
|
135
|
+
)
|
136
|
+
self.event_subsection_frames__path = os.path.join(
|
137
|
+
base_data_path, "subaction_rawframes"
|
138
|
+
)
|
139
|
+
self.event_subsection_audios__path = os.path.join(
|
140
|
+
base_data_path, "subaction_audios"
|
141
|
+
)
|
142
|
+
|
143
|
+
self.event_subsection_audios_fea__path = os.path.join(
|
144
|
+
base_data_path, "subaction_audios_features"
|
145
|
+
)
|
146
|
+
|
147
|
+
self.rawframes_splits_list_files_into = {
|
148
|
+
"train": os.path.join(
|
149
|
+
self.data_annotation_path, "gym99_train_rawframes.txt"
|
150
|
+
),
|
151
|
+
"val": os.path.join(self.data_annotation_path, "gym99_val_rawframes.txt"),
|
152
|
+
}
|
153
|
+
|
154
|
+
self.audios_splits_list_files_into = {
|
155
|
+
"train": os.path.join(self.data_annotation_path, "gym99_train_audios.txt"),
|
156
|
+
"val": os.path.join(self.data_annotation_path, "gym99_val_audios.txt"),
|
157
|
+
}
|
158
|
+
self.audio_features_splits_list_files_into = {
|
159
|
+
"train": os.path.join(
|
160
|
+
self.data_annotation_path, "gym99_train_audio_features.txt"
|
161
|
+
),
|
162
|
+
"val": os.path.join(
|
163
|
+
self.data_annotation_path, "gym99_val_audio_features.txt"
|
164
|
+
),
|
165
|
+
}
|
166
|
+
|
167
|
+
set_level_category_url = (
|
168
|
+
"https://sdolivia.github.io/FineGym/resources/dataset/set_categories.txt"
|
169
|
+
)
|
170
|
+
g99_categoty_url = (
|
171
|
+
"https://sdolivia.github.io/FineGym/resources/dataset/gym99_categories.txt"
|
172
|
+
)
|
173
|
+
|
174
|
+
anno_url = "https://sdolivia.github.io/FineGym/resources/dataset/finegym_annotation_info_v1.0.json"
|
175
|
+
|
176
|
+
train_url = "https://sdolivia.github.io/FineGym/resources/dataset/gym99_train_element_v1.0.txt"
|
177
|
+
|
178
|
+
eval_url = (
|
179
|
+
"https://sdolivia.github.io/FineGym/resources/dataset/gym99_val_element.txt"
|
180
|
+
)
|
181
|
+
|
182
|
+
_ = self._download_arrange_data(
|
183
|
+
download_url_address=set_level_category_url,
|
184
|
+
data_path=self.data_annotation_path,
|
185
|
+
obtained_file_name="set_categories.txt",
|
186
|
+
)
|
187
|
+
|
188
|
+
_ = self._download_arrange_data(
|
189
|
+
download_url_address=g99_categoty_url,
|
190
|
+
data_path=self.data_annotation_path,
|
191
|
+
obtained_file_name="gym99_categories.txt",
|
192
|
+
)
|
193
|
+
|
194
|
+
_ = self._download_arrange_data(
|
195
|
+
download_url_address=anno_url,
|
196
|
+
data_path=self.data_annotation_path,
|
197
|
+
obtained_file_name="annotation.json",
|
198
|
+
)
|
199
|
+
|
200
|
+
_ = self._download_arrange_data(
|
201
|
+
download_url_address=train_url,
|
202
|
+
data_path=self.data_annotation_path,
|
203
|
+
obtained_file_name="gym99_train_org.txt",
|
204
|
+
)
|
205
|
+
|
206
|
+
_ = self._download_arrange_data(
|
207
|
+
download_url_address=eval_url,
|
208
|
+
data_path=self.data_annotation_path,
|
209
|
+
obtained_file_name="gym99_val_org.txt",
|
210
|
+
)
|
211
|
+
|
212
|
+
if not self._exists(self.raw_videos_path):
|
213
|
+
logging.info(
|
214
|
+
"Downloading the raw videos for the Gym dataset. This may take a long time."
|
215
|
+
)
|
216
|
+
|
217
|
+
gym_downloader.main(
|
218
|
+
input=self.data_anno_file_path,
|
219
|
+
output_dir=self.raw_videos_path,
|
220
|
+
num_jobs=Config().data.downloader.num_workers,
|
221
|
+
)
|
222
|
+
logging.info("Done.")
|
223
|
+
|
224
|
+
# Trim Videos into Events
|
225
|
+
if not self._exists(self.event__path):
|
226
|
+
gym_trim.trim_event(
|
227
|
+
video_root=self.raw_videos_path,
|
228
|
+
anno_file=self.data_anno_file_path,
|
229
|
+
event_anno_file=self.data_event_anno_file_path,
|
230
|
+
event_root=self.event__path,
|
231
|
+
)
|
232
|
+
if not self._exists(self.event_subsection__path):
|
233
|
+
gym_trim.trim_subsection(
|
234
|
+
event_anno_file=self.data_event_anno_file_path,
|
235
|
+
event_root=self.event__path,
|
236
|
+
subaction_root=self.event_subsection__path,
|
237
|
+
)
|
238
|
+
|
239
|
+
logging.info("The Gym dataset has been prepared")
|
240
|
+
self.extract_videos_rgb_flow_audio()
|
241
|
+
|
242
|
+
def extract_videos_rgb_flow_audio(self):
|
243
|
+
"""Extract the rgb optical flow audios from the video"""
|
244
|
+
src_videos_dir = self.event_subsection__path
|
245
|
+
frames_out__path = self.event_subsection_frames__path
|
246
|
+
rgb_out__path = self.event_subsection_frames__path
|
247
|
+
flow_our__path = self.event_subsection_frames__path
|
248
|
+
audio_out__path = self.event_subsection_audios__path
|
249
|
+
audio_feature__path = self.event_subsection_audios_fea__path
|
250
|
+
|
251
|
+
# define the modalities extractor
|
252
|
+
vdf_extractor = frames_extraction_tools.VideoFramesExtractor(
|
253
|
+
video_src_dir=src_videos_dir,
|
254
|
+
dir_level=1,
|
255
|
+
num_worker=8,
|
256
|
+
video_ext="mp4",
|
257
|
+
mixed_ext=False,
|
258
|
+
)
|
259
|
+
vda_extractor = audio_extraction_tools.VideoAudioExtractor(
|
260
|
+
video_src_dir=src_videos_dir,
|
261
|
+
dir_level=1,
|
262
|
+
num_worker=8,
|
263
|
+
video_ext="mp4",
|
264
|
+
mixed_ext=False,
|
265
|
+
)
|
266
|
+
|
267
|
+
if torch.cuda.is_available():
|
268
|
+
if not self._exists(rgb_out__path) and not self._exists(flow_our__path):
|
269
|
+
logging.info(
|
270
|
+
"Extracting frames by GPU from videos in %s to %s.",
|
271
|
+
src_videos_dir,
|
272
|
+
rgb_out__path,
|
273
|
+
)
|
274
|
+
vdf_extractor.build_full_frames_gpu(
|
275
|
+
to__path=frames_out__path, new_short=256, new_width=0, new_height=0
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
if not self._exists(rgb_out__path):
|
279
|
+
logging.info(
|
280
|
+
"Extracting frames by CPU from videos in %s to %s.",
|
281
|
+
src_videos_dir,
|
282
|
+
rgb_out__path,
|
283
|
+
)
|
284
|
+
vdf_extractor.build_frames_cpu(to_dir=frames_out__path)
|
285
|
+
|
286
|
+
if not self._exists(audio_out__path):
|
287
|
+
logging.info(
|
288
|
+
"Extracting audios by CPU from videos in %s to %s.",
|
289
|
+
src_videos_dir,
|
290
|
+
audio_out__path,
|
291
|
+
)
|
292
|
+
vda_extractor.build_audios(to_dir=audio_out__path)
|
293
|
+
|
294
|
+
if not self._exists(audio_feature__path):
|
295
|
+
logging.info(
|
296
|
+
"Extracting audios feature by CPU from audios in %s to %s.",
|
297
|
+
audio_out__path,
|
298
|
+
audio_feature__path,
|
299
|
+
)
|
300
|
+
# # window_size:32ms hop_size:16ms
|
301
|
+
|
302
|
+
vda_extractor.build_audios_features(
|
303
|
+
audio_src_path=audio_out__path,
|
304
|
+
to_dir=audio_feature__path,
|
305
|
+
fft_size=512, # fft_size / sample_rate is window size
|
306
|
+
hop_size=256,
|
307
|
+
)
|
308
|
+
# extract the splits data into list files based on the frames information
|
309
|
+
gym_trim.generate_splits_list(
|
310
|
+
data_root=self.event_subsection__path,
|
311
|
+
annotation_root=self.data_annotation_path,
|
312
|
+
frame_data_root=frames_out__path,
|
313
|
+
)
|
314
|
+
|
315
|
+
# generate the audio and audio features splits file
|
316
|
+
# just copy the frame files to the audio ones
|
317
|
+
for split in list(self.rawframes_splits_list_files_into.keys()):
|
318
|
+
rawframes_split_file_path = self.rawframes_splits_list_files_into[split]
|
319
|
+
audios_split_file_path = self.audios_splits_list_files_into[split]
|
320
|
+
audio_features_split_file_path = self.audios_splits_list_files_into[split]
|
321
|
+
shutil.copy(src=rawframes_split_file_path, dst=audios_split_file_path)
|
322
|
+
shutil.copy(
|
323
|
+
src=rawframes_split_file_path, dst=audio_features_split_file_path
|
324
|
+
)
|
325
|
+
|
326
|
+
def correct_current_config(self, loaded_plato_config, mode, modality_name):
|
327
|
+
"""Correct the loaded configuration settings based on
|
328
|
+
on-hand data information"""
|
329
|
+
|
330
|
+
# 1.1. convert plato config to dict type
|
331
|
+
loaded_config = data_utils.config_to_dict(loaded_plato_config)
|
332
|
+
# 1.2. convert the list to tuple
|
333
|
+
loaded_config = data_utils.dict_list2tuple(loaded_config)
|
334
|
+
|
335
|
+
# 2. using the obtained annotation file replace the user set ones
|
336
|
+
# in the configuration file
|
337
|
+
# The main reason is that the obtained path here is the full path
|
338
|
+
cur_rawframes_anno_file_path = self.rawframes_splits_list_files_into[mode]
|
339
|
+
cur_rawframes_data_path = self.event_subsection_frames__path
|
340
|
+
cur_videos_anno_file_path = None
|
341
|
+
cur_video_data_path = self.event_subsection__path
|
342
|
+
cur_audio_feas_anno_file_path = self.audios_splits_list_files_into[mode]
|
343
|
+
cur_audio_feas_data_path = self.event_subsection_audios__path
|
344
|
+
|
345
|
+
if modality_name == "rgb" or modality_name == "flow":
|
346
|
+
loaded_config["ann_file"] = cur_rawframes_anno_file_path
|
347
|
+
elif modality_name == "audio_feature":
|
348
|
+
loaded_config["ann_file"] = cur_audio_feas_anno_file_path
|
349
|
+
else:
|
350
|
+
loaded_config["ann_file"] = cur_videos_anno_file_path
|
351
|
+
|
352
|
+
# 3. reset the data_prefix by using the modality path
|
353
|
+
if modality_name == "rgb" or modality_name == "flow":
|
354
|
+
loaded_config["data_prefix"] = cur_rawframes_data_path
|
355
|
+
elif modality_name == "audio_feature":
|
356
|
+
loaded_config["data_prefix"] = cur_audio_feas_data_path
|
357
|
+
else:
|
358
|
+
loaded_config["data_prefix"] = cur_video_data_path
|
359
|
+
|
360
|
+
return loaded_config
|
361
|
+
|
362
|
+
def get_phase_dataset(self, phase, modality_sampler):
|
363
|
+
"""Get the dataset for the specific phase."""
|
364
|
+
rgb_mode_config = getattr(Config().data.multi_modal_configs.rgb, phase)
|
365
|
+
flow_mode_config = getattr(Config().data.multi_modal_configs.flow, phase)
|
366
|
+
audio_feature_mode_config = getattr(
|
367
|
+
Config().data.multi_modal_configs.audio_feature, phase
|
368
|
+
)
|
369
|
+
|
370
|
+
rgb_mode_config = self.correct_current_config(
|
371
|
+
loaded_plato_config=rgb_mode_config, mode=phase, modality_name="rgb"
|
372
|
+
)
|
373
|
+
flow_mode_config = self.correct_current_config(
|
374
|
+
loaded_plato_config=flow_mode_config, mode=phase, modality_name="flow"
|
375
|
+
)
|
376
|
+
audio_feature_mode_config = self.correct_current_config(
|
377
|
+
loaded_plato_config=audio_feature_mode_config,
|
378
|
+
mode=phase,
|
379
|
+
modality_name="audio_feature",
|
380
|
+
)
|
381
|
+
# build a RawframeDataset
|
382
|
+
rgb_mode_dataset = build_dataset(rgb_mode_config)
|
383
|
+
flow_mode_dataset = build_dataset(flow_mode_config)
|
384
|
+
audio_feature_mode_dataset = build_dataset(audio_feature_mode_config)
|
385
|
+
|
386
|
+
multi_modal_mode_data = {
|
387
|
+
"rgb": rgb_mode_dataset,
|
388
|
+
"flow": flow_mode_dataset,
|
389
|
+
"audio_feature": audio_feature_mode_dataset,
|
390
|
+
}
|
391
|
+
|
392
|
+
multi_modal_mode_info = {
|
393
|
+
"rgb": rgb_mode_config["ann_file"],
|
394
|
+
"flow": flow_mode_config["ann_file"],
|
395
|
+
"audio_feature": audio_feature_mode_config["ann_file"],
|
396
|
+
"categories": self.categoty_anno_file_path,
|
397
|
+
}
|
398
|
+
|
399
|
+
gym_mode_dataset = GymDataset(
|
400
|
+
multimodal_data_holder=multi_modal_mode_data,
|
401
|
+
phase="train",
|
402
|
+
phase_info=multi_modal_mode_info,
|
403
|
+
modality_sampler=modality_sampler,
|
404
|
+
)
|
405
|
+
|
406
|
+
return gym_mode_dataset
|
407
|
+
|
408
|
+
def get_train_set(self, modality_sampler=None):
|
409
|
+
"""Obtain the trainset for multimodal data."""
|
410
|
+
gym_train_dataset = self.get_phase_dataset(
|
411
|
+
phase="train", modality_sampler=modality_sampler
|
412
|
+
)
|
413
|
+
|
414
|
+
return gym_train_dataset
|
415
|
+
|
416
|
+
def get_test_set(self, modality_sampler=None):
|
417
|
+
"""Obtain the testset for multimodal data.
|
418
|
+
|
419
|
+
Note, in the kinetics dataset, there is no testset in which
|
420
|
+
samples contain the groundtruth label.
|
421
|
+
Thus, we utilize the validation set directly.
|
422
|
+
"""
|
423
|
+
gym_val_dataset = self.get_phase_dataset(
|
424
|
+
phase="val", modality_sampler=modality_sampler
|
425
|
+
)
|
426
|
+
|
427
|
+
return gym_val_dataset
|
428
|
+
|
429
|
+
def get_modality_name(self):
|
430
|
+
"""Get all supports modalities"""
|
431
|
+
return ["rgb", "flow", "audio"]
|
@@ -0,0 +1,165 @@
|
|
1
|
+
"""
|
2
|
+
A data source for the HuggingFace datasets.
|
3
|
+
|
4
|
+
For more information about the HuggingFace datasets, refer to:
|
5
|
+
|
6
|
+
https://huggingface.co/docs/datasets/quicktour.html
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
|
12
|
+
from datasets import load_dataset, load_from_disk
|
13
|
+
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser
|
14
|
+
from transformers import TrainingArguments, testing_utils, utils
|
15
|
+
|
16
|
+
from plato.config import Config
|
17
|
+
from plato.datasources import base
|
18
|
+
|
19
|
+
|
20
|
+
class DataSource(base.DataSource):
|
21
|
+
"""A data source for the HuggingFace datasets."""
|
22
|
+
|
23
|
+
def __init__(self, **kwargs):
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
dataset_name = Config().data.dataset_name
|
27
|
+
logging.info("Dataset: %s", dataset_name)
|
28
|
+
|
29
|
+
if hasattr(Config.data, "dataset_config"):
|
30
|
+
dataset_config = Config().data.dataset_config
|
31
|
+
else:
|
32
|
+
dataset_config = None
|
33
|
+
|
34
|
+
saved_data_path = (
|
35
|
+
f"{Config().params['data_path']}/{dataset_name}_{dataset_config}"
|
36
|
+
)
|
37
|
+
|
38
|
+
if os.path.exists(saved_data_path):
|
39
|
+
# If the dataset has already been downloaded and saved
|
40
|
+
self.dataset = load_from_disk(saved_data_path)
|
41
|
+
else:
|
42
|
+
# Download and save the dataset
|
43
|
+
self.dataset = load_dataset(dataset_name, dataset_config)
|
44
|
+
self.dataset.save_to_disk(saved_data_path)
|
45
|
+
|
46
|
+
parser = HfArgumentParser(TrainingArguments)
|
47
|
+
(self.training_args,) = parser.parse_args_into_dataclasses(
|
48
|
+
args=["--output_dir=/tmp", "--report_to=none"]
|
49
|
+
)
|
50
|
+
|
51
|
+
model_name = Config().trainer.model_name
|
52
|
+
use_auth_token = None
|
53
|
+
if hasattr(Config().parameters, "huggingface_token"):
|
54
|
+
use_auth_token = Config().parameters.huggingface_token
|
55
|
+
config_kwargs = {
|
56
|
+
"cache_dir": Config().params["model_path"],
|
57
|
+
"revision": "main",
|
58
|
+
"use_auth_token": use_auth_token,
|
59
|
+
}
|
60
|
+
tokenizer_kwargs = {
|
61
|
+
"cache_dir": Config().params["data_path"],
|
62
|
+
"use_fast": True,
|
63
|
+
"revision": "main",
|
64
|
+
"use_auth_token": use_auth_token,
|
65
|
+
}
|
66
|
+
|
67
|
+
self.config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
68
|
+
|
69
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
70
|
+
model_name, config=self.config, **tokenizer_kwargs
|
71
|
+
)
|
72
|
+
self.tok_logger = utils.logging.get_logger(
|
73
|
+
"transformers.tokenization_utils_base"
|
74
|
+
)
|
75
|
+
|
76
|
+
self.block_size = 128
|
77
|
+
|
78
|
+
self.column_names = ["text"]
|
79
|
+
self.text_column_name = "text"
|
80
|
+
self.trainset = self.preprocess_data(self.dataset["train"])
|
81
|
+
self.testset = self.preprocess_data(self.dataset["validation"])
|
82
|
+
|
83
|
+
def num_train_examples(self):
|
84
|
+
return len(self.trainset)
|
85
|
+
|
86
|
+
def num_test_examples(self):
|
87
|
+
return len(self.testset)
|
88
|
+
|
89
|
+
def get_train_set(self):
|
90
|
+
return self.trainset
|
91
|
+
|
92
|
+
def get_test_set(self):
|
93
|
+
return self.testset
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def input_shape():
|
97
|
+
"""Returns the input shape of the dataset, useful for building
|
98
|
+
a TF model."""
|
99
|
+
raise ValueError("Not implemented.")
|
100
|
+
|
101
|
+
def tokenize_function(self, examples):
|
102
|
+
"""Using the tokenizer from AutoTokenizer to tokenize the text."""
|
103
|
+
with testing_utils.CaptureLogger(self.tok_logger) as cl:
|
104
|
+
output = self.tokenizer(examples[self.text_column_name])
|
105
|
+
# clm input could be much much longer than block_size
|
106
|
+
if "Token indices sequence length is longer than the" in cl.out:
|
107
|
+
self.tok_logger.warning(
|
108
|
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be "
|
109
|
+
"chunked into smaller bits before being passed to the model."
|
110
|
+
)
|
111
|
+
return output
|
112
|
+
|
113
|
+
def group_texts(self, examples):
|
114
|
+
"""Concatenate all texts."""
|
115
|
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
116
|
+
|
117
|
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
118
|
+
|
119
|
+
# We drop the small remainder, we could add padding if the model supported it
|
120
|
+
# instead of this drop, you can customize this part to your needs.
|
121
|
+
total_length = (total_length // self.block_size) * self.block_size
|
122
|
+
|
123
|
+
# Split by chunks of max_len.
|
124
|
+
result = {
|
125
|
+
k: [
|
126
|
+
t[i : i + self.block_size]
|
127
|
+
for i in range(0, total_length, self.block_size)
|
128
|
+
]
|
129
|
+
for k, t in concatenated_examples.items()
|
130
|
+
}
|
131
|
+
|
132
|
+
result["labels"] = result["input_ids"].copy()
|
133
|
+
return result
|
134
|
+
|
135
|
+
def preprocess_data(self, datasets):
|
136
|
+
"""Tokenizing and grouping the raw dataset."""
|
137
|
+
with self.training_args.main_process_first(desc="dataset map tokenization"):
|
138
|
+
tokenized_datasets = datasets.map(
|
139
|
+
self.tokenize_function,
|
140
|
+
batched=True,
|
141
|
+
num_proc=4,
|
142
|
+
remove_columns=self.column_names,
|
143
|
+
load_from_cache_file=True,
|
144
|
+
desc="Running tokenizer on dataset",
|
145
|
+
)
|
146
|
+
|
147
|
+
block_size = self.tokenizer.model_max_length
|
148
|
+
if block_size > 1024:
|
149
|
+
logging.warning(
|
150
|
+
"The tokenizer picked seems to have a very large `model_max_length` "
|
151
|
+
"%s. Picking 1024 instead.",
|
152
|
+
self.tokenizer.model_max_length,
|
153
|
+
)
|
154
|
+
block_size = 1024
|
155
|
+
|
156
|
+
with self.training_args.main_process_first(desc="grouping texts together"):
|
157
|
+
lm_datasets = tokenized_datasets.map(
|
158
|
+
self.group_texts,
|
159
|
+
batched=True,
|
160
|
+
num_proc=4,
|
161
|
+
load_from_cache_file=True,
|
162
|
+
desc=f"Grouping texts in chunks of {block_size}",
|
163
|
+
)
|
164
|
+
|
165
|
+
return lm_datasets
|