datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/__init__.py +1 -3
- datamint/api/__init__.py +0 -3
- datamint/api/base_api.py +286 -54
- datamint/api/client.py +76 -13
- datamint/api/endpoints/__init__.py +2 -2
- datamint/api/endpoints/annotations_api.py +186 -28
- datamint/api/endpoints/deploy_model_api.py +78 -0
- datamint/api/endpoints/models_api.py +1 -0
- datamint/api/endpoints/projects_api.py +38 -7
- datamint/api/endpoints/resources_api.py +227 -100
- datamint/api/entity_base_api.py +66 -7
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/client_cmd_tools/datamint_config.py +0 -1
- datamint/client_cmd_tools/datamint_upload.py +3 -1
- datamint/configs.py +11 -7
- datamint/dataset/base_dataset.py +24 -4
- datamint/dataset/dataset.py +1 -1
- datamint/entities/__init__.py +1 -1
- datamint/entities/annotations/__init__.py +13 -0
- datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
- datamint/entities/annotations/image_classification.py +12 -0
- datamint/entities/annotations/image_segmentation.py +252 -0
- datamint/entities/annotations/volume_segmentation.py +273 -0
- datamint/entities/base_entity.py +100 -6
- datamint/entities/cache_manager.py +129 -15
- datamint/entities/datasetinfo.py +60 -65
- datamint/entities/deployjob.py +18 -0
- datamint/entities/project.py +39 -0
- datamint/entities/resource.py +310 -46
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +65 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +131 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/flavors/__init__.py +17 -0
- datamint/mlflow/flavors/datamint_flavor.py +150 -0
- datamint/mlflow/flavors/model.py +877 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
- datamint/mlflow/models/__init__.py +93 -0
- datamint/mlflow/tracking/datamint_store.py +76 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +91 -0
- datamint/utils/env.py +27 -0
- datamint/utils/visualization.py +21 -13
- datamint-2.9.0.dist-info/METADATA +220 -0
- datamint-2.9.0.dist-info/RECORD +73 -0
- {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
- datamint-2.9.0.dist-info/entry_points.txt +18 -0
- datamint/apihandler/exp_api_handler.py +0 -204
- datamint/experiment/__init__.py +0 -1
- datamint/experiment/_patcher.py +0 -570
- datamint/experiment/experiment.py +0 -1049
- datamint-2.3.3.dist-info/METADATA +0 -125
- datamint-2.3.3.dist-info/RECORD +0 -54
- datamint-2.3.3.dist-info/entry_points.txt +0 -4
datamint/entities/resource.py
CHANGED
|
@@ -1,20 +1,27 @@
|
|
|
1
1
|
"""Resource entity module for DataMint API."""
|
|
2
2
|
|
|
3
3
|
from datetime import datetime
|
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
import logging
|
|
7
|
+
import urllib.parse
|
|
8
|
+
import urllib.request
|
|
6
9
|
|
|
7
10
|
from .base_entity import BaseEntity, MISSING_FIELD
|
|
8
11
|
from .cache_manager import CacheManager
|
|
9
12
|
from pydantic import PrivateAttr
|
|
10
|
-
from datamint.api.dto import AnnotationType
|
|
11
13
|
import webbrowser
|
|
12
|
-
|
|
14
|
+
import shutil
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from datamint.api.base_api import BaseApi
|
|
13
17
|
|
|
14
18
|
if TYPE_CHECKING:
|
|
15
19
|
from datamint.api.endpoints.resources_api import ResourcesApi
|
|
16
20
|
from .project import Project
|
|
17
|
-
from .annotation import Annotation
|
|
21
|
+
from .annotations.annotation import Annotation
|
|
22
|
+
from datamint.types import ImagingData
|
|
23
|
+
from datamint.api.dto import AnnotationType
|
|
24
|
+
|
|
18
25
|
|
|
19
26
|
logger = logging.getLogger(__name__)
|
|
20
27
|
|
|
@@ -70,47 +77,57 @@ class Resource(BaseEntity):
|
|
|
70
77
|
location: str
|
|
71
78
|
upload_channel: str
|
|
72
79
|
filename: str
|
|
73
|
-
modality: str
|
|
74
80
|
mimetype: str
|
|
75
81
|
size: int
|
|
76
|
-
upload_mechanism: str
|
|
77
82
|
customer_id: str
|
|
78
83
|
status: str
|
|
79
84
|
created_at: str
|
|
80
85
|
created_by: str
|
|
81
86
|
published: bool
|
|
82
87
|
deleted: bool
|
|
83
|
-
|
|
84
|
-
metadata: dict
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
upload_mechanism: str | None = None
|
|
89
|
+
# metadata: dict[str,Any] = {}
|
|
90
|
+
modality: str | None = None
|
|
91
|
+
source_filepath: str | None = None
|
|
92
|
+
# projects: list[dict[str, Any]] | None = None
|
|
93
|
+
published_on: str | None = None
|
|
94
|
+
published_by: str | None = None
|
|
88
95
|
tags: list[str] | None = None
|
|
89
|
-
publish_transforms:
|
|
96
|
+
# publish_transforms: dict[str, Any] | None = None
|
|
90
97
|
deleted_at: Optional[str] = None
|
|
91
98
|
deleted_by: Optional[str] = None
|
|
92
99
|
instance_uid: Optional[str] = None
|
|
93
100
|
series_uid: Optional[str] = None
|
|
94
101
|
study_uid: Optional[str] = None
|
|
95
102
|
patient_id: Optional[str] = None
|
|
96
|
-
segmentations: Optional[Any] = None # TODO: Define proper type when spec available
|
|
97
|
-
measurements: Optional[Any] = None # TODO: Define proper type when spec available
|
|
98
|
-
categories: Optional[Any] = None # TODO: Define proper type when spec available
|
|
99
|
-
user_info:
|
|
103
|
+
# segmentations: Optional[Any] = None # TODO: Define proper type when spec available
|
|
104
|
+
# measurements: Optional[Any] = None # TODO: Define proper type when spec available
|
|
105
|
+
# categories: Optional[Any] = None # TODO: Define proper type when spec available
|
|
106
|
+
user_info: dict[str, str | None] = MISSING_FIELD
|
|
100
107
|
|
|
101
108
|
_api: 'ResourcesApi' = PrivateAttr()
|
|
102
109
|
|
|
110
|
+
def __new__(cls, *args, **kwargs):
|
|
111
|
+
if cls is Resource and ('local_filepath' in kwargs or 'raw_data' in kwargs):
|
|
112
|
+
return super().__new__(LocalResource)
|
|
113
|
+
return super().__new__(cls)
|
|
114
|
+
|
|
103
115
|
def __init__(self, **data):
|
|
104
116
|
"""Initialize the resource entity."""
|
|
105
117
|
super().__init__(**data)
|
|
106
|
-
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def _cache(self) -> CacheManager[bytes]:
|
|
121
|
+
if not hasattr(self, '__cache'):
|
|
122
|
+
self.__cache = CacheManager[bytes]('resources')
|
|
123
|
+
return self.__cache
|
|
107
124
|
|
|
108
125
|
def fetch_file_data(
|
|
109
126
|
self,
|
|
110
127
|
auto_convert: bool = True,
|
|
111
128
|
save_path: str | None = None,
|
|
112
129
|
use_cache: bool = False,
|
|
113
|
-
) -> bytes | ImagingData:
|
|
130
|
+
) -> 'bytes | ImagingData':
|
|
114
131
|
"""Get the file data for this resource.
|
|
115
132
|
|
|
116
133
|
This method automatically caches the file data locally. On subsequent
|
|
@@ -119,7 +136,9 @@ class Resource(BaseEntity):
|
|
|
119
136
|
Args:
|
|
120
137
|
use_cache: If True, uses cached data when available and valid
|
|
121
138
|
auto_convert: If True, automatically converts to appropriate format (pydicom.Dataset, PIL Image, etc.)
|
|
122
|
-
save_path: Optional path to save the file locally
|
|
139
|
+
save_path: Optional path to save the file locally. If use_cache is also True,
|
|
140
|
+
the file is saved to save_path and cache metadata points to that location
|
|
141
|
+
(no duplication - only one file on disk).
|
|
123
142
|
|
|
124
143
|
Returns:
|
|
125
144
|
File data (format depends on auto_convert and file type)
|
|
@@ -127,31 +146,30 @@ class Resource(BaseEntity):
|
|
|
127
146
|
# Version info for cache validation
|
|
128
147
|
version_info = self._generate_version_info()
|
|
129
148
|
|
|
130
|
-
#
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
img_data = self._cache.get(self.id, _IMAGE_CACHEKEY, version_info)
|
|
134
|
-
if img_data is not None:
|
|
135
|
-
logger.debug(f"Using cached image data for resource {self.id}")
|
|
136
|
-
|
|
137
|
-
if img_data is None:
|
|
138
|
-
# Fetch from server using download_resource_file
|
|
139
|
-
logger.debug(f"Fetching image data from server for resource {self.id}")
|
|
140
|
-
img_data = self._api.download_resource_file(
|
|
149
|
+
# Download callback for the shared caching logic
|
|
150
|
+
def download_callback(path: str | None) -> bytes:
|
|
151
|
+
return self._api.download_resource_file(
|
|
141
152
|
self,
|
|
142
|
-
save_path=
|
|
153
|
+
save_path=path,
|
|
143
154
|
auto_convert=False
|
|
144
155
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
156
|
+
|
|
157
|
+
# Use shared caching logic from BaseEntity
|
|
158
|
+
img_data = self._fetch_and_cache_file_data(
|
|
159
|
+
cache_manager=self._cache,
|
|
160
|
+
data_key=_IMAGE_CACHEKEY,
|
|
161
|
+
version_info=version_info,
|
|
162
|
+
download_callback=download_callback,
|
|
163
|
+
save_path=save_path,
|
|
164
|
+
use_cache=use_cache,
|
|
165
|
+
)
|
|
148
166
|
|
|
149
167
|
if auto_convert:
|
|
150
168
|
try:
|
|
151
|
-
mimetype, ext =
|
|
152
|
-
img_data =
|
|
153
|
-
|
|
154
|
-
|
|
169
|
+
mimetype, ext = BaseApi._determine_mimetype(img_data, self.mimetype)
|
|
170
|
+
img_data = BaseApi.convert_format(img_data,
|
|
171
|
+
mimetype=mimetype,
|
|
172
|
+
file_path=save_path)
|
|
155
173
|
except Exception as e:
|
|
156
174
|
logger.error(f"Failed to auto-convert resource {self.id}: {e}")
|
|
157
175
|
|
|
@@ -170,17 +188,36 @@ class Resource(BaseEntity):
|
|
|
170
188
|
version_info = self._generate_version_info()
|
|
171
189
|
self._cache.set(self.id, _IMAGE_CACHEKEY, data, version_info)
|
|
172
190
|
|
|
191
|
+
def is_cached(self) -> bool:
|
|
192
|
+
"""Check if the resource's file data is already cached locally and valid.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
True if valid cached data exists, False otherwise.
|
|
196
|
+
"""
|
|
197
|
+
version_info = self._generate_version_info()
|
|
198
|
+
cached_data = self._cache.get(self.id, _IMAGE_CACHEKEY, version_info)
|
|
199
|
+
return cached_data is not None
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def filepath_cached(self) -> Path | None:
|
|
203
|
+
"""Get the file path of the cached resource data, if available.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Path to the cached file data, or None if not cached.
|
|
207
|
+
"""
|
|
208
|
+
if self._cache is None:
|
|
209
|
+
return None
|
|
210
|
+
version_info = self._generate_version_info()
|
|
211
|
+
path = self._cache.get_path(self.id, _IMAGE_CACHEKEY, version_info)
|
|
212
|
+
return path
|
|
213
|
+
|
|
173
214
|
def fetch_annotations(
|
|
174
215
|
self,
|
|
175
|
-
annotation_type: AnnotationType | str | None = None
|
|
216
|
+
annotation_type: 'AnnotationType | str | None' = None
|
|
176
217
|
) -> Sequence['Annotation']:
|
|
177
218
|
"""Get annotations associated with this resource."""
|
|
178
219
|
|
|
179
|
-
annotations = self._api.get_annotations(self)
|
|
180
|
-
|
|
181
|
-
if annotation_type:
|
|
182
|
-
annotation_type = AnnotationType(annotation_type)
|
|
183
|
-
annotations = [a for a in annotations if a.annotation_type == annotation_type]
|
|
220
|
+
annotations = self._api.get_annotations(self, annotation_type=annotation_type)
|
|
184
221
|
return annotations
|
|
185
222
|
|
|
186
223
|
# def get_projects(
|
|
@@ -193,7 +230,6 @@ class Resource(BaseEntity):
|
|
|
193
230
|
# """
|
|
194
231
|
# return self._api.get_projects(self)
|
|
195
232
|
|
|
196
|
-
|
|
197
233
|
def invalidate_cache(self) -> None:
|
|
198
234
|
"""Invalidate cached data for this resource.
|
|
199
235
|
"""
|
|
@@ -245,7 +281,7 @@ class Resource(BaseEntity):
|
|
|
245
281
|
f"modality='{self.modality}', status='{self.status}', "
|
|
246
282
|
f"published={self.published})"
|
|
247
283
|
)
|
|
248
|
-
|
|
284
|
+
|
|
249
285
|
@property
|
|
250
286
|
def url(self) -> str:
|
|
251
287
|
"""Get the URL to access this resource in the DataMint web application."""
|
|
@@ -255,3 +291,231 @@ class Resource(BaseEntity):
|
|
|
255
291
|
def show(self) -> None:
|
|
256
292
|
"""Open the resource in the default web browser."""
|
|
257
293
|
webbrowser.open(self.url)
|
|
294
|
+
|
|
295
|
+
@staticmethod
|
|
296
|
+
def from_local_file(file_path: str | Path):
|
|
297
|
+
"""Create a LocalResource instance from a local file path.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
file_path: Path to the local file
|
|
301
|
+
"""
|
|
302
|
+
return LocalResource(local_filepath=file_path)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class LocalResource(Resource):
|
|
306
|
+
"""Represents a local resource that hasn't been uploaded to DataMint API yet."""
|
|
307
|
+
|
|
308
|
+
local_filepath: str | None = None
|
|
309
|
+
raw_data: bytes | None = None
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def filepath_cached(self) -> str | None:
|
|
313
|
+
"""Get the file path of the local resource data.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Path to the local file, or None if only raw data is available.
|
|
317
|
+
"""
|
|
318
|
+
return self.local_filepath
|
|
319
|
+
|
|
320
|
+
def __init__(self,
|
|
321
|
+
local_filepath: str | Path | None = None,
|
|
322
|
+
raw_data: bytes | None = None,
|
|
323
|
+
convert_to_bytes: bool = False,
|
|
324
|
+
**kwargs):
|
|
325
|
+
"""Initialize a local resource from a local file path, URL, or raw data.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
local_filepath: Path to the local file or URL to an online image
|
|
329
|
+
raw_data: Raw bytes of the file data
|
|
330
|
+
convert_to_bytes: If True and local_filepath is provided, read file into raw_data
|
|
331
|
+
"""
|
|
332
|
+
from medimgkit.format_detection import guess_type, DEFAULT_MIME_TYPE
|
|
333
|
+
from medimgkit.modality_detector import detect_modality
|
|
334
|
+
|
|
335
|
+
if raw_data is None and local_filepath is None:
|
|
336
|
+
raise ValueError("Either local_filepath or raw_data must be provided.")
|
|
337
|
+
if raw_data is not None and local_filepath is not None:
|
|
338
|
+
raise ValueError("Only one of local_filepath or raw_data should be provided.")
|
|
339
|
+
|
|
340
|
+
# Check if local_filepath is a URL
|
|
341
|
+
if local_filepath is not None:
|
|
342
|
+
local_filepath_str = str(local_filepath)
|
|
343
|
+
if local_filepath_str.startswith(('http://', 'https://')):
|
|
344
|
+
# Download content from URL
|
|
345
|
+
logger.debug(f"Downloading resource from URL: {local_filepath_str}")
|
|
346
|
+
try:
|
|
347
|
+
with urllib.request.urlopen(local_filepath_str) as response:
|
|
348
|
+
raw_data = response.read()
|
|
349
|
+
# Try to get content-type from response headers
|
|
350
|
+
content_type = response.headers.get('Content-Type', '').split(';')[0].strip()
|
|
351
|
+
except Exception as e:
|
|
352
|
+
raise ValueError(f"Failed to download from URL: {local_filepath_str}") from e
|
|
353
|
+
|
|
354
|
+
# Extract filename from URL
|
|
355
|
+
parsed_url = urllib.parse.urlparse(local_filepath_str)
|
|
356
|
+
url_path = urllib.parse.unquote(parsed_url.path)
|
|
357
|
+
filename = Path(url_path).name if url_path else 'downloaded_file'
|
|
358
|
+
|
|
359
|
+
# Determine mimetype
|
|
360
|
+
mimetype, _ = guess_type(raw_data)
|
|
361
|
+
if mimetype is None and content_type:
|
|
362
|
+
mimetype = content_type
|
|
363
|
+
if mimetype is None:
|
|
364
|
+
mimetype = DEFAULT_MIME_TYPE
|
|
365
|
+
|
|
366
|
+
default_values = {
|
|
367
|
+
'id': '',
|
|
368
|
+
'resource_uri': '',
|
|
369
|
+
'storage': '',
|
|
370
|
+
'location': local_filepath_str,
|
|
371
|
+
'upload_channel': '',
|
|
372
|
+
'filename': filename,
|
|
373
|
+
'modality': None,
|
|
374
|
+
'mimetype': mimetype,
|
|
375
|
+
'size': len(raw_data),
|
|
376
|
+
'upload_mechanism': '',
|
|
377
|
+
'customer_id': '',
|
|
378
|
+
'status': 'local',
|
|
379
|
+
'created_at': datetime.now().isoformat(),
|
|
380
|
+
'created_by': '',
|
|
381
|
+
'published': False,
|
|
382
|
+
'deleted': False,
|
|
383
|
+
'source_filepath': local_filepath_str,
|
|
384
|
+
}
|
|
385
|
+
new_kwargs = kwargs.copy()
|
|
386
|
+
for key, value in default_values.items():
|
|
387
|
+
new_kwargs.setdefault(key, value)
|
|
388
|
+
super(Resource, self).__init__(
|
|
389
|
+
local_filepath=None,
|
|
390
|
+
raw_data=raw_data,
|
|
391
|
+
**new_kwargs
|
|
392
|
+
)
|
|
393
|
+
return
|
|
394
|
+
|
|
395
|
+
if convert_to_bytes and local_filepath:
|
|
396
|
+
with open(local_filepath, 'rb') as f:
|
|
397
|
+
raw_data = f.read()
|
|
398
|
+
local_filepath = None
|
|
399
|
+
if raw_data is not None:
|
|
400
|
+
# import io
|
|
401
|
+
if isinstance(raw_data, str):
|
|
402
|
+
mimetype, _ = guess_type(raw_data.encode())
|
|
403
|
+
else:
|
|
404
|
+
mimetype, _ = guess_type(raw_data)
|
|
405
|
+
default_values = {
|
|
406
|
+
'id': '',
|
|
407
|
+
'resource_uri': '',
|
|
408
|
+
'storage': '',
|
|
409
|
+
'location': '',
|
|
410
|
+
'upload_channel': '',
|
|
411
|
+
'filename': 'raw_data',
|
|
412
|
+
'modality': None,
|
|
413
|
+
'mimetype': mimetype if mimetype else DEFAULT_MIME_TYPE,
|
|
414
|
+
'size': len(raw_data),
|
|
415
|
+
'upload_mechanism': '',
|
|
416
|
+
'customer_id': '',
|
|
417
|
+
'status': 'local',
|
|
418
|
+
'created_at': datetime.now().isoformat(),
|
|
419
|
+
'created_by': '',
|
|
420
|
+
'published': False,
|
|
421
|
+
'deleted': False,
|
|
422
|
+
'source_filepath': None,
|
|
423
|
+
}
|
|
424
|
+
new_kwargs = kwargs.copy()
|
|
425
|
+
for key, value in default_values.items():
|
|
426
|
+
new_kwargs.setdefault(key, value)
|
|
427
|
+
super().__init__(
|
|
428
|
+
local_filepath=None,
|
|
429
|
+
raw_data=raw_data,
|
|
430
|
+
**new_kwargs
|
|
431
|
+
)
|
|
432
|
+
elif local_filepath is not None:
|
|
433
|
+
file_path = Path(local_filepath)
|
|
434
|
+
if not file_path.exists():
|
|
435
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
436
|
+
|
|
437
|
+
mimetype, _ = guess_type(file_path)
|
|
438
|
+
if mimetype is None or mimetype == DEFAULT_MIME_TYPE:
|
|
439
|
+
logger.warning(f"Could not determine mimetype for file: {file_path}")
|
|
440
|
+
size = file_path.stat().st_size
|
|
441
|
+
created_at = datetime.fromtimestamp(file_path.stat().st_ctime).isoformat()
|
|
442
|
+
|
|
443
|
+
super().__init__(
|
|
444
|
+
id="",
|
|
445
|
+
resource_uri="",
|
|
446
|
+
storage="",
|
|
447
|
+
location=str(file_path),
|
|
448
|
+
upload_channel="",
|
|
449
|
+
filename=file_path.name,
|
|
450
|
+
modality=detect_modality(file_path),
|
|
451
|
+
mimetype=mimetype,
|
|
452
|
+
size=size,
|
|
453
|
+
upload_mechanism="",
|
|
454
|
+
customer_id="",
|
|
455
|
+
status="local",
|
|
456
|
+
created_at=created_at,
|
|
457
|
+
created_by="",
|
|
458
|
+
published=False,
|
|
459
|
+
deleted=False,
|
|
460
|
+
source_filepath=str(file_path),
|
|
461
|
+
local_filepath=str(file_path),
|
|
462
|
+
raw_data=None,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def fetch_file_data(
|
|
466
|
+
self, *args,
|
|
467
|
+
auto_convert: bool = True,
|
|
468
|
+
save_path: str | None = None,
|
|
469
|
+
**kwargs,
|
|
470
|
+
) -> 'bytes | ImagingData':
|
|
471
|
+
"""Get the file data for this local resource.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
auto_convert: If True, automatically converts to appropriate format (pydicom.Dataset, PIL Image, etc.)
|
|
475
|
+
save_path: Optional path to save the file locally
|
|
476
|
+
Returns:
|
|
477
|
+
File data (format depends on auto_convert and file type)
|
|
478
|
+
"""
|
|
479
|
+
if self.raw_data is not None:
|
|
480
|
+
img_data = self.raw_data
|
|
481
|
+
local_filepath = None
|
|
482
|
+
else:
|
|
483
|
+
local_filepath = str(self.local_filepath)
|
|
484
|
+
with open(local_filepath, 'rb') as f:
|
|
485
|
+
img_data = f.read()
|
|
486
|
+
|
|
487
|
+
if save_path:
|
|
488
|
+
with open(save_path, 'wb') as f:
|
|
489
|
+
f.write(img_data)
|
|
490
|
+
|
|
491
|
+
if auto_convert:
|
|
492
|
+
try:
|
|
493
|
+
mimetype, ext = BaseApi._determine_mimetype(img_data, self.mimetype)
|
|
494
|
+
img_data = BaseApi.convert_format(img_data,
|
|
495
|
+
mimetype=mimetype,
|
|
496
|
+
file_path=local_filepath)
|
|
497
|
+
except Exception as e:
|
|
498
|
+
logger.error(f"Failed to auto-convert local resource: {e}")
|
|
499
|
+
logger.error(e, exc_info=True)
|
|
500
|
+
|
|
501
|
+
return img_data
|
|
502
|
+
|
|
503
|
+
def __str__(self) -> str:
|
|
504
|
+
"""String representation of the local resource.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
Human-readable string describing the local resource
|
|
508
|
+
"""
|
|
509
|
+
return f"LocalResource(filepath='{self.local_filepath}', size={self.size_mb}MB)"
|
|
510
|
+
|
|
511
|
+
def __repr__(self) -> str:
|
|
512
|
+
"""Detailed string representation of the local resource.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
Detailed string representation for debugging
|
|
516
|
+
"""
|
|
517
|
+
return (
|
|
518
|
+
f"LocalResource(filepath='{self.local_filepath}', "
|
|
519
|
+
f"filename='{self.filename}', modality='{self.modality}', "
|
|
520
|
+
f"size={self.size_mb}MB)"
|
|
521
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamintdatamodule import DatamintDataModule
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from torch.utils.data import DataLoader
|
|
2
|
+
from datamint import Dataset
|
|
3
|
+
import lightning as L
|
|
4
|
+
from typing import Any
|
|
5
|
+
from copy import copy
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatamintDataModule(L.LightningDataModule):
|
|
10
|
+
"""
|
|
11
|
+
LightningDataModule for Datamint datasets with train/val split.
|
|
12
|
+
TODO: Add support for test and predict dataloaders.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
project_name: str = "./",
|
|
18
|
+
batch_size: int = 32,
|
|
19
|
+
image_transform=None,
|
|
20
|
+
mask_transform=None,
|
|
21
|
+
alb_transform=None,
|
|
22
|
+
alb_train_transform=None,
|
|
23
|
+
alb_val_transform=None,
|
|
24
|
+
train_split: float = 0.9,
|
|
25
|
+
val_split: float = 0.1,
|
|
26
|
+
seed: int = 42,
|
|
27
|
+
num_workers: int = 4,
|
|
28
|
+
**dataset_kwargs: Any,
|
|
29
|
+
):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.project_name = project_name
|
|
32
|
+
self.batch_size = batch_size
|
|
33
|
+
self.image_transform = image_transform
|
|
34
|
+
self.mask_transform = mask_transform
|
|
35
|
+
|
|
36
|
+
if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
|
|
37
|
+
raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
|
|
38
|
+
|
|
39
|
+
# Handle backward compatibility for alb_transform
|
|
40
|
+
if alb_transform is not None:
|
|
41
|
+
self.alb_train_transform = alb_transform
|
|
42
|
+
self.alb_val_transform = alb_transform
|
|
43
|
+
else:
|
|
44
|
+
self.alb_train_transform = alb_train_transform
|
|
45
|
+
self.alb_val_transform = alb_val_transform
|
|
46
|
+
|
|
47
|
+
self.train_split = train_split
|
|
48
|
+
self.val_split = val_split
|
|
49
|
+
self.seed = seed
|
|
50
|
+
self.dataset_kwargs = dataset_kwargs
|
|
51
|
+
self.num_workers = num_workers
|
|
52
|
+
|
|
53
|
+
self.dataset = None
|
|
54
|
+
|
|
55
|
+
def prepare_data(self) -> None:
|
|
56
|
+
"""Download or update data if needed."""
|
|
57
|
+
Dataset(
|
|
58
|
+
project_name=self.project_name,
|
|
59
|
+
auto_update=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def setup(self, stage: str = None) -> None:
|
|
63
|
+
"""Set up datasets and perform train/val split."""
|
|
64
|
+
if self.dataset is None:
|
|
65
|
+
# Create base dataset for getting indices
|
|
66
|
+
self.dataset = Dataset(
|
|
67
|
+
return_as_semantic_segmentation=True,
|
|
68
|
+
semantic_seg_merge_strategy="union",
|
|
69
|
+
return_frame_by_frame=True,
|
|
70
|
+
include_unannotated=False,
|
|
71
|
+
project_name=self.project_name,
|
|
72
|
+
image_transform=self.image_transform,
|
|
73
|
+
mask_transform=self.mask_transform,
|
|
74
|
+
alb_transform=None, # No transform for base dataset
|
|
75
|
+
auto_update=False,
|
|
76
|
+
**self.dataset_kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
indices = list(copy(self.dataset.subset_indices))
|
|
80
|
+
rs = np.random.RandomState(self.seed)
|
|
81
|
+
rs.shuffle(indices)
|
|
82
|
+
train_end = int(self.train_split * len(indices))
|
|
83
|
+
train_idx = indices[:train_end]
|
|
84
|
+
val_idx = indices[train_end:]
|
|
85
|
+
|
|
86
|
+
self.train_dataset = copy(self.dataset).subset(train_idx)
|
|
87
|
+
self.train_dataset.alb_transform = self.alb_train_transform
|
|
88
|
+
self.val_dataset = copy(self.dataset).subset(val_idx)
|
|
89
|
+
self.val_dataset.alb_transform = self.alb_val_transform
|
|
90
|
+
|
|
91
|
+
def train_dataloader(self) -> DataLoader:
|
|
92
|
+
return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
|
93
|
+
|
|
94
|
+
def val_dataloader(self) -> DataLoader:
|
|
95
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
96
|
+
|
|
97
|
+
def test_dataloader(self):
|
|
98
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
99
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
100
|
+
|
|
101
|
+
def predict_dataloader(self):
|
|
102
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
103
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
|
|
2
|
+
from .tracking.fluent import set_project
|
|
3
|
+
import mlflow.tracking._tracking_service.utils as mlflow_utils
|
|
4
|
+
from functools import wraps
|
|
5
|
+
import logging
|
|
6
|
+
from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
_LOGGER = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
# Store reference to original function
|
|
12
|
+
_original_get_tracking_uri = mlflow_utils.get_tracking_uri
|
|
13
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
14
|
+
|
|
15
|
+
if mlflow_utils.is_tracking_uri_set():
|
|
16
|
+
_LOGGER.warning("MLflow tracking URI is already set before patching get_tracking_uri.")
|
|
17
|
+
|
|
18
|
+
@wraps(_original_get_tracking_uri)
|
|
19
|
+
def _patched_get_tracking_uri(*args, **kwargs):
|
|
20
|
+
"""Patched version of get_tracking_uri that ensures MLflow environment is set up first.
|
|
21
|
+
|
|
22
|
+
This wrapper ensures that setup_mlflow_environment is called before any tracking
|
|
23
|
+
URI operations, guaranteeing proper MLflow configuration.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
*args: Arguments passed to the original get_tracking_uri function.
|
|
27
|
+
**kwargs: Keyword arguments passed to the original get_tracking_uri function.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The result of the original get_tracking_uri function.
|
|
31
|
+
"""
|
|
32
|
+
global _SETUP_CALLED_SUCCESSFULLY
|
|
33
|
+
if _SETUP_CALLED_SUCCESSFULLY:
|
|
34
|
+
return _original_get_tracking_uri(*args, **kwargs)
|
|
35
|
+
if mlflow_utils.is_tracking_uri_set():
|
|
36
|
+
_LOGGER.warning("MLflow tracking URI is already set before patching get_tracking_uri.")
|
|
37
|
+
try:
|
|
38
|
+
_SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
|
|
39
|
+
except Exception as e:
|
|
40
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
41
|
+
_LOGGER.error("Failed to set up MLflow environment: %s", e)
|
|
42
|
+
ret = _original_get_tracking_uri(*args, **kwargs)
|
|
43
|
+
return ret
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
setup_mlflow_environment(set_mlflow=False)
|
|
47
|
+
# Replace the original function with our patched version
|
|
48
|
+
mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if TYPE_CHECKING:
|
|
52
|
+
from .flavors.model import DatamintModel
|
|
53
|
+
else:
|
|
54
|
+
import lazy_loader as lazy
|
|
55
|
+
|
|
56
|
+
__getattr__, __dir__, __all__ = lazy.attach(
|
|
57
|
+
__name__,
|
|
58
|
+
submodules=['flavors.model', 'flavors.datamint_flavor'],
|
|
59
|
+
submod_attrs={
|
|
60
|
+
"flavors.model": ["DatamintModel"],
|
|
61
|
+
"flavors.datamint_flavor": ["log_model", "load_model"],
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
__all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured', 'DatamintModel']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamint_artifacts_repo import DatamintArtifactsRepository
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DatamintArtifactsRepository(MlflowArtifactsRepository):
|
|
5
|
+
@classmethod
|
|
6
|
+
def resolve_uri(cls, artifact_uri, tracking_uri):
|
|
7
|
+
tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
|
|
8
|
+
return super().resolve_uri(artifact_uri, tracking_uri)
|