datamint 2.3.5__py3-none-any.whl → 2.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/api/base_api.py +42 -8
- datamint/api/client.py +2 -0
- datamint/api/endpoints/resources_api.py +37 -13
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/dataset/base_dataset.py +4 -0
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +46 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +109 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +338 -0
- datamint/mlflow/models/__init__.py +94 -0
- datamint/mlflow/tracking/datamint_store.py +46 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +78 -0
- datamint-2.4.1.dist-info/METADATA +320 -0
- {datamint-2.3.5.dist-info → datamint-2.4.1.dist-info}/RECORD +23 -10
- datamint-2.4.1.dist-info/entry_points.txt +18 -0
- datamint-2.3.5.dist-info/METADATA +0 -125
- datamint-2.3.5.dist-info/entry_points.txt +0 -4
- {datamint-2.3.5.dist-info → datamint-2.4.1.dist-info}/WHEEL +0 -0
datamint/api/base_api.py
CHANGED
|
@@ -61,22 +61,56 @@ class BaseApi:
|
|
|
61
61
|
client: Optional HTTP client instance. If None, a new one will be created.
|
|
62
62
|
"""
|
|
63
63
|
self.config = config
|
|
64
|
-
self.
|
|
64
|
+
self._owns_client = client is None # Track if we created the client
|
|
65
|
+
self.client = client or BaseApi._create_client(config)
|
|
65
66
|
self.semaphore = asyncio.Semaphore(20)
|
|
66
67
|
self._api_instance: 'Api | None' = None # Injected by Api class
|
|
67
68
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _create_client(config: ApiConfig) -> httpx.Client:
|
|
71
|
+
"""Create and configure HTTP client with authentication and timeouts.
|
|
72
|
+
|
|
73
|
+
The client is designed to be long-lived and reused across multiple requests.
|
|
74
|
+
It maintains connection pooling for improved performance.
|
|
75
|
+
Default limits: max_keepalive_connections=20, max_connections=100
|
|
76
|
+
"""
|
|
77
|
+
headers = {"apikey": config.api_key} if config.api_key else None
|
|
73
78
|
|
|
74
79
|
return httpx.Client(
|
|
75
|
-
base_url=
|
|
80
|
+
base_url=config.server_url,
|
|
76
81
|
headers=headers,
|
|
77
|
-
timeout=
|
|
82
|
+
timeout=config.timeout,
|
|
83
|
+
limits=httpx.Limits(
|
|
84
|
+
max_keepalive_connections=5, # Increased from default 20
|
|
85
|
+
max_connections=20, # Increased from default 100
|
|
86
|
+
keepalive_expiry=8
|
|
87
|
+
)
|
|
78
88
|
)
|
|
79
89
|
|
|
90
|
+
def close(self) -> None:
|
|
91
|
+
"""Close the HTTP client and release resources.
|
|
92
|
+
|
|
93
|
+
Should be called when the API instance is no longer needed.
|
|
94
|
+
Only closes the client if it was created by this instance.
|
|
95
|
+
"""
|
|
96
|
+
if self._owns_client and self.client is not None:
|
|
97
|
+
self.client.close()
|
|
98
|
+
|
|
99
|
+
def __enter__(self):
|
|
100
|
+
"""Context manager entry."""
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
104
|
+
"""Context manager exit - ensures client is closed."""
|
|
105
|
+
self.close()
|
|
106
|
+
|
|
107
|
+
def __del__(self):
|
|
108
|
+
"""Destructor - ensures client is closed when instance is garbage collected."""
|
|
109
|
+
try:
|
|
110
|
+
self.close()
|
|
111
|
+
except Exception:
|
|
112
|
+
pass # Ignore errors during cleanup
|
|
113
|
+
|
|
80
114
|
def _stream_request(self, method: str, endpoint: str, **kwargs):
|
|
81
115
|
"""Make streaming HTTP request with error handling.
|
|
82
116
|
|
datamint/api/client.py
CHANGED
|
@@ -68,6 +68,8 @@ class Api:
|
|
|
68
68
|
f" Please check your api_key and/or other configurations. {e}")
|
|
69
69
|
|
|
70
70
|
def _get_endpoint(self, name: str):
|
|
71
|
+
if self._client is None:
|
|
72
|
+
self._client = BaseApi._create_client(self.config)
|
|
71
73
|
if name not in self._endpoints:
|
|
72
74
|
api_class = self._API_MAP[name]
|
|
73
75
|
endpoint = api_class(self.config, self._client)
|
|
@@ -25,6 +25,7 @@ import nest_asyncio # For running asyncio in jupyter notebooks
|
|
|
25
25
|
from PIL import Image
|
|
26
26
|
import io
|
|
27
27
|
from datamint.types import ImagingData
|
|
28
|
+
from collections import defaultdict
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
_LOGGER = logging.getLogger(__name__)
|
|
@@ -279,7 +280,6 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
|
|
|
279
280
|
_LOGGER.warning(msg)
|
|
280
281
|
_USER_LOGGER.warning(msg)
|
|
281
282
|
|
|
282
|
-
|
|
283
283
|
mimetype = standardize_mimetype(mimetype)
|
|
284
284
|
|
|
285
285
|
if is_a_dicom_file == True or is_dicom(file_path):
|
|
@@ -440,12 +440,12 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
|
|
|
440
440
|
|
|
441
441
|
try:
|
|
442
442
|
tasks = [__upload_single_resource(f, segfiles, metadata_file)
|
|
443
|
-
|
|
443
|
+
for f, segfiles, metadata_file in zip(files_path, segmentation_files, metadata_files)]
|
|
444
444
|
except ValueError:
|
|
445
445
|
msg = f"Error preparing upload tasks. Try `assemble_dicom=False`."
|
|
446
446
|
_LOGGER.error(msg)
|
|
447
447
|
_USER_LOGGER.error(msg)
|
|
448
|
-
raise
|
|
448
|
+
raise
|
|
449
449
|
return await asyncio.gather(*tasks, return_exceptions=on_error == 'skip')
|
|
450
450
|
|
|
451
451
|
def upload_resources(self,
|
|
@@ -996,22 +996,28 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
|
|
|
996
996
|
raise
|
|
997
997
|
|
|
998
998
|
def set_tags(self,
|
|
999
|
-
resource: str | Resource,
|
|
999
|
+
resource: str | Resource | Sequence[str | Resource],
|
|
1000
1000
|
tags: Sequence[str],
|
|
1001
1001
|
):
|
|
1002
1002
|
"""
|
|
1003
1003
|
Set tags for a resource, IMPORTANT: This replaces all existing tags.
|
|
1004
1004
|
Args:
|
|
1005
|
-
resource: The resource
|
|
1005
|
+
resource: The resource object or a list of resources.
|
|
1006
1006
|
tags: The tags to set.
|
|
1007
1007
|
"""
|
|
1008
1008
|
data = {'tags': tags}
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1009
|
+
if isinstance(resource, Sequence):
|
|
1010
|
+
resource_ids = [self._entid(res) for res in resource]
|
|
1011
|
+
response = self._make_request('PUT',
|
|
1012
|
+
f'{self.endpoint_base}/tags',
|
|
1013
|
+
json={'resource_ids': resource_ids,
|
|
1014
|
+
'tags': tags})
|
|
1015
|
+
else:
|
|
1016
|
+
resource_id = self._entid(resource)
|
|
1017
|
+
response = self._make_entity_request('PUT',
|
|
1018
|
+
resource_id,
|
|
1019
|
+
add_path='tags',
|
|
1020
|
+
json=data)
|
|
1015
1021
|
return response
|
|
1016
1022
|
|
|
1017
1023
|
# def get_projects(self, resource: Resource) -> Sequence[Project]:
|
|
@@ -1029,7 +1035,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
|
|
|
1029
1035
|
# return [proj for proj in self.projects_api.get_all() if proj.id in proj_ids]
|
|
1030
1036
|
|
|
1031
1037
|
def add_tags(self,
|
|
1032
|
-
resource: str | Resource,
|
|
1038
|
+
resource: str | Resource | Sequence[str | Resource],
|
|
1033
1039
|
tags: Sequence[str],
|
|
1034
1040
|
):
|
|
1035
1041
|
"""
|
|
@@ -1040,8 +1046,26 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
|
|
|
1040
1046
|
"""
|
|
1041
1047
|
if isinstance(resource, str):
|
|
1042
1048
|
resource = self.get_by_id(resource)
|
|
1049
|
+
elif isinstance(resource, Sequence):
|
|
1050
|
+
# Transform every str to Resource first.
|
|
1051
|
+
resources = [self.get_by_id(res) if isinstance(res, str) else res for res in resource]
|
|
1052
|
+
|
|
1053
|
+
# group resource having the exact same tags to minimize requests
|
|
1054
|
+
tag_map: dict[tuple, list[Resource]] = defaultdict(list)
|
|
1055
|
+
for res in resources:
|
|
1056
|
+
old_tags = res.tags if res.tags is not None else []
|
|
1057
|
+
# key = tuple(sorted(old_tags))
|
|
1058
|
+
key = tuple(old_tags) # keep order, assuming order matters for tags
|
|
1059
|
+
tag_map[key].append(res)
|
|
1060
|
+
|
|
1061
|
+
# finally, set tags for each group
|
|
1062
|
+
for old_tags_tuple, res_group in tag_map.items():
|
|
1063
|
+
old_tags = list(old_tags_tuple)
|
|
1064
|
+
self.set_tags(res_group, old_tags + list(tags))
|
|
1065
|
+
return
|
|
1066
|
+
|
|
1043
1067
|
old_tags = resource.tags if resource.tags is not None else []
|
|
1044
|
-
|
|
1068
|
+
self.set_tags(resource, old_tags + list(tags))
|
|
1045
1069
|
|
|
1046
1070
|
def bulk_delete(self, entities: Sequence[str | Resource]) -> None:
|
|
1047
1071
|
"""Delete multiple entities. Faster than deleting them one by one.
|
|
@@ -178,6 +178,8 @@ class CreateAnnotationDto:
|
|
|
178
178
|
if model_id is not None:
|
|
179
179
|
if is_model == False:
|
|
180
180
|
raise ValueError("model_id==False while self.model_id is provided.")
|
|
181
|
+
if not isinstance(model_id, str):
|
|
182
|
+
raise ValueError("model_id must be a string if provided.")
|
|
181
183
|
is_model = True
|
|
182
184
|
self.is_model = is_model
|
|
183
185
|
self.geometry = geometry
|
datamint/dataset/base_dataset.py
CHANGED
|
@@ -307,6 +307,10 @@ class DatamintBaseDataset:
|
|
|
307
307
|
self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
|
|
308
308
|
worklist_id = self.get_info()['worklist_id']
|
|
309
309
|
groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
|
|
310
|
+
if not groups:
|
|
311
|
+
self.seglabel_list = []
|
|
312
|
+
self.seglabel2code = {}
|
|
313
|
+
return
|
|
310
314
|
# order by 'index' key
|
|
311
315
|
max_index = max([g['index'] for g in groups.values()])
|
|
312
316
|
self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamintdatamodule import DatamintDataModule
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from torch.utils.data import DataLoader
|
|
2
|
+
from datamint import Dataset
|
|
3
|
+
import lightning as L
|
|
4
|
+
from typing import Any
|
|
5
|
+
from copy import copy
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatamintDataModule(L.LightningDataModule):
|
|
10
|
+
"""
|
|
11
|
+
LightningDataModule for Datamint datasets with train/val split.
|
|
12
|
+
TODO: Add support for test and predict dataloaders.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
project_name: str = "./",
|
|
18
|
+
batch_size: int = 32,
|
|
19
|
+
image_transform=None,
|
|
20
|
+
mask_transform=None,
|
|
21
|
+
alb_transform=None,
|
|
22
|
+
alb_train_transform=None,
|
|
23
|
+
alb_val_transform=None,
|
|
24
|
+
train_split: float = 0.9,
|
|
25
|
+
val_split: float = 0.1,
|
|
26
|
+
seed: int = 42,
|
|
27
|
+
num_workers: int = 4,
|
|
28
|
+
**dataset_kwargs: Any,
|
|
29
|
+
):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.project_name = project_name
|
|
32
|
+
self.batch_size = batch_size
|
|
33
|
+
self.image_transform = image_transform
|
|
34
|
+
self.mask_transform = mask_transform
|
|
35
|
+
|
|
36
|
+
if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
|
|
37
|
+
raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
|
|
38
|
+
|
|
39
|
+
# Handle backward compatibility for alb_transform
|
|
40
|
+
if alb_transform is not None:
|
|
41
|
+
self.alb_train_transform = alb_transform
|
|
42
|
+
self.alb_val_transform = alb_transform
|
|
43
|
+
else:
|
|
44
|
+
self.alb_train_transform = alb_train_transform
|
|
45
|
+
self.alb_val_transform = alb_val_transform
|
|
46
|
+
|
|
47
|
+
self.train_split = train_split
|
|
48
|
+
self.val_split = val_split
|
|
49
|
+
self.seed = seed
|
|
50
|
+
self.dataset_kwargs = dataset_kwargs
|
|
51
|
+
self.num_workers = num_workers
|
|
52
|
+
|
|
53
|
+
self.dataset = None
|
|
54
|
+
|
|
55
|
+
def prepare_data(self) -> None:
|
|
56
|
+
"""Download or update data if needed."""
|
|
57
|
+
Dataset(
|
|
58
|
+
project_name=self.project_name,
|
|
59
|
+
auto_update=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def setup(self, stage: str = None) -> None:
|
|
63
|
+
"""Set up datasets and perform train/val split."""
|
|
64
|
+
if self.dataset is None:
|
|
65
|
+
# Create base dataset for getting indices
|
|
66
|
+
self.dataset = Dataset(
|
|
67
|
+
return_as_semantic_segmentation=True,
|
|
68
|
+
semantic_seg_merge_strategy="union",
|
|
69
|
+
return_frame_by_frame=True,
|
|
70
|
+
include_unannotated=False,
|
|
71
|
+
project_name=self.project_name,
|
|
72
|
+
image_transform=self.image_transform,
|
|
73
|
+
mask_transform=self.mask_transform,
|
|
74
|
+
alb_transform=None, # No transform for base dataset
|
|
75
|
+
auto_update=False,
|
|
76
|
+
**self.dataset_kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
indices = list(copy(self.dataset.subset_indices))
|
|
80
|
+
rs = np.random.RandomState(self.seed)
|
|
81
|
+
rs.shuffle(indices)
|
|
82
|
+
train_end = int(self.train_split * len(indices))
|
|
83
|
+
train_idx = indices[:train_end]
|
|
84
|
+
val_idx = indices[train_end:]
|
|
85
|
+
|
|
86
|
+
self.train_dataset = copy(self.dataset).subset(train_idx)
|
|
87
|
+
self.train_dataset.alb_transform = self.alb_train_transform
|
|
88
|
+
self.val_dataset = copy(self.dataset).subset(val_idx)
|
|
89
|
+
self.val_dataset.alb_transform = self.alb_val_transform
|
|
90
|
+
|
|
91
|
+
def train_dataloader(self) -> DataLoader:
|
|
92
|
+
return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
|
93
|
+
|
|
94
|
+
def val_dataloader(self) -> DataLoader:
|
|
95
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
96
|
+
|
|
97
|
+
def test_dataloader(self):
|
|
98
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
99
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
100
|
+
|
|
101
|
+
def predict_dataloader(self):
|
|
102
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
103
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
|
|
2
|
+
from .tracking.fluent import set_project
|
|
3
|
+
import mlflow.tracking._tracking_service.utils as mlflow_utils
|
|
4
|
+
from functools import wraps
|
|
5
|
+
import logging
|
|
6
|
+
from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
|
|
7
|
+
|
|
8
|
+
_LOGGER = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# Store reference to original function
|
|
11
|
+
_original_get_tracking_uri = mlflow_utils.get_tracking_uri
|
|
12
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@wraps(_original_get_tracking_uri)
|
|
16
|
+
def _patched_get_tracking_uri(*args, **kwargs):
|
|
17
|
+
"""Patched version of get_tracking_uri that ensures MLflow environment is set up first.
|
|
18
|
+
|
|
19
|
+
This wrapper ensures that setup_mlflow_environment is called before any tracking
|
|
20
|
+
URI operations, guaranteeing proper MLflow configuration.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
*args: Arguments passed to the original get_tracking_uri function.
|
|
24
|
+
**kwargs: Keyword arguments passed to the original get_tracking_uri function.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The result of the original get_tracking_uri function.
|
|
28
|
+
"""
|
|
29
|
+
global _SETUP_CALLED_SUCCESSFULLY
|
|
30
|
+
if _SETUP_CALLED_SUCCESSFULLY:
|
|
31
|
+
return _original_get_tracking_uri(*args, **kwargs)
|
|
32
|
+
try:
|
|
33
|
+
_SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
36
|
+
_LOGGER.error("Failed to set up MLflow environment: %s", e)
|
|
37
|
+
ret = _original_get_tracking_uri(*args, **kwargs)
|
|
38
|
+
return ret
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
setup_mlflow_environment(set_mlflow=False)
|
|
42
|
+
# Replace the original function with our patched version
|
|
43
|
+
mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
__all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamint_artifacts_repo import DatamintArtifactsRepository
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DatamintArtifactsRepository(MlflowArtifactsRepository):
|
|
5
|
+
@classmethod
|
|
6
|
+
def resolve_uri(cls, artifact_uri, tracking_uri):
|
|
7
|
+
tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
|
|
8
|
+
return super().resolve_uri(artifact_uri, tracking_uri)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for automatically configuring MLflow environment variables
|
|
3
|
+
based on Datamint configuration.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
from datamint import configs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_LOGGER = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_datamint_api_url() -> Optional[str]:
|
|
17
|
+
"""Get the Datamint API URL from configuration or environment variables."""
|
|
18
|
+
# First check environment variable
|
|
19
|
+
api_url = os.getenv('DATAMINT_API_URL')
|
|
20
|
+
if api_url:
|
|
21
|
+
return api_url
|
|
22
|
+
|
|
23
|
+
# Then check configuration
|
|
24
|
+
api_url = configs.get_value(configs.APIURL_KEY)
|
|
25
|
+
if api_url:
|
|
26
|
+
return api_url
|
|
27
|
+
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_datamint_api_key() -> Optional[str]:
|
|
32
|
+
"""Get the Datamint API key from configuration or environment variables."""
|
|
33
|
+
# First check environment variable
|
|
34
|
+
api_key = os.getenv('DATAMINT_API_KEY')
|
|
35
|
+
if api_key:
|
|
36
|
+
return api_key
|
|
37
|
+
|
|
38
|
+
# Then check configuration
|
|
39
|
+
api_key = configs.get_value(configs.APIKEY_KEY)
|
|
40
|
+
if api_key:
|
|
41
|
+
return api_key
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _get_mlflowdatamint_uri() -> Optional[str]:
|
|
47
|
+
api_url = get_datamint_api_url()
|
|
48
|
+
if not api_url:
|
|
49
|
+
return None
|
|
50
|
+
_LOGGER.debug(f"Retrieved Datamint API URL: {api_url}")
|
|
51
|
+
|
|
52
|
+
# Remove trailing slash if present
|
|
53
|
+
api_url = api_url.rstrip('/')
|
|
54
|
+
# api_url samples:
|
|
55
|
+
# https://api.datamint.io
|
|
56
|
+
# http://localhost:3001
|
|
57
|
+
|
|
58
|
+
parsed_url = urlparse(api_url)
|
|
59
|
+
base_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
|
|
60
|
+
_LOGGER.debug(f"Derived base URL for MLflow Datamint: {base_url}")
|
|
61
|
+
# FIXME: It should work with https or datamint-api server should forward https requests.
|
|
62
|
+
base_url = base_url.replace('https://', 'http://')
|
|
63
|
+
if len(base_url.replace('http:', '')) == 0:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
mlflow_uri = f"{base_url}:5000"
|
|
67
|
+
return mlflow_uri
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def setup_mlflow_environment(overwrite: bool = False,
|
|
71
|
+
set_mlflow: bool = True) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Automatically set up MLflow environment variables based on Datamint configuration.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
bool: True if MLflow environment was successfully configured, False otherwise.
|
|
77
|
+
"""
|
|
78
|
+
_LOGGER.debug("Setting up MLflow environment variables from Datamint configuration")
|
|
79
|
+
api_key = get_datamint_api_key()
|
|
80
|
+
mlflow_uri = _get_mlflowdatamint_uri()
|
|
81
|
+
if not mlflow_uri or not api_key:
|
|
82
|
+
_LOGGER.warning("Datamint configuration incomplete, cannot auto-configure MLflow")
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
if overwrite or not os.getenv('MLFLOW_TRACKING_TOKEN'):
|
|
86
|
+
os.environ['MLFLOW_TRACKING_TOKEN'] = api_key
|
|
87
|
+
if overwrite or not os.getenv('MLFLOW_TRACKING_URI'):
|
|
88
|
+
os.environ['MLFLOW_TRACKING_URI'] = mlflow_uri
|
|
89
|
+
|
|
90
|
+
if set_mlflow:
|
|
91
|
+
import mlflow
|
|
92
|
+
mlflow.set_tracking_uri(mlflow_uri)
|
|
93
|
+
|
|
94
|
+
return True
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def ensure_mlflow_configured() -> None:
|
|
98
|
+
"""
|
|
99
|
+
Ensure MLflow environment is properly configured.
|
|
100
|
+
Raises an exception if configuration is incomplete.
|
|
101
|
+
"""
|
|
102
|
+
if not setup_mlflow_environment():
|
|
103
|
+
if not os.getenv('MLFLOW_TRACKING_URI') or not os.getenv('MLFLOW_TRACKING_TOKEN'):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
"MLflow environment not configured. Please either:\n"
|
|
106
|
+
"1. Run 'datamint-config' to set up Datamint configuration, or\n"
|
|
107
|
+
"2. Set DATAMINT_API_URL and DATAMINT_API_KEY environment variables, or\n"
|
|
108
|
+
"3. Manually set MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN environment variables"
|
|
109
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modelcheckpoint import MLFlowModelCheckpoint
|