plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,431 @@
1
+ """
2
+ The Gym dataset.
3
+
4
+ Note that the setting for the data loader is obtained from the github repo provided
5
+ by the official workers:
6
+ Finegym: A hierarchical video dataset for fine-grained action understanding
7
+
8
+ The data structure should be:
9
+
10
+ ├── data
11
+ │ ├── gym99
12
+ | | ├── annotations
13
+ | | | ├── gym99_train_org.txt
14
+ | | | ├── gym99_val_org.txt
15
+ | | | ├── gym99_train.txt
16
+ | | | ├── gym99_val.txt
17
+ | | | ├── annotation.json
18
+ | | | └── event_annotation.json
19
+ │ │ ├── videos
20
+ | | | ├── 0LtLS9wROrk.mp4
21
+ | | | ├── ...
22
+ | | | └── zfqS-wCJSsw.mp4
23
+ │ │ ├── events
24
+ | | | ├── 0LtLS9wROrk_E_002407_002435.mp4
25
+ | | | ├── ...
26
+ | | | └── zfqS-wCJSsw_E_006732_006824.mp4
27
+ │ │ ├── subactions
28
+ | | | ├── 0LtLS9wROrk_E_002407_002435_A_0003_0005.mp4
29
+ | | | ├── ...
30
+ | | | └── zfqS-wCJSsw_E_006244_006252_A_0000_0007.mp4
31
+ | | └── subaction_frames
32
+ | | |── subaction_audios
33
+
34
+ """
35
+
36
+ import logging
37
+ import os
38
+ import shutil
39
+
40
+ import torch
41
+
42
+ from mmaction.tools.data.gym import download as gym_downloader
43
+ from mmaction.datasets import build_dataset
44
+
45
+ from plato.config import Config
46
+ from plato.datasources.datalib.gym_utils import gym_trim
47
+ from plato.datasources import multimodal_base
48
+ from plato.datasources.datalib import frames_extraction_tools
49
+ from plato.datasources.datalib import audio_extraction_tools
50
+ from plato.datasources.datalib import data_utils
51
+
52
+
53
+ class GymDataset(multimodal_base.MultiModalDataset):
54
+ """Prepare the Gym dataset."""
55
+
56
+ def __init__(
57
+ self, multimodal_data_holder, phase, phase_info, modality_sampler=None
58
+ ):
59
+ super().__init__()
60
+ self.phase = phase
61
+ # multimodal_data_holder is a dict:
62
+ # {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
63
+ self.phase_multimodal_data_record = multimodal_data_holder
64
+
65
+ # a dict presented as:
66
+ # "rgb": <rgb_annotation_file_path>
67
+ self.phase_info = phase_info
68
+
69
+ self.modalities_name = list(multimodal_data_holder.keys())
70
+
71
+ self.supported_modalities = ["rgb", "flow", "audio_feature"]
72
+
73
+ # default utilizing the full modalities
74
+ if modality_sampler is None:
75
+ self.modality_sampler = self.supported_modalities
76
+ else:
77
+ self.modality_sampler = modality_sampler
78
+
79
+ self.targets = self.get_targets()
80
+
81
+ def __len__(self):
82
+ return len(self.phase_multimodal_data_record)
83
+
84
+ def get_targets(self):
85
+ """Obtain the labels of samples in current phase dataset."""
86
+ # There is no label provided in the fine gym dataset currently
87
+ # This part will be added afterward
88
+ return [0]
89
+
90
+ def get_one_multimodal_sample(self, sample_idx):
91
+ """Obtain one sample from the Kinetics dataset."""
92
+ obtained_mm_sample = dict()
93
+
94
+ for modality_name in self.modalities_name:
95
+ modality_dataset = self.phase_multimodal_data_record[modality_name]
96
+ obtained_mm_sample[modality_name] = modality_dataset[sample_idx]
97
+
98
+ return obtained_mm_sample
99
+
100
+
101
+ class DataSource(multimodal_base.MultiModalDataSource):
102
+ """The Gym dataset."""
103
+
104
+ def __init__(self, **kwargs):
105
+ super().__init__()
106
+
107
+ self.data_name = Config().data.datasource
108
+
109
+ # the rawframes contains the "flow", "rgb"
110
+ # thus, the flow and rgb will be put in in same directory rawframes/
111
+ # self.modality_names = ["video", "audio", "rawframes", "audio_feature"]
112
+ self.modality_names = ["video", "audio", "rgb", "flow", "audio_feature"]
113
+
114
+ _path = Config().params["data_path"]
115
+ self._data_path_process(data_path=_path, base_data_name=self.data_name)
116
+ self._create_modalities_path(modality_names=self.modality_names)
117
+
118
+ base_data_path = self.mm_data_info["data_path"]
119
+ # define all the dir here
120
+ gym_anno_dir_name = "annotations"
121
+ self.data_annotation_path = os.path.join(base_data_path, gym_anno_dir_name)
122
+
123
+ self.data_anno_file_path = os.path.join(
124
+ self.data_annotation_path, "annotation.json"
125
+ )
126
+ self.categoty_anno_file_path = os.path.join(
127
+ self.data_annotation_path, "gym99_categories.txt"
128
+ )
129
+
130
+ self.raw_videos_path = os.path.join(base_data_path, "videos")
131
+ self.event__path = os.path.join(base_data_path, "event")
132
+ self.event_subsection__path = os.path.join(base_data_path, "subactions")
133
+ self.data_event_anno_file_path = os.path.join(
134
+ self.data_annotation_path, "event_annotation.json"
135
+ )
136
+ self.event_subsection_frames__path = os.path.join(
137
+ base_data_path, "subaction_rawframes"
138
+ )
139
+ self.event_subsection_audios__path = os.path.join(
140
+ base_data_path, "subaction_audios"
141
+ )
142
+
143
+ self.event_subsection_audios_fea__path = os.path.join(
144
+ base_data_path, "subaction_audios_features"
145
+ )
146
+
147
+ self.rawframes_splits_list_files_into = {
148
+ "train": os.path.join(
149
+ self.data_annotation_path, "gym99_train_rawframes.txt"
150
+ ),
151
+ "val": os.path.join(self.data_annotation_path, "gym99_val_rawframes.txt"),
152
+ }
153
+
154
+ self.audios_splits_list_files_into = {
155
+ "train": os.path.join(self.data_annotation_path, "gym99_train_audios.txt"),
156
+ "val": os.path.join(self.data_annotation_path, "gym99_val_audios.txt"),
157
+ }
158
+ self.audio_features_splits_list_files_into = {
159
+ "train": os.path.join(
160
+ self.data_annotation_path, "gym99_train_audio_features.txt"
161
+ ),
162
+ "val": os.path.join(
163
+ self.data_annotation_path, "gym99_val_audio_features.txt"
164
+ ),
165
+ }
166
+
167
+ set_level_category_url = (
168
+ "https://sdolivia.github.io/FineGym/resources/dataset/set_categories.txt"
169
+ )
170
+ g99_categoty_url = (
171
+ "https://sdolivia.github.io/FineGym/resources/dataset/gym99_categories.txt"
172
+ )
173
+
174
+ anno_url = "https://sdolivia.github.io/FineGym/resources/dataset/finegym_annotation_info_v1.0.json"
175
+
176
+ train_url = "https://sdolivia.github.io/FineGym/resources/dataset/gym99_train_element_v1.0.txt"
177
+
178
+ eval_url = (
179
+ "https://sdolivia.github.io/FineGym/resources/dataset/gym99_val_element.txt"
180
+ )
181
+
182
+ _ = self._download_arrange_data(
183
+ download_url_address=set_level_category_url,
184
+ data_path=self.data_annotation_path,
185
+ obtained_file_name="set_categories.txt",
186
+ )
187
+
188
+ _ = self._download_arrange_data(
189
+ download_url_address=g99_categoty_url,
190
+ data_path=self.data_annotation_path,
191
+ obtained_file_name="gym99_categories.txt",
192
+ )
193
+
194
+ _ = self._download_arrange_data(
195
+ download_url_address=anno_url,
196
+ data_path=self.data_annotation_path,
197
+ obtained_file_name="annotation.json",
198
+ )
199
+
200
+ _ = self._download_arrange_data(
201
+ download_url_address=train_url,
202
+ data_path=self.data_annotation_path,
203
+ obtained_file_name="gym99_train_org.txt",
204
+ )
205
+
206
+ _ = self._download_arrange_data(
207
+ download_url_address=eval_url,
208
+ data_path=self.data_annotation_path,
209
+ obtained_file_name="gym99_val_org.txt",
210
+ )
211
+
212
+ if not self._exists(self.raw_videos_path):
213
+ logging.info(
214
+ "Downloading the raw videos for the Gym dataset. This may take a long time."
215
+ )
216
+
217
+ gym_downloader.main(
218
+ input=self.data_anno_file_path,
219
+ output_dir=self.raw_videos_path,
220
+ num_jobs=Config().data.downloader.num_workers,
221
+ )
222
+ logging.info("Done.")
223
+
224
+ # Trim Videos into Events
225
+ if not self._exists(self.event__path):
226
+ gym_trim.trim_event(
227
+ video_root=self.raw_videos_path,
228
+ anno_file=self.data_anno_file_path,
229
+ event_anno_file=self.data_event_anno_file_path,
230
+ event_root=self.event__path,
231
+ )
232
+ if not self._exists(self.event_subsection__path):
233
+ gym_trim.trim_subsection(
234
+ event_anno_file=self.data_event_anno_file_path,
235
+ event_root=self.event__path,
236
+ subaction_root=self.event_subsection__path,
237
+ )
238
+
239
+ logging.info("The Gym dataset has been prepared")
240
+ self.extract_videos_rgb_flow_audio()
241
+
242
+ def extract_videos_rgb_flow_audio(self):
243
+ """Extract the rgb optical flow audios from the video"""
244
+ src_videos_dir = self.event_subsection__path
245
+ frames_out__path = self.event_subsection_frames__path
246
+ rgb_out__path = self.event_subsection_frames__path
247
+ flow_our__path = self.event_subsection_frames__path
248
+ audio_out__path = self.event_subsection_audios__path
249
+ audio_feature__path = self.event_subsection_audios_fea__path
250
+
251
+ # define the modalities extractor
252
+ vdf_extractor = frames_extraction_tools.VideoFramesExtractor(
253
+ video_src_dir=src_videos_dir,
254
+ dir_level=1,
255
+ num_worker=8,
256
+ video_ext="mp4",
257
+ mixed_ext=False,
258
+ )
259
+ vda_extractor = audio_extraction_tools.VideoAudioExtractor(
260
+ video_src_dir=src_videos_dir,
261
+ dir_level=1,
262
+ num_worker=8,
263
+ video_ext="mp4",
264
+ mixed_ext=False,
265
+ )
266
+
267
+ if torch.cuda.is_available():
268
+ if not self._exists(rgb_out__path) and not self._exists(flow_our__path):
269
+ logging.info(
270
+ "Extracting frames by GPU from videos in %s to %s.",
271
+ src_videos_dir,
272
+ rgb_out__path,
273
+ )
274
+ vdf_extractor.build_full_frames_gpu(
275
+ to__path=frames_out__path, new_short=256, new_width=0, new_height=0
276
+ )
277
+ else:
278
+ if not self._exists(rgb_out__path):
279
+ logging.info(
280
+ "Extracting frames by CPU from videos in %s to %s.",
281
+ src_videos_dir,
282
+ rgb_out__path,
283
+ )
284
+ vdf_extractor.build_frames_cpu(to_dir=frames_out__path)
285
+
286
+ if not self._exists(audio_out__path):
287
+ logging.info(
288
+ "Extracting audios by CPU from videos in %s to %s.",
289
+ src_videos_dir,
290
+ audio_out__path,
291
+ )
292
+ vda_extractor.build_audios(to_dir=audio_out__path)
293
+
294
+ if not self._exists(audio_feature__path):
295
+ logging.info(
296
+ "Extracting audios feature by CPU from audios in %s to %s.",
297
+ audio_out__path,
298
+ audio_feature__path,
299
+ )
300
+ # # window_size:32ms hop_size:16ms
301
+
302
+ vda_extractor.build_audios_features(
303
+ audio_src_path=audio_out__path,
304
+ to_dir=audio_feature__path,
305
+ fft_size=512, # fft_size / sample_rate is window size
306
+ hop_size=256,
307
+ )
308
+ # extract the splits data into list files based on the frames information
309
+ gym_trim.generate_splits_list(
310
+ data_root=self.event_subsection__path,
311
+ annotation_root=self.data_annotation_path,
312
+ frame_data_root=frames_out__path,
313
+ )
314
+
315
+ # generate the audio and audio features splits file
316
+ # just copy the frame files to the audio ones
317
+ for split in list(self.rawframes_splits_list_files_into.keys()):
318
+ rawframes_split_file_path = self.rawframes_splits_list_files_into[split]
319
+ audios_split_file_path = self.audios_splits_list_files_into[split]
320
+ audio_features_split_file_path = self.audios_splits_list_files_into[split]
321
+ shutil.copy(src=rawframes_split_file_path, dst=audios_split_file_path)
322
+ shutil.copy(
323
+ src=rawframes_split_file_path, dst=audio_features_split_file_path
324
+ )
325
+
326
+ def correct_current_config(self, loaded_plato_config, mode, modality_name):
327
+ """Correct the loaded configuration settings based on
328
+ on-hand data information"""
329
+
330
+ # 1.1. convert plato config to dict type
331
+ loaded_config = data_utils.config_to_dict(loaded_plato_config)
332
+ # 1.2. convert the list to tuple
333
+ loaded_config = data_utils.dict_list2tuple(loaded_config)
334
+
335
+ # 2. using the obtained annotation file replace the user set ones
336
+ # in the configuration file
337
+ # The main reason is that the obtained path here is the full path
338
+ cur_rawframes_anno_file_path = self.rawframes_splits_list_files_into[mode]
339
+ cur_rawframes_data_path = self.event_subsection_frames__path
340
+ cur_videos_anno_file_path = None
341
+ cur_video_data_path = self.event_subsection__path
342
+ cur_audio_feas_anno_file_path = self.audios_splits_list_files_into[mode]
343
+ cur_audio_feas_data_path = self.event_subsection_audios__path
344
+
345
+ if modality_name == "rgb" or modality_name == "flow":
346
+ loaded_config["ann_file"] = cur_rawframes_anno_file_path
347
+ elif modality_name == "audio_feature":
348
+ loaded_config["ann_file"] = cur_audio_feas_anno_file_path
349
+ else:
350
+ loaded_config["ann_file"] = cur_videos_anno_file_path
351
+
352
+ # 3. reset the data_prefix by using the modality path
353
+ if modality_name == "rgb" or modality_name == "flow":
354
+ loaded_config["data_prefix"] = cur_rawframes_data_path
355
+ elif modality_name == "audio_feature":
356
+ loaded_config["data_prefix"] = cur_audio_feas_data_path
357
+ else:
358
+ loaded_config["data_prefix"] = cur_video_data_path
359
+
360
+ return loaded_config
361
+
362
+ def get_phase_dataset(self, phase, modality_sampler):
363
+ """Get the dataset for the specific phase."""
364
+ rgb_mode_config = getattr(Config().data.multi_modal_configs.rgb, phase)
365
+ flow_mode_config = getattr(Config().data.multi_modal_configs.flow, phase)
366
+ audio_feature_mode_config = getattr(
367
+ Config().data.multi_modal_configs.audio_feature, phase
368
+ )
369
+
370
+ rgb_mode_config = self.correct_current_config(
371
+ loaded_plato_config=rgb_mode_config, mode=phase, modality_name="rgb"
372
+ )
373
+ flow_mode_config = self.correct_current_config(
374
+ loaded_plato_config=flow_mode_config, mode=phase, modality_name="flow"
375
+ )
376
+ audio_feature_mode_config = self.correct_current_config(
377
+ loaded_plato_config=audio_feature_mode_config,
378
+ mode=phase,
379
+ modality_name="audio_feature",
380
+ )
381
+ # build a RawframeDataset
382
+ rgb_mode_dataset = build_dataset(rgb_mode_config)
383
+ flow_mode_dataset = build_dataset(flow_mode_config)
384
+ audio_feature_mode_dataset = build_dataset(audio_feature_mode_config)
385
+
386
+ multi_modal_mode_data = {
387
+ "rgb": rgb_mode_dataset,
388
+ "flow": flow_mode_dataset,
389
+ "audio_feature": audio_feature_mode_dataset,
390
+ }
391
+
392
+ multi_modal_mode_info = {
393
+ "rgb": rgb_mode_config["ann_file"],
394
+ "flow": flow_mode_config["ann_file"],
395
+ "audio_feature": audio_feature_mode_config["ann_file"],
396
+ "categories": self.categoty_anno_file_path,
397
+ }
398
+
399
+ gym_mode_dataset = GymDataset(
400
+ multimodal_data_holder=multi_modal_mode_data,
401
+ phase="train",
402
+ phase_info=multi_modal_mode_info,
403
+ modality_sampler=modality_sampler,
404
+ )
405
+
406
+ return gym_mode_dataset
407
+
408
+ def get_train_set(self, modality_sampler=None):
409
+ """Obtain the trainset for multimodal data."""
410
+ gym_train_dataset = self.get_phase_dataset(
411
+ phase="train", modality_sampler=modality_sampler
412
+ )
413
+
414
+ return gym_train_dataset
415
+
416
+ def get_test_set(self, modality_sampler=None):
417
+ """Obtain the testset for multimodal data.
418
+
419
+ Note, in the kinetics dataset, there is no testset in which
420
+ samples contain the groundtruth label.
421
+ Thus, we utilize the validation set directly.
422
+ """
423
+ gym_val_dataset = self.get_phase_dataset(
424
+ phase="val", modality_sampler=modality_sampler
425
+ )
426
+
427
+ return gym_val_dataset
428
+
429
+ def get_modality_name(self):
430
+ """Get all supports modalities"""
431
+ return ["rgb", "flow", "audio"]
@@ -0,0 +1,165 @@
1
+ """
2
+ A data source for the HuggingFace datasets.
3
+
4
+ For more information about the HuggingFace datasets, refer to:
5
+
6
+ https://huggingface.co/docs/datasets/quicktour.html
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ from datasets import load_dataset, load_from_disk
13
+ from transformers import AutoConfig, AutoTokenizer, HfArgumentParser
14
+ from transformers import TrainingArguments, testing_utils, utils
15
+
16
+ from plato.config import Config
17
+ from plato.datasources import base
18
+
19
+
20
+ class DataSource(base.DataSource):
21
+ """A data source for the HuggingFace datasets."""
22
+
23
+ def __init__(self, **kwargs):
24
+ super().__init__()
25
+
26
+ dataset_name = Config().data.dataset_name
27
+ logging.info("Dataset: %s", dataset_name)
28
+
29
+ if hasattr(Config.data, "dataset_config"):
30
+ dataset_config = Config().data.dataset_config
31
+ else:
32
+ dataset_config = None
33
+
34
+ saved_data_path = (
35
+ f"{Config().params['data_path']}/{dataset_name}_{dataset_config}"
36
+ )
37
+
38
+ if os.path.exists(saved_data_path):
39
+ # If the dataset has already been downloaded and saved
40
+ self.dataset = load_from_disk(saved_data_path)
41
+ else:
42
+ # Download and save the dataset
43
+ self.dataset = load_dataset(dataset_name, dataset_config)
44
+ self.dataset.save_to_disk(saved_data_path)
45
+
46
+ parser = HfArgumentParser(TrainingArguments)
47
+ (self.training_args,) = parser.parse_args_into_dataclasses(
48
+ args=["--output_dir=/tmp", "--report_to=none"]
49
+ )
50
+
51
+ model_name = Config().trainer.model_name
52
+ use_auth_token = None
53
+ if hasattr(Config().parameters, "huggingface_token"):
54
+ use_auth_token = Config().parameters.huggingface_token
55
+ config_kwargs = {
56
+ "cache_dir": Config().params["model_path"],
57
+ "revision": "main",
58
+ "use_auth_token": use_auth_token,
59
+ }
60
+ tokenizer_kwargs = {
61
+ "cache_dir": Config().params["data_path"],
62
+ "use_fast": True,
63
+ "revision": "main",
64
+ "use_auth_token": use_auth_token,
65
+ }
66
+
67
+ self.config = AutoConfig.from_pretrained(model_name, **config_kwargs)
68
+
69
+ self.tokenizer = AutoTokenizer.from_pretrained(
70
+ model_name, config=self.config, **tokenizer_kwargs
71
+ )
72
+ self.tok_logger = utils.logging.get_logger(
73
+ "transformers.tokenization_utils_base"
74
+ )
75
+
76
+ self.block_size = 128
77
+
78
+ self.column_names = ["text"]
79
+ self.text_column_name = "text"
80
+ self.trainset = self.preprocess_data(self.dataset["train"])
81
+ self.testset = self.preprocess_data(self.dataset["validation"])
82
+
83
+ def num_train_examples(self):
84
+ return len(self.trainset)
85
+
86
+ def num_test_examples(self):
87
+ return len(self.testset)
88
+
89
+ def get_train_set(self):
90
+ return self.trainset
91
+
92
+ def get_test_set(self):
93
+ return self.testset
94
+
95
+ @staticmethod
96
+ def input_shape():
97
+ """Returns the input shape of the dataset, useful for building
98
+ a TF model."""
99
+ raise ValueError("Not implemented.")
100
+
101
+ def tokenize_function(self, examples):
102
+ """Using the tokenizer from AutoTokenizer to tokenize the text."""
103
+ with testing_utils.CaptureLogger(self.tok_logger) as cl:
104
+ output = self.tokenizer(examples[self.text_column_name])
105
+ # clm input could be much much longer than block_size
106
+ if "Token indices sequence length is longer than the" in cl.out:
107
+ self.tok_logger.warning(
108
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be "
109
+ "chunked into smaller bits before being passed to the model."
110
+ )
111
+ return output
112
+
113
+ def group_texts(self, examples):
114
+ """Concatenate all texts."""
115
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
116
+
117
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
118
+
119
+ # We drop the small remainder, we could add padding if the model supported it
120
+ # instead of this drop, you can customize this part to your needs.
121
+ total_length = (total_length // self.block_size) * self.block_size
122
+
123
+ # Split by chunks of max_len.
124
+ result = {
125
+ k: [
126
+ t[i : i + self.block_size]
127
+ for i in range(0, total_length, self.block_size)
128
+ ]
129
+ for k, t in concatenated_examples.items()
130
+ }
131
+
132
+ result["labels"] = result["input_ids"].copy()
133
+ return result
134
+
135
+ def preprocess_data(self, datasets):
136
+ """Tokenizing and grouping the raw dataset."""
137
+ with self.training_args.main_process_first(desc="dataset map tokenization"):
138
+ tokenized_datasets = datasets.map(
139
+ self.tokenize_function,
140
+ batched=True,
141
+ num_proc=4,
142
+ remove_columns=self.column_names,
143
+ load_from_cache_file=True,
144
+ desc="Running tokenizer on dataset",
145
+ )
146
+
147
+ block_size = self.tokenizer.model_max_length
148
+ if block_size > 1024:
149
+ logging.warning(
150
+ "The tokenizer picked seems to have a very large `model_max_length` "
151
+ "%s. Picking 1024 instead.",
152
+ self.tokenizer.model_max_length,
153
+ )
154
+ block_size = 1024
155
+
156
+ with self.training_args.main_process_first(desc="grouping texts together"):
157
+ lm_datasets = tokenized_datasets.map(
158
+ self.group_texts,
159
+ batched=True,
160
+ num_proc=4,
161
+ load_from_cache_file=True,
162
+ desc=f"Grouping texts in chunks of {block_size}",
163
+ )
164
+
165
+ return lm_datasets