plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
The CIFAR-100 dataset from the torchvision package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from torchvision import datasets, transforms
|
6
|
+
|
7
|
+
from plato.config import Config
|
8
|
+
from plato.datasources import base
|
9
|
+
|
10
|
+
|
11
|
+
class DataSource(base.DataSource):
|
12
|
+
"""The CIFAR-100 dataset."""
|
13
|
+
|
14
|
+
def __init__(self, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
_path = Config().params["data_path"]
|
17
|
+
|
18
|
+
train_transform = (
|
19
|
+
kwargs["train_transform"]
|
20
|
+
if "train_transform" in kwargs
|
21
|
+
else (
|
22
|
+
transforms.Compose(
|
23
|
+
[
|
24
|
+
transforms.RandomHorizontalFlip(),
|
25
|
+
transforms.RandomCrop(32, 4),
|
26
|
+
transforms.ToTensor(),
|
27
|
+
transforms.Normalize(
|
28
|
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
29
|
+
),
|
30
|
+
]
|
31
|
+
)
|
32
|
+
)
|
33
|
+
)
|
34
|
+
|
35
|
+
test_transform = (
|
36
|
+
kwargs["test_transform"]
|
37
|
+
if "test_transform" in kwargs
|
38
|
+
else (
|
39
|
+
transforms.Compose(
|
40
|
+
[
|
41
|
+
transforms.ToTensor(),
|
42
|
+
transforms.Normalize(
|
43
|
+
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
44
|
+
),
|
45
|
+
]
|
46
|
+
)
|
47
|
+
)
|
48
|
+
)
|
49
|
+
|
50
|
+
self.trainset = datasets.CIFAR100(
|
51
|
+
root=_path, train=True, download=True, transform=train_transform
|
52
|
+
)
|
53
|
+
self.testset = datasets.CIFAR100(
|
54
|
+
root=_path, train=False, download=True, transform=test_transform
|
55
|
+
)
|
56
|
+
|
57
|
+
def num_train_examples(self):
|
58
|
+
return 50000
|
59
|
+
|
60
|
+
def num_test_examples(self):
|
61
|
+
return 10000
|
@@ -0,0 +1,62 @@
|
|
1
|
+
"""
|
2
|
+
The CINIC-10 dataset.
|
3
|
+
|
4
|
+
For more information about CINIC-10, refer to:
|
5
|
+
|
6
|
+
https://github.com/BayesWatch/cinic-10
|
7
|
+
"""
|
8
|
+
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
|
12
|
+
from torchvision import datasets, transforms
|
13
|
+
|
14
|
+
from plato.config import Config
|
15
|
+
from plato.datasources import base
|
16
|
+
|
17
|
+
|
18
|
+
class DataSource(base.DataSource):
|
19
|
+
"""The CINIC-10 dataset."""
|
20
|
+
|
21
|
+
def __init__(self, **kwargs):
|
22
|
+
super().__init__()
|
23
|
+
_path = Config().params["data_path"]
|
24
|
+
|
25
|
+
if not os.path.exists(_path):
|
26
|
+
logging.info("Downloading the CINIC-10 dataset. This may take a while.")
|
27
|
+
url = (
|
28
|
+
Config().data.download_url
|
29
|
+
if hasattr(Config().data, "download_url")
|
30
|
+
else "http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz"
|
31
|
+
)
|
32
|
+
DataSource.download(url, _path)
|
33
|
+
|
34
|
+
train_transform = (
|
35
|
+
kwargs["train_transform"]
|
36
|
+
if "train_transform" in kwargs
|
37
|
+
else (
|
38
|
+
transforms.Compose(
|
39
|
+
[
|
40
|
+
transforms.ToTensor(),
|
41
|
+
transforms.Normalize(
|
42
|
+
[0.47889522, 0.47227842, 0.43047404],
|
43
|
+
[0.24205776, 0.23828046, 0.25874835],
|
44
|
+
),
|
45
|
+
]
|
46
|
+
)
|
47
|
+
)
|
48
|
+
)
|
49
|
+
test_transform = train_transform
|
50
|
+
|
51
|
+
self.trainset = datasets.ImageFolder(
|
52
|
+
root=os.path.join(_path, "train"), transform=train_transform
|
53
|
+
)
|
54
|
+
self.testset = datasets.ImageFolder(
|
55
|
+
root=os.path.join(_path, "test"), transform=test_transform
|
56
|
+
)
|
57
|
+
|
58
|
+
def num_train_examples(self):
|
59
|
+
return 90000
|
60
|
+
|
61
|
+
def num_test_examples(self):
|
62
|
+
return 90000
|
@@ -0,0 +1,119 @@
|
|
1
|
+
"""
|
2
|
+
The MS COCO- dataset stands for Common Objects in Context, and is
|
3
|
+
designed to represent a vast array of objects that we
|
4
|
+
regularly encounter in everyday life.
|
5
|
+
|
6
|
+
We mainly utilize COCO-17 (25.20 GB):
|
7
|
+
- COCO has 121,408 images in total.
|
8
|
+
- has 883,331 object annotations
|
9
|
+
- COCO defines 91 classes but the data only uses 80 classes.
|
10
|
+
- Some images from the train and validation sets don’t have annotations.
|
11
|
+
- The test set does not have annotations.
|
12
|
+
- COCO 2014 and 2017 use the same images, but the splits are different.
|
13
|
+
-> for image, detection, segmentation.
|
14
|
+
|
15
|
+
The data structure and setting follows:
|
16
|
+
"https://cocodataset.org/#home".
|
17
|
+
|
18
|
+
Then, the download urls are obtained from:
|
19
|
+
"https://gist.github.com/mkocabas/a6177fc00315403d31572e17700d7fd9".
|
20
|
+
|
21
|
+
We utilize the official splits that contain:
|
22
|
+
- train: 118,287 images
|
23
|
+
- val: 5,000 images
|
24
|
+
- test: 40,670 images
|
25
|
+
|
26
|
+
The file structure of this dataset is:
|
27
|
+
- train2017: train images
|
28
|
+
- test2017: test images
|
29
|
+
- val2017: validation images
|
30
|
+
- annotations_trainval2017: captions for train/val
|
31
|
+
|
32
|
+
The data structure under the 'data/' is:
|
33
|
+
├── COCO2017 # root dir of COCO2017 Entities dataset
|
34
|
+
│ ├── COCO2017Raw # Raw images/annotations and the official splits
|
35
|
+
│ │ └── annotations
|
36
|
+
│ │ └── train2017
|
37
|
+
│ │ └── test2017
|
38
|
+
│ │ └── val2017
|
39
|
+
│ ├── train # images for the train phase
|
40
|
+
│ └── test # images for the test phase
|
41
|
+
│ └── val # images for the validation phase
|
42
|
+
|
43
|
+
Note:
|
44
|
+
Currently, we have not utilize the COCO dataset to train the model.
|
45
|
+
Thus, we only implement the code of downloading and arrange the data,
|
46
|
+
which is required when using the referitgame dataset.
|
47
|
+
"""
|
48
|
+
|
49
|
+
import os
|
50
|
+
import shutil
|
51
|
+
|
52
|
+
from plato.config import Config
|
53
|
+
from plato.datasources import multimodal_base
|
54
|
+
|
55
|
+
|
56
|
+
class DataSource(multimodal_base.MultiModalDataSource):
|
57
|
+
"""The COCO dataset."""
|
58
|
+
|
59
|
+
def __init__(self, **kwargs):
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
self.data_name = Config().data.dataname
|
63
|
+
self.data_source = Config().data.datasource
|
64
|
+
|
65
|
+
self.modality_names = ["image", "text"]
|
66
|
+
|
67
|
+
_path = Config().params["data_path"]
|
68
|
+
self._data_path_process(data_path=_path, base_data_name=self.data_name)
|
69
|
+
|
70
|
+
base_data_path = self.mm_data_info["data_path"]
|
71
|
+
raw_data_name = self.data_name + "Raw"
|
72
|
+
raw_data_path = os.path.join(base_data_path, raw_data_name)
|
73
|
+
if not self._exists(raw_data_path):
|
74
|
+
os.makedirs(raw_data_path, exist_ok=True)
|
75
|
+
|
76
|
+
download_train_url = Config().data.download_train_url
|
77
|
+
download_test_url = Config().data.download_test_url
|
78
|
+
download_val_url = Config().data.download_val_url
|
79
|
+
download_annotation_url = Config().data.download_annotation_url
|
80
|
+
|
81
|
+
splits_downalods = {
|
82
|
+
"train": download_train_url,
|
83
|
+
"test": download_test_url,
|
84
|
+
"val": download_val_url,
|
85
|
+
}
|
86
|
+
|
87
|
+
# Download raw data and extract to different splits
|
88
|
+
for split_name in list(self.splits_info.keys()):
|
89
|
+
split_path = self.splits_info[split_name]["path"]
|
90
|
+
split_download_url = splits_downalods[split_name]
|
91
|
+
split_file_name = self._download_arrange_data(
|
92
|
+
download_url_address=split_download_url,
|
93
|
+
data_path=raw_data_path,
|
94
|
+
extract_to_dir=split_path,
|
95
|
+
)
|
96
|
+
# renaming the extracted file to "images"
|
97
|
+
extracted_path = os.path.join(split_path, split_file_name)
|
98
|
+
renamed_path = os.path.join(split_path, "images")
|
99
|
+
os.rename(src=extracted_path, dst=renamed_path)
|
100
|
+
|
101
|
+
# Download the annotation
|
102
|
+
self._download_arrange_data(
|
103
|
+
download_url_address=download_annotation_url, data_path=raw_data_path
|
104
|
+
)
|
105
|
+
annotation_path = os.path.join(raw_data_path, "annotations")
|
106
|
+
|
107
|
+
# Move the annotation to each split
|
108
|
+
splits_caption_name = {
|
109
|
+
"train": "captions_train2017.json",
|
110
|
+
"val": "captions_val2017.json",
|
111
|
+
}
|
112
|
+
for split_name in list(splits_caption_name.keys()):
|
113
|
+
split_caption_name = splits_caption_name[split_name]
|
114
|
+
to_split_path = os.path.join(
|
115
|
+
self.splits_info[split_name]["path"], "captions.json"
|
116
|
+
)
|
117
|
+
shutil.copyfile(
|
118
|
+
src=os.path.join(annotation_path, split_caption_name), dst=to_split_path
|
119
|
+
)
|
File without changes
|
@@ -0,0 +1,137 @@
|
|
1
|
+
"""
|
2
|
+
Tools for extracting information from the audio
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import glob
|
7
|
+
import os
|
8
|
+
|
9
|
+
from multiprocessing import Pool
|
10
|
+
|
11
|
+
from mmaction.tools.data import build_audio_features
|
12
|
+
|
13
|
+
from plato.datasources.datalib import modality_extraction_base
|
14
|
+
|
15
|
+
|
16
|
+
def obtain_audio_dest_dir(out_dir, audio_path, dir_level):
|
17
|
+
"""Get the destination path to save the audio"""
|
18
|
+
|
19
|
+
if dir_level == 2:
|
20
|
+
class_name = os.path.basename(os.path.dirname(audio_path))
|
21
|
+
_, tail = os.path.split(audio_path)
|
22
|
+
_ = tail.split(".")[0] # should get the audio_name
|
23
|
+
out_dir_full_path = os.path.join(out_dir, class_name)
|
24
|
+
else: # the class name is not contained
|
25
|
+
_ = audio_path.split(".")[0] # audio_name
|
26
|
+
out_dir_full_path = out_dir
|
27
|
+
|
28
|
+
return out_dir_full_path
|
29
|
+
|
30
|
+
|
31
|
+
def extract_audio_wav(line_times):
|
32
|
+
"""Extract the audio wave from video streams using FFMPEG."""
|
33
|
+
line, root, out_dir = line_times
|
34
|
+
video_id, _ = os.path.splitext(os.path.basename(line))
|
35
|
+
video_dir = os.path.dirname(line)
|
36
|
+
video_rel_dir = os.path.relpath(video_dir, root)
|
37
|
+
dst_dir = os.path.join(out_dir, video_rel_dir)
|
38
|
+
os.popen(f"mkdir -p {dst_dir}")
|
39
|
+
try:
|
40
|
+
if os.path.exists(f"{dst_dir}/{video_id}.wav"):
|
41
|
+
return
|
42
|
+
cmd = f"ffmpeg -i {line} -map 0:a -y {dst_dir}/{video_id}.wav"
|
43
|
+
os.popen(cmd)
|
44
|
+
except OSError:
|
45
|
+
with open("extract_wav_err_file.txt", "a+") as error_file:
|
46
|
+
error_file.write(f"{line}\n")
|
47
|
+
|
48
|
+
|
49
|
+
class VideoAudioExtractor(modality_extraction_base.VideoExtractorBase):
|
50
|
+
"""A class for extracting audio from the video"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
video_src_dir,
|
55
|
+
dir_level=2,
|
56
|
+
num_worker=2,
|
57
|
+
video_ext="mp4",
|
58
|
+
mixed_ext=False,
|
59
|
+
audio_ext="wav",
|
60
|
+
):
|
61
|
+
super().__init__(video_src_dir, dir_level, num_worker, video_ext, mixed_ext)
|
62
|
+
self.audio_ext = audio_ext
|
63
|
+
|
64
|
+
def build_audios(
|
65
|
+
self,
|
66
|
+
to_dir,
|
67
|
+
):
|
68
|
+
"""Extract audios in parallel"""
|
69
|
+
sourc_video_dir = self.video_src_dir
|
70
|
+
if self.dir_level == 2:
|
71
|
+
self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
|
72
|
+
_ = glob.glob(to_dir + "/*" * self.dir_level + ".wav")
|
73
|
+
|
74
|
+
pool = Pool(self.num_worker)
|
75
|
+
pool.map(
|
76
|
+
extract_audio_wav,
|
77
|
+
zip(
|
78
|
+
self.fullpath_list,
|
79
|
+
len(self.videos_path_list) * [sourc_video_dir],
|
80
|
+
len(self.videos_path_list) * [to_dir],
|
81
|
+
),
|
82
|
+
)
|
83
|
+
|
84
|
+
def build_audios_features(
|
85
|
+
self,
|
86
|
+
audio_src_path, # the dir that contains the src audio files
|
87
|
+
to_dir, # dir to save the extracted features
|
88
|
+
frame_rate=30, # The frame rate per second of the video.
|
89
|
+
sample_rate=16000, # The sample rate for audio sampling
|
90
|
+
num_mels=80, # Number of channels of the melspectrogram. Default
|
91
|
+
fft_size=1280, # fft_size / sample_rate is window size
|
92
|
+
hop_size=320, # hop_size / sample_rate is step size
|
93
|
+
spectrogram_type="lws", # lws, 'librosa', recommand lws
|
94
|
+
# part: Determines how many parts to be splited and which part to run.
|
95
|
+
# e.g., 2/5 means splitting all files into 5-fold and executing the 2nd part.
|
96
|
+
# This is useful if you have several machines
|
97
|
+
part="1/1",
|
98
|
+
):
|
99
|
+
"""Obtain the feature from the audio"""
|
100
|
+
audio_tools = build_audio_features.AudioTools(
|
101
|
+
frame_rate=frame_rate,
|
102
|
+
sample_rate=sample_rate,
|
103
|
+
num_mels=num_mels,
|
104
|
+
fft_size=fft_size,
|
105
|
+
hop_size=hop_size,
|
106
|
+
spectrogram_type=spectrogram_type,
|
107
|
+
)
|
108
|
+
|
109
|
+
audio_files = glob.glob(
|
110
|
+
audio_src_path + "/*" * self.dir_level + "." + self.audio_ext
|
111
|
+
)
|
112
|
+
files = sorted(audio_files)
|
113
|
+
|
114
|
+
if part is not None:
|
115
|
+
[this_part, num_parts] = [int(i) for i in part.split("/")]
|
116
|
+
part_len = len(files) // num_parts
|
117
|
+
|
118
|
+
extractor_pool = Pool(self.num_worker)
|
119
|
+
for file in files[
|
120
|
+
part_len * (this_part - 1) : (part_len * this_part)
|
121
|
+
if this_part != num_parts
|
122
|
+
else len(files)
|
123
|
+
]:
|
124
|
+
out_full_path = obtain_audio_dest_dir(
|
125
|
+
out_dir=to_dir, audio_path=file, dir_level=self.dir_level
|
126
|
+
)
|
127
|
+
|
128
|
+
# create the output dir if not existed
|
129
|
+
if not os.path.exists(out_full_path):
|
130
|
+
os.makedirs(out_full_path)
|
131
|
+
|
132
|
+
extractor_pool.apply_async(
|
133
|
+
build_audio_features.extract_audio_feature,
|
134
|
+
args=(file, audio_tools, out_full_path),
|
135
|
+
)
|
136
|
+
extractor_pool.close()
|
137
|
+
extractor_pool.join()
|
@@ -0,0 +1,124 @@
|
|
1
|
+
"""
|
2
|
+
Useful tools for processing the data
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import shutil
|
7
|
+
import os
|
8
|
+
import json
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
|
12
|
+
def config_to_dict(plato_config):
|
13
|
+
"""Convert the plato config (can be nested one) instance to the dict."""
|
14
|
+
# convert the whole to dict - OrderedDict
|
15
|
+
plato_config_dict = plato_config._asdict()
|
16
|
+
|
17
|
+
def to_dict(elem):
|
18
|
+
for key, value in elem.items():
|
19
|
+
try:
|
20
|
+
value = value._asdict()
|
21
|
+
elem[key] = to_dict(value)
|
22
|
+
except:
|
23
|
+
pass
|
24
|
+
if isinstance(value, list):
|
25
|
+
for idx, value_item in enumerate(value):
|
26
|
+
try:
|
27
|
+
value_item = value_item._asdict()
|
28
|
+
value[idx] = to_dict(value_item)
|
29
|
+
except:
|
30
|
+
pass
|
31
|
+
elem[key] = value
|
32
|
+
return elem
|
33
|
+
|
34
|
+
plato_config_dict = to_dict(plato_config_dict)
|
35
|
+
|
36
|
+
return plato_config_dict
|
37
|
+
|
38
|
+
|
39
|
+
def dict_list2tuple(dict_obj):
|
40
|
+
"""Convert all list element in the dict to tuple"""
|
41
|
+
for key, value in dict_obj.items():
|
42
|
+
if isinstance(value, dict):
|
43
|
+
for inner_key, inner_v in value.items():
|
44
|
+
if isinstance(inner_v, list):
|
45
|
+
# empty or None list, mainly for meta_keys
|
46
|
+
if not value or inner_v[0] is None:
|
47
|
+
dict_obj[key][inner_key] = ()
|
48
|
+
else:
|
49
|
+
dict_obj[key][inner_key] = tuple(inner_v)
|
50
|
+
else:
|
51
|
+
if isinstance(value, list):
|
52
|
+
# empty or None list, mainly for meta_keys
|
53
|
+
if not value or value[0] is None:
|
54
|
+
dict_obj[key] = ()
|
55
|
+
else:
|
56
|
+
dict_obj[key] = tuple(value)
|
57
|
+
for idx, item in enumerate(value):
|
58
|
+
item = value[idx]
|
59
|
+
if isinstance(item, dict):
|
60
|
+
value[idx] = dict_list2tuple(item)
|
61
|
+
|
62
|
+
return dict_obj
|
63
|
+
|
64
|
+
|
65
|
+
def phrase_boxes_alignment(flatten_boxes, ori_phrases_boxes):
|
66
|
+
"""Align the phase and its corresponding boxes"""
|
67
|
+
phrases_boxes = list()
|
68
|
+
|
69
|
+
ori_pb_boxes_count = list()
|
70
|
+
for ph_boxes in ori_phrases_boxes:
|
71
|
+
ori_pb_boxes_count.append(len(ph_boxes))
|
72
|
+
|
73
|
+
strat_point = 0
|
74
|
+
for pb_boxes_num in ori_pb_boxes_count:
|
75
|
+
sub_boxes = list()
|
76
|
+
for i in range(strat_point, strat_point + pb_boxes_num):
|
77
|
+
sub_boxes.append(flatten_boxes[i])
|
78
|
+
|
79
|
+
strat_point += pb_boxes_num
|
80
|
+
phrases_boxes.append(sub_boxes)
|
81
|
+
|
82
|
+
pb_boxes_count = list()
|
83
|
+
for ph_boxes in phrases_boxes:
|
84
|
+
pb_boxes_count.append(len(ph_boxes))
|
85
|
+
|
86
|
+
assert pb_boxes_count == ori_pb_boxes_count
|
87
|
+
|
88
|
+
return phrases_boxes
|
89
|
+
|
90
|
+
|
91
|
+
def list_inorder(listed_files, flag_str):
|
92
|
+
""" " List the files in order based on the file name"""
|
93
|
+
filtered_listed_files = [fn for fn in listed_files if flag_str in fn]
|
94
|
+
listed_files = sorted(filtered_listed_files, key=lambda x: x.strip().split(".")[0])
|
95
|
+
return listed_files
|
96
|
+
|
97
|
+
|
98
|
+
def copy_files(src_files, dst_dir):
|
99
|
+
"""copy files from src to dst"""
|
100
|
+
for file in src_files:
|
101
|
+
shutil.copy(file, dst_dir)
|
102
|
+
|
103
|
+
|
104
|
+
def union_shuffled_lists(src_lists):
|
105
|
+
"""shuffle the lists"""
|
106
|
+
for i in range(1, len(src_lists)):
|
107
|
+
assert len(src_lists[i]) == len(src_lists[i - 1])
|
108
|
+
processed = np.random.permutation(len(src_lists[0]))
|
109
|
+
|
110
|
+
return [np.array(ele)[processed] for ele in src_lists]
|
111
|
+
|
112
|
+
|
113
|
+
def read_anno_file(anno_file_path):
|
114
|
+
_, tail = os.path.split(anno_file_path)
|
115
|
+
file_type = tail.split(".")[-1]
|
116
|
+
|
117
|
+
if file_type == "json":
|
118
|
+
with open(anno_file_path, "r") as anno_file:
|
119
|
+
annos_list = json.load(anno_file)
|
120
|
+
else:
|
121
|
+
with open(anno_file_path, "r") as anno_file:
|
122
|
+
annos_list = anno_file.readlines()
|
123
|
+
|
124
|
+
return annos_list
|