datamint 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/__init__.py +11 -0
- datamint-1.2.4.dist-info/METADATA +118 -0
- datamint-1.2.4.dist-info/RECORD +30 -0
- datamint-1.2.4.dist-info/WHEEL +4 -0
- datamint-1.2.4.dist-info/entry_points.txt +4 -0
- datamintapi/__init__.py +25 -0
- datamintapi/apihandler/annotation_api_handler.py +748 -0
- datamintapi/apihandler/api_handler.py +15 -0
- datamintapi/apihandler/base_api_handler.py +300 -0
- datamintapi/apihandler/dto/annotation_dto.py +149 -0
- datamintapi/apihandler/exp_api_handler.py +204 -0
- datamintapi/apihandler/root_api_handler.py +1013 -0
- datamintapi/client_cmd_tools/__init__.py +0 -0
- datamintapi/client_cmd_tools/datamint_config.py +168 -0
- datamintapi/client_cmd_tools/datamint_upload.py +483 -0
- datamintapi/configs.py +58 -0
- datamintapi/dataset/__init__.py +1 -0
- datamintapi/dataset/base_dataset.py +881 -0
- datamintapi/dataset/dataset.py +492 -0
- datamintapi/examples/__init__.py +1 -0
- datamintapi/examples/example_projects.py +75 -0
- datamintapi/experiment/__init__.py +1 -0
- datamintapi/experiment/_patcher.py +570 -0
- datamintapi/experiment/experiment.py +1049 -0
- datamintapi/logging.yaml +27 -0
- datamintapi/utils/dicom_utils.py +640 -0
- datamintapi/utils/io_utils.py +149 -0
- datamintapi/utils/logging_utils.py +55 -0
- datamintapi/utils/torchmetrics.py +70 -0
- datamintapi/utils/visualization.py +129 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .root_api_handler import RootAPIHandler
|
|
2
|
+
from .annotation_api_handler import AnnotationAPIHandler
|
|
3
|
+
from .exp_api_handler import ExperimentAPIHandler
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class APIHandler(RootAPIHandler, ExperimentAPIHandler, AnnotationAPIHandler):
|
|
7
|
+
"""
|
|
8
|
+
Import using this code:
|
|
9
|
+
|
|
10
|
+
.. code-block:: python
|
|
11
|
+
|
|
12
|
+
from datamintapi import APIHandler
|
|
13
|
+
api = APIHandler()
|
|
14
|
+
"""
|
|
15
|
+
pass
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from typing import Optional, Literal, Generator, TypeAlias, Dict, Union, List
|
|
2
|
+
import pydicom.dataset
|
|
3
|
+
from requests import Session
|
|
4
|
+
from requests.exceptions import HTTPError
|
|
5
|
+
import logging
|
|
6
|
+
import asyncio
|
|
7
|
+
import aiohttp
|
|
8
|
+
import nest_asyncio # For running asyncio in jupyter notebooks
|
|
9
|
+
import pydicom
|
|
10
|
+
import json
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from io import BytesIO
|
|
13
|
+
import cv2
|
|
14
|
+
import nibabel as nib
|
|
15
|
+
from nibabel.filebasedimages import FileBasedImage as nib_FileBasedImage
|
|
16
|
+
from datamintapi import configs
|
|
17
|
+
from functools import wraps
|
|
18
|
+
|
|
19
|
+
_LOGGER = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
ResourceStatus: TypeAlias = Literal['new', 'inbox', 'published', 'archived']
|
|
23
|
+
"""TypeAlias: The available resource status. Possible values: 'new', 'inbox', 'published', 'archived'.
|
|
24
|
+
"""
|
|
25
|
+
ResourceFields: TypeAlias = Literal['modality', 'created_by', 'published_by', 'published_on', 'filename', 'created_at']
|
|
26
|
+
"""TypeAlias: The available fields to order resources. Possible values: 'modality', 'created_by', 'published_by', 'published_on', 'filename', 'created_at' (default).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
_PAGE_LIMIT = 5000
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DatamintException(Exception):
|
|
33
|
+
"""
|
|
34
|
+
Base class for exceptions in this module.
|
|
35
|
+
"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ResourceNotFoundError(DatamintException):
|
|
40
|
+
"""
|
|
41
|
+
Exception raised when a resource is not found.
|
|
42
|
+
For instance, when trying to get a resource by a non-existing id.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self,
|
|
46
|
+
resource_type: str,
|
|
47
|
+
params: dict):
|
|
48
|
+
""" Constructor.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
resource_type (str): A resource type.
|
|
52
|
+
params (dict): Dict of params identifying the sought resource.
|
|
53
|
+
"""
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.resource_type = resource_type
|
|
56
|
+
self.params = params
|
|
57
|
+
|
|
58
|
+
def set_params(self, resource_type: str, params: dict):
|
|
59
|
+
self.resource_type = resource_type
|
|
60
|
+
self.params = params
|
|
61
|
+
|
|
62
|
+
def __str__(self):
|
|
63
|
+
return f"Resource '{self.resource_type}' not found for parameters: {self.params}"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class BaseAPIHandler:
|
|
67
|
+
"""
|
|
68
|
+
Class to handle the API requests to the Datamint API
|
|
69
|
+
"""
|
|
70
|
+
DATAMINT_API_VENV_NAME = configs.ENV_VARS[configs.APIKEY_KEY]
|
|
71
|
+
DEFAULT_ROOT_URL = 'https://api.datamint.io'
|
|
72
|
+
|
|
73
|
+
def __init__(self,
|
|
74
|
+
root_url: Optional[str] = None,
|
|
75
|
+
api_key: Optional[str] = None,
|
|
76
|
+
check_connection: bool = True):
|
|
77
|
+
nest_asyncio.apply() # For running asyncio in jupyter notebooks
|
|
78
|
+
self.root_url = root_url if root_url is not None else configs.get_value(configs.APIURL_KEY)
|
|
79
|
+
if self.root_url is None:
|
|
80
|
+
self.root_url = BaseAPIHandler.DEFAULT_ROOT_URL
|
|
81
|
+
self.root_url.rstrip('/')
|
|
82
|
+
|
|
83
|
+
self.api_key = api_key if api_key is not None else configs.get_value(configs.APIKEY_KEY)
|
|
84
|
+
if self.api_key is None:
|
|
85
|
+
msg = f"API key not provided! Use the environment variable " + \
|
|
86
|
+
f"{BaseAPIHandler.DATAMINT_API_VENV_NAME} or pass it as an argument."
|
|
87
|
+
raise DatamintException(msg)
|
|
88
|
+
self.semaphore = asyncio.Semaphore(10) # Limit to 10 parallel requests
|
|
89
|
+
|
|
90
|
+
if check_connection:
|
|
91
|
+
self.check_connection()
|
|
92
|
+
|
|
93
|
+
def check_connection(self):
|
|
94
|
+
try:
|
|
95
|
+
self.get_projects()
|
|
96
|
+
except Exception as e:
|
|
97
|
+
raise DatamintException("Error connecting to the Datamint API." +
|
|
98
|
+
f" Please check your api_key and/or other configurations. {e}")
|
|
99
|
+
|
|
100
|
+
def _generate_curl_command(self, request_args: dict) -> str:
|
|
101
|
+
"""
|
|
102
|
+
Generate a curl command for debugging purposes.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
request_args (dict): Request arguments dictionary containing method, url, headers, etc.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
str: Equivalent curl command
|
|
109
|
+
"""
|
|
110
|
+
method = request_args.get('method', 'GET').upper()
|
|
111
|
+
url = request_args['url']
|
|
112
|
+
headers = request_args.get('headers', {})
|
|
113
|
+
data = request_args.get('json') or request_args.get('data')
|
|
114
|
+
params = request_args.get('params')
|
|
115
|
+
|
|
116
|
+
curl_command = ['curl']
|
|
117
|
+
|
|
118
|
+
# Add method if not GET
|
|
119
|
+
if method != 'GET':
|
|
120
|
+
curl_command.extend(['-X', method])
|
|
121
|
+
|
|
122
|
+
# Add URL
|
|
123
|
+
curl_command.append(f"'{url}'")
|
|
124
|
+
|
|
125
|
+
# Add headers
|
|
126
|
+
for key, value in headers.items():
|
|
127
|
+
if key.lower() == 'apikey':
|
|
128
|
+
value = '<YOUR-API-KEY>' # Mask API key for security
|
|
129
|
+
curl_command.extend(['-H', f"'{key}: {value}'"])
|
|
130
|
+
|
|
131
|
+
# Add query parameters
|
|
132
|
+
if params:
|
|
133
|
+
param_str = '&'.join([f"{k}={v}" for k, v in params.items()])
|
|
134
|
+
url = f"{url}?{param_str}"
|
|
135
|
+
|
|
136
|
+
# Add data
|
|
137
|
+
if data:
|
|
138
|
+
if isinstance(data, aiohttp.FormData): # Check if it's aiohttp.FormData
|
|
139
|
+
# Handle FormData by extracting fields
|
|
140
|
+
form_parts = []
|
|
141
|
+
for options,headers,value in data._fields:
|
|
142
|
+
# get the name from options
|
|
143
|
+
name = options.get('name', 'file')
|
|
144
|
+
if hasattr(value, 'read'): # File-like object
|
|
145
|
+
filename = getattr(value, 'name', 'file')
|
|
146
|
+
form_parts.extend(['-F', f"'{name}=@{filename}'"])
|
|
147
|
+
else:
|
|
148
|
+
form_parts.extend(['-F', f"'{name}={value}'"])
|
|
149
|
+
curl_command.extend(form_parts)
|
|
150
|
+
elif isinstance(data, dict):
|
|
151
|
+
curl_command.extend(['-d', f"'{json.dumps(data)}'"])
|
|
152
|
+
else:
|
|
153
|
+
curl_command.extend(['-d', f"'{data}'"])
|
|
154
|
+
|
|
155
|
+
return ' '.join(curl_command)
|
|
156
|
+
|
|
157
|
+
async def _run_request_async(self,
|
|
158
|
+
request_args: dict,
|
|
159
|
+
session: aiohttp.ClientSession | None = None,
|
|
160
|
+
data_to_get: str = 'json'):
|
|
161
|
+
if session is None:
|
|
162
|
+
async with aiohttp.ClientSession() as s:
|
|
163
|
+
return await self._run_request_async(request_args, s)
|
|
164
|
+
try:
|
|
165
|
+
_LOGGER.debug(f"Running request to {request_args['url']}")
|
|
166
|
+
_LOGGER.debug(f'Equivalent curl command: "{self._generate_curl_command(request_args)}"')
|
|
167
|
+
except Exception as e:
|
|
168
|
+
_LOGGER.debug(f"Error generating curl command: {e}")
|
|
169
|
+
|
|
170
|
+
# add apikey to the headers
|
|
171
|
+
if 'headers' not in request_args:
|
|
172
|
+
request_args['headers'] = {}
|
|
173
|
+
|
|
174
|
+
request_args['headers']['apikey'] = self.api_key
|
|
175
|
+
|
|
176
|
+
async with session.request(**request_args) as response:
|
|
177
|
+
self._check_errors_response(response, request_args)
|
|
178
|
+
if data_to_get == 'json':
|
|
179
|
+
return await response.json()
|
|
180
|
+
elif data_to_get == 'text':
|
|
181
|
+
return await response.text()
|
|
182
|
+
else:
|
|
183
|
+
raise ValueError("data_to_get must be either 'json' or 'text'")
|
|
184
|
+
|
|
185
|
+
def _check_errors_response(self,
|
|
186
|
+
response,
|
|
187
|
+
request_args: dict):
|
|
188
|
+
try:
|
|
189
|
+
response.raise_for_status()
|
|
190
|
+
except HTTPError as e:
|
|
191
|
+
status_code = BaseAPIHandler.get_status_code(e)
|
|
192
|
+
if status_code >= 500 and status_code < 600:
|
|
193
|
+
_LOGGER.error(f"Error in request to {request_args['url']}: {e}")
|
|
194
|
+
if status_code >= 400 and status_code < 500:
|
|
195
|
+
try:
|
|
196
|
+
_LOGGER.error(f"Error response: {response.text}")
|
|
197
|
+
error_data = response.json()
|
|
198
|
+
except Exception as e2:
|
|
199
|
+
_LOGGER.error(f"Error parsing the response. {e2}")
|
|
200
|
+
else:
|
|
201
|
+
if isinstance(error_data['message'], str) and ' not found' in error_data['message'].lower():
|
|
202
|
+
# Will be caught by the caller and properly initialized:
|
|
203
|
+
raise ResourceNotFoundError('unknown', {})
|
|
204
|
+
|
|
205
|
+
raise e
|
|
206
|
+
|
|
207
|
+
def _check_errors_response_json(self,
|
|
208
|
+
response):
|
|
209
|
+
response_json = response.json()
|
|
210
|
+
if isinstance(response_json, dict):
|
|
211
|
+
response_json = [response_json]
|
|
212
|
+
if isinstance(response_json, list):
|
|
213
|
+
for r in response_json:
|
|
214
|
+
if isinstance(r, dict) and 'error' in r:
|
|
215
|
+
if hasattr(response, 'text'):
|
|
216
|
+
_LOGGER.error(f"Error response: {response.text}")
|
|
217
|
+
raise DatamintException(r['error'])
|
|
218
|
+
|
|
219
|
+
def _run_request(self,
|
|
220
|
+
request_args: dict,
|
|
221
|
+
session: Session = None):
|
|
222
|
+
if session is None:
|
|
223
|
+
with Session() as s:
|
|
224
|
+
return self._run_request(request_args, s)
|
|
225
|
+
_LOGGER.debug(f'Equivalent curl command: "{self._generate_curl_command(request_args)}"')
|
|
226
|
+
|
|
227
|
+
# add apikey to the headers
|
|
228
|
+
if 'headers' not in request_args:
|
|
229
|
+
request_args['headers'] = {}
|
|
230
|
+
|
|
231
|
+
request_args['headers']['apikey'] = self.api_key
|
|
232
|
+
response = session.request(**request_args)
|
|
233
|
+
self._check_errors_response(response, request_args)
|
|
234
|
+
return response
|
|
235
|
+
|
|
236
|
+
def _get_endpoint_url(self, endpoint: str) -> str:
|
|
237
|
+
return f'{self.root_url}/{endpoint}'
|
|
238
|
+
|
|
239
|
+
def _run_pagination_request(self,
|
|
240
|
+
request_params: Dict,
|
|
241
|
+
return_field: Optional[Union[str, List]] = None
|
|
242
|
+
) -> Generator[Dict, None, None]:
|
|
243
|
+
offset = 0
|
|
244
|
+
params = request_params.get('params', {})
|
|
245
|
+
while True:
|
|
246
|
+
params['offset'] = offset
|
|
247
|
+
params['limit'] = _PAGE_LIMIT
|
|
248
|
+
|
|
249
|
+
response = self._run_request(request_params)
|
|
250
|
+
self._check_errors_response_json(response)
|
|
251
|
+
response = response.json()
|
|
252
|
+
if return_field is not None:
|
|
253
|
+
if isinstance(return_field, list) or isinstance(return_field, tuple):
|
|
254
|
+
for field in return_field:
|
|
255
|
+
response = response[field]
|
|
256
|
+
else:
|
|
257
|
+
response = response[return_field]
|
|
258
|
+
for r in response:
|
|
259
|
+
yield r
|
|
260
|
+
|
|
261
|
+
if len(response) < _PAGE_LIMIT:
|
|
262
|
+
_LOGGER.debug(f"Last page reached. Total resources: {offset + len(response)}")
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
offset += _PAGE_LIMIT
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def get_status_code(e) -> int:
|
|
269
|
+
if not hasattr(e, 'response') or e.response is None:
|
|
270
|
+
return -1
|
|
271
|
+
return e.response.status_code
|
|
272
|
+
|
|
273
|
+
@staticmethod
|
|
274
|
+
def _has_status_code(e, status_code: int) -> bool:
|
|
275
|
+
return BaseAPIHandler.get_status_code(e) == status_code
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def convert_format(bytes_array: bytes,
|
|
279
|
+
mimetype: str,
|
|
280
|
+
file_path: str = None
|
|
281
|
+
) -> pydicom.dataset.Dataset | Image.Image | cv2.VideoCapture | bytes | nib_FileBasedImage:
|
|
282
|
+
content_io = BytesIO(bytes_array)
|
|
283
|
+
if mimetype == 'application/dicom':
|
|
284
|
+
return pydicom.dcmread(content_io)
|
|
285
|
+
elif mimetype in ('image/jpeg', 'image/png', 'image/tiff'):
|
|
286
|
+
return Image.open(content_io)
|
|
287
|
+
elif mimetype == 'video/mp4':
|
|
288
|
+
if file_path is None:
|
|
289
|
+
raise NotImplementedError("file_path=None is not implemented yet for video/mp4.")
|
|
290
|
+
return cv2.VideoCapture(file_path)
|
|
291
|
+
elif mimetype == 'application/json':
|
|
292
|
+
return json.loads(bytes_array)
|
|
293
|
+
elif mimetype == 'application/octet-stream':
|
|
294
|
+
return bytes_array
|
|
295
|
+
elif mimetype == 'application/nifti':
|
|
296
|
+
if file_path is None:
|
|
297
|
+
raise NotImplementedError("file_path=None is not implemented yet for application/nifti.")
|
|
298
|
+
return nib.load(file_path)
|
|
299
|
+
|
|
300
|
+
raise ValueError(f"Unsupported mimetype: {mimetype}")
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data Transfer Objects (DTOs) for handling annotations in the datamint-python-api.
|
|
3
|
+
|
|
4
|
+
This module provides classes for creating and manipulating annotation data
|
|
5
|
+
that can be sent to or received from the Datamint API. It includes structures
|
|
6
|
+
for different annotation geometry types, metadata, and formatting utilities.
|
|
7
|
+
|
|
8
|
+
Classes:
|
|
9
|
+
Handles (cornerstone): Manages annotation control points and handle properties.
|
|
10
|
+
ExternalDescription (cornerstone): Contains external metadata for annotations.
|
|
11
|
+
Metadata (cornerstone): Nested class for managing annotation positioning and reference metadata.
|
|
12
|
+
SamGeometry (datamint): Represents Segment Anything Model geometry with boxes and points.
|
|
13
|
+
MainGeometry: Combines SAM geometry with external descriptions.
|
|
14
|
+
CreateAnnotationDto: Main DTO for creating annotation requests.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from typing import Any, TypeAlias, Literal
|
|
19
|
+
import logging
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from datamintapi.utils.dicom_utils import pixel_to_patient
|
|
22
|
+
import pydicom
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_LOGGER = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
CoordinateSystem: TypeAlias = Literal['pixel', 'patient']
|
|
30
|
+
"""TypeAlias: The available coordinate systems for annotation geometry. Possible values are 'pixel' and 'patient' (used in DICOMs).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AnnotationType(Enum):
|
|
35
|
+
SEGMENTATION = 'segmentation'
|
|
36
|
+
AREA = 'area'
|
|
37
|
+
DISTANCE = 'distance'
|
|
38
|
+
ANGLE = 'angle'
|
|
39
|
+
POINT = 'point'
|
|
40
|
+
LINE = 'line'
|
|
41
|
+
REGION = 'region'
|
|
42
|
+
SQUARE = 'square'
|
|
43
|
+
CIRCLE = 'circle'
|
|
44
|
+
CATEGORY = 'category'
|
|
45
|
+
LABEL = 'label'
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _remove_none(d: dict) -> dict:
|
|
49
|
+
return {k: _remove_none(v) for k, v in d.items() if v is not None} if isinstance(d, dict) else d
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Box:
|
|
53
|
+
def __init__(self, x0, y0, x1, y1, frame_index):
|
|
54
|
+
self.x0 = x0
|
|
55
|
+
self.y0 = y0
|
|
56
|
+
self.x1 = x1
|
|
57
|
+
self.y1 = y1
|
|
58
|
+
self.frame_index = frame_index
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Geometry:
|
|
62
|
+
def __init__(self, type: AnnotationType | str):
|
|
63
|
+
self.type = type if isinstance(type, AnnotationType) else AnnotationType(type)
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> dict:
|
|
66
|
+
raise NotImplementedError("Subclasses must implement to_dict method.")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class LineGeometry(Geometry):
|
|
70
|
+
def __init__(self, point1: tuple[float, float, float],
|
|
71
|
+
point2: tuple[float, float, float]):
|
|
72
|
+
super().__init__(AnnotationType.LINE)
|
|
73
|
+
if isinstance(point1, np.ndarray):
|
|
74
|
+
point1 = point1.tolist()
|
|
75
|
+
if isinstance(point2, np.ndarray):
|
|
76
|
+
point2 = point2.tolist()
|
|
77
|
+
self.point1 = point1
|
|
78
|
+
self.point2 = point2
|
|
79
|
+
|
|
80
|
+
def to_dict(self) -> dict:
|
|
81
|
+
return {
|
|
82
|
+
'points': [self.point1, self.point2],
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def from_dicom(ds: pydicom.Dataset,
|
|
87
|
+
point1: tuple[int, int],
|
|
88
|
+
point2: tuple[int, int],
|
|
89
|
+
slice_index: int | None = None) -> 'LineGeometry':
|
|
90
|
+
pixel_x1, pixel_y1 = point1
|
|
91
|
+
pixel_x2, pixel_y2 = point2
|
|
92
|
+
|
|
93
|
+
new_point1 = pixel_to_patient(ds, pixel_x1, pixel_y1,
|
|
94
|
+
slice_index=slice_index)
|
|
95
|
+
new_point2 = pixel_to_patient(ds, pixel_x2, pixel_y2,
|
|
96
|
+
slice_index=slice_index)
|
|
97
|
+
return LineGeometry(new_point1, new_point2)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class CreateAnnotationDto:
|
|
101
|
+
def __init__(self,
|
|
102
|
+
type: AnnotationType | str,
|
|
103
|
+
identifier: str,
|
|
104
|
+
scope: str,
|
|
105
|
+
annotation_worklist_id: str,
|
|
106
|
+
value=None,
|
|
107
|
+
imported_from: str | None = None,
|
|
108
|
+
import_author: str | None = None,
|
|
109
|
+
frame_index: int | None = None,
|
|
110
|
+
is_model: bool = None,
|
|
111
|
+
model_id: str | None = None,
|
|
112
|
+
geometry: Geometry | None = None,
|
|
113
|
+
units: str = None):
|
|
114
|
+
self.type = type if isinstance(type, AnnotationType) else AnnotationType(type)
|
|
115
|
+
self.value = value
|
|
116
|
+
self.identifier = identifier
|
|
117
|
+
self.scope = scope
|
|
118
|
+
self.annotation_worklist_id = annotation_worklist_id
|
|
119
|
+
self.imported_from = imported_from
|
|
120
|
+
self.import_author = import_author
|
|
121
|
+
self.frame_index = frame_index
|
|
122
|
+
self.units = units
|
|
123
|
+
self.model_id = model_id
|
|
124
|
+
if model_id is not None:
|
|
125
|
+
if is_model == False:
|
|
126
|
+
raise ValueError("model_id==False while self.model_id is provided.")
|
|
127
|
+
is_model = True
|
|
128
|
+
self.is_model = is_model
|
|
129
|
+
self.geometry = geometry
|
|
130
|
+
|
|
131
|
+
if geometry is not None and self.type != self.geometry.type:
|
|
132
|
+
raise ValueError(f"Annotation type {self.type} does not match geometry type {self.geometry.type}.")
|
|
133
|
+
|
|
134
|
+
def to_dict(self) -> dict[str, Any]:
|
|
135
|
+
ret = {
|
|
136
|
+
"value": self.value,
|
|
137
|
+
"type": self.type.value,
|
|
138
|
+
"identifier": self.identifier,
|
|
139
|
+
"scope": self.scope,
|
|
140
|
+
'frame_index': self.frame_index,
|
|
141
|
+
'annotation_worklist_id': self.annotation_worklist_id,
|
|
142
|
+
'imported_from': self.imported_from,
|
|
143
|
+
'import_author': self.import_author,
|
|
144
|
+
'units': self.units,
|
|
145
|
+
"geometry": self.geometry.to_dict() if self.geometry else None,
|
|
146
|
+
"is_model": self.is_model,
|
|
147
|
+
"model_id": self.model_id
|
|
148
|
+
}
|
|
149
|
+
return _remove_none(ret)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from datamintapi.apihandler.base_api_handler import BaseAPIHandler
|
|
2
|
+
from typing import Optional, Dict, List, Union, Any
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
|
|
7
|
+
_LOGGER = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ExperimentAPIHandler(BaseAPIHandler):
|
|
11
|
+
def __init__(self,
|
|
12
|
+
root_url: Optional[str] = None,
|
|
13
|
+
api_key: Optional[str] = None,
|
|
14
|
+
check_connection: bool = True,
|
|
15
|
+
**kwargs):
|
|
16
|
+
super().__init__(root_url=root_url, api_key=api_key, check_connection=check_connection, **kwargs)
|
|
17
|
+
self.exp_url = f"{self.root_url}/experiments"
|
|
18
|
+
|
|
19
|
+
def create_experiment(self,
|
|
20
|
+
dataset_id: str,
|
|
21
|
+
name: str,
|
|
22
|
+
description: str,
|
|
23
|
+
environment: Dict) -> str:
|
|
24
|
+
request_params = {
|
|
25
|
+
'method': 'POST',
|
|
26
|
+
'url': self.exp_url,
|
|
27
|
+
'json': {"dataset_id": dataset_id,
|
|
28
|
+
"name": name,
|
|
29
|
+
"description": description,
|
|
30
|
+
"environment": environment
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
_LOGGER.debug(f"Creating experiment with name {name} and params {json.dumps(request_params)}")
|
|
35
|
+
|
|
36
|
+
response = self._run_request(request_params)
|
|
37
|
+
|
|
38
|
+
return response.json()['id']
|
|
39
|
+
|
|
40
|
+
def get_experiment_by_id(self, exp_id: str) -> Dict:
|
|
41
|
+
request_params = {
|
|
42
|
+
'method': 'GET',
|
|
43
|
+
'url': f"{self.exp_url}/{exp_id}"
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
response = self._run_request(request_params)
|
|
47
|
+
|
|
48
|
+
return response.json()
|
|
49
|
+
|
|
50
|
+
def get_experiments(self) -> List[Dict]:
|
|
51
|
+
request_params = {
|
|
52
|
+
'method': 'GET',
|
|
53
|
+
'url': self.exp_url
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
response = self._run_request(request_params)
|
|
57
|
+
|
|
58
|
+
return response.json()
|
|
59
|
+
|
|
60
|
+
def get_experiment_logs(self, exp_id: str) -> List[Dict]:
|
|
61
|
+
request_params = {
|
|
62
|
+
'method': 'GET',
|
|
63
|
+
'url': f"{self.exp_url}/{exp_id}/log"
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
response = self._run_request(request_params)
|
|
67
|
+
|
|
68
|
+
return response.json()
|
|
69
|
+
|
|
70
|
+
def log_summary(self,
|
|
71
|
+
exp_id: str,
|
|
72
|
+
result_summary: Dict,
|
|
73
|
+
) -> None:
|
|
74
|
+
request_params = {
|
|
75
|
+
'method': 'POST',
|
|
76
|
+
'url': f"{self.exp_url}/{exp_id}/summary",
|
|
77
|
+
'json': {"result_summary": result_summary}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
resp = self._run_request(request_params)
|
|
81
|
+
|
|
82
|
+
def update_experiment(self,
|
|
83
|
+
exp_id: str,
|
|
84
|
+
name: Optional[str] = None,
|
|
85
|
+
description: Optional[str] = None,
|
|
86
|
+
result_summary: Optional[Dict] = None) -> None:
|
|
87
|
+
|
|
88
|
+
# check that at least one of the optional parameters is not None
|
|
89
|
+
if not any([name, description, result_summary]):
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
data = {}
|
|
93
|
+
|
|
94
|
+
if name is not None:
|
|
95
|
+
data['name'] = name
|
|
96
|
+
if description is not None:
|
|
97
|
+
data['description'] = description
|
|
98
|
+
if result_summary is not None:
|
|
99
|
+
data['result_summary'] = result_summary
|
|
100
|
+
|
|
101
|
+
headers = {
|
|
102
|
+
'accept': 'application/json',
|
|
103
|
+
'Content-Type': 'application/json',
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
request_params = {
|
|
107
|
+
'method': 'PATCH',
|
|
108
|
+
'url': f"{self.exp_url}/{exp_id}",
|
|
109
|
+
'json': data,
|
|
110
|
+
'headers': headers
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
resp = self._run_request(request_params)
|
|
114
|
+
|
|
115
|
+
def log_entry(self,
|
|
116
|
+
exp_id: str,
|
|
117
|
+
entry: Dict):
|
|
118
|
+
|
|
119
|
+
if not isinstance(entry, dict):
|
|
120
|
+
raise ValueError(f"Invalid type for entry: {type(entry)}")
|
|
121
|
+
|
|
122
|
+
request_params = {
|
|
123
|
+
'method': 'POST',
|
|
124
|
+
'url': f"{self.exp_url}/{exp_id}/log",
|
|
125
|
+
'json': entry
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
resp = self._run_request(request_params)
|
|
129
|
+
return resp
|
|
130
|
+
|
|
131
|
+
def finish_experiment(self, exp_id: str):
|
|
132
|
+
pass
|
|
133
|
+
# _LOGGER.info(f"Finishing experiment with id {exp_id}")
|
|
134
|
+
# _LOGGER.warning("Finishing experiment not implemented yet")
|
|
135
|
+
# request_params = {
|
|
136
|
+
# 'method': 'POST',
|
|
137
|
+
# 'url': f"{self.exp_url}/{exp_id}/finish"
|
|
138
|
+
# }
|
|
139
|
+
|
|
140
|
+
# resp = self._run_request(request_params)
|
|
141
|
+
|
|
142
|
+
def log_model(self,
|
|
143
|
+
exp_id: str,
|
|
144
|
+
model: Union[Any, str, BytesIO],
|
|
145
|
+
hyper_params: Optional[Dict] = None,
|
|
146
|
+
torch_save_kwargs: Dict = {}) -> Dict:
|
|
147
|
+
import torch
|
|
148
|
+
if isinstance(model, torch.nn.Module):
|
|
149
|
+
f = BytesIO()
|
|
150
|
+
torch.save(model, f, **torch_save_kwargs)
|
|
151
|
+
f.seek(0)
|
|
152
|
+
f.name = None
|
|
153
|
+
elif isinstance(model, str):
|
|
154
|
+
with open(model, 'rb') as f1:
|
|
155
|
+
f = BytesIO(f1.read())
|
|
156
|
+
f.name = None
|
|
157
|
+
elif isinstance(model, BytesIO):
|
|
158
|
+
f = model
|
|
159
|
+
else:
|
|
160
|
+
raise ValueError(f"Invalid type for model: {type(model)}")
|
|
161
|
+
|
|
162
|
+
name = None
|
|
163
|
+
f.name = name
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
json_data = hyper_params
|
|
167
|
+
json_data['model_name'] = name
|
|
168
|
+
request_params = {
|
|
169
|
+
'method': 'POST',
|
|
170
|
+
'url': f"{self.exp_url}/{exp_id}/model",
|
|
171
|
+
'data': json_data,
|
|
172
|
+
'files': [(None, f)],
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
resp = self._run_request(request_params).json()
|
|
176
|
+
return resp[0]
|
|
177
|
+
finally:
|
|
178
|
+
f.close()
|
|
179
|
+
|
|
180
|
+
def get_experiment_by_name(self, name: str, project: Dict) -> Optional[Dict]:
|
|
181
|
+
"""
|
|
182
|
+
Get the experiment by name of the project.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
name (str): Name of the experiment.
|
|
186
|
+
project (Dict): The project to search for the experiment.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Optional[Dict]: The experiment if found, otherwise None.
|
|
190
|
+
"""
|
|
191
|
+
# uses GET /projects/{project_id}/experiments
|
|
192
|
+
|
|
193
|
+
project_id = project['id']
|
|
194
|
+
request_params = {
|
|
195
|
+
'method': 'GET',
|
|
196
|
+
'url': f"{self.root_url}/projects/{project_id}/experiments"
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
response = self._run_request(request_params)
|
|
200
|
+
experiments = response.json()
|
|
201
|
+
for exp in experiments:
|
|
202
|
+
if exp['name'] == name:
|
|
203
|
+
return exp
|
|
204
|
+
return None
|