plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,330 @@
1
+ """
2
+
3
+ Although the name of this dataset is referitgame, it actually contains four datasets:
4
+ - ReferItGame http://tamaraberg.com/referitgame/.
5
+ Then, refer-based datasets http://vision2.cs.unc.edu/refer/:
6
+ - RefCOCO
7
+ - RefCOCO+
8
+ - RefCOCOg
9
+
10
+ The 'split_config' needed to be set to support the following datasets:
11
+ - referitgame: 130,525 expressions for referring to 96,654 objects in 19,894 images.
12
+ The samples are splited into three subsets. train/54,127 referring expressions.
13
+ test/5,842, val/60,103 referring expressions.
14
+ - refcoco: 142,209 refer expressions for 50,000 objects.
15
+ - refcoco+: 141,564 expressions for 49,856 objects.
16
+ - refcocog (google): 25,799 images with 49,856 referred objects and expressions.
17
+
18
+ The output sample structure of this data is consistent with that
19
+ in the flickr30k entities dataset.
20
+
21
+ """
22
+
23
+ import logging
24
+
25
+ import collections
26
+
27
+ import torch
28
+ import cv2
29
+
30
+ from plato.config import Config
31
+ from plato.datasources import multimodal_base
32
+ from plato.datasources.multimodal_base import TextData, BoxData, TargetData
33
+ from plato.datasources.datalib.refer_utils import referitgame_utils
34
+
35
+ SplitedDatasets = collections.namedtuple(
36
+ "SplitedDatasets",
37
+ [
38
+ "train_ref_ids",
39
+ "val_ref_ids",
40
+ "test_ref_ids",
41
+ "testA_ref_ids",
42
+ "testB_ref_ids",
43
+ "testC_ref_ids",
44
+ ],
45
+ )
46
+
47
+
48
+ def collate_fn(batch):
49
+ """[The construction of the loaded batch of data]
50
+
51
+ Args:
52
+ batch ([list]): [a list in which each element contains the data for one task,
53
+ assert len(batch) == number of tasks,
54
+ assert len(batch[i]) == 6 that is the output of \
55
+ create_task_examples_data function]
56
+
57
+ Returns:
58
+ [batch]: [return the original batch of data directly]
59
+ """
60
+ return batch
61
+
62
+
63
+ class ReferItGameDataset(multimodal_base.MultiModalDataset):
64
+ """Prepares the Flickr30K Entities dataset."""
65
+
66
+ def __init__(
67
+ self,
68
+ dataset_info,
69
+ phase,
70
+ phase_info,
71
+ modality_sampler=None,
72
+ transform_image_dec_func=None,
73
+ transform_text_func=None,
74
+ ):
75
+ super().__init__()
76
+
77
+ self.phase = phase
78
+ self.phase_multimodal_data_record = dataset_info
79
+ self.phase_info = phase_info
80
+ self.transform_image_dec_func = transform_image_dec_func
81
+ self.transform_text_func = transform_text_func
82
+
83
+ # The phase data record in referitgame is a list,
84
+ # each item contains information of one image as
85
+ # presented in line-258.
86
+ self.phase_samples_name = self.phase_multimodal_data_record
87
+
88
+ self.supported_modalities = ["rgb", "text"]
89
+
90
+ # Default, utilizing the full modalities
91
+ if modality_sampler is None:
92
+ self.modality_sampler = self.supported_modalities
93
+ else:
94
+ self.modality_sampler = modality_sampler
95
+
96
+ def __len__(self):
97
+ return len(self.phase_multimodal_data_record)
98
+
99
+ def get_one_multimodal_sample(self, sample_idx):
100
+ [
101
+ image_id,
102
+ _,
103
+ caption,
104
+ caption_phrases,
105
+ caption_phrase_bboxs,
106
+ caption_phrases_cate,
107
+ caption_phrases_cate_id,
108
+ ] = self.phase_multimodal_data_record[sample_idx]
109
+
110
+ _ = image_id
111
+ image_data = self.phase_info.loadImgsData(image_id)[0]
112
+
113
+ image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
114
+
115
+ _ = image_data.copy()
116
+
117
+ caption = (
118
+ caption
119
+ if any(isinstance(boxes_i, list) for boxes_i in caption)
120
+ else [caption]
121
+ )
122
+ caption_phrase_bboxs = (
123
+ caption_phrase_bboxs
124
+ if any(isinstance(boxes_i, list) for boxes_i in caption_phrase_bboxs)
125
+ else [caption_phrase_bboxs]
126
+ )
127
+ caption_phrases = (
128
+ caption_phrases
129
+ if any(isinstance(boxes_i, list) for boxes_i in caption_phrases)
130
+ else [caption_phrases]
131
+ )
132
+ caption_phrases_cate = (
133
+ caption_phrases_cate
134
+ if any(isinstance(boxes_i, list) for boxes_i in caption_phrases_cate)
135
+ else [[caption_phrases_cate]]
136
+ )
137
+ caption_phrases_cate_id = (
138
+ caption_phrases_cate_id
139
+ if isinstance(caption_phrases_cate_id, list)
140
+ else [caption_phrases_cate_id]
141
+ )
142
+
143
+ assert len(caption_phrase_bboxs) == len(caption_phrases)
144
+ if self.transform_image_dec_func is not None:
145
+ transformed = self.transform_image_dec_func(
146
+ image=image_data,
147
+ bboxes=caption_phrase_bboxs,
148
+ category_ids=caption_phrases_cate_id,
149
+ )
150
+
151
+ image_data = transformed["image"]
152
+ image_data = torch.from_numpy(image_data)
153
+ caption_phrase_bboxs = transformed["bboxes"]
154
+
155
+ if self.transform_text_func is not None:
156
+ caption_phrases = self.transform_text_func(caption_phrases)
157
+
158
+ caption_phrase_bboxs = [
159
+ caption_phrase_bboxs
160
+ ] # convert to the standard structure
161
+
162
+ text_data = TextData(caption=caption, caption_phrases=caption_phrases)
163
+ box_data = BoxData(caption_phrase_bboxs=caption_phrase_bboxs)
164
+ taget_data = TargetData(
165
+ caption_phrases_cate=caption_phrases_cate,
166
+ caption_phrases_cate_id=caption_phrases_cate_id,
167
+ )
168
+
169
+ return {
170
+ "rgb": image_data,
171
+ "text": text_data,
172
+ "box": box_data,
173
+ "target": taget_data,
174
+ }
175
+
176
+
177
+ class DataSource(multimodal_base.MultiModalDataSource):
178
+ """The ReferItGame dataset."""
179
+
180
+ def __init__(self, **kwargs):
181
+ super().__init__()
182
+
183
+ self.split_configs = ["refcoco", "refcoco+", "refcocog"]
184
+ self.modality_names = ["image", "text"]
185
+
186
+ self.data_name = Config().data.dataname
187
+ self.base_coco = Config().data.base_coco_images_path
188
+ self.data_source = "COCO2017"
189
+
190
+ # Obtain which split to use:
191
+ # refclef, refcoco, refcoco+ and refcocog
192
+ self.split_config = Config().data.split_config
193
+ # Obtain which specific setting to use:
194
+ # unc, google
195
+ self.split_name = Config().data.split_name
196
+ if self.split_config not in self.split_configs:
197
+ logging.info(
198
+ "%s does not exist in the official configurations %s.",
199
+ self.split_config,
200
+ self.split_configs,
201
+ )
202
+
203
+ _path = Config().params["data_path"]
204
+ self._data_path_process(data_path=_path, base_data_name=self.data_name)
205
+ base_data_path = self.mm_data_info["data_path"]
206
+
207
+ # raw coco images path
208
+ coco_raw_imgs_path = self.base_coco
209
+ if self._exists(coco_raw_imgs_path):
210
+ logging.info(
211
+ "Successfully connecting the source COCO2017 images data from the path %s",
212
+ coco_raw_imgs_path,
213
+ )
214
+ else:
215
+ logging.info(
216
+ "Fail to connect the source COCO2017 images data from the path %s",
217
+ coco_raw_imgs_path,
218
+ )
219
+
220
+ # download the public official code and the required config
221
+ download_split_url = (
222
+ Config().data.download_splits_base_url + self.split_config + ".zip"
223
+ )
224
+ for dd_url in [download_split_url]:
225
+ self._download_arrange_data(
226
+ download_url_address=dd_url, data_path=base_data_path
227
+ )
228
+
229
+ self._dataset_refer = referitgame_utils.REFER(
230
+ data_root=base_data_path,
231
+ image_dataroot=coco_raw_imgs_path,
232
+ dataset=self.split_config,
233
+ splitBy=self.split_name,
234
+ ) # default is unc or google
235
+
236
+ self._splited_referids_holder = {}
237
+ self._connect_to_splits()
238
+
239
+ def _connect_to_splits(self):
240
+ split_types = SplitedDatasets._fields
241
+ for split_type in split_types:
242
+ formatted_split_type = split_type.split("_", maxsplit=1)[0]
243
+ self._splited_referids_holder[formatted_split_type] = (
244
+ self._dataset_refer.getRefIds(split=formatted_split_type)
245
+ )
246
+
247
+ def get_phase_data(self, phase):
248
+ """Get phrases from the raw data"""
249
+ mode_refer_ids = self._splited_referids_holder[phase]
250
+
251
+ mode_elements_holder = {}
252
+ mode_flatten_emelemts = []
253
+
254
+ for refer_id in mode_refer_ids:
255
+ ref = self._dataset_refer.loadRefs(refer_id)[0]
256
+ image_id = ref["image_id"]
257
+ image_file_path = self._dataset_refer.loadImgspath(image_id)
258
+ caption_phrases_cate = self._dataset_refer.Cats[ref["category_id"]]
259
+ caption_phrases_cate_id = ref["category_id"]
260
+
261
+ mode_elements_holder[refer_id] = {}
262
+ mode_elements_holder[refer_id]["image_id"] = image_id
263
+ mode_elements_holder[refer_id]["image_file_path"] = image_file_path
264
+
265
+ mode_elements_holder[refer_id]["sentences"] = []
266
+ for send in ref["sentences"]:
267
+ caption = send["tokens"]
268
+ caption_phrase = send["tokens"]
269
+
270
+ # images_data = dt_refer.loadImgData(image_id) # a list
271
+ caption_phrase_bboxs = self._dataset_refer.getRefBox(
272
+ ref["ref_id"]
273
+ ) # [x, y, w, h]
274
+ # convert to [xmin, ymin, xmax, ymax]
275
+ caption_phrase_bboxs = [
276
+ caption_phrase_bboxs[0],
277
+ caption_phrase_bboxs[1],
278
+ caption_phrase_bboxs[0] + caption_phrase_bboxs[2],
279
+ caption_phrase_bboxs[1] + caption_phrase_bboxs[3],
280
+ ]
281
+
282
+ sent_infos = {
283
+ "caption": caption,
284
+ "caption_phrase": caption_phrase,
285
+ "caption_phrase_bboxs": caption_phrase_bboxs,
286
+ "caption_phrases_cate": caption_phrases_cate,
287
+ "caption_phrases_cate_id": caption_phrases_cate_id,
288
+ }
289
+
290
+ mode_elements_holder[refer_id]["sentences"].append(sent_infos)
291
+
292
+ mode_flatten_emelemts.append(
293
+ [
294
+ image_id,
295
+ image_file_path,
296
+ caption,
297
+ caption_phrase,
298
+ caption_phrase_bboxs,
299
+ caption_phrases_cate,
300
+ caption_phrases_cate_id,
301
+ ]
302
+ )
303
+
304
+ return mode_elements_holder, mode_flatten_emelemts
305
+
306
+ def get_phase_dataset(self, phase, modality_sampler):
307
+ """Obtain the dataset for the specific phase"""
308
+ _, mode_flatten_emelemts = self.get_phase_data(phase)
309
+
310
+ dataset = ReferItGameDataset(
311
+ dataset_info=mode_flatten_emelemts,
312
+ phase_info=self._dataset_refer,
313
+ phase=phase,
314
+ modality_sampler=modality_sampler,
315
+ )
316
+ return dataset
317
+
318
+ def get_train_set(self, modality_sampler=None):
319
+ """Obtains the training dataset."""
320
+ phase = "train"
321
+
322
+ self.trainset = self.get_phase_dataset(phase, modality_sampler)
323
+ return self.trainset
324
+
325
+ def get_test_set(self, modality_sampler=None):
326
+ """Obtains the validation dataset."""
327
+ phase = "test"
328
+
329
+ self.testset = self.get_phase_dataset(phase, modality_sampler)
330
+ return self.testset
@@ -0,0 +1,119 @@
1
+ """
2
+ Having a registry of all available classes is convenient for retrieving an instance
3
+ based on a configuration at run-time.
4
+ """
5
+
6
+ import logging
7
+
8
+ from plato.config import Config
9
+
10
+ from plato.datasources import (
11
+ mnist,
12
+ fashion_mnist,
13
+ emnist,
14
+ cifar10,
15
+ cifar100,
16
+ cinic10,
17
+ purchase,
18
+ texas,
19
+ huggingface,
20
+ pascal_voc,
21
+ tiny_imagenet,
22
+ femnist,
23
+ feature,
24
+ qoenflx,
25
+ celeba,
26
+ stl10,
27
+ )
28
+
29
+ registered_datasources = {
30
+ "MNIST": mnist,
31
+ "FashionMNIST": fashion_mnist,
32
+ "EMNIST": emnist,
33
+ "CIFAR10": cifar10,
34
+ "CIFAR100": cifar100,
35
+ "CINIC10": cinic10,
36
+ "Purchase": purchase,
37
+ "Texas": texas,
38
+ "HuggingFace": huggingface,
39
+ "PASCAL_VOC": pascal_voc,
40
+ "TinyImageNet": tiny_imagenet,
41
+ "Feature": feature,
42
+ "QoENFLX": qoenflx,
43
+ "CelebA": celeba,
44
+ "STL10": stl10,
45
+ }
46
+
47
+ registered_partitioned_datasources = {"FEMNIST": femnist}
48
+
49
+
50
+ def get(client_id: int = 0, **kwargs):
51
+ """Get the data source with the provided name."""
52
+ datasource_name = (
53
+ kwargs["datasource_name"]
54
+ if "datasource_name" in kwargs
55
+ else Config().data.datasource
56
+ )
57
+
58
+ logging.info("Data source: %s", datasource_name)
59
+
60
+ if datasource_name == "kinetics700":
61
+ from plato.datasources import kinetics
62
+
63
+ return kinetics.DataSource(**kwargs)
64
+
65
+ if datasource_name == "Gym":
66
+ from plato.datasources import gym
67
+
68
+ return gym.DataSource(**kwargs)
69
+
70
+ if datasource_name == "Flickr30KE":
71
+ from plato.datasources import flickr30k_entities
72
+
73
+ return flickr30k_entities.DataSource(**kwargs)
74
+
75
+ if datasource_name == "ReferItGame":
76
+ from plato.datasources import referitgame
77
+
78
+ return referitgame.DataSource(**kwargs)
79
+
80
+ if datasource_name == "COCO":
81
+ from plato.datasources import coco
82
+
83
+ return coco.DataSource(**kwargs)
84
+
85
+ if datasource_name == "YOLOv8":
86
+ from plato.datasources import yolov8
87
+
88
+ return yolov8.DataSource(**kwargs)
89
+ elif datasource_name in registered_datasources:
90
+ dataset = registered_datasources[datasource_name].DataSource(**kwargs)
91
+ elif datasource_name in registered_partitioned_datasources:
92
+ dataset = registered_partitioned_datasources[datasource_name].DataSource(
93
+ client_id, **kwargs
94
+ )
95
+ else:
96
+ raise ValueError(f"No such data source: {datasource_name}")
97
+
98
+ return dataset
99
+
100
+
101
+ def get_input_shape():
102
+ """Get the input shape of data source with the provided name."""
103
+ datasource_name = Config().data.datasource
104
+
105
+ logging.info("Data source: %s", Config().data.datasource)
106
+ if Config().data.datasource == "YOLO":
107
+ from plato.datasources import yolo
108
+
109
+ return yolo.DataSource.input_shape()
110
+ elif datasource_name in registered_datasources:
111
+ input_shape = registered_datasources[datasource_name].DataSource.input_shape()
112
+ elif datasource_name in registered_partitioned_datasources:
113
+ input_shape = registered_partitioned_datasources[
114
+ datasource_name
115
+ ].DataSource.input_shape()
116
+ else:
117
+ raise ValueError(f"No such data source: {datasource_name}")
118
+
119
+ return input_shape
@@ -0,0 +1,98 @@
1
+ """
2
+ A self-supervised learning dataset working as a wrapper to add the SSL data
3
+ transform to the datasource of Plato.
4
+
5
+ To allow SSL transform to use the desired parameters, one should place the
6
+ 'data_transforms' sub-block under the 'algorithm' block in the config file.
7
+ """
8
+
9
+ from lightly import transforms
10
+
11
+ from plato.datasources import base
12
+ from plato.datasources import registry as datasources_registry
13
+ from plato.config import Config
14
+
15
+
16
+ # The normalizations for different datasets
17
+ MNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]}
18
+ FashionMNIST_NORMALIZE = {"mean": [0.1307], "std": [0.3081]}
19
+ CIFAR10_NORMALIZE = {"mean": [0.491, 0.482, 0.447], "std": [0.247, 0.243, 0.262]}
20
+ CIFAR100_NORMALIZE = {"mean": [0.491, 0.482, 0.447], "std": [0.247, 0.243, 0.262]}
21
+ IMAGENET_NORMALIZE = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
22
+ STL10_NORMALIZE = {"mean": [0.4914, 0.4823, 0.4466], "std": [0.247, 0.243, 0.261]}
23
+
24
+ dataset_normalizations = {
25
+ "MNIST": MNIST_NORMALIZE,
26
+ "FashionMNIST": FashionMNIST_NORMALIZE,
27
+ "CIFAR10": CIFAR10_NORMALIZE,
28
+ "CIFAR100": CIFAR100_NORMALIZE,
29
+ "IMAGENET": IMAGENET_NORMALIZE,
30
+ "STL10": STL10_NORMALIZE,
31
+ }
32
+
33
+
34
+ # All transforms for different SSL algorithms
35
+ registered_transforms = {
36
+ "SimCLR": transforms.SimCLRTransform,
37
+ "DINO": transforms.DINOTransform,
38
+ "MAE": transforms.MAETransform,
39
+ "MoCoV1": transforms.MoCoV1Transform,
40
+ "MoCoV2": transforms.MoCoV2Transform,
41
+ "MSN": transforms.MSNTransform,
42
+ "PIRL": transforms.PIRLTransform,
43
+ "SimSiam": transforms.SimSiamTransform,
44
+ "SMoG": transforms.SMoGTransform,
45
+ "SwaV": transforms.SwaVTransform,
46
+ "VICReg": transforms.VICRegTransform,
47
+ "VICRegL": transforms.VICRegLTransform,
48
+ "FastSiam": transforms.FastSiamTransform,
49
+ }
50
+
51
+
52
+ def get_transforms():
53
+ """Obtain train/test transforms for the corresponding data."""
54
+
55
+ # Get the transforms details set in the config file
56
+ transforms_config = Config().algorithm.data_transforms._asdict()
57
+
58
+ # Set the data transform, which will be used as parameters to define the
59
+ # SSL transform in registered_transforms
60
+ data_transforms = {}
61
+ if "train_transform" in transforms_config:
62
+ transform_config = transforms_config["train_transform"]._asdict()
63
+ transform_name = transform_config["name"]
64
+ transform_params = transform_config["parameters"]._asdict()
65
+
66
+ # Get the data normalization for the datasource
67
+ datasource_name = Config().data.datasource
68
+ transform_params["normalize"] = dataset_normalizations[datasource_name]
69
+ # Get the SSL transform
70
+ if transform_name in registered_transforms:
71
+ dataset_transform = registered_transforms[transform_name](
72
+ **transform_params
73
+ )
74
+ else:
75
+ raise ValueError(f"No such data source: {transform_name}")
76
+
77
+ # Insert the obtained transform to the data_transforms.
78
+ # It is used by the datasource of Plato to get the train/test set.
79
+ data_transforms.update({"train_transform": dataset_transform})
80
+
81
+ return data_transforms
82
+
83
+
84
+ # pylint: disable=abstract-method
85
+ class SSLDataSource(base.DataSource):
86
+ """
87
+ An SSL datasource to define the dataSource for self-supervised learning.
88
+ """
89
+
90
+ def __init__(self):
91
+ super().__init__()
92
+
93
+ # Get the transforms for the data
94
+ data_transforms = get_transforms()
95
+
96
+ self.datasource = datasources_registry.get(**data_transforms)
97
+ self.trainset = self.datasource.trainset
98
+ self.testset = self.datasource.testset
@@ -0,0 +1,103 @@
1
+ """
2
+ The STL-10 dataset from the torchvision package.
3
+ The details of this data can be found on the websites:
4
+ https://cs.stanford.edu/~acoates/stl10/
5
+ and
6
+ https://www.kaggle.com/datasets/jessicali9530/stl10
7
+ .
8
+ """
9
+
10
+ from torch.utils.data.dataset import Dataset
11
+ from torchvision import datasets, transforms
12
+
13
+ from plato.config import Config
14
+ from plato.datasources import base
15
+
16
+
17
+ class STL10Dataset(Dataset):
18
+ """Prepares the STL10 dataset for the subsequence usage.
19
+ The class annotation of the STL10 dataset is denoted as
20
+ labels instead of targets used by subsequence learning
21
+ of the Plato.
22
+ """
23
+
24
+ def __init__(self, dataset):
25
+ self.dataset = dataset
26
+
27
+ # obtain the raw data for subsequence
28
+ # usage, such as the self-supervised learning
29
+ self.data = self.dataset.data
30
+ self.targets = self.dataset.labels
31
+ self.target_transform = self.dataset.target_transform
32
+ self.classes = self.dataset.classes
33
+
34
+ def __getitem__(self, index):
35
+ return self.dataset[index]
36
+
37
+ def __len__(self):
38
+ return len(self.dataset)
39
+
40
+
41
+ class DataSource(base.DataSource):
42
+ """The STL-10 dataset."""
43
+
44
+ def __init__(self, **kwargs):
45
+ super().__init__()
46
+ _path = Config().params["data_path"]
47
+
48
+ normalize = transforms.Normalize(
49
+ mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
50
+ std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
51
+ )
52
+ train_transform = (
53
+ kwargs["train_transform"]
54
+ if "train_transform" in kwargs
55
+ else (
56
+ transforms.Compose(
57
+ [
58
+ transforms.RandomCrop(96, padding=4),
59
+ transforms.RandomHorizontalFlip(),
60
+ transforms.ToTensor(),
61
+ normalize,
62
+ ]
63
+ )
64
+ )
65
+ )
66
+
67
+ test_transform = (
68
+ kwargs["test_transform"]
69
+ if "test_transform" in kwargs
70
+ else (
71
+ transforms.Compose(
72
+ [
73
+ transforms.ToTensor(),
74
+ normalize,
75
+ ]
76
+ )
77
+ )
78
+ )
79
+
80
+ stl10_trainset = datasets.STL10(
81
+ root=_path, split="train", download=True, transform=train_transform
82
+ )
83
+ stl10_unlabeled_set = datasets.STL10(
84
+ root=_path, split="unlabeled", download=True, transform=train_transform
85
+ )
86
+
87
+ stl10_testset = datasets.STL10(
88
+ root=_path, split="test", download=True, transform=test_transform
89
+ )
90
+
91
+ self.trainset = STL10Dataset(stl10_trainset)
92
+ self.unlabeledset = STL10Dataset(stl10_unlabeled_set)
93
+ self.testset = STL10Dataset(stl10_testset)
94
+
95
+ def num_train_examples(self):
96
+ return len(self.trainset)
97
+
98
+ def num_test_examples(self):
99
+ return len(self.testset)
100
+
101
+ def get_unlabeled_set(self):
102
+ """Obtains the unlabeled dataset."""
103
+ return self.unlabeledset