datamint 2.3.5__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

datamint/api/base_api.py CHANGED
@@ -61,22 +61,56 @@ class BaseApi:
61
61
  client: Optional HTTP client instance. If None, a new one will be created.
62
62
  """
63
63
  self.config = config
64
- self.client = client or self._create_client()
64
+ self._owns_client = client is None # Track if we created the client
65
+ self.client = client or BaseApi._create_client(config)
65
66
  self.semaphore = asyncio.Semaphore(20)
66
67
  self._api_instance: 'Api | None' = None # Injected by Api class
67
68
 
68
- def _create_client(self) -> httpx.Client:
69
- """Create and configure HTTP client with authentication and timeouts."""
70
- headers = None
71
- if self.config.api_key:
72
- headers = {"apikey": self.config.api_key}
69
+ @staticmethod
70
+ def _create_client(config: ApiConfig) -> httpx.Client:
71
+ """Create and configure HTTP client with authentication and timeouts.
72
+
73
+ The client is designed to be long-lived and reused across multiple requests.
74
+ It maintains connection pooling for improved performance.
75
+ Default limits: max_keepalive_connections=20, max_connections=100
76
+ """
77
+ headers = {"apikey": config.api_key} if config.api_key else None
73
78
 
74
79
  return httpx.Client(
75
- base_url=self.config.server_url,
80
+ base_url=config.server_url,
76
81
  headers=headers,
77
- timeout=self.config.timeout
82
+ timeout=config.timeout,
83
+ limits=httpx.Limits(
84
+ max_keepalive_connections=5, # Increased from default 20
85
+ max_connections=20, # Increased from default 100
86
+ keepalive_expiry=8
87
+ )
78
88
  )
79
89
 
90
+ def close(self) -> None:
91
+ """Close the HTTP client and release resources.
92
+
93
+ Should be called when the API instance is no longer needed.
94
+ Only closes the client if it was created by this instance.
95
+ """
96
+ if self._owns_client and self.client is not None:
97
+ self.client.close()
98
+
99
+ def __enter__(self):
100
+ """Context manager entry."""
101
+ return self
102
+
103
+ def __exit__(self, exc_type, exc_val, exc_tb):
104
+ """Context manager exit - ensures client is closed."""
105
+ self.close()
106
+
107
+ def __del__(self):
108
+ """Destructor - ensures client is closed when instance is garbage collected."""
109
+ try:
110
+ self.close()
111
+ except Exception:
112
+ pass # Ignore errors during cleanup
113
+
80
114
  def _stream_request(self, method: str, endpoint: str, **kwargs):
81
115
  """Make streaming HTTP request with error handling.
82
116
 
datamint/api/client.py CHANGED
@@ -68,6 +68,8 @@ class Api:
68
68
  f" Please check your api_key and/or other configurations. {e}")
69
69
 
70
70
  def _get_endpoint(self, name: str):
71
+ if self._client is None:
72
+ self._client = BaseApi._create_client(self.config)
71
73
  if name not in self._endpoints:
72
74
  api_class = self._API_MAP[name]
73
75
  endpoint = api_class(self.config, self._client)
@@ -30,7 +30,6 @@ ResourceFields: TypeAlias = Literal['modality', 'created_by', 'published_by', 'p
30
30
 
31
31
  _PAGE_LIMIT = 5000
32
32
 
33
-
34
33
  @deprecated(reason="Please use `from datamint import Api` instead.", version="2.0.0")
35
34
  class BaseAPIHandler:
36
35
  """
@@ -178,6 +178,8 @@ class CreateAnnotationDto:
178
178
  if model_id is not None:
179
179
  if is_model == False:
180
180
  raise ValueError("model_id==False while self.model_id is provided.")
181
+ if not isinstance(model_id, str):
182
+ raise ValueError("model_id must be a string if provided.")
181
183
  is_model = True
182
184
  self.is_model = is_model
183
185
  self.geometry = geometry
@@ -307,6 +307,10 @@ class DatamintBaseDataset:
307
307
  self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
308
308
  worklist_id = self.get_info()['worklist_id']
309
309
  groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
310
+ if not groups:
311
+ self.seglabel_list = []
312
+ self.seglabel2code = {}
313
+ return
310
314
  # order by 'index' key
311
315
  max_index = max([g['index'] for g in groups.values()])
312
316
  self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
@@ -0,0 +1 @@
1
+ from .datamintdatamodule import DatamintDataModule
@@ -0,0 +1,103 @@
1
+ from torch.utils.data import DataLoader
2
+ from datamint import Dataset
3
+ import lightning as L
4
+ from typing import Any
5
+ from copy import copy
6
+ import numpy as np
7
+
8
+
9
+ class DatamintDataModule(L.LightningDataModule):
10
+ """
11
+ LightningDataModule for Datamint datasets with train/val split.
12
+ TODO: Add support for test and predict dataloaders.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ project_name: str = "./",
18
+ batch_size: int = 32,
19
+ image_transform=None,
20
+ mask_transform=None,
21
+ alb_transform=None,
22
+ alb_train_transform=None,
23
+ alb_val_transform=None,
24
+ train_split: float = 0.9,
25
+ val_split: float = 0.1,
26
+ seed: int = 42,
27
+ num_workers: int = 4,
28
+ **dataset_kwargs: Any,
29
+ ):
30
+ super().__init__()
31
+ self.project_name = project_name
32
+ self.batch_size = batch_size
33
+ self.image_transform = image_transform
34
+ self.mask_transform = mask_transform
35
+
36
+ if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
37
+ raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
38
+
39
+ # Handle backward compatibility for alb_transform
40
+ if alb_transform is not None:
41
+ self.alb_train_transform = alb_transform
42
+ self.alb_val_transform = alb_transform
43
+ else:
44
+ self.alb_train_transform = alb_train_transform
45
+ self.alb_val_transform = alb_val_transform
46
+
47
+ self.train_split = train_split
48
+ self.val_split = val_split
49
+ self.seed = seed
50
+ self.dataset_kwargs = dataset_kwargs
51
+ self.num_workers = num_workers
52
+
53
+ self.dataset = None
54
+
55
+ def prepare_data(self) -> None:
56
+ """Download or update data if needed."""
57
+ Dataset(
58
+ project_name=self.project_name,
59
+ auto_update=True,
60
+ )
61
+
62
+ def setup(self, stage: str = None) -> None:
63
+ """Set up datasets and perform train/val split."""
64
+ if self.dataset is None:
65
+ # Create base dataset for getting indices
66
+ self.dataset = Dataset(
67
+ return_as_semantic_segmentation=True,
68
+ semantic_seg_merge_strategy="union",
69
+ return_frame_by_frame=True,
70
+ include_unannotated=False,
71
+ project_name=self.project_name,
72
+ image_transform=self.image_transform,
73
+ mask_transform=self.mask_transform,
74
+ alb_transform=None, # No transform for base dataset
75
+ auto_update=False,
76
+ **self.dataset_kwargs,
77
+ )
78
+
79
+ indices = list(copy(self.dataset.subset_indices))
80
+ rs = np.random.RandomState(self.seed)
81
+ rs.shuffle(indices)
82
+ train_end = int(self.train_split * len(indices))
83
+ train_idx = indices[:train_end]
84
+ val_idx = indices[train_end:]
85
+
86
+ self.train_dataset = copy(self.dataset).subset(train_idx)
87
+ self.train_dataset.alb_transform = self.alb_train_transform
88
+ self.val_dataset = copy(self.dataset).subset(val_idx)
89
+ self.val_dataset.alb_transform = self.alb_val_transform
90
+
91
+ def train_dataloader(self) -> DataLoader:
92
+ return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
93
+
94
+ def val_dataloader(self) -> DataLoader:
95
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
96
+
97
+ def test_dataloader(self):
98
+ # Use the same dataloader as validation for testing, because we have so few samples
99
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
100
+
101
+ def predict_dataloader(self):
102
+ # Use the same dataloader as validation for testing, because we have so few samples
103
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
@@ -0,0 +1,46 @@
1
+ # Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
2
+ from .tracking.fluent import set_project
3
+ import mlflow.tracking._tracking_service.utils as mlflow_utils
4
+ from functools import wraps
5
+ import logging
6
+ from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
7
+
8
+ _LOGGER = logging.getLogger(__name__)
9
+
10
+ # Store reference to original function
11
+ _original_get_tracking_uri = mlflow_utils.get_tracking_uri
12
+ _SETUP_CALLED_SUCCESSFULLY = False
13
+
14
+
15
+ @wraps(_original_get_tracking_uri)
16
+ def _patched_get_tracking_uri(*args, **kwargs):
17
+ """Patched version of get_tracking_uri that ensures MLflow environment is set up first.
18
+
19
+ This wrapper ensures that setup_mlflow_environment is called before any tracking
20
+ URI operations, guaranteeing proper MLflow configuration.
21
+
22
+ Args:
23
+ *args: Arguments passed to the original get_tracking_uri function.
24
+ **kwargs: Keyword arguments passed to the original get_tracking_uri function.
25
+
26
+ Returns:
27
+ The result of the original get_tracking_uri function.
28
+ """
29
+ global _SETUP_CALLED_SUCCESSFULLY
30
+ if _SETUP_CALLED_SUCCESSFULLY:
31
+ return _original_get_tracking_uri(*args, **kwargs)
32
+ try:
33
+ _SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
34
+ except Exception as e:
35
+ _SETUP_CALLED_SUCCESSFULLY = False
36
+ _LOGGER.error("Failed to set up MLflow environment: %s", e)
37
+ ret = _original_get_tracking_uri(*args, **kwargs)
38
+ return ret
39
+
40
+
41
+ setup_mlflow_environment(set_mlflow=False)
42
+ # Replace the original function with our patched version
43
+ mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
44
+
45
+
46
+ __all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured']
@@ -0,0 +1 @@
1
+ from .datamint_artifacts_repo import DatamintArtifactsRepository
@@ -0,0 +1,8 @@
1
+ from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
2
+
3
+
4
+ class DatamintArtifactsRepository(MlflowArtifactsRepository):
5
+ @classmethod
6
+ def resolve_uri(cls, artifact_uri, tracking_uri):
7
+ tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
8
+ return super().resolve_uri(artifact_uri, tracking_uri)
@@ -0,0 +1,109 @@
1
+ """
2
+ Utility functions for automatically configuring MLflow environment variables
3
+ based on Datamint configuration.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import Optional
9
+ from urllib.parse import urlparse
10
+ from datamint import configs
11
+
12
+
13
+ _LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ def get_datamint_api_url() -> Optional[str]:
17
+ """Get the Datamint API URL from configuration or environment variables."""
18
+ # First check environment variable
19
+ api_url = os.getenv('DATAMINT_API_URL')
20
+ if api_url:
21
+ return api_url
22
+
23
+ # Then check configuration
24
+ api_url = configs.get_value(configs.APIURL_KEY)
25
+ if api_url:
26
+ return api_url
27
+
28
+ return None
29
+
30
+
31
+ def get_datamint_api_key() -> Optional[str]:
32
+ """Get the Datamint API key from configuration or environment variables."""
33
+ # First check environment variable
34
+ api_key = os.getenv('DATAMINT_API_KEY')
35
+ if api_key:
36
+ return api_key
37
+
38
+ # Then check configuration
39
+ api_key = configs.get_value(configs.APIKEY_KEY)
40
+ if api_key:
41
+ return api_key
42
+
43
+ return None
44
+
45
+
46
+ def _get_mlflowdatamint_uri() -> Optional[str]:
47
+ api_url = get_datamint_api_url()
48
+ if not api_url:
49
+ return None
50
+ _LOGGER.debug(f"Retrieved Datamint API URL: {api_url}")
51
+
52
+ # Remove trailing slash if present
53
+ api_url = api_url.rstrip('/')
54
+ # api_url samples:
55
+ # https://api.datamint.io
56
+ # http://localhost:3001
57
+
58
+ parsed_url = urlparse(api_url)
59
+ base_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
60
+ _LOGGER.debug(f"Derived base URL for MLflow Datamint: {base_url}")
61
+ # FIXME: It should work with https or datamint-api server should forward https requests.
62
+ base_url = base_url.replace('https://', 'http://')
63
+ if len(base_url.replace('http:', '')) == 0:
64
+ return None
65
+
66
+ mlflow_uri = f"{base_url}:5000"
67
+ return mlflow_uri
68
+
69
+
70
+ def setup_mlflow_environment(overwrite: bool = False,
71
+ set_mlflow: bool = True) -> bool:
72
+ """
73
+ Automatically set up MLflow environment variables based on Datamint configuration.
74
+
75
+ Returns:
76
+ bool: True if MLflow environment was successfully configured, False otherwise.
77
+ """
78
+ _LOGGER.debug("Setting up MLflow environment variables from Datamint configuration")
79
+ api_key = get_datamint_api_key()
80
+ mlflow_uri = _get_mlflowdatamint_uri()
81
+ if not mlflow_uri or not api_key:
82
+ _LOGGER.warning("Datamint configuration incomplete, cannot auto-configure MLflow")
83
+ return False
84
+
85
+ if overwrite or not os.getenv('MLFLOW_TRACKING_TOKEN'):
86
+ os.environ['MLFLOW_TRACKING_TOKEN'] = api_key
87
+ if overwrite or not os.getenv('MLFLOW_TRACKING_URI'):
88
+ os.environ['MLFLOW_TRACKING_URI'] = mlflow_uri
89
+
90
+ if set_mlflow:
91
+ import mlflow
92
+ mlflow.set_tracking_uri(mlflow_uri)
93
+
94
+ return True
95
+
96
+
97
+ def ensure_mlflow_configured() -> None:
98
+ """
99
+ Ensure MLflow environment is properly configured.
100
+ Raises an exception if configuration is incomplete.
101
+ """
102
+ if not setup_mlflow_environment():
103
+ if not os.getenv('MLFLOW_TRACKING_URI') or not os.getenv('MLFLOW_TRACKING_TOKEN'):
104
+ raise ValueError(
105
+ "MLflow environment not configured. Please either:\n"
106
+ "1. Run 'datamint-config' to set up Datamint configuration, or\n"
107
+ "2. Set DATAMINT_API_URL and DATAMINT_API_KEY environment variables, or\n"
108
+ "3. Manually set MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN environment variables"
109
+ )
@@ -0,0 +1,5 @@
1
+ from enum import Enum
2
+
3
+ class EnvVars(Enum):
4
+ DATAMINT_PROJECT_ID = "DATAMINT_PROJECT_ID"
5
+ DATAMINT_PROJECT_NAME = "DATAMINT_PROJECT_NAME"
@@ -0,0 +1 @@
1
+ from .modelcheckpoint import MLFlowModelCheckpoint