plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,362 @@
1
+ """
2
+ The Flickr30K Entities dataset.
3
+
4
+ The data structure and setting follow:
5
+ "http://bryanplummer.com/Flickr30kEntities/".
6
+
7
+ We utilize the official splits that contain:
8
+ - train: 29783 images,
9
+ - val: 1000 images,
10
+ - test: 1000 images
11
+
12
+ The file structure of this dataset is:
13
+ - Images (jpg): the raw images
14
+ - Annotations (xml): the bounding boxes
15
+ - Sentence (txt): captions of the image
16
+
17
+ The data structure under the 'data/' is:
18
+ ├── Flickr30KEntities # root dir of Flickr30K Entities dataset
19
+ │ ├── Flickr30KEntitiesRaw # Raw images/annotations and the official splits
20
+ │ ├── train # data dir for the train phase
21
+ │ │ └── train_Annotations
22
+ │ │ └── train_Images
23
+ │ │ └── train_Sentences
24
+ │ └── test
25
+ │ └── val
26
+
27
+
28
+ Detailed loaded sample structure:
29
+
30
+ One sample is presented as the dict type:
31
+ - rgb: the image data.
32
+ - text:
33
+ - caption : a nested list, such as
34
+ [['The woman is applying mascara while looking in the mirror.']],
35
+ - caption_phrases: a nested list, each item is the list contains
36
+ the phrases of the caption, such as:
37
+ [['Military personnel'], ['greenish gray uniforms'], ['matching hats']]
38
+ - box:
39
+ - caption_phrase_bboxs: a 2-depth nested list, each item is a list that
40
+ contains boxes of the corresponding phrase, such as:
41
+ [[[295, 130, 366, 244], [209, 123, 300, 246], [347, 1, 439, 236]],
42
+ [[0, 21, 377, 220]], [[0, 209, 214, 332]]]
43
+ - target:
44
+ - caption_phrases_cate: a nested list, each item is a string that
45
+ presents the categories of the phrase, such as:
46
+ [['people'], ['bodyparts'], ['other']].
47
+
48
+ - caption_phrases_cate_id: a list, each item is a int that shows
49
+ the integar/str of the phrase, such as:
50
+ ['121973', '121976', '121975']
51
+
52
+ One batch of samples is presented as a list,
53
+ For example, the corresponding caption_phrase_bboxs in one batch is:
54
+ [
55
+ [[[295, 130, 366, 244], [209, 123, 300, 246], [347, 1, 439, 236]], [[0, 21, 377, 220]],
56
+ [[0, 209, 214, 332]]], - batch-1
57
+ [[[90, 68, 325, 374]], [[118, 64, 192, 128]]], - batch-1
58
+ [[[1, 0, 148, 451]], [[153, 148, 400, 413]], [[374, 320, 450, 440]]], - batch-1
59
+ ]
60
+ """
61
+
62
+ import json
63
+ import logging
64
+ import os
65
+
66
+ import torch
67
+ import skimage.io as io
68
+ import cv2
69
+
70
+ from plato.config import Config
71
+ from plato.datasources import multimodal_base
72
+ from plato.datasources.multimodal_base import TextData, BoxData, TargetData
73
+ from plato.datasources.datalib import data_utils
74
+ from plato.datasources.datalib import flickr30kE_utils
75
+
76
+
77
+ def collate_fn(batch):
78
+ """The construction of the loaded batch of data
79
+
80
+ Args:
81
+ batch (list): [a list in which each element contains the data for one task,
82
+ assert len(batch) == number of tasks,
83
+ assert len(batch[i]) == 6]
84
+
85
+ Returns:
86
+ [batch]: [return the original batch of data directly]
87
+ """
88
+ return batch
89
+
90
+
91
+ class Flickr30KEDataset(multimodal_base.MultiModalDataset):
92
+ """Prepare the Flickr30K Entities dataset."""
93
+
94
+ def __init__(
95
+ self,
96
+ dataset_info,
97
+ phase,
98
+ phase_info,
99
+ data_types,
100
+ modality_sampler=None,
101
+ transform_image_dec_func=None,
102
+ transform_text_func=None,
103
+ ):
104
+ super().__init__()
105
+
106
+ self.phase = phase
107
+ self.phase_multimodal_data_record = dataset_info
108
+ self.phase_info = phase_info
109
+ self.data_types = data_types
110
+ self.transform_image_dec_func = transform_image_dec_func
111
+ self.transform_text_func = transform_text_func
112
+
113
+ self.phase_samples_name = list(self.phase_multimodal_data_record.keys())
114
+
115
+ self.supported_modalities = ["rgb", "text"]
116
+
117
+ # default utilizing the full modalities
118
+ if modality_sampler is None:
119
+ self.modality_sampler = self.supported_modalities
120
+ else:
121
+ self.modality_sampler = modality_sampler
122
+
123
+ def __len__(self):
124
+ return len(self.phase_multimodal_data_record)
125
+
126
+ def get_sample_image_data(self, image_id):
127
+ """Get one image data as the sample"""
128
+ # get the image data
129
+ image_phase_path = self.phase_info[self.data_types[0]]["path"]
130
+ image_phase_format = self.phase_info[self.data_types[0]]["format"]
131
+
132
+ image_data = io.imread(
133
+ os.path.join(image_phase_path, str(image_id) + image_phase_format)
134
+ )
135
+ image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
136
+
137
+ return image_data
138
+
139
+ def extract_sample_anno_data(self, image_anno_sent):
140
+ """Extract the annotation."""
141
+ sentence = image_anno_sent["sentence"] # a string
142
+ sentence_phrases = image_anno_sent["sentence_phrases"] # a list
143
+ sentence_phrases_type = image_anno_sent[
144
+ "sentence_phrases_type"
145
+ ] # a nested list
146
+ sentence_phrases_id = image_anno_sent["sentence_phrases_id"] # a list
147
+ sentence_phrases_boxes = image_anno_sent[
148
+ "sentence_phrases_boxes"
149
+ ] # a nested list
150
+
151
+ return (
152
+ sentence,
153
+ sentence_phrases,
154
+ sentence_phrases_type,
155
+ sentence_phrases_id,
156
+ sentence_phrases_boxes,
157
+ )
158
+
159
+ def get_one_multimodal_sample(self, sample_idx):
160
+ """Obtain one sample from the Flickr30K Entities dataset."""
161
+ samle_retrieval_name = self.phase_samples_name[sample_idx]
162
+ image_file_name = os.path.basename(samle_retrieval_name)
163
+ image_id = os.path.splitext(image_file_name)[0]
164
+
165
+ image_data = self.get_sample_image_data(image_id)
166
+
167
+ image_anno_sent = self.phase_multimodal_data_record[samle_retrieval_name]
168
+
169
+ (
170
+ sentence,
171
+ sentence_phrases,
172
+ sentence_phrases_type,
173
+ sentence_phrases_id,
174
+ sentence_phrases_boxes,
175
+ ) = self.extract_sample_anno_data(image_anno_sent)
176
+
177
+ caption = (
178
+ sentence
179
+ if any(isinstance(iter_i, list) for iter_i in sentence)
180
+ else [[sentence]]
181
+ )
182
+ flatten_caption_phrase_bboxs = [
183
+ box for boxes in sentence_phrases_boxes for box in boxes
184
+ ]
185
+ # ['The woman', 'mascara', 'the mirror']
186
+ caption_phrases = [[phrase] for phrase in sentence_phrases]
187
+ caption_phrases_cate = sentence_phrases_type
188
+ caption_phrases_cate_id = sentence_phrases_id
189
+
190
+ if self.transform_image_dec_func is not None:
191
+ transformed = self.transform_image_dec_func(
192
+ image=image_data,
193
+ bboxes=flatten_caption_phrase_bboxs,
194
+ category_ids=range(len(flatten_caption_phrase_bboxs)),
195
+ )
196
+
197
+ image_data = transformed["image"]
198
+ image_data = torch.from_numpy(image_data)
199
+ flatten_caption_phrase_bboxs = transformed["bboxes"]
200
+ caption_phrase_bboxs = flickr30kE_utils.phrase_boxes_alignment(
201
+ flatten_caption_phrase_bboxs, sentence_phrases_boxes
202
+ )
203
+
204
+ else:
205
+ caption_phrase_bboxs = sentence_phrases_boxes
206
+
207
+ if self.transform_text_func is not None:
208
+ caption_phrases = self.transform_text_func(caption_phrases)
209
+
210
+ text_data = TextData(caption=caption, caption_phrases=caption_phrases)
211
+ box_data = BoxData(caption_phrase_bboxs=caption_phrase_bboxs)
212
+ taget_data = TargetData(
213
+ caption_phrases_cate=caption_phrases_cate,
214
+ caption_phrases_cate_id=caption_phrases_cate_id,
215
+ )
216
+
217
+ return {
218
+ "rgb": image_data,
219
+ "text": text_data,
220
+ "box": box_data,
221
+ "target": taget_data,
222
+ }
223
+
224
+
225
+ class DataSource(multimodal_base.MultiModalDataSource):
226
+ """The Flickr30K Entities dataset."""
227
+
228
+ def __init__(self, **kwargs):
229
+ super().__init__()
230
+
231
+ self.data_name = Config().data.dataname
232
+
233
+ self.modality_names = ["image", "text"]
234
+
235
+ _path = Config().params["data_path"]
236
+ self._data_path_process(data_path=_path, base_data_name=self.data_name)
237
+
238
+ raw_data_name = self.data_name + "Raw"
239
+ base_data_path = self.mm_data_info["data_path"]
240
+
241
+ download_url = Config().data.download_url
242
+
243
+ self._download_arrange_data(
244
+ download_url_address=download_url,
245
+ data_path=base_data_path,
246
+ extract_to_dir=base_data_path,
247
+ )
248
+
249
+ # define the path of different data source,
250
+ # the annotation is .xml, the sentence is in .txt
251
+ self.raw_data_types = ["Flickr30k_images", "Annotations", "Sentences"]
252
+ self.raw_data_file_format = [".jpg", ".xml", ".txt"]
253
+ self.data_types = ["Images", "Annotations", "Sentences"]
254
+
255
+ # extract the data information and structure
256
+ for raw_type_idx, raw_type in enumerate(self.raw_data_types):
257
+ raw_file_format = self.raw_data_file_format[raw_type_idx]
258
+ data_type = self.data_types[raw_type_idx]
259
+
260
+ raw_type_path = os.path.join(base_data_path, raw_data_name, raw_type)
261
+
262
+ self.mm_data_info[data_type] = dict()
263
+ self.mm_data_info[data_type]["path"] = raw_type_path
264
+ self.mm_data_info[data_type]["format"] = raw_file_format
265
+ self.mm_data_info[data_type]["num_samples"] = len(os.listdir(raw_type_path))
266
+
267
+ # generate path/type information for splits
268
+ for split_type in list(self.splits_info.keys()):
269
+ self.splits_info[split_type]["split_file"] = os.path.join(
270
+ base_data_path, raw_data_name, split_type + ".txt"
271
+ )
272
+ split_path = self.splits_info[split_type]["path"]
273
+ for dt_type_idx, dt_type in enumerate(self.data_types):
274
+ dt_type_format = self.raw_data_file_format[dt_type_idx]
275
+
276
+ self.splits_info[split_type][dt_type] = dict()
277
+ self.splits_info[split_type][dt_type]["path"] = os.path.join(
278
+ split_path, ("{}_{}").format(split_type, dt_type)
279
+ )
280
+ self.splits_info[split_type][dt_type]["format"] = dt_type_format
281
+
282
+ # distribution data to splits
283
+ self.create_splits_data()
284
+
285
+ # generate the splits information txt for further utilization
286
+ flickr30kE_utils.integrate_data_to_json(
287
+ splits_info=self.splits_info,
288
+ mm_data_info=self.mm_data_info,
289
+ data_types=self.data_types,
290
+ split_wise=True,
291
+ globally=True,
292
+ )
293
+
294
+ def create_splits_data(self):
295
+ """Create datasets for different splits"""
296
+ # saveing the images and entities to the corresponding directory
297
+ for split_type in list(self.splits_info.keys()):
298
+ logging.info("Creating split %s data..........", split_type)
299
+ # obtain the split data information
300
+ # 0. getting the data
301
+ split_info_file = self.splits_info[split_type]["split_file"]
302
+ with open(split_info_file, "r") as loaded_file:
303
+ split_data_samples = [
304
+ sample_id.split("\n")[0] for sample_id in loaded_file.readlines()
305
+ ]
306
+ self.splits_info[split_type]["num_samples"] = len(split_data_samples)
307
+
308
+ # 1. create directory for the splited data if necessary
309
+ for dt_type in self.data_types:
310
+ split_dt_type_path = self.splits_info[split_type][dt_type]["path"]
311
+
312
+ if not self._exists(split_dt_type_path):
313
+ os.makedirs(split_dt_type_path, exist_ok=True)
314
+ else:
315
+ logging.info("The path %s exists.", split_dt_type_path)
316
+ continue
317
+
318
+ raw_data_type_path = self.mm_data_info[dt_type]["path"]
319
+ raw_data_format = self.mm_data_info[dt_type]["format"]
320
+ split_samples_path = [
321
+ os.path.join(raw_data_type_path, sample_id + raw_data_format)
322
+ for sample_id in split_data_samples
323
+ ]
324
+ # 2. saving the splited data into the target file
325
+ data_utils.copy_files(split_samples_path, split_dt_type_path)
326
+
327
+ logging.info("Done.")
328
+
329
+ def get_phase_data_info(self, phase):
330
+ """Obtain the data information for the required phrase"""
331
+ path = self.splits_info[phase]["path"]
332
+ save_path = os.path.join(path, phase + "_integrated_data.json")
333
+ with open(save_path, "r") as outfile:
334
+ phase_data_info = json.load(outfile)
335
+ return phase_data_info
336
+
337
+ def get_phase_dataset(self, phase, modality_sampler):
338
+ """Obtain the dataset for the specific phase"""
339
+ phase_data_info = self.get_phase_data_info(phase)
340
+ phase_split_info = self.splits_info[phase]
341
+ dataset = Flickr30KEDataset(
342
+ dataset_info=phase_data_info,
343
+ phase_info=phase_split_info,
344
+ data_types=self.data_types,
345
+ phase=phase,
346
+ modality_sampler=modality_sampler,
347
+ )
348
+ return dataset
349
+
350
+ def get_train_set(self, modality_sampler=None):
351
+ """Obtains the training dataset."""
352
+ phase = "train"
353
+
354
+ self.trainset = self.get_phase_dataset(phase, modality_sampler)
355
+ return self.trainset
356
+
357
+ def get_test_set(self, modality_sampler=None):
358
+ """Obtains the validation dataset."""
359
+ phase = "test"
360
+
361
+ self.testset = self.get_phase_dataset(phase, modality_sampler)
362
+ return self.testset