plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,61 @@
1
+ """
2
+ The CIFAR-100 dataset from the torchvision package.
3
+ """
4
+
5
+ from torchvision import datasets, transforms
6
+
7
+ from plato.config import Config
8
+ from plato.datasources import base
9
+
10
+
11
+ class DataSource(base.DataSource):
12
+ """The CIFAR-100 dataset."""
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__()
16
+ _path = Config().params["data_path"]
17
+
18
+ train_transform = (
19
+ kwargs["train_transform"]
20
+ if "train_transform" in kwargs
21
+ else (
22
+ transforms.Compose(
23
+ [
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.RandomCrop(32, 4),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(
28
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
29
+ ),
30
+ ]
31
+ )
32
+ )
33
+ )
34
+
35
+ test_transform = (
36
+ kwargs["test_transform"]
37
+ if "test_transform" in kwargs
38
+ else (
39
+ transforms.Compose(
40
+ [
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(
43
+ [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
44
+ ),
45
+ ]
46
+ )
47
+ )
48
+ )
49
+
50
+ self.trainset = datasets.CIFAR100(
51
+ root=_path, train=True, download=True, transform=train_transform
52
+ )
53
+ self.testset = datasets.CIFAR100(
54
+ root=_path, train=False, download=True, transform=test_transform
55
+ )
56
+
57
+ def num_train_examples(self):
58
+ return 50000
59
+
60
+ def num_test_examples(self):
61
+ return 10000
@@ -0,0 +1,62 @@
1
+ """
2
+ The CINIC-10 dataset.
3
+
4
+ For more information about CINIC-10, refer to:
5
+
6
+ https://github.com/BayesWatch/cinic-10
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ from torchvision import datasets, transforms
13
+
14
+ from plato.config import Config
15
+ from plato.datasources import base
16
+
17
+
18
+ class DataSource(base.DataSource):
19
+ """The CINIC-10 dataset."""
20
+
21
+ def __init__(self, **kwargs):
22
+ super().__init__()
23
+ _path = Config().params["data_path"]
24
+
25
+ if not os.path.exists(_path):
26
+ logging.info("Downloading the CINIC-10 dataset. This may take a while.")
27
+ url = (
28
+ Config().data.download_url
29
+ if hasattr(Config().data, "download_url")
30
+ else "http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz"
31
+ )
32
+ DataSource.download(url, _path)
33
+
34
+ train_transform = (
35
+ kwargs["train_transform"]
36
+ if "train_transform" in kwargs
37
+ else (
38
+ transforms.Compose(
39
+ [
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(
42
+ [0.47889522, 0.47227842, 0.43047404],
43
+ [0.24205776, 0.23828046, 0.25874835],
44
+ ),
45
+ ]
46
+ )
47
+ )
48
+ )
49
+ test_transform = train_transform
50
+
51
+ self.trainset = datasets.ImageFolder(
52
+ root=os.path.join(_path, "train"), transform=train_transform
53
+ )
54
+ self.testset = datasets.ImageFolder(
55
+ root=os.path.join(_path, "test"), transform=test_transform
56
+ )
57
+
58
+ def num_train_examples(self):
59
+ return 90000
60
+
61
+ def num_test_examples(self):
62
+ return 90000
@@ -0,0 +1,119 @@
1
+ """
2
+ The MS COCO- dataset stands for Common Objects in Context, and is
3
+ designed to represent a vast array of objects that we
4
+ regularly encounter in everyday life.
5
+
6
+ We mainly utilize COCO-17 (25.20 GB):
7
+ - COCO has 121,408 images in total.
8
+ - has 883,331 object annotations
9
+ - COCO defines 91 classes but the data only uses 80 classes.
10
+ - Some images from the train and validation sets don’t have annotations.
11
+ - The test set does not have annotations.
12
+ - COCO 2014 and 2017 use the same images, but the splits are different.
13
+ -> for image, detection, segmentation.
14
+
15
+ The data structure and setting follows:
16
+ "https://cocodataset.org/#home".
17
+
18
+ Then, the download urls are obtained from:
19
+ "https://gist.github.com/mkocabas/a6177fc00315403d31572e17700d7fd9".
20
+
21
+ We utilize the official splits that contain:
22
+ - train: 118,287 images
23
+ - val: 5,000 images
24
+ - test: 40,670 images
25
+
26
+ The file structure of this dataset is:
27
+ - train2017: train images
28
+ - test2017: test images
29
+ - val2017: validation images
30
+ - annotations_trainval2017: captions for train/val
31
+
32
+ The data structure under the 'data/' is:
33
+ ├── COCO2017 # root dir of COCO2017 Entities dataset
34
+ │ ├── COCO2017Raw # Raw images/annotations and the official splits
35
+ │ │ └── annotations
36
+ │ │ └── train2017
37
+ │ │ └── test2017
38
+ │ │ └── val2017
39
+ │ ├── train # images for the train phase
40
+ │ └── test # images for the test phase
41
+ │ └── val # images for the validation phase
42
+
43
+ Note:
44
+ Currently, we have not utilize the COCO dataset to train the model.
45
+ Thus, we only implement the code of downloading and arrange the data,
46
+ which is required when using the referitgame dataset.
47
+ """
48
+
49
+ import os
50
+ import shutil
51
+
52
+ from plato.config import Config
53
+ from plato.datasources import multimodal_base
54
+
55
+
56
+ class DataSource(multimodal_base.MultiModalDataSource):
57
+ """The COCO dataset."""
58
+
59
+ def __init__(self, **kwargs):
60
+ super().__init__()
61
+
62
+ self.data_name = Config().data.dataname
63
+ self.data_source = Config().data.datasource
64
+
65
+ self.modality_names = ["image", "text"]
66
+
67
+ _path = Config().params["data_path"]
68
+ self._data_path_process(data_path=_path, base_data_name=self.data_name)
69
+
70
+ base_data_path = self.mm_data_info["data_path"]
71
+ raw_data_name = self.data_name + "Raw"
72
+ raw_data_path = os.path.join(base_data_path, raw_data_name)
73
+ if not self._exists(raw_data_path):
74
+ os.makedirs(raw_data_path, exist_ok=True)
75
+
76
+ download_train_url = Config().data.download_train_url
77
+ download_test_url = Config().data.download_test_url
78
+ download_val_url = Config().data.download_val_url
79
+ download_annotation_url = Config().data.download_annotation_url
80
+
81
+ splits_downalods = {
82
+ "train": download_train_url,
83
+ "test": download_test_url,
84
+ "val": download_val_url,
85
+ }
86
+
87
+ # Download raw data and extract to different splits
88
+ for split_name in list(self.splits_info.keys()):
89
+ split_path = self.splits_info[split_name]["path"]
90
+ split_download_url = splits_downalods[split_name]
91
+ split_file_name = self._download_arrange_data(
92
+ download_url_address=split_download_url,
93
+ data_path=raw_data_path,
94
+ extract_to_dir=split_path,
95
+ )
96
+ # renaming the extracted file to "images"
97
+ extracted_path = os.path.join(split_path, split_file_name)
98
+ renamed_path = os.path.join(split_path, "images")
99
+ os.rename(src=extracted_path, dst=renamed_path)
100
+
101
+ # Download the annotation
102
+ self._download_arrange_data(
103
+ download_url_address=download_annotation_url, data_path=raw_data_path
104
+ )
105
+ annotation_path = os.path.join(raw_data_path, "annotations")
106
+
107
+ # Move the annotation to each split
108
+ splits_caption_name = {
109
+ "train": "captions_train2017.json",
110
+ "val": "captions_val2017.json",
111
+ }
112
+ for split_name in list(splits_caption_name.keys()):
113
+ split_caption_name = splits_caption_name[split_name]
114
+ to_split_path = os.path.join(
115
+ self.splits_info[split_name]["path"], "captions.json"
116
+ )
117
+ shutil.copyfile(
118
+ src=os.path.join(annotation_path, split_caption_name), dst=to_split_path
119
+ )
File without changes
@@ -0,0 +1,137 @@
1
+ """
2
+ Tools for extracting information from the audio
3
+
4
+ """
5
+
6
+ import glob
7
+ import os
8
+
9
+ from multiprocessing import Pool
10
+
11
+ from mmaction.tools.data import build_audio_features
12
+
13
+ from plato.datasources.datalib import modality_extraction_base
14
+
15
+
16
+ def obtain_audio_dest_dir(out_dir, audio_path, dir_level):
17
+ """Get the destination path to save the audio"""
18
+
19
+ if dir_level == 2:
20
+ class_name = os.path.basename(os.path.dirname(audio_path))
21
+ _, tail = os.path.split(audio_path)
22
+ _ = tail.split(".")[0] # should get the audio_name
23
+ out_dir_full_path = os.path.join(out_dir, class_name)
24
+ else: # the class name is not contained
25
+ _ = audio_path.split(".")[0] # audio_name
26
+ out_dir_full_path = out_dir
27
+
28
+ return out_dir_full_path
29
+
30
+
31
+ def extract_audio_wav(line_times):
32
+ """Extract the audio wave from video streams using FFMPEG."""
33
+ line, root, out_dir = line_times
34
+ video_id, _ = os.path.splitext(os.path.basename(line))
35
+ video_dir = os.path.dirname(line)
36
+ video_rel_dir = os.path.relpath(video_dir, root)
37
+ dst_dir = os.path.join(out_dir, video_rel_dir)
38
+ os.popen(f"mkdir -p {dst_dir}")
39
+ try:
40
+ if os.path.exists(f"{dst_dir}/{video_id}.wav"):
41
+ return
42
+ cmd = f"ffmpeg -i {line} -map 0:a -y {dst_dir}/{video_id}.wav"
43
+ os.popen(cmd)
44
+ except OSError:
45
+ with open("extract_wav_err_file.txt", "a+") as error_file:
46
+ error_file.write(f"{line}\n")
47
+
48
+
49
+ class VideoAudioExtractor(modality_extraction_base.VideoExtractorBase):
50
+ """A class for extracting audio from the video"""
51
+
52
+ def __init__(
53
+ self,
54
+ video_src_dir,
55
+ dir_level=2,
56
+ num_worker=2,
57
+ video_ext="mp4",
58
+ mixed_ext=False,
59
+ audio_ext="wav",
60
+ ):
61
+ super().__init__(video_src_dir, dir_level, num_worker, video_ext, mixed_ext)
62
+ self.audio_ext = audio_ext
63
+
64
+ def build_audios(
65
+ self,
66
+ to_dir,
67
+ ):
68
+ """Extract audios in parallel"""
69
+ sourc_video_dir = self.video_src_dir
70
+ if self.dir_level == 2:
71
+ self.organize_modality_dir(src_dir=sourc_video_dir, to_dir=to_dir)
72
+ _ = glob.glob(to_dir + "/*" * self.dir_level + ".wav")
73
+
74
+ pool = Pool(self.num_worker)
75
+ pool.map(
76
+ extract_audio_wav,
77
+ zip(
78
+ self.fullpath_list,
79
+ len(self.videos_path_list) * [sourc_video_dir],
80
+ len(self.videos_path_list) * [to_dir],
81
+ ),
82
+ )
83
+
84
+ def build_audios_features(
85
+ self,
86
+ audio_src_path, # the dir that contains the src audio files
87
+ to_dir, # dir to save the extracted features
88
+ frame_rate=30, # The frame rate per second of the video.
89
+ sample_rate=16000, # The sample rate for audio sampling
90
+ num_mels=80, # Number of channels of the melspectrogram. Default
91
+ fft_size=1280, # fft_size / sample_rate is window size
92
+ hop_size=320, # hop_size / sample_rate is step size
93
+ spectrogram_type="lws", # lws, 'librosa', recommand lws
94
+ # part: Determines how many parts to be splited and which part to run.
95
+ # e.g., 2/5 means splitting all files into 5-fold and executing the 2nd part.
96
+ # This is useful if you have several machines
97
+ part="1/1",
98
+ ):
99
+ """Obtain the feature from the audio"""
100
+ audio_tools = build_audio_features.AudioTools(
101
+ frame_rate=frame_rate,
102
+ sample_rate=sample_rate,
103
+ num_mels=num_mels,
104
+ fft_size=fft_size,
105
+ hop_size=hop_size,
106
+ spectrogram_type=spectrogram_type,
107
+ )
108
+
109
+ audio_files = glob.glob(
110
+ audio_src_path + "/*" * self.dir_level + "." + self.audio_ext
111
+ )
112
+ files = sorted(audio_files)
113
+
114
+ if part is not None:
115
+ [this_part, num_parts] = [int(i) for i in part.split("/")]
116
+ part_len = len(files) // num_parts
117
+
118
+ extractor_pool = Pool(self.num_worker)
119
+ for file in files[
120
+ part_len * (this_part - 1) : (part_len * this_part)
121
+ if this_part != num_parts
122
+ else len(files)
123
+ ]:
124
+ out_full_path = obtain_audio_dest_dir(
125
+ out_dir=to_dir, audio_path=file, dir_level=self.dir_level
126
+ )
127
+
128
+ # create the output dir if not existed
129
+ if not os.path.exists(out_full_path):
130
+ os.makedirs(out_full_path)
131
+
132
+ extractor_pool.apply_async(
133
+ build_audio_features.extract_audio_feature,
134
+ args=(file, audio_tools, out_full_path),
135
+ )
136
+ extractor_pool.close()
137
+ extractor_pool.join()
@@ -0,0 +1,124 @@
1
+ """
2
+ Useful tools for processing the data
3
+
4
+ """
5
+
6
+ import shutil
7
+ import os
8
+ import json
9
+ import numpy as np
10
+
11
+
12
+ def config_to_dict(plato_config):
13
+ """Convert the plato config (can be nested one) instance to the dict."""
14
+ # convert the whole to dict - OrderedDict
15
+ plato_config_dict = plato_config._asdict()
16
+
17
+ def to_dict(elem):
18
+ for key, value in elem.items():
19
+ try:
20
+ value = value._asdict()
21
+ elem[key] = to_dict(value)
22
+ except:
23
+ pass
24
+ if isinstance(value, list):
25
+ for idx, value_item in enumerate(value):
26
+ try:
27
+ value_item = value_item._asdict()
28
+ value[idx] = to_dict(value_item)
29
+ except:
30
+ pass
31
+ elem[key] = value
32
+ return elem
33
+
34
+ plato_config_dict = to_dict(plato_config_dict)
35
+
36
+ return plato_config_dict
37
+
38
+
39
+ def dict_list2tuple(dict_obj):
40
+ """Convert all list element in the dict to tuple"""
41
+ for key, value in dict_obj.items():
42
+ if isinstance(value, dict):
43
+ for inner_key, inner_v in value.items():
44
+ if isinstance(inner_v, list):
45
+ # empty or None list, mainly for meta_keys
46
+ if not value or inner_v[0] is None:
47
+ dict_obj[key][inner_key] = ()
48
+ else:
49
+ dict_obj[key][inner_key] = tuple(inner_v)
50
+ else:
51
+ if isinstance(value, list):
52
+ # empty or None list, mainly for meta_keys
53
+ if not value or value[0] is None:
54
+ dict_obj[key] = ()
55
+ else:
56
+ dict_obj[key] = tuple(value)
57
+ for idx, item in enumerate(value):
58
+ item = value[idx]
59
+ if isinstance(item, dict):
60
+ value[idx] = dict_list2tuple(item)
61
+
62
+ return dict_obj
63
+
64
+
65
+ def phrase_boxes_alignment(flatten_boxes, ori_phrases_boxes):
66
+ """Align the phase and its corresponding boxes"""
67
+ phrases_boxes = list()
68
+
69
+ ori_pb_boxes_count = list()
70
+ for ph_boxes in ori_phrases_boxes:
71
+ ori_pb_boxes_count.append(len(ph_boxes))
72
+
73
+ strat_point = 0
74
+ for pb_boxes_num in ori_pb_boxes_count:
75
+ sub_boxes = list()
76
+ for i in range(strat_point, strat_point + pb_boxes_num):
77
+ sub_boxes.append(flatten_boxes[i])
78
+
79
+ strat_point += pb_boxes_num
80
+ phrases_boxes.append(sub_boxes)
81
+
82
+ pb_boxes_count = list()
83
+ for ph_boxes in phrases_boxes:
84
+ pb_boxes_count.append(len(ph_boxes))
85
+
86
+ assert pb_boxes_count == ori_pb_boxes_count
87
+
88
+ return phrases_boxes
89
+
90
+
91
+ def list_inorder(listed_files, flag_str):
92
+ """ " List the files in order based on the file name"""
93
+ filtered_listed_files = [fn for fn in listed_files if flag_str in fn]
94
+ listed_files = sorted(filtered_listed_files, key=lambda x: x.strip().split(".")[0])
95
+ return listed_files
96
+
97
+
98
+ def copy_files(src_files, dst_dir):
99
+ """copy files from src to dst"""
100
+ for file in src_files:
101
+ shutil.copy(file, dst_dir)
102
+
103
+
104
+ def union_shuffled_lists(src_lists):
105
+ """shuffle the lists"""
106
+ for i in range(1, len(src_lists)):
107
+ assert len(src_lists[i]) == len(src_lists[i - 1])
108
+ processed = np.random.permutation(len(src_lists[0]))
109
+
110
+ return [np.array(ele)[processed] for ele in src_lists]
111
+
112
+
113
+ def read_anno_file(anno_file_path):
114
+ _, tail = os.path.split(anno_file_path)
115
+ file_type = tail.split(".")[-1]
116
+
117
+ if file_type == "json":
118
+ with open(anno_file_path, "r") as anno_file:
119
+ annos_list = json.load(anno_file)
120
+ else:
121
+ with open(anno_file_path, "r") as anno_file:
122
+ annos_list = anno_file.readlines()
123
+
124
+ return annos_list