datamint 2.3.1__py3-none-any.whl → 2.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  from typing import Any, TypeVar, Generic, Type, Sequence
2
2
  import logging
3
3
  import httpx
4
- from dataclasses import dataclass
5
4
  from datamint.entities.base_entity import BaseEntity
6
5
  from datamint.exceptions import DatamintException, ResourceNotFoundError
7
6
  import aiohttp
@@ -37,9 +36,14 @@ class EntityBaseApi(BaseApi, Generic[T]):
37
36
  client: Optional HTTP client instance. If None, a new one will be created.
38
37
  """
39
38
  super().__init__(config, client)
40
- self.entity_class = entity_class
39
+ self.__entity_class = entity_class
41
40
  self.endpoint_base = endpoint_base.strip('/')
42
41
 
42
+ def _init_entity_obj(self, **kwargs) -> T:
43
+ obj = self.__entity_class(**kwargs)
44
+ obj._api = self
45
+ return obj
46
+
43
47
  @staticmethod
44
48
  def _entid(entity: BaseEntity | str) -> str:
45
49
  return entity if isinstance(entity, str) else entity.id
@@ -117,7 +121,7 @@ class EntityBaseApi(BaseApi, Generic[T]):
117
121
  for resp, items in items_gen:
118
122
  all_items.extend(items)
119
123
 
120
- return [self.entity_class(**item) for item in all_items]
124
+ return [self._init_entity_obj(**item) for item in all_items]
121
125
 
122
126
  def get_all(self, limit: int | None = None) -> Sequence[T]:
123
127
  """Get all entities with optional pagination and filtering.
@@ -143,7 +147,7 @@ class EntityBaseApi(BaseApi, Generic[T]):
143
147
  httpx.HTTPStatusError: If the entity is not found or request fails.
144
148
  """
145
149
  response = self._make_entity_request('GET', entity_id)
146
- return self.entity_class(**response.json())
150
+ return self._init_entity_obj(**response.json())
147
151
 
148
152
  async def _create_async(self, entity_data: dict[str, Any]) -> str | Sequence[str | dict]:
149
153
  """Create a new entity.
@@ -177,42 +181,6 @@ class EntityBaseApi(BaseApi, Generic[T]):
177
181
  add_path=child_entity_name)
178
182
  return response
179
183
 
180
- # def bulk_create(self, entities_data: list[dict[str, Any]]) -> list[T]:
181
- # """Create multiple entities in a single request.
182
-
183
- # Args:
184
- # entities_data: List of dictionaries containing entity data
185
-
186
- # Returns:
187
- # List of created entity instances
188
-
189
- # Raises:
190
- # httpx.HTTPStatusError: If bulk creation fails
191
- # """
192
- # payload = {'items': entities_data} # Common bulk API format
193
- # response = self._make_request('POST', f'/{self.endpoint_base}/bulk', json=payload)
194
- # data = response.json()
195
-
196
- # # Handle response format - may be direct list or wrapped
197
- # items = data if isinstance(data, list) else data.get('items', [])
198
- # return [self.entity_class(**item) for item in items]
199
-
200
- # def count(self, **params: Any) -> int:
201
- # """Get the total count of entities matching the given filters.
202
-
203
- # Args:
204
- # **params: Query parameters for filtering
205
-
206
- # Returns:
207
- # Total count of matching entities
208
-
209
- # Raises:
210
- # httpx.HTTPStatusError: If the request fails
211
- # """
212
- # response = self._make_request('GET', f'/{self.endpoint_base}/count', params=params)
213
- # data = response.json()
214
- # return data.get('count', 0) if isinstance(data, dict) else data
215
-
216
184
 
217
185
  class DeletableEntityApi(EntityBaseApi[T]):
218
186
  """Extension of EntityBaseApi for entities that support soft deletion.
@@ -221,7 +189,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
221
189
  retrieval and restoration of such entities.
222
190
  """
223
191
 
224
- def delete(self, entity: str | BaseEntity) -> None:
192
+ def delete(self, entity: str | T) -> None:
225
193
  """Delete an entity.
226
194
 
227
195
  Args:
@@ -232,7 +200,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
232
200
  """
233
201
  self._make_entity_request('DELETE', entity)
234
202
 
235
- def bulk_delete(self, entities: Sequence[str | BaseEntity]) -> None:
203
+ def bulk_delete(self, entities: Sequence[str | T]) -> None:
236
204
  """Delete multiple entities.
237
205
 
238
206
  Args:
@@ -264,7 +232,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
264
232
  httpx.HTTPStatusError: If deletion fails or entity not found
265
233
  """
266
234
  async with self._make_entity_request_async('DELETE', entity,
267
- session=session) as resp:
235
+ session=session) as resp:
268
236
  await resp.text() # Consume response to complete request
269
237
 
270
238
  # def get_deleted(self, **kwargs) -> Sequence[T]:
@@ -17,7 +17,11 @@ Classes:
17
17
  import json
18
18
  from typing import Any, TypeAlias, Literal
19
19
  import logging
20
- from enum import Enum
20
+ import sys
21
+ if sys.version_info >= (3, 11):
22
+ from enum import StrEnum
23
+ else:
24
+ from backports.strenum import StrEnum
21
25
  from medimgkit.dicom_utils import pixel_to_patient
22
26
  import pydicom
23
27
  import numpy as np
@@ -31,7 +35,7 @@ CoordinateSystem: TypeAlias = Literal['pixel', 'patient']
31
35
  """
32
36
 
33
37
 
34
- class AnnotationType(Enum):
38
+ class AnnotationType(StrEnum):
35
39
  SEGMENTATION = 'segmentation'
36
40
  AREA = 'area'
37
41
  DISTANCE = 'distance'
datamint/configs.py CHANGED
@@ -18,6 +18,12 @@ _LOGGER = logging.getLogger(__name__)
18
18
 
19
19
  DIRS = PlatformDirs(appname='datamintapi')
20
20
  CONFIG_FILE = os.path.join(DIRS.user_config_dir, 'datamintapi.yaml')
21
+ try:
22
+ DATAMINT_DATA_DIR = os.path.join(os.path.expanduser("~"), '.datamint')
23
+ except Exception as e:
24
+ _LOGGER.error(f"Could not determine home directory: {e}")
25
+ DATAMINT_DATA_DIR = None
26
+
21
27
 
22
28
 
23
29
  def get_env_var_name(key: str) -> str:
@@ -11,7 +11,6 @@ from torch.utils.data import DataLoader
11
11
  import torch
12
12
  from torch import Tensor
13
13
  from datamint.exceptions import DatamintException
14
- from medimgkit.dicom_utils import is_dicom
15
14
  from medimgkit.readers import read_array_normalized
16
15
  from medimgkit.format_detection import guess_extension, guess_typez
17
16
  from medimgkit.nifti_utils import NIFTI_MIMES, get_nifti_shape
@@ -20,6 +19,7 @@ from pathlib import Path
20
19
  from datamint.entities import Annotation, DatasetInfo
21
20
  import cv2
22
21
  from datamint.entities import Resource
22
+ import datamint.configs
23
23
 
24
24
  _LOGGER = logging.getLogger(__name__)
25
25
 
@@ -55,7 +55,7 @@ class DatamintBaseDataset:
55
55
  exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
56
56
  """
57
57
 
58
- DATAMINT_DEFAULT_DIR = ".datamint"
58
+
59
59
  DATAMINT_DATASETS_DIR = "datasets"
60
60
 
61
61
  def __init__(
@@ -184,8 +184,7 @@ class DatamintBaseDataset:
184
184
  """Setup root and dataset directories."""
185
185
  if root is None:
186
186
  root = os.path.join(
187
- os.path.expanduser("~"),
188
- self.DATAMINT_DEFAULT_DIR,
187
+ datamint.configs.DATAMINT_DATA_DIR,
189
188
  self.DATAMINT_DATASETS_DIR
190
189
  )
191
190
  os.makedirs(root, exist_ok=True)
@@ -306,6 +305,15 @@ class DatamintBaseDataset:
306
305
  """Setup label sets and mappings."""
307
306
  self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
308
307
  self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
308
+ worklist_id = self.get_info()['worklist_id']
309
+ groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
310
+ # order by 'index' key
311
+ max_index = max([g['index'] for g in groups.values()])
312
+ self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
313
+ for segname, g in groups.items():
314
+ self.seglabel_list[g['index'] - 1] = segname
315
+
316
+ self.seglabel2code = {label: idx + 1 for idx, label in enumerate(self.seglabel_list)}
309
317
 
310
318
  def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
311
319
  """Filter items that have annotations."""
@@ -357,9 +365,7 @@ class DatamintBaseDataset:
357
365
  @property
358
366
  def segmentation_labels_set(self) -> list[str]:
359
367
  """Returns the set of segmentation labels in the dataset."""
360
- a = set(self.frame_lsets['segmentation'])
361
- b = set(self.image_lsets['segmentation'])
362
- return list(a.union(b))
368
+ return self.seglabel_list
363
369
 
364
370
  def _get_annotations_internal(
365
371
  self,
@@ -463,8 +469,8 @@ class DatamintBaseDataset:
463
469
  label_anns = self.get_annotations(i, type='label', scope=scope)
464
470
  multilabel_set.update(ann.name for ann in label_anns)
465
471
 
466
- seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
467
- segmentation_labels.update(ann.name for ann in seg_anns)
472
+ # seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
473
+ # segmentation_labels.update(ann.name for ann in seg_anns)
468
474
 
469
475
  cat_anns = self.get_annotations(i, type='category', scope=scope)
470
476
  multiclass_set.update((ann.name, ann.value) for ann in cat_anns)
@@ -472,17 +478,17 @@ class DatamintBaseDataset:
472
478
  # Sort and create mappings
473
479
  multilabel_list = sorted(multilabel_set)
474
480
  multiclass_list = sorted(multiclass_set)
475
- segmentation_list = sorted(segmentation_labels)
481
+ # segmentation_list = sorted(segmentation_labels)
476
482
 
477
483
  sets = {
478
484
  'multilabel': multilabel_list,
479
- 'segmentation': segmentation_list,
485
+ # 'segmentation': segmentation_list,
480
486
  'multiclass': multiclass_list
481
487
  }
482
488
 
483
489
  codes_map = {
484
490
  'multilabel': {label: idx for idx, label in enumerate(multilabel_list)},
485
- 'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
491
+ # 'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
486
492
  'multiclass': {label: idx for idx, label in enumerate(multiclass_list)}
487
493
  }
488
494
 
@@ -191,7 +191,7 @@ class DatamintDataset(DatamintBaseDataset):
191
191
  author_labels = seg_labels[author]
192
192
 
193
193
  if frame_index is not None and ann.scope == 'frame':
194
- seg_code = self.frame_lcodes['segmentation'][ann.name]
194
+ seg_code = self.seglabel2code[ann.name]
195
195
  if author_segs[frame_index] is None:
196
196
  author_segs[frame_index] = []
197
197
  author_labels[frame_index] = []
@@ -199,7 +199,7 @@ class DatamintDataset(DatamintBaseDataset):
199
199
  author_segs[frame_index].append(s)
200
200
  author_labels[frame_index].append(seg_code)
201
201
  elif frame_index is None and ann.scope == 'image':
202
- seg_code = self.image_lcodes['segmentation'][ann.name]
202
+ seg_code = self.seglabel2code[ann.name]
203
203
  # apply to all frames
204
204
  for i in range(nframes):
205
205
  if author_segs[i] is None:
@@ -7,14 +7,16 @@ from .project import Project
7
7
  from .resource import Resource
8
8
  from .user import User # new export
9
9
  from .datasetinfo import DatasetInfo
10
+ from .cache_manager import CacheManager
10
11
 
11
12
  __all__ = [
12
13
  'Annotation',
13
14
  'BaseEntity',
15
+ 'CacheManager',
14
16
  'Channel',
15
17
  'ChannelResourceData',
18
+ 'DatasetInfo',
16
19
  'Project',
17
20
  'Resource',
18
- "User",
19
- 'DatasetInfo',
21
+ 'User',
20
22
  ]
@@ -5,11 +5,20 @@ This module defines the Annotation model used to represent annotation
5
5
  records returned by the DataMint API.
6
6
  """
7
7
 
8
- from typing import Any
8
+ from typing import TYPE_CHECKING, Any
9
9
  import logging
10
+ import os
11
+
10
12
  from .base_entity import BaseEntity, MISSING_FIELD
11
- from pydantic import Field
13
+ from .cache_manager import CacheManager
14
+ from pydantic import PrivateAttr
12
15
  from datetime import datetime
16
+ from datamint.api.dto import AnnotationType
17
+ from datamint.types import ImagingData
18
+
19
+ if TYPE_CHECKING:
20
+ from datamint.api.endpoints.annotations_api import AnnotationsApi
21
+ from .resource import Resource
13
22
 
14
23
  logger = logging.getLogger(__name__)
15
24
 
@@ -21,6 +30,8 @@ _FIELD_MAPPING = {
21
30
  'index': 'frame_index',
22
31
  }
23
32
 
33
+ _ANNOTATION_CACHE_KEY = "annotation_data"
34
+
24
35
 
25
36
  class Annotation(BaseEntity):
26
37
  """Pydantic Model representing a DataMint annotation.
@@ -60,7 +71,7 @@ class Annotation(BaseEntity):
60
71
  identifier: str
61
72
  scope: str
62
73
  frame_index: int | None
63
- annotation_type: str
74
+ annotation_type: AnnotationType
64
75
  text_value: str | None
65
76
  numeric_value: float | int | None
66
77
  units: str | None
@@ -83,7 +94,66 @@ class Annotation(BaseEntity):
83
94
  annotation_worklist_name: str | None
84
95
  user_info: dict | None
85
96
  values: list | None = MISSING_FIELD
86
- file: str | None = None # Add file field for segmentations
97
+ file: str | None = None
98
+
99
+ _api: 'AnnotationsApi' = PrivateAttr()
100
+
101
+ def __init__(self, **data):
102
+ """Initialize the annotation entity."""
103
+ super().__init__(**data)
104
+ self._cache: CacheManager = CacheManager('annotations')
105
+ self._resource: 'Resource | None' = None
106
+
107
+ @property
108
+ def resource(self) -> 'Resource':
109
+ """Lazily load and cache the associated Resource entity."""
110
+ if self._resource is None:
111
+ self._resource = self._api._get_resource(self)
112
+ return self._resource
113
+
114
+ def fetch_file_data(
115
+ self,
116
+ save_path: os.PathLike | str | None = None,
117
+ auto_convert: bool = True,
118
+ use_cache: bool = False,
119
+ ) -> bytes | ImagingData:
120
+ # Version info for cache validation
121
+ version_info = self._generate_version_info()
122
+
123
+ # Try to get from cache
124
+ img_data = None
125
+ if use_cache:
126
+ img_data = self._cache.get(self.id, _ANNOTATION_CACHE_KEY, version_info)
127
+
128
+ if img_data is None:
129
+ # Fetch from server using download_resource_file
130
+ logger.debug(f"Fetching image data from server for resource {self.id}")
131
+ img_data = self._api.download_file(
132
+ self,
133
+ fpath_out=save_path
134
+ )
135
+ # Cache the data
136
+ if use_cache:
137
+ self._cache.set(self.id, _ANNOTATION_CACHE_KEY, img_data, version_info)
138
+
139
+ if auto_convert:
140
+ return self._api.convert_format(img_data)
141
+
142
+ return img_data
143
+
144
+ def _generate_version_info(self) -> dict:
145
+ """Helper to generate version info for caching."""
146
+ return {
147
+ 'created_at': self.created_at,
148
+ 'deleted_at': self.deleted_at,
149
+ 'associated_file': self.associated_file,
150
+ }
151
+
152
+ def invalidate_cache(self) -> None:
153
+ """Invalidate all cached data for this annotation."""
154
+ self._cache.invalidate(self.id)
155
+ self._resource = None
156
+ logger.debug(f"Invalidated cache for annotation {self.id}")
87
157
 
88
158
  @classmethod
89
159
  def from_dict(cls, data: dict[str, Any]) -> 'Annotation':
@@ -1,7 +1,11 @@
1
1
  import logging
2
2
  import sys
3
- from typing import Any
4
- from pydantic import ConfigDict, BaseModel
3
+ from typing import Any, TYPE_CHECKING
4
+ from pydantic import ConfigDict, BaseModel, PrivateAttr
5
+
6
+ if TYPE_CHECKING:
7
+ from datamint.api.client import Api
8
+ from datamint.api.entity_base_api import EntityBaseApi
5
9
 
6
10
  if sys.version_info >= (3, 11):
7
11
  from typing import Self
@@ -22,9 +26,14 @@ class BaseEntity(BaseModel):
22
26
  This class provides common functionality for all entities, such as
23
27
  serialization and deserialization from dictionaries, as well as
24
28
  handling unknown fields gracefully.
29
+
30
+ The API client is automatically injected by the Api class when entities
31
+ are created through API endpoints.
25
32
  """
26
33
 
27
- model_config = ConfigDict(extra='allow') # Allow extra fields not defined in the model
34
+ model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True) # Allow extra fields and arbitrary types
35
+
36
+ _api: 'EntityBaseApi[Self] | EntityBaseApi' = PrivateAttr()
28
37
 
29
38
  def asdict(self) -> dict[str, Any]:
30
39
  """Convert the entity to a dictionary, including unknown fields."""
@@ -38,14 +47,46 @@ class BaseEntity(BaseModel):
38
47
  """Handle unknown fields by logging a warning once per class/field combination in debug mode."""
39
48
  if self.__pydantic_extra__ and _LOGGER.isEnabledFor(logging.DEBUG):
40
49
  class_name = self.__class__.__name__
41
-
50
+
42
51
  have_to_log = False
43
52
  for key in self.__pydantic_extra__.keys():
44
53
  warning_key = (class_name, key)
45
-
54
+
46
55
  if warning_key not in _LOGGED_WARNINGS:
47
56
  _LOGGED_WARNINGS.add(warning_key)
48
57
  have_to_log = True
49
-
58
+
50
59
  if have_to_log:
51
60
  _LOGGER.warning(f"Unknown fields {list(self.__pydantic_extra__.keys())} found in {class_name}")
61
+
62
+ @staticmethod
63
+ def is_attr_missing(value: Any) -> bool:
64
+ """Check if a value is the MISSING_FIELD sentinel."""
65
+ return value == MISSING_FIELD
66
+
67
+ def _refresh(self) -> Self:
68
+ """Refresh the entity data from the server.
69
+
70
+ This method fetches the latest data from the server and updates
71
+ the current instance with any missing or updated fields.
72
+
73
+ Returns:
74
+ The updated Entity instance (self)
75
+ """
76
+ updated_ent = self._api.get_by_id(self._api._entid(self))
77
+
78
+ # Update all fields from the fresh data
79
+ for field_name, field_value in updated_ent.model_dump().items():
80
+ if field_value != MISSING_FIELD:
81
+ setattr(self, field_name, field_value)
82
+
83
+ return self
84
+
85
+ def _ensure_attr(self, attr_name: str) -> None:
86
+ """Ensure that a given attribute is not MISSING_FIELD, refreshing if necessary.
87
+
88
+ Args:
89
+ attr_name: Name of the attribute to check and ensure
90
+ """
91
+ if self.is_attr_missing(getattr(self, attr_name)):
92
+ self._refresh()