datamint 2.3.5__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datamint might be problematic. Click here for more details.

datamint/api/base_api.py CHANGED
@@ -61,22 +61,56 @@ class BaseApi:
61
61
  client: Optional HTTP client instance. If None, a new one will be created.
62
62
  """
63
63
  self.config = config
64
- self.client = client or self._create_client()
64
+ self._owns_client = client is None # Track if we created the client
65
+ self.client = client or BaseApi._create_client(config)
65
66
  self.semaphore = asyncio.Semaphore(20)
66
67
  self._api_instance: 'Api | None' = None # Injected by Api class
67
68
 
68
- def _create_client(self) -> httpx.Client:
69
- """Create and configure HTTP client with authentication and timeouts."""
70
- headers = None
71
- if self.config.api_key:
72
- headers = {"apikey": self.config.api_key}
69
+ @staticmethod
70
+ def _create_client(config: ApiConfig) -> httpx.Client:
71
+ """Create and configure HTTP client with authentication and timeouts.
72
+
73
+ The client is designed to be long-lived and reused across multiple requests.
74
+ It maintains connection pooling for improved performance.
75
+ Default limits: max_keepalive_connections=20, max_connections=100
76
+ """
77
+ headers = {"apikey": config.api_key} if config.api_key else None
73
78
 
74
79
  return httpx.Client(
75
- base_url=self.config.server_url,
80
+ base_url=config.server_url,
76
81
  headers=headers,
77
- timeout=self.config.timeout
82
+ timeout=config.timeout,
83
+ limits=httpx.Limits(
84
+ max_keepalive_connections=5, # Increased from default 20
85
+ max_connections=20, # Increased from default 100
86
+ keepalive_expiry=8
87
+ )
78
88
  )
79
89
 
90
+ def close(self) -> None:
91
+ """Close the HTTP client and release resources.
92
+
93
+ Should be called when the API instance is no longer needed.
94
+ Only closes the client if it was created by this instance.
95
+ """
96
+ if self._owns_client and self.client is not None:
97
+ self.client.close()
98
+
99
+ def __enter__(self):
100
+ """Context manager entry."""
101
+ return self
102
+
103
+ def __exit__(self, exc_type, exc_val, exc_tb):
104
+ """Context manager exit - ensures client is closed."""
105
+ self.close()
106
+
107
+ def __del__(self):
108
+ """Destructor - ensures client is closed when instance is garbage collected."""
109
+ try:
110
+ self.close()
111
+ except Exception:
112
+ pass # Ignore errors during cleanup
113
+
80
114
  def _stream_request(self, method: str, endpoint: str, **kwargs):
81
115
  """Make streaming HTTP request with error handling.
82
116
 
datamint/api/client.py CHANGED
@@ -68,6 +68,8 @@ class Api:
68
68
  f" Please check your api_key and/or other configurations. {e}")
69
69
 
70
70
  def _get_endpoint(self, name: str):
71
+ if self._client is None:
72
+ self._client = BaseApi._create_client(self.config)
71
73
  if name not in self._endpoints:
72
74
  api_class = self._API_MAP[name]
73
75
  endpoint = api_class(self.config, self._client)
@@ -25,6 +25,7 @@ import nest_asyncio # For running asyncio in jupyter notebooks
25
25
  from PIL import Image
26
26
  import io
27
27
  from datamint.types import ImagingData
28
+ from collections import defaultdict
28
29
 
29
30
 
30
31
  _LOGGER = logging.getLogger(__name__)
@@ -279,7 +280,6 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
279
280
  _LOGGER.warning(msg)
280
281
  _USER_LOGGER.warning(msg)
281
282
 
282
-
283
283
  mimetype = standardize_mimetype(mimetype)
284
284
 
285
285
  if is_a_dicom_file == True or is_dicom(file_path):
@@ -440,12 +440,12 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
440
440
 
441
441
  try:
442
442
  tasks = [__upload_single_resource(f, segfiles, metadata_file)
443
- for f, segfiles, metadata_file in zip(files_path, segmentation_files, metadata_files)]
443
+ for f, segfiles, metadata_file in zip(files_path, segmentation_files, metadata_files)]
444
444
  except ValueError:
445
445
  msg = f"Error preparing upload tasks. Try `assemble_dicom=False`."
446
446
  _LOGGER.error(msg)
447
447
  _USER_LOGGER.error(msg)
448
- raise
448
+ raise
449
449
  return await asyncio.gather(*tasks, return_exceptions=on_error == 'skip')
450
450
 
451
451
  def upload_resources(self,
@@ -996,22 +996,28 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
996
996
  raise
997
997
 
998
998
  def set_tags(self,
999
- resource: str | Resource,
999
+ resource: str | Resource | Sequence[str | Resource],
1000
1000
  tags: Sequence[str],
1001
1001
  ):
1002
1002
  """
1003
1003
  Set tags for a resource, IMPORTANT: This replaces all existing tags.
1004
1004
  Args:
1005
- resource: The resource unique id or Resource object.
1005
+ resource: The resource object or a list of resources.
1006
1006
  tags: The tags to set.
1007
1007
  """
1008
1008
  data = {'tags': tags}
1009
- resource_id = self._entid(resource)
1010
-
1011
- response = self._make_entity_request('PUT',
1012
- resource_id,
1013
- add_path='tags',
1014
- json=data)
1009
+ if isinstance(resource, Sequence):
1010
+ resource_ids = [self._entid(res) for res in resource]
1011
+ response = self._make_request('PUT',
1012
+ f'{self.endpoint_base}/tags',
1013
+ json={'resource_ids': resource_ids,
1014
+ 'tags': tags})
1015
+ else:
1016
+ resource_id = self._entid(resource)
1017
+ response = self._make_entity_request('PUT',
1018
+ resource_id,
1019
+ add_path='tags',
1020
+ json=data)
1015
1021
  return response
1016
1022
 
1017
1023
  # def get_projects(self, resource: Resource) -> Sequence[Project]:
@@ -1029,7 +1035,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
1029
1035
  # return [proj for proj in self.projects_api.get_all() if proj.id in proj_ids]
1030
1036
 
1031
1037
  def add_tags(self,
1032
- resource: str | Resource,
1038
+ resource: str | Resource | Sequence[str | Resource],
1033
1039
  tags: Sequence[str],
1034
1040
  ):
1035
1041
  """
@@ -1040,8 +1046,26 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
1040
1046
  """
1041
1047
  if isinstance(resource, str):
1042
1048
  resource = self.get_by_id(resource)
1049
+ elif isinstance(resource, Sequence):
1050
+ # Transform every str to Resource first.
1051
+ resources = [self.get_by_id(res) if isinstance(res, str) else res for res in resource]
1052
+
1053
+ # group resource having the exact same tags to minimize requests
1054
+ tag_map: dict[tuple, list[Resource]] = defaultdict(list)
1055
+ for res in resources:
1056
+ old_tags = res.tags if res.tags is not None else []
1057
+ # key = tuple(sorted(old_tags))
1058
+ key = tuple(old_tags) # keep order, assuming order matters for tags
1059
+ tag_map[key].append(res)
1060
+
1061
+ # finally, set tags for each group
1062
+ for old_tags_tuple, res_group in tag_map.items():
1063
+ old_tags = list(old_tags_tuple)
1064
+ self.set_tags(res_group, old_tags + list(tags))
1065
+ return
1066
+
1043
1067
  old_tags = resource.tags if resource.tags is not None else []
1044
- return self.set_tags(resource, old_tags + list(tags))
1068
+ self.set_tags(resource, old_tags + list(tags))
1045
1069
 
1046
1070
  def bulk_delete(self, entities: Sequence[str | Resource]) -> None:
1047
1071
  """Delete multiple entities. Faster than deleting them one by one.
@@ -30,7 +30,6 @@ ResourceFields: TypeAlias = Literal['modality', 'created_by', 'published_by', 'p
30
30
 
31
31
  _PAGE_LIMIT = 5000
32
32
 
33
-
34
33
  @deprecated(reason="Please use `from datamint import Api` instead.", version="2.0.0")
35
34
  class BaseAPIHandler:
36
35
  """
@@ -178,6 +178,8 @@ class CreateAnnotationDto:
178
178
  if model_id is not None:
179
179
  if is_model == False:
180
180
  raise ValueError("model_id==False while self.model_id is provided.")
181
+ if not isinstance(model_id, str):
182
+ raise ValueError("model_id must be a string if provided.")
181
183
  is_model = True
182
184
  self.is_model = is_model
183
185
  self.geometry = geometry
@@ -307,6 +307,10 @@ class DatamintBaseDataset:
307
307
  self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
308
308
  worklist_id = self.get_info()['worklist_id']
309
309
  groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
310
+ if not groups:
311
+ self.seglabel_list = []
312
+ self.seglabel2code = {}
313
+ return
310
314
  # order by 'index' key
311
315
  max_index = max([g['index'] for g in groups.values()])
312
316
  self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
@@ -0,0 +1 @@
1
+ from .datamintdatamodule import DatamintDataModule
@@ -0,0 +1,103 @@
1
+ from torch.utils.data import DataLoader
2
+ from datamint import Dataset
3
+ import lightning as L
4
+ from typing import Any
5
+ from copy import copy
6
+ import numpy as np
7
+
8
+
9
+ class DatamintDataModule(L.LightningDataModule):
10
+ """
11
+ LightningDataModule for Datamint datasets with train/val split.
12
+ TODO: Add support for test and predict dataloaders.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ project_name: str = "./",
18
+ batch_size: int = 32,
19
+ image_transform=None,
20
+ mask_transform=None,
21
+ alb_transform=None,
22
+ alb_train_transform=None,
23
+ alb_val_transform=None,
24
+ train_split: float = 0.9,
25
+ val_split: float = 0.1,
26
+ seed: int = 42,
27
+ num_workers: int = 4,
28
+ **dataset_kwargs: Any,
29
+ ):
30
+ super().__init__()
31
+ self.project_name = project_name
32
+ self.batch_size = batch_size
33
+ self.image_transform = image_transform
34
+ self.mask_transform = mask_transform
35
+
36
+ if alb_transform is not None and (alb_train_transform is not None or alb_val_transform is not None):
37
+ raise ValueError("You cannot specify both `alb_transform` and `alb_train_transform`/`alb_val_transform`.")
38
+
39
+ # Handle backward compatibility for alb_transform
40
+ if alb_transform is not None:
41
+ self.alb_train_transform = alb_transform
42
+ self.alb_val_transform = alb_transform
43
+ else:
44
+ self.alb_train_transform = alb_train_transform
45
+ self.alb_val_transform = alb_val_transform
46
+
47
+ self.train_split = train_split
48
+ self.val_split = val_split
49
+ self.seed = seed
50
+ self.dataset_kwargs = dataset_kwargs
51
+ self.num_workers = num_workers
52
+
53
+ self.dataset = None
54
+
55
+ def prepare_data(self) -> None:
56
+ """Download or update data if needed."""
57
+ Dataset(
58
+ project_name=self.project_name,
59
+ auto_update=True,
60
+ )
61
+
62
+ def setup(self, stage: str = None) -> None:
63
+ """Set up datasets and perform train/val split."""
64
+ if self.dataset is None:
65
+ # Create base dataset for getting indices
66
+ self.dataset = Dataset(
67
+ return_as_semantic_segmentation=True,
68
+ semantic_seg_merge_strategy="union",
69
+ return_frame_by_frame=True,
70
+ include_unannotated=False,
71
+ project_name=self.project_name,
72
+ image_transform=self.image_transform,
73
+ mask_transform=self.mask_transform,
74
+ alb_transform=None, # No transform for base dataset
75
+ auto_update=False,
76
+ **self.dataset_kwargs,
77
+ )
78
+
79
+ indices = list(copy(self.dataset.subset_indices))
80
+ rs = np.random.RandomState(self.seed)
81
+ rs.shuffle(indices)
82
+ train_end = int(self.train_split * len(indices))
83
+ train_idx = indices[:train_end]
84
+ val_idx = indices[train_end:]
85
+
86
+ self.train_dataset = copy(self.dataset).subset(train_idx)
87
+ self.train_dataset.alb_transform = self.alb_train_transform
88
+ self.val_dataset = copy(self.dataset).subset(val_idx)
89
+ self.val_dataset.alb_transform = self.alb_val_transform
90
+
91
+ def train_dataloader(self) -> DataLoader:
92
+ return self.train_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
93
+
94
+ def val_dataloader(self) -> DataLoader:
95
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
96
+
97
+ def test_dataloader(self):
98
+ # Use the same dataloader as validation for testing, because we have so few samples
99
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
100
+
101
+ def predict_dataloader(self):
102
+ # Use the same dataloader as validation for testing, because we have so few samples
103
+ return self.val_dataset.get_dataloader(batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
@@ -0,0 +1,46 @@
1
+ # Monkey patch mlflow.tracking._tracking_service.utils.get_tracking_uri
2
+ from .tracking.fluent import set_project
3
+ import mlflow.tracking._tracking_service.utils as mlflow_utils
4
+ from functools import wraps
5
+ import logging
6
+ from .env_utils import setup_mlflow_environment, ensure_mlflow_configured
7
+
8
+ _LOGGER = logging.getLogger(__name__)
9
+
10
+ # Store reference to original function
11
+ _original_get_tracking_uri = mlflow_utils.get_tracking_uri
12
+ _SETUP_CALLED_SUCCESSFULLY = False
13
+
14
+
15
+ @wraps(_original_get_tracking_uri)
16
+ def _patched_get_tracking_uri(*args, **kwargs):
17
+ """Patched version of get_tracking_uri that ensures MLflow environment is set up first.
18
+
19
+ This wrapper ensures that setup_mlflow_environment is called before any tracking
20
+ URI operations, guaranteeing proper MLflow configuration.
21
+
22
+ Args:
23
+ *args: Arguments passed to the original get_tracking_uri function.
24
+ **kwargs: Keyword arguments passed to the original get_tracking_uri function.
25
+
26
+ Returns:
27
+ The result of the original get_tracking_uri function.
28
+ """
29
+ global _SETUP_CALLED_SUCCESSFULLY
30
+ if _SETUP_CALLED_SUCCESSFULLY:
31
+ return _original_get_tracking_uri(*args, **kwargs)
32
+ try:
33
+ _SETUP_CALLED_SUCCESSFULLY = setup_mlflow_environment(set_mlflow=True)
34
+ except Exception as e:
35
+ _SETUP_CALLED_SUCCESSFULLY = False
36
+ _LOGGER.error("Failed to set up MLflow environment: %s", e)
37
+ ret = _original_get_tracking_uri(*args, **kwargs)
38
+ return ret
39
+
40
+
41
+ setup_mlflow_environment(set_mlflow=False)
42
+ # Replace the original function with our patched version
43
+ mlflow_utils.get_tracking_uri = _patched_get_tracking_uri
44
+
45
+
46
+ __all__ = ['set_project', 'setup_mlflow_environment', 'ensure_mlflow_configured']
@@ -0,0 +1 @@
1
+ from .datamint_artifacts_repo import DatamintArtifactsRepository
@@ -0,0 +1,8 @@
1
+ from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository
2
+
3
+
4
+ class DatamintArtifactsRepository(MlflowArtifactsRepository):
5
+ @classmethod
6
+ def resolve_uri(cls, artifact_uri, tracking_uri):
7
+ tracking_uri = tracking_uri.split('datamint://', maxsplit=1)[-1]
8
+ return super().resolve_uri(artifact_uri, tracking_uri)
@@ -0,0 +1,109 @@
1
+ """
2
+ Utility functions for automatically configuring MLflow environment variables
3
+ based on Datamint configuration.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import Optional
9
+ from urllib.parse import urlparse
10
+ from datamint import configs
11
+
12
+
13
+ _LOGGER = logging.getLogger(__name__)
14
+
15
+
16
+ def get_datamint_api_url() -> Optional[str]:
17
+ """Get the Datamint API URL from configuration or environment variables."""
18
+ # First check environment variable
19
+ api_url = os.getenv('DATAMINT_API_URL')
20
+ if api_url:
21
+ return api_url
22
+
23
+ # Then check configuration
24
+ api_url = configs.get_value(configs.APIURL_KEY)
25
+ if api_url:
26
+ return api_url
27
+
28
+ return None
29
+
30
+
31
+ def get_datamint_api_key() -> Optional[str]:
32
+ """Get the Datamint API key from configuration or environment variables."""
33
+ # First check environment variable
34
+ api_key = os.getenv('DATAMINT_API_KEY')
35
+ if api_key:
36
+ return api_key
37
+
38
+ # Then check configuration
39
+ api_key = configs.get_value(configs.APIKEY_KEY)
40
+ if api_key:
41
+ return api_key
42
+
43
+ return None
44
+
45
+
46
+ def _get_mlflowdatamint_uri() -> Optional[str]:
47
+ api_url = get_datamint_api_url()
48
+ if not api_url:
49
+ return None
50
+ _LOGGER.debug(f"Retrieved Datamint API URL: {api_url}")
51
+
52
+ # Remove trailing slash if present
53
+ api_url = api_url.rstrip('/')
54
+ # api_url samples:
55
+ # https://api.datamint.io
56
+ # http://localhost:3001
57
+
58
+ parsed_url = urlparse(api_url)
59
+ base_url = f"{parsed_url.scheme}://{parsed_url.hostname}"
60
+ _LOGGER.debug(f"Derived base URL for MLflow Datamint: {base_url}")
61
+ # FIXME: It should work with https or datamint-api server should forward https requests.
62
+ base_url = base_url.replace('https://', 'http://')
63
+ if len(base_url.replace('http:', '')) == 0:
64
+ return None
65
+
66
+ mlflow_uri = f"{base_url}:5000"
67
+ return mlflow_uri
68
+
69
+
70
+ def setup_mlflow_environment(overwrite: bool = False,
71
+ set_mlflow: bool = True) -> bool:
72
+ """
73
+ Automatically set up MLflow environment variables based on Datamint configuration.
74
+
75
+ Returns:
76
+ bool: True if MLflow environment was successfully configured, False otherwise.
77
+ """
78
+ _LOGGER.debug("Setting up MLflow environment variables from Datamint configuration")
79
+ api_key = get_datamint_api_key()
80
+ mlflow_uri = _get_mlflowdatamint_uri()
81
+ if not mlflow_uri or not api_key:
82
+ _LOGGER.warning("Datamint configuration incomplete, cannot auto-configure MLflow")
83
+ return False
84
+
85
+ if overwrite or not os.getenv('MLFLOW_TRACKING_TOKEN'):
86
+ os.environ['MLFLOW_TRACKING_TOKEN'] = api_key
87
+ if overwrite or not os.getenv('MLFLOW_TRACKING_URI'):
88
+ os.environ['MLFLOW_TRACKING_URI'] = mlflow_uri
89
+
90
+ if set_mlflow:
91
+ import mlflow
92
+ mlflow.set_tracking_uri(mlflow_uri)
93
+
94
+ return True
95
+
96
+
97
+ def ensure_mlflow_configured() -> None:
98
+ """
99
+ Ensure MLflow environment is properly configured.
100
+ Raises an exception if configuration is incomplete.
101
+ """
102
+ if not setup_mlflow_environment():
103
+ if not os.getenv('MLFLOW_TRACKING_URI') or not os.getenv('MLFLOW_TRACKING_TOKEN'):
104
+ raise ValueError(
105
+ "MLflow environment not configured. Please either:\n"
106
+ "1. Run 'datamint-config' to set up Datamint configuration, or\n"
107
+ "2. Set DATAMINT_API_URL and DATAMINT_API_KEY environment variables, or\n"
108
+ "3. Manually set MLFLOW_TRACKING_URI and MLFLOW_TRACKING_TOKEN environment variables"
109
+ )
@@ -0,0 +1,5 @@
1
+ from enum import Enum
2
+
3
+ class EnvVars(Enum):
4
+ DATAMINT_PROJECT_ID = "DATAMINT_PROJECT_ID"
5
+ DATAMINT_PROJECT_NAME = "DATAMINT_PROJECT_NAME"
@@ -0,0 +1 @@
1
+ from .modelcheckpoint import MLFlowModelCheckpoint