plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,328 @@
|
|
1
|
+
"""
|
2
|
+
Base class for multimodal datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import subprocess
|
9
|
+
from collections import namedtuple
|
10
|
+
|
11
|
+
import torch
|
12
|
+
from torchvision.datasets.utils import download_url, extract_archive
|
13
|
+
from torchvision.datasets.utils import download_file_from_google_drive
|
14
|
+
|
15
|
+
from plato.datasources import base
|
16
|
+
|
17
|
+
TextData = namedtuple("TextData", ["caption", "caption_phrases"])
|
18
|
+
BoxData = namedtuple("BoxData", ["caption_phrase_bboxs"])
|
19
|
+
TargetData = namedtuple(
|
20
|
+
"TargetData", ["caption_phrases_cate", "caption_phrases_cate_id"]
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class MultiModalDataSource(base.DataSource):
|
25
|
+
"""
|
26
|
+
The training or testing dataset that accommodates custom augmentation and transforms.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
super().__init__()
|
31
|
+
|
32
|
+
# data name
|
33
|
+
self.data_name = ""
|
34
|
+
|
35
|
+
# the text name of the contained modalities
|
36
|
+
self.modality_names = []
|
37
|
+
|
38
|
+
# define the information container for the source data
|
39
|
+
# - source_data_path: the original downloaded data
|
40
|
+
# - data_path: the source data used for the model
|
41
|
+
# For some datasets, we directly utilize the data_path as
|
42
|
+
# there is no need to process the original downloaded data to put them
|
43
|
+
# in the data_path dir.
|
44
|
+
self.mm_data_info = {"source_data_path": "", "data_path": ""}
|
45
|
+
|
46
|
+
# define the paths for the splited root data - train, test, and val
|
47
|
+
self.splits_info = {
|
48
|
+
"train": {"path": "", "split_anno_file": ""},
|
49
|
+
"test": {"path": "", "split_anno_file": ""},
|
50
|
+
"val": {"path": "", "split_anno_file": ""},
|
51
|
+
}
|
52
|
+
|
53
|
+
def set_modality_format(self, modality_name):
|
54
|
+
"""An interface to set the modality name
|
55
|
+
Thus, calling this func to obtain the modality name
|
56
|
+
in all parts of the class to achieve the consistency
|
57
|
+
"""
|
58
|
+
if modality_name in ["rgb", "flow"]:
|
59
|
+
modality_format = "rawframes"
|
60
|
+
else: # convert to plurality
|
61
|
+
modality_format = modality_name + "s"
|
62
|
+
|
63
|
+
return modality_format
|
64
|
+
|
65
|
+
def set_modality_path_key_format(self, modality_name):
|
66
|
+
"""An interface to set the modality path
|
67
|
+
Thus, calling this func to obtain the modality path
|
68
|
+
in all parts of the class to achieve the consistency
|
69
|
+
"""
|
70
|
+
modality_format = self.set_modality_format(modality_name)
|
71
|
+
|
72
|
+
return modality_format + "_" + "path"
|
73
|
+
|
74
|
+
def _create_modalities_path(self, modality_names=None):
|
75
|
+
if modality_names is None:
|
76
|
+
assert len(self.modality_names) != 0
|
77
|
+
modality_names = self.modality_names
|
78
|
+
|
79
|
+
for split_type in list(self.splits_info.keys()):
|
80
|
+
split_path = self.splits_info[split_type]["path"]
|
81
|
+
for modality_nm in modality_names:
|
82
|
+
modality_format = self.set_modality_format(modality_nm)
|
83
|
+
split_modality_path = os.path.join(split_path, modality_format)
|
84
|
+
# modality data dir
|
85
|
+
modality_path_format = self.set_modality_path_key_format(modality_nm)
|
86
|
+
self.splits_info[split_type][modality_path_format] = split_modality_path
|
87
|
+
if not os.path.exists(split_modality_path):
|
88
|
+
try:
|
89
|
+
os.makedirs(split_modality_path)
|
90
|
+
except FileExistsError:
|
91
|
+
pass
|
92
|
+
|
93
|
+
def _data_path_process(
|
94
|
+
self,
|
95
|
+
data_path,
|
96
|
+
base_data_name=None, # the base directory for the data
|
97
|
+
): # the directory name of the working data
|
98
|
+
"""Generate the data structure based on the defined data path"""
|
99
|
+
|
100
|
+
# Create the full path by introducing the project path
|
101
|
+
base_data_path = os.path.join(data_path, base_data_name)
|
102
|
+
|
103
|
+
if not os.path.exists(base_data_path):
|
104
|
+
os.makedirs(base_data_path)
|
105
|
+
|
106
|
+
#
|
107
|
+
self.mm_data_info["data_path"] = base_data_path
|
108
|
+
|
109
|
+
# create the split dirs for current dataset
|
110
|
+
for split_type in list(self.splits_info.keys()):
|
111
|
+
split_path = os.path.join(base_data_path, split_type)
|
112
|
+
self.splits_info[split_type]["path"] = split_path
|
113
|
+
if not os.path.exists(split_path):
|
114
|
+
try:
|
115
|
+
os.makedirs(split_path)
|
116
|
+
except FileExistsError:
|
117
|
+
pass
|
118
|
+
|
119
|
+
def _download_arrange_data(
|
120
|
+
self,
|
121
|
+
download_url_address,
|
122
|
+
data_path,
|
123
|
+
extract_to_dir=None,
|
124
|
+
obtained_file_name=None,
|
125
|
+
):
|
126
|
+
"""Download the raw data and arrange the data"""
|
127
|
+
# Extract to the same dir as the download dir
|
128
|
+
if extract_to_dir is None:
|
129
|
+
extract_to_dir = data_path
|
130
|
+
|
131
|
+
download_file_name = os.path.basename(download_url_address)
|
132
|
+
download_file_path = os.path.join(data_path, download_file_name)
|
133
|
+
|
134
|
+
download_extracted_file_name = download_file_name.split(".")[0]
|
135
|
+
download_extracted_path = os.path.join(
|
136
|
+
extract_to_dir, download_extracted_file_name
|
137
|
+
)
|
138
|
+
# Download the raw data if necessary
|
139
|
+
if not self._exists(download_file_path):
|
140
|
+
logging.info("Downloading the %s data.....", download_file_name)
|
141
|
+
download_url(
|
142
|
+
url=download_url_address, root=data_path, filename=obtained_file_name
|
143
|
+
)
|
144
|
+
|
145
|
+
# Extract the data to the specific dir
|
146
|
+
if ".zip" in download_file_name or ".tar.gz" in download_file_name:
|
147
|
+
if not self._exists(download_extracted_path):
|
148
|
+
logging.info("Extracting data to %s dir.....", extract_to_dir)
|
149
|
+
extract_archive(
|
150
|
+
from_path=download_file_path,
|
151
|
+
to_path=extract_to_dir,
|
152
|
+
remove_finished=False,
|
153
|
+
)
|
154
|
+
|
155
|
+
return download_extracted_file_name
|
156
|
+
|
157
|
+
def _download_google_driver_arrange_data(
|
158
|
+
self,
|
159
|
+
download_file_id,
|
160
|
+
extract_download_file_name,
|
161
|
+
data_path,
|
162
|
+
):
|
163
|
+
download_data_file_name = extract_download_file_name + ".zip"
|
164
|
+
download_data_path = os.path.join(data_path, download_data_file_name)
|
165
|
+
extract_data_path = os.path.join(data_path, extract_download_file_name)
|
166
|
+
if not self._exists(download_data_path):
|
167
|
+
logging.info("Downloading the data to %s", download_data_path)
|
168
|
+
download_file_from_google_drive(
|
169
|
+
file_id=download_file_id,
|
170
|
+
root=data_path,
|
171
|
+
filename=download_data_file_name,
|
172
|
+
)
|
173
|
+
if not self._exists(extract_data_path):
|
174
|
+
extract_archive(
|
175
|
+
from_path=download_data_path, to_path=data_path, remove_finished=True
|
176
|
+
)
|
177
|
+
|
178
|
+
def _file_exists(self, tg_file_name, search_path, is_partial_name=True):
|
179
|
+
"""Judge whether the input file exists in the search_path."""
|
180
|
+
# the tg_file_name matches one file if it match part of the file name
|
181
|
+
if is_partial_name:
|
182
|
+
is_included = lambda src_f_name: tg_file_name in src_f_name
|
183
|
+
else:
|
184
|
+
is_included = lambda src_f_name: tg_file_name == src_f_name
|
185
|
+
exists = any(is_included(f_name) for f_name in os.listdir(search_path))
|
186
|
+
|
187
|
+
return exists
|
188
|
+
|
189
|
+
def _exists(self, target_path):
|
190
|
+
"""Does the input path/file exist and does the file contain useful data?"""
|
191
|
+
if not os.path.exists(target_path):
|
192
|
+
logging.info("The path %s does not exist.", target_path)
|
193
|
+
return False
|
194
|
+
|
195
|
+
# remove all .DS_Store files
|
196
|
+
command = ["find", ".", "-name", '".DS_Store"', "-delete"]
|
197
|
+
command = " ".join(command)
|
198
|
+
# cmd = f"find . -name ".DS_Store" -delete"
|
199
|
+
subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
|
200
|
+
|
201
|
+
def get_size(folder):
|
202
|
+
# get size
|
203
|
+
size = 0
|
204
|
+
for ele in os.scandir(folder):
|
205
|
+
if not ele.name.startswith("."):
|
206
|
+
size += os.path.getsize(ele)
|
207
|
+
return size
|
208
|
+
|
209
|
+
def is_contain_useful_file(target_dir):
|
210
|
+
"""Return True once reaching one useful file"""
|
211
|
+
for __, __, files in os.walk(target_dir):
|
212
|
+
for file in files:
|
213
|
+
# whether a useful file
|
214
|
+
if not file.startswith("."):
|
215
|
+
return True
|
216
|
+
return False
|
217
|
+
|
218
|
+
if os.path.isdir(target_path):
|
219
|
+
if get_size(target_path) == 0 or not is_contain_useful_file(target_path):
|
220
|
+
logging.info("The path %s exists but contains no data.", target_path)
|
221
|
+
return False
|
222
|
+
|
223
|
+
return True
|
224
|
+
|
225
|
+
logging.info("The file %s exists.", target_path)
|
226
|
+
return True
|
227
|
+
|
228
|
+
def num_modalities(self) -> int:
|
229
|
+
"""The number of modalities."""
|
230
|
+
return len(self.modality_names)
|
231
|
+
|
232
|
+
@abstractmethod
|
233
|
+
def get_phase_dataset(self, phase, modality_sampler):
|
234
|
+
"""Obtain the dataset with the modaltiy_sampler for the
|
235
|
+
specific phase (train/test/val)"""
|
236
|
+
raise NotImplementedError("Please implement the 'get_phase_dataset' method.")
|
237
|
+
|
238
|
+
@abstractmethod
|
239
|
+
def get_train_set(self, modality_sampler):
|
240
|
+
"""Obtain the train dataset with the modaltiy_sampler"""
|
241
|
+
raise NotImplementedError("Please implement the 'get_train_set' method.")
|
242
|
+
|
243
|
+
@abstractmethod
|
244
|
+
def get_test_set(self, modality_sampler):
|
245
|
+
"""Obtain the test dataset with the modaltiy_sampler."""
|
246
|
+
raise NotImplementedError("Please implement the 'get_test_set' method.")
|
247
|
+
|
248
|
+
|
249
|
+
class MultiModalDataset(torch.utils.data.Dataset):
|
250
|
+
"""The base interface for multimodal data."""
|
251
|
+
|
252
|
+
def __init__(self):
|
253
|
+
self.phase = None # the 'train' , 'test', 'val'
|
254
|
+
|
255
|
+
# The recorded samples for current dataset:
|
256
|
+
# In flickr20K entities dataset, this presents as:
|
257
|
+
# this is a dict in which key is the 'sample name/id' ...
|
258
|
+
# the values are the sample's information,
|
259
|
+
# for example: the annotation with its bounding boxes ...
|
260
|
+
# In the Kinetics, this presents as:
|
261
|
+
# this is a dict:
|
262
|
+
# {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
|
263
|
+
self.phase_multimodal_data_record = None
|
264
|
+
|
265
|
+
# Detailed information in selected split:
|
266
|
+
# i.e., path, path for different modalities, etc.
|
267
|
+
self.phase_info = None
|
268
|
+
# the data types included,
|
269
|
+
# e.g. in flickr30k entities, ["Images", "Annotations", "Sentences"]
|
270
|
+
self.data_types = None
|
271
|
+
|
272
|
+
# the name of the modalities in the dataset
|
273
|
+
self.modalities_name = None
|
274
|
+
|
275
|
+
# the sampler for modalities,
|
276
|
+
# specific modalities can be masked by this sampler
|
277
|
+
self.modality_sampler = None
|
278
|
+
# transformation func for image and text if provided
|
279
|
+
self.transform_image_dec_func = None
|
280
|
+
self.transform_text_func = None
|
281
|
+
|
282
|
+
# the basic modalities
|
283
|
+
self.basic_modalities = ["rgb", "flow", "text", "audio"]
|
284
|
+
# the additional data/annotations
|
285
|
+
self.basic_items = ["box", "target"]
|
286
|
+
|
287
|
+
@abstractmethod
|
288
|
+
def get_targets(self):
|
289
|
+
"""Obtain the labels of samples in current phase dataset."""
|
290
|
+
raise NotImplementedError("Please Implement the 'targets' function")
|
291
|
+
|
292
|
+
@abstractmethod
|
293
|
+
def get_one_multimodal_sample(self, sample_idx):
|
294
|
+
"""Get the sample containing different modalities.
|
295
|
+
Different multi-modal datasets should have their
|
296
|
+
personal 'get_one_multimodal_sample' method.
|
297
|
+
|
298
|
+
|
299
|
+
Args:
|
300
|
+
sample_idx (int): the index of the sample
|
301
|
+
|
302
|
+
Output:
|
303
|
+
a dict containing different modalities, the
|
304
|
+
key of the dict is the modality name that should
|
305
|
+
be included in the basic_modalities and basic_items.
|
306
|
+
"""
|
307
|
+
raise NotImplementedError(
|
308
|
+
"Please Implement the 'get_one_multimodal_sample(self, sample_idx)' function"
|
309
|
+
)
|
310
|
+
|
311
|
+
def __getitem__(self, sample_idx):
|
312
|
+
"""Get the sample for either training or testing given index."""
|
313
|
+
sampled_multimodal_data = self.get_one_multimodal_sample(sample_idx)
|
314
|
+
|
315
|
+
# utilize the modality to mask specific modalities
|
316
|
+
sampled_modality_data = {}
|
317
|
+
for item_name, item_data in sampled_multimodal_data.items():
|
318
|
+
# maintain the modality data based on the sampler
|
319
|
+
# maintain the external data
|
320
|
+
if item_name in self.modality_sampler or item_name in self.basic_items:
|
321
|
+
sampled_modality_data[item_name] = item_data
|
322
|
+
|
323
|
+
return sampled_modality_data
|
324
|
+
|
325
|
+
@abstractmethod
|
326
|
+
def __len__(self):
|
327
|
+
"""obtain the length of the multi-modal data"""
|
328
|
+
raise NotImplementedError("Please Implement this method")
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""
|
2
|
+
The PASCAL VOC dataset for image segmentation.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from torchvision import datasets, transforms
|
6
|
+
from plato.config import Config
|
7
|
+
|
8
|
+
from plato.datasources import base
|
9
|
+
|
10
|
+
|
11
|
+
class DataSource(base.DataSource):
|
12
|
+
"""The PASCAL dataset."""
|
13
|
+
|
14
|
+
def __init__(self, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
_path = Config().params["data_path"]
|
17
|
+
self.mean = [0.45734706, 0.43338275, 0.40058118]
|
18
|
+
self.std = [0.23965294, 0.23532275, 0.2398498]
|
19
|
+
|
20
|
+
train_transform = (
|
21
|
+
kwargs["train_transform"]
|
22
|
+
if train_transform in kwargs
|
23
|
+
else (
|
24
|
+
transforms.Compose(
|
25
|
+
[
|
26
|
+
transforms.Resize((96, 96)),
|
27
|
+
transforms.ToTensor(),
|
28
|
+
]
|
29
|
+
)
|
30
|
+
)
|
31
|
+
)
|
32
|
+
|
33
|
+
test_transform = train_transform
|
34
|
+
|
35
|
+
self.trainset = datasets.VOCSegmentation(
|
36
|
+
root=_path,
|
37
|
+
year="2012",
|
38
|
+
image_set="train",
|
39
|
+
download=True,
|
40
|
+
transform=train_transform,
|
41
|
+
target_transform=train_transform,
|
42
|
+
)
|
43
|
+
self.testset = datasets.VOCSegmentation(
|
44
|
+
root=_path,
|
45
|
+
year="2012",
|
46
|
+
image_set="val",
|
47
|
+
download=True,
|
48
|
+
transform=test_transform,
|
49
|
+
target_transform=test_transform,
|
50
|
+
)
|
51
|
+
|
52
|
+
def num_train_examples(self):
|
53
|
+
return len(self.trainset)
|
54
|
+
|
55
|
+
def num_test_examples(self):
|
56
|
+
return len(self.testset)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
"""
|
2
|
+
The Purchase100 dataset.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import logging
|
7
|
+
import urllib
|
8
|
+
import tarfile
|
9
|
+
import torch
|
10
|
+
import numpy as np
|
11
|
+
from torch.utils import data
|
12
|
+
from plato.config import Config
|
13
|
+
from plato.datasources import base
|
14
|
+
|
15
|
+
|
16
|
+
class DataSource(base.DataSource):
|
17
|
+
"""The Purchase100 dataset."""
|
18
|
+
|
19
|
+
def __init__(self, **kwargs):
|
20
|
+
super().__init__()
|
21
|
+
root_path = Config().params["data_path"]
|
22
|
+
dataset_path = os.path.join(root_path, "dataset_purchase")
|
23
|
+
if not os.path.isdir(root_path):
|
24
|
+
os.mkdir(root_path)
|
25
|
+
if not os.path.isfile(dataset_path):
|
26
|
+
self.download_dataset(root_path, dataset_path)
|
27
|
+
|
28
|
+
self.trainset, self.testset = self.extract_data(root_path)
|
29
|
+
|
30
|
+
def download_dataset(self, root_path, dataset_path):
|
31
|
+
"""Download the Purchase100 dataset."""
|
32
|
+
logging.info("Downloading the Purchase100 dataset...")
|
33
|
+
filename = "https://www.comp.nus.edu.sg/~reza/files/dataset_purchase.tgz"
|
34
|
+
urllib.request.urlretrieve(
|
35
|
+
filename, os.path.join(root_path, "tmp_purchase.tgz")
|
36
|
+
)
|
37
|
+
logging.info("Dataset downloaded.")
|
38
|
+
tar = tarfile.open(os.path.join(root_path, "tmp_purchase.tgz"))
|
39
|
+
tar.extractall(path=root_path)
|
40
|
+
|
41
|
+
logging.info("Processing the dataset...")
|
42
|
+
data_set = np.genfromtxt(dataset_path, delimiter=",")
|
43
|
+
logging.info("Finish processing the dataset.")
|
44
|
+
|
45
|
+
X = data_set[:, 1:].astype(np.float64)
|
46
|
+
Y = (data_set[:, 0]).astype(np.int32) - 1
|
47
|
+
np.savez(os.path.join(root_path, "purchase_numpy.npz"), X=X, Y=Y)
|
48
|
+
|
49
|
+
def extract_data(self, root_path):
|
50
|
+
"""Extract data."""
|
51
|
+
dataset = np.load(os.path.join(root_path, "purchase_numpy.npz"))
|
52
|
+
|
53
|
+
## randomly shuffle the data
|
54
|
+
X, Y = dataset["X"], dataset["Y"]
|
55
|
+
np.random.seed(0)
|
56
|
+
indices = np.arange(len(X))
|
57
|
+
np.random.shuffle(indices)
|
58
|
+
X, Y = X[indices], Y[indices]
|
59
|
+
|
60
|
+
## extract 20000 data samplers for training and testing respectively
|
61
|
+
num_train = 20000
|
62
|
+
train_data = X[:num_train]
|
63
|
+
test_data = X[num_train : num_train * 2]
|
64
|
+
train_label = Y[:num_train]
|
65
|
+
test_label = Y[num_train : num_train * 2]
|
66
|
+
|
67
|
+
## create datasets
|
68
|
+
train_dataset = VectorDataset(train_data, train_label)
|
69
|
+
test_dataset = VectorDataset(test_data, test_label)
|
70
|
+
|
71
|
+
return train_dataset, test_dataset
|
72
|
+
|
73
|
+
def num_train_examples(self):
|
74
|
+
return 20000
|
75
|
+
|
76
|
+
def num_test_examples(self):
|
77
|
+
return 20000
|
78
|
+
|
79
|
+
|
80
|
+
class VectorDataset(data.Dataset):
|
81
|
+
"""
|
82
|
+
Create a Purchase100 dataset based on features and labels
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(self, features, labels):
|
86
|
+
self.data = torch.stack([torch.FloatTensor(i) for i in features])
|
87
|
+
self.targets = torch.stack([torch.LongTensor([i]) for i in labels])[:, 0]
|
88
|
+
self.classes = [f"Style #{i}" for i in range(100)]
|
89
|
+
|
90
|
+
def __getitem__(self, index):
|
91
|
+
return self.data[index], self.targets[index]
|
92
|
+
|
93
|
+
def __len__(self):
|
94
|
+
return self.data.size(0)
|
@@ -0,0 +1,127 @@
|
|
1
|
+
"""
|
2
|
+
The LIVE Netflix Video QoE datasets.
|
3
|
+
|
4
|
+
For more information about the datasets, refer to
|
5
|
+
https://live.ece.utexas.edu/research/LIVE_NFLXStudy/nflx_index.html.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import copy
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import re
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import scipy.io as sio
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from plato.config import Config
|
18
|
+
from plato.datasources import base
|
19
|
+
|
20
|
+
FEATURE_NAMES = ["VQA", "R$_1$", "R$_2$", "M", "I"]
|
21
|
+
|
22
|
+
|
23
|
+
class QoENFLXDataset(torch.utils.data.Dataset):
|
24
|
+
def __init__(self, dataset):
|
25
|
+
self.dataset = dataset
|
26
|
+
|
27
|
+
def __len__(self):
|
28
|
+
return len(self.dataset)
|
29
|
+
|
30
|
+
def __getitem__(self, idx):
|
31
|
+
VQA = torch.from_numpy(self.dataset[idx, [0]].astype(np.float)).float()
|
32
|
+
R1 = torch.from_numpy(self.dataset[idx, [1]].astype(np.float)).float()
|
33
|
+
R2 = self.dataset[idx, [2]].astype(np.int)
|
34
|
+
M = torch.from_numpy(self.dataset[idx, [3]].astype(np.float)).float()
|
35
|
+
I = torch.from_numpy(self.dataset[idx, [4]].astype(np.float)).float()
|
36
|
+
label = self.dataset[idx, [5]]
|
37
|
+
sample = {"VQA": VQA, "R1": R1, "R2": R2, "Mem": M, "Impair": I, "label": label}
|
38
|
+
|
39
|
+
return sample
|
40
|
+
|
41
|
+
|
42
|
+
class DataSource(base.DataSource):
|
43
|
+
"""A data source for QoENFLX datasets."""
|
44
|
+
|
45
|
+
def __init__(self, **kwargs):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
logging.info("Dataset: QoENFLX")
|
49
|
+
dataset_path = Config().params["data_path"] + "/QoENFLX/VideoATLAS/"
|
50
|
+
db_files = os.listdir(dataset_path)
|
51
|
+
db_files.sort(
|
52
|
+
key=lambda var: [
|
53
|
+
int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
|
54
|
+
]
|
55
|
+
)
|
56
|
+
Nvideos = len(db_files)
|
57
|
+
|
58
|
+
pre_load_train_test_data_LIVE_Netflix = sio.loadmat(
|
59
|
+
Config().params["data_path"]
|
60
|
+
+ "/QoENFLX/TrainingMatrix_LIVENetflix_1000_trials.mat"
|
61
|
+
)["TrainingMatrix_LIVENetflix_1000_trials"]
|
62
|
+
|
63
|
+
# randomly pick a trial out of the 1000
|
64
|
+
nt_rand = np.random.choice(
|
65
|
+
np.shape(pre_load_train_test_data_LIVE_Netflix)[1], 1
|
66
|
+
)
|
67
|
+
n_train = [
|
68
|
+
ind
|
69
|
+
for ind in range(0, Nvideos)
|
70
|
+
if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 1
|
71
|
+
]
|
72
|
+
n_test = [
|
73
|
+
ind
|
74
|
+
for ind in range(0, Nvideos)
|
75
|
+
if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 0
|
76
|
+
]
|
77
|
+
|
78
|
+
X = np.zeros((len(db_files), len(FEATURE_NAMES)))
|
79
|
+
y = np.zeros((len(db_files), 1))
|
80
|
+
|
81
|
+
feature_labels = list()
|
82
|
+
for typ in FEATURE_NAMES:
|
83
|
+
if typ == "VQA":
|
84
|
+
feature_labels.append("STRRED" + "_" + "mean")
|
85
|
+
elif typ == "R$_1$":
|
86
|
+
feature_labels.append("ds_norm")
|
87
|
+
elif typ == "R$_2$":
|
88
|
+
feature_labels.append("ns")
|
89
|
+
elif typ == "M":
|
90
|
+
feature_labels.append("tsl_norm")
|
91
|
+
else:
|
92
|
+
feature_labels.append("lt_norm")
|
93
|
+
|
94
|
+
for i, f in enumerate(db_files):
|
95
|
+
data = sio.loadmat(dataset_path + f)
|
96
|
+
for feat_cnt, feat in enumerate(feature_labels):
|
97
|
+
X[i, feat_cnt] = data[feat]
|
98
|
+
y[i] = data["final_subj_score"]
|
99
|
+
|
100
|
+
X_train_before_scaling = X[n_train, :]
|
101
|
+
X_test_before_scaling = X[n_test, :]
|
102
|
+
y_train = y[n_train]
|
103
|
+
y_test = y[n_test]
|
104
|
+
|
105
|
+
self.trainset = copy.deepcopy(
|
106
|
+
np.concatenate((X_train_before_scaling, y_train), axis=1)
|
107
|
+
)
|
108
|
+
self.testset = copy.deepcopy(
|
109
|
+
np.concatenate((X_test_before_scaling, y_test), axis=1)
|
110
|
+
)
|
111
|
+
|
112
|
+
@staticmethod
|
113
|
+
def get_train_loader(batch_size, trainset, sampler, shuffle=False):
|
114
|
+
"""The custom train loader for QoENFLX."""
|
115
|
+
return torch.utils.data.DataLoader(
|
116
|
+
QoENFLXDataset(trainset),
|
117
|
+
batch_size=batch_size,
|
118
|
+
sampler=sampler,
|
119
|
+
shuffle=shuffle,
|
120
|
+
)
|
121
|
+
|
122
|
+
@staticmethod
|
123
|
+
def get_test_loader(batch_size, testset):
|
124
|
+
"""The custom test loader for QoENFLX."""
|
125
|
+
return torch.utils.data.DataLoader(
|
126
|
+
QoENFLXDataset(testset), batch_size=batch_size, shuffle=False
|
127
|
+
)
|