datamint 1.6.3.post1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/apihandler/annotation_api_handler.py +125 -3
- datamint/apihandler/base_api_handler.py +30 -26
- datamint/apihandler/root_api_handler.py +160 -36
- datamint/dataset/annotation.py +221 -0
- datamint/dataset/base_dataset.py +735 -483
- datamint/dataset/dataset.py +33 -16
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/METADATA +1 -1
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/RECORD +10 -9
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/WHEEL +0 -0
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/entry_points.txt +0 -0
datamint/dataset/base_dataset.py
CHANGED
|
@@ -15,9 +15,12 @@ import torch
|
|
|
15
15
|
from torch import Tensor
|
|
16
16
|
from datamint.apihandler.base_api_handler import DatamintException
|
|
17
17
|
from medimgkit.dicom_utils import is_dicom
|
|
18
|
-
import cv2
|
|
19
18
|
from medimgkit.io_utils import read_array_normalized
|
|
20
19
|
from datetime import datetime
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from mimetypes import guess_extension
|
|
22
|
+
from datamint.dataset.annotation import Annotation
|
|
23
|
+
import cv2
|
|
21
24
|
|
|
22
25
|
_LOGGER = logging.getLogger(__name__)
|
|
23
26
|
|
|
@@ -27,12 +30,11 @@ class DatamintDatasetException(DatamintException):
|
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
class DatamintBaseDataset:
|
|
30
|
-
"""
|
|
31
|
-
Class to download and load datasets from the Datamint API.
|
|
33
|
+
"""Class to download and load datasets from the Datamint API.
|
|
32
34
|
|
|
33
35
|
Args:
|
|
34
|
-
root: Root directory of dataset where data already exists or will be downloaded.
|
|
35
36
|
project_name: Name of the project to download.
|
|
37
|
+
root: Root directory of dataset where data already exists or will be downloaded.
|
|
36
38
|
auto_update: If True, the dataset will be checked for updates and downloaded if necessary.
|
|
37
39
|
api_key: API key to access the Datamint API. If not provided, it will look for the
|
|
38
40
|
environment variable 'DATAMINT_API_KEY'. Not necessary if
|
|
@@ -41,68 +43,114 @@ class DatamintBaseDataset:
|
|
|
41
43
|
return_metainfo: If True, the metainfo of the image will be returned.
|
|
42
44
|
return_annotations: If True, the annotations of the image will be returned.
|
|
43
45
|
return_frame_by_frame: If True, each frame of a video/DICOM/3d-image will be returned separately.
|
|
44
|
-
include_unannotated: If True, images without annotations will be included.
|
|
46
|
+
include_unannotated: If True, images without annotations will be included.
|
|
45
47
|
all_annotations: If True, all annotations will be downloaded, including the ones that are not set as closed/done.
|
|
46
48
|
server_url: URL of the Datamint server. If not provided, it will use the default server.
|
|
47
|
-
include_annotators: List of annotators to include. If None, all annotators will be included.
|
|
48
|
-
exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded.
|
|
49
|
+
include_annotators: List of annotators to include. If None, all annotators will be included.
|
|
50
|
+
exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded.
|
|
49
51
|
include_segmentation_names: List of segmentation names to include. If None, all segmentations will be included.
|
|
50
52
|
exclude_segmentation_names: List of segmentation names to exclude. If None, no segmentations will be excluded.
|
|
51
53
|
include_image_label_names: List of image label names to include. If None, all image labels will be included.
|
|
52
54
|
exclude_image_label_names: List of image label names to exclude. If None, no image labels will be excluded.
|
|
53
55
|
include_frame_label_names: List of frame label names to include. If None, all frame labels will be included.
|
|
54
56
|
exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
|
|
55
|
-
|
|
56
57
|
"""
|
|
57
58
|
|
|
58
59
|
DATAMINT_DEFAULT_DIR = ".datamint"
|
|
59
60
|
DATAMINT_DATASETS_DIR = "datasets"
|
|
60
61
|
|
|
61
|
-
def __init__(
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
project_name: str,
|
|
65
|
+
root: str | None = None,
|
|
66
|
+
auto_update: bool = True,
|
|
67
|
+
api_key: Optional[str] = None,
|
|
68
|
+
server_url: Optional[str] = None,
|
|
69
|
+
return_dicom: bool = False,
|
|
70
|
+
return_metainfo: bool = True,
|
|
71
|
+
return_annotations: bool = True,
|
|
72
|
+
return_frame_by_frame: bool = False,
|
|
73
|
+
include_unannotated: bool = True,
|
|
74
|
+
all_annotations: bool = False,
|
|
75
|
+
# Filtering parameters
|
|
76
|
+
include_annotators: Optional[list[str]] = None,
|
|
77
|
+
exclude_annotators: Optional[list[str]] = None,
|
|
78
|
+
include_segmentation_names: Optional[list[str]] = None,
|
|
79
|
+
exclude_segmentation_names: Optional[list[str]] = None,
|
|
80
|
+
include_image_label_names: Optional[list[str]] = None,
|
|
81
|
+
exclude_image_label_names: Optional[list[str]] = None,
|
|
82
|
+
include_frame_label_names: Optional[list[str]] = None,
|
|
83
|
+
exclude_frame_label_names: Optional[list[str]] = None,
|
|
84
|
+
):
|
|
85
|
+
self._validate_inputs(project_name, include_annotators, exclude_annotators,
|
|
86
|
+
include_segmentation_names, exclude_segmentation_names,
|
|
87
|
+
include_image_label_names, exclude_image_label_names,
|
|
88
|
+
include_frame_label_names, exclude_frame_label_names)
|
|
89
|
+
|
|
90
|
+
self._initialize_config(
|
|
91
|
+
project_name, auto_update, all_annotations, return_dicom,
|
|
92
|
+
return_metainfo, return_annotations, return_frame_by_frame,
|
|
93
|
+
include_unannotated, include_annotators, exclude_annotators,
|
|
94
|
+
include_segmentation_names, exclude_segmentation_names,
|
|
95
|
+
include_image_label_names, exclude_image_label_names,
|
|
96
|
+
include_frame_label_names, exclude_frame_label_names
|
|
97
|
+
)
|
|
84
98
|
|
|
99
|
+
self._setup_api_handler(server_url, api_key, auto_update)
|
|
100
|
+
self._setup_directories(root)
|
|
101
|
+
self._setup_dataset()
|
|
102
|
+
self._post_process_data()
|
|
103
|
+
|
|
104
|
+
def _validate_inputs(
|
|
105
|
+
self,
|
|
106
|
+
project_name: str,
|
|
107
|
+
include_annotators: Optional[list[str]],
|
|
108
|
+
exclude_annotators: Optional[list[str]],
|
|
109
|
+
include_segmentation_names: Optional[list[str]],
|
|
110
|
+
exclude_segmentation_names: Optional[list[str]],
|
|
111
|
+
include_image_label_names: Optional[list[str]],
|
|
112
|
+
exclude_image_label_names: Optional[list[str]],
|
|
113
|
+
include_frame_label_names: Optional[list[str]],
|
|
114
|
+
exclude_frame_label_names: Optional[list[str]],
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Validate input parameters."""
|
|
85
117
|
if project_name is None:
|
|
86
118
|
raise ValueError("project_name is required.")
|
|
87
119
|
|
|
120
|
+
# Validate mutually exclusive filtering parameters
|
|
121
|
+
filter_pairs = [
|
|
122
|
+
(include_annotators, exclude_annotators, "annotators"),
|
|
123
|
+
(include_segmentation_names, exclude_segmentation_names, "segmentation_names"),
|
|
124
|
+
(include_image_label_names, exclude_image_label_names, "image_label_names"),
|
|
125
|
+
(include_frame_label_names, exclude_frame_label_names, "frame_label_names"),
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
for include_param, exclude_param, param_name in filter_pairs:
|
|
129
|
+
if include_param is not None and exclude_param is not None:
|
|
130
|
+
raise ValueError(f"Cannot set both include_{param_name} and exclude_{param_name} at the same time")
|
|
131
|
+
|
|
132
|
+
def _initialize_config(
|
|
133
|
+
self,
|
|
134
|
+
project_name: str,
|
|
135
|
+
auto_update: bool,
|
|
136
|
+
all_annotations: bool,
|
|
137
|
+
return_dicom: bool,
|
|
138
|
+
return_metainfo: bool,
|
|
139
|
+
return_annotations: bool,
|
|
140
|
+
return_frame_by_frame: bool,
|
|
141
|
+
include_unannotated: bool,
|
|
142
|
+
include_annotators: Optional[list[str]],
|
|
143
|
+
exclude_annotators: Optional[list[str]],
|
|
144
|
+
include_segmentation_names: Optional[list[str]],
|
|
145
|
+
exclude_segmentation_names: Optional[list[str]],
|
|
146
|
+
include_image_label_names: Optional[list[str]],
|
|
147
|
+
exclude_image_label_names: Optional[list[str]],
|
|
148
|
+
include_frame_label_names: Optional[list[str]],
|
|
149
|
+
exclude_frame_label_names: Optional[list[str]],
|
|
150
|
+
) -> None:
|
|
151
|
+
"""Initialize configuration parameters."""
|
|
152
|
+
self.project_name = project_name
|
|
88
153
|
self.all_annotations = all_annotations
|
|
89
|
-
self.api_handler = APIHandler(root_url=server_url, api_key=api_key,
|
|
90
|
-
check_connection=auto_update)
|
|
91
|
-
self.server_url = self.api_handler.root_url
|
|
92
|
-
if root is None:
|
|
93
|
-
# store them in the home directory
|
|
94
|
-
root = os.path.join(os.path.expanduser("~"),
|
|
95
|
-
DatamintBaseDataset.DATAMINT_DEFAULT_DIR)
|
|
96
|
-
root = os.path.join(root, DatamintBaseDataset.DATAMINT_DATASETS_DIR)
|
|
97
|
-
if not os.path.exists(root):
|
|
98
|
-
os.makedirs(root)
|
|
99
|
-
elif isinstance(root, str):
|
|
100
|
-
root = os.path.expanduser(root)
|
|
101
|
-
if not os.path.isdir(root):
|
|
102
|
-
raise NotADirectoryError(f"Root directory not found: {root}")
|
|
103
|
-
|
|
104
|
-
self.root = root
|
|
105
|
-
|
|
106
154
|
self.return_dicom = return_dicom
|
|
107
155
|
self.return_metainfo = return_metainfo
|
|
108
156
|
self.return_frame_by_frame = return_frame_by_frame
|
|
@@ -120,108 +168,153 @@ class DatamintBaseDataset:
|
|
|
120
168
|
self.include_frame_label_names = include_frame_label_names
|
|
121
169
|
self.exclude_frame_label_names = exclude_frame_label_names
|
|
122
170
|
|
|
123
|
-
#
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
if include_segmentation_names is not None and exclude_segmentation_names is not None:
|
|
128
|
-
raise ValueError("Cannot set both include_segmentation_names and exclude_segmentation_names at the same time")
|
|
129
|
-
|
|
130
|
-
if include_image_label_names is not None and exclude_image_label_names is not None:
|
|
131
|
-
raise ValueError("Cannot set both include_image_label_names and exclude_image_label_names at the same time")
|
|
132
|
-
|
|
133
|
-
if include_frame_label_names is not None and exclude_frame_label_names is not None:
|
|
134
|
-
raise ValueError("Cannot set both include_frame_label_names and exclude_frame_label_names at the same time")
|
|
171
|
+
# Internal state
|
|
172
|
+
self.__logged_uint16_conversion = False
|
|
173
|
+
self.auto_update = auto_update
|
|
135
174
|
|
|
136
|
-
|
|
137
|
-
|
|
175
|
+
def _setup_api_handler(self, server_url: Optional[str], api_key: Optional[str], auto_update: bool) -> None:
|
|
176
|
+
"""Setup API handler and validate connection."""
|
|
177
|
+
from datamint.apihandler.api_handler import APIHandler
|
|
138
178
|
|
|
139
|
-
self.
|
|
140
|
-
|
|
179
|
+
self.api_handler = APIHandler(
|
|
180
|
+
root_url=server_url,
|
|
181
|
+
api_key=api_key,
|
|
182
|
+
check_connection=auto_update
|
|
183
|
+
)
|
|
184
|
+
self.server_url = self.api_handler.root_url
|
|
185
|
+
self.api_key = self.api_handler.api_key
|
|
141
186
|
|
|
142
|
-
|
|
187
|
+
if self.api_key is None:
|
|
188
|
+
_LOGGER.warning(
|
|
189
|
+
"API key not provided. If you want to download data, please provide an API key, "
|
|
190
|
+
f"either by passing it as an argument, "
|
|
191
|
+
f"setting environment variable {configs.ENV_VARS[configs.APIKEY_KEY]} or "
|
|
192
|
+
"using datamint-config command line tool."
|
|
193
|
+
)
|
|
143
194
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
195
|
+
def _setup_directories(self, root: str | None) -> None:
|
|
196
|
+
"""Setup root and dataset directories."""
|
|
197
|
+
if root is None:
|
|
198
|
+
root = os.path.join(
|
|
199
|
+
os.path.expanduser("~"),
|
|
200
|
+
self.DATAMINT_DEFAULT_DIR,
|
|
201
|
+
self.DATAMINT_DATASETS_DIR
|
|
202
|
+
)
|
|
203
|
+
os.makedirs(root, exist_ok=True)
|
|
147
204
|
else:
|
|
148
|
-
|
|
149
|
-
|
|
205
|
+
root = os.path.expanduser(root)
|
|
206
|
+
if not os.path.isdir(root):
|
|
207
|
+
raise NotADirectoryError(f"Root directory not found: {root}")
|
|
150
208
|
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
209
|
+
self.root = root
|
|
210
|
+
self.dataset_dir = os.path.join(root, self.project_name)
|
|
211
|
+
self.dataset_zippath = os.path.join(root, f'{self.project_name}.zip')
|
|
212
|
+
|
|
213
|
+
if not os.path.exists(self.dataset_dir):
|
|
214
|
+
os.makedirs(self.dataset_dir, exist_ok=True)
|
|
215
|
+
os.makedirs(os.path.join(self.dataset_dir, 'images'), exist_ok=True)
|
|
216
|
+
os.makedirs(os.path.join(self.dataset_dir, 'masks'), exist_ok=True)
|
|
217
|
+
|
|
218
|
+
def _setup_dataset(self) -> None:
|
|
219
|
+
"""Setup dataset by downloading or loading existing data."""
|
|
220
|
+
self._server_dataset_info = None
|
|
221
|
+
local_load_success = self._load_metadata()
|
|
222
|
+
self._handle_dataset_download_or_update(local_load_success)
|
|
223
|
+
self._apply_annotation_filters()
|
|
224
|
+
|
|
225
|
+
def _handle_dataset_download_or_update(self, local_load_success: bool) -> None:
|
|
226
|
+
"""Handle dataset download or update logic."""
|
|
227
|
+
|
|
228
|
+
if local_load_success:
|
|
229
|
+
_LOGGER.debug(f"Dataset directory already exists: {self.dataset_dir}")
|
|
230
|
+
# Check for updates if auto_update is enabled and we have API access
|
|
231
|
+
if self.auto_update:
|
|
163
232
|
_LOGGER.info("Checking for updates...")
|
|
164
233
|
self._check_version()
|
|
165
234
|
else:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
235
|
+
self._check_version()
|
|
236
|
+
|
|
237
|
+
def _load_metadata(self) -> bool:
|
|
238
|
+
"""Load and process dataset metadata."""
|
|
239
|
+
if hasattr(self, 'metainfo'):
|
|
240
|
+
_LOGGER.warning("Metadata already loaded.")
|
|
241
|
+
metadata_path = os.path.join(self.dataset_dir, 'dataset.json')
|
|
242
|
+
if not os.path.isfile(metadata_path):
|
|
243
|
+
# get the server info
|
|
244
|
+
self.project_info = self.get_info()
|
|
245
|
+
self.metainfo = self._get_datasetinfo().copy()
|
|
246
|
+
self.metainfo['updated_at'] = None
|
|
247
|
+
self.metainfo['resources'] = []
|
|
248
|
+
self.metainfo['all_annotations'] = self.all_annotations
|
|
249
|
+
self.images_metainfo = self.metainfo['resources']
|
|
250
|
+
return False
|
|
251
|
+
else:
|
|
252
|
+
with open(metadata_path, 'r') as file:
|
|
174
253
|
self.metainfo = json.load(file)
|
|
175
254
|
self.images_metainfo = self.metainfo['resources']
|
|
255
|
+
# Convert annotations from dict to Annotation objects
|
|
256
|
+
self._convert_metainfo_to_clsobj()
|
|
257
|
+
return True
|
|
176
258
|
|
|
177
|
-
|
|
259
|
+
def _convert_metainfo_to_clsobj(self):
|
|
260
|
+
for imginfo in self.images_metainfo:
|
|
261
|
+
if 'annotations' in imginfo:
|
|
262
|
+
for ann in imginfo['annotations']:
|
|
263
|
+
if 'resource_id' not in ann:
|
|
264
|
+
ann['resource_id'] = imginfo['id']
|
|
265
|
+
if 'id' not in ann:
|
|
266
|
+
ann['id'] = None
|
|
267
|
+
imginfo['annotations'] = [Annotation.from_dict(ann) if isinstance(ann, dict) else ann
|
|
268
|
+
for ann in imginfo['annotations']]
|
|
269
|
+
|
|
270
|
+
def _apply_annotation_filters(self) -> None:
|
|
271
|
+
"""Apply annotation filters and remove unannotated images if needed."""
|
|
272
|
+
# Filter annotations for each image
|
|
178
273
|
for imginfo in self.images_metainfo:
|
|
179
274
|
imginfo['annotations'] = self._filter_annotations(imginfo['annotations'])
|
|
180
275
|
|
|
181
|
-
#
|
|
276
|
+
# Filter out images with no annotations if needed
|
|
182
277
|
if self.discard_without_annotations:
|
|
183
278
|
original_count = len(self.images_metainfo)
|
|
184
279
|
self.images_metainfo = self._filter_items(self.images_metainfo)
|
|
185
280
|
_LOGGER.info(f"Discarded {original_count - len(self.images_metainfo)} images without annotations.")
|
|
186
281
|
|
|
282
|
+
def _post_process_data(self) -> None:
|
|
283
|
+
"""Post-process data after loading metadata."""
|
|
187
284
|
self._check_integrity()
|
|
285
|
+
self._calculate_dataset_length()
|
|
286
|
+
self._precompute_frame_data()
|
|
287
|
+
self._setup_labels()
|
|
188
288
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
# for imginfo in self.images_metainfo:
|
|
192
|
-
# if imginfo['frame_labels'] is not None:
|
|
193
|
-
# for flabels in imginfo['frame_labels']:
|
|
194
|
-
# if flabels['label'] is None:
|
|
195
|
-
# flabels['label'] = []
|
|
196
|
-
# elif isinstance(flabels['label'], str):
|
|
197
|
-
# flabels['label'] = flabels['label'].split(',')
|
|
289
|
+
if self.discard_without_annotations and self.return_frame_by_frame:
|
|
290
|
+
self._filter_unannotated()
|
|
198
291
|
|
|
292
|
+
def _calculate_dataset_length(self) -> None:
|
|
293
|
+
"""Calculate the total dataset length based on frame-by-frame setting."""
|
|
199
294
|
if self.return_frame_by_frame:
|
|
200
|
-
self.dataset_length =
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
295
|
+
self.dataset_length = sum(
|
|
296
|
+
self.read_number_of_frames(os.path.join(self.dataset_dir, imginfo['file']))
|
|
297
|
+
for imginfo in self.images_metainfo
|
|
298
|
+
)
|
|
204
299
|
else:
|
|
205
300
|
self.dataset_length = len(self.images_metainfo)
|
|
206
301
|
|
|
302
|
+
def _precompute_frame_data(self) -> None:
|
|
303
|
+
"""Precompute frame-related data for efficient indexing."""
|
|
207
304
|
self.num_frames_per_resource = self.__compute_num_frames_per_resource()
|
|
208
|
-
|
|
209
|
-
# Precompute cumulative frame counts for faster index lookup
|
|
210
305
|
self._cumulative_frames = np.cumsum([0] + self.num_frames_per_resource)
|
|
211
|
-
|
|
212
306
|
self.subset_indices = list(range(self.dataset_length))
|
|
213
|
-
|
|
307
|
+
|
|
308
|
+
def _setup_labels(self) -> None:
|
|
309
|
+
"""Setup label sets and mappings."""
|
|
214
310
|
self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
|
|
215
311
|
self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
|
|
216
|
-
self.__logged_uint16_conversion = False
|
|
217
|
-
if self.discard_without_annotations and self.return_frame_by_frame:
|
|
218
|
-
# If we are returning frame by frame, we need to filter out frames without segmentations
|
|
219
|
-
self._filter_unannotated()
|
|
220
312
|
|
|
221
313
|
def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
|
|
314
|
+
"""Filter items that have annotations."""
|
|
222
315
|
return [img for img in images_metainfo if len(img.get('annotations', []))]
|
|
223
316
|
|
|
224
|
-
def _filter_unannotated(self):
|
|
317
|
+
def _filter_unannotated(self) -> None:
|
|
225
318
|
"""Filter out frames that don't have any segmentations."""
|
|
226
319
|
filtered_indices = []
|
|
227
320
|
for i in range(len(self.subset_indices)):
|
|
@@ -229,124 +322,125 @@ class DatamintBaseDataset:
|
|
|
229
322
|
annotations = item_meta.get('annotations', [])
|
|
230
323
|
|
|
231
324
|
# Check if there are any segmentation annotations
|
|
232
|
-
has_segmentations = any(ann
|
|
325
|
+
has_segmentations = any(ann.type == 'segmentation' for ann in annotations)
|
|
233
326
|
|
|
234
327
|
if has_segmentations:
|
|
235
328
|
filtered_indices.append(self.subset_indices[i])
|
|
236
329
|
|
|
237
330
|
self.subset_indices = filtered_indices
|
|
238
|
-
|
|
331
|
+
_LOGGER.debug(f"Filtered dataset: {len(self.subset_indices)} frames with segmentations")
|
|
239
332
|
|
|
240
333
|
def __compute_num_frames_per_resource(self) -> list[int]:
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
334
|
+
"""Compute number of frames for each resource."""
|
|
335
|
+
return [
|
|
336
|
+
self.read_number_of_frames(os.path.join(self.dataset_dir, imginfo['file']))
|
|
337
|
+
for imginfo in self.images_metainfo
|
|
338
|
+
]
|
|
246
339
|
|
|
247
340
|
@property
|
|
248
341
|
def frame_labels_set(self) -> list[str]:
|
|
249
|
-
"""
|
|
250
|
-
Returns the set of independent labels in the dataset.
|
|
251
|
-
This is more related to multi-label tasks.
|
|
252
|
-
"""
|
|
342
|
+
"""Returns the set of independent labels in the dataset (multi-label tasks)."""
|
|
253
343
|
return self.frame_lsets['multilabel']
|
|
254
344
|
|
|
255
345
|
@property
|
|
256
346
|
def frame_categories_set(self) -> list[tuple[str, str]]:
|
|
257
|
-
"""
|
|
258
|
-
Returns the set of categories in the dataset.
|
|
259
|
-
This is more related to multi-class tasks.
|
|
260
|
-
"""
|
|
347
|
+
"""Returns the set of categories in the dataset (multi-class tasks)."""
|
|
261
348
|
return self.frame_lsets['multiclass']
|
|
262
349
|
|
|
263
350
|
@property
|
|
264
351
|
def image_labels_set(self) -> list[str]:
|
|
265
|
-
"""
|
|
266
|
-
Returns the set of independent labels in the dataset.
|
|
267
|
-
This is more related to multi-label tasks.
|
|
268
|
-
"""
|
|
352
|
+
"""Returns the set of independent labels in the dataset (multi-label tasks)."""
|
|
269
353
|
return self.image_lsets['multilabel']
|
|
270
354
|
|
|
271
355
|
@property
|
|
272
356
|
def image_categories_set(self) -> list[tuple[str, str]]:
|
|
273
|
-
"""
|
|
274
|
-
Returns the set of categories in the dataset.
|
|
275
|
-
This is more related to multi-class tasks.
|
|
276
|
-
"""
|
|
357
|
+
"""Returns the set of categories in the dataset (multi-class tasks)."""
|
|
277
358
|
return self.image_lsets['multiclass']
|
|
278
359
|
|
|
279
360
|
@property
|
|
280
361
|
def segmentation_labels_set(self) -> list[str]:
|
|
281
|
-
"""
|
|
282
|
-
Returns the set of segmentation labels in the dataset.
|
|
283
|
-
"""
|
|
362
|
+
"""Returns the set of segmentation labels in the dataset."""
|
|
284
363
|
return self.frame_lsets['segmentation']
|
|
285
364
|
|
|
286
|
-
def _get_annotations_internal(
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
365
|
+
def _get_annotations_internal(
|
|
366
|
+
self,
|
|
367
|
+
annotations: list[Annotation],
|
|
368
|
+
type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
|
|
369
|
+
scope: Literal['frame', 'image', 'all'] = 'all'
|
|
370
|
+
) -> list[Annotation]:
|
|
371
|
+
"""Internal method to filter annotations by type and scope."""
|
|
291
372
|
if type not in ['label', 'category', 'segmentation', 'all']:
|
|
292
373
|
raise ValueError(f"Invalid value for 'type': {type}")
|
|
293
374
|
if scope not in ['frame', 'image', 'all']:
|
|
294
375
|
raise ValueError(f"Invalid value for 'scope': {scope}")
|
|
295
376
|
|
|
296
|
-
|
|
377
|
+
filtered_annotations = []
|
|
297
378
|
for ann in annotations:
|
|
298
|
-
ann_scope = 'image' if ann.
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
379
|
+
ann_scope = 'image' if ann.index is None else 'frame'
|
|
380
|
+
|
|
381
|
+
type_matches = type == 'all' or ann.type == type
|
|
382
|
+
scope_matches = scope == 'all' or scope == ann_scope
|
|
383
|
+
|
|
384
|
+
if type_matches and scope_matches:
|
|
385
|
+
filtered_annotations.append(ann)
|
|
386
|
+
|
|
387
|
+
return filtered_annotations
|
|
388
|
+
|
|
389
|
+
def get_annotations(
|
|
390
|
+
self,
|
|
391
|
+
index: int,
|
|
392
|
+
type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
|
|
393
|
+
scope: Literal['frame', 'image', 'all'] = 'all'
|
|
394
|
+
) -> list[Annotation]:
|
|
395
|
+
"""Returns the annotations of the image at the given index.
|
|
309
396
|
|
|
310
397
|
Args:
|
|
311
|
-
index
|
|
312
|
-
type
|
|
313
|
-
scope
|
|
398
|
+
index: Index of the image.
|
|
399
|
+
type: The type of the annotations. Can be 'label', 'category', 'segmentation' or 'all'.
|
|
400
|
+
scope: The scope of the annotations. Can be 'frame', 'image' or 'all'.
|
|
314
401
|
|
|
315
402
|
Returns:
|
|
316
|
-
|
|
403
|
+
The annotations of the image.
|
|
317
404
|
"""
|
|
318
405
|
if index >= len(self):
|
|
319
406
|
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
|
|
407
|
+
|
|
320
408
|
imginfo = self._get_image_metainfo(index)
|
|
321
409
|
return self._get_annotations_internal(imginfo['annotations'], type=type, scope=scope)
|
|
322
410
|
|
|
323
411
|
@staticmethod
|
|
324
412
|
def read_number_of_frames(filepath: str) -> int:
|
|
325
|
-
|
|
413
|
+
"""Read the number of frames in a file."""
|
|
326
414
|
if is_dicom(filepath):
|
|
327
415
|
ds = pydicom.dcmread(filepath)
|
|
328
|
-
return
|
|
329
|
-
|
|
330
|
-
elif filepath.endswith('.mp4') or filepath.endswith('.avi'):
|
|
416
|
+
return getattr(ds, 'NumberOfFrames', 1)
|
|
417
|
+
elif filepath.lower().endswith(('.mp4', '.avi')):
|
|
331
418
|
cap = cv2.VideoCapture(filepath)
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
419
|
+
try:
|
|
420
|
+
return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
421
|
+
finally:
|
|
422
|
+
cap.release()
|
|
423
|
+
elif filepath.lower().endswith(('.png', '.jpg', '.jpeg')):
|
|
335
424
|
return 1
|
|
336
425
|
else:
|
|
337
426
|
raise ValueError(f"Unsupported file type: {filepath}")
|
|
338
427
|
|
|
339
428
|
def get_resources_ids(self) -> list[str]:
|
|
340
|
-
|
|
429
|
+
"""Get list of resource IDs."""
|
|
430
|
+
return [
|
|
431
|
+
self.__getitem_internal(i, only_load_metainfo=True)['metainfo']['id']
|
|
432
|
+
for i in self.subset_indices
|
|
433
|
+
]
|
|
341
434
|
|
|
342
435
|
def _get_labels_set(self, framed: bool) -> tuple[dict, dict[str, dict[str, int]]]:
|
|
343
|
-
"""
|
|
344
|
-
|
|
436
|
+
"""Returns the set of labels and mappings to integers.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
framed: If True, get frame-level labels, otherwise image-level labels.
|
|
345
440
|
|
|
346
441
|
Returns:
|
|
347
|
-
Tuple
|
|
442
|
+
Tuple containing label sets and label-to-code mappings.
|
|
348
443
|
"""
|
|
349
|
-
|
|
350
444
|
scope = 'frame' if framed else 'image'
|
|
351
445
|
|
|
352
446
|
multilabel_set = set()
|
|
@@ -354,100 +448,113 @@ class DatamintBaseDataset:
|
|
|
354
448
|
multiclass_set = set()
|
|
355
449
|
|
|
356
450
|
for i in range(len(self)):
|
|
357
|
-
|
|
358
|
-
|
|
451
|
+
# Collect labels by type
|
|
452
|
+
label_anns = self.get_annotations(i, type='label', scope=scope)
|
|
453
|
+
multilabel_set.update(ann.name for ann in label_anns)
|
|
359
454
|
|
|
360
|
-
|
|
361
|
-
segmentation_labels.update(
|
|
455
|
+
seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
|
|
456
|
+
segmentation_labels.update(ann.name for ann in seg_anns)
|
|
362
457
|
|
|
363
|
-
|
|
364
|
-
multiclass_set.update(
|
|
458
|
+
cat_anns = self.get_annotations(i, type='category', scope=scope)
|
|
459
|
+
multiclass_set.update((ann.name, ann.value) for ann in cat_anns)
|
|
365
460
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
461
|
+
# Sort and create mappings
|
|
462
|
+
multilabel_list = sorted(multilabel_set)
|
|
463
|
+
multiclass_list = sorted(multiclass_set)
|
|
464
|
+
segmentation_list = sorted(segmentation_labels)
|
|
369
465
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
466
|
+
sets = {
|
|
467
|
+
'multilabel': multilabel_list,
|
|
468
|
+
'segmentation': segmentation_list,
|
|
469
|
+
'multiclass': multiclass_list
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
codes_map = {
|
|
473
|
+
'multilabel': {label: idx for idx, label in enumerate(multilabel_list)},
|
|
474
|
+
'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
|
|
475
|
+
'multiclass': {label: idx for idx, label in enumerate(multiclass_list)}
|
|
476
|
+
}
|
|
373
477
|
|
|
374
|
-
sets = {'multilabel': multilabel_set,
|
|
375
|
-
'segmentation': segmentation_labels,
|
|
376
|
-
'multiclass': multiclass_set}
|
|
377
|
-
codes_map = {'multilabel': multilabel2code,
|
|
378
|
-
'segmentation': segmentation_label2code,
|
|
379
|
-
'multiclass': multiclass2code}
|
|
380
478
|
return sets, codes_map
|
|
381
479
|
|
|
382
|
-
def get_framelabel_distribution(self, normalize=False) -> dict[str, float]:
|
|
383
|
-
"""
|
|
384
|
-
|
|
480
|
+
def get_framelabel_distribution(self, normalize: bool = False) -> dict[str, float]:
|
|
481
|
+
"""Returns the distribution of frame labels in the dataset."""
|
|
482
|
+
return self._get_label_distribution('label', 'frame', normalize)
|
|
385
483
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
label_distribution = {label: 0 for label in self.frame_labels_set}
|
|
390
|
-
for imginfo in self.images_metainfo:
|
|
391
|
-
for ann in imginfo['annotations']:
|
|
392
|
-
if ann['type'] == 'label' and ann['index'] is not None:
|
|
393
|
-
label_distribution[ann['name']] += 1
|
|
484
|
+
def get_segmentationlabel_distribution(self, normalize: bool = False) -> dict[str, float]:
|
|
485
|
+
"""Returns the distribution of segmentation labels in the dataset."""
|
|
486
|
+
return self._get_label_distribution('segmentation', 'all', normalize)
|
|
394
487
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
488
|
+
def _get_label_distribution(self, ann_type: str, scope: str, normalize: bool) -> dict[str, float]:
|
|
489
|
+
"""Helper method to calculate label distributions."""
|
|
490
|
+
if ann_type == 'label' and scope == 'frame':
|
|
491
|
+
labels = self.frame_labels_set
|
|
492
|
+
elif ann_type == 'segmentation':
|
|
493
|
+
labels = self.segmentation_labels_set
|
|
494
|
+
else:
|
|
495
|
+
raise ValueError(f"Unsupported combination: type={ann_type}, scope={scope}")
|
|
401
496
|
|
|
402
|
-
|
|
403
|
-
"""
|
|
404
|
-
Returns the distribution of segmentation labels in the dataset.
|
|
497
|
+
distribution = {label: 0 for label in labels}
|
|
405
498
|
|
|
406
|
-
Returns:
|
|
407
|
-
Dict[str, int]: The distribution of segmentation labels in the dataset.
|
|
408
|
-
"""
|
|
409
|
-
label_distribution = {label: 0 for label in self.segmentation_labels_set}
|
|
410
499
|
for imginfo in self.images_metainfo:
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
500
|
+
for ann in imginfo.get('annotations', []):
|
|
501
|
+
condition_met = (
|
|
502
|
+
ann.type == ann_type and
|
|
503
|
+
(scope == 'all' or
|
|
504
|
+
(scope == 'frame' and ann.index is not None) or
|
|
505
|
+
(scope == 'image' and ann.index is None))
|
|
506
|
+
)
|
|
507
|
+
if condition_met and ann.name in distribution:
|
|
508
|
+
distribution[ann.name] += 1
|
|
415
509
|
|
|
416
510
|
if normalize:
|
|
417
|
-
total = sum(
|
|
418
|
-
if total
|
|
419
|
-
|
|
420
|
-
label_distribution = {k: v/total for k, v in label_distribution.items()}
|
|
421
|
-
return label_distribution
|
|
511
|
+
total = sum(distribution.values())
|
|
512
|
+
if total > 0:
|
|
513
|
+
distribution = {k: v / total for k, v in distribution.items()}
|
|
422
514
|
|
|
423
|
-
|
|
515
|
+
return distribution
|
|
516
|
+
|
|
517
|
+
def _check_integrity(self) -> None:
|
|
518
|
+
"""Check if all image files exist."""
|
|
519
|
+
missing_files = []
|
|
424
520
|
for imginfo in self.images_metainfo:
|
|
425
|
-
|
|
426
|
-
|
|
521
|
+
filepath = os.path.join(self.dataset_dir, imginfo['file'])
|
|
522
|
+
if not os.path.isfile(filepath):
|
|
523
|
+
missing_files.append(imginfo['file'])
|
|
524
|
+
|
|
525
|
+
if missing_files:
|
|
526
|
+
raise DatamintDatasetException(f"Image files not found: {missing_files}")
|
|
427
527
|
|
|
428
528
|
def _get_datasetinfo(self) -> dict:
|
|
529
|
+
"""Get dataset information from API."""
|
|
530
|
+
if self._server_dataset_info is not None:
|
|
531
|
+
return self._server_dataset_info
|
|
429
532
|
all_datasets = self.api_handler.get_datasets()
|
|
430
533
|
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
if d[field_to_search] == value_to_search:
|
|
436
|
-
return d
|
|
534
|
+
for dataset in all_datasets:
|
|
535
|
+
if dataset['id'] == self.dataset_id:
|
|
536
|
+
self._server_dataset_info = dataset
|
|
537
|
+
return dataset
|
|
437
538
|
|
|
438
539
|
available_datasets = [(d['name'], d['id']) for d in all_datasets]
|
|
439
540
|
raise DatamintDatasetException(
|
|
440
|
-
f"Dataset with
|
|
441
|
-
f"
|
|
541
|
+
f"Dataset with id '{self.dataset_id}' not found. "
|
|
542
|
+
f"Available datasets: {available_datasets}"
|
|
442
543
|
)
|
|
443
544
|
|
|
444
545
|
def get_info(self) -> dict:
|
|
546
|
+
"""Get project information from API."""
|
|
547
|
+
if hasattr(self, 'project_info') and self.project_info is not None:
|
|
548
|
+
return self.project_info
|
|
445
549
|
project = self.api_handler.get_project_by_name(self.project_name)
|
|
446
550
|
if 'error' in project:
|
|
447
551
|
available_projects = project['all_projects']
|
|
448
552
|
raise DatamintDatasetException(
|
|
449
|
-
f"Project with name '{self.project_name}' not found.
|
|
553
|
+
f"Project with name '{self.project_name}' not found. "
|
|
554
|
+
f"Available projects: {available_projects}"
|
|
450
555
|
)
|
|
556
|
+
self.project_info = project
|
|
557
|
+
self.dataset_id = project['dataset_id']
|
|
451
558
|
return project
|
|
452
559
|
|
|
453
560
|
def _run_request(self, session, request_args) -> requests.Response:
|
|
@@ -457,216 +564,191 @@ class DatamintBaseDataset:
|
|
|
457
564
|
response.raise_for_status()
|
|
458
565
|
return response
|
|
459
566
|
|
|
460
|
-
def _get_jwttoken(self, dataset_id, session) -> str:
|
|
461
|
-
if dataset_id is None:
|
|
462
|
-
raise ValueError("Dataset ID is required to download the dataset.")
|
|
463
|
-
request_params = {
|
|
464
|
-
'method': 'GET',
|
|
465
|
-
'url': f'{self.server_url}/datasets/{dataset_id}/download/png',
|
|
466
|
-
'headers': {'apikey': self.api_key},
|
|
467
|
-
'stream': True
|
|
468
|
-
}
|
|
469
|
-
_LOGGER.debug(f"Getting jwt token for dataset {dataset_id}...")
|
|
470
|
-
response = self._run_request(session, request_params)
|
|
471
|
-
progress_bar = None
|
|
472
|
-
number_processed_images = 0
|
|
473
|
-
|
|
474
|
-
# check if the response is a stream of data and everything is ok
|
|
475
|
-
if response.status_code != 200:
|
|
476
|
-
msg = f"Getting jwt token failed with status code={response.status_code}: {response.text}"
|
|
477
|
-
raise DatamintDatasetException(msg)
|
|
478
|
-
|
|
479
|
-
try:
|
|
480
|
-
response_iterator = response.iter_lines(decode_unicode=True)
|
|
481
|
-
for line in response_iterator:
|
|
482
|
-
line = line.strip()
|
|
483
|
-
if 'event: error' in line:
|
|
484
|
-
error_msg = line+'\n'
|
|
485
|
-
error_msg += '\n'.join(response_iterator)
|
|
486
|
-
raise DatamintDatasetException(f"Getting jwt token failed:\n{error_msg}")
|
|
487
|
-
if not line.startswith('data:'):
|
|
488
|
-
continue
|
|
489
|
-
dataline = yaml.safe_load(line)['data']
|
|
490
|
-
if 'zip' in dataline:
|
|
491
|
-
_LOGGER.debug(f"Got jwt token for dataset {dataset_id}")
|
|
492
|
-
return dataline['zip'] # Function normally ends here
|
|
493
|
-
elif 'processedImages' in dataline:
|
|
494
|
-
if progress_bar is None:
|
|
495
|
-
total_size = int(dataline['totalImages'])
|
|
496
|
-
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
|
|
497
|
-
processed_images = int(dataline['processedImages'])
|
|
498
|
-
if number_processed_images < processed_images:
|
|
499
|
-
progress_bar.update(processed_images - number_processed_images)
|
|
500
|
-
number_processed_images = processed_images
|
|
501
|
-
else:
|
|
502
|
-
_LOGGER.warning(f"Unknown data line: {dataline}")
|
|
503
|
-
except Exception as e:
|
|
504
|
-
raise e
|
|
505
|
-
finally:
|
|
506
|
-
if progress_bar is not None:
|
|
507
|
-
progress_bar.close()
|
|
508
|
-
|
|
509
|
-
raise DatamintDatasetException("Getting jwt token failed! No dataline with 'zip' entry found.")
|
|
510
|
-
|
|
511
567
|
def __repr__(self) -> str:
|
|
512
|
-
"""
|
|
513
|
-
Example:
|
|
514
|
-
.. code-block:: python
|
|
515
|
-
|
|
516
|
-
print(dataset)
|
|
517
|
-
|
|
518
|
-
Output:
|
|
519
|
-
|
|
520
|
-
.. code-block:: text
|
|
521
|
-
|
|
522
|
-
Dataset DatamintDataset
|
|
523
|
-
Number of datapoints: 3
|
|
524
|
-
Root location: /home/user/.datamint/datasets
|
|
525
|
-
|
|
526
|
-
"""
|
|
568
|
+
"""String representation of the dataset."""
|
|
527
569
|
head = f"Dataset {self.project_name}"
|
|
528
570
|
body = [f"Number of datapoints: {self.__len__()}"]
|
|
571
|
+
|
|
529
572
|
if self.root is not None:
|
|
530
573
|
body.append(f"Location: {self.dataset_dir}")
|
|
531
574
|
|
|
532
|
-
# Add filter information
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
if self.exclude_frame_label_names is not None:
|
|
548
|
-
body += [f"Excluding frame labels: {self.exclude_frame_label_names}"]
|
|
575
|
+
# Add filter information
|
|
576
|
+
filter_info = [
|
|
577
|
+
(self.include_annotators, "Including only annotators"),
|
|
578
|
+
(self.exclude_annotators, "Excluding annotators"),
|
|
579
|
+
(self.include_segmentation_names, "Including only segmentations"),
|
|
580
|
+
(self.exclude_segmentation_names, "Excluding segmentations"),
|
|
581
|
+
(self.include_image_label_names, "Including only image labels"),
|
|
582
|
+
(self.exclude_image_label_names, "Excluding image labels"),
|
|
583
|
+
(self.include_frame_label_names, "Including only frame labels"),
|
|
584
|
+
(self.exclude_frame_label_names, "Excluding frame labels"),
|
|
585
|
+
]
|
|
586
|
+
|
|
587
|
+
for filter_value, description in filter_info:
|
|
588
|
+
if filter_value is not None:
|
|
589
|
+
body.append(f"{description}: {filter_value}")
|
|
549
590
|
|
|
550
591
|
lines = [head] + [" " * 4 + line for line in body]
|
|
551
592
|
return "\n".join(lines)
|
|
552
593
|
|
|
553
|
-
def download_project(self):
|
|
554
|
-
from
|
|
594
|
+
def download_project(self) -> None:
|
|
595
|
+
"""Download project data from API."""
|
|
555
596
|
|
|
556
597
|
dataset_info = self._get_datasetinfo()
|
|
557
598
|
self.dataset_id = dataset_info['id']
|
|
558
599
|
self.last_updaded_at = dataset_info['updated_at']
|
|
559
600
|
|
|
560
|
-
self.api_handler.download_project(
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
601
|
+
self.api_handler.download_project(
|
|
602
|
+
self.project_info['id'],
|
|
603
|
+
self.dataset_zippath,
|
|
604
|
+
all_annotations=self.all_annotations,
|
|
605
|
+
include_unannotated=self.include_unannotated
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
_LOGGER.debug("Downloaded dataset")
|
|
609
|
+
|
|
610
|
+
if os.path.getsize(self.dataset_zippath) == 0:
|
|
567
611
|
raise DatamintDatasetException("Download failed.")
|
|
568
612
|
|
|
613
|
+
self._extract_and_update_metadata()
|
|
614
|
+
|
|
615
|
+
def _get_dataset_id(self) -> str:
|
|
616
|
+
if self.dataset_id is None:
|
|
617
|
+
dataset_info = self._get_datasetinfo()
|
|
618
|
+
self.dataset_id = dataset_info['id']
|
|
619
|
+
return self.dataset_id
|
|
620
|
+
|
|
621
|
+
def _extract_and_update_metadata(self) -> None:
|
|
622
|
+
"""Extract downloaded archive and update metadata."""
|
|
623
|
+
from torchvision.datasets.utils import extract_archive
|
|
624
|
+
|
|
569
625
|
if os.path.exists(self.dataset_dir):
|
|
570
626
|
_LOGGER.info(f"Deleting existing dataset directory: {self.dataset_dir}")
|
|
571
627
|
shutil.rmtree(self.dataset_dir)
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
with open(
|
|
628
|
+
|
|
629
|
+
extract_archive(self.dataset_zippath, self.dataset_dir, remove_finished=True)
|
|
630
|
+
|
|
631
|
+
# Load and update metadata
|
|
632
|
+
datasetjson_path = os.path.join(self.dataset_dir, 'dataset.json')
|
|
633
|
+
with open(datasetjson_path, 'r') as file:
|
|
578
634
|
self.metainfo = json.load(file)
|
|
635
|
+
|
|
636
|
+
self._update_metadata_timestamps()
|
|
637
|
+
|
|
638
|
+
# Save updated metadata
|
|
639
|
+
with open(datasetjson_path, 'w') as file:
|
|
640
|
+
json.dump(self.metainfo, file, default=lambda o: o.to_dict() if hasattr(o, 'to_dict') else o)
|
|
641
|
+
|
|
642
|
+
self.images_metainfo = self.metainfo['resources']
|
|
643
|
+
# self._convert_metainfo_to_clsobj()
|
|
644
|
+
|
|
645
|
+
def _update_metadata_timestamps(self) -> None:
|
|
646
|
+
"""Update metadata with correct timestamps."""
|
|
579
647
|
if 'updated_at' not in self.metainfo:
|
|
580
648
|
self.metainfo['updated_at'] = self.last_updaded_at
|
|
581
649
|
else:
|
|
582
|
-
# if self.last_updated_at is newer than the one in the dataset, update it
|
|
583
650
|
try:
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
651
|
+
local_time = datetime.fromisoformat(self.metainfo['updated_at'])
|
|
652
|
+
server_time = datetime.fromisoformat(self.last_updaded_at)
|
|
653
|
+
|
|
654
|
+
if local_time < server_time:
|
|
655
|
+
_LOGGER.warning(
|
|
656
|
+
f"Inconsistent updated_at dates detected "
|
|
657
|
+
f"({self.metainfo['updated_at']} < {self.last_updaded_at}). "
|
|
658
|
+
f"Fixing it to {self.last_updaded_at}"
|
|
659
|
+
)
|
|
587
660
|
self.metainfo['updated_at'] = self.last_updaded_at
|
|
588
661
|
except Exception as e:
|
|
589
662
|
_LOGGER.warning(f"Failed to parse updated_at date: {e}")
|
|
590
663
|
|
|
591
|
-
# Add all_annotations to the metadata
|
|
592
664
|
self.metainfo['all_annotations'] = self.all_annotations
|
|
593
665
|
|
|
594
|
-
|
|
595
|
-
with
|
|
596
|
-
json.dump(self.metainfo, file)
|
|
597
|
-
|
|
598
|
-
def _load_image(self, filepath: str,
|
|
599
|
-
index: int | None = None) -> tuple[Tensor, FileDataset | None]:
|
|
666
|
+
def _load_image(self, filepath: str, index: int | None = None) -> tuple[Tensor, FileDataset | None]:
|
|
667
|
+
"""Load image from file with optional frame index."""
|
|
600
668
|
if os.path.isdir(filepath):
|
|
601
|
-
raise NotImplementedError("Loading
|
|
669
|
+
raise NotImplementedError("Loading an image from a directory is not supported yet.")
|
|
602
670
|
|
|
603
671
|
if self.return_frame_by_frame:
|
|
604
672
|
img, ds = read_array_normalized(filepath, return_metainfo=True, index=index)
|
|
605
673
|
else:
|
|
606
674
|
img, ds = read_array_normalized(filepath, return_metainfo=True)
|
|
607
675
|
|
|
676
|
+
img = self._process_image_array(img)
|
|
677
|
+
return img, ds
|
|
678
|
+
|
|
679
|
+
def _process_image_array(self, img: np.ndarray) -> Tensor:
|
|
680
|
+
"""Process numpy array to tensor with proper normalization."""
|
|
608
681
|
if img.dtype == np.uint16:
|
|
609
682
|
if not self.__logged_uint16_conversion:
|
|
610
683
|
_LOGGER.info("Original image is uint16, converting to uint8")
|
|
611
684
|
self.__logged_uint16_conversion = True
|
|
612
685
|
|
|
613
|
-
#
|
|
686
|
+
# Min-max normalization
|
|
614
687
|
img = img.astype(np.float32)
|
|
615
|
-
|
|
616
|
-
img = (img -
|
|
688
|
+
min_val = img.min()
|
|
689
|
+
img = (img - min_val) / (img.max() - min_val) * 255
|
|
617
690
|
img = img.astype(np.uint8)
|
|
618
691
|
|
|
619
|
-
|
|
620
|
-
if isinstance(img, torch.ByteTensor):
|
|
621
|
-
img = img.to(dtype=torch.get_default_dtype()).div(255)
|
|
692
|
+
img_tensor = torch.from_numpy(img).contiguous()
|
|
622
693
|
|
|
623
|
-
|
|
694
|
+
if isinstance(img_tensor, torch.ByteTensor):
|
|
695
|
+
img_tensor = img_tensor.to(dtype=torch.get_default_dtype()).div(255)
|
|
696
|
+
|
|
697
|
+
return img_tensor
|
|
624
698
|
|
|
625
|
-
def _get_image_metainfo(self, index: int, bypass_subset_indices=False) -> dict[str, Any]:
|
|
699
|
+
def _get_image_metainfo(self, index: int, bypass_subset_indices: bool = False) -> dict[str, Any]:
|
|
700
|
+
"""Get metadata for image at given index."""
|
|
626
701
|
if not bypass_subset_indices:
|
|
627
702
|
index = self.subset_indices[index]
|
|
703
|
+
|
|
628
704
|
if self.return_frame_by_frame:
|
|
629
|
-
# Find the correct filepath and index
|
|
630
705
|
resource_id, frame_index = self.__find_index(index)
|
|
631
|
-
|
|
632
|
-
img_metainfo = self.images_metainfo[resource_id]
|
|
633
|
-
img_metainfo = dict(img_metainfo) # copy
|
|
634
|
-
# insert frame index
|
|
706
|
+
img_metainfo = dict(self.images_metainfo[resource_id]) # Copy
|
|
635
707
|
img_metainfo['frame_index'] = frame_index
|
|
636
|
-
img_metainfo['annotations'] = [
|
|
637
|
-
|
|
708
|
+
img_metainfo['annotations'] = [
|
|
709
|
+
ann for ann in img_metainfo['annotations']
|
|
710
|
+
if ann.index is None or ann.index == frame_index
|
|
711
|
+
]
|
|
638
712
|
else:
|
|
639
713
|
img_metainfo = self.images_metainfo[index]
|
|
714
|
+
|
|
640
715
|
return img_metainfo
|
|
641
716
|
|
|
642
717
|
def __find_index(self, index: int) -> tuple[int, int]:
|
|
643
|
-
"""
|
|
644
|
-
Find the resource index and frame index for a given global frame index.
|
|
645
|
-
|
|
646
|
-
"""
|
|
647
|
-
# Use binary search to find the resource containing this frame
|
|
718
|
+
"""Find the resource index and frame index for a given global frame index."""
|
|
648
719
|
resource_index = np.searchsorted(self._cumulative_frames[1:], index, side='right')
|
|
649
720
|
frame_index = index - self._cumulative_frames[resource_index]
|
|
650
|
-
|
|
651
721
|
return resource_index, frame_index
|
|
652
722
|
|
|
653
|
-
def __getitem_internal(
|
|
654
|
-
|
|
723
|
+
def __getitem_internal(
|
|
724
|
+
self,
|
|
725
|
+
index: int,
|
|
726
|
+
only_load_metainfo: bool = False
|
|
727
|
+
) -> dict[str, Tensor | FileDataset | dict | list]:
|
|
728
|
+
"""Internal method to get item at index."""
|
|
655
729
|
if self.return_frame_by_frame:
|
|
656
730
|
resource_index, frame_idx = self.__find_index(index)
|
|
657
731
|
else:
|
|
658
732
|
resource_index = index
|
|
659
733
|
frame_idx = None
|
|
734
|
+
|
|
660
735
|
img_metainfo = self._get_image_metainfo(index, bypass_subset_indices=True)
|
|
661
736
|
|
|
662
737
|
if only_load_metainfo:
|
|
663
738
|
return {'metainfo': img_metainfo}
|
|
664
739
|
|
|
665
740
|
filepath = os.path.join(self.dataset_dir, img_metainfo['file'])
|
|
666
|
-
|
|
667
|
-
# Can be multi-frame, Gray-scale and/or RGB. So the shape is really variable, but it's always a numpy array.
|
|
668
741
|
img, ds = self._load_image(filepath, frame_idx)
|
|
669
742
|
|
|
743
|
+
return self._build_item_dict(img, ds, img_metainfo)
|
|
744
|
+
|
|
745
|
+
def _build_item_dict(
|
|
746
|
+
self,
|
|
747
|
+
img: Tensor,
|
|
748
|
+
ds: FileDataset | None,
|
|
749
|
+
img_metainfo: dict
|
|
750
|
+
) -> dict[str, Any]:
|
|
751
|
+
"""Build the return dictionary for __getitem__."""
|
|
670
752
|
ret = {'image': img}
|
|
671
753
|
|
|
672
754
|
if self.return_dicom:
|
|
@@ -678,52 +760,42 @@ class DatamintBaseDataset:
|
|
|
678
760
|
|
|
679
761
|
return ret
|
|
680
762
|
|
|
681
|
-
def _filter_annotations(self, annotations: list[
|
|
682
|
-
"""
|
|
683
|
-
Filter annotations based on the filtering settings.
|
|
684
|
-
|
|
685
|
-
Args:
|
|
686
|
-
annotations: list of annotations
|
|
687
|
-
|
|
688
|
-
Returns:
|
|
689
|
-
list[dict]: filtered list of annotations
|
|
690
|
-
"""
|
|
763
|
+
def _filter_annotations(self, annotations: list[Annotation]) -> list[Annotation]:
|
|
764
|
+
"""Filter annotations based on the filtering settings."""
|
|
691
765
|
if annotations is None:
|
|
692
766
|
return []
|
|
693
767
|
|
|
694
768
|
filtered_annotations = []
|
|
695
769
|
for ann in annotations:
|
|
696
|
-
|
|
697
|
-
if not self._should_include_annotator(ann['added_by']):
|
|
770
|
+
if not self._should_include_annotation(ann):
|
|
698
771
|
continue
|
|
699
|
-
|
|
700
|
-
# Filter by annotation type and name
|
|
701
|
-
if ann['type'] == 'segmentation':
|
|
702
|
-
if not self._should_include_segmentation(ann['name']):
|
|
703
|
-
continue
|
|
704
|
-
elif ann['type'] == 'label':
|
|
705
|
-
# Check if it's a frame or image label
|
|
706
|
-
if ann.get('index', None) is None:
|
|
707
|
-
# Image label
|
|
708
|
-
if not self._should_include_image_label(ann['name']):
|
|
709
|
-
continue
|
|
710
|
-
else:
|
|
711
|
-
# Frame label
|
|
712
|
-
if not self._should_include_frame_label(ann['name']):
|
|
713
|
-
continue
|
|
714
|
-
|
|
715
|
-
# If we reach here, the annotation passed all filters
|
|
716
772
|
filtered_annotations.append(ann)
|
|
717
773
|
|
|
718
774
|
return filtered_annotations
|
|
719
775
|
|
|
776
|
+
def _should_include_annotation(self, ann: Annotation) -> bool:
|
|
777
|
+
"""Check if an annotation should be included based on all filters."""
|
|
778
|
+
if not self._should_include_annotator(ann.created_by):
|
|
779
|
+
return False
|
|
780
|
+
|
|
781
|
+
if ann.type == 'segmentation':
|
|
782
|
+
return self._should_include_segmentation(ann.name)
|
|
783
|
+
elif ann.type == 'label':
|
|
784
|
+
if ann.index is None:
|
|
785
|
+
return self._should_include_image_label(ann.name)
|
|
786
|
+
else:
|
|
787
|
+
return self._should_include_frame_label(ann.name)
|
|
788
|
+
|
|
789
|
+
return True
|
|
790
|
+
|
|
720
791
|
def __getitem__(self, index: int) -> dict[str, Tensor | FileDataset | dict | list]:
|
|
721
|
-
"""
|
|
792
|
+
"""Get item at index.
|
|
793
|
+
|
|
722
794
|
Args:
|
|
723
|
-
index
|
|
795
|
+
index: Index
|
|
724
796
|
|
|
725
797
|
Returns:
|
|
726
|
-
|
|
798
|
+
A dictionary containing 'image', 'metainfo' and 'annotations' keys.
|
|
727
799
|
"""
|
|
728
800
|
if index >= len(self):
|
|
729
801
|
raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
|
|
@@ -731,21 +803,28 @@ class DatamintBaseDataset:
|
|
|
731
803
|
return self.__getitem_internal(self.subset_indices[index])
|
|
732
804
|
|
|
733
805
|
def __iter__(self):
|
|
806
|
+
"""Iterate over dataset items."""
|
|
734
807
|
for index in self.subset_indices:
|
|
735
|
-
yield self.
|
|
808
|
+
yield self.__getitem__(index)
|
|
809
|
+
# do not use __getitem_internal__ here, so subclass only need to implement __getitem__
|
|
736
810
|
|
|
737
811
|
def __len__(self) -> int:
|
|
812
|
+
"""Return dataset length."""
|
|
738
813
|
return len(self.subset_indices)
|
|
739
814
|
|
|
740
|
-
def _check_version(self):
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
815
|
+
def _check_version(self) -> None:
|
|
816
|
+
"""Check if local dataset version is up to date."""
|
|
817
|
+
# metainfo_path = os.path.join(self.dataset_dir, 'dataset.json')
|
|
818
|
+
# if not os.path.exists(metainfo_path):
|
|
819
|
+
# self.download_project()
|
|
820
|
+
# return
|
|
821
|
+
|
|
822
|
+
if not hasattr(self, 'project_info'):
|
|
823
|
+
self.project_info = self.get_info()
|
|
824
|
+
self.dataset_id = self.project_info['dataset_id']
|
|
825
|
+
|
|
826
|
+
local_updated_at = self.metainfo.get('updated_at', None)
|
|
827
|
+
local_all_annotations = self.metainfo.get('all_annotations', None)
|
|
749
828
|
|
|
750
829
|
try:
|
|
751
830
|
external_metadata_info = self._get_datasetinfo()
|
|
@@ -756,82 +835,279 @@ class DatamintBaseDataset:
|
|
|
756
835
|
|
|
757
836
|
_LOGGER.debug(f"Local updated at: {local_updated_at}, Server updated at: {server_updated_at}")
|
|
758
837
|
|
|
759
|
-
# Check if all_annotations changed or doesn't exist
|
|
760
838
|
annotations_changed = local_all_annotations != self.all_annotations
|
|
839
|
+
version_outdated = local_updated_at is None or local_updated_at < server_updated_at
|
|
761
840
|
|
|
762
|
-
if
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
841
|
+
if annotations_changed:
|
|
842
|
+
_LOGGER.info(
|
|
843
|
+
f"The 'all_annotations' parameter has changed. "
|
|
844
|
+
f"Previous: {local_all_annotations}, Current: {self.all_annotations}."
|
|
845
|
+
)
|
|
846
|
+
# self.download_project()
|
|
847
|
+
self._incremental_update()
|
|
848
|
+
elif version_outdated:
|
|
849
|
+
_LOGGER.info(
|
|
850
|
+
f"A newer version of the dataset is available. "
|
|
851
|
+
f"Your version: {local_updated_at}. Last version: {server_updated_at}."
|
|
852
|
+
)
|
|
853
|
+
self._incremental_update()
|
|
773
854
|
else:
|
|
774
855
|
_LOGGER.info('Local version is up to date with the latest version.')
|
|
775
856
|
|
|
857
|
+
def _fetch_new_resources(self,
|
|
858
|
+
all_uptodate_resources: list[dict]) -> list[dict]:
|
|
859
|
+
local_resources = self.images_metainfo
|
|
860
|
+
local_resources_ids = [res['id'] for res in local_resources]
|
|
861
|
+
new_resources = []
|
|
862
|
+
for resource in all_uptodate_resources:
|
|
863
|
+
if resource['id'] not in local_resources_ids:
|
|
864
|
+
resource['file'] = str(self._get_resource_file_path(resource))
|
|
865
|
+
resource['annotations'] = []
|
|
866
|
+
new_resources.append(resource)
|
|
867
|
+
return new_resources
|
|
868
|
+
|
|
869
|
+
def _fetch_deleted_resources(self, all_uptodate_resources: list[dict]) -> list[dict]:
|
|
870
|
+
local_resources = self.images_metainfo
|
|
871
|
+
all_uptodate_resources_ids = [res['id'] for res in all_uptodate_resources]
|
|
872
|
+
deleted_resources = []
|
|
873
|
+
for resource in local_resources:
|
|
874
|
+
try:
|
|
875
|
+
res_idx = all_uptodate_resources_ids.index(resource['id'])
|
|
876
|
+
if resource.get('deleted_at', None): # was deleted on server
|
|
877
|
+
if local_resources[res_idx].get('deleted_at_local', None) is None:
|
|
878
|
+
deleted_resources.append(resource)
|
|
879
|
+
except ValueError:
|
|
880
|
+
deleted_resources.append(resource)
|
|
881
|
+
|
|
882
|
+
return deleted_resources
|
|
883
|
+
|
|
884
|
+
def _incremental_update(self) -> None:
|
|
885
|
+
# local_updated_at = self.metainfo.get('updated_at', None)
|
|
886
|
+
# external_metadata_info = self._get_datasetinfo()
|
|
887
|
+
# server_updated_at = external_metadata_info['updated_at']
|
|
888
|
+
|
|
889
|
+
### RESOURCES ###
|
|
890
|
+
all_uptodate_resources = self.api_handler.get_project_resources(self.get_info()['id'])
|
|
891
|
+
new_resources = self._fetch_new_resources(all_uptodate_resources)
|
|
892
|
+
deleted_resources = self._fetch_deleted_resources(all_uptodate_resources)
|
|
893
|
+
|
|
894
|
+
if new_resources:
|
|
895
|
+
for r in new_resources:
|
|
896
|
+
self._new_resource_created(r)
|
|
897
|
+
new_resources_path = [Path(self.dataset_dir) / r['file'] for r in new_resources]
|
|
898
|
+
new_resources_ids = [r['id'] for r in new_resources]
|
|
899
|
+
_LOGGER.info(f"Downloading {len(new_resources)} new resources...")
|
|
900
|
+
self.api_handler.download_multiple_resources(new_resources_ids,
|
|
901
|
+
save_path=new_resources_path)
|
|
902
|
+
_LOGGER.info(f"Downloaded {len(new_resources)} new resources.")
|
|
903
|
+
|
|
904
|
+
for r in deleted_resources:
|
|
905
|
+
self._resource_deleted(r)
|
|
906
|
+
################
|
|
907
|
+
|
|
908
|
+
### ANNOTATIONS ###
|
|
909
|
+
all_annotations = self.api_handler.get_annotations(worklist_id=self.project_info['worklist_id'],
|
|
910
|
+
status='published' if self.all_annotations else None)
|
|
911
|
+
# group annotations by resource ID
|
|
912
|
+
annotations_by_resource = {}
|
|
913
|
+
for ann in all_annotations:
|
|
914
|
+
# add the local filepath
|
|
915
|
+
filepath = self._get_annotation_file_path(ann)
|
|
916
|
+
if filepath is not None:
|
|
917
|
+
ann['file'] = str(filepath)
|
|
918
|
+
resource_id = ann['resource_id']
|
|
919
|
+
if resource_id not in annotations_by_resource:
|
|
920
|
+
annotations_by_resource[resource_id] = []
|
|
921
|
+
annotations_by_resource[resource_id].append(ann)
|
|
922
|
+
|
|
923
|
+
# Collect all segmentation annotations that need to be downloaded
|
|
924
|
+
segmentations_to_download = []
|
|
925
|
+
segmentation_paths = []
|
|
926
|
+
|
|
927
|
+
# update annotations in resources
|
|
928
|
+
for resource in self.images_metainfo:
|
|
929
|
+
resource_id = resource['id']
|
|
930
|
+
new_resource_annotations = annotations_by_resource.get(resource_id, [])
|
|
931
|
+
old_resource_annotations = resource.get('annotations', [])
|
|
932
|
+
|
|
933
|
+
# check if segmentation annotations need to be downloaded
|
|
934
|
+
# Also check if annotations need to be deleted
|
|
935
|
+
old_ann_ids = set([ann.id for ann in old_resource_annotations if hasattr(ann, 'id')])
|
|
936
|
+
new_ann_ids = set([ann['id'] for ann in new_resource_annotations])
|
|
937
|
+
|
|
938
|
+
# Find annotations to add, update, or remove
|
|
939
|
+
annotations_to_add = [ann for ann in new_resource_annotations
|
|
940
|
+
if ann['id'] not in old_ann_ids]
|
|
941
|
+
annotations_to_remove = [ann for ann in old_resource_annotations
|
|
942
|
+
if getattr(ann, 'id', 'NA') not in new_ann_ids]
|
|
943
|
+
|
|
944
|
+
for ann in annotations_to_add:
|
|
945
|
+
filepath = self._get_annotation_file_path(ann)
|
|
946
|
+
if filepath is not None: # None means it is not a segmentation
|
|
947
|
+
# Collect for batch download
|
|
948
|
+
filepath = Path(self.dataset_dir) / filepath
|
|
949
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
950
|
+
segmentations_to_download.append(ann)
|
|
951
|
+
segmentation_paths.append(filepath)
|
|
952
|
+
|
|
953
|
+
# Process annotation changes
|
|
954
|
+
for ann in annotations_to_remove:
|
|
955
|
+
filepath = getattr(ann, 'file', None) if hasattr(ann, 'file') else ann.get('file', None)
|
|
956
|
+
if filepath is None:
|
|
957
|
+
# Not a segmentation annotation
|
|
958
|
+
continue
|
|
959
|
+
|
|
960
|
+
try:
|
|
961
|
+
filepath = Path(self.dataset_dir) / filepath
|
|
962
|
+
# delete the local annotation file if it exists
|
|
963
|
+
if filepath.exists():
|
|
964
|
+
os.remove(filepath)
|
|
965
|
+
except Exception as e:
|
|
966
|
+
_LOGGER.error(f"Error deleting annotation file {filepath}: {e}")
|
|
967
|
+
|
|
968
|
+
# Update resource annotations list - convert to Annotation objects
|
|
969
|
+
resource['annotations'] = [Annotation.from_dict(ann) for ann in new_resource_annotations]
|
|
970
|
+
|
|
971
|
+
# Batch download all segmentation files
|
|
972
|
+
if segmentations_to_download:
|
|
973
|
+
_LOGGER.info(f"Downloading {len(segmentations_to_download)} segmentation files...")
|
|
974
|
+
self.api_handler.download_multiple_segmentations(segmentations_to_download, segmentation_paths)
|
|
975
|
+
_LOGGER.info(f"Downloaded {len(segmentations_to_download)} segmentation files.")
|
|
976
|
+
|
|
977
|
+
###################
|
|
978
|
+
# update metadata
|
|
979
|
+
self.metainfo['updated_at'] = self._get_datasetinfo()['updated_at']
|
|
980
|
+
self.metainfo['all_annotations'] = self.all_annotations
|
|
981
|
+
# save updated metadata
|
|
982
|
+
datasetjson_path = os.path.join(self.dataset_dir, 'dataset.json')
|
|
983
|
+
with open(datasetjson_path, 'w') as file:
|
|
984
|
+
json.dump(self.metainfo, file, default=lambda o: o.to_dict() if hasattr(o, 'to_dict') else o)
|
|
985
|
+
|
|
986
|
+
def _get_resource_file_path(self, resource: dict) -> Path:
|
|
987
|
+
"""Get the local file path for a resource."""
|
|
988
|
+
if 'file' in resource and resource['file'] is not None:
|
|
989
|
+
return Path(resource['file'])
|
|
990
|
+
else:
|
|
991
|
+
ext = guess_extension(resource['mimetype'], strict=False)
|
|
992
|
+
if ext is None:
|
|
993
|
+
_LOGGER.warning(f"Could not guess extension for resource {resource['id']}.")
|
|
994
|
+
ext = ''
|
|
995
|
+
return Path('images', f"{resource['id']}{ext}")
|
|
996
|
+
|
|
997
|
+
def _get_annotation_file_path(self, annotation: dict | Annotation) -> Path | None:
|
|
998
|
+
"""Get the local file path for an annotation."""
|
|
999
|
+
if isinstance(annotation, Annotation):
|
|
1000
|
+
if annotation.file:
|
|
1001
|
+
return Path(annotation.file)
|
|
1002
|
+
elif annotation.type == 'segmentation':
|
|
1003
|
+
return Path('masks',
|
|
1004
|
+
annotation.created_by,
|
|
1005
|
+
annotation.resource_id,
|
|
1006
|
+
annotation.id)
|
|
1007
|
+
else:
|
|
1008
|
+
# Handle dict format for backwards compatibility
|
|
1009
|
+
if 'file' in annotation:
|
|
1010
|
+
return Path(annotation['file'])
|
|
1011
|
+
elif annotation.get('annotation_type', annotation.get('type')) == 'segmentation':
|
|
1012
|
+
return Path('masks',
|
|
1013
|
+
annotation['created_by'],
|
|
1014
|
+
annotation['resource_id'],
|
|
1015
|
+
annotation['id'])
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
def _new_resource_created(self, resource: dict) -> None:
|
|
1019
|
+
"""Handle a new resource created in the dataset."""
|
|
1020
|
+
if 'annotations' not in resource:
|
|
1021
|
+
resource['annotations'] = [] # Initialize as empty list for Annotation objects
|
|
1022
|
+
self.images_metainfo.append(resource)
|
|
1023
|
+
|
|
1024
|
+
if hasattr(self, 'num_frames_per_resource'):
|
|
1025
|
+
raise NotImplementedError('Cannot handle new resources after dataset initialization')
|
|
1026
|
+
|
|
1027
|
+
def _resource_deleted(self, resource: dict) -> None:
|
|
1028
|
+
"""Handle a resource deleted from the dataset."""
|
|
1029
|
+
|
|
1030
|
+
# remove from metadata
|
|
1031
|
+
for i, imginfo in enumerate(self.images_metainfo):
|
|
1032
|
+
if imginfo['id'] == resource['id']:
|
|
1033
|
+
deleted_metainfo = self.images_metainfo.pop(i)
|
|
1034
|
+
break
|
|
1035
|
+
else:
|
|
1036
|
+
_LOGGER.warning(f"Resource {resource['id']} not found in dataset metadata.")
|
|
1037
|
+
return
|
|
1038
|
+
|
|
1039
|
+
# delete from system file
|
|
1040
|
+
if os.path.exists(deleted_metainfo['file']):
|
|
1041
|
+
os.remove(os.path.join(self.dataset_dir, deleted_metainfo['file']))
|
|
1042
|
+
|
|
1043
|
+
# delete associated annotations
|
|
1044
|
+
for ann in deleted_metainfo.get('annotations', []):
|
|
1045
|
+
ann_file = getattr(ann, 'file', None) if hasattr(ann, 'file') else ann.get('file', None)
|
|
1046
|
+
if ann_file is not None:
|
|
1047
|
+
os.remove(os.path.join(self.dataset_dir, ann_file))
|
|
1048
|
+
|
|
776
1049
|
def __add__(self, other):
|
|
1050
|
+
"""Concatenate datasets."""
|
|
777
1051
|
from torch.utils.data import ConcatDataset
|
|
778
1052
|
return ConcatDataset([self, other])
|
|
779
1053
|
|
|
780
1054
|
def get_dataloader(self, *args, **kwargs) -> DataLoader:
|
|
781
|
-
"""
|
|
782
|
-
Returns a DataLoader for the dataset.
|
|
783
|
-
This is a wrapper around the PyTorch DataLoader, with the convinience of using a nice collate_fn
|
|
784
|
-
that properly handles the different types of data in this dataset.
|
|
1055
|
+
"""Returns a DataLoader for the dataset with proper collate function.
|
|
785
1056
|
|
|
786
1057
|
Args:
|
|
787
|
-
*args: Positional arguments for the DataLoader.
|
|
788
|
-
**kwargs: Keyword arguments for the DataLoader.
|
|
789
|
-
|
|
1058
|
+
*args: Positional arguments for the DataLoader.
|
|
1059
|
+
**kwargs: Keyword arguments for the DataLoader.
|
|
790
1060
|
|
|
1061
|
+
Returns:
|
|
1062
|
+
DataLoader instance with custom collate function.
|
|
791
1063
|
"""
|
|
792
|
-
return DataLoader(self,
|
|
793
|
-
*args,
|
|
794
|
-
collate_fn=self.get_collate_fn(),
|
|
795
|
-
**kwargs)
|
|
1064
|
+
return DataLoader(self, *args, collate_fn=self.get_collate_fn(), **kwargs)
|
|
796
1065
|
|
|
797
1066
|
def get_collate_fn(self) -> Callable:
|
|
798
|
-
|
|
1067
|
+
"""Get collate function for DataLoader."""
|
|
1068
|
+
def collate_fn(batch: list[dict]) -> dict:
|
|
1069
|
+
if not batch:
|
|
1070
|
+
return {}
|
|
1071
|
+
|
|
799
1072
|
keys = batch[0].keys()
|
|
800
1073
|
collated_batch = {}
|
|
1074
|
+
|
|
801
1075
|
for key in keys:
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
shapes = [tensor.shape for tensor in
|
|
1076
|
+
values = [item[key] for item in batch]
|
|
1077
|
+
|
|
1078
|
+
if isinstance(values[0], torch.Tensor):
|
|
1079
|
+
shapes = [tensor.shape for tensor in values]
|
|
806
1080
|
if all(shape == shapes[0] for shape in shapes):
|
|
807
|
-
collated_batch[key] = torch.stack(
|
|
1081
|
+
collated_batch[key] = torch.stack(values)
|
|
808
1082
|
else:
|
|
809
|
-
_LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}
|
|
810
|
-
|
|
811
|
-
|
|
1083
|
+
_LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}")
|
|
1084
|
+
collated_batch[key] = values
|
|
1085
|
+
elif isinstance(values[0], np.ndarray):
|
|
1086
|
+
collated_batch[key] = np.stack(values)
|
|
1087
|
+
else:
|
|
1088
|
+
collated_batch[key] = values
|
|
812
1089
|
|
|
813
1090
|
return collated_batch
|
|
814
1091
|
|
|
815
1092
|
return collate_fn
|
|
816
1093
|
|
|
817
1094
|
def subset(self, indices: list[int]) -> 'DatamintBaseDataset':
|
|
818
|
-
|
|
1095
|
+
"""Create a subset of the dataset.
|
|
1096
|
+
|
|
1097
|
+
Args:
|
|
1098
|
+
indices: List of indices to include in the subset.
|
|
1099
|
+
|
|
1100
|
+
Returns:
|
|
1101
|
+
Self with updated subset indices.
|
|
1102
|
+
"""
|
|
1103
|
+
if max(indices, default=-1) >= self.dataset_length:
|
|
819
1104
|
raise ValueError(f"Subset indices must be less than the dataset length: {self.dataset_length}")
|
|
820
1105
|
|
|
821
1106
|
self.subset_indices = indices
|
|
822
|
-
|
|
823
1107
|
return self
|
|
824
1108
|
|
|
825
1109
|
def _should_include_annotator(self, annotator_id: str) -> bool:
|
|
826
|
-
"""
|
|
827
|
-
Check if an annotator should be included based on the filtering settings.
|
|
828
|
-
|
|
829
|
-
Args:
|
|
830
|
-
annotator_id: The ID of the annotator to check
|
|
831
|
-
|
|
832
|
-
Returns:
|
|
833
|
-
bool: True if the annotator should be included, False otherwise
|
|
834
|
-
"""
|
|
1110
|
+
"""Check if an annotator should be included based on filtering settings."""
|
|
835
1111
|
if self.include_annotators is not None:
|
|
836
1112
|
return annotator_id in self.include_annotators
|
|
837
1113
|
if self.exclude_annotators is not None:
|
|
@@ -839,15 +1115,7 @@ class DatamintBaseDataset:
|
|
|
839
1115
|
return True
|
|
840
1116
|
|
|
841
1117
|
def _should_include_segmentation(self, segmentation_name: str) -> bool:
|
|
842
|
-
"""
|
|
843
|
-
Check if a segmentation should be included based on the filtering settings.
|
|
844
|
-
|
|
845
|
-
Args:
|
|
846
|
-
segmentation_name: The name of the segmentation to check
|
|
847
|
-
|
|
848
|
-
Returns:
|
|
849
|
-
bool: True if the segmentation should be included, False otherwise
|
|
850
|
-
"""
|
|
1118
|
+
"""Check if a segmentation should be included based on filtering settings."""
|
|
851
1119
|
if self.include_segmentation_names is not None:
|
|
852
1120
|
return segmentation_name in self.include_segmentation_names
|
|
853
1121
|
if self.exclude_segmentation_names is not None:
|
|
@@ -855,15 +1123,7 @@ class DatamintBaseDataset:
|
|
|
855
1123
|
return True
|
|
856
1124
|
|
|
857
1125
|
def _should_include_image_label(self, label_name: str) -> bool:
|
|
858
|
-
"""
|
|
859
|
-
Check if an image label should be included based on the filtering settings.
|
|
860
|
-
|
|
861
|
-
Args:
|
|
862
|
-
label_name: The name of the image label to check
|
|
863
|
-
|
|
864
|
-
Returns:
|
|
865
|
-
bool: True if the image label should be included, False otherwise
|
|
866
|
-
"""
|
|
1126
|
+
"""Check if an image label should be included based on filtering settings."""
|
|
867
1127
|
if self.include_image_label_names is not None:
|
|
868
1128
|
return label_name in self.include_image_label_names
|
|
869
1129
|
if self.exclude_image_label_names is not None:
|
|
@@ -871,15 +1131,7 @@ class DatamintBaseDataset:
|
|
|
871
1131
|
return True
|
|
872
1132
|
|
|
873
1133
|
def _should_include_frame_label(self, label_name: str) -> bool:
|
|
874
|
-
"""
|
|
875
|
-
Check if a frame label should be included based on the filtering settings.
|
|
876
|
-
|
|
877
|
-
Args:
|
|
878
|
-
label_name: The name of the frame label to check
|
|
879
|
-
|
|
880
|
-
Returns:
|
|
881
|
-
bool: True if the frame label should be included, False otherwise
|
|
882
|
-
"""
|
|
1134
|
+
"""Check if a frame label should be included based on filtering settings."""
|
|
883
1135
|
if self.include_frame_label_names is not None:
|
|
884
1136
|
return label_name in self.include_frame_label_names
|
|
885
1137
|
if self.exclude_frame_label_names is not None:
|