datamint 1.6.3.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

@@ -15,9 +15,12 @@ import torch
15
15
  from torch import Tensor
16
16
  from datamint.apihandler.base_api_handler import DatamintException
17
17
  from medimgkit.dicom_utils import is_dicom
18
- import cv2
19
18
  from medimgkit.io_utils import read_array_normalized
20
19
  from datetime import datetime
20
+ from pathlib import Path
21
+ from mimetypes import guess_extension
22
+ from datamint.dataset.annotation import Annotation
23
+ import cv2
21
24
 
22
25
  _LOGGER = logging.getLogger(__name__)
23
26
 
@@ -27,12 +30,11 @@ class DatamintDatasetException(DatamintException):
27
30
 
28
31
 
29
32
  class DatamintBaseDataset:
30
- """
31
- Class to download and load datasets from the Datamint API.
33
+ """Class to download and load datasets from the Datamint API.
32
34
 
33
35
  Args:
34
- root: Root directory of dataset where data already exists or will be downloaded.
35
36
  project_name: Name of the project to download.
37
+ root: Root directory of dataset where data already exists or will be downloaded.
36
38
  auto_update: If True, the dataset will be checked for updates and downloaded if necessary.
37
39
  api_key: API key to access the Datamint API. If not provided, it will look for the
38
40
  environment variable 'DATAMINT_API_KEY'. Not necessary if
@@ -41,68 +43,114 @@ class DatamintBaseDataset:
41
43
  return_metainfo: If True, the metainfo of the image will be returned.
42
44
  return_annotations: If True, the annotations of the image will be returned.
43
45
  return_frame_by_frame: If True, each frame of a video/DICOM/3d-image will be returned separately.
44
- include_unannotated: If True, images without annotations will be included. If False, images without annotations will be discarded.
46
+ include_unannotated: If True, images without annotations will be included.
45
47
  all_annotations: If True, all annotations will be downloaded, including the ones that are not set as closed/done.
46
48
  server_url: URL of the Datamint server. If not provided, it will use the default server.
47
- include_annotators: List of annotators to include. If None, all annotators will be included. See parameter ``exclude_annotators``.
48
- exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded. See parameter ``include_annotators``.
49
+ include_annotators: List of annotators to include. If None, all annotators will be included.
50
+ exclude_annotators: List of annotators to exclude. If None, no annotators will be excluded.
49
51
  include_segmentation_names: List of segmentation names to include. If None, all segmentations will be included.
50
52
  exclude_segmentation_names: List of segmentation names to exclude. If None, no segmentations will be excluded.
51
53
  include_image_label_names: List of image label names to include. If None, all image labels will be included.
52
54
  exclude_image_label_names: List of image label names to exclude. If None, no image labels will be excluded.
53
55
  include_frame_label_names: List of frame label names to include. If None, all frame labels will be included.
54
56
  exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
55
-
56
57
  """
57
58
 
58
59
  DATAMINT_DEFAULT_DIR = ".datamint"
59
60
  DATAMINT_DATASETS_DIR = "datasets"
60
61
 
61
- def __init__(self,
62
- project_name: str,
63
- root: str | None = None,
64
- auto_update: bool = True,
65
- api_key: Optional[str] = None,
66
- server_url: Optional[str] = None,
67
- return_dicom: bool = False,
68
- return_metainfo: bool = True,
69
- return_annotations: bool = True,
70
- return_frame_by_frame: bool = False,
71
- include_unannotated: bool = True,
72
- all_annotations: bool = False,
73
- # filtering parameters
74
- include_annotators: Optional[list[str]] = None,
75
- exclude_annotators: Optional[list[str]] = None,
76
- include_segmentation_names: Optional[list[str]] = None,
77
- exclude_segmentation_names: Optional[list[str]] = None,
78
- include_image_label_names: Optional[list[str]] = None,
79
- exclude_image_label_names: Optional[list[str]] = None,
80
- include_frame_label_names: Optional[list[str]] = None,
81
- exclude_frame_label_names: Optional[list[str]] = None
82
- ):
83
- from datamint.apihandler.api_handler import APIHandler
62
+ def __init__(
63
+ self,
64
+ project_name: str,
65
+ root: str | None = None,
66
+ auto_update: bool = True,
67
+ api_key: Optional[str] = None,
68
+ server_url: Optional[str] = None,
69
+ return_dicom: bool = False,
70
+ return_metainfo: bool = True,
71
+ return_annotations: bool = True,
72
+ return_frame_by_frame: bool = False,
73
+ include_unannotated: bool = True,
74
+ all_annotations: bool = False,
75
+ # Filtering parameters
76
+ include_annotators: Optional[list[str]] = None,
77
+ exclude_annotators: Optional[list[str]] = None,
78
+ include_segmentation_names: Optional[list[str]] = None,
79
+ exclude_segmentation_names: Optional[list[str]] = None,
80
+ include_image_label_names: Optional[list[str]] = None,
81
+ exclude_image_label_names: Optional[list[str]] = None,
82
+ include_frame_label_names: Optional[list[str]] = None,
83
+ exclude_frame_label_names: Optional[list[str]] = None,
84
+ ):
85
+ self._validate_inputs(project_name, include_annotators, exclude_annotators,
86
+ include_segmentation_names, exclude_segmentation_names,
87
+ include_image_label_names, exclude_image_label_names,
88
+ include_frame_label_names, exclude_frame_label_names)
89
+
90
+ self._initialize_config(
91
+ project_name, auto_update, all_annotations, return_dicom,
92
+ return_metainfo, return_annotations, return_frame_by_frame,
93
+ include_unannotated, include_annotators, exclude_annotators,
94
+ include_segmentation_names, exclude_segmentation_names,
95
+ include_image_label_names, exclude_image_label_names,
96
+ include_frame_label_names, exclude_frame_label_names
97
+ )
84
98
 
99
+ self._setup_api_handler(server_url, api_key, auto_update)
100
+ self._setup_directories(root)
101
+ self._setup_dataset()
102
+ self._post_process_data()
103
+
104
+ def _validate_inputs(
105
+ self,
106
+ project_name: str,
107
+ include_annotators: Optional[list[str]],
108
+ exclude_annotators: Optional[list[str]],
109
+ include_segmentation_names: Optional[list[str]],
110
+ exclude_segmentation_names: Optional[list[str]],
111
+ include_image_label_names: Optional[list[str]],
112
+ exclude_image_label_names: Optional[list[str]],
113
+ include_frame_label_names: Optional[list[str]],
114
+ exclude_frame_label_names: Optional[list[str]],
115
+ ) -> None:
116
+ """Validate input parameters."""
85
117
  if project_name is None:
86
118
  raise ValueError("project_name is required.")
87
119
 
120
+ # Validate mutually exclusive filtering parameters
121
+ filter_pairs = [
122
+ (include_annotators, exclude_annotators, "annotators"),
123
+ (include_segmentation_names, exclude_segmentation_names, "segmentation_names"),
124
+ (include_image_label_names, exclude_image_label_names, "image_label_names"),
125
+ (include_frame_label_names, exclude_frame_label_names, "frame_label_names"),
126
+ ]
127
+
128
+ for include_param, exclude_param, param_name in filter_pairs:
129
+ if include_param is not None and exclude_param is not None:
130
+ raise ValueError(f"Cannot set both include_{param_name} and exclude_{param_name} at the same time")
131
+
132
+ def _initialize_config(
133
+ self,
134
+ project_name: str,
135
+ auto_update: bool,
136
+ all_annotations: bool,
137
+ return_dicom: bool,
138
+ return_metainfo: bool,
139
+ return_annotations: bool,
140
+ return_frame_by_frame: bool,
141
+ include_unannotated: bool,
142
+ include_annotators: Optional[list[str]],
143
+ exclude_annotators: Optional[list[str]],
144
+ include_segmentation_names: Optional[list[str]],
145
+ exclude_segmentation_names: Optional[list[str]],
146
+ include_image_label_names: Optional[list[str]],
147
+ exclude_image_label_names: Optional[list[str]],
148
+ include_frame_label_names: Optional[list[str]],
149
+ exclude_frame_label_names: Optional[list[str]],
150
+ ) -> None:
151
+ """Initialize configuration parameters."""
152
+ self.project_name = project_name
88
153
  self.all_annotations = all_annotations
89
- self.api_handler = APIHandler(root_url=server_url, api_key=api_key,
90
- check_connection=auto_update)
91
- self.server_url = self.api_handler.root_url
92
- if root is None:
93
- # store them in the home directory
94
- root = os.path.join(os.path.expanduser("~"),
95
- DatamintBaseDataset.DATAMINT_DEFAULT_DIR)
96
- root = os.path.join(root, DatamintBaseDataset.DATAMINT_DATASETS_DIR)
97
- if not os.path.exists(root):
98
- os.makedirs(root)
99
- elif isinstance(root, str):
100
- root = os.path.expanduser(root)
101
- if not os.path.isdir(root):
102
- raise NotADirectoryError(f"Root directory not found: {root}")
103
-
104
- self.root = root
105
-
106
154
  self.return_dicom = return_dicom
107
155
  self.return_metainfo = return_metainfo
108
156
  self.return_frame_by_frame = return_frame_by_frame
@@ -120,108 +168,153 @@ class DatamintBaseDataset:
120
168
  self.include_frame_label_names = include_frame_label_names
121
169
  self.exclude_frame_label_names = exclude_frame_label_names
122
170
 
123
- # Validate filtering parameters
124
- if include_annotators is not None and exclude_annotators is not None:
125
- raise ValueError("Cannot set both include_annotators and exclude_annotators at the same time")
126
-
127
- if include_segmentation_names is not None and exclude_segmentation_names is not None:
128
- raise ValueError("Cannot set both include_segmentation_names and exclude_segmentation_names at the same time")
129
-
130
- if include_image_label_names is not None and exclude_image_label_names is not None:
131
- raise ValueError("Cannot set both include_image_label_names and exclude_image_label_names at the same time")
132
-
133
- if include_frame_label_names is not None and exclude_frame_label_names is not None:
134
- raise ValueError("Cannot set both include_frame_label_names and exclude_frame_label_names at the same time")
171
+ # Internal state
172
+ self.__logged_uint16_conversion = False
173
+ self.auto_update = auto_update
135
174
 
136
- self.project_name = project_name
137
- dataset_name = project_name
175
+ def _setup_api_handler(self, server_url: Optional[str], api_key: Optional[str], auto_update: bool) -> None:
176
+ """Setup API handler and validate connection."""
177
+ from datamint.apihandler.api_handler import APIHandler
138
178
 
139
- self.dataset_dir = os.path.join(root, dataset_name)
140
- self.dataset_zippath = os.path.join(root, f'{dataset_name}.zip')
179
+ self.api_handler = APIHandler(
180
+ root_url=server_url,
181
+ api_key=api_key,
182
+ check_connection=auto_update
183
+ )
184
+ self.server_url = self.api_handler.root_url
185
+ self.api_key = self.api_handler.api_key
141
186
 
142
- local_dataset_exists = os.path.exists(os.path.join(self.dataset_dir, 'dataset.json'))
187
+ if self.api_key is None:
188
+ _LOGGER.warning(
189
+ "API key not provided. If you want to download data, please provide an API key, "
190
+ f"either by passing it as an argument, "
191
+ f"setting environment variable {configs.ENV_VARS[configs.APIKEY_KEY]} or "
192
+ "using datamint-config command line tool."
193
+ )
143
194
 
144
- if local_dataset_exists and auto_update == False:
145
- # In this case, we don't need to check the API, so we don't need the id.
146
- self.dataset_id = None
195
+ def _setup_directories(self, root: str | None) -> None:
196
+ """Setup root and dataset directories."""
197
+ if root is None:
198
+ root = os.path.join(
199
+ os.path.expanduser("~"),
200
+ self.DATAMINT_DEFAULT_DIR,
201
+ self.DATAMINT_DATASETS_DIR
202
+ )
203
+ os.makedirs(root, exist_ok=True)
147
204
  else:
148
- self.project_info = self.get_info()
149
- self.dataset_id = self.project_info['dataset_id']
205
+ root = os.path.expanduser(root)
206
+ if not os.path.isdir(root):
207
+ raise NotADirectoryError(f"Root directory not found: {root}")
150
208
 
151
- self.api_key = self.api_handler.api_key
152
- if self.api_key is None:
153
- _LOGGER.warning("API key not provided. If you want to download data, please provide an API key, " +
154
- f"either by passing it as an argument," +
155
- f"setting environment variable {configs.ENV_VARS[configs.APIKEY_KEY]} or " +
156
- "using datamint-config command line tool."
157
- )
158
-
159
- # Download/Updates the dataset, if necessary.
160
- if local_dataset_exists:
161
- _LOGGER.info(f"Dataset directory already exists: {self.dataset_dir}")
162
- if auto_update:
209
+ self.root = root
210
+ self.dataset_dir = os.path.join(root, self.project_name)
211
+ self.dataset_zippath = os.path.join(root, f'{self.project_name}.zip')
212
+
213
+ if not os.path.exists(self.dataset_dir):
214
+ os.makedirs(self.dataset_dir, exist_ok=True)
215
+ os.makedirs(os.path.join(self.dataset_dir, 'images'), exist_ok=True)
216
+ os.makedirs(os.path.join(self.dataset_dir, 'masks'), exist_ok=True)
217
+
218
+ def _setup_dataset(self) -> None:
219
+ """Setup dataset by downloading or loading existing data."""
220
+ self._server_dataset_info = None
221
+ local_load_success = self._load_metadata()
222
+ self._handle_dataset_download_or_update(local_load_success)
223
+ self._apply_annotation_filters()
224
+
225
+ def _handle_dataset_download_or_update(self, local_load_success: bool) -> None:
226
+ """Handle dataset download or update logic."""
227
+
228
+ if local_load_success:
229
+ _LOGGER.debug(f"Dataset directory already exists: {self.dataset_dir}")
230
+ # Check for updates if auto_update is enabled and we have API access
231
+ if self.auto_update:
163
232
  _LOGGER.info("Checking for updates...")
164
233
  self._check_version()
165
234
  else:
166
- if self.api_key is None:
167
- raise DatamintDatasetException("API key is required to download the dataset.")
168
- _LOGGER.info(f"No data found at {self.dataset_dir}. Downloading...")
169
- self.download_project()
170
-
171
- # Loads the metadata
172
- if not hasattr(self, 'metainfo'):
173
- with open(os.path.join(self.dataset_dir, 'dataset.json'), 'r') as file:
235
+ self._check_version()
236
+
237
+ def _load_metadata(self) -> bool:
238
+ """Load and process dataset metadata."""
239
+ if hasattr(self, 'metainfo'):
240
+ _LOGGER.warning("Metadata already loaded.")
241
+ metadata_path = os.path.join(self.dataset_dir, 'dataset.json')
242
+ if not os.path.isfile(metadata_path):
243
+ # get the server info
244
+ self.project_info = self.get_info()
245
+ self.metainfo = self._get_datasetinfo().copy()
246
+ self.metainfo['updated_at'] = None
247
+ self.metainfo['resources'] = []
248
+ self.metainfo['all_annotations'] = self.all_annotations
249
+ self.images_metainfo = self.metainfo['resources']
250
+ return False
251
+ else:
252
+ with open(metadata_path, 'r') as file:
174
253
  self.metainfo = json.load(file)
175
254
  self.images_metainfo = self.metainfo['resources']
255
+ # Convert annotations from dict to Annotation objects
256
+ self._convert_metainfo_to_clsobj()
257
+ return True
176
258
 
177
- # filter annotations
259
+ def _convert_metainfo_to_clsobj(self):
260
+ for imginfo in self.images_metainfo:
261
+ if 'annotations' in imginfo:
262
+ for ann in imginfo['annotations']:
263
+ if 'resource_id' not in ann:
264
+ ann['resource_id'] = imginfo['id']
265
+ if 'id' not in ann:
266
+ ann['id'] = None
267
+ imginfo['annotations'] = [Annotation.from_dict(ann) if isinstance(ann, dict) else ann
268
+ for ann in imginfo['annotations']]
269
+
270
+ def _apply_annotation_filters(self) -> None:
271
+ """Apply annotation filters and remove unannotated images if needed."""
272
+ # Filter annotations for each image
178
273
  for imginfo in self.images_metainfo:
179
274
  imginfo['annotations'] = self._filter_annotations(imginfo['annotations'])
180
275
 
181
- # filter out images with no annotations.
276
+ # Filter out images with no annotations if needed
182
277
  if self.discard_without_annotations:
183
278
  original_count = len(self.images_metainfo)
184
279
  self.images_metainfo = self._filter_items(self.images_metainfo)
185
280
  _LOGGER.info(f"Discarded {original_count - len(self.images_metainfo)} images without annotations.")
186
281
 
282
+ def _post_process_data(self) -> None:
283
+ """Post-process data after loading metadata."""
187
284
  self._check_integrity()
285
+ self._calculate_dataset_length()
286
+ self._precompute_frame_data()
287
+ self._setup_labels()
188
288
 
189
- # fix images_metainfo labels
190
- # TODO: check tags
191
- # for imginfo in self.images_metainfo:
192
- # if imginfo['frame_labels'] is not None:
193
- # for flabels in imginfo['frame_labels']:
194
- # if flabels['label'] is None:
195
- # flabels['label'] = []
196
- # elif isinstance(flabels['label'], str):
197
- # flabels['label'] = flabels['label'].split(',')
289
+ if self.discard_without_annotations and self.return_frame_by_frame:
290
+ self._filter_unannotated()
198
291
 
292
+ def _calculate_dataset_length(self) -> None:
293
+ """Calculate the total dataset length based on frame-by-frame setting."""
199
294
  if self.return_frame_by_frame:
200
- self.dataset_length = 0
201
- for imginfo in self.images_metainfo:
202
- filepath = os.path.join(self.dataset_dir, imginfo['file'])
203
- self.dataset_length += self.read_number_of_frames(filepath)
295
+ self.dataset_length = sum(
296
+ self.read_number_of_frames(os.path.join(self.dataset_dir, imginfo['file']))
297
+ for imginfo in self.images_metainfo
298
+ )
204
299
  else:
205
300
  self.dataset_length = len(self.images_metainfo)
206
301
 
302
+ def _precompute_frame_data(self) -> None:
303
+ """Precompute frame-related data for efficient indexing."""
207
304
  self.num_frames_per_resource = self.__compute_num_frames_per_resource()
208
-
209
- # Precompute cumulative frame counts for faster index lookup
210
305
  self._cumulative_frames = np.cumsum([0] + self.num_frames_per_resource)
211
-
212
306
  self.subset_indices = list(range(self.dataset_length))
213
- # self.labels_set, self.label2code, self.segmentation_labels, self.segmentation_label2code = self.get_labels_set()
307
+
308
+ def _setup_labels(self) -> None:
309
+ """Setup label sets and mappings."""
214
310
  self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
215
311
  self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
216
- self.__logged_uint16_conversion = False
217
- if self.discard_without_annotations and self.return_frame_by_frame:
218
- # If we are returning frame by frame, we need to filter out frames without segmentations
219
- self._filter_unannotated()
220
312
 
221
313
  def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
314
+ """Filter items that have annotations."""
222
315
  return [img for img in images_metainfo if len(img.get('annotations', []))]
223
316
 
224
- def _filter_unannotated(self):
317
+ def _filter_unannotated(self) -> None:
225
318
  """Filter out frames that don't have any segmentations."""
226
319
  filtered_indices = []
227
320
  for i in range(len(self.subset_indices)):
@@ -229,124 +322,125 @@ class DatamintBaseDataset:
229
322
  annotations = item_meta.get('annotations', [])
230
323
 
231
324
  # Check if there are any segmentation annotations
232
- has_segmentations = any(ann['type'] == 'segmentation' for ann in annotations)
325
+ has_segmentations = any(ann.type == 'segmentation' for ann in annotations)
233
326
 
234
327
  if has_segmentations:
235
328
  filtered_indices.append(self.subset_indices[i])
236
329
 
237
330
  self.subset_indices = filtered_indices
238
- print(f"Filtered dataset: {len(self.subset_indices)} frames with segmentations")
331
+ _LOGGER.debug(f"Filtered dataset: {len(self.subset_indices)} frames with segmentations")
239
332
 
240
333
  def __compute_num_frames_per_resource(self) -> list[int]:
241
- num_frames_per_dicom = []
242
- for imginfo in self.images_metainfo:
243
- filepath = os.path.join(self.dataset_dir, imginfo['file'])
244
- num_frames_per_dicom.append(self.read_number_of_frames(filepath))
245
- return num_frames_per_dicom
334
+ """Compute number of frames for each resource."""
335
+ return [
336
+ self.read_number_of_frames(os.path.join(self.dataset_dir, imginfo['file']))
337
+ for imginfo in self.images_metainfo
338
+ ]
246
339
 
247
340
  @property
248
341
  def frame_labels_set(self) -> list[str]:
249
- """
250
- Returns the set of independent labels in the dataset.
251
- This is more related to multi-label tasks.
252
- """
342
+ """Returns the set of independent labels in the dataset (multi-label tasks)."""
253
343
  return self.frame_lsets['multilabel']
254
344
 
255
345
  @property
256
346
  def frame_categories_set(self) -> list[tuple[str, str]]:
257
- """
258
- Returns the set of categories in the dataset.
259
- This is more related to multi-class tasks.
260
- """
347
+ """Returns the set of categories in the dataset (multi-class tasks)."""
261
348
  return self.frame_lsets['multiclass']
262
349
 
263
350
  @property
264
351
  def image_labels_set(self) -> list[str]:
265
- """
266
- Returns the set of independent labels in the dataset.
267
- This is more related to multi-label tasks.
268
- """
352
+ """Returns the set of independent labels in the dataset (multi-label tasks)."""
269
353
  return self.image_lsets['multilabel']
270
354
 
271
355
  @property
272
356
  def image_categories_set(self) -> list[tuple[str, str]]:
273
- """
274
- Returns the set of categories in the dataset.
275
- This is more related to multi-class tasks.
276
- """
357
+ """Returns the set of categories in the dataset (multi-class tasks)."""
277
358
  return self.image_lsets['multiclass']
278
359
 
279
360
  @property
280
361
  def segmentation_labels_set(self) -> list[str]:
281
- """
282
- Returns the set of segmentation labels in the dataset.
283
- """
362
+ """Returns the set of segmentation labels in the dataset."""
284
363
  return self.frame_lsets['segmentation']
285
364
 
286
- def _get_annotations_internal(self,
287
- annotations: list[dict],
288
- type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
289
- scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
290
- # check parameters
365
+ def _get_annotations_internal(
366
+ self,
367
+ annotations: list[Annotation],
368
+ type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
369
+ scope: Literal['frame', 'image', 'all'] = 'all'
370
+ ) -> list[Annotation]:
371
+ """Internal method to filter annotations by type and scope."""
291
372
  if type not in ['label', 'category', 'segmentation', 'all']:
292
373
  raise ValueError(f"Invalid value for 'type': {type}")
293
374
  if scope not in ['frame', 'image', 'all']:
294
375
  raise ValueError(f"Invalid value for 'scope': {scope}")
295
376
 
296
- annots = []
377
+ filtered_annotations = []
297
378
  for ann in annotations:
298
- ann_scope = 'image' if ann.get('index', None) is None else 'frame'
299
- if (type == 'all' or ann['type'] == type) and (scope == 'all' or scope == ann_scope):
300
- annots.append(ann)
301
- return annots
302
-
303
- def get_annotations(self,
304
- index: int,
305
- type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
306
- scope: Literal['frame', 'image', 'all'] = 'all') -> list[dict]:
307
- """
308
- Returns the annotations of the image at the given index.
379
+ ann_scope = 'image' if ann.index is None else 'frame'
380
+
381
+ type_matches = type == 'all' or ann.type == type
382
+ scope_matches = scope == 'all' or scope == ann_scope
383
+
384
+ if type_matches and scope_matches:
385
+ filtered_annotations.append(ann)
386
+
387
+ return filtered_annotations
388
+
389
+ def get_annotations(
390
+ self,
391
+ index: int,
392
+ type: Literal['label', 'category', 'segmentation', 'all'] = 'all',
393
+ scope: Literal['frame', 'image', 'all'] = 'all'
394
+ ) -> list[Annotation]:
395
+ """Returns the annotations of the image at the given index.
309
396
 
310
397
  Args:
311
- index (int): Index of the image.
312
- type (str): The type of the annotations. It can be 'label', 'category', 'segmentation' or 'all'.
313
- scope (str): The scope of the annotations. It can be 'frame', 'image' or 'all'.
398
+ index: Index of the image.
399
+ type: The type of the annotations. Can be 'label', 'category', 'segmentation' or 'all'.
400
+ scope: The scope of the annotations. Can be 'frame', 'image' or 'all'.
314
401
 
315
402
  Returns:
316
- list[dict]: The annotations of the image.
403
+ The annotations of the image.
317
404
  """
318
405
  if index >= len(self):
319
406
  raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
407
+
320
408
  imginfo = self._get_image_metainfo(index)
321
409
  return self._get_annotations_internal(imginfo['annotations'], type=type, scope=scope)
322
410
 
323
411
  @staticmethod
324
412
  def read_number_of_frames(filepath: str) -> int:
325
- # if is dicom
413
+ """Read the number of frames in a file."""
326
414
  if is_dicom(filepath):
327
415
  ds = pydicom.dcmread(filepath)
328
- return ds.NumberOfFrames if hasattr(ds, 'NumberOfFrames') else 1
329
- # if is a video
330
- elif filepath.endswith('.mp4') or filepath.endswith('.avi'):
416
+ return getattr(ds, 'NumberOfFrames', 1)
417
+ elif filepath.lower().endswith(('.mp4', '.avi')):
331
418
  cap = cv2.VideoCapture(filepath)
332
- return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
333
- # if is a image
334
- elif filepath.endswith('.png') or filepath.endswith('.jpg') or filepath.endswith('.jpeg'):
419
+ try:
420
+ return int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
421
+ finally:
422
+ cap.release()
423
+ elif filepath.lower().endswith(('.png', '.jpg', '.jpeg')):
335
424
  return 1
336
425
  else:
337
426
  raise ValueError(f"Unsupported file type: {filepath}")
338
427
 
339
428
  def get_resources_ids(self) -> list[str]:
340
- return [self.__getitem_internal(i, only_load_metainfo=True)['metainfo']['id'] for i in self.subset_indices]
429
+ """Get list of resource IDs."""
430
+ return [
431
+ self.__getitem_internal(i, only_load_metainfo=True)['metainfo']['id']
432
+ for i in self.subset_indices
433
+ ]
341
434
 
342
435
  def _get_labels_set(self, framed: bool) -> tuple[dict, dict[str, dict[str, int]]]:
343
- """
344
- Returns the set of labels and a dictionary that maps labels to integers.
436
+ """Returns the set of labels and mappings to integers.
437
+
438
+ Args:
439
+ framed: If True, get frame-level labels, otherwise image-level labels.
345
440
 
346
441
  Returns:
347
- Tuple[List[str], Dict[str, int]]: The set of labels and the dictionary that maps labels to integers
442
+ Tuple containing label sets and label-to-code mappings.
348
443
  """
349
-
350
444
  scope = 'frame' if framed else 'image'
351
445
 
352
446
  multilabel_set = set()
@@ -354,100 +448,113 @@ class DatamintBaseDataset:
354
448
  multiclass_set = set()
355
449
 
356
450
  for i in range(len(self)):
357
- anns = self.get_annotations(i, type='label', scope=scope)
358
- multilabel_set.update([ann['name'] for ann in anns])
451
+ # Collect labels by type
452
+ label_anns = self.get_annotations(i, type='label', scope=scope)
453
+ multilabel_set.update(ann.name for ann in label_anns)
359
454
 
360
- anns = self.get_annotations(i, type='segmentation', scope=scope)
361
- segmentation_labels.update([ann['name'] for ann in anns])
455
+ seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
456
+ segmentation_labels.update(ann.name for ann in seg_anns)
362
457
 
363
- anns = self.get_annotations(i, type='category', scope=scope)
364
- multiclass_set.update([(ann['name'], ann['value']) for ann in anns])
458
+ cat_anns = self.get_annotations(i, type='category', scope=scope)
459
+ multiclass_set.update((ann.name, ann.value) for ann in cat_anns)
365
460
 
366
- multilabel_set = sorted(list(multilabel_set))
367
- multiclass_set = sorted(list(multiclass_set))
368
- segmentation_labels = sorted(list(segmentation_labels))
461
+ # Sort and create mappings
462
+ multilabel_list = sorted(multilabel_set)
463
+ multiclass_list = sorted(multiclass_set)
464
+ segmentation_list = sorted(segmentation_labels)
369
465
 
370
- multilabel2code = {label: idx for idx, label in enumerate(multilabel_set)}
371
- segmentation_label2code = {label: idx+1 for idx, label in enumerate(segmentation_labels)}
372
- multiclass2code = {label: idx for idx, label in enumerate(multiclass_set)}
466
+ sets = {
467
+ 'multilabel': multilabel_list,
468
+ 'segmentation': segmentation_list,
469
+ 'multiclass': multiclass_list
470
+ }
471
+
472
+ codes_map = {
473
+ 'multilabel': {label: idx for idx, label in enumerate(multilabel_list)},
474
+ 'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
475
+ 'multiclass': {label: idx for idx, label in enumerate(multiclass_list)}
476
+ }
373
477
 
374
- sets = {'multilabel': multilabel_set,
375
- 'segmentation': segmentation_labels,
376
- 'multiclass': multiclass_set}
377
- codes_map = {'multilabel': multilabel2code,
378
- 'segmentation': segmentation_label2code,
379
- 'multiclass': multiclass2code}
380
478
  return sets, codes_map
381
479
 
382
- def get_framelabel_distribution(self, normalize=False) -> dict[str, float]:
383
- """
384
- Returns the distribution of labels in the dataset.
480
+ def get_framelabel_distribution(self, normalize: bool = False) -> dict[str, float]:
481
+ """Returns the distribution of frame labels in the dataset."""
482
+ return self._get_label_distribution('label', 'frame', normalize)
385
483
 
386
- Returns:
387
- Dict[str, int]: The distribution of labels in the dataset.
388
- """
389
- label_distribution = {label: 0 for label in self.frame_labels_set}
390
- for imginfo in self.images_metainfo:
391
- for ann in imginfo['annotations']:
392
- if ann['type'] == 'label' and ann['index'] is not None:
393
- label_distribution[ann['name']] += 1
484
+ def get_segmentationlabel_distribution(self, normalize: bool = False) -> dict[str, float]:
485
+ """Returns the distribution of segmentation labels in the dataset."""
486
+ return self._get_label_distribution('segmentation', 'all', normalize)
394
487
 
395
- if normalize:
396
- total = sum(label_distribution.values())
397
- if total == 0:
398
- return label_distribution
399
- label_distribution = {k: v/total for k, v in label_distribution.items()}
400
- return label_distribution
488
+ def _get_label_distribution(self, ann_type: str, scope: str, normalize: bool) -> dict[str, float]:
489
+ """Helper method to calculate label distributions."""
490
+ if ann_type == 'label' and scope == 'frame':
491
+ labels = self.frame_labels_set
492
+ elif ann_type == 'segmentation':
493
+ labels = self.segmentation_labels_set
494
+ else:
495
+ raise ValueError(f"Unsupported combination: type={ann_type}, scope={scope}")
401
496
 
402
- def get_segmentationlabel_distribution(self, normalize=False) -> dict[str, float]:
403
- """
404
- Returns the distribution of segmentation labels in the dataset.
497
+ distribution = {label: 0 for label in labels}
405
498
 
406
- Returns:
407
- Dict[str, int]: The distribution of segmentation labels in the dataset.
408
- """
409
- label_distribution = {label: 0 for label in self.segmentation_labels_set}
410
499
  for imginfo in self.images_metainfo:
411
- if 'annotations' in imginfo and imginfo['annotations'] is not None:
412
- for ann in imginfo['annotations']:
413
- if ann['type'] == 'segmentation':
414
- label_distribution[ann['name']] += 1
500
+ for ann in imginfo.get('annotations', []):
501
+ condition_met = (
502
+ ann.type == ann_type and
503
+ (scope == 'all' or
504
+ (scope == 'frame' and ann.index is not None) or
505
+ (scope == 'image' and ann.index is None))
506
+ )
507
+ if condition_met and ann.name in distribution:
508
+ distribution[ann.name] += 1
415
509
 
416
510
  if normalize:
417
- total = sum(label_distribution.values())
418
- if total == 0:
419
- return label_distribution
420
- label_distribution = {k: v/total for k, v in label_distribution.items()}
421
- return label_distribution
511
+ total = sum(distribution.values())
512
+ if total > 0:
513
+ distribution = {k: v / total for k, v in distribution.items()}
422
514
 
423
- def _check_integrity(self):
515
+ return distribution
516
+
517
+ def _check_integrity(self) -> None:
518
+ """Check if all image files exist."""
519
+ missing_files = []
424
520
  for imginfo in self.images_metainfo:
425
- if not os.path.isfile(os.path.join(self.dataset_dir, imginfo['file'])):
426
- raise DatamintDatasetException(f"Image file {imginfo['file']} not found.")
521
+ filepath = os.path.join(self.dataset_dir, imginfo['file'])
522
+ if not os.path.isfile(filepath):
523
+ missing_files.append(imginfo['file'])
524
+
525
+ if missing_files:
526
+ raise DatamintDatasetException(f"Image files not found: {missing_files}")
427
527
 
428
528
  def _get_datasetinfo(self) -> dict:
529
+ """Get dataset information from API."""
530
+ if self._server_dataset_info is not None:
531
+ return self._server_dataset_info
429
532
  all_datasets = self.api_handler.get_datasets()
430
533
 
431
- value_to_search = self.dataset_id
432
- field_to_search = 'id'
433
-
434
- for d in all_datasets:
435
- if d[field_to_search] == value_to_search:
436
- return d
534
+ for dataset in all_datasets:
535
+ if dataset['id'] == self.dataset_id:
536
+ self._server_dataset_info = dataset
537
+ return dataset
437
538
 
438
539
  available_datasets = [(d['name'], d['id']) for d in all_datasets]
439
540
  raise DatamintDatasetException(
440
- f"Dataset with {field_to_search} '{value_to_search}' not found." +
441
- f" Available datasets: {available_datasets}"
541
+ f"Dataset with id '{self.dataset_id}' not found. "
542
+ f"Available datasets: {available_datasets}"
442
543
  )
443
544
 
444
545
  def get_info(self) -> dict:
546
+ """Get project information from API."""
547
+ if hasattr(self, 'project_info') and self.project_info is not None:
548
+ return self.project_info
445
549
  project = self.api_handler.get_project_by_name(self.project_name)
446
550
  if 'error' in project:
447
551
  available_projects = project['all_projects']
448
552
  raise DatamintDatasetException(
449
- f"Project with name '{self.project_name}' not found. Available projects: {available_projects}"
553
+ f"Project with name '{self.project_name}' not found. "
554
+ f"Available projects: {available_projects}"
450
555
  )
556
+ self.project_info = project
557
+ self.dataset_id = project['dataset_id']
451
558
  return project
452
559
 
453
560
  def _run_request(self, session, request_args) -> requests.Response:
@@ -457,216 +564,191 @@ class DatamintBaseDataset:
457
564
  response.raise_for_status()
458
565
  return response
459
566
 
460
- def _get_jwttoken(self, dataset_id, session) -> str:
461
- if dataset_id is None:
462
- raise ValueError("Dataset ID is required to download the dataset.")
463
- request_params = {
464
- 'method': 'GET',
465
- 'url': f'{self.server_url}/datasets/{dataset_id}/download/png',
466
- 'headers': {'apikey': self.api_key},
467
- 'stream': True
468
- }
469
- _LOGGER.debug(f"Getting jwt token for dataset {dataset_id}...")
470
- response = self._run_request(session, request_params)
471
- progress_bar = None
472
- number_processed_images = 0
473
-
474
- # check if the response is a stream of data and everything is ok
475
- if response.status_code != 200:
476
- msg = f"Getting jwt token failed with status code={response.status_code}: {response.text}"
477
- raise DatamintDatasetException(msg)
478
-
479
- try:
480
- response_iterator = response.iter_lines(decode_unicode=True)
481
- for line in response_iterator:
482
- line = line.strip()
483
- if 'event: error' in line:
484
- error_msg = line+'\n'
485
- error_msg += '\n'.join(response_iterator)
486
- raise DatamintDatasetException(f"Getting jwt token failed:\n{error_msg}")
487
- if not line.startswith('data:'):
488
- continue
489
- dataline = yaml.safe_load(line)['data']
490
- if 'zip' in dataline:
491
- _LOGGER.debug(f"Got jwt token for dataset {dataset_id}")
492
- return dataline['zip'] # Function normally ends here
493
- elif 'processedImages' in dataline:
494
- if progress_bar is None:
495
- total_size = int(dataline['totalImages'])
496
- progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
497
- processed_images = int(dataline['processedImages'])
498
- if number_processed_images < processed_images:
499
- progress_bar.update(processed_images - number_processed_images)
500
- number_processed_images = processed_images
501
- else:
502
- _LOGGER.warning(f"Unknown data line: {dataline}")
503
- except Exception as e:
504
- raise e
505
- finally:
506
- if progress_bar is not None:
507
- progress_bar.close()
508
-
509
- raise DatamintDatasetException("Getting jwt token failed! No dataline with 'zip' entry found.")
510
-
511
567
  def __repr__(self) -> str:
512
- """
513
- Example:
514
- .. code-block:: python
515
-
516
- print(dataset)
517
-
518
- Output:
519
-
520
- .. code-block:: text
521
-
522
- Dataset DatamintDataset
523
- Number of datapoints: 3
524
- Root location: /home/user/.datamint/datasets
525
-
526
- """
568
+ """String representation of the dataset."""
527
569
  head = f"Dataset {self.project_name}"
528
570
  body = [f"Number of datapoints: {self.__len__()}"]
571
+
529
572
  if self.root is not None:
530
573
  body.append(f"Location: {self.dataset_dir}")
531
574
 
532
- # Add filter information to representation
533
- if self.include_annotators is not None:
534
- body += [f"Including only annotators: {self.include_annotators}"]
535
- if self.exclude_annotators is not None:
536
- body += [f"Excluding annotators: {self.exclude_annotators}"]
537
- if self.include_segmentation_names is not None:
538
- body += [f"Including only segmentations: {self.include_segmentation_names}"]
539
- if self.exclude_segmentation_names is not None:
540
- body += [f"Excluding segmentations: {self.exclude_segmentation_names}"]
541
- if self.include_image_label_names is not None:
542
- body += [f"Including only image labels: {self.include_image_label_names}"]
543
- if self.exclude_image_label_names is not None:
544
- body += [f"Excluding image labels: {self.exclude_image_label_names}"]
545
- if self.include_frame_label_names is not None:
546
- body += [f"Including only frame labels: {self.include_frame_label_names}"]
547
- if self.exclude_frame_label_names is not None:
548
- body += [f"Excluding frame labels: {self.exclude_frame_label_names}"]
575
+ # Add filter information
576
+ filter_info = [
577
+ (self.include_annotators, "Including only annotators"),
578
+ (self.exclude_annotators, "Excluding annotators"),
579
+ (self.include_segmentation_names, "Including only segmentations"),
580
+ (self.exclude_segmentation_names, "Excluding segmentations"),
581
+ (self.include_image_label_names, "Including only image labels"),
582
+ (self.exclude_image_label_names, "Excluding image labels"),
583
+ (self.include_frame_label_names, "Including only frame labels"),
584
+ (self.exclude_frame_label_names, "Excluding frame labels"),
585
+ ]
586
+
587
+ for filter_value, description in filter_info:
588
+ if filter_value is not None:
589
+ body.append(f"{description}: {filter_value}")
549
590
 
550
591
  lines = [head] + [" " * 4 + line for line in body]
551
592
  return "\n".join(lines)
552
593
 
553
- def download_project(self):
554
- from torchvision.datasets.utils import extract_archive
594
+ def download_project(self) -> None:
595
+ """Download project data from API."""
555
596
 
556
597
  dataset_info = self._get_datasetinfo()
557
598
  self.dataset_id = dataset_info['id']
558
599
  self.last_updaded_at = dataset_info['updated_at']
559
600
 
560
- self.api_handler.download_project(self.project_info['id'],
561
- self.dataset_zippath,
562
- all_annotations=self.all_annotations,
563
- include_unannotated=self.include_unannotated)
564
- _LOGGER.debug(f"Downloaded dataset")
565
- downloaded_size = os.path.getsize(self.dataset_zippath)
566
- if downloaded_size == 0:
601
+ self.api_handler.download_project(
602
+ self.project_info['id'],
603
+ self.dataset_zippath,
604
+ all_annotations=self.all_annotations,
605
+ include_unannotated=self.include_unannotated
606
+ )
607
+
608
+ _LOGGER.debug("Downloaded dataset")
609
+
610
+ if os.path.getsize(self.dataset_zippath) == 0:
567
611
  raise DatamintDatasetException("Download failed.")
568
612
 
613
+ self._extract_and_update_metadata()
614
+
615
+ def _get_dataset_id(self) -> str:
616
+ if self.dataset_id is None:
617
+ dataset_info = self._get_datasetinfo()
618
+ self.dataset_id = dataset_info['id']
619
+ return self.dataset_id
620
+
621
+ def _extract_and_update_metadata(self) -> None:
622
+ """Extract downloaded archive and update metadata."""
623
+ from torchvision.datasets.utils import extract_archive
624
+
569
625
  if os.path.exists(self.dataset_dir):
570
626
  _LOGGER.info(f"Deleting existing dataset directory: {self.dataset_dir}")
571
627
  shutil.rmtree(self.dataset_dir)
572
- extract_archive(self.dataset_zippath,
573
- self.dataset_dir,
574
- remove_finished=True
575
- )
576
- datasetjson = os.path.join(self.dataset_dir, 'dataset.json')
577
- with open(datasetjson, 'r') as file:
628
+
629
+ extract_archive(self.dataset_zippath, self.dataset_dir, remove_finished=True)
630
+
631
+ # Load and update metadata
632
+ datasetjson_path = os.path.join(self.dataset_dir, 'dataset.json')
633
+ with open(datasetjson_path, 'r') as file:
578
634
  self.metainfo = json.load(file)
635
+
636
+ self._update_metadata_timestamps()
637
+
638
+ # Save updated metadata
639
+ with open(datasetjson_path, 'w') as file:
640
+ json.dump(self.metainfo, file, default=lambda o: o.to_dict() if hasattr(o, 'to_dict') else o)
641
+
642
+ self.images_metainfo = self.metainfo['resources']
643
+ # self._convert_metainfo_to_clsobj()
644
+
645
+ def _update_metadata_timestamps(self) -> None:
646
+ """Update metadata with correct timestamps."""
579
647
  if 'updated_at' not in self.metainfo:
580
648
  self.metainfo['updated_at'] = self.last_updaded_at
581
649
  else:
582
- # if self.last_updated_at is newer than the one in the dataset, update it
583
650
  try:
584
- if datetime.fromisoformat(self.metainfo['updated_at']) < datetime.fromisoformat(self.last_updaded_at):
585
- _LOGGER.warning(f"Inconsistent updated_at dates detected ({self.metainfo['updated_at']} < {self.last_updaded_at})." +
586
- f"Fixing it to {self.last_updaded_at}")
651
+ local_time = datetime.fromisoformat(self.metainfo['updated_at'])
652
+ server_time = datetime.fromisoformat(self.last_updaded_at)
653
+
654
+ if local_time < server_time:
655
+ _LOGGER.warning(
656
+ f"Inconsistent updated_at dates detected "
657
+ f"({self.metainfo['updated_at']} < {self.last_updaded_at}). "
658
+ f"Fixing it to {self.last_updaded_at}"
659
+ )
587
660
  self.metainfo['updated_at'] = self.last_updaded_at
588
661
  except Exception as e:
589
662
  _LOGGER.warning(f"Failed to parse updated_at date: {e}")
590
663
 
591
- # Add all_annotations to the metadata
592
664
  self.metainfo['all_annotations'] = self.all_annotations
593
665
 
594
- # save the updated_at date
595
- with open(datasetjson, 'w') as file:
596
- json.dump(self.metainfo, file)
597
-
598
- def _load_image(self, filepath: str,
599
- index: int | None = None) -> tuple[Tensor, FileDataset | None]:
666
+ def _load_image(self, filepath: str, index: int | None = None) -> tuple[Tensor, FileDataset | None]:
667
+ """Load image from file with optional frame index."""
600
668
  if os.path.isdir(filepath):
601
- raise NotImplementedError("Loading a image from a directory is not supported yet.")
669
+ raise NotImplementedError("Loading an image from a directory is not supported yet.")
602
670
 
603
671
  if self.return_frame_by_frame:
604
672
  img, ds = read_array_normalized(filepath, return_metainfo=True, index=index)
605
673
  else:
606
674
  img, ds = read_array_normalized(filepath, return_metainfo=True)
607
675
 
676
+ img = self._process_image_array(img)
677
+ return img, ds
678
+
679
+ def _process_image_array(self, img: np.ndarray) -> Tensor:
680
+ """Process numpy array to tensor with proper normalization."""
608
681
  if img.dtype == np.uint16:
609
682
  if not self.__logged_uint16_conversion:
610
683
  _LOGGER.info("Original image is uint16, converting to uint8")
611
684
  self.__logged_uint16_conversion = True
612
685
 
613
- # min-max normalization
686
+ # Min-max normalization
614
687
  img = img.astype(np.float32)
615
- mn = img.min()
616
- img = (img - mn) / (img.max() - mn) * 255
688
+ min_val = img.min()
689
+ img = (img - min_val) / (img.max() - min_val) * 255
617
690
  img = img.astype(np.uint8)
618
691
 
619
- img = torch.from_numpy(img).contiguous()
620
- if isinstance(img, torch.ByteTensor):
621
- img = img.to(dtype=torch.get_default_dtype()).div(255)
692
+ img_tensor = torch.from_numpy(img).contiguous()
622
693
 
623
- return img, ds
694
+ if isinstance(img_tensor, torch.ByteTensor):
695
+ img_tensor = img_tensor.to(dtype=torch.get_default_dtype()).div(255)
696
+
697
+ return img_tensor
624
698
 
625
- def _get_image_metainfo(self, index: int, bypass_subset_indices=False) -> dict[str, Any]:
699
+ def _get_image_metainfo(self, index: int, bypass_subset_indices: bool = False) -> dict[str, Any]:
700
+ """Get metadata for image at given index."""
626
701
  if not bypass_subset_indices:
627
702
  index = self.subset_indices[index]
703
+
628
704
  if self.return_frame_by_frame:
629
- # Find the correct filepath and index
630
705
  resource_id, frame_index = self.__find_index(index)
631
-
632
- img_metainfo = self.images_metainfo[resource_id]
633
- img_metainfo = dict(img_metainfo) # copy
634
- # insert frame index
706
+ img_metainfo = dict(self.images_metainfo[resource_id]) # Copy
635
707
  img_metainfo['frame_index'] = frame_index
636
- img_metainfo['annotations'] = [ann for ann in img_metainfo['annotations']
637
- if ann['index'] is None or ann['index'] == frame_index]
708
+ img_metainfo['annotations'] = [
709
+ ann for ann in img_metainfo['annotations']
710
+ if ann.index is None or ann.index == frame_index
711
+ ]
638
712
  else:
639
713
  img_metainfo = self.images_metainfo[index]
714
+
640
715
  return img_metainfo
641
716
 
642
717
  def __find_index(self, index: int) -> tuple[int, int]:
643
- """
644
- Find the resource index and frame index for a given global frame index.
645
-
646
- """
647
- # Use binary search to find the resource containing this frame
718
+ """Find the resource index and frame index for a given global frame index."""
648
719
  resource_index = np.searchsorted(self._cumulative_frames[1:], index, side='right')
649
720
  frame_index = index - self._cumulative_frames[resource_index]
650
-
651
721
  return resource_index, frame_index
652
722
 
653
- def __getitem_internal(self, index: int,
654
- only_load_metainfo=False) -> dict[str, Tensor | FileDataset | dict | list]:
723
+ def __getitem_internal(
724
+ self,
725
+ index: int,
726
+ only_load_metainfo: bool = False
727
+ ) -> dict[str, Tensor | FileDataset | dict | list]:
728
+ """Internal method to get item at index."""
655
729
  if self.return_frame_by_frame:
656
730
  resource_index, frame_idx = self.__find_index(index)
657
731
  else:
658
732
  resource_index = index
659
733
  frame_idx = None
734
+
660
735
  img_metainfo = self._get_image_metainfo(index, bypass_subset_indices=True)
661
736
 
662
737
  if only_load_metainfo:
663
738
  return {'metainfo': img_metainfo}
664
739
 
665
740
  filepath = os.path.join(self.dataset_dir, img_metainfo['file'])
666
-
667
- # Can be multi-frame, Gray-scale and/or RGB. So the shape is really variable, but it's always a numpy array.
668
741
  img, ds = self._load_image(filepath, frame_idx)
669
742
 
743
+ return self._build_item_dict(img, ds, img_metainfo)
744
+
745
+ def _build_item_dict(
746
+ self,
747
+ img: Tensor,
748
+ ds: FileDataset | None,
749
+ img_metainfo: dict
750
+ ) -> dict[str, Any]:
751
+ """Build the return dictionary for __getitem__."""
670
752
  ret = {'image': img}
671
753
 
672
754
  if self.return_dicom:
@@ -678,52 +760,42 @@ class DatamintBaseDataset:
678
760
 
679
761
  return ret
680
762
 
681
- def _filter_annotations(self, annotations: list[dict]) -> list[dict]:
682
- """
683
- Filter annotations based on the filtering settings.
684
-
685
- Args:
686
- annotations: list of annotations
687
-
688
- Returns:
689
- list[dict]: filtered list of annotations
690
- """
763
+ def _filter_annotations(self, annotations: list[Annotation]) -> list[Annotation]:
764
+ """Filter annotations based on the filtering settings."""
691
765
  if annotations is None:
692
766
  return []
693
767
 
694
768
  filtered_annotations = []
695
769
  for ann in annotations:
696
- # Filter by annotator
697
- if not self._should_include_annotator(ann['added_by']):
770
+ if not self._should_include_annotation(ann):
698
771
  continue
699
-
700
- # Filter by annotation type and name
701
- if ann['type'] == 'segmentation':
702
- if not self._should_include_segmentation(ann['name']):
703
- continue
704
- elif ann['type'] == 'label':
705
- # Check if it's a frame or image label
706
- if ann.get('index', None) is None:
707
- # Image label
708
- if not self._should_include_image_label(ann['name']):
709
- continue
710
- else:
711
- # Frame label
712
- if not self._should_include_frame_label(ann['name']):
713
- continue
714
-
715
- # If we reach here, the annotation passed all filters
716
772
  filtered_annotations.append(ann)
717
773
 
718
774
  return filtered_annotations
719
775
 
776
+ def _should_include_annotation(self, ann: Annotation) -> bool:
777
+ """Check if an annotation should be included based on all filters."""
778
+ if not self._should_include_annotator(ann.created_by):
779
+ return False
780
+
781
+ if ann.type == 'segmentation':
782
+ return self._should_include_segmentation(ann.name)
783
+ elif ann.type == 'label':
784
+ if ann.index is None:
785
+ return self._should_include_image_label(ann.name)
786
+ else:
787
+ return self._should_include_frame_label(ann.name)
788
+
789
+ return True
790
+
720
791
  def __getitem__(self, index: int) -> dict[str, Tensor | FileDataset | dict | list]:
721
- """
792
+ """Get item at index.
793
+
722
794
  Args:
723
- index (int): Index
795
+ index: Index
724
796
 
725
797
  Returns:
726
- dict: A dictionary containing three keys: 'image', 'metainfo' and 'annotations'.
798
+ A dictionary containing 'image', 'metainfo' and 'annotations' keys.
727
799
  """
728
800
  if index >= len(self):
729
801
  raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}")
@@ -731,21 +803,28 @@ class DatamintBaseDataset:
731
803
  return self.__getitem_internal(self.subset_indices[index])
732
804
 
733
805
  def __iter__(self):
806
+ """Iterate over dataset items."""
734
807
  for index in self.subset_indices:
735
- yield self.__getitem_internal(index)
808
+ yield self.__getitem__(index)
809
+ # do not use __getitem_internal__ here, so subclass only need to implement __getitem__
736
810
 
737
811
  def __len__(self) -> int:
812
+ """Return dataset length."""
738
813
  return len(self.subset_indices)
739
814
 
740
- def _check_version(self):
741
- metainfo_path = os.path.join(self.dataset_dir, 'dataset.json')
742
- if not os.path.exists(metainfo_path):
743
- self.download_project()
744
- return
745
- with open(metainfo_path, 'r') as file:
746
- local_dataset_info = json.load(file)
747
- local_updated_at = local_dataset_info.get('updated_at', None)
748
- local_all_annotations = local_dataset_info.get('all_annotations', None)
815
+ def _check_version(self) -> None:
816
+ """Check if local dataset version is up to date."""
817
+ # metainfo_path = os.path.join(self.dataset_dir, 'dataset.json')
818
+ # if not os.path.exists(metainfo_path):
819
+ # self.download_project()
820
+ # return
821
+
822
+ if not hasattr(self, 'project_info'):
823
+ self.project_info = self.get_info()
824
+ self.dataset_id = self.project_info['dataset_id']
825
+
826
+ local_updated_at = self.metainfo.get('updated_at', None)
827
+ local_all_annotations = self.metainfo.get('all_annotations', None)
749
828
 
750
829
  try:
751
830
  external_metadata_info = self._get_datasetinfo()
@@ -756,82 +835,279 @@ class DatamintBaseDataset:
756
835
 
757
836
  _LOGGER.debug(f"Local updated at: {local_updated_at}, Server updated at: {server_updated_at}")
758
837
 
759
- # Check if all_annotations changed or doesn't exist
760
838
  annotations_changed = local_all_annotations != self.all_annotations
839
+ version_outdated = local_updated_at is None or local_updated_at < server_updated_at
761
840
 
762
- if local_updated_at is None or local_updated_at < server_updated_at or annotations_changed:
763
- if annotations_changed:
764
- _LOGGER.info(
765
- f"The 'all_annotations' parameter has changed. Previous: {local_all_annotations}, Current: {self.all_annotations}."
766
- )
767
- else:
768
- _LOGGER.info(
769
- f"A newer version of the dataset is available. Your version: {local_updated_at}." +
770
- f" Last version: {server_updated_at}."
771
- )
772
- self.download_project()
841
+ if annotations_changed:
842
+ _LOGGER.info(
843
+ f"The 'all_annotations' parameter has changed. "
844
+ f"Previous: {local_all_annotations}, Current: {self.all_annotations}."
845
+ )
846
+ # self.download_project()
847
+ self._incremental_update()
848
+ elif version_outdated:
849
+ _LOGGER.info(
850
+ f"A newer version of the dataset is available. "
851
+ f"Your version: {local_updated_at}. Last version: {server_updated_at}."
852
+ )
853
+ self._incremental_update()
773
854
  else:
774
855
  _LOGGER.info('Local version is up to date with the latest version.')
775
856
 
857
+ def _fetch_new_resources(self,
858
+ all_uptodate_resources: list[dict]) -> list[dict]:
859
+ local_resources = self.images_metainfo
860
+ local_resources_ids = [res['id'] for res in local_resources]
861
+ new_resources = []
862
+ for resource in all_uptodate_resources:
863
+ if resource['id'] not in local_resources_ids:
864
+ resource['file'] = str(self._get_resource_file_path(resource))
865
+ resource['annotations'] = []
866
+ new_resources.append(resource)
867
+ return new_resources
868
+
869
+ def _fetch_deleted_resources(self, all_uptodate_resources: list[dict]) -> list[dict]:
870
+ local_resources = self.images_metainfo
871
+ all_uptodate_resources_ids = [res['id'] for res in all_uptodate_resources]
872
+ deleted_resources = []
873
+ for resource in local_resources:
874
+ try:
875
+ res_idx = all_uptodate_resources_ids.index(resource['id'])
876
+ if resource.get('deleted_at', None): # was deleted on server
877
+ if local_resources[res_idx].get('deleted_at_local', None) is None:
878
+ deleted_resources.append(resource)
879
+ except ValueError:
880
+ deleted_resources.append(resource)
881
+
882
+ return deleted_resources
883
+
884
+ def _incremental_update(self) -> None:
885
+ # local_updated_at = self.metainfo.get('updated_at', None)
886
+ # external_metadata_info = self._get_datasetinfo()
887
+ # server_updated_at = external_metadata_info['updated_at']
888
+
889
+ ### RESOURCES ###
890
+ all_uptodate_resources = self.api_handler.get_project_resources(self.get_info()['id'])
891
+ new_resources = self._fetch_new_resources(all_uptodate_resources)
892
+ deleted_resources = self._fetch_deleted_resources(all_uptodate_resources)
893
+
894
+ if new_resources:
895
+ for r in new_resources:
896
+ self._new_resource_created(r)
897
+ new_resources_path = [Path(self.dataset_dir) / r['file'] for r in new_resources]
898
+ new_resources_ids = [r['id'] for r in new_resources]
899
+ _LOGGER.info(f"Downloading {len(new_resources)} new resources...")
900
+ self.api_handler.download_multiple_resources(new_resources_ids,
901
+ save_path=new_resources_path)
902
+ _LOGGER.info(f"Downloaded {len(new_resources)} new resources.")
903
+
904
+ for r in deleted_resources:
905
+ self._resource_deleted(r)
906
+ ################
907
+
908
+ ### ANNOTATIONS ###
909
+ all_annotations = self.api_handler.get_annotations(worklist_id=self.project_info['worklist_id'],
910
+ status='published' if self.all_annotations else None)
911
+ # group annotations by resource ID
912
+ annotations_by_resource = {}
913
+ for ann in all_annotations:
914
+ # add the local filepath
915
+ filepath = self._get_annotation_file_path(ann)
916
+ if filepath is not None:
917
+ ann['file'] = str(filepath)
918
+ resource_id = ann['resource_id']
919
+ if resource_id not in annotations_by_resource:
920
+ annotations_by_resource[resource_id] = []
921
+ annotations_by_resource[resource_id].append(ann)
922
+
923
+ # Collect all segmentation annotations that need to be downloaded
924
+ segmentations_to_download = []
925
+ segmentation_paths = []
926
+
927
+ # update annotations in resources
928
+ for resource in self.images_metainfo:
929
+ resource_id = resource['id']
930
+ new_resource_annotations = annotations_by_resource.get(resource_id, [])
931
+ old_resource_annotations = resource.get('annotations', [])
932
+
933
+ # check if segmentation annotations need to be downloaded
934
+ # Also check if annotations need to be deleted
935
+ old_ann_ids = set([ann.id for ann in old_resource_annotations if hasattr(ann, 'id')])
936
+ new_ann_ids = set([ann['id'] for ann in new_resource_annotations])
937
+
938
+ # Find annotations to add, update, or remove
939
+ annotations_to_add = [ann for ann in new_resource_annotations
940
+ if ann['id'] not in old_ann_ids]
941
+ annotations_to_remove = [ann for ann in old_resource_annotations
942
+ if getattr(ann, 'id', 'NA') not in new_ann_ids]
943
+
944
+ for ann in annotations_to_add:
945
+ filepath = self._get_annotation_file_path(ann)
946
+ if filepath is not None: # None means it is not a segmentation
947
+ # Collect for batch download
948
+ filepath = Path(self.dataset_dir) / filepath
949
+ filepath.parent.mkdir(parents=True, exist_ok=True)
950
+ segmentations_to_download.append(ann)
951
+ segmentation_paths.append(filepath)
952
+
953
+ # Process annotation changes
954
+ for ann in annotations_to_remove:
955
+ filepath = getattr(ann, 'file', None) if hasattr(ann, 'file') else ann.get('file', None)
956
+ if filepath is None:
957
+ # Not a segmentation annotation
958
+ continue
959
+
960
+ try:
961
+ filepath = Path(self.dataset_dir) / filepath
962
+ # delete the local annotation file if it exists
963
+ if filepath.exists():
964
+ os.remove(filepath)
965
+ except Exception as e:
966
+ _LOGGER.error(f"Error deleting annotation file {filepath}: {e}")
967
+
968
+ # Update resource annotations list - convert to Annotation objects
969
+ resource['annotations'] = [Annotation.from_dict(ann) for ann in new_resource_annotations]
970
+
971
+ # Batch download all segmentation files
972
+ if segmentations_to_download:
973
+ _LOGGER.info(f"Downloading {len(segmentations_to_download)} segmentation files...")
974
+ self.api_handler.download_multiple_segmentations(segmentations_to_download, segmentation_paths)
975
+ _LOGGER.info(f"Downloaded {len(segmentations_to_download)} segmentation files.")
976
+
977
+ ###################
978
+ # update metadata
979
+ self.metainfo['updated_at'] = self._get_datasetinfo()['updated_at']
980
+ self.metainfo['all_annotations'] = self.all_annotations
981
+ # save updated metadata
982
+ datasetjson_path = os.path.join(self.dataset_dir, 'dataset.json')
983
+ with open(datasetjson_path, 'w') as file:
984
+ json.dump(self.metainfo, file, default=lambda o: o.to_dict() if hasattr(o, 'to_dict') else o)
985
+
986
+ def _get_resource_file_path(self, resource: dict) -> Path:
987
+ """Get the local file path for a resource."""
988
+ if 'file' in resource and resource['file'] is not None:
989
+ return Path(resource['file'])
990
+ else:
991
+ ext = guess_extension(resource['mimetype'], strict=False)
992
+ if ext is None:
993
+ _LOGGER.warning(f"Could not guess extension for resource {resource['id']}.")
994
+ ext = ''
995
+ return Path('images', f"{resource['id']}{ext}")
996
+
997
+ def _get_annotation_file_path(self, annotation: dict | Annotation) -> Path | None:
998
+ """Get the local file path for an annotation."""
999
+ if isinstance(annotation, Annotation):
1000
+ if annotation.file:
1001
+ return Path(annotation.file)
1002
+ elif annotation.type == 'segmentation':
1003
+ return Path('masks',
1004
+ annotation.created_by,
1005
+ annotation.resource_id,
1006
+ annotation.id)
1007
+ else:
1008
+ # Handle dict format for backwards compatibility
1009
+ if 'file' in annotation:
1010
+ return Path(annotation['file'])
1011
+ elif annotation.get('annotation_type', annotation.get('type')) == 'segmentation':
1012
+ return Path('masks',
1013
+ annotation['created_by'],
1014
+ annotation['resource_id'],
1015
+ annotation['id'])
1016
+ return None
1017
+
1018
+ def _new_resource_created(self, resource: dict) -> None:
1019
+ """Handle a new resource created in the dataset."""
1020
+ if 'annotations' not in resource:
1021
+ resource['annotations'] = [] # Initialize as empty list for Annotation objects
1022
+ self.images_metainfo.append(resource)
1023
+
1024
+ if hasattr(self, 'num_frames_per_resource'):
1025
+ raise NotImplementedError('Cannot handle new resources after dataset initialization')
1026
+
1027
+ def _resource_deleted(self, resource: dict) -> None:
1028
+ """Handle a resource deleted from the dataset."""
1029
+
1030
+ # remove from metadata
1031
+ for i, imginfo in enumerate(self.images_metainfo):
1032
+ if imginfo['id'] == resource['id']:
1033
+ deleted_metainfo = self.images_metainfo.pop(i)
1034
+ break
1035
+ else:
1036
+ _LOGGER.warning(f"Resource {resource['id']} not found in dataset metadata.")
1037
+ return
1038
+
1039
+ # delete from system file
1040
+ if os.path.exists(deleted_metainfo['file']):
1041
+ os.remove(os.path.join(self.dataset_dir, deleted_metainfo['file']))
1042
+
1043
+ # delete associated annotations
1044
+ for ann in deleted_metainfo.get('annotations', []):
1045
+ ann_file = getattr(ann, 'file', None) if hasattr(ann, 'file') else ann.get('file', None)
1046
+ if ann_file is not None:
1047
+ os.remove(os.path.join(self.dataset_dir, ann_file))
1048
+
776
1049
  def __add__(self, other):
1050
+ """Concatenate datasets."""
777
1051
  from torch.utils.data import ConcatDataset
778
1052
  return ConcatDataset([self, other])
779
1053
 
780
1054
  def get_dataloader(self, *args, **kwargs) -> DataLoader:
781
- """
782
- Returns a DataLoader for the dataset.
783
- This is a wrapper around the PyTorch DataLoader, with the convinience of using a nice collate_fn
784
- that properly handles the different types of data in this dataset.
1055
+ """Returns a DataLoader for the dataset with proper collate function.
785
1056
 
786
1057
  Args:
787
- *args: Positional arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
788
- **kwargs: Keyword arguments for the DataLoader. See `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for details.
789
-
1058
+ *args: Positional arguments for the DataLoader.
1059
+ **kwargs: Keyword arguments for the DataLoader.
790
1060
 
1061
+ Returns:
1062
+ DataLoader instance with custom collate function.
791
1063
  """
792
- return DataLoader(self,
793
- *args,
794
- collate_fn=self.get_collate_fn(),
795
- **kwargs)
1064
+ return DataLoader(self, *args, collate_fn=self.get_collate_fn(), **kwargs)
796
1065
 
797
1066
  def get_collate_fn(self) -> Callable:
798
- def collate_fn(batch: dict) -> dict:
1067
+ """Get collate function for DataLoader."""
1068
+ def collate_fn(batch: list[dict]) -> dict:
1069
+ if not batch:
1070
+ return {}
1071
+
799
1072
  keys = batch[0].keys()
800
1073
  collated_batch = {}
1074
+
801
1075
  for key in keys:
802
- collated_batch[key] = [item[key] for item in batch]
803
- if isinstance(collated_batch[key][0], torch.Tensor):
804
- # check if every tensor has the same shape
805
- shapes = [tensor.shape for tensor in collated_batch[key]]
1076
+ values = [item[key] for item in batch]
1077
+
1078
+ if isinstance(values[0], torch.Tensor):
1079
+ shapes = [tensor.shape for tensor in values]
806
1080
  if all(shape == shapes[0] for shape in shapes):
807
- collated_batch[key] = torch.stack(collated_batch[key])
1081
+ collated_batch[key] = torch.stack(values)
808
1082
  else:
809
- _LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}. ")
810
- elif isinstance(collated_batch[key][0], np.ndarray):
811
- collated_batch[key] = np.stack(collated_batch[key])
1083
+ _LOGGER.warning(f"Collating {key} tensors with different shapes: {shapes}")
1084
+ collated_batch[key] = values
1085
+ elif isinstance(values[0], np.ndarray):
1086
+ collated_batch[key] = np.stack(values)
1087
+ else:
1088
+ collated_batch[key] = values
812
1089
 
813
1090
  return collated_batch
814
1091
 
815
1092
  return collate_fn
816
1093
 
817
1094
  def subset(self, indices: list[int]) -> 'DatamintBaseDataset':
818
- if len(self.subset_indices) > self.dataset_length:
1095
+ """Create a subset of the dataset.
1096
+
1097
+ Args:
1098
+ indices: List of indices to include in the subset.
1099
+
1100
+ Returns:
1101
+ Self with updated subset indices.
1102
+ """
1103
+ if max(indices, default=-1) >= self.dataset_length:
819
1104
  raise ValueError(f"Subset indices must be less than the dataset length: {self.dataset_length}")
820
1105
 
821
1106
  self.subset_indices = indices
822
-
823
1107
  return self
824
1108
 
825
1109
  def _should_include_annotator(self, annotator_id: str) -> bool:
826
- """
827
- Check if an annotator should be included based on the filtering settings.
828
-
829
- Args:
830
- annotator_id: The ID of the annotator to check
831
-
832
- Returns:
833
- bool: True if the annotator should be included, False otherwise
834
- """
1110
+ """Check if an annotator should be included based on filtering settings."""
835
1111
  if self.include_annotators is not None:
836
1112
  return annotator_id in self.include_annotators
837
1113
  if self.exclude_annotators is not None:
@@ -839,15 +1115,7 @@ class DatamintBaseDataset:
839
1115
  return True
840
1116
 
841
1117
  def _should_include_segmentation(self, segmentation_name: str) -> bool:
842
- """
843
- Check if a segmentation should be included based on the filtering settings.
844
-
845
- Args:
846
- segmentation_name: The name of the segmentation to check
847
-
848
- Returns:
849
- bool: True if the segmentation should be included, False otherwise
850
- """
1118
+ """Check if a segmentation should be included based on filtering settings."""
851
1119
  if self.include_segmentation_names is not None:
852
1120
  return segmentation_name in self.include_segmentation_names
853
1121
  if self.exclude_segmentation_names is not None:
@@ -855,15 +1123,7 @@ class DatamintBaseDataset:
855
1123
  return True
856
1124
 
857
1125
  def _should_include_image_label(self, label_name: str) -> bool:
858
- """
859
- Check if an image label should be included based on the filtering settings.
860
-
861
- Args:
862
- label_name: The name of the image label to check
863
-
864
- Returns:
865
- bool: True if the image label should be included, False otherwise
866
- """
1126
+ """Check if an image label should be included based on filtering settings."""
867
1127
  if self.include_image_label_names is not None:
868
1128
  return label_name in self.include_image_label_names
869
1129
  if self.exclude_image_label_names is not None:
@@ -871,15 +1131,7 @@ class DatamintBaseDataset:
871
1131
  return True
872
1132
 
873
1133
  def _should_include_frame_label(self, label_name: str) -> bool:
874
- """
875
- Check if a frame label should be included based on the filtering settings.
876
-
877
- Args:
878
- label_name: The name of the frame label to check
879
-
880
- Returns:
881
- bool: True if the frame label should be included, False otherwise
882
- """
1134
+ """Check if a frame label should be included based on filtering settings."""
883
1135
  if self.include_frame_label_names is not None:
884
1136
  return label_name in self.include_frame_label_names
885
1137
  if self.exclude_frame_label_names is not None: