datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -1,20 +1,27 @@
1
1
  """Resource entity module for DataMint API."""
2
2
 
3
3
  from datetime import datetime
4
- from typing import TYPE_CHECKING, Optional, Any, Sequence
4
+ from typing import TYPE_CHECKING, Optional
5
+ from collections.abc import Sequence
5
6
  import logging
7
+ import urllib.parse
8
+ import urllib.request
6
9
 
7
10
  from .base_entity import BaseEntity, MISSING_FIELD
8
11
  from .cache_manager import CacheManager
9
12
  from pydantic import PrivateAttr
10
- from datamint.api.dto import AnnotationType
11
13
  import webbrowser
12
- from datamint.types import ImagingData
14
+ import shutil
15
+ from pathlib import Path
16
+ from datamint.api.base_api import BaseApi
13
17
 
14
18
  if TYPE_CHECKING:
15
19
  from datamint.api.endpoints.resources_api import ResourcesApi
16
20
  from .project import Project
17
- from .annotation import Annotation
21
+ from .annotations.annotation import Annotation
22
+ from datamint.types import ImagingData
23
+ from datamint.api.dto import AnnotationType
24
+
18
25
 
19
26
  logger = logging.getLogger(__name__)
20
27
 
@@ -70,47 +77,57 @@ class Resource(BaseEntity):
70
77
  location: str
71
78
  upload_channel: str
72
79
  filename: str
73
- modality: str
74
80
  mimetype: str
75
81
  size: int
76
- upload_mechanism: str
77
82
  customer_id: str
78
83
  status: str
79
84
  created_at: str
80
85
  created_by: str
81
86
  published: bool
82
87
  deleted: bool
83
- source_filepath: str | None
84
- metadata: dict
85
- projects: list[dict] = MISSING_FIELD
86
- published_on: str | None
87
- published_by: str | None
88
+ upload_mechanism: str | None = None
89
+ # metadata: dict[str,Any] = {}
90
+ modality: str | None = None
91
+ source_filepath: str | None = None
92
+ # projects: list[dict[str, Any]] | None = None
93
+ published_on: str | None = None
94
+ published_by: str | None = None
88
95
  tags: list[str] | None = None
89
- publish_transforms: Optional[Any] = None
96
+ # publish_transforms: dict[str, Any] | None = None
90
97
  deleted_at: Optional[str] = None
91
98
  deleted_by: Optional[str] = None
92
99
  instance_uid: Optional[str] = None
93
100
  series_uid: Optional[str] = None
94
101
  study_uid: Optional[str] = None
95
102
  patient_id: Optional[str] = None
96
- segmentations: Optional[Any] = None # TODO: Define proper type when spec available
97
- measurements: Optional[Any] = None # TODO: Define proper type when spec available
98
- categories: Optional[Any] = None # TODO: Define proper type when spec available
99
- user_info: Optional[dict] = None
103
+ # segmentations: Optional[Any] = None # TODO: Define proper type when spec available
104
+ # measurements: Optional[Any] = None # TODO: Define proper type when spec available
105
+ # categories: Optional[Any] = None # TODO: Define proper type when spec available
106
+ user_info: dict[str, str | None] = MISSING_FIELD
100
107
 
101
108
  _api: 'ResourcesApi' = PrivateAttr()
102
109
 
110
+ def __new__(cls, *args, **kwargs):
111
+ if cls is Resource and ('local_filepath' in kwargs or 'raw_data' in kwargs):
112
+ return super().__new__(LocalResource)
113
+ return super().__new__(cls)
114
+
103
115
  def __init__(self, **data):
104
116
  """Initialize the resource entity."""
105
117
  super().__init__(**data)
106
- self._cache: CacheManager[bytes] = CacheManager[bytes]('resources')
118
+
119
+ @property
120
+ def _cache(self) -> CacheManager[bytes]:
121
+ if not hasattr(self, '__cache'):
122
+ self.__cache = CacheManager[bytes]('resources')
123
+ return self.__cache
107
124
 
108
125
  def fetch_file_data(
109
126
  self,
110
127
  auto_convert: bool = True,
111
128
  save_path: str | None = None,
112
129
  use_cache: bool = False,
113
- ) -> bytes | ImagingData:
130
+ ) -> 'bytes | ImagingData':
114
131
  """Get the file data for this resource.
115
132
 
116
133
  This method automatically caches the file data locally. On subsequent
@@ -119,7 +136,9 @@ class Resource(BaseEntity):
119
136
  Args:
120
137
  use_cache: If True, uses cached data when available and valid
121
138
  auto_convert: If True, automatically converts to appropriate format (pydicom.Dataset, PIL Image, etc.)
122
- save_path: Optional path to save the file locally
139
+ save_path: Optional path to save the file locally. If use_cache is also True,
140
+ the file is saved to save_path and cache metadata points to that location
141
+ (no duplication - only one file on disk).
123
142
 
124
143
  Returns:
125
144
  File data (format depends on auto_convert and file type)
@@ -127,31 +146,30 @@ class Resource(BaseEntity):
127
146
  # Version info for cache validation
128
147
  version_info = self._generate_version_info()
129
148
 
130
- # Try to get from cache
131
- img_data = None
132
- if use_cache:
133
- img_data = self._cache.get(self.id, _IMAGE_CACHEKEY, version_info)
134
- if img_data is not None:
135
- logger.debug(f"Using cached image data for resource {self.id}")
136
-
137
- if img_data is None:
138
- # Fetch from server using download_resource_file
139
- logger.debug(f"Fetching image data from server for resource {self.id}")
140
- img_data = self._api.download_resource_file(
149
+ # Download callback for the shared caching logic
150
+ def download_callback(path: str | None) -> bytes:
151
+ return self._api.download_resource_file(
141
152
  self,
142
- save_path=save_path,
153
+ save_path=path,
143
154
  auto_convert=False
144
155
  )
145
- # Cache the data
146
- if use_cache:
147
- self._cache.set(self.id, _IMAGE_CACHEKEY, img_data, version_info)
156
+
157
+ # Use shared caching logic from BaseEntity
158
+ img_data = self._fetch_and_cache_file_data(
159
+ cache_manager=self._cache,
160
+ data_key=_IMAGE_CACHEKEY,
161
+ version_info=version_info,
162
+ download_callback=download_callback,
163
+ save_path=save_path,
164
+ use_cache=use_cache,
165
+ )
148
166
 
149
167
  if auto_convert:
150
168
  try:
151
- mimetype, ext = self._api._determine_mimetype(img_data, self)
152
- img_data = self._api.convert_format(img_data,
153
- mimetype=mimetype,
154
- file_path=save_path)
169
+ mimetype, ext = BaseApi._determine_mimetype(img_data, self.mimetype)
170
+ img_data = BaseApi.convert_format(img_data,
171
+ mimetype=mimetype,
172
+ file_path=save_path)
155
173
  except Exception as e:
156
174
  logger.error(f"Failed to auto-convert resource {self.id}: {e}")
157
175
 
@@ -170,17 +188,36 @@ class Resource(BaseEntity):
170
188
  version_info = self._generate_version_info()
171
189
  self._cache.set(self.id, _IMAGE_CACHEKEY, data, version_info)
172
190
 
191
+ def is_cached(self) -> bool:
192
+ """Check if the resource's file data is already cached locally and valid.
193
+
194
+ Returns:
195
+ True if valid cached data exists, False otherwise.
196
+ """
197
+ version_info = self._generate_version_info()
198
+ cached_data = self._cache.get(self.id, _IMAGE_CACHEKEY, version_info)
199
+ return cached_data is not None
200
+
201
+ @property
202
+ def filepath_cached(self) -> Path | None:
203
+ """Get the file path of the cached resource data, if available.
204
+
205
+ Returns:
206
+ Path to the cached file data, or None if not cached.
207
+ """
208
+ if self._cache is None:
209
+ return None
210
+ version_info = self._generate_version_info()
211
+ path = self._cache.get_path(self.id, _IMAGE_CACHEKEY, version_info)
212
+ return path
213
+
173
214
  def fetch_annotations(
174
215
  self,
175
- annotation_type: AnnotationType | str | None = None
216
+ annotation_type: 'AnnotationType | str | None' = None
176
217
  ) -> Sequence['Annotation']:
177
218
  """Get annotations associated with this resource."""
178
219
 
179
- annotations = self._api.get_annotations(self)
180
-
181
- if annotation_type:
182
- annotation_type = AnnotationType(annotation_type)
183
- annotations = [a for a in annotations if a.annotation_type == annotation_type]
220
+ annotations = self._api.get_annotations(self, annotation_type=annotation_type)
184
221
  return annotations
185
222
 
186
223
  # def get_projects(
@@ -193,7 +230,6 @@ class Resource(BaseEntity):
193
230
  # """
194
231
  # return self._api.get_projects(self)
195
232
 
196
-
197
233
  def invalidate_cache(self) -> None:
198
234
  """Invalidate cached data for this resource.
199
235
  """
@@ -245,7 +281,7 @@ class Resource(BaseEntity):
245
281
  f"modality='{self.modality}', status='{self.status}', "
246
282
  f"published={self.published})"
247
283
  )
248
-
284
+
249
285
  @property
250
286
  def url(self) -> str:
251
287
  """Get the URL to access this resource in the DataMint web application."""
@@ -255,3 +291,231 @@ class Resource(BaseEntity):
255
291
  def show(self) -> None:
256
292
  """Open the resource in the default web browser."""
257
293
  webbrowser.open(self.url)
294
+
295
+ @staticmethod
296
+ def from_local_file(file_path: str | Path):
297
+ """Create a LocalResource instance from a local file path.
298
+
299
+ Args:
300
+ file_path: Path to the local file
301
+ """
302
+ return LocalResource(local_filepath=file_path)
303
+
304
+
305
+ class LocalResource(Resource):
306
+ """Represents a local resource that hasn't been uploaded to DataMint API yet."""
307
+
308
+ local_filepath: str | None = None
309
+ raw_data: bytes | None = None
310
+
311
+ @property
312
+ def filepath_cached(self) -> str | None:
313
+ """Get the file path of the local resource data.
314
+
315
+ Returns:
316
+ Path to the local file, or None if only raw data is available.
317
+ """
318
+ return self.local_filepath
319
+
320
+ def __init__(self,
321
+ local_filepath: str | Path | None = None,
322
+ raw_data: bytes | None = None,
323
+ convert_to_bytes: bool = False,
324
+ **kwargs):
325
+ """Initialize a local resource from a local file path, URL, or raw data.
326
+
327
+ Args:
328
+ local_filepath: Path to the local file or URL to an online image
329
+ raw_data: Raw bytes of the file data
330
+ convert_to_bytes: If True and local_filepath is provided, read file into raw_data
331
+ """
332
+ from medimgkit.format_detection import guess_type, DEFAULT_MIME_TYPE
333
+ from medimgkit.modality_detector import detect_modality
334
+
335
+ if raw_data is None and local_filepath is None:
336
+ raise ValueError("Either local_filepath or raw_data must be provided.")
337
+ if raw_data is not None and local_filepath is not None:
338
+ raise ValueError("Only one of local_filepath or raw_data should be provided.")
339
+
340
+ # Check if local_filepath is a URL
341
+ if local_filepath is not None:
342
+ local_filepath_str = str(local_filepath)
343
+ if local_filepath_str.startswith(('http://', 'https://')):
344
+ # Download content from URL
345
+ logger.debug(f"Downloading resource from URL: {local_filepath_str}")
346
+ try:
347
+ with urllib.request.urlopen(local_filepath_str) as response:
348
+ raw_data = response.read()
349
+ # Try to get content-type from response headers
350
+ content_type = response.headers.get('Content-Type', '').split(';')[0].strip()
351
+ except Exception as e:
352
+ raise ValueError(f"Failed to download from URL: {local_filepath_str}") from e
353
+
354
+ # Extract filename from URL
355
+ parsed_url = urllib.parse.urlparse(local_filepath_str)
356
+ url_path = urllib.parse.unquote(parsed_url.path)
357
+ filename = Path(url_path).name if url_path else 'downloaded_file'
358
+
359
+ # Determine mimetype
360
+ mimetype, _ = guess_type(raw_data)
361
+ if mimetype is None and content_type:
362
+ mimetype = content_type
363
+ if mimetype is None:
364
+ mimetype = DEFAULT_MIME_TYPE
365
+
366
+ default_values = {
367
+ 'id': '',
368
+ 'resource_uri': '',
369
+ 'storage': '',
370
+ 'location': local_filepath_str,
371
+ 'upload_channel': '',
372
+ 'filename': filename,
373
+ 'modality': None,
374
+ 'mimetype': mimetype,
375
+ 'size': len(raw_data),
376
+ 'upload_mechanism': '',
377
+ 'customer_id': '',
378
+ 'status': 'local',
379
+ 'created_at': datetime.now().isoformat(),
380
+ 'created_by': '',
381
+ 'published': False,
382
+ 'deleted': False,
383
+ 'source_filepath': local_filepath_str,
384
+ }
385
+ new_kwargs = kwargs.copy()
386
+ for key, value in default_values.items():
387
+ new_kwargs.setdefault(key, value)
388
+ super(Resource, self).__init__(
389
+ local_filepath=None,
390
+ raw_data=raw_data,
391
+ **new_kwargs
392
+ )
393
+ return
394
+
395
+ if convert_to_bytes and local_filepath:
396
+ with open(local_filepath, 'rb') as f:
397
+ raw_data = f.read()
398
+ local_filepath = None
399
+ if raw_data is not None:
400
+ # import io
401
+ if isinstance(raw_data, str):
402
+ mimetype, _ = guess_type(raw_data.encode())
403
+ else:
404
+ mimetype, _ = guess_type(raw_data)
405
+ default_values = {
406
+ 'id': '',
407
+ 'resource_uri': '',
408
+ 'storage': '',
409
+ 'location': '',
410
+ 'upload_channel': '',
411
+ 'filename': 'raw_data',
412
+ 'modality': None,
413
+ 'mimetype': mimetype if mimetype else DEFAULT_MIME_TYPE,
414
+ 'size': len(raw_data),
415
+ 'upload_mechanism': '',
416
+ 'customer_id': '',
417
+ 'status': 'local',
418
+ 'created_at': datetime.now().isoformat(),
419
+ 'created_by': '',
420
+ 'published': False,
421
+ 'deleted': False,
422
+ 'source_filepath': None,
423
+ }
424
+ new_kwargs = kwargs.copy()
425
+ for key, value in default_values.items():
426
+ new_kwargs.setdefault(key, value)
427
+ super().__init__(
428
+ local_filepath=None,
429
+ raw_data=raw_data,
430
+ **new_kwargs
431
+ )
432
+ elif local_filepath is not None:
433
+ file_path = Path(local_filepath)
434
+ if not file_path.exists():
435
+ raise FileNotFoundError(f"File not found: {file_path}")
436
+
437
+ mimetype, _ = guess_type(file_path)
438
+ if mimetype is None or mimetype == DEFAULT_MIME_TYPE:
439
+ logger.warning(f"Could not determine mimetype for file: {file_path}")
440
+ size = file_path.stat().st_size
441
+ created_at = datetime.fromtimestamp(file_path.stat().st_ctime).isoformat()
442
+
443
+ super().__init__(
444
+ id="",
445
+ resource_uri="",
446
+ storage="",
447
+ location=str(file_path),
448
+ upload_channel="",
449
+ filename=file_path.name,
450
+ modality=detect_modality(file_path),
451
+ mimetype=mimetype,
452
+ size=size,
453
+ upload_mechanism="",
454
+ customer_id="",
455
+ status="local",
456
+ created_at=created_at,
457
+ created_by="",
458
+ published=False,
459
+ deleted=False,
460
+ source_filepath=str(file_path),
461
+ local_filepath=str(file_path),
462
+ raw_data=None,
463
+ )
464
+
465
+ def fetch_file_data(
466
+ self, *args,
467
+ auto_convert: bool = True,
468
+ save_path: str | None = None,
469
+ **kwargs,
470
+ ) -> 'bytes | ImagingData':
471
+ """Get the file data for this local resource.
472
+
473
+ Args:
474
+ auto_convert: If True, automatically converts to appropriate format (pydicom.Dataset, PIL Image, etc.)
475
+ save_path: Optional path to save the file locally
476
+ Returns:
477
+ File data (format depends on auto_convert and file type)
478
+ """
479
+ if self.raw_data is not None:
480
+ img_data = self.raw_data
481
+ local_filepath = None
482
+ else:
483
+ local_filepath = str(self.local_filepath)
484
+ with open(local_filepath, 'rb') as f:
485
+ img_data = f.read()
486
+
487
+ if save_path:
488
+ with open(save_path, 'wb') as f:
489
+ f.write(img_data)
490
+
491
+ if auto_convert:
492
+ try:
493
+ mimetype, ext = BaseApi._determine_mimetype(img_data, self.mimetype)
494
+ img_data = BaseApi.convert_format(img_data,
495
+ mimetype=mimetype,
496
+ file_path=local_filepath)
497
+ except Exception as e:
498
+ logger.error(f"Failed to auto-convert local resource: {e}")
499
+ logger.error(e, exc_info=True)
500
+
501
+ return img_data
502
+
503
+ def __str__(self) -> str:
504
+ """String representation of the local resource.
505
+
506
+ Returns:
507
+ Human-readable string describing the local resource
508
+ """
509
+ return f"LocalResource(filepath='{self.local_filepath}', size={self.size_mb}MB)"
510
+
511
+ def __repr__(self) -> str:
512
+ """Detailed string representation of the local resource.
513
+
514
+ Returns:
515
+ Detailed string representation for debugging
516
+ """
517
+ return (
518
+ f"LocalResource(filepath='{self.local_filepath}', "
519
+ f"filename='{self.filename}', modality='{self.modality}', "
520
+ f"size={self.size_mb}MB)"
521
+ )
@@ -0,0 +1 @@
1
+ from .datamintdatamodule import DatamintDataModule
@@ -0,0 +1,103 @@
1
+ from torch.utils.data import DataLoader
2
+ from datamint import Dataset
3
+ import lightning as L
4
+ from typing import Any
5
+ from copy import copy
6
+ import numpy as np
7
+
8
+
9
+ class DatamintDataModule(L.LightningDataModule):
10
+ """
11
+ LightningDataModule for Datamint datasets with train/val split.
12
+ TODO: Add support for test and predict dataloaders.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ project_name: str = "./",
18
+ batch_size: int = 32,
19
+ image_transform=None,
20
+ mask_transform=None,
21
+ alb_transform=None,
22
+ alb_train_transform=None,
23
+ alb_val_transform=None,
24
+ train_split: float = 0.9,
25
+ val_split: float = 0.1,
26
+ seed: int = 42,
27
+ num_workers: int = 4,
28
+ **dataset_kwargs: Any,
29
+ ):
30
+ super().__init__()
31
+ self.project_name = project_name
32
+ self.batch_size = batch_size
33
+ self.image_transform = image_transform
34
+ self.mask_transform = mask_transform
35
+
36
+ if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
37
+ raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
38
+
39
+ # Handle backward compatibility for alb_transform
40
+ if alb_transform is not None:
41
+ self.alb_train_transform = alb_transform
42
+ self.alb_val_transform = alb_transform
43
+ else:
44
+ self.alb_train_transform = alb_train_transform
45
+ self.alb_val_transform = alb_val_transform
46
+
47
+ self.train_split = train_split
48
+ self.val_split = val_split
49
+ self.seed = seed
50
+ self.dataset_kwargs = dataset_kwargs
51
+ self.num_workers = num_workers
52
+
53
+ self.dataset = None
54
+
55
+ def prepare_data(self) -> None:
56
+ """Download or update data if needed."""
57
+ Dataset(
58
+ project_name=self.project_name,
59
+ auto_update=True,
60
+ )
61
+
62
+ def setup(self, stage: str = None) -> None:
63
+ """Set up datasets and perform train/val split."""
64
+ if self.dataset is None:
65
+ # Create base dataset for getting indices
66
+ self.dataset = Dataset(
67
+ return_as_semantic_segmentation=True,
68
+ semantic_seg_merge_strategy="union",
69
+ return_frame_by_frame=True,
70
+ include_unannotated=False,
71
+ project_name=self.project_name,
72
+ image_transform=self.image_transform,
73
+ mask_transform=self.mask_transform,
74
+ alb_transform=None, # No transform for base dataset
75
+ auto_update=False,
76
+ **self.dataset_kwargs,
77
+ )
78
+
79
+ indices = list(copy(self.dataset.subset_indices))
80
+ rs = np.random.RandomState(self.seed)
81
+ rs.shuffle(indices)
82
+ train_end = int(self.train_split * len(indices))
83
+ train_idx = indices[:train_end]
84
+ val_idx = indices[train_end:]
85
+
86
+ self.train_dataset = copy(self.dataset).subset(train_idx)
87
+ self.train_dataset.alb_transform = self.alb_train_transform
88
+ self.val_dataset = copy(self.dataset).subset(val_idx)
89
+ self.val_dataset.alb_transform = self.alb_val_transform
90
+
91
+ def train_dataloader(self) -> DataLoader:
92
+ return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
93
+
94
+ def val_dataloader(self) -> DataLoader:
95
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
96
+
97
+ def test_dataloader(self):
98
+ # Use the same dataloader as validation for testing, because we have so few samples
99
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
100
+
101
+ def predict_dataloader(self):
102
+ # Use the same dataloader as validation for testing, because we have so few samples
103
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
@@ -0,0 +1,65 @@
1
+ # Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
2
+ from .tracking.fluent import set_project
3
+ import mlflow.tracking._tracking_service.utils as mlflow_utils
4
+ from functools import wraps
5
+ import logging
6
+ from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
7
+ from typing import TYPE_CHECKING
8
+
9
+ _LOGGER = logging.getLogger(__name__)
10
+
11
+ # Store reference to original function
12
+ _original_get_tracking_uri = mlflow_utils.get_tracking_uri
13
+ _SETUP_CALLED_SUCCESSFULLY = False
14
+
15
+ if mlflow_utils.is_tracking_uri_set():
16
+ _LOGGER.warning("MLflow tracking URI is already set before patching get_tracking_uri.")
17
+
18
+ @wraps(_original_get_tracking_uri)
19
+ def _patched_get_tracking_uri(*args, **kwargs):
20
+ """Patched version of get_tracking_uri that ensures MLflow environment is set up first.
21
+
22
+ This wrapper ensures that setup_mlflow_environment is called before any tracking
23
+ URI operations, guaranteeing proper MLflow configuration.
24
+
25
+ Args:
26
+ *args: Arguments passed to the original get_tracking_uri function.
27
+ **kwargs: Keyword arguments passed to the original get_tracking_uri function.
28
+
29
+ Returns:
30
+ The result of the original get_tracking_uri function.
31
+ """
32
+ global _SETUP_CALLED_SUCCESSFULLY
33
+ if _SETUP_CALLED_SUCCESSFULLY:
34
+ return _original_get_tracking_uri(*args, **kwargs)
35
+ if mlflow_utils.is_tracking_uri_set():
36
+ _LOGGER.warning("MLflow tracking URI is already set before patching get_tracking_uri.")
37
+ try:
38
+ _SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
39
+ except Exception as e:
40
+ _SETUP_CALLED_SUCCESSFULLY = False
41
+ _LOGGER.error("Failed to set up MLflow environment: %s", e)
42
+ ret = _original_get_tracking_uri(*args, **kwargs)
43
+ return ret
44
+
45
+
46
+ setup_mlflow_environment(set_mlflow=False)
47
+ # Replace the original function with our patched version
48
+ mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
49
+
50
+
51
+ if TYPE_CHECKING:
52
+ from .flavors.model import DatamintModel
53
+ else:
54
+ import lazy_loader as lazy
55
+
56
+ __getattr__, __dir__, __all__ = lazy.attach(
57
+ __name__,
58
+ submodules=['flavors.model', 'flavors.datamint_flavor'],
59
+ submod_attrs={
60
+ "flavors.model": ["DatamintModel"],
61
+ "flavors.datamint_flavor": ["log_model", "load_model"],
62
+ },
63
+ )
64
+
65
+ __all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured', 'DatamintModel']
@@ -0,0 +1 @@
1
+ from .datamint_artifacts_repo import DatamintArtifactsRepository
@@ -0,0 +1,8 @@
1
+ from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
2
+
3
+
4
+ class DatamintArtifactsRepository(MlflowArtifactsRepository):
5
+ @classmethod
6
+ def resolve_uri(cls, artifact_uri, tracking_uri):
7
+ tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
8
+ return super().resolve_uri(artifact_uri, tracking_uri)