datamint 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/__init__.py +11 -0
- datamint-1.2.4.dist-info/METADATA +118 -0
- datamint-1.2.4.dist-info/RECORD +30 -0
- datamint-1.2.4.dist-info/WHEEL +4 -0
- datamint-1.2.4.dist-info/entry_points.txt +4 -0
- datamintapi/__init__.py +25 -0
- datamintapi/apihandler/annotation_api_handler.py +748 -0
- datamintapi/apihandler/api_handler.py +15 -0
- datamintapi/apihandler/base_api_handler.py +300 -0
- datamintapi/apihandler/dto/annotation_dto.py +149 -0
- datamintapi/apihandler/exp_api_handler.py +204 -0
- datamintapi/apihandler/root_api_handler.py +1013 -0
- datamintapi/client_cmd_tools/__init__.py +0 -0
- datamintapi/client_cmd_tools/datamint_config.py +168 -0
- datamintapi/client_cmd_tools/datamint_upload.py +483 -0
- datamintapi/configs.py +58 -0
- datamintapi/dataset/__init__.py +1 -0
- datamintapi/dataset/base_dataset.py +881 -0
- datamintapi/dataset/dataset.py +492 -0
- datamintapi/examples/__init__.py +1 -0
- datamintapi/examples/example_projects.py +75 -0
- datamintapi/experiment/__init__.py +1 -0
- datamintapi/experiment/_patcher.py +570 -0
- datamintapi/experiment/experiment.py +1049 -0
- datamintapi/logging.yaml +27 -0
- datamintapi/utils/dicom_utils.py +640 -0
- datamintapi/utils/io_utils.py +149 -0
- datamintapi/utils/logging_utils.py +55 -0
- datamintapi/utils/torchmetrics.py +70 -0
- datamintapi/utils/visualization.py +129 -0
|
@@ -0,0 +1,881 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
from typing import Optional, Callable, Any, Literal
|
|
5
|
+
import logging
|
|
6
|
+
import shutil
|
|
7
|
+
import json
|
|
8
|
+
import yaml
|
|
9
|
+
import pydicom
|
|
10
|
+
import numpy as np
|
|
11
|
+
from datamintapi import configs
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
import torch
|
|
14
|
+
from datamintapi.apihandler.base_api_handler import DatamintException
|
|
15
|
+
from datamintapi.utils.dicom_utils import is_dicom
|
|
16
|
+
import cv2
|
|
17
|
+
from datamintapi.utils.io_utils import read_array_normalized
|
|
18
|
+
from deprecated import deprecated
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
|
|
21
|
+
_LOGGER = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DatamintDatasetException(DatamintException):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DatamintBaseDataset:
|
|
29
|
+
"""
|
|
30
|
+
Class to download and load datasets from the Datamint API.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
root: Root directory of dataset where data already exists or will be downloaded.
|
|
34
|
+
project_name: Name of the project to download.
|
|
35
|
+
auto_update: If True, the dataset will be checked for updates and downloaded if necessary.
|
|
36
|
+
api_key: API key to access the Datamint API. If not provided, it will look for the
|
|
37
|
+
environment variable 'DATAMINT_API_KEY'. Not necessary if
|
|
38
|
+
you don't want to download/update the dataset.
|
|
39
|
+
return_dicom: If True, the DICOM object will be returned, if the image is a DICOM file.
|
|
40
|
+
return_metainfo: If True, the metainfo of the image will be returned.
|
|
41
|
+
return_annotations: If True, the annotations of the image will be returned.
|
|
42
|
+
return_frame_by_frame: If True, each frame of a video/DICOM/3d-image will be returned separately.
|
|
43
|
+
include_unannotated: If True, images without annotations will be included. If False, images without annotations will be discarded.
|
|
44
|
+
all_annotations: If True, all annotations will be downloaded, including the ones that are not set as closed/done.
|
|
45
|
+
server_url: URL of the Datamint server. If not provided, it will use the default server.
|
|
46
|
+
include_annotators: List of annotators to include. If None, all annotators will be included. See parameter ``exclude_annotators``.
|
|
47
|
+
exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded. See parameter ``include_annotators``.
|
|
48
|
+
include_segmentation_names: List of segmentation names to include. If None, all segmentations will be included.
|
|
49
|
+
exclude_segmentation_names: List of segmentation names to exclude. If None, no segmentations will be excluded.
|
|
50
|
+
include_image_label_names: List of image label names to include. If None, all image labels will be included.
|
|
51
|
+
exclude_image_label_names: List of image label names to exclude. If None, no image labels will be excluded.
|
|
52
|
+
include_frame_label_names: List of frame label names to include. If None, all frame labels will be included.
|
|
53
|
+
exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
DATAMINT_DEFAULT_DIR = ".datamint"
|
|
58
|
+
DATAMINT_DATASETS_DIR = "datasets"
|
|
59
|
+
|
|
60
|
+
def __init__(self,
|
|
61
|
+
project_name: str,
|
|
62
|
+
root: str | None = None,
|
|
63
|
+
auto_update: bool = True,
|
|
64
|
+
api_key: Optional[str] = None,
|
|
65
|
+
server_url: Optional[str] = None,
|
|
66
|
+
return_dicom: bool = False,
|
|
67
|
+
return_metainfo: bool = True,
|
|
68
|
+
return_annotations: bool = True,
|
|
69
|
+
return_frame_by_frame: bool = False,
|
|
70
|
+
include_unannotated: bool = True,
|
|
71
|
+
all_annotations: bool = False,
|
|
72
|
+
# filtering parameters
|
|
73
|
+
include_annotators: Optional[list[str]] = None,
|
|
74
|
+
exclude_annotators: Optional[list[str]] = None,
|
|
75
|
+
include_segmentation_names: Optional[list[str]] = None,
|
|
76
|
+
exclude_segmentation_names: Optional[list[str]] = None,
|
|
77
|
+
include_image_label_names: Optional[list[str]] = None,
|
|
78
|
+
exclude_image_label_names: Optional[list[str]] = None,
|
|
79
|
+
include_frame_label_names: Optional[list[str]] = None,
|
|
80
|
+
exclude_frame_label_names: Optional[list[str]] = None
|
|
81
|
+
):
|
|
82
|
+
from datamintapi.apihandler.api_handler import APIHandler
|
|
83
|
+
|
|
84
|
+
if project_name is None:
|
|
85
|
+
raise ValueError("project_name is required.")
|
|
86
|
+
|
|
87
|
+
self.all_annotations = all_annotations
|
|
88
|
+
self.api_handler = APIHandler(root_url=server_url, api_key=api_key,
|
|
89
|
+
check_connection=auto_update)
|
|
90
|
+
self.server_url = self.api_handler.root_url
|
|
91
|
+
if root is None:
|
|
92
|
+
# store them in the home directory
|
|
93
|
+
root = os.path.join(os.path.expanduser("~"),
|
|
94
|
+
DatamintBaseDataset.DATAMINT_DEFAULT_DIR)
|
|
95
|
+
root = os.path.join(root, DatamintBaseDataset.DATAMINT_DATASETS_DIR)
|
|
96
|
+
if not os.path.exists(root):
|
|
97
|
+
os.makedirs(root)
|
|
98
|
+
elif isinstance(root, str):
|
|
99
|
+
root = os.path.expanduser(root)
|
|
100
|
+
if not os.path.isdir(root):
|
|
101
|
+
raise NotADirectoryError(f"Root directory not found: {root}")
|
|
102
|
+
|
|
103
|
+
self.root = root
|
|
104
|
+
|
|
105
|
+
self.return_dicom = return_dicom
|
|
106
|
+
self.return_metainfo = return_metainfo
|
|
107
|
+
self.return_frame_by_frame = return_frame_by_frame
|
|
108
|
+
self.return_annotations = return_annotations
|
|
109
|
+
self.include_unannotated = include_unannotated
|
|
110
|
+
self.discard_without_annotations = not include_unannotated
|
|
111
|
+
|
|
112
|
+
# Filtering parameters
|
|
113
|
+
self.include_annotators = include_annotators
|
|
114
|
+
self.exclude_annotators = exclude_annotators
|
|
115
|
+
self.include_segmentation_names = include_segmentation_names
|
|
116
|
+
self.exclude_segmentation_names = exclude_segmentation_names
|
|
117
|
+
self.include_image_label_names = include_image_label_names
|
|
118
|
+
self.exclude_image_label_names = exclude_image_label_names
|
|
119
|
+
self.include_frame_label_names = include_frame_label_names
|
|
120
|
+
self.exclude_frame_label_names = exclude_frame_label_names
|
|
121
|
+
|
|
122
|
+
# Validate filtering parameters
|
|
123
|
+
if include_annotators is not None and exclude_annotators is not None:
|
|
124
|
+
raise ValueError("Cannot set both include_annotators and exclude_annotators at the same time")
|
|
125
|
+
|
|
126
|
+
if include_segmentation_names is not None and exclude_segmentation_names is not None:
|
|
127
|
+
raise ValueError("Cannot set both include_segmentation_names and exclude_segmentation_names at the same time")
|
|
128
|
+
|
|
129
|
+
if include_image_label_names is not None and exclude_image_label_names is not None:
|
|
130
|
+
raise ValueError("Cannot set both include_image_label_names and exclude_image_label_names at the same time")
|
|
131
|
+
|
|
132
|
+
if include_frame_label_names is not None and exclude_frame_label_names is not None:
|
|
133
|
+
raise ValueError("Cannot set both include_frame_label_names and exclude_frame_label_names at the same time")
|
|
134
|
+
|
|
135
|
+
self.project_name = project_name
|
|
136
|
+
dataset_name = project_name
|
|
137
|
+
|
|
138
|
+
self.dataset_dir = os.path.join(root, dataset_name)
|
|
139
|
+
self.dataset_zippath = os.path.join(root, f'{dataset_name}.zip')
|
|
140
|
+
|
|
141
|
+
local_dataset_exists = os.path.exists(os.path.join(self.dataset_dir, 'dataset.json'))
|
|
142
|
+
|
|
143
|
+
if local_dataset_exists and auto_update == False:
|
|
144
|
+
# In this case, we don't need to check the API, so we don't need the id.
|
|
145
|
+
self.dataset_id = None
|
|
146
|
+
else:
|
|
147
|
+
self.project_info = self.get_info()
|
|
148
|
+
self.dataset_id = self.project_info['dataset_id']
|
|
149
|
+
|
|
150
|
+
self.api_key = self.api_handler.api_key
|
|
151
|
+
if self.api_key is None:
|
|
152
|
+
_LOGGER.warning("API key not provided. If you want to download data, please provide an API key, " +
|
|
153
|
+
f"either by passing it as an argument," +
|
|
154
|
+
f"setting environment variable {configs.ENV_VARS[configs.APIKEY_KEY]} or " +
|
|
155
|
+
"using datamint-config command line tool."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Download/Updates the dataset, if necessary.
|
|
159
|
+
if local_dataset_exists:
|
|
160
|
+
_LOGGER.info(f"Dataset directory already exists: {self.dataset_dir}")
|
|
161
|
+
if auto_update:
|
|
162
|
+
_LOGGER.info("Checking for updates...")
|
|
163
|
+
self._check_version()
|
|
164
|
+
else:
|
|
165
|
+
if self.api_key is None:
|
|
166
|
+
raise DatamintDatasetException("API key is required to download the dataset.")
|
|
167
|
+
_LOGGER.info(f"No data found at {self.dataset_dir}. Downloading...")
|
|
168
|
+
self.download_project()
|
|
169
|
+
|
|
170
|
+
# Loads the metadata
|
|
171
|
+
if not hasattr(self, 'metainfo'):
|
|
172
|
+
with open(os.path.join(self.dataset_dir, 'dataset.json'), 'r') as file:
|
|
173
|
+
self.metainfo = json.load(file)
|
|
174
|
+
self.images_metainfo = self.metainfo['resources']
|
|
175
|
+
|
|
176
|
+
# filter annotations
|
|
177
|
+
for imginfo in self.images_metainfo:
|
|
178
|
+
imginfo['annotations'] = self._filter_annotations(imginfo['annotations'])
|
|
179
|
+
|
|
180
|
+
# filter out images with no annotations.
|
|
181
|
+
if self.discard_without_annotations:
|
|
182
|
+
original_count = len(self.images_metainfo)
|
|
183
|
+
self.images_metainfo = self._filter_items(self.images_metainfo)
|
|
184
|
+
_LOGGER.info(f"Discarded {original_count - len(self.images_metainfo)} images without annotations.")
|
|
185
|
+
|
|
186
|
+
self._check_integrity()
|
|
187
|
+
|
|
188
|
+
# fix images_metainfo labels
|
|
189
|
+
# TODO: check tags
|
|
190
|
+
# for imginfo in self.images_metainfo:
|
|
191
|
+
# if imginfo['frame_labels'] is not None:
|
|
192
|
+
# for flabels in imginfo['frame_labels']:
|
|
193
|
+
# if flabels['label'] is None:
|
|
194
|
+
# flabels['label'] = []
|
|
195
|
+
# elif isinstance(flabels['label'], str):
|
|
196
|
+
# flabels['label'] = flabels['label'].split(',')
|
|
197
|
+
|
|
198
|
+
if self.return_frame_by_frame:
|
|
199
|
+
self.dataset_length = 0
|
|
200
|
+
for imginfo in self.images_metainfo:
|
|
201
|
+
filepath = os.path.join(self.dataset_dir, imginfo['file'])
|
|
202
|
+
self.dataset_length += self.read_number_of_frames(filepath)
|
|
203
|
+
else:
|
|
204
|
+
self.dataset_length = len(self.images_metainfo)
|
|
205
|
+
|
|
206
|
+
self.num_frames_per_resource = self.__compute_num_frames_per_resource()
|
|
207
|
+
|
|
208
|
+
self.subset_indices = list(range(self.dataset_length))
|
|
209
|
+
# self.labels_set, self.label2code, self.segmentation_labels, self.segmentation_label2code = self.get_labels_set()
|
|
210
|
+
self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
|
|
211
|
+
self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
|
|
212
|
+
self.__logged_uint16_conversion = False
|
|
213
|
+
if self.discard_without_annotations and self.return_frame_by_frame:
|
|
214
|
+
# If we are returning frame by frame, we need to filter out frames without segmentations
|
|
215
|
+
self._filter_unannotated()
|
|
216
|
+
|
|
217
|
+
def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
|
|
218
|
+
return [img for img in images_metainfo if len(img.get('annotations', []))]
|
|
219
|
+
|
|
220
|
+
def _filter_unannotated(self):
|
|
221
|
+
"""Filter out frames that don't have any segmentations."""
|
|
222
|
+
filtered_indices = []
|
|
223
|
+
for i in range(len(self.subset_indices)):
|
|
224
|
+
item_meta = self._get_image_metainfo(i)
|
|
225
|
+
annotations = item_meta.get('annotations', [])
|
|
226
|
+
|
|
227
|
+
# Check if there are any segmentation annotations
|
|
228
|
+
has_segmentations = any(ann['type'] == 'segmentation' for ann in annotations)
|
|
229
|
+
|
|
230
|
+
if has_segmentations:
|
|
231
|
+
filtered_indices.append(self.subset_indices[i])
|
|
232
|
+
|
|
233
|
+
self.subset_indices = filtered_indices
|
|
234
|
+
print(f"Filtered dataset: {len(self.subset_indices)} frames with segmentations")
|
|
235
|
+
|
|
236
|
+
def __compute_num_frames_per_resource(self) -> list[int]:
|
|
237
|
+
num_frames_per_dicom = []
|
|
238
|
+
for imginfo in self.images_metainfo:
|
|
239
|
+
filepath = os.path.join(self.dataset_dir, imginfo['file'])
|
|
240
|
+
num_frames_per_dicom.append(self.read_number_of_frames(filepath))
|
|
241
|
+
return num_frames_per_dicom
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def frame_labels_set(self) -> list[str]:
|
|
245
|
+
"""
|
|
246
|
+
Returns the set of independent labels in the dataset.
|
|
247
|
+
This is more related to multi-label tasks.
|
|
248
|
+
"""
|
|
249
|
+
return self.frame_lsets['multilabel']
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def frame_categories_set(self) -> list[tuple[str, str]]:
|
|
253
|
+
"""
|
|
254
|
+
Returns the set of categories in the dataset.
|
|
255
|
+
This is more related to multi-class tasks.
|
|
256
|
+
"""
|
|
257
|
+
return self.frame_lsets['multiclass']
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def image_labels_set(self) -> list[str]:
|
|
261
|
+
"""
|
|
262
|
+
Returns the set of independent labels in the dataset.
|
|
263
|
+
This is more related to multi-label tasks.
|
|
264
|
+
"""
|
|
265
|
+
return self.image_lsets['multilabel']
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def image_categories_set(self) -> list[tuple[str, str]]:
|
|
269
|
+
"""
|
|
270
|
+
Returns the set of categories in the dataset.
|
|
271
|
+
This is more related to multi-class tasks.
|
|
272
|
+
"""
|
|
273
|
+
return self.image_lsets['multiclass']
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def segmentation_labels_set(self) -> list[str]:
|
|
277
|
+
"""
|
|
278
|
+
Returns the set of segmentation labels in the dataset.
|
|
279
|
+
"""
|
|
280
|
+
return self.frame_lsets['segmentation']
|
|
281
|
+
|
|
282
|
+
def _get_annotations_internal(self,
|
|
283
|
+
annotations: list[dict],
|
|
284
|
+
type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
|
|
285
|
+
scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
|
|
286
|
+
# check parameters
|
|
287
|
+
if type not in ['label', 'category', 'segmentation', 'all']:
|
|
288
|
+
raise ValueError(f"Invalid value for 'type': {type}")
|
|
289
|
+
if scope not in ['frame', 'image', 'all']:
|
|
290
|
+
raise ValueError(f"Invalid value for 'scope': {scope}")
|
|
291
|
+
|
|
292
|
+
annots = []
|
|
293
|
+
for ann in annotations:
|
|
294
|
+
ann_scope = 'image' if ann.get('index', None) is None else 'frame'
|
|
295
|
+
if (type == 'all' or ann['type'] == type) and (scope == 'all' or scope == ann_scope):
|
|
296
|
+
annots.append(ann)
|
|
297
|
+
return annots
|
|
298
|
+
|
|
299
|
+
def get_annotations(self,
|
|
300
|
+
index: int,
|
|
301
|
+
type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
|
|
302
|
+
scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
|
|
303
|
+
"""
|
|
304
|
+
Returns the annotations of the image at the given index.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
index (int): Index of the image.
|
|
308
|
+
type (str): The type of the annotations. It can be 'label', 'category', 'segmentation' or 'all'.
|
|
309
|
+
scope (str): The scope of the annotations. It can be 'frame', 'image' or 'all'.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
List[Dict]: The annotations of the image.
|
|
313
|
+
"""
|
|
314
|
+
if index >= len(self):
|
|
315
|
+
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
|
|
316
|
+
imginfo = self._get_image_metainfo(index)
|
|
317
|
+
return self._get_annotations_internal(imginfo['annotations'], type=type, scope=scope)
|
|
318
|
+
|
|
319
|
+
@staticmethod
|
|
320
|
+
def read_number_of_frames(filepath: str) -> int:
|
|
321
|
+
# if is dicom
|
|
322
|
+
if is_dicom(filepath):
|
|
323
|
+
ds = pydicom.dcmread(filepath)
|
|
324
|
+
return ds.NumberOfFrames if hasattr(ds, 'NumberOfFrames') else 1
|
|
325
|
+
# if is a video
|
|
326
|
+
elif filepath.endswith('.mp4') or filepath.endswith('.avi'):
|
|
327
|
+
cap = cv2.VideoCapture(filepath)
|
|
328
|
+
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
329
|
+
# if is a image
|
|
330
|
+
elif filepath.endswith('.png') or filepath.endswith('.jpg') or filepath.endswith('.jpeg'):
|
|
331
|
+
return 1
|
|
332
|
+
else:
|
|
333
|
+
raise ValueError(f"Unsupported file type: {filepath}")
|
|
334
|
+
|
|
335
|
+
def get_resources_ids(self) -> list[str]:
|
|
336
|
+
return [self.__getitem_internal(i, only_load_metainfo=True)['metainfo']['id'] for i in self.subset_indices]
|
|
337
|
+
|
|
338
|
+
def _get_labels_set(self, framed: bool) -> tuple[dict, dict[str, dict[str, int]]]:
|
|
339
|
+
"""
|
|
340
|
+
Returns the set of labels and a dictionary that maps labels to integers.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Tuple[List[str], Dict[str, int]]: The set of labels and the dictionary that maps labels to integers
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
scope = 'frame' if framed else 'image'
|
|
347
|
+
|
|
348
|
+
multilabel_set = set()
|
|
349
|
+
segmentation_labels = set()
|
|
350
|
+
multiclass_set = set()
|
|
351
|
+
|
|
352
|
+
for i in range(len(self)):
|
|
353
|
+
anns = self.get_annotations(i, type='label', scope=scope)
|
|
354
|
+
multilabel_set.update([ann['name'] for ann in anns])
|
|
355
|
+
|
|
356
|
+
anns = self.get_annotations(i, type='segmentation', scope=scope)
|
|
357
|
+
segmentation_labels.update([ann['name'] for ann in anns])
|
|
358
|
+
|
|
359
|
+
anns = self.get_annotations(i, type='category', scope=scope)
|
|
360
|
+
multiclass_set.update([(ann['name'], ann['value']) for ann in anns])
|
|
361
|
+
|
|
362
|
+
multilabel_set = sorted(list(multilabel_set))
|
|
363
|
+
multiclass_set = sorted(list(multiclass_set))
|
|
364
|
+
segmentation_labels = sorted(list(segmentation_labels))
|
|
365
|
+
|
|
366
|
+
multilabel2code = {label: idx for idx, label in enumerate(multilabel_set)}
|
|
367
|
+
segmentation_label2code = {label: idx+1 for idx, label in enumerate(segmentation_labels)}
|
|
368
|
+
multiclass2code = {label: idx for idx, label in enumerate(multiclass_set)}
|
|
369
|
+
|
|
370
|
+
sets = {'multilabel': multilabel_set,
|
|
371
|
+
'segmentation': segmentation_labels,
|
|
372
|
+
'multiclass': multiclass_set}
|
|
373
|
+
codes_map = {'multilabel': multilabel2code,
|
|
374
|
+
'segmentation': segmentation_label2code,
|
|
375
|
+
'multiclass': multiclass2code}
|
|
376
|
+
return sets, codes_map
|
|
377
|
+
|
|
378
|
+
def get_framelabel_distribution(self, normalize=False) -> dict[str, float]:
|
|
379
|
+
"""
|
|
380
|
+
Returns the distribution of labels in the dataset.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
Dict[str, int]: The distribution of labels in the dataset.
|
|
384
|
+
"""
|
|
385
|
+
label_distribution = {label: 0 for label in self.frame_labels_set}
|
|
386
|
+
for imginfo in self.images_metainfo:
|
|
387
|
+
for ann in imginfo['annotations']:
|
|
388
|
+
if ann['type'] == 'label' and ann['index'] is not None:
|
|
389
|
+
label_distribution[ann['name']] += 1
|
|
390
|
+
|
|
391
|
+
if normalize:
|
|
392
|
+
total = sum(label_distribution.values())
|
|
393
|
+
if total == 0:
|
|
394
|
+
return label_distribution
|
|
395
|
+
label_distribution = {k: v/total for k, v in label_distribution.items()}
|
|
396
|
+
return label_distribution
|
|
397
|
+
|
|
398
|
+
def get_segmentationlabel_distribution(self, normalize=False) -> dict[str, float]:
|
|
399
|
+
"""
|
|
400
|
+
Returns the distribution of segmentation labels in the dataset.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
Dict[str, int]: The distribution of segmentation labels in the dataset.
|
|
404
|
+
"""
|
|
405
|
+
label_distribution = {label: 0 for label in self.segmentation_labels_set}
|
|
406
|
+
for imginfo in self.images_metainfo:
|
|
407
|
+
if 'annotations' in imginfo and imginfo['annotations'] is not None:
|
|
408
|
+
for ann in imginfo['annotations']:
|
|
409
|
+
if ann['type'] == 'segmentation':
|
|
410
|
+
label_distribution[ann['name']] += 1
|
|
411
|
+
|
|
412
|
+
if normalize:
|
|
413
|
+
total = sum(label_distribution.values())
|
|
414
|
+
if total == 0:
|
|
415
|
+
return label_distribution
|
|
416
|
+
label_distribution = {k: v/total for k, v in label_distribution.items()}
|
|
417
|
+
return label_distribution
|
|
418
|
+
|
|
419
|
+
def _check_integrity(self):
|
|
420
|
+
for imginfo in self.images_metainfo:
|
|
421
|
+
if not os.path.isfile(os.path.join(self.dataset_dir, imginfo['file'])):
|
|
422
|
+
raise DatamintDatasetException(f"Image file {imginfo['file']} not found.")
|
|
423
|
+
|
|
424
|
+
def _get_datasetinfo(self) -> dict:
|
|
425
|
+
all_datasets = self.api_handler.get_datasets()
|
|
426
|
+
|
|
427
|
+
value_to_search = self.dataset_id
|
|
428
|
+
field_to_search = 'id'
|
|
429
|
+
|
|
430
|
+
for d in all_datasets:
|
|
431
|
+
if d[field_to_search] == value_to_search:
|
|
432
|
+
return d
|
|
433
|
+
|
|
434
|
+
available_datasets = [(d['name'], d['id']) for d in all_datasets]
|
|
435
|
+
raise DatamintDatasetException(
|
|
436
|
+
f"Dataset with {field_to_search} '{value_to_search}' not found." +
|
|
437
|
+
f" Available datasets: {available_datasets}"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def get_info(self) -> dict:
|
|
441
|
+
project = self.api_handler.get_project_by_name(self.project_name)
|
|
442
|
+
if 'error' in project:
|
|
443
|
+
available_projects = project['all_projects']
|
|
444
|
+
raise DatamintDatasetException(
|
|
445
|
+
f"Project with name '{self.project_name}' not found. Available projects: {available_projects}"
|
|
446
|
+
)
|
|
447
|
+
return project
|
|
448
|
+
|
|
449
|
+
def _run_request(self, session, request_args) -> requests.Response:
|
|
450
|
+
response = session.request(**request_args)
|
|
451
|
+
if response.status_code == 400:
|
|
452
|
+
_LOGGER.error(f"Bad request: {response.text}")
|
|
453
|
+
response.raise_for_status()
|
|
454
|
+
return response
|
|
455
|
+
|
|
456
|
+
def _get_jwttoken(self, dataset_id, session) -> str:
|
|
457
|
+
if dataset_id is None:
|
|
458
|
+
raise ValueError("Dataset ID is required to download the dataset.")
|
|
459
|
+
request_params = {
|
|
460
|
+
'method': 'GET',
|
|
461
|
+
'url': f'{self.server_url}/datasets/{dataset_id}/download/png',
|
|
462
|
+
'headers': {'apikey': self.api_key},
|
|
463
|
+
'stream': True
|
|
464
|
+
}
|
|
465
|
+
_LOGGER.debug(f"Getting jwt token for dataset {dataset_id}...")
|
|
466
|
+
response = self._run_request(session, request_params)
|
|
467
|
+
progress_bar = None
|
|
468
|
+
number_processed_images = 0
|
|
469
|
+
|
|
470
|
+
# check if the response is a stream of data and everything is ok
|
|
471
|
+
if response.status_code != 200:
|
|
472
|
+
msg = f"Getting jwt token failed with status code={response.status_code}: {response.text}"
|
|
473
|
+
raise DatamintDatasetException(msg)
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
response_iterator = response.iter_lines(decode_unicode=True)
|
|
477
|
+
for line in response_iterator:
|
|
478
|
+
line = line.strip()
|
|
479
|
+
if 'event: error' in line:
|
|
480
|
+
error_msg = line+'\n'
|
|
481
|
+
error_msg += '\n'.join(response_iterator)
|
|
482
|
+
raise DatamintDatasetException(f"Getting jwt token failed:\n{error_msg}")
|
|
483
|
+
if not line.startswith('data:'):
|
|
484
|
+
continue
|
|
485
|
+
dataline = yaml.safe_load(line)['data']
|
|
486
|
+
if 'zip' in dataline:
|
|
487
|
+
_LOGGER.debug(f"Got jwt token for dataset {dataset_id}")
|
|
488
|
+
return dataline['zip'] # Function normally ends here
|
|
489
|
+
elif 'processedImages' in dataline:
|
|
490
|
+
if progress_bar is None:
|
|
491
|
+
total_size = int(dataline['totalImages'])
|
|
492
|
+
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
|
493
|
+
processed_images = int(dataline['processedImages'])
|
|
494
|
+
if number_processed_images < processed_images:
|
|
495
|
+
progress_bar.update(processed_images - number_processed_images)
|
|
496
|
+
number_processed_images = processed_images
|
|
497
|
+
else:
|
|
498
|
+
_LOGGER.warning(f"Unknown data line: {dataline}")
|
|
499
|
+
except Exception as e:
|
|
500
|
+
raise e
|
|
501
|
+
finally:
|
|
502
|
+
if progress_bar is not None:
|
|
503
|
+
progress_bar.close()
|
|
504
|
+
|
|
505
|
+
raise DatamintDatasetException("Getting jwt token failed! No dataline with 'zip' entry found.")
|
|
506
|
+
|
|
507
|
+
def __repr__(self) -> str:
|
|
508
|
+
"""
|
|
509
|
+
Example:
|
|
510
|
+
.. code-block:: python
|
|
511
|
+
|
|
512
|
+
print(dataset)
|
|
513
|
+
|
|
514
|
+
Output:
|
|
515
|
+
|
|
516
|
+
.. code-block:: text
|
|
517
|
+
|
|
518
|
+
Dataset DatamintDataset
|
|
519
|
+
Number of datapoints: 3
|
|
520
|
+
Root location: /home/user/.datamint/datasets
|
|
521
|
+
|
|
522
|
+
"""
|
|
523
|
+
head = f"Dataset {self.project_name}"
|
|
524
|
+
body = [f"Number of datapoints: {self.__len__()}"]
|
|
525
|
+
if self.root is not None:
|
|
526
|
+
body.append(f"Location: {self.dataset_dir}")
|
|
527
|
+
|
|
528
|
+
# Add filter information to representation
|
|
529
|
+
if self.include_annotators is not None:
|
|
530
|
+
body += [f"Including only annotators: {self.include_annotators}"]
|
|
531
|
+
if self.exclude_annotators is not None:
|
|
532
|
+
body += [f"Excluding annotators: {self.exclude_annotators}"]
|
|
533
|
+
if self.include_segmentation_names is not None:
|
|
534
|
+
body += [f"Including only segmentations: {self.include_segmentation_names}"]
|
|
535
|
+
if self.exclude_segmentation_names is not None:
|
|
536
|
+
body += [f"Excluding segmentations: {self.exclude_segmentation_names}"]
|
|
537
|
+
if self.include_image_label_names is not None:
|
|
538
|
+
body += [f"Including only image labels: {self.include_image_label_names}"]
|
|
539
|
+
if self.exclude_image_label_names is not None:
|
|
540
|
+
body += [f"Excluding image labels: {self.exclude_image_label_names}"]
|
|
541
|
+
if self.include_frame_label_names is not None:
|
|
542
|
+
body += [f"Including only frame labels: {self.include_frame_label_names}"]
|
|
543
|
+
if self.exclude_frame_label_names is not None:
|
|
544
|
+
body += [f"Excluding frame labels: {self.exclude_frame_label_names}"]
|
|
545
|
+
|
|
546
|
+
lines = [head] + [" " * 4 + line for line in body]
|
|
547
|
+
return "\n".join(lines)
|
|
548
|
+
|
|
549
|
+
def download_project(self):
|
|
550
|
+
from torchvision.datasets.utils import extract_archive
|
|
551
|
+
|
|
552
|
+
dataset_info = self._get_datasetinfo()
|
|
553
|
+
self.dataset_id = dataset_info['id']
|
|
554
|
+
self.last_updaded_at = dataset_info['updated_at']
|
|
555
|
+
|
|
556
|
+
self.api_handler.download_project(self.project_info['id'],
|
|
557
|
+
self.dataset_zippath,
|
|
558
|
+
all_annotations=self.all_annotations,
|
|
559
|
+
include_unannotated=self.include_unannotated)
|
|
560
|
+
_LOGGER.debug(f"Downloaded dataset")
|
|
561
|
+
downloaded_size = os.path.getsize(self.dataset_zippath)
|
|
562
|
+
if downloaded_size == 0:
|
|
563
|
+
raise DatamintDatasetException("Download failed.")
|
|
564
|
+
|
|
565
|
+
if os.path.exists(self.dataset_dir):
|
|
566
|
+
_LOGGER.info(f"Deleting existing dataset directory: {self.dataset_dir}")
|
|
567
|
+
shutil.rmtree(self.dataset_dir)
|
|
568
|
+
extract_archive(self.dataset_zippath,
|
|
569
|
+
self.dataset_dir,
|
|
570
|
+
remove_finished=True
|
|
571
|
+
)
|
|
572
|
+
datasetjson = os.path.join(self.dataset_dir, 'dataset.json')
|
|
573
|
+
with open(datasetjson, 'r') as file:
|
|
574
|
+
self.metainfo = json.load(file)
|
|
575
|
+
if 'updated_at' not in self.metainfo:
|
|
576
|
+
self.metainfo['updated_at'] = self.last_updaded_at
|
|
577
|
+
else:
|
|
578
|
+
# if self.last_updated_at is newer than the one in the dataset, update it
|
|
579
|
+
try:
|
|
580
|
+
if datetime.fromisoformat(self.metainfo['updated_at']) < datetime.fromisoformat(self.last_updaded_at):
|
|
581
|
+
_LOGGER.warning(f"Inconsistent updated_at dates detected ({self.metainfo['updated_at']} < {self.last_updaded_at})." +
|
|
582
|
+
f"Fixing it to {self.last_updaded_at}")
|
|
583
|
+
self.metainfo['updated_at'] = self.last_updaded_at
|
|
584
|
+
except Exception as e:
|
|
585
|
+
_LOGGER.warning(f"Failed to parse updated_at date: {e}")
|
|
586
|
+
|
|
587
|
+
# Add all_annotations to the metadata
|
|
588
|
+
self.metainfo['all_annotations'] = self.all_annotations
|
|
589
|
+
|
|
590
|
+
# save the updated_at date
|
|
591
|
+
with open(datasetjson, 'w') as file:
|
|
592
|
+
json.dump(self.metainfo, file)
|
|
593
|
+
|
|
594
|
+
def _load_image(self, filepath: str, index: int = None) -> tuple[torch.Tensor, pydicom.FileDataset]:
|
|
595
|
+
if os.path.isdir(filepath):
|
|
596
|
+
raise NotImplementedError("Loading a image from a directory is not supported yet.")
|
|
597
|
+
|
|
598
|
+
if self.return_frame_by_frame:
|
|
599
|
+
img, ds = read_array_normalized(filepath, return_metainfo=True, index=index)
|
|
600
|
+
else:
|
|
601
|
+
img, ds = read_array_normalized(filepath, return_metainfo=True)
|
|
602
|
+
|
|
603
|
+
if img.dtype == np.uint16:
|
|
604
|
+
# Pytorch doesn't support uint16
|
|
605
|
+
if self.__logged_uint16_conversion == False:
|
|
606
|
+
_LOGGER.info("Original image is uint16, converting to uint8")
|
|
607
|
+
self.__logged_uint16_conversion = True
|
|
608
|
+
|
|
609
|
+
# min-max normalization
|
|
610
|
+
img = img.astype(np.float32)
|
|
611
|
+
img = (img - img.min()) / (img.max() - img.min()) * 255
|
|
612
|
+
img = img.astype(np.uint8)
|
|
613
|
+
|
|
614
|
+
img = torch.from_numpy(img).contiguous()
|
|
615
|
+
if isinstance(img, torch.ByteTensor):
|
|
616
|
+
img = img.to(dtype=torch.get_default_dtype()).div(255)
|
|
617
|
+
|
|
618
|
+
return img, ds
|
|
619
|
+
|
|
620
|
+
def _get_image_metainfo(self, index: int, bypass_subset_indices=False) -> dict[str, Any]:
|
|
621
|
+
if bypass_subset_indices == False:
|
|
622
|
+
index = self.subset_indices[index]
|
|
623
|
+
if self.return_frame_by_frame:
|
|
624
|
+
# Find the correct filepath and index
|
|
625
|
+
resource_id, frame_index = self.__find_index(index)
|
|
626
|
+
|
|
627
|
+
img_metainfo = self.images_metainfo[resource_id]
|
|
628
|
+
img_metainfo = dict(img_metainfo) # copy
|
|
629
|
+
# insert frame index
|
|
630
|
+
img_metainfo['frame_index'] = frame_index
|
|
631
|
+
img_metainfo['annotations'] = [ann for ann in img_metainfo['annotations']
|
|
632
|
+
if ann['index'] is None or ann['index'] == frame_index]
|
|
633
|
+
else:
|
|
634
|
+
img_metainfo = self.images_metainfo[index]
|
|
635
|
+
return img_metainfo
|
|
636
|
+
|
|
637
|
+
def __find_index(self, index: int) -> tuple[int, int]:
|
|
638
|
+
frame_index = index
|
|
639
|
+
for i, num_frames in enumerate(self.num_frames_per_resource):
|
|
640
|
+
if frame_index < num_frames:
|
|
641
|
+
break
|
|
642
|
+
frame_index -= num_frames
|
|
643
|
+
else:
|
|
644
|
+
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
|
|
645
|
+
|
|
646
|
+
return i, frame_index
|
|
647
|
+
|
|
648
|
+
def __getitem_internal(self, index: int, only_load_metainfo=False) -> dict[str, Any]:
|
|
649
|
+
if self.return_frame_by_frame:
|
|
650
|
+
resource_index, frame_idx = self.__find_index(index)
|
|
651
|
+
else:
|
|
652
|
+
resource_index = index
|
|
653
|
+
frame_idx = None
|
|
654
|
+
img_metainfo = self._get_image_metainfo(index, bypass_subset_indices=True)
|
|
655
|
+
|
|
656
|
+
if only_load_metainfo:
|
|
657
|
+
return {'metainfo': img_metainfo}
|
|
658
|
+
|
|
659
|
+
filepath = os.path.join(self.dataset_dir, img_metainfo['file'])
|
|
660
|
+
|
|
661
|
+
# Can be multi-frame, Gray-scale and/or RGB. So the shape is really variable, but it's always a numpy array.
|
|
662
|
+
img, ds = self._load_image(filepath, frame_idx)
|
|
663
|
+
|
|
664
|
+
ret = {'image': img}
|
|
665
|
+
|
|
666
|
+
if self.return_dicom:
|
|
667
|
+
ret['dicom'] = ds
|
|
668
|
+
if self.return_metainfo:
|
|
669
|
+
ret['metainfo'] = {k: v for k, v in img_metainfo.items() if k != 'annotations'}
|
|
670
|
+
if self.return_annotations:
|
|
671
|
+
ret['annotations'] = img_metainfo['annotations']
|
|
672
|
+
|
|
673
|
+
return ret
|
|
674
|
+
|
|
675
|
+
def _filter_annotations(self, annotations: list[dict]) -> list[dict]:
|
|
676
|
+
"""
|
|
677
|
+
Filter annotations based on the filtering settings.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
annotations: list of annotations
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
list[dict]: filtered list of annotations
|
|
684
|
+
"""
|
|
685
|
+
if annotations is None:
|
|
686
|
+
return []
|
|
687
|
+
|
|
688
|
+
filtered_annotations = []
|
|
689
|
+
for ann in annotations:
|
|
690
|
+
# Filter by annotator
|
|
691
|
+
if not self._should_include_annotator(ann['added_by']):
|
|
692
|
+
continue
|
|
693
|
+
|
|
694
|
+
# Filter by annotation type and name
|
|
695
|
+
if ann['type'] == 'segmentation':
|
|
696
|
+
if not self._should_include_segmentation(ann['name']):
|
|
697
|
+
continue
|
|
698
|
+
elif ann['type'] == 'label':
|
|
699
|
+
# Check if it's a frame or image label
|
|
700
|
+
if ann.get('index', None) is None:
|
|
701
|
+
# Image label
|
|
702
|
+
if not self._should_include_image_label(ann['name']):
|
|
703
|
+
continue
|
|
704
|
+
else:
|
|
705
|
+
# Frame label
|
|
706
|
+
if not self._should_include_frame_label(ann['name']):
|
|
707
|
+
continue
|
|
708
|
+
|
|
709
|
+
# If we reach here, the annotation passed all filters
|
|
710
|
+
filtered_annotations.append(ann)
|
|
711
|
+
|
|
712
|
+
return filtered_annotations
|
|
713
|
+
|
|
714
|
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
|
715
|
+
"""
|
|
716
|
+
Args:
|
|
717
|
+
index (int): Index
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
dict: A dictionary containing three keys: 'image', 'metainfo' and 'annotations'.
|
|
721
|
+
"""
|
|
722
|
+
if index >= len(self):
|
|
723
|
+
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
|
|
724
|
+
|
|
725
|
+
return self.__getitem_internal(self.subset_indices[index])
|
|
726
|
+
|
|
727
|
+
def __iter__(self):
|
|
728
|
+
for i in range(len(self)):
|
|
729
|
+
yield self[i]
|
|
730
|
+
|
|
731
|
+
def __len__(self) -> int:
|
|
732
|
+
return len(self.subset_indices)
|
|
733
|
+
|
|
734
|
+
def _check_version(self):
|
|
735
|
+
metainfo_path = os.path.join(self.dataset_dir, 'dataset.json')
|
|
736
|
+
if not os.path.exists(metainfo_path):
|
|
737
|
+
self.download_project()
|
|
738
|
+
return
|
|
739
|
+
with open(metainfo_path, 'r') as file:
|
|
740
|
+
local_dataset_info = json.load(file)
|
|
741
|
+
local_updated_at = local_dataset_info.get('updated_at', None)
|
|
742
|
+
local_all_annotations = local_dataset_info.get('all_annotations', None)
|
|
743
|
+
|
|
744
|
+
try:
|
|
745
|
+
external_metadata_info = self._get_datasetinfo()
|
|
746
|
+
server_updated_at = external_metadata_info['updated_at']
|
|
747
|
+
except Exception as e:
|
|
748
|
+
_LOGGER.warning(f"Failed to check for updates in {self.project_name}: {e}")
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
_LOGGER.debug(f"Local updated at: {local_updated_at}, Server updated at: {server_updated_at}")
|
|
752
|
+
|
|
753
|
+
# Check if all_annotations changed or doesn't exist
|
|
754
|
+
annotations_changed = local_all_annotations != self.all_annotations
|
|
755
|
+
|
|
756
|
+
if local_updated_at is None or local_updated_at < server_updated_at or annotations_changed:
|
|
757
|
+
if annotations_changed:
|
|
758
|
+
_LOGGER.info(
|
|
759
|
+
f"The 'all_annotations' parameter has changed. Previous: {local_all_annotations}, Current: {self.all_annotations}."
|
|
760
|
+
)
|
|
761
|
+
else:
|
|
762
|
+
_LOGGER.info(
|
|
763
|
+
f"A newer version of the dataset is available. Your version: {local_updated_at}." +
|
|
764
|
+
f" Last version: {server_updated_at}."
|
|
765
|
+
)
|
|
766
|
+
self.download_project()
|
|
767
|
+
else:
|
|
768
|
+
_LOGGER.info('Local version is up to date with the latest version.')
|
|
769
|
+
|
|
770
|
+
def __add__(self, other):
|
|
771
|
+
from torch.utils.data import ConcatDataset
|
|
772
|
+
return ConcatDataset([self, other])
|
|
773
|
+
|
|
774
|
+
def get_dataloader(self, *args, **kwargs) -> DataLoader:
|
|
775
|
+
"""
|
|
776
|
+
Returns a DataLoader for the dataset.
|
|
777
|
+
This is a wrapper around the PyTorch DataLoader, with the convinience of using a nice collate_fn
|
|
778
|
+
that properly handles the different types of data in this dataset.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
*args: Positional arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
|
|
782
|
+
**kwargs: Keyword arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
"""
|
|
786
|
+
return DataLoader(self,
|
|
787
|
+
*args,
|
|
788
|
+
collate_fn=self.get_collate_fn(),
|
|
789
|
+
**kwargs)
|
|
790
|
+
|
|
791
|
+
def get_collate_fn(self) -> Callable:
|
|
792
|
+
def collate_fn(batch: dict) -> dict:
|
|
793
|
+
keys = batch[0].keys()
|
|
794
|
+
collated_batch = {}
|
|
795
|
+
for key in keys:
|
|
796
|
+
collated_batch[key] = [item[key] for item in batch]
|
|
797
|
+
if isinstance(collated_batch[key][0], torch.Tensor):
|
|
798
|
+
# check if every tensor has the same shape
|
|
799
|
+
shapes = [tensor.shape for tensor in collated_batch[key]]
|
|
800
|
+
if all(shape == shapes[0] for shape in shapes):
|
|
801
|
+
collated_batch[key] = torch.stack(collated_batch[key])
|
|
802
|
+
else:
|
|
803
|
+
_LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}. ")
|
|
804
|
+
elif isinstance(collated_batch[key][0], np.ndarray):
|
|
805
|
+
collated_batch[key] = np.stack(collated_batch[key])
|
|
806
|
+
|
|
807
|
+
return collated_batch
|
|
808
|
+
|
|
809
|
+
return collate_fn
|
|
810
|
+
|
|
811
|
+
def subset(self, indices: list[int]) -> 'DatamintBaseDataset':
|
|
812
|
+
if len(self.subset_indices) > self.dataset_length:
|
|
813
|
+
raise ValueError(f"Subset indices must be less than the dataset length: {self.dataset_length}")
|
|
814
|
+
|
|
815
|
+
self.subset_indices = indices
|
|
816
|
+
|
|
817
|
+
return self
|
|
818
|
+
|
|
819
|
+
def _should_include_annotator(self, annotator_id: str) -> bool:
|
|
820
|
+
"""
|
|
821
|
+
Check if an annotator should be included based on the filtering settings.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
annotator_id: The ID of the annotator to check
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
bool: True if the annotator should be included, False otherwise
|
|
828
|
+
"""
|
|
829
|
+
if self.include_annotators is not None:
|
|
830
|
+
return annotator_id in self.include_annotators
|
|
831
|
+
if self.exclude_annotators is not None:
|
|
832
|
+
return annotator_id not in self.exclude_annotators
|
|
833
|
+
return True
|
|
834
|
+
|
|
835
|
+
def _should_include_segmentation(self, segmentation_name: str) -> bool:
|
|
836
|
+
"""
|
|
837
|
+
Check if a segmentation should be included based on the filtering settings.
|
|
838
|
+
|
|
839
|
+
Args:
|
|
840
|
+
segmentation_name: The name of the segmentation to check
|
|
841
|
+
|
|
842
|
+
Returns:
|
|
843
|
+
bool: True if the segmentation should be included, False otherwise
|
|
844
|
+
"""
|
|
845
|
+
if self.include_segmentation_names is not None:
|
|
846
|
+
return segmentation_name in self.include_segmentation_names
|
|
847
|
+
if self.exclude_segmentation_names is not None:
|
|
848
|
+
return segmentation_name not in self.exclude_segmentation_names
|
|
849
|
+
return True
|
|
850
|
+
|
|
851
|
+
def _should_include_image_label(self, label_name: str) -> bool:
|
|
852
|
+
"""
|
|
853
|
+
Check if an image label should be included based on the filtering settings.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
label_name: The name of the image label to check
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
bool: True if the image label should be included, False otherwise
|
|
860
|
+
"""
|
|
861
|
+
if self.include_image_label_names is not None:
|
|
862
|
+
return label_name in self.include_image_label_names
|
|
863
|
+
if self.exclude_image_label_names is not None:
|
|
864
|
+
return label_name not in self.exclude_image_label_names
|
|
865
|
+
return True
|
|
866
|
+
|
|
867
|
+
def _should_include_frame_label(self, label_name: str) -> bool:
|
|
868
|
+
"""
|
|
869
|
+
Check if a frame label should be included based on the filtering settings.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
label_name: The name of the frame label to check
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
bool: True if the frame label should be included, False otherwise
|
|
876
|
+
"""
|
|
877
|
+
if self.include_frame_label_names is not None:
|
|
878
|
+
return label_name in self.include_frame_label_names
|
|
879
|
+
if self.exclude_frame_label_names is not None:
|
|
880
|
+
return label_name not in self.exclude_frame_label_names
|
|
881
|
+
return True
|