datamint 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

@@ -0,0 +1,881 @@
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from typing import Optional, Callable, Any, Literal
5
+ import logging
6
+ import shutil
7
+ import json
8
+ import yaml
9
+ import pydicom
10
+ import numpy as np
11
+ from datamintapi import configs
12
+ from torch.utils.data import DataLoader
13
+ import torch
14
+ from datamintapi.apihandler.base_api_handler import DatamintException
15
+ from datamintapi.utils.dicom_utils import is_dicom
16
+ import cv2
17
+ from datamintapi.utils.io_utils import read_array_normalized
18
+ from deprecated import deprecated
19
+ from datetime import datetime
20
+
21
+ _LOGGER = logging.getLogger(__name__)
22
+
23
+
24
+ class DatamintDatasetException(DatamintException):
25
+ pass
26
+
27
+
28
+ class DatamintBaseDataset:
29
+ """
30
+ Class to download and load datasets from the Datamint API.
31
+
32
+ Args:
33
+ root: Root directory of dataset where data already exists or will be downloaded.
34
+ project_name: Name of the project to download.
35
+ auto_update: If True, the dataset will be checked for updates and downloaded if necessary.
36
+ api_key: API key to access the Datamint API. If not provided, it will look for the
37
+ environment variable 'DATAMINT_API_KEY'. Not necessary if
38
+ you don't want to download/update the dataset.
39
+ return_dicom: If True, the DICOM object will be returned, if the image is a DICOM file.
40
+ return_metainfo: If True, the metainfo of the image will be returned.
41
+ return_annotations: If True, the annotations of the image will be returned.
42
+ return_frame_by_frame: If True, each frame of a video/DICOM/3d-image will be returned separately.
43
+ include_unannotated: If True, images without annotations will be included. If False, images without annotations will be discarded.
44
+ all_annotations: If True, all annotations will be downloaded, including the ones that are not set as closed/done.
45
+ server_url: URL of the Datamint server. If not provided, it will use the default server.
46
+ include_annotators: List of annotators to include. If None, all annotators will be included. See parameter ``exclude_annotators``.
47
+ exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded. See parameter ``include_annotators``.
48
+ include_segmentation_names: List of segmentation names to include. If None, all segmentations will be included.
49
+ exclude_segmentation_names: List of segmentation names to exclude. If None, no segmentations will be excluded.
50
+ include_image_label_names: List of image label names to include. If None, all image labels will be included.
51
+ exclude_image_label_names: List of image label names to exclude. If None, no image labels will be excluded.
52
+ include_frame_label_names: List of frame label names to include. If None, all frame labels will be included.
53
+ exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
54
+
55
+ """
56
+
57
+ DATAMINT_DEFAULT_DIR = ".datamint"
58
+ DATAMINT_DATASETS_DIR = "datasets"
59
+
60
+ def __init__(self,
61
+ project_name: str,
62
+ root: str | None = None,
63
+ auto_update: bool = True,
64
+ api_key: Optional[str] = None,
65
+ server_url: Optional[str] = None,
66
+ return_dicom: bool = False,
67
+ return_metainfo: bool = True,
68
+ return_annotations: bool = True,
69
+ return_frame_by_frame: bool = False,
70
+ include_unannotated: bool = True,
71
+ all_annotations: bool = False,
72
+ # filtering parameters
73
+ include_annotators: Optional[list[str]] = None,
74
+ exclude_annotators: Optional[list[str]] = None,
75
+ include_segmentation_names: Optional[list[str]] = None,
76
+ exclude_segmentation_names: Optional[list[str]] = None,
77
+ include_image_label_names: Optional[list[str]] = None,
78
+ exclude_image_label_names: Optional[list[str]] = None,
79
+ include_frame_label_names: Optional[list[str]] = None,
80
+ exclude_frame_label_names: Optional[list[str]] = None
81
+ ):
82
+ from datamintapi.apihandler.api_handler import APIHandler
83
+
84
+ if project_name is None:
85
+ raise ValueError("project_name is required.")
86
+
87
+ self.all_annotations = all_annotations
88
+ self.api_handler = APIHandler(root_url=server_url, api_key=api_key,
89
+ check_connection=auto_update)
90
+ self.server_url = self.api_handler.root_url
91
+ if root is None:
92
+ # store them in the home directory
93
+ root = os.path.join(os.path.expanduser("~"),
94
+ DatamintBaseDataset.DATAMINT_DEFAULT_DIR)
95
+ root = os.path.join(root, DatamintBaseDataset.DATAMINT_DATASETS_DIR)
96
+ if not os.path.exists(root):
97
+ os.makedirs(root)
98
+ elif isinstance(root, str):
99
+ root = os.path.expanduser(root)
100
+ if not os.path.isdir(root):
101
+ raise NotADirectoryError(f"Root directory not found: {root}")
102
+
103
+ self.root = root
104
+
105
+ self.return_dicom = return_dicom
106
+ self.return_metainfo = return_metainfo
107
+ self.return_frame_by_frame = return_frame_by_frame
108
+ self.return_annotations = return_annotations
109
+ self.include_unannotated = include_unannotated
110
+ self.discard_without_annotations = not include_unannotated
111
+
112
+ # Filtering parameters
113
+ self.include_annotators = include_annotators
114
+ self.exclude_annotators = exclude_annotators
115
+ self.include_segmentation_names = include_segmentation_names
116
+ self.exclude_segmentation_names = exclude_segmentation_names
117
+ self.include_image_label_names = include_image_label_names
118
+ self.exclude_image_label_names = exclude_image_label_names
119
+ self.include_frame_label_names = include_frame_label_names
120
+ self.exclude_frame_label_names = exclude_frame_label_names
121
+
122
+ # Validate filtering parameters
123
+ if include_annotators is not None and exclude_annotators is not None:
124
+ raise ValueError("Cannot set both include_annotators and exclude_annotators at the same time")
125
+
126
+ if include_segmentation_names is not None and exclude_segmentation_names is not None:
127
+ raise ValueError("Cannot set both include_segmentation_names and exclude_segmentation_names at the same time")
128
+
129
+ if include_image_label_names is not None and exclude_image_label_names is not None:
130
+ raise ValueError("Cannot set both include_image_label_names and exclude_image_label_names at the same time")
131
+
132
+ if include_frame_label_names is not None and exclude_frame_label_names is not None:
133
+ raise ValueError("Cannot set both include_frame_label_names and exclude_frame_label_names at the same time")
134
+
135
+ self.project_name = project_name
136
+ dataset_name = project_name
137
+
138
+ self.dataset_dir = os.path.join(root, dataset_name)
139
+ self.dataset_zippath = os.path.join(root, f'{dataset_name}.zip')
140
+
141
+ local_dataset_exists = os.path.exists(os.path.join(self.dataset_dir, 'dataset.json'))
142
+
143
+ if local_dataset_exists and auto_update == False:
144
+ # In this case, we don't need to check the API, so we don't need the id.
145
+ self.dataset_id = None
146
+ else:
147
+ self.project_info = self.get_info()
148
+ self.dataset_id = self.project_info['dataset_id']
149
+
150
+ self.api_key = self.api_handler.api_key
151
+ if self.api_key is None:
152
+ _LOGGER.warning("API key not provided. If you want to download data, please provide an API key, " +
153
+ f"either by passing it as an argument," +
154
+ f"setting environment variable {configs.ENV_VARS[configs.APIKEY_KEY]} or " +
155
+ "using datamint-config command line tool."
156
+ )
157
+
158
+ # Download/Updates the dataset, if necessary.
159
+ if local_dataset_exists:
160
+ _LOGGER.info(f"Dataset directory already exists: {self.dataset_dir}")
161
+ if auto_update:
162
+ _LOGGER.info("Checking for updates...")
163
+ self._check_version()
164
+ else:
165
+ if self.api_key is None:
166
+ raise DatamintDatasetException("API key is required to download the dataset.")
167
+ _LOGGER.info(f"No data found at {self.dataset_dir}. Downloading...")
168
+ self.download_project()
169
+
170
+ # Loads the metadata
171
+ if not hasattr(self, 'metainfo'):
172
+ with open(os.path.join(self.dataset_dir, 'dataset.json'), 'r') as file:
173
+ self.metainfo = json.load(file)
174
+ self.images_metainfo = self.metainfo['resources']
175
+
176
+ # filter annotations
177
+ for imginfo in self.images_metainfo:
178
+ imginfo['annotations'] = self._filter_annotations(imginfo['annotations'])
179
+
180
+ # filter out images with no annotations.
181
+ if self.discard_without_annotations:
182
+ original_count = len(self.images_metainfo)
183
+ self.images_metainfo = self._filter_items(self.images_metainfo)
184
+ _LOGGER.info(f"Discarded {original_count - len(self.images_metainfo)} images without annotations.")
185
+
186
+ self._check_integrity()
187
+
188
+ # fix images_metainfo labels
189
+ # TODO: check tags
190
+ # for imginfo in self.images_metainfo:
191
+ # if imginfo['frame_labels'] is not None:
192
+ # for flabels in imginfo['frame_labels']:
193
+ # if flabels['label'] is None:
194
+ # flabels['label'] = []
195
+ # elif isinstance(flabels['label'], str):
196
+ # flabels['label'] = flabels['label'].split(',')
197
+
198
+ if self.return_frame_by_frame:
199
+ self.dataset_length = 0
200
+ for imginfo in self.images_metainfo:
201
+ filepath = os.path.join(self.dataset_dir, imginfo['file'])
202
+ self.dataset_length += self.read_number_of_frames(filepath)
203
+ else:
204
+ self.dataset_length = len(self.images_metainfo)
205
+
206
+ self.num_frames_per_resource = self.__compute_num_frames_per_resource()
207
+
208
+ self.subset_indices = list(range(self.dataset_length))
209
+ # self.labels_set, self.label2code, self.segmentation_labels, self.segmentation_label2code = self.get_labels_set()
210
+ self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
211
+ self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
212
+ self.__logged_uint16_conversion = False
213
+ if self.discard_without_annotations and self.return_frame_by_frame:
214
+ # If we are returning frame by frame, we need to filter out frames without segmentations
215
+ self._filter_unannotated()
216
+
217
+ def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
218
+ return [img for img in images_metainfo if len(img.get('annotations', []))]
219
+
220
+ def _filter_unannotated(self):
221
+ """Filter out frames that don't have any segmentations."""
222
+ filtered_indices = []
223
+ for i in range(len(self.subset_indices)):
224
+ item_meta = self._get_image_metainfo(i)
225
+ annotations = item_meta.get('annotations', [])
226
+
227
+ # Check if there are any segmentation annotations
228
+ has_segmentations = any(ann['type'] == 'segmentation' for ann in annotations)
229
+
230
+ if has_segmentations:
231
+ filtered_indices.append(self.subset_indices[i])
232
+
233
+ self.subset_indices = filtered_indices
234
+ print(f"Filtered dataset: {len(self.subset_indices)} frames with segmentations")
235
+
236
+ def __compute_num_frames_per_resource(self) -> list[int]:
237
+ num_frames_per_dicom = []
238
+ for imginfo in self.images_metainfo:
239
+ filepath = os.path.join(self.dataset_dir, imginfo['file'])
240
+ num_frames_per_dicom.append(self.read_number_of_frames(filepath))
241
+ return num_frames_per_dicom
242
+
243
+ @property
244
+ def frame_labels_set(self) -> list[str]:
245
+ """
246
+ Returns the set of independent labels in the dataset.
247
+ This is more related to multi-label tasks.
248
+ """
249
+ return self.frame_lsets['multilabel']
250
+
251
+ @property
252
+ def frame_categories_set(self) -> list[tuple[str, str]]:
253
+ """
254
+ Returns the set of categories in the dataset.
255
+ This is more related to multi-class tasks.
256
+ """
257
+ return self.frame_lsets['multiclass']
258
+
259
+ @property
260
+ def image_labels_set(self) -> list[str]:
261
+ """
262
+ Returns the set of independent labels in the dataset.
263
+ This is more related to multi-label tasks.
264
+ """
265
+ return self.image_lsets['multilabel']
266
+
267
+ @property
268
+ def image_categories_set(self) -> list[tuple[str, str]]:
269
+ """
270
+ Returns the set of categories in the dataset.
271
+ This is more related to multi-class tasks.
272
+ """
273
+ return self.image_lsets['multiclass']
274
+
275
+ @property
276
+ def segmentation_labels_set(self) -> list[str]:
277
+ """
278
+ Returns the set of segmentation labels in the dataset.
279
+ """
280
+ return self.frame_lsets['segmentation']
281
+
282
+ def _get_annotations_internal(self,
283
+ annotations: list[dict],
284
+ type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
285
+ scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
286
+ # check parameters
287
+ if type not in ['label', 'category', 'segmentation', 'all']:
288
+ raise ValueError(f"Invalid value for 'type': {type}")
289
+ if scope not in ['frame', 'image', 'all']:
290
+ raise ValueError(f"Invalid value for 'scope': {scope}")
291
+
292
+ annots = []
293
+ for ann in annotations:
294
+ ann_scope = 'image' if ann.get('index', None) is None else 'frame'
295
+ if (type == 'all' or ann['type'] == type) and (scope == 'all' or scope == ann_scope):
296
+ annots.append(ann)
297
+ return annots
298
+
299
+ def get_annotations(self,
300
+ index: int,
301
+ type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
302
+ scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
303
+ """
304
+ Returns the annotations of the image at the given index.
305
+
306
+ Args:
307
+ index (int): Index of the image.
308
+ type (str): The type of the annotations. It can be 'label', 'category', 'segmentation' or 'all'.
309
+ scope (str): The scope of the annotations. It can be 'frame', 'image' or 'all'.
310
+
311
+ Returns:
312
+ List[Dict]: The annotations of the image.
313
+ """
314
+ if index >= len(self):
315
+ raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
316
+ imginfo = self._get_image_metainfo(index)
317
+ return self._get_annotations_internal(imginfo['annotations'], type=type, scope=scope)
318
+
319
+ @staticmethod
320
+ def read_number_of_frames(filepath: str) -> int:
321
+ # if is dicom
322
+ if is_dicom(filepath):
323
+ ds = pydicom.dcmread(filepath)
324
+ return ds.NumberOfFrames if hasattr(ds, 'NumberOfFrames') else 1
325
+ # if is a video
326
+ elif filepath.endswith('.mp4') or filepath.endswith('.avi'):
327
+ cap = cv2.VideoCapture(filepath)
328
+ return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
329
+ # if is a image
330
+ elif filepath.endswith('.png') or filepath.endswith('.jpg') or filepath.endswith('.jpeg'):
331
+ return 1
332
+ else:
333
+ raise ValueError(f"Unsupported file type: {filepath}")
334
+
335
+ def get_resources_ids(self) -> list[str]:
336
+ return [self.__getitem_internal(i, only_load_metainfo=True)['metainfo']['id'] for i in self.subset_indices]
337
+
338
+ def _get_labels_set(self, framed: bool) -> tuple[dict, dict[str, dict[str, int]]]:
339
+ """
340
+ Returns the set of labels and a dictionary that maps labels to integers.
341
+
342
+ Returns:
343
+ Tuple[List[str], Dict[str, int]]: The set of labels and the dictionary that maps labels to integers
344
+ """
345
+
346
+ scope = 'frame' if framed else 'image'
347
+
348
+ multilabel_set = set()
349
+ segmentation_labels = set()
350
+ multiclass_set = set()
351
+
352
+ for i in range(len(self)):
353
+ anns = self.get_annotations(i, type='label', scope=scope)
354
+ multilabel_set.update([ann['name'] for ann in anns])
355
+
356
+ anns = self.get_annotations(i, type='segmentation', scope=scope)
357
+ segmentation_labels.update([ann['name'] for ann in anns])
358
+
359
+ anns = self.get_annotations(i, type='category', scope=scope)
360
+ multiclass_set.update([(ann['name'], ann['value']) for ann in anns])
361
+
362
+ multilabel_set = sorted(list(multilabel_set))
363
+ multiclass_set = sorted(list(multiclass_set))
364
+ segmentation_labels = sorted(list(segmentation_labels))
365
+
366
+ multilabel2code = {label: idx for idx, label in enumerate(multilabel_set)}
367
+ segmentation_label2code = {label: idx+1 for idx, label in enumerate(segmentation_labels)}
368
+ multiclass2code = {label: idx for idx, label in enumerate(multiclass_set)}
369
+
370
+ sets = {'multilabel': multilabel_set,
371
+ 'segmentation': segmentation_labels,
372
+ 'multiclass': multiclass_set}
373
+ codes_map = {'multilabel': multilabel2code,
374
+ 'segmentation': segmentation_label2code,
375
+ 'multiclass': multiclass2code}
376
+ return sets, codes_map
377
+
378
+ def get_framelabel_distribution(self, normalize=False) -> dict[str, float]:
379
+ """
380
+ Returns the distribution of labels in the dataset.
381
+
382
+ Returns:
383
+ Dict[str, int]: The distribution of labels in the dataset.
384
+ """
385
+ label_distribution = {label: 0 for label in self.frame_labels_set}
386
+ for imginfo in self.images_metainfo:
387
+ for ann in imginfo['annotations']:
388
+ if ann['type'] == 'label' and ann['index'] is not None:
389
+ label_distribution[ann['name']] += 1
390
+
391
+ if normalize:
392
+ total = sum(label_distribution.values())
393
+ if total == 0:
394
+ return label_distribution
395
+ label_distribution = {k: v/total for k, v in label_distribution.items()}
396
+ return label_distribution
397
+
398
+ def get_segmentationlabel_distribution(self, normalize=False) -> dict[str, float]:
399
+ """
400
+ Returns the distribution of segmentation labels in the dataset.
401
+
402
+ Returns:
403
+ Dict[str, int]: The distribution of segmentation labels in the dataset.
404
+ """
405
+ label_distribution = {label: 0 for label in self.segmentation_labels_set}
406
+ for imginfo in self.images_metainfo:
407
+ if 'annotations' in imginfo and imginfo['annotations'] is not None:
408
+ for ann in imginfo['annotations']:
409
+ if ann['type'] == 'segmentation':
410
+ label_distribution[ann['name']] += 1
411
+
412
+ if normalize:
413
+ total = sum(label_distribution.values())
414
+ if total == 0:
415
+ return label_distribution
416
+ label_distribution = {k: v/total for k, v in label_distribution.items()}
417
+ return label_distribution
418
+
419
+ def _check_integrity(self):
420
+ for imginfo in self.images_metainfo:
421
+ if not os.path.isfile(os.path.join(self.dataset_dir, imginfo['file'])):
422
+ raise DatamintDatasetException(f"Image file {imginfo['file']} not found.")
423
+
424
+ def _get_datasetinfo(self) -> dict:
425
+ all_datasets = self.api_handler.get_datasets()
426
+
427
+ value_to_search = self.dataset_id
428
+ field_to_search = 'id'
429
+
430
+ for d in all_datasets:
431
+ if d[field_to_search] == value_to_search:
432
+ return d
433
+
434
+ available_datasets = [(d['name'], d['id']) for d in all_datasets]
435
+ raise DatamintDatasetException(
436
+ f"Dataset with {field_to_search} '{value_to_search}' not found." +
437
+ f" Available datasets: {available_datasets}"
438
+ )
439
+
440
+ def get_info(self) -> dict:
441
+ project = self.api_handler.get_project_by_name(self.project_name)
442
+ if 'error' in project:
443
+ available_projects = project['all_projects']
444
+ raise DatamintDatasetException(
445
+ f"Project with name '{self.project_name}' not found. Available projects: {available_projects}"
446
+ )
447
+ return project
448
+
449
+ def _run_request(self, session, request_args) -> requests.Response:
450
+ response = session.request(**request_args)
451
+ if response.status_code == 400:
452
+ _LOGGER.error(f"Bad request: {response.text}")
453
+ response.raise_for_status()
454
+ return response
455
+
456
+ def _get_jwttoken(self, dataset_id, session) -> str:
457
+ if dataset_id is None:
458
+ raise ValueError("Dataset ID is required to download the dataset.")
459
+ request_params = {
460
+ 'method': 'GET',
461
+ 'url': f'{self.server_url}/datasets/{dataset_id}/download/png',
462
+ 'headers': {'apikey': self.api_key},
463
+ 'stream': True
464
+ }
465
+ _LOGGER.debug(f"Getting jwt token for dataset {dataset_id}...")
466
+ response = self._run_request(session, request_params)
467
+ progress_bar = None
468
+ number_processed_images = 0
469
+
470
+ # check if the response is a stream of data and everything is ok
471
+ if response.status_code != 200:
472
+ msg = f"Getting jwt token failed with status code={response.status_code}: {response.text}"
473
+ raise DatamintDatasetException(msg)
474
+
475
+ try:
476
+ response_iterator = response.iter_lines(decode_unicode=True)
477
+ for line in response_iterator:
478
+ line = line.strip()
479
+ if 'event: error' in line:
480
+ error_msg = line+'\n'
481
+ error_msg += '\n'.join(response_iterator)
482
+ raise DatamintDatasetException(f"Getting jwt token failed:\n{error_msg}")
483
+ if not line.startswith('data:'):
484
+ continue
485
+ dataline = yaml.safe_load(line)['data']
486
+ if 'zip' in dataline:
487
+ _LOGGER.debug(f"Got jwt token for dataset {dataset_id}")
488
+ return dataline['zip'] # Function normally ends here
489
+ elif 'processedImages' in dataline:
490
+ if progress_bar is None:
491
+ total_size = int(dataline['totalImages'])
492
+ progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
493
+ processed_images = int(dataline['processedImages'])
494
+ if number_processed_images < processed_images:
495
+ progress_bar.update(processed_images - number_processed_images)
496
+ number_processed_images = processed_images
497
+ else:
498
+ _LOGGER.warning(f"Unknown data line: {dataline}")
499
+ except Exception as e:
500
+ raise e
501
+ finally:
502
+ if progress_bar is not None:
503
+ progress_bar.close()
504
+
505
+ raise DatamintDatasetException("Getting jwt token failed! No dataline with 'zip' entry found.")
506
+
507
+ def __repr__(self) -> str:
508
+ """
509
+ Example:
510
+ .. code-block:: python
511
+
512
+ print(dataset)
513
+
514
+ Output:
515
+
516
+ .. code-block:: text
517
+
518
+ Dataset DatamintDataset
519
+ Number of datapoints: 3
520
+ Root location: /home/user/.datamint/datasets
521
+
522
+ """
523
+ head = f"Dataset {self.project_name}"
524
+ body = [f"Number of datapoints: {self.__len__()}"]
525
+ if self.root is not None:
526
+ body.append(f"Location: {self.dataset_dir}")
527
+
528
+ # Add filter information to representation
529
+ if self.include_annotators is not None:
530
+ body += [f"Including only annotators: {self.include_annotators}"]
531
+ if self.exclude_annotators is not None:
532
+ body += [f"Excluding annotators: {self.exclude_annotators}"]
533
+ if self.include_segmentation_names is not None:
534
+ body += [f"Including only segmentations: {self.include_segmentation_names}"]
535
+ if self.exclude_segmentation_names is not None:
536
+ body += [f"Excluding segmentations: {self.exclude_segmentation_names}"]
537
+ if self.include_image_label_names is not None:
538
+ body += [f"Including only image labels: {self.include_image_label_names}"]
539
+ if self.exclude_image_label_names is not None:
540
+ body += [f"Excluding image labels: {self.exclude_image_label_names}"]
541
+ if self.include_frame_label_names is not None:
542
+ body += [f"Including only frame labels: {self.include_frame_label_names}"]
543
+ if self.exclude_frame_label_names is not None:
544
+ body += [f"Excluding frame labels: {self.exclude_frame_label_names}"]
545
+
546
+ lines = [head] + [" " * 4 + line for line in body]
547
+ return "\n".join(lines)
548
+
549
+ def download_project(self):
550
+ from torchvision.datasets.utils import extract_archive
551
+
552
+ dataset_info = self._get_datasetinfo()
553
+ self.dataset_id = dataset_info['id']
554
+ self.last_updaded_at = dataset_info['updated_at']
555
+
556
+ self.api_handler.download_project(self.project_info['id'],
557
+ self.dataset_zippath,
558
+ all_annotations=self.all_annotations,
559
+ include_unannotated=self.include_unannotated)
560
+ _LOGGER.debug(f"Downloaded dataset")
561
+ downloaded_size = os.path.getsize(self.dataset_zippath)
562
+ if downloaded_size == 0:
563
+ raise DatamintDatasetException("Download failed.")
564
+
565
+ if os.path.exists(self.dataset_dir):
566
+ _LOGGER.info(f"Deleting existing dataset directory: {self.dataset_dir}")
567
+ shutil.rmtree(self.dataset_dir)
568
+ extract_archive(self.dataset_zippath,
569
+ self.dataset_dir,
570
+ remove_finished=True
571
+ )
572
+ datasetjson = os.path.join(self.dataset_dir, 'dataset.json')
573
+ with open(datasetjson, 'r') as file:
574
+ self.metainfo = json.load(file)
575
+ if 'updated_at' not in self.metainfo:
576
+ self.metainfo['updated_at'] = self.last_updaded_at
577
+ else:
578
+ # if self.last_updated_at is newer than the one in the dataset, update it
579
+ try:
580
+ if datetime.fromisoformat(self.metainfo['updated_at']) < datetime.fromisoformat(self.last_updaded_at):
581
+ _LOGGER.warning(f"Inconsistent updated_at dates detected ({self.metainfo['updated_at']} < {self.last_updaded_at})." +
582
+ f"Fixing it to {self.last_updaded_at}")
583
+ self.metainfo['updated_at'] = self.last_updaded_at
584
+ except Exception as e:
585
+ _LOGGER.warning(f"Failed to parse updated_at date: {e}")
586
+
587
+ # Add all_annotations to the metadata
588
+ self.metainfo['all_annotations'] = self.all_annotations
589
+
590
+ # save the updated_at date
591
+ with open(datasetjson, 'w') as file:
592
+ json.dump(self.metainfo, file)
593
+
594
+ def _load_image(self, filepath: str, index: int = None) -> tuple[torch.Tensor, pydicom.FileDataset]:
595
+ if os.path.isdir(filepath):
596
+ raise NotImplementedError("Loading a image from a directory is not supported yet.")
597
+
598
+ if self.return_frame_by_frame:
599
+ img, ds = read_array_normalized(filepath, return_metainfo=True, index=index)
600
+ else:
601
+ img, ds = read_array_normalized(filepath, return_metainfo=True)
602
+
603
+ if img.dtype == np.uint16:
604
+ # Pytorch doesn't support uint16
605
+ if self.__logged_uint16_conversion == False:
606
+ _LOGGER.info("Original image is uint16, converting to uint8")
607
+ self.__logged_uint16_conversion = True
608
+
609
+ # min-max normalization
610
+ img = img.astype(np.float32)
611
+ img = (img - img.min()) / (img.max() - img.min()) * 255
612
+ img = img.astype(np.uint8)
613
+
614
+ img = torch.from_numpy(img).contiguous()
615
+ if isinstance(img, torch.ByteTensor):
616
+ img = img.to(dtype=torch.get_default_dtype()).div(255)
617
+
618
+ return img, ds
619
+
620
+ def _get_image_metainfo(self, index: int, bypass_subset_indices=False) -> dict[str, Any]:
621
+ if bypass_subset_indices == False:
622
+ index = self.subset_indices[index]
623
+ if self.return_frame_by_frame:
624
+ # Find the correct filepath and index
625
+ resource_id, frame_index = self.__find_index(index)
626
+
627
+ img_metainfo = self.images_metainfo[resource_id]
628
+ img_metainfo = dict(img_metainfo) # copy
629
+ # insert frame index
630
+ img_metainfo['frame_index'] = frame_index
631
+ img_metainfo['annotations'] = [ann for ann in img_metainfo['annotations']
632
+ if ann['index'] is None or ann['index'] == frame_index]
633
+ else:
634
+ img_metainfo = self.images_metainfo[index]
635
+ return img_metainfo
636
+
637
+ def __find_index(self, index: int) -> tuple[int, int]:
638
+ frame_index = index
639
+ for i, num_frames in enumerate(self.num_frames_per_resource):
640
+ if frame_index < num_frames:
641
+ break
642
+ frame_index -= num_frames
643
+ else:
644
+ raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
645
+
646
+ return i, frame_index
647
+
648
+ def __getitem_internal(self, index: int, only_load_metainfo=False) -> dict[str, Any]:
649
+ if self.return_frame_by_frame:
650
+ resource_index, frame_idx = self.__find_index(index)
651
+ else:
652
+ resource_index = index
653
+ frame_idx = None
654
+ img_metainfo = self._get_image_metainfo(index, bypass_subset_indices=True)
655
+
656
+ if only_load_metainfo:
657
+ return {'metainfo': img_metainfo}
658
+
659
+ filepath = os.path.join(self.dataset_dir, img_metainfo['file'])
660
+
661
+ # Can be multi-frame, Gray-scale and/or RGB. So the shape is really variable, but it's always a numpy array.
662
+ img, ds = self._load_image(filepath, frame_idx)
663
+
664
+ ret = {'image': img}
665
+
666
+ if self.return_dicom:
667
+ ret['dicom'] = ds
668
+ if self.return_metainfo:
669
+ ret['metainfo'] = {k: v for k, v in img_metainfo.items() if k != 'annotations'}
670
+ if self.return_annotations:
671
+ ret['annotations'] = img_metainfo['annotations']
672
+
673
+ return ret
674
+
675
+ def _filter_annotations(self, annotations: list[dict]) -> list[dict]:
676
+ """
677
+ Filter annotations based on the filtering settings.
678
+
679
+ Args:
680
+ annotations: list of annotations
681
+
682
+ Returns:
683
+ list[dict]: filtered list of annotations
684
+ """
685
+ if annotations is None:
686
+ return []
687
+
688
+ filtered_annotations = []
689
+ for ann in annotations:
690
+ # Filter by annotator
691
+ if not self._should_include_annotator(ann['added_by']):
692
+ continue
693
+
694
+ # Filter by annotation type and name
695
+ if ann['type'] == 'segmentation':
696
+ if not self._should_include_segmentation(ann['name']):
697
+ continue
698
+ elif ann['type'] == 'label':
699
+ # Check if it's a frame or image label
700
+ if ann.get('index', None) is None:
701
+ # Image label
702
+ if not self._should_include_image_label(ann['name']):
703
+ continue
704
+ else:
705
+ # Frame label
706
+ if not self._should_include_frame_label(ann['name']):
707
+ continue
708
+
709
+ # If we reach here, the annotation passed all filters
710
+ filtered_annotations.append(ann)
711
+
712
+ return filtered_annotations
713
+
714
+ def __getitem__(self, index: int) -> dict[str, Any]:
715
+ """
716
+ Args:
717
+ index (int): Index
718
+
719
+ Returns:
720
+ dict: A dictionary containing three keys: 'image', 'metainfo' and 'annotations'.
721
+ """
722
+ if index >= len(self):
723
+ raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
724
+
725
+ return self.__getitem_internal(self.subset_indices[index])
726
+
727
+ def __iter__(self):
728
+ for i in range(len(self)):
729
+ yield self[i]
730
+
731
+ def __len__(self) -> int:
732
+ return len(self.subset_indices)
733
+
734
+ def _check_version(self):
735
+ metainfo_path = os.path.join(self.dataset_dir, 'dataset.json')
736
+ if not os.path.exists(metainfo_path):
737
+ self.download_project()
738
+ return
739
+ with open(metainfo_path, 'r') as file:
740
+ local_dataset_info = json.load(file)
741
+ local_updated_at = local_dataset_info.get('updated_at', None)
742
+ local_all_annotations = local_dataset_info.get('all_annotations', None)
743
+
744
+ try:
745
+ external_metadata_info = self._get_datasetinfo()
746
+ server_updated_at = external_metadata_info['updated_at']
747
+ except Exception as e:
748
+ _LOGGER.warning(f"Failed to check for updates in {self.project_name}: {e}")
749
+ return
750
+
751
+ _LOGGER.debug(f"Local updated at: {local_updated_at}, Server updated at: {server_updated_at}")
752
+
753
+ # Check if all_annotations changed or doesn't exist
754
+ annotations_changed = local_all_annotations != self.all_annotations
755
+
756
+ if local_updated_at is None or local_updated_at < server_updated_at or annotations_changed:
757
+ if annotations_changed:
758
+ _LOGGER.info(
759
+ f"The 'all_annotations' parameter has changed. Previous: {local_all_annotations}, Current: {self.all_annotations}."
760
+ )
761
+ else:
762
+ _LOGGER.info(
763
+ f"A newer version of the dataset is available. Your version: {local_updated_at}." +
764
+ f" Last version: {server_updated_at}."
765
+ )
766
+ self.download_project()
767
+ else:
768
+ _LOGGER.info('Local version is up to date with the latest version.')
769
+
770
+ def __add__(self, other):
771
+ from torch.utils.data import ConcatDataset
772
+ return ConcatDataset([self, other])
773
+
774
+ def get_dataloader(self, *args, **kwargs) -> DataLoader:
775
+ """
776
+ Returns a DataLoader for the dataset.
777
+ This is a wrapper around the PyTorch DataLoader, with the convinience of using a nice collate_fn
778
+ that properly handles the different types of data in this dataset.
779
+
780
+ Args:
781
+ *args: Positional arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
782
+ **kwargs: Keyword arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
783
+
784
+
785
+ """
786
+ return DataLoader(self,
787
+ *args,
788
+ collate_fn=self.get_collate_fn(),
789
+ **kwargs)
790
+
791
+ def get_collate_fn(self) -> Callable:
792
+ def collate_fn(batch: dict) -> dict:
793
+ keys = batch[0].keys()
794
+ collated_batch = {}
795
+ for key in keys:
796
+ collated_batch[key] = [item[key] for item in batch]
797
+ if isinstance(collated_batch[key][0], torch.Tensor):
798
+ # check if every tensor has the same shape
799
+ shapes = [tensor.shape for tensor in collated_batch[key]]
800
+ if all(shape == shapes[0] for shape in shapes):
801
+ collated_batch[key] = torch.stack(collated_batch[key])
802
+ else:
803
+ _LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}. ")
804
+ elif isinstance(collated_batch[key][0], np.ndarray):
805
+ collated_batch[key] = np.stack(collated_batch[key])
806
+
807
+ return collated_batch
808
+
809
+ return collate_fn
810
+
811
+ def subset(self, indices: list[int]) -> 'DatamintBaseDataset':
812
+ if len(self.subset_indices) > self.dataset_length:
813
+ raise ValueError(f"Subset indices must be less than the dataset length: {self.dataset_length}")
814
+
815
+ self.subset_indices = indices
816
+
817
+ return self
818
+
819
+ def _should_include_annotator(self, annotator_id: str) -> bool:
820
+ """
821
+ Check if an annotator should be included based on the filtering settings.
822
+
823
+ Args:
824
+ annotator_id: The ID of the annotator to check
825
+
826
+ Returns:
827
+ bool: True if the annotator should be included, False otherwise
828
+ """
829
+ if self.include_annotators is not None:
830
+ return annotator_id in self.include_annotators
831
+ if self.exclude_annotators is not None:
832
+ return annotator_id not in self.exclude_annotators
833
+ return True
834
+
835
+ def _should_include_segmentation(self, segmentation_name: str) -> bool:
836
+ """
837
+ Check if a segmentation should be included based on the filtering settings.
838
+
839
+ Args:
840
+ segmentation_name: The name of the segmentation to check
841
+
842
+ Returns:
843
+ bool: True if the segmentation should be included, False otherwise
844
+ """
845
+ if self.include_segmentation_names is not None:
846
+ return segmentation_name in self.include_segmentation_names
847
+ if self.exclude_segmentation_names is not None:
848
+ return segmentation_name not in self.exclude_segmentation_names
849
+ return True
850
+
851
+ def _should_include_image_label(self, label_name: str) -> bool:
852
+ """
853
+ Check if an image label should be included based on the filtering settings.
854
+
855
+ Args:
856
+ label_name: The name of the image label to check
857
+
858
+ Returns:
859
+ bool: True if the image label should be included, False otherwise
860
+ """
861
+ if self.include_image_label_names is not None:
862
+ return label_name in self.include_image_label_names
863
+ if self.exclude_image_label_names is not None:
864
+ return label_name not in self.exclude_image_label_names
865
+ return True
866
+
867
+ def _should_include_frame_label(self, label_name: str) -> bool:
868
+ """
869
+ Check if a frame label should be included based on the filtering settings.
870
+
871
+ Args:
872
+ label_name: The name of the frame label to check
873
+
874
+ Returns:
875
+ bool: True if the frame label should be included, False otherwise
876
+ """
877
+ if self.include_frame_label_names is not None:
878
+ return label_name in self.include_frame_label_names
879
+ if self.exclude_frame_label_names is not None:
880
+ return label_name not in self.exclude_frame_label_names
881
+ return True