plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,362 @@
|
|
1
|
+
"""
|
2
|
+
The Flickr30K Entities dataset.
|
3
|
+
|
4
|
+
The data structure and setting follow:
|
5
|
+
"http://bryanplummer.com/Flickr30kEntities/".
|
6
|
+
|
7
|
+
We utilize the official splits that contain:
|
8
|
+
- train: 29783 images,
|
9
|
+
- val: 1000 images,
|
10
|
+
- test: 1000 images
|
11
|
+
|
12
|
+
The file structure of this dataset is:
|
13
|
+
- Images (jpg): the raw images
|
14
|
+
- Annotations (xml): the bounding boxes
|
15
|
+
- Sentence (txt): captions of the image
|
16
|
+
|
17
|
+
The data structure under the 'data/' is:
|
18
|
+
├── Flickr30KEntities # root dir of Flickr30K Entities dataset
|
19
|
+
│ ├── Flickr30KEntitiesRaw # Raw images/annotations and the official splits
|
20
|
+
│ ├── train # data dir for the train phase
|
21
|
+
│ │ └── train_Annotations
|
22
|
+
│ │ └── train_Images
|
23
|
+
│ │ └── train_Sentences
|
24
|
+
│ └── test
|
25
|
+
│ └── val
|
26
|
+
|
27
|
+
|
28
|
+
Detailed loaded sample structure:
|
29
|
+
|
30
|
+
One sample is presented as the dict type:
|
31
|
+
- rgb: the image data.
|
32
|
+
- text:
|
33
|
+
- caption : a nested list, such as
|
34
|
+
[['The woman is applying mascara while looking in the mirror.']],
|
35
|
+
- caption_phrases: a nested list, each item is the list contains
|
36
|
+
the phrases of the caption, such as:
|
37
|
+
[['Military personnel'], ['greenish gray uniforms'], ['matching hats']]
|
38
|
+
- box:
|
39
|
+
- caption_phrase_bboxs: a 2-depth nested list, each item is a list that
|
40
|
+
contains boxes of the corresponding phrase, such as:
|
41
|
+
[[[295, 130, 366, 244], [209, 123, 300, 246], [347, 1, 439, 236]],
|
42
|
+
[[0, 21, 377, 220]], [[0, 209, 214, 332]]]
|
43
|
+
- target:
|
44
|
+
- caption_phrases_cate: a nested list, each item is a string that
|
45
|
+
presents the categories of the phrase, such as:
|
46
|
+
[['people'], ['bodyparts'], ['other']].
|
47
|
+
|
48
|
+
- caption_phrases_cate_id: a list, each item is a int that shows
|
49
|
+
the integar/str of the phrase, such as:
|
50
|
+
['121973', '121976', '121975']
|
51
|
+
|
52
|
+
One batch of samples is presented as a list,
|
53
|
+
For example, the corresponding caption_phrase_bboxs in one batch is:
|
54
|
+
[
|
55
|
+
[[[295, 130, 366, 244], [209, 123, 300, 246], [347, 1, 439, 236]], [[0, 21, 377, 220]],
|
56
|
+
[[0, 209, 214, 332]]], - batch-1
|
57
|
+
[[[90, 68, 325, 374]], [[118, 64, 192, 128]]], - batch-1
|
58
|
+
[[[1, 0, 148, 451]], [[153, 148, 400, 413]], [[374, 320, 450, 440]]], - batch-1
|
59
|
+
]
|
60
|
+
"""
|
61
|
+
|
62
|
+
import json
|
63
|
+
import logging
|
64
|
+
import os
|
65
|
+
|
66
|
+
import torch
|
67
|
+
import skimage.io as io
|
68
|
+
import cv2
|
69
|
+
|
70
|
+
from plato.config import Config
|
71
|
+
from plato.datasources import multimodal_base
|
72
|
+
from plato.datasources.multimodal_base import TextData, BoxData, TargetData
|
73
|
+
from plato.datasources.datalib import data_utils
|
74
|
+
from plato.datasources.datalib import flickr30kE_utils
|
75
|
+
|
76
|
+
|
77
|
+
def collate_fn(batch):
|
78
|
+
"""The construction of the loaded batch of data
|
79
|
+
|
80
|
+
Args:
|
81
|
+
batch (list): [a list in which each element contains the data for one task,
|
82
|
+
assert len(batch) == number of tasks,
|
83
|
+
assert len(batch[i]) == 6]
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
[batch]: [return the original batch of data directly]
|
87
|
+
"""
|
88
|
+
return batch
|
89
|
+
|
90
|
+
|
91
|
+
class Flickr30KEDataset(multimodal_base.MultiModalDataset):
|
92
|
+
"""Prepare the Flickr30K Entities dataset."""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
dataset_info,
|
97
|
+
phase,
|
98
|
+
phase_info,
|
99
|
+
data_types,
|
100
|
+
modality_sampler=None,
|
101
|
+
transform_image_dec_func=None,
|
102
|
+
transform_text_func=None,
|
103
|
+
):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
self.phase = phase
|
107
|
+
self.phase_multimodal_data_record = dataset_info
|
108
|
+
self.phase_info = phase_info
|
109
|
+
self.data_types = data_types
|
110
|
+
self.transform_image_dec_func = transform_image_dec_func
|
111
|
+
self.transform_text_func = transform_text_func
|
112
|
+
|
113
|
+
self.phase_samples_name = list(self.phase_multimodal_data_record.keys())
|
114
|
+
|
115
|
+
self.supported_modalities = ["rgb", "text"]
|
116
|
+
|
117
|
+
# default utilizing the full modalities
|
118
|
+
if modality_sampler is None:
|
119
|
+
self.modality_sampler = self.supported_modalities
|
120
|
+
else:
|
121
|
+
self.modality_sampler = modality_sampler
|
122
|
+
|
123
|
+
def __len__(self):
|
124
|
+
return len(self.phase_multimodal_data_record)
|
125
|
+
|
126
|
+
def get_sample_image_data(self, image_id):
|
127
|
+
"""Get one image data as the sample"""
|
128
|
+
# get the image data
|
129
|
+
image_phase_path = self.phase_info[self.data_types[0]]["path"]
|
130
|
+
image_phase_format = self.phase_info[self.data_types[0]]["format"]
|
131
|
+
|
132
|
+
image_data = io.imread(
|
133
|
+
os.path.join(image_phase_path, str(image_id) + image_phase_format)
|
134
|
+
)
|
135
|
+
image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
|
136
|
+
|
137
|
+
return image_data
|
138
|
+
|
139
|
+
def extract_sample_anno_data(self, image_anno_sent):
|
140
|
+
"""Extract the annotation."""
|
141
|
+
sentence = image_anno_sent["sentence"] # a string
|
142
|
+
sentence_phrases = image_anno_sent["sentence_phrases"] # a list
|
143
|
+
sentence_phrases_type = image_anno_sent[
|
144
|
+
"sentence_phrases_type"
|
145
|
+
] # a nested list
|
146
|
+
sentence_phrases_id = image_anno_sent["sentence_phrases_id"] # a list
|
147
|
+
sentence_phrases_boxes = image_anno_sent[
|
148
|
+
"sentence_phrases_boxes"
|
149
|
+
] # a nested list
|
150
|
+
|
151
|
+
return (
|
152
|
+
sentence,
|
153
|
+
sentence_phrases,
|
154
|
+
sentence_phrases_type,
|
155
|
+
sentence_phrases_id,
|
156
|
+
sentence_phrases_boxes,
|
157
|
+
)
|
158
|
+
|
159
|
+
def get_one_multimodal_sample(self, sample_idx):
|
160
|
+
"""Obtain one sample from the Flickr30K Entities dataset."""
|
161
|
+
samle_retrieval_name = self.phase_samples_name[sample_idx]
|
162
|
+
image_file_name = os.path.basename(samle_retrieval_name)
|
163
|
+
image_id = os.path.splitext(image_file_name)[0]
|
164
|
+
|
165
|
+
image_data = self.get_sample_image_data(image_id)
|
166
|
+
|
167
|
+
image_anno_sent = self.phase_multimodal_data_record[samle_retrieval_name]
|
168
|
+
|
169
|
+
(
|
170
|
+
sentence,
|
171
|
+
sentence_phrases,
|
172
|
+
sentence_phrases_type,
|
173
|
+
sentence_phrases_id,
|
174
|
+
sentence_phrases_boxes,
|
175
|
+
) = self.extract_sample_anno_data(image_anno_sent)
|
176
|
+
|
177
|
+
caption = (
|
178
|
+
sentence
|
179
|
+
if any(isinstance(iter_i, list) for iter_i in sentence)
|
180
|
+
else [[sentence]]
|
181
|
+
)
|
182
|
+
flatten_caption_phrase_bboxs = [
|
183
|
+
box for boxes in sentence_phrases_boxes for box in boxes
|
184
|
+
]
|
185
|
+
# ['The woman', 'mascara', 'the mirror']
|
186
|
+
caption_phrases = [[phrase] for phrase in sentence_phrases]
|
187
|
+
caption_phrases_cate = sentence_phrases_type
|
188
|
+
caption_phrases_cate_id = sentence_phrases_id
|
189
|
+
|
190
|
+
if self.transform_image_dec_func is not None:
|
191
|
+
transformed = self.transform_image_dec_func(
|
192
|
+
image=image_data,
|
193
|
+
bboxes=flatten_caption_phrase_bboxs,
|
194
|
+
category_ids=range(len(flatten_caption_phrase_bboxs)),
|
195
|
+
)
|
196
|
+
|
197
|
+
image_data = transformed["image"]
|
198
|
+
image_data = torch.from_numpy(image_data)
|
199
|
+
flatten_caption_phrase_bboxs = transformed["bboxes"]
|
200
|
+
caption_phrase_bboxs = flickr30kE_utils.phrase_boxes_alignment(
|
201
|
+
flatten_caption_phrase_bboxs, sentence_phrases_boxes
|
202
|
+
)
|
203
|
+
|
204
|
+
else:
|
205
|
+
caption_phrase_bboxs = sentence_phrases_boxes
|
206
|
+
|
207
|
+
if self.transform_text_func is not None:
|
208
|
+
caption_phrases = self.transform_text_func(caption_phrases)
|
209
|
+
|
210
|
+
text_data = TextData(caption=caption, caption_phrases=caption_phrases)
|
211
|
+
box_data = BoxData(caption_phrase_bboxs=caption_phrase_bboxs)
|
212
|
+
taget_data = TargetData(
|
213
|
+
caption_phrases_cate=caption_phrases_cate,
|
214
|
+
caption_phrases_cate_id=caption_phrases_cate_id,
|
215
|
+
)
|
216
|
+
|
217
|
+
return {
|
218
|
+
"rgb": image_data,
|
219
|
+
"text": text_data,
|
220
|
+
"box": box_data,
|
221
|
+
"target": taget_data,
|
222
|
+
}
|
223
|
+
|
224
|
+
|
225
|
+
class DataSource(multimodal_base.MultiModalDataSource):
|
226
|
+
"""The Flickr30K Entities dataset."""
|
227
|
+
|
228
|
+
def __init__(self, **kwargs):
|
229
|
+
super().__init__()
|
230
|
+
|
231
|
+
self.data_name = Config().data.dataname
|
232
|
+
|
233
|
+
self.modality_names = ["image", "text"]
|
234
|
+
|
235
|
+
_path = Config().params["data_path"]
|
236
|
+
self._data_path_process(data_path=_path, base_data_name=self.data_name)
|
237
|
+
|
238
|
+
raw_data_name = self.data_name + "Raw"
|
239
|
+
base_data_path = self.mm_data_info["data_path"]
|
240
|
+
|
241
|
+
download_url = Config().data.download_url
|
242
|
+
|
243
|
+
self._download_arrange_data(
|
244
|
+
download_url_address=download_url,
|
245
|
+
data_path=base_data_path,
|
246
|
+
extract_to_dir=base_data_path,
|
247
|
+
)
|
248
|
+
|
249
|
+
# define the path of different data source,
|
250
|
+
# the annotation is .xml, the sentence is in .txt
|
251
|
+
self.raw_data_types = ["Flickr30k_images", "Annotations", "Sentences"]
|
252
|
+
self.raw_data_file_format = [".jpg", ".xml", ".txt"]
|
253
|
+
self.data_types = ["Images", "Annotations", "Sentences"]
|
254
|
+
|
255
|
+
# extract the data information and structure
|
256
|
+
for raw_type_idx, raw_type in enumerate(self.raw_data_types):
|
257
|
+
raw_file_format = self.raw_data_file_format[raw_type_idx]
|
258
|
+
data_type = self.data_types[raw_type_idx]
|
259
|
+
|
260
|
+
raw_type_path = os.path.join(base_data_path, raw_data_name, raw_type)
|
261
|
+
|
262
|
+
self.mm_data_info[data_type] = dict()
|
263
|
+
self.mm_data_info[data_type]["path"] = raw_type_path
|
264
|
+
self.mm_data_info[data_type]["format"] = raw_file_format
|
265
|
+
self.mm_data_info[data_type]["num_samples"] = len(os.listdir(raw_type_path))
|
266
|
+
|
267
|
+
# generate path/type information for splits
|
268
|
+
for split_type in list(self.splits_info.keys()):
|
269
|
+
self.splits_info[split_type]["split_file"] = os.path.join(
|
270
|
+
base_data_path, raw_data_name, split_type + ".txt"
|
271
|
+
)
|
272
|
+
split_path = self.splits_info[split_type]["path"]
|
273
|
+
for dt_type_idx, dt_type in enumerate(self.data_types):
|
274
|
+
dt_type_format = self.raw_data_file_format[dt_type_idx]
|
275
|
+
|
276
|
+
self.splits_info[split_type][dt_type] = dict()
|
277
|
+
self.splits_info[split_type][dt_type]["path"] = os.path.join(
|
278
|
+
split_path, ("{}_{}").format(split_type, dt_type)
|
279
|
+
)
|
280
|
+
self.splits_info[split_type][dt_type]["format"] = dt_type_format
|
281
|
+
|
282
|
+
# distribution data to splits
|
283
|
+
self.create_splits_data()
|
284
|
+
|
285
|
+
# generate the splits information txt for further utilization
|
286
|
+
flickr30kE_utils.integrate_data_to_json(
|
287
|
+
splits_info=self.splits_info,
|
288
|
+
mm_data_info=self.mm_data_info,
|
289
|
+
data_types=self.data_types,
|
290
|
+
split_wise=True,
|
291
|
+
globally=True,
|
292
|
+
)
|
293
|
+
|
294
|
+
def create_splits_data(self):
|
295
|
+
"""Create datasets for different splits"""
|
296
|
+
# saveing the images and entities to the corresponding directory
|
297
|
+
for split_type in list(self.splits_info.keys()):
|
298
|
+
logging.info("Creating split %s data..........", split_type)
|
299
|
+
# obtain the split data information
|
300
|
+
# 0. getting the data
|
301
|
+
split_info_file = self.splits_info[split_type]["split_file"]
|
302
|
+
with open(split_info_file, "r") as loaded_file:
|
303
|
+
split_data_samples = [
|
304
|
+
sample_id.split("\n")[0] for sample_id in loaded_file.readlines()
|
305
|
+
]
|
306
|
+
self.splits_info[split_type]["num_samples"] = len(split_data_samples)
|
307
|
+
|
308
|
+
# 1. create directory for the splited data if necessary
|
309
|
+
for dt_type in self.data_types:
|
310
|
+
split_dt_type_path = self.splits_info[split_type][dt_type]["path"]
|
311
|
+
|
312
|
+
if not self._exists(split_dt_type_path):
|
313
|
+
os.makedirs(split_dt_type_path, exist_ok=True)
|
314
|
+
else:
|
315
|
+
logging.info("The path %s exists.", split_dt_type_path)
|
316
|
+
continue
|
317
|
+
|
318
|
+
raw_data_type_path = self.mm_data_info[dt_type]["path"]
|
319
|
+
raw_data_format = self.mm_data_info[dt_type]["format"]
|
320
|
+
split_samples_path = [
|
321
|
+
os.path.join(raw_data_type_path, sample_id + raw_data_format)
|
322
|
+
for sample_id in split_data_samples
|
323
|
+
]
|
324
|
+
# 2. saving the splited data into the target file
|
325
|
+
data_utils.copy_files(split_samples_path, split_dt_type_path)
|
326
|
+
|
327
|
+
logging.info("Done.")
|
328
|
+
|
329
|
+
def get_phase_data_info(self, phase):
|
330
|
+
"""Obtain the data information for the required phrase"""
|
331
|
+
path = self.splits_info[phase]["path"]
|
332
|
+
save_path = os.path.join(path, phase + "_integrated_data.json")
|
333
|
+
with open(save_path, "r") as outfile:
|
334
|
+
phase_data_info = json.load(outfile)
|
335
|
+
return phase_data_info
|
336
|
+
|
337
|
+
def get_phase_dataset(self, phase, modality_sampler):
|
338
|
+
"""Obtain the dataset for the specific phase"""
|
339
|
+
phase_data_info = self.get_phase_data_info(phase)
|
340
|
+
phase_split_info = self.splits_info[phase]
|
341
|
+
dataset = Flickr30KEDataset(
|
342
|
+
dataset_info=phase_data_info,
|
343
|
+
phase_info=phase_split_info,
|
344
|
+
data_types=self.data_types,
|
345
|
+
phase=phase,
|
346
|
+
modality_sampler=modality_sampler,
|
347
|
+
)
|
348
|
+
return dataset
|
349
|
+
|
350
|
+
def get_train_set(self, modality_sampler=None):
|
351
|
+
"""Obtains the training dataset."""
|
352
|
+
phase = "train"
|
353
|
+
|
354
|
+
self.trainset = self.get_phase_dataset(phase, modality_sampler)
|
355
|
+
return self.trainset
|
356
|
+
|
357
|
+
def get_test_set(self, modality_sampler=None):
|
358
|
+
"""Obtains the validation dataset."""
|
359
|
+
phase = "test"
|
360
|
+
|
361
|
+
self.testset = self.get_phase_dataset(phase, modality_sampler)
|
362
|
+
return self.testset
|