plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,568 @@
1
+ """
2
+
3
+ The Kinetics700 dataset.
4
+
5
+ Note that the setting for the data loader is obtained from the github
6
+ repo provided by the official workers:
7
+ https://github.com/pytorch/vision/references/video_classification/train.py
8
+
9
+ We consider three modalities: RGB, optical flow and audio.
10
+ For RGB and flow, we use input clips of 16×224×224 as input.
11
+ We follow [1] for visual pre-processing and augmentation.
12
+ For audio, we use log-Mel with 100 temporal frames by 40 Mel filters.
13
+
14
+ Audio and visual are temporally aligned.
15
+
16
+ [1]. Video classification with channel-separated convolutional networks.
17
+ In ICCV, 2019. (CSN network)
18
+ This is actually the csn network in the mmaction packet.
19
+
20
+ Also, the implementation of our code is based on the mmaction2 of the
21
+ openmmlab https://openmmlab.com/.
22
+
23
+ The data structure is:
24
+ ├── data
25
+ │ ├── ${DATASET}
26
+ │ │ ├── ${DATASET}_train_list_videos.txt
27
+ │ │ ├── ${DATASET}_val_list_videos.txt
28
+ │ │ ├── annotations
29
+ │ │ ├── videos_train
30
+ │ │ ├── videos_val
31
+ │ │ │ ├── abseiling
32
+ │ │ │ │ ├── 0wR5jVB-WPk_000417_000427.mp4
33
+ │ │ │ │ ├── ...
34
+ │ │ │ ├── ...
35
+ │ │ │ ├── wrapping_present
36
+ │ │ │ ├── ...
37
+ │ │ │ ├── zumba
38
+ │ │ ├── rawframes_train
39
+ │ │ ├── rawframes_val
40
+
41
+
42
+ For data_formates, we support "videos", "rawframes", "audios", "audio_features"
43
+ for modality, we support "video", "audio", "audio_feature", "rgb", "flow"
44
+
45
+ """
46
+
47
+ import re
48
+ import logging
49
+ import os
50
+ import shutil
51
+ from collections import defaultdict
52
+
53
+ import torch
54
+
55
+ from mmaction.tools.data.kinetics import download as kinetics_downloader
56
+ from mmaction.datasets import build_dataset
57
+
58
+ from plato.config import Config
59
+ from plato.datasources import multimodal_base
60
+ from plato.datasources.datalib import frames_extraction_tools
61
+ from plato.datasources.datalib import audio_extraction_tools
62
+ from plato.datasources.datalib import modality_data_anntation_tools
63
+ from plato.datasources.datalib import data_utils
64
+ from plato.datasources.datalib import tiny_data_tools
65
+
66
+
67
+ def obtain_required_anno_files(splits_info):
68
+ """Obtain the general full/tiny annotation files for splits"""
69
+ required_anno_files = {"train": "", "test": "", "val": ""}
70
+ for split in ["train", "test", "val"]:
71
+ split_info = splits_info[split]
72
+ # Obtain the annotation files for the whole dataset
73
+ if hasattr(Config().data, "tiny_data") and Config().data.tiny_data:
74
+ split_anno_path = split_info["split_tiny_anno_file"]
75
+ else: # Obtain the annotation files for the tiny dataset
76
+ split_anno_path = split_info["split_anno_file"]
77
+
78
+ required_anno_files[split] = split_anno_path
79
+ return required_anno_files
80
+
81
+
82
+ class KineticsDataset(multimodal_base.MultiModalDataset):
83
+ """Prepare the Kinetics dataset."""
84
+
85
+ def __init__(
86
+ self, multimodal_data_holder, phase, phase_info, modality_sampler=None
87
+ ):
88
+ super().__init__()
89
+ self.phase = phase
90
+ # multimodal_data_holder is a dict:
91
+ # {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
92
+ self.phase_multimodal_data_record = multimodal_data_holder
93
+
94
+ # a dict presented as:
95
+ # "rgb": <rgb_annotation_file_path>
96
+ self.phase_info = phase_info
97
+
98
+ self.modalities_name = list(multimodal_data_holder.keys())
99
+
100
+ self.supported_modalities = ["rgb", "flow", "audio_feature"]
101
+
102
+ # default utilizing the full modalities
103
+ if modality_sampler is None:
104
+ self.modality_sampler = self.supported_modalities
105
+ else:
106
+ self.modality_sampler = modality_sampler
107
+
108
+ self.targets = self.get_targets()
109
+
110
+ def __len__(self):
111
+ return len(self.phase_multimodal_data_record)
112
+
113
+ def get_targets(self):
114
+ """Obtain the labels of samples in current phase dataset."""
115
+ # the order of samples in rgb, flow, or audio annotation files
116
+ # is maintained the same, thus obtain either one is great.
117
+ # Normally, rgb and flow belong to the rawframes
118
+ rawframes_anno_list_file_path = self.phase_info["rgb"]
119
+ annos_list = data_utils.read_anno_file(rawframes_anno_list_file_path)
120
+
121
+ obtained_targets = [anno_item["label"][0] for anno_item in annos_list]
122
+
123
+ return obtained_targets
124
+
125
+ def get_one_multimodal_sample(self, sample_idx):
126
+ """Obtain one sample from the Kinetics dataset."""
127
+ obtained_mm_sample = dict()
128
+
129
+ for modality_name in self.modalities_name:
130
+ modality_dataset = self.phase_multimodal_data_record[modality_name]
131
+ obtained_mm_sample[modality_name] = modality_dataset[sample_idx]
132
+
133
+ return obtained_mm_sample
134
+
135
+
136
+ class DataSource(multimodal_base.MultiModalDataSource):
137
+ """The Kinetics datasource."""
138
+
139
+ def __init__(self, **kwargs):
140
+ super().__init__()
141
+
142
+ self.data_name = Config().data.datasource
143
+ base_data_name = re.findall(r"\D+", self.data_name)[0]
144
+
145
+ # The rawframes contains the "flow", "rgb"
146
+ # thus, the flow and rgb will be put in the same directory rawframes/
147
+ self.modality_names = ["video", "audio", "rgb", "flow", "audio_feature"]
148
+
149
+ _path = Config().params["data_path"]
150
+ # Generate the basic path for the dataset, it performs:
151
+ # 1.- Assign path to self.mm_data_info
152
+ # 2.- Assign splits path to self.splits_info
153
+ # where the root path for splits is the data_path
154
+ # in self.mm_data_info
155
+ self._data_path_process(data_path=_path, base_data_name=self.data_name)
156
+ # Generate the modalities path for all splits, it performs:
157
+ # 1.- Add modality path to each modality, the key style is:
158
+ # {modality}_path: ...
159
+ # Note: the rgb and flow modalities are merged into 'rawframes_path'
160
+ # as they belong to the same prototype "rawframes".
161
+ self._create_modalities_path(modality_names=self.modality_names)
162
+
163
+ # Set the annotation file path
164
+ base_data_path = self.mm_data_info["data_path"]
165
+
166
+ # Define all the dir here
167
+ kinetics_anno_dir_name = "annotations"
168
+ self.data_annotation_path = os.path.join(base_data_path, kinetics_anno_dir_name)
169
+
170
+ for split in ["train", "test", "validate"]:
171
+ split_anno_path = os.path.join(
172
+ self.data_annotation_path, base_data_name + "_" + split + ".csv"
173
+ )
174
+ split_tiny_anno_path = os.path.join(
175
+ self.data_annotation_path, base_data_name + "_" + split + "_tiny.csv"
176
+ )
177
+ split_name = split if split != "validate" else "val"
178
+ self.splits_info[split_name]["split_anno_file"] = split_anno_path
179
+ self.splits_info[split_name]["split_tiny_anno_file"] = split_tiny_anno_path
180
+
181
+ # Thus, after operating the above two functions,
182
+ # the self.splits_info can contain
183
+ # e.g. {'train':
184
+ # 'path': xxx,
185
+ # 'split_anno_file': xxx,
186
+ # 'split_tiny_anno_file': xxx,
187
+ # 'rawframes_path': xxx,
188
+ # 'video_path': xxx}
189
+ # the self.mm_data_info can contain
190
+ # - source_data_path
191
+ # - data_path
192
+
193
+ anno_download_url = (
194
+ "https://storage.googleapis.com/deepmind-media/Datasets/{}.tar.gz"
195
+ ).format(self.data_name)
196
+
197
+ extracted_anno_file_name = self._download_arrange_data(
198
+ download_url_address=anno_download_url,
199
+ data_path=self.data_annotation_path,
200
+ obtained_file_name=None,
201
+ )
202
+ download_anno_path = os.path.join(
203
+ self.data_annotation_path, extracted_anno_file_name
204
+ )
205
+
206
+ downloaded_files = os.listdir(download_anno_path)
207
+ for file_name in downloaded_files:
208
+ new_file_name = base_data_name + "_" + file_name
209
+ shutil.move(
210
+ os.path.join(download_anno_path, file_name),
211
+ os.path.join(self.data_annotation_path, new_file_name),
212
+ )
213
+
214
+ # Whether to create the tiny dataset
215
+ for split in ["train", "test", "validate"]:
216
+ split_name = split if split != "validate" else "val"
217
+ split_anno_path = self.splits_info[split_name]["split_anno_file"]
218
+
219
+ if hasattr(Config().data, "tiny_data") and Config().data.tiny_data:
220
+ anno_files_info = {
221
+ "train": self.splits_info["train"]["split_anno_file"],
222
+ "test": self.splits_info["test"]["split_anno_file"],
223
+ "val": self.splits_info["val"]["split_anno_file"],
224
+ }
225
+ tiny_data_tools.create_tiny_kinetics_anno(
226
+ kinetics_annotation_files_info=anno_files_info,
227
+ num_samples=Config().data.tiny_data_number,
228
+ random_seed=Config().data.random_seed,
229
+ )
230
+
231
+ # Download the raw datasets for splits
232
+ # There is no need to download data for test as the test dataset of kinetics does not
233
+ # contain labels.
234
+ required_anno_files = obtain_required_anno_files(self.splits_info)
235
+ for split in ["train", "val"]:
236
+ split_anno_path = required_anno_files[split]
237
+ video_path_format = self.set_modality_path_key_format(modality_name="video")
238
+ video_dir = self.splits_info[split][video_path_format]
239
+ if not self._exists(video_dir):
240
+ num_workers = Config().data.downloader.num_workers
241
+ # Set the tmp_dir to save the raw video
242
+ # Then, the raw video will be clipped to save to
243
+ # the target video_dir
244
+
245
+ tmp_dir = os.path.join(video_dir, "tmp")
246
+ logging.info(
247
+ "Downloading the raw videos for the %s %s dataset. This may take a long time.",
248
+ self.data_name,
249
+ split,
250
+ )
251
+ kinetics_downloader.main(
252
+ input_csv=split_anno_path,
253
+ output_dir=video_dir,
254
+ trim_format="%06d",
255
+ num_jobs=num_workers,
256
+ tmp_dir=tmp_dir,
257
+ )
258
+ # Rename of class name
259
+ for split in ["train", "val"]:
260
+ self.rename_classes(mode=split)
261
+
262
+ logging.info("Done.")
263
+
264
+ logging.info("The %s dataset has been prepared", self.data_name)
265
+
266
+ # Extract rgb, flow, audio, audio_feature from the video
267
+ for split in ["train", "val"]:
268
+ self.extract_videos_rgb_flow_audio(mode=split)
269
+
270
+ # Extract the splits information into the
271
+ # list corresponding files
272
+ self.audios_splits_list_files_into = self.extract_splits_list_files(
273
+ data_format="audio_features", splits=["train", "val"]
274
+ )
275
+
276
+ self.video_splits_list_files_into = self.extract_splits_list_files(
277
+ data_format="videos", splits=["train", "val"]
278
+ )
279
+ self.rawframes_splits_list_files_into = self.extract_splits_list_files(
280
+ data_format="rawframes", splits=["train", "val"]
281
+ )
282
+
283
+ def get_modality_name(self):
284
+ """Get all supports modalities"""
285
+ return ["rgb", "flow", "audio"]
286
+
287
+ def rename_classes(self, mode):
288
+ """Rename classes by replacing whitespace to 'Underscore'"""
289
+ video_format_path_key = self.set_modality_path_key_format(modality_name="video")
290
+ videos_root__path = self.splits_info[mode][video_format_path_key]
291
+ videos_dirs_name = [
292
+ dir_name
293
+ for dir_name in os.listdir(videos_root__path)
294
+ if os.path.isdir(os.path.join(videos_root__path, dir_name))
295
+ ]
296
+
297
+ new_videos_dirs_name = [
298
+ dir_name.replace(" ", "_") for dir_name in videos_dirs_name
299
+ ]
300
+
301
+ videos_dirs_path = [
302
+ os.path.join(videos_root__path, dir_name) for dir_name in videos_dirs_name
303
+ ]
304
+ new_videos_dirs_path = [
305
+ os.path.join(videos_root__path, dir_name)
306
+ for dir_name in new_videos_dirs_name
307
+ ]
308
+ for i, _ in enumerate(videos_dirs_path):
309
+ os.rename(videos_dirs_path[i], new_videos_dirs_path[i])
310
+
311
+ def get_modality_data_path(self, mode, modality_name):
312
+ """Obtain the path for the modality data in specific mode"""
313
+
314
+ modality_key = self.set_modality_path_key_format(modality_name=modality_name)
315
+
316
+ return self.splits_info[mode][modality_key]
317
+
318
+ def extract_videos_rgb_flow_audio(self, mode="train"):
319
+ """Extract rgb, optical flow, and audio from videos"""
320
+ video_data_path = self.get_modality_data_path(mode=mode, modality_name="video")
321
+ src_mode_videos_dir = video_data_path
322
+
323
+ rgb_out__path = self.get_modality_data_path(mode=mode, modality_name="rgb")
324
+ flow_our__path = self.get_modality_data_path(mode=mode, modality_name="flow")
325
+ audio_out__path = self.get_modality_data_path(mode=mode, modality_name="audio")
326
+ audio_feature__path = self.get_modality_data_path(
327
+ mode=mode, modality_name="audio_feature"
328
+ )
329
+
330
+ # define the modalities extractor
331
+ if not self._exists(rgb_out__path):
332
+ vdf_extractor = frames_extraction_tools.VideoFramesExtractor(
333
+ video_src_dir=src_mode_videos_dir,
334
+ dir_level=2,
335
+ num_worker=8,
336
+ video_ext="mp4",
337
+ mixed_ext=False,
338
+ )
339
+ if not self._exists(audio_out__path) or not self._exists(audio_feature__path):
340
+ vda_extractor = audio_extraction_tools.VideoAudioExtractor(
341
+ video_src_dir=src_mode_videos_dir,
342
+ dir_level=2,
343
+ num_worker=8,
344
+ video_ext="mp4",
345
+ mixed_ext=False,
346
+ )
347
+
348
+ if torch.cuda.is_available():
349
+ if not self._exists(rgb_out__path):
350
+ logging.info(
351
+ "Extracting frames by GPU from videos in %s to %s.",
352
+ src_mode_videos_dir,
353
+ rgb_out__path,
354
+ )
355
+ vdf_extractor.build_frames_gpu(
356
+ rgb_out__path,
357
+ flow_our__path,
358
+ new_short=1,
359
+ new_width=0,
360
+ new_height=0,
361
+ )
362
+ else:
363
+ if not self._exists(rgb_out__path):
364
+ logging.info(
365
+ "Extracting frames by CPU from videos in %s to %s.",
366
+ src_mode_videos_dir,
367
+ rgb_out__path,
368
+ )
369
+ vdf_extractor.build_frames_cpu(to_dir=rgb_out__path)
370
+
371
+ if not self._exists(audio_out__path):
372
+ logging.info(
373
+ "Extracting audios by CPU from videos in %s to %s.",
374
+ src_mode_videos_dir,
375
+ audio_out__path,
376
+ )
377
+ vda_extractor.build_audios(to_dir=audio_out__path)
378
+
379
+ if not self._exists(audio_feature__path):
380
+ logging.info(
381
+ "Extracting audios feature by CPU from audios in %s to %s.",
382
+ audio_out__path,
383
+ audio_feature__path,
384
+ )
385
+ # # window_size:32ms hop_size:16ms
386
+ vda_extractor.build_audios_features(
387
+ audio_src_path=audio_out__path,
388
+ to_dir=audio_feature__path,
389
+ fft_size=512, # fft_size / sample_rate is window size
390
+ hop_size=256,
391
+ )
392
+
393
+ def extract_splits_list_files(self, data_format, splits):
394
+ """Extract and generate the split information of current mode/phase"""
395
+ output_format = "json"
396
+ out_path = self.mm_data_info["data_path"]
397
+
398
+ # obtained a dict that contains the required data splits' file path
399
+ # it can be full data or tiny data
400
+ required_anno_files = obtain_required_anno_files(self.splits_info)
401
+ data_splits_file_info = required_anno_files
402
+ gen_annots_op = modality_data_anntation_tools.GenerateMDataAnnotation(
403
+ data_src_dir=self.mm_data_info["data_path"],
404
+ data_annos_files_info=data_splits_file_info,
405
+ dataset_name=self.data_name,
406
+ data_format=data_format, # 'rawframes', 'videos', 'audio_features'
407
+ rgb_prefix="img_", # prefix of rgb frames
408
+ flow_x_prefix="x_", # prefix of flow x frames
409
+ flow_y_prefix="y_", # prefix of flow y frames
410
+ out_path=out_path,
411
+ output_format=output_format,
412
+ )
413
+
414
+ target_list_regu = f"_{data_format}.{output_format}"
415
+ if not self._file_exists(
416
+ tg_file_name=target_list_regu, search_path=out_path, is_partial_name=True
417
+ ):
418
+ logging.info("Extracting annotation list for %s. ", data_format)
419
+
420
+ gen_annots_op.read_data_splits_csv_info()
421
+
422
+ for split_name in splits:
423
+ gen_annots_op.generate_data_splits_info_file(split_name=split_name)
424
+
425
+ # obtain the extracted files path
426
+ generated_list_files_info = {}
427
+ for split_name in splits:
428
+ generated_list_files_info[split_name] = gen_annots_op.get_anno_file_path(
429
+ split_name
430
+ )
431
+
432
+ return generated_list_files_info
433
+
434
+ def correct_current_config(self, loaded_plato_config, mode, modality_name):
435
+ """Correct the loaded configuration settings based on on-hand data information."""
436
+
437
+ # 1.1. convert plato config to dict type
438
+ loaded_config = data_utils.config_to_dict(loaded_plato_config)
439
+ # 1.2. convert the list to tuple
440
+ loaded_config = data_utils.dict_list2tuple(loaded_config)
441
+
442
+ # 2. using the obtained annotation file replace the user set ones
443
+ # in the configuration file
444
+ # The main reason is that the obtained path here is the full path
445
+ cur_rawframes_anno_file_path = self.rawframes_splits_list_files_into[mode]
446
+ cur_rawframes_data_path = self.get_modality_data_path(
447
+ mode=mode, modality_name="rgb"
448
+ )
449
+ cur_videos_anno_file_path = self.video_splits_list_files_into[mode]
450
+ cur_video_data_path = self.get_modality_data_path(
451
+ mode=mode, modality_name="video"
452
+ )
453
+ cur_audio_feas_anno_file_path = self.audios_splits_list_files_into[mode]
454
+ cur_audio_feas_data_path = self.get_modality_data_path(
455
+ mode=mode, modality_name="audio_feature"
456
+ )
457
+
458
+ if modality_name == "rgb" or modality_name == "flow":
459
+ loaded_config["ann_file"] = cur_rawframes_anno_file_path
460
+ elif modality_name == "audio_feature":
461
+ loaded_config["ann_file"] = cur_audio_feas_anno_file_path
462
+ else:
463
+ loaded_config["ann_file"] = cur_videos_anno_file_path
464
+
465
+ # 3. reset the data_prefix by using the modality path
466
+ if modality_name == "rgb" or modality_name == "flow":
467
+ loaded_config["data_prefix"] = cur_rawframes_data_path
468
+ elif modality_name == "audio_feature":
469
+ loaded_config["data_prefix"] = cur_audio_feas_data_path
470
+ else:
471
+ loaded_config["data_prefix"] = cur_video_data_path
472
+
473
+ return loaded_config
474
+
475
+ def get_phase_dataset(self, phase, modality_sampler):
476
+ """Get the dataset for the specific phase."""
477
+ rgb_mode_config = getattr(Config().data.multi_modal_configs.rgb, phase)
478
+ flow_mode_config = getattr(Config().data.multi_modal_configs.flow, phase)
479
+ audio_feature_mode_config = getattr(
480
+ Config().data.multi_modal_configs.audio_feature, phase
481
+ )
482
+
483
+ rgb_mode_config = self.correct_current_config(
484
+ loaded_plato_config=rgb_mode_config, mode=phase, modality_name="rgb"
485
+ )
486
+ flow_mode_config = self.correct_current_config(
487
+ loaded_plato_config=flow_mode_config, mode=phase, modality_name="flow"
488
+ )
489
+ audio_feature_mode_config = self.correct_current_config(
490
+ loaded_plato_config=audio_feature_mode_config,
491
+ mode=phase,
492
+ modality_name="audio_feature",
493
+ )
494
+ # build a RawframeDataset
495
+ rgb_mode_dataset = build_dataset(rgb_mode_config)
496
+ flow_mode_dataset = build_dataset(flow_mode_config)
497
+ audio_feature_mode_dataset = build_dataset(audio_feature_mode_config)
498
+
499
+ multi_modal_mode_data = {
500
+ "rgb": rgb_mode_dataset,
501
+ "flow": flow_mode_dataset,
502
+ "audio_feature": audio_feature_mode_dataset,
503
+ }
504
+
505
+ multi_modal_mode_info = {
506
+ "rgb": rgb_mode_config["ann_file"],
507
+ "flow": flow_mode_config["ann_file"],
508
+ "audio_feature": audio_feature_mode_config["ann_file"],
509
+ }
510
+
511
+ kinetics_mode_dataset = KineticsDataset(
512
+ multimodal_data_holder=multi_modal_mode_data,
513
+ phase="train",
514
+ phase_info=multi_modal_mode_info,
515
+ modality_sampler=modality_sampler,
516
+ )
517
+
518
+ return kinetics_mode_dataset
519
+
520
+ def get_train_set(self, modality_sampler=None):
521
+ """Obtain the trainset for multimodal data."""
522
+ kinetics_train_dataset = self.get_phase_dataset(
523
+ phase="train", modality_sampler=modality_sampler
524
+ )
525
+
526
+ return kinetics_train_dataset
527
+
528
+ def get_test_set(self, modality_sampler=None):
529
+ """Obtain the testset for multimodal data.
530
+
531
+ Note, in the kinetics dataset, there is no testset in which
532
+ samples contain the groundtruth label.
533
+ Thus, we utilize the validation set directly.
534
+ """
535
+ kinetics_val_dataset = self.get_phase_dataset(
536
+ phase="val", modality_sampler=modality_sampler
537
+ )
538
+
539
+ return kinetics_val_dataset
540
+
541
+ def get_class_label_mapper(self):
542
+ """Obtain the mapper used to map the text to integer."""
543
+ textclass_integer_mapper = defaultdict(list)
544
+ # obtain the classes from the trainset
545
+ train_anno_list_path = self.rawframes_splits_list_files_into["train"]
546
+ train_anno_list = data_utils.read_anno_file(train_anno_list_path)
547
+ # [{"frame_dir": "clay_pottery_making/---0dWlqevI_000019_000029",
548
+ # "total_frames": 300, "label": [0]}
549
+ for item in train_anno_list:
550
+ textclass = item["frame_dir"].split("/")[0]
551
+ integar_label = item["frame_dir"]["label"][0]
552
+
553
+ textclass_integer_mapper[textclass].append(integar_label)
554
+
555
+ return textclass_integer_mapper
556
+
557
+ def classes(self):
558
+ """The classes of the dataset."""
559
+
560
+ # obtain the classes from the trainset
561
+ train_anno_list_path = self.rawframes_splits_list_files_into["train"]
562
+ train_anno_list = data_utils.read_anno_file(train_anno_list_path)
563
+
564
+ integer_labels = [anno_elem["label"][0] for anno_elem in train_anno_list]
565
+ integer_classes = list(set(integer_labels))
566
+ integer_classes.sort()
567
+
568
+ return integer_classes
@@ -0,0 +1,44 @@
1
+ """
2
+ The MNIST dataset from the torchvision package.
3
+ """
4
+
5
+ from torchvision import datasets, transforms
6
+
7
+ from plato.config import Config
8
+ from plato.datasources import base
9
+
10
+
11
+ class DataSource(base.DataSource):
12
+ """The MNIST dataset."""
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__()
16
+ _path = Config().params["data_path"]
17
+
18
+ train_transform = (
19
+ kwargs["train_transform"]
20
+ if "train_transform" in kwargs
21
+ else transforms.Compose(
22
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
23
+ )
24
+ )
25
+
26
+ test_transform = (
27
+ kwargs["test_transform"]
28
+ if "test_transform" in kwargs
29
+ else transforms.Compose(
30
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
31
+ )
32
+ )
33
+ self.trainset = datasets.MNIST(
34
+ root=_path, train=True, download=True, transform=train_transform
35
+ )
36
+ self.testset = datasets.MNIST(
37
+ root=_path, train=False, download=True, transform=test_transform
38
+ )
39
+
40
+ def num_train_examples(self):
41
+ return 60000
42
+
43
+ def num_test_examples(self):
44
+ return 10000