plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,330 @@
|
|
1
|
+
"""
|
2
|
+
|
3
|
+
Although the name of this dataset is referitgame, it actually contains four datasets:
|
4
|
+
- ReferItGame http://tamaraberg.com/referitgame/.
|
5
|
+
Then, refer-based datasets http://vision2.cs.unc.edu/refer/:
|
6
|
+
- RefCOCO
|
7
|
+
- RefCOCO+
|
8
|
+
- RefCOCOg
|
9
|
+
|
10
|
+
The 'split_config' needed to be set to support the following datasets:
|
11
|
+
- referitgame: 130,525 expressions for referring to 96,654 objects in 19,894 images.
|
12
|
+
The samples are splited into three subsets. train/54,127 referring expressions.
|
13
|
+
test/5,842, val/60,103 referring expressions.
|
14
|
+
- refcoco: 142,209 refer expressions for 50,000 objects.
|
15
|
+
- refcoco+: 141,564 expressions for 49,856 objects.
|
16
|
+
- refcocog (google): 25,799 images with 49,856 referred objects and expressions.
|
17
|
+
|
18
|
+
The output sample structure of this data is consistent with that
|
19
|
+
in the flickr30k entities dataset.
|
20
|
+
|
21
|
+
"""
|
22
|
+
|
23
|
+
import logging
|
24
|
+
|
25
|
+
import collections
|
26
|
+
|
27
|
+
import torch
|
28
|
+
import cv2
|
29
|
+
|
30
|
+
from plato.config import Config
|
31
|
+
from plato.datasources import multimodal_base
|
32
|
+
from plato.datasources.multimodal_base import TextData, BoxData, TargetData
|
33
|
+
from plato.datasources.datalib.refer_utils import referitgame_utils
|
34
|
+
|
35
|
+
SplitedDatasets = collections.namedtuple(
|
36
|
+
"SplitedDatasets",
|
37
|
+
[
|
38
|
+
"train_ref_ids",
|
39
|
+
"val_ref_ids",
|
40
|
+
"test_ref_ids",
|
41
|
+
"testA_ref_ids",
|
42
|
+
"testB_ref_ids",
|
43
|
+
"testC_ref_ids",
|
44
|
+
],
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
def collate_fn(batch):
|
49
|
+
"""[The construction of the loaded batch of data]
|
50
|
+
|
51
|
+
Args:
|
52
|
+
batch ([list]): [a list in which each element contains the data for one task,
|
53
|
+
assert len(batch) == number of tasks,
|
54
|
+
assert len(batch[i]) == 6 that is the output of \
|
55
|
+
create_task_examples_data function]
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
[batch]: [return the original batch of data directly]
|
59
|
+
"""
|
60
|
+
return batch
|
61
|
+
|
62
|
+
|
63
|
+
class ReferItGameDataset(multimodal_base.MultiModalDataset):
|
64
|
+
"""Prepares the Flickr30K Entities dataset."""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
dataset_info,
|
69
|
+
phase,
|
70
|
+
phase_info,
|
71
|
+
modality_sampler=None,
|
72
|
+
transform_image_dec_func=None,
|
73
|
+
transform_text_func=None,
|
74
|
+
):
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
self.phase = phase
|
78
|
+
self.phase_multimodal_data_record = dataset_info
|
79
|
+
self.phase_info = phase_info
|
80
|
+
self.transform_image_dec_func = transform_image_dec_func
|
81
|
+
self.transform_text_func = transform_text_func
|
82
|
+
|
83
|
+
# The phase data record in referitgame is a list,
|
84
|
+
# each item contains information of one image as
|
85
|
+
# presented in line-258.
|
86
|
+
self.phase_samples_name = self.phase_multimodal_data_record
|
87
|
+
|
88
|
+
self.supported_modalities = ["rgb", "text"]
|
89
|
+
|
90
|
+
# Default, utilizing the full modalities
|
91
|
+
if modality_sampler is None:
|
92
|
+
self.modality_sampler = self.supported_modalities
|
93
|
+
else:
|
94
|
+
self.modality_sampler = modality_sampler
|
95
|
+
|
96
|
+
def __len__(self):
|
97
|
+
return len(self.phase_multimodal_data_record)
|
98
|
+
|
99
|
+
def get_one_multimodal_sample(self, sample_idx):
|
100
|
+
[
|
101
|
+
image_id,
|
102
|
+
_,
|
103
|
+
caption,
|
104
|
+
caption_phrases,
|
105
|
+
caption_phrase_bboxs,
|
106
|
+
caption_phrases_cate,
|
107
|
+
caption_phrases_cate_id,
|
108
|
+
] = self.phase_multimodal_data_record[sample_idx]
|
109
|
+
|
110
|
+
_ = image_id
|
111
|
+
image_data = self.phase_info.loadImgsData(image_id)[0]
|
112
|
+
|
113
|
+
image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
|
114
|
+
|
115
|
+
_ = image_data.copy()
|
116
|
+
|
117
|
+
caption = (
|
118
|
+
caption
|
119
|
+
if any(isinstance(boxes_i, list) for boxes_i in caption)
|
120
|
+
else [caption]
|
121
|
+
)
|
122
|
+
caption_phrase_bboxs = (
|
123
|
+
caption_phrase_bboxs
|
124
|
+
if any(isinstance(boxes_i, list) for boxes_i in caption_phrase_bboxs)
|
125
|
+
else [caption_phrase_bboxs]
|
126
|
+
)
|
127
|
+
caption_phrases = (
|
128
|
+
caption_phrases
|
129
|
+
if any(isinstance(boxes_i, list) for boxes_i in caption_phrases)
|
130
|
+
else [caption_phrases]
|
131
|
+
)
|
132
|
+
caption_phrases_cate = (
|
133
|
+
caption_phrases_cate
|
134
|
+
if any(isinstance(boxes_i, list) for boxes_i in caption_phrases_cate)
|
135
|
+
else [[caption_phrases_cate]]
|
136
|
+
)
|
137
|
+
caption_phrases_cate_id = (
|
138
|
+
caption_phrases_cate_id
|
139
|
+
if isinstance(caption_phrases_cate_id, list)
|
140
|
+
else [caption_phrases_cate_id]
|
141
|
+
)
|
142
|
+
|
143
|
+
assert len(caption_phrase_bboxs) == len(caption_phrases)
|
144
|
+
if self.transform_image_dec_func is not None:
|
145
|
+
transformed = self.transform_image_dec_func(
|
146
|
+
image=image_data,
|
147
|
+
bboxes=caption_phrase_bboxs,
|
148
|
+
category_ids=caption_phrases_cate_id,
|
149
|
+
)
|
150
|
+
|
151
|
+
image_data = transformed["image"]
|
152
|
+
image_data = torch.from_numpy(image_data)
|
153
|
+
caption_phrase_bboxs = transformed["bboxes"]
|
154
|
+
|
155
|
+
if self.transform_text_func is not None:
|
156
|
+
caption_phrases = self.transform_text_func(caption_phrases)
|
157
|
+
|
158
|
+
caption_phrase_bboxs = [
|
159
|
+
caption_phrase_bboxs
|
160
|
+
] # convert to the standard structure
|
161
|
+
|
162
|
+
text_data = TextData(caption=caption, caption_phrases=caption_phrases)
|
163
|
+
box_data = BoxData(caption_phrase_bboxs=caption_phrase_bboxs)
|
164
|
+
taget_data = TargetData(
|
165
|
+
caption_phrases_cate=caption_phrases_cate,
|
166
|
+
caption_phrases_cate_id=caption_phrases_cate_id,
|
167
|
+
)
|
168
|
+
|
169
|
+
return {
|
170
|
+
"rgb": image_data,
|
171
|
+
"text": text_data,
|
172
|
+
"box": box_data,
|
173
|
+
"target": taget_data,
|
174
|
+
}
|
175
|
+
|
176
|
+
|
177
|
+
class DataSource(multimodal_base.MultiModalDataSource):
|
178
|
+
"""The ReferItGame dataset."""
|
179
|
+
|
180
|
+
def __init__(self, **kwargs):
|
181
|
+
super().__init__()
|
182
|
+
|
183
|
+
self.split_configs = ["refcoco", "refcoco+", "refcocog"]
|
184
|
+
self.modality_names = ["image", "text"]
|
185
|
+
|
186
|
+
self.data_name = Config().data.dataname
|
187
|
+
self.base_coco = Config().data.base_coco_images_path
|
188
|
+
self.data_source = "COCO2017"
|
189
|
+
|
190
|
+
# Obtain which split to use:
|
191
|
+
# refclef, refcoco, refcoco+ and refcocog
|
192
|
+
self.split_config = Config().data.split_config
|
193
|
+
# Obtain which specific setting to use:
|
194
|
+
# unc, google
|
195
|
+
self.split_name = Config().data.split_name
|
196
|
+
if self.split_config not in self.split_configs:
|
197
|
+
logging.info(
|
198
|
+
"%s does not exist in the official configurations %s.",
|
199
|
+
self.split_config,
|
200
|
+
self.split_configs,
|
201
|
+
)
|
202
|
+
|
203
|
+
_path = Config().params["data_path"]
|
204
|
+
self._data_path_process(data_path=_path, base_data_name=self.data_name)
|
205
|
+
base_data_path = self.mm_data_info["data_path"]
|
206
|
+
|
207
|
+
# raw coco images path
|
208
|
+
coco_raw_imgs_path = self.base_coco
|
209
|
+
if self._exists(coco_raw_imgs_path):
|
210
|
+
logging.info(
|
211
|
+
"Successfully connecting the source COCO2017 images data from the path %s",
|
212
|
+
coco_raw_imgs_path,
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
logging.info(
|
216
|
+
"Fail to connect the source COCO2017 images data from the path %s",
|
217
|
+
coco_raw_imgs_path,
|
218
|
+
)
|
219
|
+
|
220
|
+
# download the public official code and the required config
|
221
|
+
download_split_url = (
|
222
|
+
Config().data.download_splits_base_url + self.split_config + ".zip"
|
223
|
+
)
|
224
|
+
for dd_url in [download_split_url]:
|
225
|
+
self._download_arrange_data(
|
226
|
+
download_url_address=dd_url, data_path=base_data_path
|
227
|
+
)
|
228
|
+
|
229
|
+
self._dataset_refer = referitgame_utils.REFER(
|
230
|
+
data_root=base_data_path,
|
231
|
+
image_dataroot=coco_raw_imgs_path,
|
232
|
+
dataset=self.split_config,
|
233
|
+
splitBy=self.split_name,
|
234
|
+
) # default is unc or google
|
235
|
+
|
236
|
+
self._splited_referids_holder = {}
|
237
|
+
self._connect_to_splits()
|
238
|
+
|
239
|
+
def _connect_to_splits(self):
|
240
|
+
split_types = SplitedDatasets._fields
|
241
|
+
for split_type in split_types:
|
242
|
+
formatted_split_type = split_type.split("_", maxsplit=1)[0]
|
243
|
+
self._splited_referids_holder[formatted_split_type] = (
|
244
|
+
self._dataset_refer.getRefIds(split=formatted_split_type)
|
245
|
+
)
|
246
|
+
|
247
|
+
def get_phase_data(self, phase):
|
248
|
+
"""Get phrases from the raw data"""
|
249
|
+
mode_refer_ids = self._splited_referids_holder[phase]
|
250
|
+
|
251
|
+
mode_elements_holder = {}
|
252
|
+
mode_flatten_emelemts = []
|
253
|
+
|
254
|
+
for refer_id in mode_refer_ids:
|
255
|
+
ref = self._dataset_refer.loadRefs(refer_id)[0]
|
256
|
+
image_id = ref["image_id"]
|
257
|
+
image_file_path = self._dataset_refer.loadImgspath(image_id)
|
258
|
+
caption_phrases_cate = self._dataset_refer.Cats[ref["category_id"]]
|
259
|
+
caption_phrases_cate_id = ref["category_id"]
|
260
|
+
|
261
|
+
mode_elements_holder[refer_id] = {}
|
262
|
+
mode_elements_holder[refer_id]["image_id"] = image_id
|
263
|
+
mode_elements_holder[refer_id]["image_file_path"] = image_file_path
|
264
|
+
|
265
|
+
mode_elements_holder[refer_id]["sentences"] = []
|
266
|
+
for send in ref["sentences"]:
|
267
|
+
caption = send["tokens"]
|
268
|
+
caption_phrase = send["tokens"]
|
269
|
+
|
270
|
+
# images_data = dt_refer.loadImgData(image_id) # a list
|
271
|
+
caption_phrase_bboxs = self._dataset_refer.getRefBox(
|
272
|
+
ref["ref_id"]
|
273
|
+
) # [x, y, w, h]
|
274
|
+
# convert to [xmin, ymin, xmax, ymax]
|
275
|
+
caption_phrase_bboxs = [
|
276
|
+
caption_phrase_bboxs[0],
|
277
|
+
caption_phrase_bboxs[1],
|
278
|
+
caption_phrase_bboxs[0] + caption_phrase_bboxs[2],
|
279
|
+
caption_phrase_bboxs[1] + caption_phrase_bboxs[3],
|
280
|
+
]
|
281
|
+
|
282
|
+
sent_infos = {
|
283
|
+
"caption": caption,
|
284
|
+
"caption_phrase": caption_phrase,
|
285
|
+
"caption_phrase_bboxs": caption_phrase_bboxs,
|
286
|
+
"caption_phrases_cate": caption_phrases_cate,
|
287
|
+
"caption_phrases_cate_id": caption_phrases_cate_id,
|
288
|
+
}
|
289
|
+
|
290
|
+
mode_elements_holder[refer_id]["sentences"].append(sent_infos)
|
291
|
+
|
292
|
+
mode_flatten_emelemts.append(
|
293
|
+
[
|
294
|
+
image_id,
|
295
|
+
image_file_path,
|
296
|
+
caption,
|
297
|
+
caption_phrase,
|
298
|
+
caption_phrase_bboxs,
|
299
|
+
caption_phrases_cate,
|
300
|
+
caption_phrases_cate_id,
|
301
|
+
]
|
302
|
+
)
|
303
|
+
|
304
|
+
return mode_elements_holder, mode_flatten_emelemts
|
305
|
+
|
306
|
+
def get_phase_dataset(self, phase, modality_sampler):
|
307
|
+
"""Obtain the dataset for the specific phase"""
|
308
|
+
_, mode_flatten_emelemts = self.get_phase_data(phase)
|
309
|
+
|
310
|
+
dataset = ReferItGameDataset(
|
311
|
+
dataset_info=mode_flatten_emelemts,
|
312
|
+
phase_info=self._dataset_refer,
|
313
|
+
phase=phase,
|
314
|
+
modality_sampler=modality_sampler,
|
315
|
+
)
|
316
|
+
return dataset
|
317
|
+
|
318
|
+
def get_train_set(self, modality_sampler=None):
|
319
|
+
"""Obtains the training dataset."""
|
320
|
+
phase = "train"
|
321
|
+
|
322
|
+
self.trainset = self.get_phase_dataset(phase, modality_sampler)
|
323
|
+
return self.trainset
|
324
|
+
|
325
|
+
def get_test_set(self, modality_sampler=None):
|
326
|
+
"""Obtains the validation dataset."""
|
327
|
+
phase = "test"
|
328
|
+
|
329
|
+
self.testset = self.get_phase_dataset(phase, modality_sampler)
|
330
|
+
return self.testset
|
@@ -0,0 +1,119 @@
|
|
1
|
+
"""
|
2
|
+
Having a registry of all available classes is convenient for retrieving an instance
|
3
|
+
based on a configuration at run-time.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
from plato.datasources import (
|
11
|
+
mnist,
|
12
|
+
fashion_mnist,
|
13
|
+
emnist,
|
14
|
+
cifar10,
|
15
|
+
cifar100,
|
16
|
+
cinic10,
|
17
|
+
purchase,
|
18
|
+
texas,
|
19
|
+
huggingface,
|
20
|
+
pascal_voc,
|
21
|
+
tiny_imagenet,
|
22
|
+
femnist,
|
23
|
+
feature,
|
24
|
+
qoenflx,
|
25
|
+
celeba,
|
26
|
+
stl10,
|
27
|
+
)
|
28
|
+
|
29
|
+
registered_datasources = {
|
30
|
+
"MNIST": mnist,
|
31
|
+
"FashionMNIST": fashion_mnist,
|
32
|
+
"EMNIST": emnist,
|
33
|
+
"CIFAR10": cifar10,
|
34
|
+
"CIFAR100": cifar100,
|
35
|
+
"CINIC10": cinic10,
|
36
|
+
"Purchase": purchase,
|
37
|
+
"Texas": texas,
|
38
|
+
"HuggingFace": huggingface,
|
39
|
+
"PASCAL_VOC": pascal_voc,
|
40
|
+
"TinyImageNet": tiny_imagenet,
|
41
|
+
"Feature": feature,
|
42
|
+
"QoENFLX": qoenflx,
|
43
|
+
"CelebA": celeba,
|
44
|
+
"STL10": stl10,
|
45
|
+
}
|
46
|
+
|
47
|
+
registered_partitioned_datasources = {"FEMNIST": femnist}
|
48
|
+
|
49
|
+
|
50
|
+
def get(client_id: int = 0, **kwargs):
|
51
|
+
"""Get the data source with the provided name."""
|
52
|
+
datasource_name = (
|
53
|
+
kwargs["datasource_name"]
|
54
|
+
if "datasource_name" in kwargs
|
55
|
+
else Config().data.datasource
|
56
|
+
)
|
57
|
+
|
58
|
+
logging.info("Data source: %s", datasource_name)
|
59
|
+
|
60
|
+
if datasource_name == "kinetics700":
|
61
|
+
from plato.datasources import kinetics
|
62
|
+
|
63
|
+
return kinetics.DataSource(**kwargs)
|
64
|
+
|
65
|
+
if datasource_name == "Gym":
|
66
|
+
from plato.datasources import gym
|
67
|
+
|
68
|
+
return gym.DataSource(**kwargs)
|
69
|
+
|
70
|
+
if datasource_name == "Flickr30KE":
|
71
|
+
from plato.datasources import flickr30k_entities
|
72
|
+
|
73
|
+
return flickr30k_entities.DataSource(**kwargs)
|
74
|
+
|
75
|
+
if datasource_name == "ReferItGame":
|
76
|
+
from plato.datasources import referitgame
|
77
|
+
|
78
|
+
return referitgame.DataSource(**kwargs)
|
79
|
+
|
80
|
+
if datasource_name == "COCO":
|
81
|
+
from plato.datasources import coco
|
82
|
+
|
83
|
+
return coco.DataSource(**kwargs)
|
84
|
+
|
85
|
+
if datasource_name == "YOLOv8":
|
86
|
+
from plato.datasources import yolov8
|
87
|
+
|
88
|
+
return yolov8.DataSource(**kwargs)
|
89
|
+
elif datasource_name in registered_datasources:
|
90
|
+
dataset = registered_datasources[datasource_name].DataSource(**kwargs)
|
91
|
+
elif datasource_name in registered_partitioned_datasources:
|
92
|
+
dataset = registered_partitioned_datasources[datasource_name].DataSource(
|
93
|
+
client_id, **kwargs
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
raise ValueError(f"No such data source: {datasource_name}")
|
97
|
+
|
98
|
+
return dataset
|
99
|
+
|
100
|
+
|
101
|
+
def get_input_shape():
|
102
|
+
"""Get the input shape of data source with the provided name."""
|
103
|
+
datasource_name = Config().data.datasource
|
104
|
+
|
105
|
+
logging.info("Data source: %s", Config().data.datasource)
|
106
|
+
if Config().data.datasource == "YOLO":
|
107
|
+
from plato.datasources import yolo
|
108
|
+
|
109
|
+
return yolo.DataSource.input_shape()
|
110
|
+
elif datasource_name in registered_datasources:
|
111
|
+
input_shape = registered_datasources[datasource_name].DataSource.input_shape()
|
112
|
+
elif datasource_name in registered_partitioned_datasources:
|
113
|
+
input_shape = registered_partitioned_datasources[
|
114
|
+
datasource_name
|
115
|
+
].DataSource.input_shape()
|
116
|
+
else:
|
117
|
+
raise ValueError(f"No such data source: {datasource_name}")
|
118
|
+
|
119
|
+
return input_shape
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
A self-supervised learning dataset working as a wrapper to add the SSL data
|
3
|
+
transform to the datasource of Plato.
|
4
|
+
|
5
|
+
To allow SSL transform to use the desired parameters, one should place the
|
6
|
+
'data_transforms' sub-block under the 'algorithm' block in the config file.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from lightly import transforms
|
10
|
+
|
11
|
+
from plato.datasources import base
|
12
|
+
from plato.datasources import registry as datasources_registry
|
13
|
+
from plato.config import Config
|
14
|
+
|
15
|
+
|
16
|
+
# The normalizations for different datasets
|
17
|
+
MNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]}
|
18
|
+
FashionMNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]}
|
19
|
+
CIFAR10_NORMALIZE = {"mean": [0.491, 0.482, 0.447], "std": [0.247, 0.243, 0.262]}
|
20
|
+
CIFAR100_NORMALIZE = {"mean": [0.491, 0.482, 0.447], "std": [0.247, 0.243, 0.262]}
|
21
|
+
IMAGENET_NORMALIZE = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
|
22
|
+
STL10_NORMALIZE = {"mean": [0.4914, 0.4823, 0.4466], "std": [0.247, 0.243, 0.261]}
|
23
|
+
|
24
|
+
dataset_normalizations = {
|
25
|
+
"MNIST": MNIST_NORMALIZE,
|
26
|
+
"FashionMNIST": FashionMNIST_NORMALIZE,
|
27
|
+
"CIFAR10": CIFAR10_NORMALIZE,
|
28
|
+
"CIFAR100": CIFAR100_NORMALIZE,
|
29
|
+
"IMAGENET": IMAGENET_NORMALIZE,
|
30
|
+
"STL10": STL10_NORMALIZE,
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
# All transforms for different SSL algorithms
|
35
|
+
registered_transforms = {
|
36
|
+
"SimCLR": transforms.SimCLRTransform,
|
37
|
+
"DINO": transforms.DINOTransform,
|
38
|
+
"MAE": transforms.MAETransform,
|
39
|
+
"MoCoV1": transforms.MoCoV1Transform,
|
40
|
+
"MoCoV2": transforms.MoCoV2Transform,
|
41
|
+
"MSN": transforms.MSNTransform,
|
42
|
+
"PIRL": transforms.PIRLTransform,
|
43
|
+
"SimSiam": transforms.SimSiamTransform,
|
44
|
+
"SMoG": transforms.SMoGTransform,
|
45
|
+
"SwaV": transforms.SwaVTransform,
|
46
|
+
"VICReg": transforms.VICRegTransform,
|
47
|
+
"VICRegL": transforms.VICRegLTransform,
|
48
|
+
"FastSiam": transforms.FastSiamTransform,
|
49
|
+
}
|
50
|
+
|
51
|
+
|
52
|
+
def get_transforms():
|
53
|
+
"""Obtain train/test transforms for the corresponding data."""
|
54
|
+
|
55
|
+
# Get the transforms details set in the config file
|
56
|
+
transforms_config = Config().algorithm.data_transforms._asdict()
|
57
|
+
|
58
|
+
# Set the data transform, which will be used as parameters to define the
|
59
|
+
# SSL transform in registered_transforms
|
60
|
+
data_transforms = {}
|
61
|
+
if "train_transform" in transforms_config:
|
62
|
+
transform_config = transforms_config["train_transform"]._asdict()
|
63
|
+
transform_name = transform_config["name"]
|
64
|
+
transform_params = transform_config["parameters"]._asdict()
|
65
|
+
|
66
|
+
# Get the data normalization for the datasource
|
67
|
+
datasource_name = Config().data.datasource
|
68
|
+
transform_params["normalize"] = dataset_normalizations[datasource_name]
|
69
|
+
# Get the SSL transform
|
70
|
+
if transform_name in registered_transforms:
|
71
|
+
dataset_transform = registered_transforms[transform_name](
|
72
|
+
**transform_params
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
raise ValueError(f"No such data source: {transform_name}")
|
76
|
+
|
77
|
+
# Insert the obtained transform to the data_transforms.
|
78
|
+
# It is used by the datasource of Plato to get the train/test set.
|
79
|
+
data_transforms.update({"train_transform": dataset_transform})
|
80
|
+
|
81
|
+
return data_transforms
|
82
|
+
|
83
|
+
|
84
|
+
# pylint: disable=abstract-method
|
85
|
+
class SSLDataSource(base.DataSource):
|
86
|
+
"""
|
87
|
+
An SSL datasource to define the dataSource for self-supervised learning.
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(self):
|
91
|
+
super().__init__()
|
92
|
+
|
93
|
+
# Get the transforms for the data
|
94
|
+
data_transforms = get_transforms()
|
95
|
+
|
96
|
+
self.datasource = datasources_registry.get(**data_transforms)
|
97
|
+
self.trainset = self.datasource.trainset
|
98
|
+
self.testset = self.datasource.testset
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
The STL-10 dataset from the torchvision package.
|
3
|
+
The details of this data can be found on the websites:
|
4
|
+
https://cs.stanford.edu/~acoates/stl10/
|
5
|
+
and
|
6
|
+
https://www.kaggle.com/datasets/jessicali9530/stl10
|
7
|
+
.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from torch.utils.data.dataset import Dataset
|
11
|
+
from torchvision import datasets, transforms
|
12
|
+
|
13
|
+
from plato.config import Config
|
14
|
+
from plato.datasources import base
|
15
|
+
|
16
|
+
|
17
|
+
class STL10Dataset(Dataset):
|
18
|
+
"""Prepares the STL10 dataset for the subsequence usage.
|
19
|
+
The class annotation of the STL10 dataset is denoted as
|
20
|
+
labels instead of targets used by subsequence learning
|
21
|
+
of the Plato.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, dataset):
|
25
|
+
self.dataset = dataset
|
26
|
+
|
27
|
+
# obtain the raw data for subsequence
|
28
|
+
# usage, such as the self-supervised learning
|
29
|
+
self.data = self.dataset.data
|
30
|
+
self.targets = self.dataset.labels
|
31
|
+
self.target_transform = self.dataset.target_transform
|
32
|
+
self.classes = self.dataset.classes
|
33
|
+
|
34
|
+
def __getitem__(self, index):
|
35
|
+
return self.dataset[index]
|
36
|
+
|
37
|
+
def __len__(self):
|
38
|
+
return len(self.dataset)
|
39
|
+
|
40
|
+
|
41
|
+
class DataSource(base.DataSource):
|
42
|
+
"""The STL-10 dataset."""
|
43
|
+
|
44
|
+
def __init__(self, **kwargs):
|
45
|
+
super().__init__()
|
46
|
+
_path = Config().params["data_path"]
|
47
|
+
|
48
|
+
normalize = transforms.Normalize(
|
49
|
+
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
|
50
|
+
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
|
51
|
+
)
|
52
|
+
train_transform = (
|
53
|
+
kwargs["train_transform"]
|
54
|
+
if "train_transform" in kwargs
|
55
|
+
else (
|
56
|
+
transforms.Compose(
|
57
|
+
[
|
58
|
+
transforms.RandomCrop(96, padding=4),
|
59
|
+
transforms.RandomHorizontalFlip(),
|
60
|
+
transforms.ToTensor(),
|
61
|
+
normalize,
|
62
|
+
]
|
63
|
+
)
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
test_transform = (
|
68
|
+
kwargs["test_transform"]
|
69
|
+
if "test_transform" in kwargs
|
70
|
+
else (
|
71
|
+
transforms.Compose(
|
72
|
+
[
|
73
|
+
transforms.ToTensor(),
|
74
|
+
normalize,
|
75
|
+
]
|
76
|
+
)
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
stl10_trainset = datasets.STL10(
|
81
|
+
root=_path, split="train", download=True, transform=train_transform
|
82
|
+
)
|
83
|
+
stl10_unlabeled_set = datasets.STL10(
|
84
|
+
root=_path, split="unlabeled", download=True, transform=train_transform
|
85
|
+
)
|
86
|
+
|
87
|
+
stl10_testset = datasets.STL10(
|
88
|
+
root=_path, split="test", download=True, transform=test_transform
|
89
|
+
)
|
90
|
+
|
91
|
+
self.trainset = STL10Dataset(stl10_trainset)
|
92
|
+
self.unlabeledset = STL10Dataset(stl10_unlabeled_set)
|
93
|
+
self.testset = STL10Dataset(stl10_testset)
|
94
|
+
|
95
|
+
def num_train_examples(self):
|
96
|
+
return len(self.trainset)
|
97
|
+
|
98
|
+
def num_test_examples(self):
|
99
|
+
return len(self.testset)
|
100
|
+
|
101
|
+
def get_unlabeled_set(self):
|
102
|
+
"""Obtains the unlabeled dataset."""
|
103
|
+
return self.unlabeledset
|