datamint 2.3.5__py3-none-any.whl → 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/api/base_api.py +42 -8
- datamint/api/client.py +2 -0
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/dataset/base_dataset.py +4 -0
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +46 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +109 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +338 -0
- datamint/mlflow/models/__init__.py +94 -0
- datamint/mlflow/tracking/datamint_store.py +46 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +78 -0
- datamint-2.4.0.dist-info/METADATA +320 -0
- {datamint-2.3.5.dist-info → datamint-2.4.0.dist-info}/RECORD +22 -9
- datamint-2.4.0.dist-info/entry_points.txt +18 -0
- datamint-2.3.5.dist-info/METADATA +0 -125
- datamint-2.3.5.dist-info/entry_points.txt +0 -4
- {datamint-2.3.5.dist-info → datamint-2.4.0.dist-info}/WHEEL +0 -0
datamint/api/base_api.py
CHANGED
|
@@ -61,22 +61,56 @@ class BaseApi:
|
|
|
61
61
|
client: Optional HTTP client instance. If None, a new one will be created.
|
|
62
62
|
"""
|
|
63
63
|
self.config = config
|
|
64
|
-
self.
|
|
64
|
+
self._owns_client = client is None # Track if we created the client
|
|
65
|
+
self.client = client or BaseApi._create_client(config)
|
|
65
66
|
self.semaphore = asyncio.Semaphore(20)
|
|
66
67
|
self._api_instance: 'Api | None' = None # Injected by Api class
|
|
67
68
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _create_client(config: ApiConfig) -> httpx.Client:
|
|
71
|
+
"""Create and configure HTTP client with authentication and timeouts.
|
|
72
|
+
|
|
73
|
+
The client is designed to be long-lived and reused across multiple requests.
|
|
74
|
+
It maintains connection pooling for improved performance.
|
|
75
|
+
Default limits: max_keepalive_connections=20, max_connections=100
|
|
76
|
+
"""
|
|
77
|
+
headers = {"apikey": config.api_key} if config.api_key else None
|
|
73
78
|
|
|
74
79
|
return httpx.Client(
|
|
75
|
-
base_url=
|
|
80
|
+
base_url=config.server_url,
|
|
76
81
|
headers=headers,
|
|
77
|
-
timeout=
|
|
82
|
+
timeout=config.timeout,
|
|
83
|
+
limits=httpx.Limits(
|
|
84
|
+
max_keepalive_connections=5, # Increased from default 20
|
|
85
|
+
max_connections=20, # Increased from default 100
|
|
86
|
+
keepalive_expiry=8
|
|
87
|
+
)
|
|
78
88
|
)
|
|
79
89
|
|
|
90
|
+
def close(self) -> None:
|
|
91
|
+
"""Close the HTTP client and release resources.
|
|
92
|
+
|
|
93
|
+
Should be called when the API instance is no longer needed.
|
|
94
|
+
Only closes the client if it was created by this instance.
|
|
95
|
+
"""
|
|
96
|
+
if self._owns_client and self.client is not None:
|
|
97
|
+
self.client.close()
|
|
98
|
+
|
|
99
|
+
def __enter__(self):
|
|
100
|
+
"""Context manager entry."""
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
104
|
+
"""Context manager exit - ensures client is closed."""
|
|
105
|
+
self.close()
|
|
106
|
+
|
|
107
|
+
def __del__(self):
|
|
108
|
+
"""Destructor - ensures client is closed when instance is garbage collected."""
|
|
109
|
+
try:
|
|
110
|
+
self.close()
|
|
111
|
+
except Exception:
|
|
112
|
+
pass # Ignore errors during cleanup
|
|
113
|
+
|
|
80
114
|
def _stream_request(self, method: str, endpoint: str, **kwargs):
|
|
81
115
|
"""Make streaming HTTP request with error handling.
|
|
82
116
|
|
datamint/api/client.py
CHANGED
|
@@ -68,6 +68,8 @@ class Api:
|
|
|
68
68
|
f" Please check your api_key and/or other configurations. {e}")
|
|
69
69
|
|
|
70
70
|
def _get_endpoint(self, name: str):
|
|
71
|
+
if self._client is None:
|
|
72
|
+
self._client = BaseApi._create_client(self.config)
|
|
71
73
|
if name not in self._endpoints:
|
|
72
74
|
api_class = self._API_MAP[name]
|
|
73
75
|
endpoint = api_class(self.config, self._client)
|
|
@@ -178,6 +178,8 @@ class CreateAnnotationDto:
|
|
|
178
178
|
if model_id is not None:
|
|
179
179
|
if is_model == False:
|
|
180
180
|
raise ValueError("model_id==False while self.model_id is provided.")
|
|
181
|
+
if not isinstance(model_id, str):
|
|
182
|
+
raise ValueError("model_id must be a string if provided.")
|
|
181
183
|
is_model = True
|
|
182
184
|
self.is_model = is_model
|
|
183
185
|
self.geometry = geometry
|
datamint/dataset/base_dataset.py
CHANGED
|
@@ -307,6 +307,10 @@ class DatamintBaseDataset:
|
|
|
307
307
|
self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
|
|
308
308
|
worklist_id = self.get_info()['worklist_id']
|
|
309
309
|
groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
|
|
310
|
+
if not groups:
|
|
311
|
+
self.seglabel_list = []
|
|
312
|
+
self.seglabel2code = {}
|
|
313
|
+
return
|
|
310
314
|
# order by 'index' key
|
|
311
315
|
max_index = max([g['index'] for g in groups.values()])
|
|
312
316
|
self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamintdatamodule import DatamintDataModule
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from torch.utils.data import DataLoader
|
|
2
|
+
from datamint import Dataset
|
|
3
|
+
import lightning as L
|
|
4
|
+
from typing import Any
|
|
5
|
+
from copy import copy
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DatamintDataModule(L.LightningDataModule):
|
|
10
|
+
"""
|
|
11
|
+
LightningDataModule for Datamint datasets with train/val split.
|
|
12
|
+
TODO: Add support for test and predict dataloaders.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
project_name: str = "./",
|
|
18
|
+
batch_size: int = 32,
|
|
19
|
+
image_transform=None,
|
|
20
|
+
mask_transform=None,
|
|
21
|
+
alb_transform=None,
|
|
22
|
+
alb_train_transform=None,
|
|
23
|
+
alb_val_transform=None,
|
|
24
|
+
train_split: float = 0.9,
|
|
25
|
+
val_split: float = 0.1,
|
|
26
|
+
seed: int = 42,
|
|
27
|
+
num_workers: int = 4,
|
|
28
|
+
**dataset_kwargs: Any,
|
|
29
|
+
):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.project_name = project_name
|
|
32
|
+
self.batch_size = batch_size
|
|
33
|
+
self.image_transform = image_transform
|
|
34
|
+
self.mask_transform = mask_transform
|
|
35
|
+
|
|
36
|
+
if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
|
|
37
|
+
raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
|
|
38
|
+
|
|
39
|
+
# Handle backward compatibility for alb_transform
|
|
40
|
+
if alb_transform is not None:
|
|
41
|
+
self.alb_train_transform = alb_transform
|
|
42
|
+
self.alb_val_transform = alb_transform
|
|
43
|
+
else:
|
|
44
|
+
self.alb_train_transform = alb_train_transform
|
|
45
|
+
self.alb_val_transform = alb_val_transform
|
|
46
|
+
|
|
47
|
+
self.train_split = train_split
|
|
48
|
+
self.val_split = val_split
|
|
49
|
+
self.seed = seed
|
|
50
|
+
self.dataset_kwargs = dataset_kwargs
|
|
51
|
+
self.num_workers = num_workers
|
|
52
|
+
|
|
53
|
+
self.dataset = None
|
|
54
|
+
|
|
55
|
+
def prepare_data(self) -> None:
|
|
56
|
+
"""Download or update data if needed."""
|
|
57
|
+
Dataset(
|
|
58
|
+
project_name=self.project_name,
|
|
59
|
+
auto_update=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def setup(self, stage: str = None) -> None:
|
|
63
|
+
"""Set up datasets and perform train/val split."""
|
|
64
|
+
if self.dataset is None:
|
|
65
|
+
# Create base dataset for getting indices
|
|
66
|
+
self.dataset = Dataset(
|
|
67
|
+
return_as_semantic_segmentation=True,
|
|
68
|
+
semantic_seg_merge_strategy="union",
|
|
69
|
+
return_frame_by_frame=True,
|
|
70
|
+
include_unannotated=False,
|
|
71
|
+
project_name=self.project_name,
|
|
72
|
+
image_transform=self.image_transform,
|
|
73
|
+
mask_transform=self.mask_transform,
|
|
74
|
+
alb_transform=None, # No transform for base dataset
|
|
75
|
+
auto_update=False,
|
|
76
|
+
**self.dataset_kwargs,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
indices = list(copy(self.dataset.subset_indices))
|
|
80
|
+
rs = np.random.RandomState(self.seed)
|
|
81
|
+
rs.shuffle(indices)
|
|
82
|
+
train_end = int(self.train_split * len(indices))
|
|
83
|
+
train_idx = indices[:train_end]
|
|
84
|
+
val_idx = indices[train_end:]
|
|
85
|
+
|
|
86
|
+
self.train_dataset = copy(self.dataset).subset(train_idx)
|
|
87
|
+
self.train_dataset.alb_transform = self.alb_train_transform
|
|
88
|
+
self.val_dataset = copy(self.dataset).subset(val_idx)
|
|
89
|
+
self.val_dataset.alb_transform = self.alb_val_transform
|
|
90
|
+
|
|
91
|
+
def train_dataloader(self) -> DataLoader:
|
|
92
|
+
return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
|
93
|
+
|
|
94
|
+
def val_dataloader(self) -> DataLoader:
|
|
95
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
96
|
+
|
|
97
|
+
def test_dataloader(self):
|
|
98
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
99
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
100
|
+
|
|
101
|
+
def predict_dataloader(self):
|
|
102
|
+
# Use the same dataloader as validation for testing, because we have so few samples
|
|
103
|
+
return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
|
|
2
|
+
from .tracking.fluent import set_project
|
|
3
|
+
import mlflow.tracking._tracking_service.utils as mlflow_utils
|
|
4
|
+
from functools import wraps
|
|
5
|
+
import logging
|
|
6
|
+
from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
|
|
7
|
+
|
|
8
|
+
_LOGGER = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# Store reference to original function
|
|
11
|
+
_original_get_tracking_uri = mlflow_utils.get_tracking_uri
|
|
12
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@wraps(_original_get_tracking_uri)
|
|
16
|
+
def _patched_get_tracking_uri(*args, **kwargs):
|
|
17
|
+
"""Patched version of get_tracking_uri that ensures MLflow environment is set up first.
|
|
18
|
+
|
|
19
|
+
This wrapper ensures that setup_mlflow_environment is called before any tracking
|
|
20
|
+
URI operations, guaranteeing proper MLflow configuration.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
*args: Arguments passed to the original get_tracking_uri function.
|
|
24
|
+
**kwargs: Keyword arguments passed to the original get_tracking_uri function.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The result of the original get_tracking_uri function.
|
|
28
|
+
"""
|
|
29
|
+
global _SETUP_CALLED_SUCCESSFULLY
|
|
30
|
+
if _SETUP_CALLED_SUCCESSFULLY:
|
|
31
|
+
return _original_get_tracking_uri(*args, **kwargs)
|
|
32
|
+
try:
|
|
33
|
+
_SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
_SETUP_CALLED_SUCCESSFULLY = False
|
|
36
|
+
_LOGGER.error("Failed to set up MLflow environment: %s", e)
|
|
37
|
+
ret = _original_get_tracking_uri(*args, **kwargs)
|
|
38
|
+
return ret
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
setup_mlflow_environment(set_mlflow=False)
|
|
42
|
+
# Replace the original function with our patched version
|
|
43
|
+
mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
__all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datamint_artifacts_repo import DatamintArtifactsRepository
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DatamintArtifactsRepository(MlflowArtifactsRepository):
|
|
5
|
+
@classmethod
|
|
6
|
+
def resolve_uri(cls, artifact_uri, tracking_uri):
|
|
7
|
+
tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
|
|
8
|
+
return super().resolve_uri(artifact_uri, tracking_uri)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions for automatically configuring MLflow environment variables
|
|
3
|
+
based on Datamint configuration.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
from datamint import configs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_LOGGER = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_datamint_api_url() -> Optional[str]:
|
|
17
|
+
"""Get the Datamint API URL from configuration or environment variables."""
|
|
18
|
+
# First check environment variable
|
|
19
|
+
api_url = os.getenv('DATAMINT_API_URL')
|
|
20
|
+
if api_url:
|
|
21
|
+
return api_url
|
|
22
|
+
|
|
23
|
+
# Then check configuration
|
|
24
|
+
api_url = configs.get_value(configs.APIURL_KEY)
|
|
25
|
+
if api_url:
|
|
26
|
+
return api_url
|
|
27
|
+
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_datamint_api_key() -> Optional[str]:
|
|
32
|
+
"""Get the Datamint API key from configuration or environment variables."""
|
|
33
|
+
# First check environment variable
|
|
34
|
+
api_key = os.getenv('DATAMINT_API_KEY')
|
|
35
|
+
if api_key:
|
|
36
|
+
return api_key
|
|
37
|
+
|
|
38
|
+
# Then check configuration
|
|
39
|
+
api_key = configs.get_value(configs.APIKEY_KEY)
|
|
40
|
+
if api_key:
|
|
41
|
+
return api_key
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _get_mlflowdatamint_uri() -> Optional[str]:
|
|
47
|
+
api_url = get_datamint_api_url()
|
|
48
|
+
if not api_url:
|
|
49
|
+
return None
|
|
50
|
+
_LOGGER.debug(f"Retrieved Datamint API URL: {api_url}")
|
|
51
|
+
|
|
52
|
+
# Remove trailing slash if present
|
|
53
|
+
api_url = api_url.rstrip('/')
|
|
54
|
+
# api_url samples:
|
|
55
|
+
# https://api.datamint.io
|
|
56
|
+
# http://localhost:3001
|
|
57
|
+
|
|
58
|
+
parsed_url = urlparse(api_url)
|
|
59
|
+
base_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
|
|
60
|
+
_LOGGER.debug(f"Derived base URL for MLflow Datamint: {base_url}")
|
|
61
|
+
# FIXME: It should work with https or datamint-api server should forward https requests.
|
|
62
|
+
base_url = base_url.replace('https://', 'http://')
|
|
63
|
+
if len(base_url.replace('http:', '')) == 0:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
mlflow_uri = f"{base_url}:5000"
|
|
67
|
+
return mlflow_uri
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def setup_mlflow_environment(overwrite: bool = False,
|
|
71
|
+
set_mlflow: bool = True) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Automatically set up MLflow environment variables based on Datamint configuration.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
bool: True if MLflow environment was successfully configured, False otherwise.
|
|
77
|
+
"""
|
|
78
|
+
_LOGGER.debug("Setting up MLflow environment variables from Datamint configuration")
|
|
79
|
+
api_key = get_datamint_api_key()
|
|
80
|
+
mlflow_uri = _get_mlflowdatamint_uri()
|
|
81
|
+
if not mlflow_uri or not api_key:
|
|
82
|
+
_LOGGER.warning("Datamint configuration incomplete, cannot auto-configure MLflow")
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
if overwrite or not os.getenv('MLFLOW_TRACKING_TOKEN'):
|
|
86
|
+
os.environ['MLFLOW_TRACKING_TOKEN'] = api_key
|
|
87
|
+
if overwrite or not os.getenv('MLFLOW_TRACKING_URI'):
|
|
88
|
+
os.environ['MLFLOW_TRACKING_URI'] = mlflow_uri
|
|
89
|
+
|
|
90
|
+
if set_mlflow:
|
|
91
|
+
import mlflow
|
|
92
|
+
mlflow.set_tracking_uri(mlflow_uri)
|
|
93
|
+
|
|
94
|
+
return True
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def ensure_mlflow_configured() -> None:
|
|
98
|
+
"""
|
|
99
|
+
Ensure MLflow environment is properly configured.
|
|
100
|
+
Raises an exception if configuration is incomplete.
|
|
101
|
+
"""
|
|
102
|
+
if not setup_mlflow_environment():
|
|
103
|
+
if not os.getenv('MLFLOW_TRACKING_URI') or not os.getenv('MLFLOW_TRACKING_TOKEN'):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
"MLflow environment not configured. Please either:\n"
|
|
106
|
+
"1. Run 'datamint-config' to set up Datamint configuration, or\n"
|
|
107
|
+
"2. Set DATAMINT_API_URL and DATAMINT_API_KEY environment variables, or\n"
|
|
108
|
+
"3. Manually set MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN environment variables"
|
|
109
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modelcheckpoint import MLFlowModelCheckpoint
|