datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -0,0 +1,252 @@
1
+ """Image segmentation annotation entity module for DataMint API.
2
+
3
+ This module defines the ImageSegmentation class for representing 2D segmentation
4
+ annotations in medical images.
5
+ """
6
+
7
+ from .annotation import Annotation
8
+ from datamint.api.dto import AnnotationType
9
+ import numpy as np
10
+ from PIL import Image
11
+ from pydantic import PrivateAttr
12
+ import logging
13
+
14
+ _LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ class ImageSegmentation(Annotation):
18
+ """
19
+ Image-level (2D) segmentation annotation entity.
20
+
21
+ Represents a 2D segmentation mask for a single 2d image.
22
+ Supports both binary segmentation (single class) and multi-class
23
+ semantic segmentation.
24
+
25
+ This class provides factory methods to create annotations from numpy
26
+ arrays or PIL Images, which can then be uploaded via AnnotationsApi.
27
+
28
+ Example:
29
+ >>> # From binary mask
30
+ >>> mask = np.zeros((256, 256), dtype=np.uint8)
31
+ >>> mask[100:150, 100:150] = 1 # lesion region
32
+ >>> img_seg = ImageSegmentation.from_mask(
33
+ ... mask=mask,
34
+ ... name='lesion'
35
+ ... )
36
+ >>>
37
+ >>> # Upload via API
38
+ >>> api.annotations.upload_segmentations(
39
+ ... resource='resource_id',
40
+ ... file_path=img_seg.mask,
41
+ ... name=img_seg.name
42
+ ... )
43
+ """
44
+
45
+ _mask: np.ndarray | Image.Image | None = PrivateAttr(default=None)
46
+ _class_name: str | None = PrivateAttr(default=None)
47
+
48
+ def __init__(self,
49
+ name: str | None = None,
50
+ mask: np.ndarray | Image.Image | None = None,
51
+ **kwargs):
52
+ """
53
+ Initialize an ImageSegmentation annotation.
54
+
55
+ Args:
56
+ name: The name/label for this segmentation class
57
+ mask: Optional 2D numpy array or PIL Image containing the segmentation mask
58
+ **kwargs: Additional fields passed to parent Annotation class
59
+ """
60
+ super().__init__(
61
+ identifier=name or "",
62
+ scope='image',
63
+ annotation_type=AnnotationType.SEGMENTATION,
64
+ **kwargs
65
+ )
66
+
67
+ self._mask = mask
68
+ self._class_name = name
69
+
70
+ @classmethod
71
+ def from_mask(cls,
72
+ mask: np.ndarray | Image.Image,
73
+ name: str,
74
+ **kwargs) -> 'ImageSegmentation':
75
+ """
76
+ Create ImageSegmentation from a binary or class mask.
77
+
78
+ Args:
79
+ mask: 2D numpy array (H x W) with integer labels or binary values,
80
+ or a PIL Image
81
+ name: The name/label for this segmentation
82
+ **kwargs: Additional annotation fields (imported_from, model_id, etc.)
83
+
84
+ Returns:
85
+ ImageSegmentation instance ready for upload
86
+
87
+ Raises:
88
+ ValueError: If mask shape is invalid or data types are incorrect
89
+
90
+ Example:
91
+ >>> mask = np.zeros((512, 512), dtype=np.uint8)
92
+ >>> mask[200:300, 200:300] = 255 # binary mask
93
+ >>> img_seg = ImageSegmentation.from_mask(
94
+ ... mask=mask,
95
+ ... name='tumor',
96
+ ... )
97
+ """
98
+ # Convert PIL Image to numpy if needed
99
+ if isinstance(mask, Image.Image):
100
+ mask_array = np.array(mask)
101
+ else:
102
+ mask_array = mask
103
+
104
+ # Validate mask array
105
+ mask_array = cls._validate_mask_array(mask_array)
106
+
107
+ instance = cls(
108
+ name=name,
109
+ mask=mask_array,
110
+ **kwargs
111
+ )
112
+
113
+ return instance
114
+
115
+ @staticmethod
116
+ def _validate_mask_array(arr: np.ndarray) -> np.ndarray:
117
+ """
118
+ Validate mask array shape and dtype.
119
+
120
+ Args:
121
+ arr: Input array to validate
122
+
123
+ Returns:
124
+ Validated array (possibly with dtype conversion)
125
+
126
+ Raises:
127
+ ValueError: If array is invalid
128
+ """
129
+ if not isinstance(arr, np.ndarray):
130
+ raise ValueError(f"Expected numpy array, got {type(arr)}")
131
+
132
+ # Check dimensionality - should be 2D (H x W)
133
+ if arr.ndim != 2:
134
+ raise ValueError(
135
+ f"Mask must be 2D (H x W), got shape {arr.shape}"
136
+ )
137
+
138
+ # Check dtype - convert floats to int if they're effectively integers
139
+ if np.issubdtype(arr.dtype, np.floating):
140
+ if not np.allclose(arr, arr.astype(int)):
141
+ raise ValueError(
142
+ "Mask array contains non-integer float values"
143
+ )
144
+ arr = arr.astype(np.uint8)
145
+ elif not np.issubdtype(arr.dtype, np.integer):
146
+ raise ValueError(
147
+ f"Mask must have integer dtype, got {arr.dtype}"
148
+ )
149
+
150
+ # Check for negative values
151
+ if np.any(arr < 0):
152
+ raise ValueError("Mask array contains negative values")
153
+
154
+ return arr
155
+
156
+ @property
157
+ def mask(self) -> np.ndarray | None:
158
+ """
159
+ Get the stored segmentation mask.
160
+
161
+ Returns:
162
+ 2D numpy array or None if not stored
163
+ """
164
+ return self._mask
165
+
166
+ @property
167
+ def mask_shape(self) -> tuple[int, int] | None:
168
+ """
169
+ Get the shape of the stored mask.
170
+
171
+ Returns:
172
+ Shape tuple (H, W) or None if no mask stored
173
+ """
174
+ if self._mask is None:
175
+ return None
176
+
177
+ if isinstance(self._mask, Image.Image):
178
+ return (self._mask.height, self._mask.width)
179
+
180
+ return self._mask.shape
181
+
182
+ @property
183
+ def class_name(self) -> str | None:
184
+ """
185
+ Get the class name for this segmentation.
186
+
187
+ Returns:
188
+ Class name string or None
189
+ """
190
+ return self._class_name
191
+
192
+ @property
193
+ def name(self) -> str | None:
194
+ """
195
+ Alias for class_name.
196
+
197
+ Returns:
198
+ Class name string or None
199
+ """
200
+ return self._class_name
201
+
202
+ def to_pil_image(self) -> Image.Image | None:
203
+ """
204
+ Convert the mask to a PIL Image.
205
+
206
+ Returns:
207
+ PIL Image or None if no mask stored
208
+ """
209
+ if self._mask is None:
210
+ return None
211
+
212
+ if isinstance(self._mask, Image.Image):
213
+ return self._mask
214
+
215
+ return Image.fromarray(self._mask)
216
+
217
+ def get_binary_mask(self, threshold: int = 0) -> np.ndarray | None:
218
+ """
219
+ Get a binary version of the mask.
220
+
221
+ Args:
222
+ threshold: Values above this threshold are set to 1
223
+
224
+ Returns:
225
+ Binary numpy array (0s and 1s) or None if no mask stored
226
+ """
227
+ if self._mask is None:
228
+ return None
229
+
230
+ if isinstance(self._mask, Image.Image):
231
+ mask_array = np.array(self._mask)
232
+ else:
233
+ mask_array = self._mask
234
+
235
+ return (mask_array > threshold).astype(np.uint8)
236
+
237
+ def get_area(self) -> int | None:
238
+ """
239
+ Get the area (number of positive pixels) of the mask.
240
+
241
+ Returns:
242
+ Number of non-zero pixels or None if no mask stored
243
+ """
244
+ if self._mask is None:
245
+ return None
246
+
247
+ if isinstance(self._mask, Image.Image):
248
+ mask_array = np.array(self._mask)
249
+ else:
250
+ mask_array = self._mask
251
+
252
+ return int(np.count_nonzero(mask_array))
@@ -0,0 +1,273 @@
1
+ """Volume segmentation annotation entity module for DataMint API.
2
+
3
+ This module defines the VolumeSegmentation class for representing 3D segmentation
4
+ annotations in medical imaging volumes.
5
+ """
6
+
7
+ from .annotation import Annotation
8
+ from datamint.api.dto import AnnotationType
9
+ import numpy as np
10
+ from nibabel.nifti1 import Nifti1Image
11
+ from pydantic import PrivateAttr
12
+ import logging
13
+
14
+ _LOGGER = logging.getLogger(__name__)
15
+
16
+
17
+ class VolumeSegmentation(Annotation):
18
+ """
19
+ Volume-level segmentation annotation entity.
20
+
21
+ Represents a 3D segmentation mask for medical imaging volumes.
22
+ Supports both semantic segmentation (class per voxel) and instance
23
+ segmentation (unique ID per object).
24
+
25
+ This class provides factory methods to create annotations from numpy
26
+ arrays or NIfTI images, which can then be uploaded via AnnotationsApi.
27
+
28
+ Example:
29
+ >>> # From semantic segmentation
30
+ >>> seg_data = np.array([...]) # Shape: (H, W, D)
31
+ >>> class_map = {1: 'tumor', 2: 'edema'}
32
+ >>> vol_seg = VolumeSegmentation.from_semantic_segmentation(
33
+ ... segmentation=seg_data,
34
+ ... class_map=class_map
35
+ ... )
36
+ >>>
37
+ >>> # Upload via API
38
+ >>> api.annotations.upload_segmentations(
39
+ ... resource='resource_id',
40
+ ... file_path=vol_seg.segmentation_data,
41
+ ... name=vol_seg.class_map
42
+ ... )
43
+ """
44
+
45
+ raw_data: bytes | None = None
46
+
47
+ _segmentation_data: np.ndarray | Nifti1Image = PrivateAttr()
48
+ _class_map: dict[int, str] = PrivateAttr()
49
+
50
+
51
+ def __init__(self,
52
+ **kwargs):
53
+ """
54
+ Initialize a VolumeSegmentation annotation.
55
+
56
+ Args:
57
+ **kwargs: Additional fields passed to parent Annotation class
58
+ """
59
+ kwargs['scope'] = 'image'
60
+ kwargs['annotation_type'] = AnnotationType.SEGMENTATION
61
+
62
+ super().__init__(
63
+ identifier="",
64
+ **kwargs
65
+ )
66
+
67
+ @classmethod
68
+ def from_semantic_segmentation(cls,
69
+ segmentation: np.ndarray | Nifti1Image,
70
+ class_map: dict[int, str] | str,
71
+ **kwargs) -> 'VolumeSegmentation':
72
+ """
73
+ Create VolumeSegmentation from semantic segmentation data.
74
+
75
+ Semantic segmentation: each voxel has a single integer label
76
+ corresponding to its class.
77
+
78
+ Args:
79
+ segmentation: 3D numpy array (H x W x D) or Nifti1Image with
80
+ integer labels representing classes
81
+ class_map: Mapping from label integers to class names, or a
82
+ single class name for binary segmentation (background=0, class=1)
83
+ **kwargs: Additional annotation fields (imported_from, model_id, etc.)
84
+
85
+ Returns:
86
+ VolumeSegmentation instance ready for upload
87
+
88
+ Raises:
89
+ ValueError: If segmentation shape is invalid, class_map is incomplete,
90
+ or data types are incorrect
91
+
92
+ Example:
93
+ >>> seg = np.zeros((256, 256, 128), dtype=np.int32)
94
+ >>> seg[100:150, 100:150, 50:75] = 1 # tumor region
95
+ >>> vol_seg = VolumeSegmentation.from_semantic_segmentation(
96
+ ... segmentation=seg,
97
+ ... class_map={1: 'tumor'}, # or just ``class_map='tumor'``
98
+ ... )
99
+ """
100
+ # Step 1: Convert Nifti1Image to numpy if needed
101
+ if isinstance(segmentation, Nifti1Image):
102
+ seg_array = segmentation.get_fdata().astype(np.int32)
103
+ else:
104
+ seg_array = segmentation
105
+
106
+ # Step 2: Validate segmentation array
107
+ seg_array = cls._validate_segmentation_array(seg_array)
108
+
109
+ # Step 3: Standardize class_map to dict[int, str]
110
+ standardized_class_map = cls._standardize_class_map(class_map, seg_array)
111
+
112
+ instance = cls(**kwargs)
113
+
114
+ instance._segmentation_data = segmentation
115
+ instance._class_map = standardized_class_map
116
+
117
+ return instance
118
+
119
+ @staticmethod
120
+ def _validate_segmentation_array(arr: np.ndarray) -> np.ndarray:
121
+ """
122
+ Validate segmentation array shape and dtype.
123
+
124
+ Args:
125
+ arr: Input array to validate
126
+
127
+ Returns:
128
+ Validated array (possibly with dtype conversion)
129
+
130
+ Raises:
131
+ ValueError: If array is invalid
132
+ """
133
+ if not isinstance(arr, np.ndarray):
134
+ raise ValueError(f"Expected numpy array, got {type(arr)}")
135
+
136
+ # Check dimensionality
137
+ if arr.ndim != 3:
138
+ raise ValueError(
139
+ f"Segmentation must be 3D (H x W x D), got shape {arr.shape}"
140
+ )
141
+
142
+ # Check dtype
143
+ if not np.issubdtype(arr.dtype, np.integer):
144
+ # Try to convert to int
145
+ if np.issubdtype(arr.dtype, np.floating):
146
+ # Check if values are effectively integers
147
+ if not np.allclose(arr, arr.astype(int)):
148
+ raise ValueError(
149
+ "Segmentation array contains non-integer float values"
150
+ )
151
+ arr = arr.astype(np.int32)
152
+ else:
153
+ raise ValueError(
154
+ f"Segmentation must have integer dtype, got {arr.dtype}"
155
+ )
156
+
157
+ # Check for negative values
158
+ if np.any(arr < 0):
159
+ raise ValueError("Segmentation array contains negative values")
160
+
161
+ return arr
162
+
163
+ @staticmethod
164
+ def _standardize_class_map(
165
+ class_map: dict[int, str] | str,
166
+ segmentation: np.ndarray
167
+ ) -> dict[int, str]:
168
+ """
169
+ Convert class_map to standard dict[int, str] format.
170
+
171
+ Args:
172
+ class_map: Either a dict or a single class name for binary seg
173
+ segmentation: The segmentation array to infer labels from
174
+
175
+ Returns:
176
+ Standardized dictionary mapping labels to class names
177
+
178
+ Raises:
179
+ ValueError: If class_map format is invalid
180
+ """
181
+ if isinstance(class_map, str):
182
+ # Binary segmentation: assume label 1 = class_map, 0 = background
183
+ unique_labels = np.unique(segmentation)
184
+ unique_labels = unique_labels[unique_labels > 0] # Exclude 0
185
+
186
+ if len(unique_labels) != 1:
187
+ raise ValueError(
188
+ f"Single class name provided but segmentation has "
189
+ f"{len(unique_labels)} non-zero labels: {unique_labels.tolist()}"
190
+ )
191
+
192
+ return {int(unique_labels[0]): class_map}
193
+
194
+ elif isinstance(class_map, dict):
195
+ # Validate all keys are integers, all values are strings
196
+ standardized = {}
197
+ for k, v in class_map.items():
198
+ if not isinstance(k, (int, np.integer)):
199
+ raise ValueError(f"class_map key must be integer, got {type(k)}")
200
+ if not isinstance(v, str):
201
+ raise ValueError(f"class_map value must be string, got {type(v)}")
202
+ standardized[int(k)] = v
203
+
204
+ return standardized
205
+
206
+ else:
207
+ raise ValueError(
208
+ f"class_map must be dict[int, str] or str, got {type(class_map)}"
209
+ )
210
+
211
+
212
+
213
+ @property
214
+ def volume_shape(self) -> tuple[int, int, int] | None:
215
+ """
216
+ Get the shape of the stored segmentation volume.
217
+
218
+ Returns:
219
+ Shape tuple (H, W, D) or None if no data stored
220
+ """
221
+ if self._segmentation_data is None:
222
+ return None
223
+
224
+ if isinstance(self._segmentation_data, Nifti1Image):
225
+ shape = self._segmentation_data.shape
226
+ return (shape[0], shape[1], shape[2])
227
+ else:
228
+ return self._segmentation_data.shape
229
+
230
+ @property
231
+ def class_names(self) -> list[str] | None:
232
+ """
233
+ Get list of class names from stored class_map.
234
+
235
+ Returns:
236
+ List of class names or None if no class_map stored
237
+ """
238
+ if self._class_map is None:
239
+ return None
240
+ return sorted(self._class_map.values())
241
+
242
+ @property
243
+ def num_classes(self) -> int | None:
244
+ """
245
+ Get number of classes in this segmentation.
246
+
247
+ Returns:
248
+ Number of classes or None if no class_map stored
249
+ """
250
+ if self._class_map is None:
251
+ return None
252
+ return len(self._class_map)
253
+
254
+ @property
255
+ def class_map(self) -> dict[int, str]:
256
+ """
257
+ Get the stored class map.
258
+
259
+ Returns:
260
+ Dictionary mapping labels to class names, or None
261
+ """
262
+ return self._class_map
263
+
264
+ @property
265
+ def segmentation_data(self) -> np.ndarray | Nifti1Image:
266
+ """
267
+ Get the stored segmentation data.
268
+
269
+ Returns:
270
+ Segmentation array/image or None if not stored
271
+ """
272
+ return self._segmentation_data
273
+
@@ -31,13 +31,24 @@ class BaseEntity(BaseModel):
31
31
  are created through API endpoints.
32
32
  """
33
33
 
34
- model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True) # Allow extra fields and arbitrary types
34
+ model_config = ConfigDict(extra='allow',
35
+ arbitrary_types_allowed=True, # Allow extra fields and arbitrary types
36
+ ser_json_bytes='base64',
37
+ val_json_bytes='base64')
35
38
 
36
39
  _api: 'EntityBaseApi[Self] | EntityBaseApi' = PrivateAttr()
37
40
 
41
+ def __init__(self, **data):
42
+ super().__init__(**data)
43
+ # check attributes for MISSING_FIELD and delete them
44
+ for field_name in self.__pydantic_fields__.keys():
45
+ if hasattr(self, field_name) and getattr(self, field_name) == MISSING_FIELD:
46
+ delattr(self, field_name)
47
+
38
48
  def asdict(self) -> dict[str, Any]:
39
49
  """Convert the entity to a dictionary, including unknown fields."""
40
- return self.model_dump(warnings='none')
50
+ d = self.model_dump(warnings='none')
51
+ return {k: v for k, v in d.items() if v != MISSING_FIELD}
41
52
 
42
53
  def asjson(self) -> str:
43
54
  """Convert the entity to a JSON string, including unknown fields."""
@@ -59,10 +70,13 @@ class BaseEntity(BaseModel):
59
70
  if have_to_log:
60
71
  _LOGGER.warning(f"Unknown fields {list(self.__pydantic_extra__.keys())} found in {class_name}")
61
72
 
62
- @staticmethod
63
- def is_attr_missing(value: Any) -> bool:
73
+ def is_attr_missing(self, attr_name: str) -> bool:
64
74
  """Check if a value is the MISSING_FIELD sentinel."""
65
- return value == MISSING_FIELD
75
+ if attr_name not in self.__pydantic_fields__.keys():
76
+ raise AttributeError(f"Attribute '{attr_name}' not found in entity of type '{self.__class__.__name__}'")
77
+ if not hasattr(self, attr_name):
78
+ return True
79
+ return getattr(self, attr_name) == MISSING_FIELD # deprecated
66
80
 
67
81
  def _refresh(self) -> Self:
68
82
  """Refresh the entity data from the server.
@@ -88,5 +102,85 @@ class BaseEntity(BaseModel):
88
102
  Args:
89
103
  attr_name: Name of the attribute to check and ensure
90
104
  """
91
- if self.is_attr_missing(getattr(self, attr_name)):
105
+ if attr_name not in self.__pydantic_fields__.keys():
106
+ raise AttributeError(f"Attribute '{attr_name}' not found in entity of type '{self.__class__.__name__}'")
107
+
108
+ if self.is_attr_missing(attr_name):
92
109
  self._refresh()
110
+
111
+ def has_missing_attrs(self) -> bool:
112
+ """Check if the entity has any attributes that are MISSING_FIELD.
113
+
114
+ Returns:
115
+ True if any attribute is MISSING_FIELD, False otherwise
116
+ """
117
+ return any(self.is_attr_missing(attr_name) for attr_name in self.__pydantic_fields__.keys())
118
+
119
+ def _fetch_and_cache_file_data(
120
+ self,
121
+ cache_manager: 'Any', # CacheManager[bytes]
122
+ data_key: str,
123
+ version_info: dict[str, Any],
124
+ download_callback: 'Any', # Callable[[str | None], bytes]
125
+ save_path: str | None = None,
126
+ use_cache: bool = False,
127
+ ) -> bytes:
128
+ """Shared logic for fetching and caching file data.
129
+
130
+ This method handles the caching strategy for both Resource and Annotation entities.
131
+
132
+ Args:
133
+ cache_manager: The CacheManager instance to use
134
+ data_key: Key identifying the type of data (e.g., 'image_data', 'annotation_data')
135
+ version_info: Version information for cache validation
136
+ download_callback: Function to call to download the file, takes save_path as parameter
137
+ save_path: Optional path to save the file locally
138
+ use_cache: If True, uses cached data when available
139
+
140
+ Returns:
141
+ File data as bytes
142
+ """
143
+ from pathlib import Path
144
+
145
+ # Try to get from cache
146
+ img_data = None
147
+
148
+ if use_cache:
149
+ img_data = cache_manager.get(self.id, data_key, version_info)
150
+ if img_data is not None:
151
+ _LOGGER.debug(f"Using cached data for {self.__class__.__name__} {self.id}")
152
+
153
+ if img_data is None:
154
+ # Cache miss - fetch from server
155
+ if use_cache and save_path:
156
+ # Download directly to save_path, register location in cache metadata
157
+ _LOGGER.debug(f"Downloading to save_path: {save_path}")
158
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
159
+
160
+ img_data = download_callback(save_path)
161
+
162
+ # Register save_path in cache metadata (no file duplication)
163
+ cache_manager.register_file_location(
164
+ self.id, data_key, save_path, version_info
165
+ )
166
+ elif use_cache:
167
+ # No save_path - download to cache directory
168
+ cache_path = cache_manager.get_expected_path(self.id, data_key)
169
+ _LOGGER.debug(f"Downloading to cache: {cache_path}")
170
+
171
+ img_data = download_callback(str(cache_path))
172
+
173
+ # Register in cache metadata
174
+ cache_manager.set(self.id, data_key, img_data, version_info)
175
+ else:
176
+ # No caching - direct download to save_path (or just return bytes)
177
+ _LOGGER.debug(f"Fetching data from server for {self.__class__.__name__} {self.id}")
178
+ img_data = download_callback(save_path)
179
+ elif save_path:
180
+ # Cached data found, but user wants to save to a specific path
181
+ _LOGGER.debug(f"Saving cached data to specified path: {save_path}")
182
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
183
+ with open(save_path, 'wb') as f:
184
+ f.write(img_data)
185
+
186
+ return img_data