datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/__init__.py +1 -3
- datamint/api/__init__.py +0 -3
- datamint/api/base_api.py +286 -54
- datamint/api/client.py +76 -13
- datamint/api/endpoints/__init__.py +2 -2
- datamint/api/endpoints/annotations_api.py +186 -28
- datamint/api/endpoints/deploy_model_api.py +78 -0
- datamint/api/endpoints/models_api.py +1 -0
- datamint/api/endpoints/projects_api.py +38 -7
- datamint/api/endpoints/resources_api.py +227 -100
- datamint/api/entity_base_api.py +66 -7
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/client_cmd_tools/datamint_config.py +0 -1
- datamint/client_cmd_tools/datamint_upload.py +3 -1
- datamint/configs.py +11 -7
- datamint/dataset/base_dataset.py +24 -4
- datamint/dataset/dataset.py +1 -1
- datamint/entities/__init__.py +1 -1
- datamint/entities/annotations/__init__.py +13 -0
- datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
- datamint/entities/annotations/image_classification.py +12 -0
- datamint/entities/annotations/image_segmentation.py +252 -0
- datamint/entities/annotations/volume_segmentation.py +273 -0
- datamint/entities/base_entity.py +100 -6
- datamint/entities/cache_manager.py +129 -15
- datamint/entities/datasetinfo.py +60 -65
- datamint/entities/deployjob.py +18 -0
- datamint/entities/project.py +39 -0
- datamint/entities/resource.py +310 -46
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +65 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +131 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/flavors/__init__.py +17 -0
- datamint/mlflow/flavors/datamint_flavor.py +150 -0
- datamint/mlflow/flavors/model.py +877 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
- datamint/mlflow/models/__init__.py +93 -0
- datamint/mlflow/tracking/datamint_store.py +76 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +91 -0
- datamint/utils/env.py +27 -0
- datamint/utils/visualization.py +21 -13
- datamint-2.9.0.dist-info/METADATA +220 -0
- datamint-2.9.0.dist-info/RECORD +73 -0
- {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
- datamint-2.9.0.dist-info/entry_points.txt +18 -0
- datamint/apihandler/exp_api_handler.py +0 -204
- datamint/experiment/__init__.py +0 -1
- datamint/experiment/_patcher.py +0 -570
- datamint/experiment/experiment.py +0 -1049
- datamint-2.3.3.dist-info/METADATA +0 -125
- datamint-2.3.3.dist-info/RECORD +0 -54
- datamint-2.3.3.dist-info/entry_points.txt +0 -4
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Image segmentation annotation entity module for DataMint API.
|
|
2
|
+
|
|
3
|
+
This module defines the ImageSegmentation class for representing 2D segmentation
|
|
4
|
+
annotations in medical images.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .annotation import Annotation
|
|
8
|
+
from datamint.api.dto import AnnotationType
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from pydantic import PrivateAttr
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
_LOGGER = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ImageSegmentation(Annotation):
|
|
18
|
+
"""
|
|
19
|
+
Image-level (2D) segmentation annotation entity.
|
|
20
|
+
|
|
21
|
+
Represents a 2D segmentation mask for a single 2d image.
|
|
22
|
+
Supports both binary segmentation (single class) and multi-class
|
|
23
|
+
semantic segmentation.
|
|
24
|
+
|
|
25
|
+
This class provides factory methods to create annotations from numpy
|
|
26
|
+
arrays or PIL Images, which can then be uploaded via AnnotationsApi.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> # From binary mask
|
|
30
|
+
>>> mask = np.zeros((256, 256), dtype=np.uint8)
|
|
31
|
+
>>> mask[100:150, 100:150] = 1 # lesion region
|
|
32
|
+
>>> img_seg = ImageSegmentation.from_mask(
|
|
33
|
+
... mask=mask,
|
|
34
|
+
... name='lesion'
|
|
35
|
+
... )
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Upload via API
|
|
38
|
+
>>> api.annotations.upload_segmentations(
|
|
39
|
+
... resource='resource_id',
|
|
40
|
+
... file_path=img_seg.mask,
|
|
41
|
+
... name=img_seg.name
|
|
42
|
+
... )
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_mask: np.ndarray | Image.Image | None = PrivateAttr(default=None)
|
|
46
|
+
_class_name: str | None = PrivateAttr(default=None)
|
|
47
|
+
|
|
48
|
+
def __init__(self,
|
|
49
|
+
name: str | None = None,
|
|
50
|
+
mask: np.ndarray | Image.Image | None = None,
|
|
51
|
+
**kwargs):
|
|
52
|
+
"""
|
|
53
|
+
Initialize an ImageSegmentation annotation.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
name: The name/label for this segmentation class
|
|
57
|
+
mask: Optional 2D numpy array or PIL Image containing the segmentation mask
|
|
58
|
+
**kwargs: Additional fields passed to parent Annotation class
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(
|
|
61
|
+
identifier=name or "",
|
|
62
|
+
scope='image',
|
|
63
|
+
annotation_type=AnnotationType.SEGMENTATION,
|
|
64
|
+
**kwargs
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self._mask = mask
|
|
68
|
+
self._class_name = name
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_mask(cls,
|
|
72
|
+
mask: np.ndarray | Image.Image,
|
|
73
|
+
name: str,
|
|
74
|
+
**kwargs) -> 'ImageSegmentation':
|
|
75
|
+
"""
|
|
76
|
+
Create ImageSegmentation from a binary or class mask.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
mask: 2D numpy array (H x W) with integer labels or binary values,
|
|
80
|
+
or a PIL Image
|
|
81
|
+
name: The name/label for this segmentation
|
|
82
|
+
**kwargs: Additional annotation fields (imported_from, model_id, etc.)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
ImageSegmentation instance ready for upload
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If mask shape is invalid or data types are incorrect
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
>>> mask = np.zeros((512, 512), dtype=np.uint8)
|
|
92
|
+
>>> mask[200:300, 200:300] = 255 # binary mask
|
|
93
|
+
>>> img_seg = ImageSegmentation.from_mask(
|
|
94
|
+
... mask=mask,
|
|
95
|
+
... name='tumor',
|
|
96
|
+
... )
|
|
97
|
+
"""
|
|
98
|
+
# Convert PIL Image to numpy if needed
|
|
99
|
+
if isinstance(mask, Image.Image):
|
|
100
|
+
mask_array = np.array(mask)
|
|
101
|
+
else:
|
|
102
|
+
mask_array = mask
|
|
103
|
+
|
|
104
|
+
# Validate mask array
|
|
105
|
+
mask_array = cls._validate_mask_array(mask_array)
|
|
106
|
+
|
|
107
|
+
instance = cls(
|
|
108
|
+
name=name,
|
|
109
|
+
mask=mask_array,
|
|
110
|
+
**kwargs
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return instance
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _validate_mask_array(arr: np.ndarray) -> np.ndarray:
|
|
117
|
+
"""
|
|
118
|
+
Validate mask array shape and dtype.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
arr: Input array to validate
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Validated array (possibly with dtype conversion)
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
ValueError: If array is invalid
|
|
128
|
+
"""
|
|
129
|
+
if not isinstance(arr, np.ndarray):
|
|
130
|
+
raise ValueError(f"Expected numpy array, got {type(arr)}")
|
|
131
|
+
|
|
132
|
+
# Check dimensionality - should be 2D (H x W)
|
|
133
|
+
if arr.ndim != 2:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"Mask must be 2D (H x W), got shape {arr.shape}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Check dtype - convert floats to int if they're effectively integers
|
|
139
|
+
if np.issubdtype(arr.dtype, np.floating):
|
|
140
|
+
if not np.allclose(arr, arr.astype(int)):
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"Mask array contains non-integer float values"
|
|
143
|
+
)
|
|
144
|
+
arr = arr.astype(np.uint8)
|
|
145
|
+
elif not np.issubdtype(arr.dtype, np.integer):
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Mask must have integer dtype, got {arr.dtype}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Check for negative values
|
|
151
|
+
if np.any(arr < 0):
|
|
152
|
+
raise ValueError("Mask array contains negative values")
|
|
153
|
+
|
|
154
|
+
return arr
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def mask(self) -> np.ndarray | None:
|
|
158
|
+
"""
|
|
159
|
+
Get the stored segmentation mask.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
2D numpy array or None if not stored
|
|
163
|
+
"""
|
|
164
|
+
return self._mask
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def mask_shape(self) -> tuple[int, int] | None:
|
|
168
|
+
"""
|
|
169
|
+
Get the shape of the stored mask.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Shape tuple (H, W) or None if no mask stored
|
|
173
|
+
"""
|
|
174
|
+
if self._mask is None:
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
if isinstance(self._mask, Image.Image):
|
|
178
|
+
return (self._mask.height, self._mask.width)
|
|
179
|
+
|
|
180
|
+
return self._mask.shape
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def class_name(self) -> str | None:
|
|
184
|
+
"""
|
|
185
|
+
Get the class name for this segmentation.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Class name string or None
|
|
189
|
+
"""
|
|
190
|
+
return self._class_name
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def name(self) -> str | None:
|
|
194
|
+
"""
|
|
195
|
+
Alias for class_name.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Class name string or None
|
|
199
|
+
"""
|
|
200
|
+
return self._class_name
|
|
201
|
+
|
|
202
|
+
def to_pil_image(self) -> Image.Image | None:
|
|
203
|
+
"""
|
|
204
|
+
Convert the mask to a PIL Image.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
PIL Image or None if no mask stored
|
|
208
|
+
"""
|
|
209
|
+
if self._mask is None:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
if isinstance(self._mask, Image.Image):
|
|
213
|
+
return self._mask
|
|
214
|
+
|
|
215
|
+
return Image.fromarray(self._mask)
|
|
216
|
+
|
|
217
|
+
def get_binary_mask(self, threshold: int = 0) -> np.ndarray | None:
|
|
218
|
+
"""
|
|
219
|
+
Get a binary version of the mask.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
threshold: Values above this threshold are set to 1
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Binary numpy array (0s and 1s) or None if no mask stored
|
|
226
|
+
"""
|
|
227
|
+
if self._mask is None:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
if isinstance(self._mask, Image.Image):
|
|
231
|
+
mask_array = np.array(self._mask)
|
|
232
|
+
else:
|
|
233
|
+
mask_array = self._mask
|
|
234
|
+
|
|
235
|
+
return (mask_array > threshold).astype(np.uint8)
|
|
236
|
+
|
|
237
|
+
def get_area(self) -> int | None:
|
|
238
|
+
"""
|
|
239
|
+
Get the area (number of positive pixels) of the mask.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Number of non-zero pixels or None if no mask stored
|
|
243
|
+
"""
|
|
244
|
+
if self._mask is None:
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
if isinstance(self._mask, Image.Image):
|
|
248
|
+
mask_array = np.array(self._mask)
|
|
249
|
+
else:
|
|
250
|
+
mask_array = self._mask
|
|
251
|
+
|
|
252
|
+
return int(np.count_nonzero(mask_array))
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""Volume segmentation annotation entity module for DataMint API.
|
|
2
|
+
|
|
3
|
+
This module defines the VolumeSegmentation class for representing 3D segmentation
|
|
4
|
+
annotations in medical imaging volumes.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .annotation import Annotation
|
|
8
|
+
from datamint.api.dto import AnnotationType
|
|
9
|
+
import numpy as np
|
|
10
|
+
from nibabel.nifti1 import Nifti1Image
|
|
11
|
+
from pydantic import PrivateAttr
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
_LOGGER = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VolumeSegmentation(Annotation):
|
|
18
|
+
"""
|
|
19
|
+
Volume-level segmentation annotation entity.
|
|
20
|
+
|
|
21
|
+
Represents a 3D segmentation mask for medical imaging volumes.
|
|
22
|
+
Supports both semantic segmentation (class per voxel) and instance
|
|
23
|
+
segmentation (unique ID per object).
|
|
24
|
+
|
|
25
|
+
This class provides factory methods to create annotations from numpy
|
|
26
|
+
arrays or NIfTI images, which can then be uploaded via AnnotationsApi.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> # From semantic segmentation
|
|
30
|
+
>>> seg_data = np.array([...]) # Shape: (H, W, D)
|
|
31
|
+
>>> class_map = {1: 'tumor', 2: 'edema'}
|
|
32
|
+
>>> vol_seg = VolumeSegmentation.from_semantic_segmentation(
|
|
33
|
+
... segmentation=seg_data,
|
|
34
|
+
... class_map=class_map
|
|
35
|
+
... )
|
|
36
|
+
>>>
|
|
37
|
+
>>> # Upload via API
|
|
38
|
+
>>> api.annotations.upload_segmentations(
|
|
39
|
+
... resource='resource_id',
|
|
40
|
+
... file_path=vol_seg.segmentation_data,
|
|
41
|
+
... name=vol_seg.class_map
|
|
42
|
+
... )
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
raw_data: bytes | None = None
|
|
46
|
+
|
|
47
|
+
_segmentation_data: np.ndarray | Nifti1Image = PrivateAttr()
|
|
48
|
+
_class_map: dict[int, str] = PrivateAttr()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def __init__(self,
|
|
52
|
+
**kwargs):
|
|
53
|
+
"""
|
|
54
|
+
Initialize a VolumeSegmentation annotation.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
**kwargs: Additional fields passed to parent Annotation class
|
|
58
|
+
"""
|
|
59
|
+
kwargs['scope'] = 'image'
|
|
60
|
+
kwargs['annotation_type'] = AnnotationType.SEGMENTATION
|
|
61
|
+
|
|
62
|
+
super().__init__(
|
|
63
|
+
identifier="",
|
|
64
|
+
**kwargs
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_semantic_segmentation(cls,
|
|
69
|
+
segmentation: np.ndarray | Nifti1Image,
|
|
70
|
+
class_map: dict[int, str] | str,
|
|
71
|
+
**kwargs) -> 'VolumeSegmentation':
|
|
72
|
+
"""
|
|
73
|
+
Create VolumeSegmentation from semantic segmentation data.
|
|
74
|
+
|
|
75
|
+
Semantic segmentation: each voxel has a single integer label
|
|
76
|
+
corresponding to its class.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
segmentation: 3D numpy array (H x W x D) or Nifti1Image with
|
|
80
|
+
integer labels representing classes
|
|
81
|
+
class_map: Mapping from label integers to class names, or a
|
|
82
|
+
single class name for binary segmentation (background=0, class=1)
|
|
83
|
+
**kwargs: Additional annotation fields (imported_from, model_id, etc.)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
VolumeSegmentation instance ready for upload
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: If segmentation shape is invalid, class_map is incomplete,
|
|
90
|
+
or data types are incorrect
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> seg = np.zeros((256, 256, 128), dtype=np.int32)
|
|
94
|
+
>>> seg[100:150, 100:150, 50:75] = 1 # tumor region
|
|
95
|
+
>>> vol_seg = VolumeSegmentation.from_semantic_segmentation(
|
|
96
|
+
... segmentation=seg,
|
|
97
|
+
... class_map={1: 'tumor'}, # or just ``class_map='tumor'``
|
|
98
|
+
... )
|
|
99
|
+
"""
|
|
100
|
+
# Step 1: Convert Nifti1Image to numpy if needed
|
|
101
|
+
if isinstance(segmentation, Nifti1Image):
|
|
102
|
+
seg_array = segmentation.get_fdata().astype(np.int32)
|
|
103
|
+
else:
|
|
104
|
+
seg_array = segmentation
|
|
105
|
+
|
|
106
|
+
# Step 2: Validate segmentation array
|
|
107
|
+
seg_array = cls._validate_segmentation_array(seg_array)
|
|
108
|
+
|
|
109
|
+
# Step 3: Standardize class_map to dict[int, str]
|
|
110
|
+
standardized_class_map = cls._standardize_class_map(class_map, seg_array)
|
|
111
|
+
|
|
112
|
+
instance = cls(**kwargs)
|
|
113
|
+
|
|
114
|
+
instance._segmentation_data = segmentation
|
|
115
|
+
instance._class_map = standardized_class_map
|
|
116
|
+
|
|
117
|
+
return instance
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _validate_segmentation_array(arr: np.ndarray) -> np.ndarray:
|
|
121
|
+
"""
|
|
122
|
+
Validate segmentation array shape and dtype.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
arr: Input array to validate
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Validated array (possibly with dtype conversion)
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValueError: If array is invalid
|
|
132
|
+
"""
|
|
133
|
+
if not isinstance(arr, np.ndarray):
|
|
134
|
+
raise ValueError(f"Expected numpy array, got {type(arr)}")
|
|
135
|
+
|
|
136
|
+
# Check dimensionality
|
|
137
|
+
if arr.ndim != 3:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Segmentation must be 3D (H x W x D), got shape {arr.shape}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Check dtype
|
|
143
|
+
if not np.issubdtype(arr.dtype, np.integer):
|
|
144
|
+
# Try to convert to int
|
|
145
|
+
if np.issubdtype(arr.dtype, np.floating):
|
|
146
|
+
# Check if values are effectively integers
|
|
147
|
+
if not np.allclose(arr, arr.astype(int)):
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"Segmentation array contains non-integer float values"
|
|
150
|
+
)
|
|
151
|
+
arr = arr.astype(np.int32)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Segmentation must have integer dtype, got {arr.dtype}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Check for negative values
|
|
158
|
+
if np.any(arr < 0):
|
|
159
|
+
raise ValueError("Segmentation array contains negative values")
|
|
160
|
+
|
|
161
|
+
return arr
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _standardize_class_map(
|
|
165
|
+
class_map: dict[int, str] | str,
|
|
166
|
+
segmentation: np.ndarray
|
|
167
|
+
) -> dict[int, str]:
|
|
168
|
+
"""
|
|
169
|
+
Convert class_map to standard dict[int, str] format.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
class_map: Either a dict or a single class name for binary seg
|
|
173
|
+
segmentation: The segmentation array to infer labels from
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Standardized dictionary mapping labels to class names
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
ValueError: If class_map format is invalid
|
|
180
|
+
"""
|
|
181
|
+
if isinstance(class_map, str):
|
|
182
|
+
# Binary segmentation: assume label 1 = class_map, 0 = background
|
|
183
|
+
unique_labels = np.unique(segmentation)
|
|
184
|
+
unique_labels = unique_labels[unique_labels > 0] # Exclude 0
|
|
185
|
+
|
|
186
|
+
if len(unique_labels) != 1:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"Single class name provided but segmentation has "
|
|
189
|
+
f"{len(unique_labels)} non-zero labels: {unique_labels.tolist()}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return {int(unique_labels[0]): class_map}
|
|
193
|
+
|
|
194
|
+
elif isinstance(class_map, dict):
|
|
195
|
+
# Validate all keys are integers, all values are strings
|
|
196
|
+
standardized = {}
|
|
197
|
+
for k, v in class_map.items():
|
|
198
|
+
if not isinstance(k, (int, np.integer)):
|
|
199
|
+
raise ValueError(f"class_map key must be integer, got {type(k)}")
|
|
200
|
+
if not isinstance(v, str):
|
|
201
|
+
raise ValueError(f"class_map value must be string, got {type(v)}")
|
|
202
|
+
standardized[int(k)] = v
|
|
203
|
+
|
|
204
|
+
return standardized
|
|
205
|
+
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"class_map must be dict[int, str] or str, got {type(class_map)}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def volume_shape(self) -> tuple[int, int, int] | None:
|
|
215
|
+
"""
|
|
216
|
+
Get the shape of the stored segmentation volume.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Shape tuple (H, W, D) or None if no data stored
|
|
220
|
+
"""
|
|
221
|
+
if self._segmentation_data is None:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
if isinstance(self._segmentation_data, Nifti1Image):
|
|
225
|
+
shape = self._segmentation_data.shape
|
|
226
|
+
return (shape[0], shape[1], shape[2])
|
|
227
|
+
else:
|
|
228
|
+
return self._segmentation_data.shape
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def class_names(self) -> list[str] | None:
|
|
232
|
+
"""
|
|
233
|
+
Get list of class names from stored class_map.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List of class names or None if no class_map stored
|
|
237
|
+
"""
|
|
238
|
+
if self._class_map is None:
|
|
239
|
+
return None
|
|
240
|
+
return sorted(self._class_map.values())
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def num_classes(self) -> int | None:
|
|
244
|
+
"""
|
|
245
|
+
Get number of classes in this segmentation.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Number of classes or None if no class_map stored
|
|
249
|
+
"""
|
|
250
|
+
if self._class_map is None:
|
|
251
|
+
return None
|
|
252
|
+
return len(self._class_map)
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def class_map(self) -> dict[int, str]:
|
|
256
|
+
"""
|
|
257
|
+
Get the stored class map.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Dictionary mapping labels to class names, or None
|
|
261
|
+
"""
|
|
262
|
+
return self._class_map
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def segmentation_data(self) -> np.ndarray | Nifti1Image:
|
|
266
|
+
"""
|
|
267
|
+
Get the stored segmentation data.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Segmentation array/image or None if not stored
|
|
271
|
+
"""
|
|
272
|
+
return self._segmentation_data
|
|
273
|
+
|
datamint/entities/base_entity.py
CHANGED
|
@@ -31,13 +31,24 @@ class BaseEntity(BaseModel):
|
|
|
31
31
|
are created through API endpoints.
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
-
model_config = ConfigDict(extra='allow',
|
|
34
|
+
model_config = ConfigDict(extra='allow',
|
|
35
|
+
arbitrary_types_allowed=True, # Allow extra fields and arbitrary types
|
|
36
|
+
ser_json_bytes='base64',
|
|
37
|
+
val_json_bytes='base64')
|
|
35
38
|
|
|
36
39
|
_api: 'EntityBaseApi[Self] | EntityBaseApi' = PrivateAttr()
|
|
37
40
|
|
|
41
|
+
def __init__(self, **data):
|
|
42
|
+
super().__init__(**data)
|
|
43
|
+
# check attributes for MISSING_FIELD and delete them
|
|
44
|
+
for field_name in self.__pydantic_fields__.keys():
|
|
45
|
+
if hasattr(self, field_name) and getattr(self, field_name) == MISSING_FIELD:
|
|
46
|
+
delattr(self, field_name)
|
|
47
|
+
|
|
38
48
|
def asdict(self) -> dict[str, Any]:
|
|
39
49
|
"""Convert the entity to a dictionary, including unknown fields."""
|
|
40
|
-
|
|
50
|
+
d = self.model_dump(warnings='none')
|
|
51
|
+
return {k: v for k, v in d.items() if v != MISSING_FIELD}
|
|
41
52
|
|
|
42
53
|
def asjson(self) -> str:
|
|
43
54
|
"""Convert the entity to a JSON string, including unknown fields."""
|
|
@@ -59,10 +70,13 @@ class BaseEntity(BaseModel):
|
|
|
59
70
|
if have_to_log:
|
|
60
71
|
_LOGGER.warning(f"Unknown fields {list(self.__pydantic_extra__.keys())} found in {class_name}")
|
|
61
72
|
|
|
62
|
-
|
|
63
|
-
def is_attr_missing(value: Any) -> bool:
|
|
73
|
+
def is_attr_missing(self, attr_name: str) -> bool:
|
|
64
74
|
"""Check if a value is the MISSING_FIELD sentinel."""
|
|
65
|
-
|
|
75
|
+
if attr_name not in self.__pydantic_fields__.keys():
|
|
76
|
+
raise AttributeError(f"Attribute '{attr_name}' not found in entity of type '{self.__class__.__name__}'")
|
|
77
|
+
if not hasattr(self, attr_name):
|
|
78
|
+
return True
|
|
79
|
+
return getattr(self, attr_name) == MISSING_FIELD # deprecated
|
|
66
80
|
|
|
67
81
|
def _refresh(self) -> Self:
|
|
68
82
|
"""Refresh the entity data from the server.
|
|
@@ -88,5 +102,85 @@ class BaseEntity(BaseModel):
|
|
|
88
102
|
Args:
|
|
89
103
|
attr_name: Name of the attribute to check and ensure
|
|
90
104
|
"""
|
|
91
|
-
if self.
|
|
105
|
+
if attr_name not in self.__pydantic_fields__.keys():
|
|
106
|
+
raise AttributeError(f"Attribute '{attr_name}' not found in entity of type '{self.__class__.__name__}'")
|
|
107
|
+
|
|
108
|
+
if self.is_attr_missing(attr_name):
|
|
92
109
|
self._refresh()
|
|
110
|
+
|
|
111
|
+
def has_missing_attrs(self) -> bool:
|
|
112
|
+
"""Check if the entity has any attributes that are MISSING_FIELD.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
True if any attribute is MISSING_FIELD, False otherwise
|
|
116
|
+
"""
|
|
117
|
+
return any(self.is_attr_missing(attr_name) for attr_name in self.__pydantic_fields__.keys())
|
|
118
|
+
|
|
119
|
+
def _fetch_and_cache_file_data(
|
|
120
|
+
self,
|
|
121
|
+
cache_manager: 'Any', # CacheManager[bytes]
|
|
122
|
+
data_key: str,
|
|
123
|
+
version_info: dict[str, Any],
|
|
124
|
+
download_callback: 'Any', # Callable[[str | None], bytes]
|
|
125
|
+
save_path: str | None = None,
|
|
126
|
+
use_cache: bool = False,
|
|
127
|
+
) -> bytes:
|
|
128
|
+
"""Shared logic for fetching and caching file data.
|
|
129
|
+
|
|
130
|
+
This method handles the caching strategy for both Resource and Annotation entities.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
cache_manager: The CacheManager instance to use
|
|
134
|
+
data_key: Key identifying the type of data (e.g., 'image_data', 'annotation_data')
|
|
135
|
+
version_info: Version information for cache validation
|
|
136
|
+
download_callback: Function to call to download the file, takes save_path as parameter
|
|
137
|
+
save_path: Optional path to save the file locally
|
|
138
|
+
use_cache: If True, uses cached data when available
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
File data as bytes
|
|
142
|
+
"""
|
|
143
|
+
from pathlib import Path
|
|
144
|
+
|
|
145
|
+
# Try to get from cache
|
|
146
|
+
img_data = None
|
|
147
|
+
|
|
148
|
+
if use_cache:
|
|
149
|
+
img_data = cache_manager.get(self.id, data_key, version_info)
|
|
150
|
+
if img_data is not None:
|
|
151
|
+
_LOGGER.debug(f"Using cached data for {self.__class__.__name__} {self.id}")
|
|
152
|
+
|
|
153
|
+
if img_data is None:
|
|
154
|
+
# Cache miss - fetch from server
|
|
155
|
+
if use_cache and save_path:
|
|
156
|
+
# Download directly to save_path, register location in cache metadata
|
|
157
|
+
_LOGGER.debug(f"Downloading to save_path: {save_path}")
|
|
158
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
img_data = download_callback(save_path)
|
|
161
|
+
|
|
162
|
+
# Register save_path in cache metadata (no file duplication)
|
|
163
|
+
cache_manager.register_file_location(
|
|
164
|
+
self.id, data_key, save_path, version_info
|
|
165
|
+
)
|
|
166
|
+
elif use_cache:
|
|
167
|
+
# No save_path - download to cache directory
|
|
168
|
+
cache_path = cache_manager.get_expected_path(self.id, data_key)
|
|
169
|
+
_LOGGER.debug(f"Downloading to cache: {cache_path}")
|
|
170
|
+
|
|
171
|
+
img_data = download_callback(str(cache_path))
|
|
172
|
+
|
|
173
|
+
# Register in cache metadata
|
|
174
|
+
cache_manager.set(self.id, data_key, img_data, version_info)
|
|
175
|
+
else:
|
|
176
|
+
# No caching - direct download to save_path (or just return bytes)
|
|
177
|
+
_LOGGER.debug(f"Fetching data from server for {self.__class__.__name__} {self.id}")
|
|
178
|
+
img_data = download_callback(save_path)
|
|
179
|
+
elif save_path:
|
|
180
|
+
# Cached data found, but user wants to save to a specific path
|
|
181
|
+
_LOGGER.debug(f"Saving cached data to specified path: {save_path}")
|
|
182
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
with open(save_path, 'wb') as f:
|
|
184
|
+
f.write(img_data)
|
|
185
|
+
|
|
186
|
+
return img_data
|