plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,328 @@
1
+ """
2
+ Base class for multimodal datasets.
3
+ """
4
+
5
+ from abc import abstractmethod
6
+ import logging
7
+ import os
8
+ import subprocess
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ from torchvision.datasets.utils import download_url, extract_archive
13
+ from torchvision.datasets.utils import download_file_from_google_drive
14
+
15
+ from plato.datasources import base
16
+
17
+ TextData = namedtuple("TextData", ["caption", "caption_phrases"])
18
+ BoxData = namedtuple("BoxData", ["caption_phrase_bboxs"])
19
+ TargetData = namedtuple(
20
+ "TargetData", ["caption_phrases_cate", "caption_phrases_cate_id"]
21
+ )
22
+
23
+
24
+ class MultiModalDataSource(base.DataSource):
25
+ """
26
+ The training or testing dataset that accommodates custom augmentation and transforms.
27
+ """
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ # data name
33
+ self.data_name = ""
34
+
35
+ # the text name of the contained modalities
36
+ self.modality_names = []
37
+
38
+ # define the information container for the source data
39
+ # - source_data_path: the original downloaded data
40
+ # - data_path: the source data used for the model
41
+ # For some datasets, we directly utilize the data_path as
42
+ # there is no need to process the original downloaded data to put them
43
+ # in the data_path dir.
44
+ self.mm_data_info = {"source_data_path": "", "data_path": ""}
45
+
46
+ # define the paths for the splited root data - train, test, and val
47
+ self.splits_info = {
48
+ "train": {"path": "", "split_anno_file": ""},
49
+ "test": {"path": "", "split_anno_file": ""},
50
+ "val": {"path": "", "split_anno_file": ""},
51
+ }
52
+
53
+ def set_modality_format(self, modality_name):
54
+ """An interface to set the modality name
55
+ Thus, calling this func to obtain the modality name
56
+ in all parts of the class to achieve the consistency
57
+ """
58
+ if modality_name in ["rgb", "flow"]:
59
+ modality_format = "rawframes"
60
+ else: # convert to plurality
61
+ modality_format = modality_name + "s"
62
+
63
+ return modality_format
64
+
65
+ def set_modality_path_key_format(self, modality_name):
66
+ """An interface to set the modality path
67
+ Thus, calling this func to obtain the modality path
68
+ in all parts of the class to achieve the consistency
69
+ """
70
+ modality_format = self.set_modality_format(modality_name)
71
+
72
+ return modality_format + "_" + "path"
73
+
74
+ def _create_modalities_path(self, modality_names=None):
75
+ if modality_names is None:
76
+ assert len(self.modality_names) != 0
77
+ modality_names = self.modality_names
78
+
79
+ for split_type in list(self.splits_info.keys()):
80
+ split_path = self.splits_info[split_type]["path"]
81
+ for modality_nm in modality_names:
82
+ modality_format = self.set_modality_format(modality_nm)
83
+ split_modality_path = os.path.join(split_path, modality_format)
84
+ # modality data dir
85
+ modality_path_format = self.set_modality_path_key_format(modality_nm)
86
+ self.splits_info[split_type][modality_path_format] = split_modality_path
87
+ if not os.path.exists(split_modality_path):
88
+ try:
89
+ os.makedirs(split_modality_path)
90
+ except FileExistsError:
91
+ pass
92
+
93
+ def _data_path_process(
94
+ self,
95
+ data_path,
96
+ base_data_name=None, # the base directory for the data
97
+ ): # the directory name of the working data
98
+ """Generate the data structure based on the defined data path"""
99
+
100
+ # Create the full path by introducing the project path
101
+ base_data_path = os.path.join(data_path, base_data_name)
102
+
103
+ if not os.path.exists(base_data_path):
104
+ os.makedirs(base_data_path)
105
+
106
+ #
107
+ self.mm_data_info["data_path"] = base_data_path
108
+
109
+ # create the split dirs for current dataset
110
+ for split_type in list(self.splits_info.keys()):
111
+ split_path = os.path.join(base_data_path, split_type)
112
+ self.splits_info[split_type]["path"] = split_path
113
+ if not os.path.exists(split_path):
114
+ try:
115
+ os.makedirs(split_path)
116
+ except FileExistsError:
117
+ pass
118
+
119
+ def _download_arrange_data(
120
+ self,
121
+ download_url_address,
122
+ data_path,
123
+ extract_to_dir=None,
124
+ obtained_file_name=None,
125
+ ):
126
+ """Download the raw data and arrange the data"""
127
+ # Extract to the same dir as the download dir
128
+ if extract_to_dir is None:
129
+ extract_to_dir = data_path
130
+
131
+ download_file_name = os.path.basename(download_url_address)
132
+ download_file_path = os.path.join(data_path, download_file_name)
133
+
134
+ download_extracted_file_name = download_file_name.split(".")[0]
135
+ download_extracted_path = os.path.join(
136
+ extract_to_dir, download_extracted_file_name
137
+ )
138
+ # Download the raw data if necessary
139
+ if not self._exists(download_file_path):
140
+ logging.info("Downloading the %s data.....", download_file_name)
141
+ download_url(
142
+ url=download_url_address, root=data_path, filename=obtained_file_name
143
+ )
144
+
145
+ # Extract the data to the specific dir
146
+ if ".zip" in download_file_name or ".tar.gz" in download_file_name:
147
+ if not self._exists(download_extracted_path):
148
+ logging.info("Extracting data to %s dir.....", extract_to_dir)
149
+ extract_archive(
150
+ from_path=download_file_path,
151
+ to_path=extract_to_dir,
152
+ remove_finished=False,
153
+ )
154
+
155
+ return download_extracted_file_name
156
+
157
+ def _download_google_driver_arrange_data(
158
+ self,
159
+ download_file_id,
160
+ extract_download_file_name,
161
+ data_path,
162
+ ):
163
+ download_data_file_name = extract_download_file_name + ".zip"
164
+ download_data_path = os.path.join(data_path, download_data_file_name)
165
+ extract_data_path = os.path.join(data_path, extract_download_file_name)
166
+ if not self._exists(download_data_path):
167
+ logging.info("Downloading the data to %s", download_data_path)
168
+ download_file_from_google_drive(
169
+ file_id=download_file_id,
170
+ root=data_path,
171
+ filename=download_data_file_name,
172
+ )
173
+ if not self._exists(extract_data_path):
174
+ extract_archive(
175
+ from_path=download_data_path, to_path=data_path, remove_finished=True
176
+ )
177
+
178
+ def _file_exists(self, tg_file_name, search_path, is_partial_name=True):
179
+ """Judge whether the input file exists in the search_path."""
180
+ # the tg_file_name matches one file if it match part of the file name
181
+ if is_partial_name:
182
+ is_included = lambda src_f_name: tg_file_name in src_f_name
183
+ else:
184
+ is_included = lambda src_f_name: tg_file_name == src_f_name
185
+ exists = any(is_included(f_name) for f_name in os.listdir(search_path))
186
+
187
+ return exists
188
+
189
+ def _exists(self, target_path):
190
+ """Does the input path/file exist and does the file contain useful data?"""
191
+ if not os.path.exists(target_path):
192
+ logging.info("The path %s does not exist.", target_path)
193
+ return False
194
+
195
+ # remove all .DS_Store files
196
+ command = ["find", ".", "-name", '".DS_Store"', "-delete"]
197
+ command = " ".join(command)
198
+ # cmd = f"find . -name ".DS_Store" -delete"
199
+ subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
200
+
201
+ def get_size(folder):
202
+ # get size
203
+ size = 0
204
+ for ele in os.scandir(folder):
205
+ if not ele.name.startswith("."):
206
+ size += os.path.getsize(ele)
207
+ return size
208
+
209
+ def is_contain_useful_file(target_dir):
210
+ """Return True once reaching one useful file"""
211
+ for __, __, files in os.walk(target_dir):
212
+ for file in files:
213
+ # whether a useful file
214
+ if not file.startswith("."):
215
+ return True
216
+ return False
217
+
218
+ if os.path.isdir(target_path):
219
+ if get_size(target_path) == 0 or not is_contain_useful_file(target_path):
220
+ logging.info("The path %s exists but contains no data.", target_path)
221
+ return False
222
+
223
+ return True
224
+
225
+ logging.info("The file %s exists.", target_path)
226
+ return True
227
+
228
+ def num_modalities(self) -> int:
229
+ """The number of modalities."""
230
+ return len(self.modality_names)
231
+
232
+ @abstractmethod
233
+ def get_phase_dataset(self, phase, modality_sampler):
234
+ """Obtain the dataset with the modaltiy_sampler for the
235
+ specific phase (train/test/val)"""
236
+ raise NotImplementedError("Please implement the 'get_phase_dataset' method.")
237
+
238
+ @abstractmethod
239
+ def get_train_set(self, modality_sampler):
240
+ """Obtain the train dataset with the modaltiy_sampler"""
241
+ raise NotImplementedError("Please implement the 'get_train_set' method.")
242
+
243
+ @abstractmethod
244
+ def get_test_set(self, modality_sampler):
245
+ """Obtain the test dataset with the modaltiy_sampler."""
246
+ raise NotImplementedError("Please implement the 'get_test_set' method.")
247
+
248
+
249
+ class MultiModalDataset(torch.utils.data.Dataset):
250
+ """The base interface for multimodal data."""
251
+
252
+ def __init__(self):
253
+ self.phase = None # the 'train' , 'test', 'val'
254
+
255
+ # The recorded samples for current dataset:
256
+ # In flickr20K entities dataset, this presents as:
257
+ # this is a dict in which key is the 'sample name/id' ...
258
+ # the values are the sample's information,
259
+ # for example: the annotation with its bounding boxes ...
260
+ # In the Kinetics, this presents as:
261
+ # this is a dict:
262
+ # {"rgb": rgb_dataset, "flow": flow_dataset, "audio": audio_dataset}
263
+ self.phase_multimodal_data_record = None
264
+
265
+ # Detailed information in selected split:
266
+ # i.e., path, path for different modalities, etc.
267
+ self.phase_info = None
268
+ # the data types included,
269
+ # e.g. in flickr30k entities, ["Images", "Annotations", "Sentences"]
270
+ self.data_types = None
271
+
272
+ # the name of the modalities in the dataset
273
+ self.modalities_name = None
274
+
275
+ # the sampler for modalities,
276
+ # specific modalities can be masked by this sampler
277
+ self.modality_sampler = None
278
+ # transformation func for image and text if provided
279
+ self.transform_image_dec_func = None
280
+ self.transform_text_func = None
281
+
282
+ # the basic modalities
283
+ self.basic_modalities = ["rgb", "flow", "text", "audio"]
284
+ # the additional data/annotations
285
+ self.basic_items = ["box", "target"]
286
+
287
+ @abstractmethod
288
+ def get_targets(self):
289
+ """Obtain the labels of samples in current phase dataset."""
290
+ raise NotImplementedError("Please Implement the 'targets' function")
291
+
292
+ @abstractmethod
293
+ def get_one_multimodal_sample(self, sample_idx):
294
+ """Get the sample containing different modalities.
295
+ Different multi-modal datasets should have their
296
+ personal 'get_one_multimodal_sample' method.
297
+
298
+
299
+ Args:
300
+ sample_idx (int): the index of the sample
301
+
302
+ Output:
303
+ a dict containing different modalities, the
304
+ key of the dict is the modality name that should
305
+ be included in the basic_modalities and basic_items.
306
+ """
307
+ raise NotImplementedError(
308
+ "Please Implement the 'get_one_multimodal_sample(self, sample_idx)' function"
309
+ )
310
+
311
+ def __getitem__(self, sample_idx):
312
+ """Get the sample for either training or testing given index."""
313
+ sampled_multimodal_data = self.get_one_multimodal_sample(sample_idx)
314
+
315
+ # utilize the modality to mask specific modalities
316
+ sampled_modality_data = {}
317
+ for item_name, item_data in sampled_multimodal_data.items():
318
+ # maintain the modality data based on the sampler
319
+ # maintain the external data
320
+ if item_name in self.modality_sampler or item_name in self.basic_items:
321
+ sampled_modality_data[item_name] = item_data
322
+
323
+ return sampled_modality_data
324
+
325
+ @abstractmethod
326
+ def __len__(self):
327
+ """obtain the length of the multi-modal data"""
328
+ raise NotImplementedError("Please Implement this method")
@@ -0,0 +1,56 @@
1
+ """
2
+ The PASCAL VOC dataset for image segmentation.
3
+ """
4
+
5
+ from torchvision import datasets, transforms
6
+ from plato.config import Config
7
+
8
+ from plato.datasources import base
9
+
10
+
11
+ class DataSource(base.DataSource):
12
+ """The PASCAL dataset."""
13
+
14
+ def __init__(self, **kwargs):
15
+ super().__init__()
16
+ _path = Config().params["data_path"]
17
+ self.mean = [0.45734706, 0.43338275, 0.40058118]
18
+ self.std = [0.23965294, 0.23532275, 0.2398498]
19
+
20
+ train_transform = (
21
+ kwargs["train_transform"]
22
+ if train_transform in kwargs
23
+ else (
24
+ transforms.Compose(
25
+ [
26
+ transforms.Resize((96, 96)),
27
+ transforms.ToTensor(),
28
+ ]
29
+ )
30
+ )
31
+ )
32
+
33
+ test_transform = train_transform
34
+
35
+ self.trainset = datasets.VOCSegmentation(
36
+ root=_path,
37
+ year="2012",
38
+ image_set="train",
39
+ download=True,
40
+ transform=train_transform,
41
+ target_transform=train_transform,
42
+ )
43
+ self.testset = datasets.VOCSegmentation(
44
+ root=_path,
45
+ year="2012",
46
+ image_set="val",
47
+ download=True,
48
+ transform=test_transform,
49
+ target_transform=test_transform,
50
+ )
51
+
52
+ def num_train_examples(self):
53
+ return len(self.trainset)
54
+
55
+ def num_test_examples(self):
56
+ return len(self.testset)
@@ -0,0 +1,94 @@
1
+ """
2
+ The Purchase100 dataset.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import urllib
8
+ import tarfile
9
+ import torch
10
+ import numpy as np
11
+ from torch.utils import data
12
+ from plato.config import Config
13
+ from plato.datasources import base
14
+
15
+
16
+ class DataSource(base.DataSource):
17
+ """The Purchase100 dataset."""
18
+
19
+ def __init__(self, **kwargs):
20
+ super().__init__()
21
+ root_path = Config().params["data_path"]
22
+ dataset_path = os.path.join(root_path, "dataset_purchase")
23
+ if not os.path.isdir(root_path):
24
+ os.mkdir(root_path)
25
+ if not os.path.isfile(dataset_path):
26
+ self.download_dataset(root_path, dataset_path)
27
+
28
+ self.trainset, self.testset = self.extract_data(root_path)
29
+
30
+ def download_dataset(self, root_path, dataset_path):
31
+ """Download the Purchase100 dataset."""
32
+ logging.info("Downloading the Purchase100 dataset...")
33
+ filename = "https://www.comp.nus.edu.sg/~reza/files/dataset_purchase.tgz"
34
+ urllib.request.urlretrieve(
35
+ filename, os.path.join(root_path, "tmp_purchase.tgz")
36
+ )
37
+ logging.info("Dataset downloaded.")
38
+ tar = tarfile.open(os.path.join(root_path, "tmp_purchase.tgz"))
39
+ tar.extractall(path=root_path)
40
+
41
+ logging.info("Processing the dataset...")
42
+ data_set = np.genfromtxt(dataset_path, delimiter=",")
43
+ logging.info("Finish processing the dataset.")
44
+
45
+ X = data_set[:, 1:].astype(np.float64)
46
+ Y = (data_set[:, 0]).astype(np.int32) - 1
47
+ np.savez(os.path.join(root_path, "purchase_numpy.npz"), X=X, Y=Y)
48
+
49
+ def extract_data(self, root_path):
50
+ """Extract data."""
51
+ dataset = np.load(os.path.join(root_path, "purchase_numpy.npz"))
52
+
53
+ ## randomly shuffle the data
54
+ X, Y = dataset["X"], dataset["Y"]
55
+ np.random.seed(0)
56
+ indices = np.arange(len(X))
57
+ np.random.shuffle(indices)
58
+ X, Y = X[indices], Y[indices]
59
+
60
+ ## extract 20000 data samplers for training and testing respectively
61
+ num_train = 20000
62
+ train_data = X[:num_train]
63
+ test_data = X[num_train : num_train * 2]
64
+ train_label = Y[:num_train]
65
+ test_label = Y[num_train : num_train * 2]
66
+
67
+ ## create datasets
68
+ train_dataset = VectorDataset(train_data, train_label)
69
+ test_dataset = VectorDataset(test_data, test_label)
70
+
71
+ return train_dataset, test_dataset
72
+
73
+ def num_train_examples(self):
74
+ return 20000
75
+
76
+ def num_test_examples(self):
77
+ return 20000
78
+
79
+
80
+ class VectorDataset(data.Dataset):
81
+ """
82
+ Create a Purchase100 dataset based on features and labels
83
+ """
84
+
85
+ def __init__(self, features, labels):
86
+ self.data = torch.stack([torch.FloatTensor(i) for i in features])
87
+ self.targets = torch.stack([torch.LongTensor([i]) for i in labels])[:, 0]
88
+ self.classes = [f"Style #{i}" for i in range(100)]
89
+
90
+ def __getitem__(self, index):
91
+ return self.data[index], self.targets[index]
92
+
93
+ def __len__(self):
94
+ return self.data.size(0)
@@ -0,0 +1,127 @@
1
+ """
2
+ The LIVE Netflix Video QoE datasets.
3
+
4
+ For more information about the datasets, refer to
5
+ https://live.ece.utexas.edu/research/LIVE_NFLXStudy/nflx_index.html.
6
+ """
7
+
8
+ import copy
9
+ import logging
10
+ import os
11
+ import re
12
+
13
+ import numpy as np
14
+ import scipy.io as sio
15
+ import torch
16
+
17
+ from plato.config import Config
18
+ from plato.datasources import base
19
+
20
+ FEATURE_NAMES = ["VQA", "R$_1$", "R$_2$", "M", "I"]
21
+
22
+
23
+ class QoENFLXDataset(torch.utils.data.Dataset):
24
+ def __init__(self, dataset):
25
+ self.dataset = dataset
26
+
27
+ def __len__(self):
28
+ return len(self.dataset)
29
+
30
+ def __getitem__(self, idx):
31
+ VQA = torch.from_numpy(self.dataset[idx, [0]].astype(np.float)).float()
32
+ R1 = torch.from_numpy(self.dataset[idx, [1]].astype(np.float)).float()
33
+ R2 = self.dataset[idx, [2]].astype(np.int)
34
+ M = torch.from_numpy(self.dataset[idx, [3]].astype(np.float)).float()
35
+ I = torch.from_numpy(self.dataset[idx, [4]].astype(np.float)).float()
36
+ label = self.dataset[idx, [5]]
37
+ sample = {"VQA": VQA, "R1": R1, "R2": R2, "Mem": M, "Impair": I, "label": label}
38
+
39
+ return sample
40
+
41
+
42
+ class DataSource(base.DataSource):
43
+ """A data source for QoENFLX datasets."""
44
+
45
+ def __init__(self, **kwargs):
46
+ super().__init__()
47
+
48
+ logging.info("Dataset: QoENFLX")
49
+ dataset_path = Config().params["data_path"] + "/QoENFLX/VideoATLAS/"
50
+ db_files = os.listdir(dataset_path)
51
+ db_files.sort(
52
+ key=lambda var: [
53
+ int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
54
+ ]
55
+ )
56
+ Nvideos = len(db_files)
57
+
58
+ pre_load_train_test_data_LIVE_Netflix = sio.loadmat(
59
+ Config().params["data_path"]
60
+ + "/QoENFLX/TrainingMatrix_LIVENetflix_1000_trials.mat"
61
+ )["TrainingMatrix_LIVENetflix_1000_trials"]
62
+
63
+ # randomly pick a trial out of the 1000
64
+ nt_rand = np.random.choice(
65
+ np.shape(pre_load_train_test_data_LIVE_Netflix)[1], 1
66
+ )
67
+ n_train = [
68
+ ind
69
+ for ind in range(0, Nvideos)
70
+ if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 1
71
+ ]
72
+ n_test = [
73
+ ind
74
+ for ind in range(0, Nvideos)
75
+ if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 0
76
+ ]
77
+
78
+ X = np.zeros((len(db_files), len(FEATURE_NAMES)))
79
+ y = np.zeros((len(db_files), 1))
80
+
81
+ feature_labels = list()
82
+ for typ in FEATURE_NAMES:
83
+ if typ == "VQA":
84
+ feature_labels.append("STRRED" + "_" + "mean")
85
+ elif typ == "R$_1$":
86
+ feature_labels.append("ds_norm")
87
+ elif typ == "R$_2$":
88
+ feature_labels.append("ns")
89
+ elif typ == "M":
90
+ feature_labels.append("tsl_norm")
91
+ else:
92
+ feature_labels.append("lt_norm")
93
+
94
+ for i, f in enumerate(db_files):
95
+ data = sio.loadmat(dataset_path + f)
96
+ for feat_cnt, feat in enumerate(feature_labels):
97
+ X[i, feat_cnt] = data[feat]
98
+ y[i] = data["final_subj_score"]
99
+
100
+ X_train_before_scaling = X[n_train, :]
101
+ X_test_before_scaling = X[n_test, :]
102
+ y_train = y[n_train]
103
+ y_test = y[n_test]
104
+
105
+ self.trainset = copy.deepcopy(
106
+ np.concatenate((X_train_before_scaling, y_train), axis=1)
107
+ )
108
+ self.testset = copy.deepcopy(
109
+ np.concatenate((X_test_before_scaling, y_test), axis=1)
110
+ )
111
+
112
+ @staticmethod
113
+ def get_train_loader(batch_size, trainset, sampler, shuffle=False):
114
+ """The custom train loader for QoENFLX."""
115
+ return torch.utils.data.DataLoader(
116
+ QoENFLXDataset(trainset),
117
+ batch_size=batch_size,
118
+ sampler=sampler,
119
+ shuffle=shuffle,
120
+ )
121
+
122
+ @staticmethod
123
+ def get_test_loader(batch_size, testset):
124
+ """The custom test loader for QoENFLX."""
125
+ return torch.utils.data.DataLoader(
126
+ QoENFLXDataset(testset), batch_size=batch_size, shuffle=False
127
+ )