datamint 1.6.3.post1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datamint might be problematic. Click here for more details.
- datamint/apihandler/annotation_api_handler.py +125 -3
- datamint/apihandler/base_api_handler.py +30 -26
- datamint/apihandler/root_api_handler.py +160 -36
- datamint/dataset/annotation.py +221 -0
- datamint/dataset/base_dataset.py +735 -483
- datamint/dataset/dataset.py +33 -16
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/METADATA +1 -1
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/RECORD +10 -9
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/WHEEL +0 -0
- {datamint-1.6.3.post1.dist-info → datamint-1.7.1.dist-info}/entry_points.txt +0 -0
|
@@ -13,6 +13,9 @@ from requests.exceptions import HTTPError
|
|
|
13
13
|
from .dto.annotation_dto import CreateAnnotationDto, LineGeometry, BoxGeometry, CoordinateSystem, AnnotationType
|
|
14
14
|
import pydicom
|
|
15
15
|
import json
|
|
16
|
+
from deprecated import deprecated
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from tqdm.auto import tqdm
|
|
16
19
|
|
|
17
20
|
_LOGGER = logging.getLogger(__name__)
|
|
18
21
|
_USER_LOGGER = logging.getLogger('user_logger')
|
|
@@ -267,8 +270,9 @@ class AnnotationAPIHandler(BaseAPIHandler):
|
|
|
267
270
|
raise NotImplementedError("`name=string` is not supported yet for volume segmentation.")
|
|
268
271
|
if isinstance(name, dict):
|
|
269
272
|
if any(isinstance(k, tuple) for k in name.keys()):
|
|
270
|
-
raise NotImplementedError(
|
|
271
|
-
|
|
273
|
+
raise NotImplementedError(
|
|
274
|
+
"For volume segmentations, `name` must be a dictionary with integer keys only.")
|
|
275
|
+
|
|
272
276
|
# Prepare file for upload
|
|
273
277
|
if isinstance(file_path, str):
|
|
274
278
|
if file_path.endswith('.nii') or file_path.endswith('.nii.gz'):
|
|
@@ -892,7 +896,7 @@ class AnnotationAPIHandler(BaseAPIHandler):
|
|
|
892
896
|
dataset_id: Optional[str] = None,
|
|
893
897
|
worklist_id: Optional[str] = None,
|
|
894
898
|
status: Optional[Literal['new', 'published']] = None,
|
|
895
|
-
load_ai_segmentations: bool = None,
|
|
899
|
+
load_ai_segmentations: bool | None = None,
|
|
896
900
|
) -> Generator[dict, None, None]:
|
|
897
901
|
"""
|
|
898
902
|
Get annotations for a resource.
|
|
@@ -1098,6 +1102,29 @@ class AnnotationAPIHandler(BaseAPIHandler):
|
|
|
1098
1102
|
resp = self._run_request(request_params)
|
|
1099
1103
|
self._check_errors_response_json(resp)
|
|
1100
1104
|
|
|
1105
|
+
def get_annotation_by_id(self, annotation_id: str) -> dict:
|
|
1106
|
+
"""
|
|
1107
|
+
Get an annotation by its unique id.
|
|
1108
|
+
|
|
1109
|
+
Args:
|
|
1110
|
+
annotation_id (str): The annotation unique id.
|
|
1111
|
+
|
|
1112
|
+
Returns:
|
|
1113
|
+
dict: The annotation information.
|
|
1114
|
+
"""
|
|
1115
|
+
request_params = {
|
|
1116
|
+
'method': 'GET',
|
|
1117
|
+
'url': f'{self.root_url}/annotations/{annotation_id}',
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
try:
|
|
1121
|
+
resp = self._run_request(request_params)
|
|
1122
|
+
return resp.json()
|
|
1123
|
+
except HTTPError as e:
|
|
1124
|
+
_LOGGER.error(f"Error getting annotation by id {annotation_id}: {e}")
|
|
1125
|
+
raise
|
|
1126
|
+
|
|
1127
|
+
@deprecated(reason="Use download_segmentation_file instead")
|
|
1101
1128
|
def get_segmentation_file(self, resource_id: str, annotation_id: str) -> bytes:
|
|
1102
1129
|
request_params = {
|
|
1103
1130
|
'method': 'GET',
|
|
@@ -1107,6 +1134,35 @@ class AnnotationAPIHandler(BaseAPIHandler):
|
|
|
1107
1134
|
resp = self._run_request(request_params)
|
|
1108
1135
|
return resp.content
|
|
1109
1136
|
|
|
1137
|
+
def download_segmentation_file(self, annotation: str | dict, fpath_out: str | Path | None) -> bytes:
|
|
1138
|
+
"""
|
|
1139
|
+
Download the segmentation file for a given resource and annotation.
|
|
1140
|
+
|
|
1141
|
+
Args:
|
|
1142
|
+
annotation (str | dict): The annotation unique id or an annotation object.
|
|
1143
|
+
fpath_out (str | None): (Optional) The file path to save the downloaded segmentation file.
|
|
1144
|
+
|
|
1145
|
+
Returns:
|
|
1146
|
+
bytes: The content of the downloaded segmentation file in bytes format.
|
|
1147
|
+
"""
|
|
1148
|
+
if isinstance(annotation, dict):
|
|
1149
|
+
annotation_id = annotation['id']
|
|
1150
|
+
resource_id = annotation['resource_id']
|
|
1151
|
+
else:
|
|
1152
|
+
annotation_id = annotation
|
|
1153
|
+
resource_id = self.get_annotation_by_id(annotation_id)['resource_id']
|
|
1154
|
+
|
|
1155
|
+
request_params = {
|
|
1156
|
+
'method': 'GET',
|
|
1157
|
+
'url': f'{self.root_url}/annotations/{resource_id}/annotations/{annotation_id}/file',
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
resp = self._run_request(request_params)
|
|
1161
|
+
if fpath_out is not None:
|
|
1162
|
+
with open(str(fpath_out), 'wb') as f:
|
|
1163
|
+
f.write(resp.content)
|
|
1164
|
+
return resp.content
|
|
1165
|
+
|
|
1110
1166
|
def set_annotation_status(self,
|
|
1111
1167
|
project_id: str,
|
|
1112
1168
|
resource_id: str,
|
|
@@ -1124,3 +1180,69 @@ class AnnotationAPIHandler(BaseAPIHandler):
|
|
|
1124
1180
|
}
|
|
1125
1181
|
resp = self._run_request(request_params)
|
|
1126
1182
|
self._check_errors_response_json(resp)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
async def _async_download_segmentation_file(self,
|
|
1186
|
+
annotation: str | dict,
|
|
1187
|
+
save_path: str | Path,
|
|
1188
|
+
session: aiohttp.ClientSession | None = None,
|
|
1189
|
+
progress_bar: tqdm | None = None):
|
|
1190
|
+
"""
|
|
1191
|
+
Asynchronously download a segmentation file.
|
|
1192
|
+
|
|
1193
|
+
Args:
|
|
1194
|
+
annotation (str | dict): The annotation unique id or an annotation object.
|
|
1195
|
+
save_path (str | Path): The path to save the file.
|
|
1196
|
+
session (aiohttp.ClientSession): The aiohttp session to use for the request.
|
|
1197
|
+
progress_bar (tqdm | None): Optional progress bar to update after download completion.
|
|
1198
|
+
"""
|
|
1199
|
+
if isinstance(annotation, dict):
|
|
1200
|
+
annotation_id = annotation['id']
|
|
1201
|
+
resource_id = annotation['resource_id']
|
|
1202
|
+
else:
|
|
1203
|
+
annotation_id = annotation
|
|
1204
|
+
# TODO: This is inefficient as it requires an extra API call per annotation
|
|
1205
|
+
# Consider passing resource_id separately or caching annotation info
|
|
1206
|
+
resource_id = self.get_annotation_by_id(annotation_id)['resource_id']
|
|
1207
|
+
|
|
1208
|
+
url = f'{self.root_url}/annotations/{resource_id}/annotations/{annotation_id}/file'
|
|
1209
|
+
request_params = {
|
|
1210
|
+
'method': 'GET',
|
|
1211
|
+
'url': url
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
try:
|
|
1215
|
+
data_bytes = await self._run_request_async(request_params, session, 'content')
|
|
1216
|
+
with open(save_path, 'wb') as f:
|
|
1217
|
+
f.write(data_bytes)
|
|
1218
|
+
if progress_bar:
|
|
1219
|
+
progress_bar.update(1)
|
|
1220
|
+
except ResourceNotFoundError as e:
|
|
1221
|
+
e.set_params('annotation', {'annotation_id': annotation_id})
|
|
1222
|
+
raise e
|
|
1223
|
+
|
|
1224
|
+
def download_multiple_segmentations(self,
|
|
1225
|
+
annotations: list[str | dict],
|
|
1226
|
+
save_paths: list[str | Path] | str
|
|
1227
|
+
) -> None:
|
|
1228
|
+
"""
|
|
1229
|
+
Download multiple segmentation files and save them to the specified paths.
|
|
1230
|
+
|
|
1231
|
+
Args:
|
|
1232
|
+
annotations (list[str | dict]): A list of annotation unique ids or annotation objects.
|
|
1233
|
+
save_paths (list[str | Path] | str): A list of paths to save the files or a directory path.
|
|
1234
|
+
"""
|
|
1235
|
+
async def _download_all_async():
|
|
1236
|
+
async with aiohttp.ClientSession() as session:
|
|
1237
|
+
tasks = [
|
|
1238
|
+
self._async_download_segmentation_file(annotation, save_path=path, session=session, progress_bar=progress_bar)
|
|
1239
|
+
for annotation, path in zip(annotations, save_paths)
|
|
1240
|
+
]
|
|
1241
|
+
await asyncio.gather(*tasks)
|
|
1242
|
+
|
|
1243
|
+
if isinstance(save_paths, str):
|
|
1244
|
+
save_paths = [os.path.join(save_paths, f"{ann['id'] if isinstance(ann, dict) else ann}") for ann in annotations]
|
|
1245
|
+
|
|
1246
|
+
with tqdm(total=len(annotations), desc="Downloading segmentations", unit="file") as progress_bar:
|
|
1247
|
+
loop = asyncio.get_event_loop()
|
|
1248
|
+
loop.run_until_complete(_download_all_async())
|
|
@@ -85,7 +85,7 @@ class BaseAPIHandler:
|
|
|
85
85
|
msg = f"API key not provided! Use the environment variable " + \
|
|
86
86
|
f"{BaseAPIHandler.DATAMINT_API_VENV_NAME} or pass it as an argument."
|
|
87
87
|
raise DatamintException(msg)
|
|
88
|
-
self.semaphore = asyncio.Semaphore(
|
|
88
|
+
self.semaphore = asyncio.Semaphore(20)
|
|
89
89
|
|
|
90
90
|
if check_connection:
|
|
91
91
|
self.check_connection()
|
|
@@ -157,30 +157,34 @@ class BaseAPIHandler:
|
|
|
157
157
|
async def _run_request_async(self,
|
|
158
158
|
request_args: dict,
|
|
159
159
|
session: aiohttp.ClientSession | None = None,
|
|
160
|
-
data_to_get:
|
|
160
|
+
data_to_get: Literal['json', 'text', 'content'] = 'json'):
|
|
161
161
|
if session is None:
|
|
162
162
|
async with aiohttp.ClientSession() as s:
|
|
163
|
-
return await self._run_request_async(request_args, s)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
163
|
+
return await self._run_request_async(request_args, s, data_to_get)
|
|
164
|
+
|
|
165
|
+
async with self.semaphore:
|
|
166
|
+
try:
|
|
167
|
+
_LOGGER.debug(f"Running request to {request_args['url']}")
|
|
168
|
+
_LOGGER.debug(f'Equivalent curl command: "{self._generate_curl_command(request_args)}"')
|
|
169
|
+
except Exception as e:
|
|
170
|
+
_LOGGER.debug(f"Error generating curl command: {e}")
|
|
171
|
+
|
|
172
|
+
# add apikey to the headers
|
|
173
|
+
if 'headers' not in request_args:
|
|
174
|
+
request_args['headers'] = {}
|
|
175
|
+
|
|
176
|
+
request_args['headers']['apikey'] = self.api_key
|
|
177
|
+
|
|
178
|
+
async with session.request(**request_args) as response:
|
|
179
|
+
self._check_errors_response(response, request_args)
|
|
180
|
+
if data_to_get == 'json':
|
|
181
|
+
return await response.json()
|
|
182
|
+
elif data_to_get == 'text':
|
|
183
|
+
return await response.text()
|
|
184
|
+
elif data_to_get == 'content':
|
|
185
|
+
return await response.read()
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError("data_to_get must be either 'json' or 'text'")
|
|
184
188
|
|
|
185
189
|
def _check_errors_response(self,
|
|
186
190
|
response,
|
|
@@ -237,9 +241,9 @@ class BaseAPIHandler:
|
|
|
237
241
|
return f'{self.root_url}/{endpoint}'
|
|
238
242
|
|
|
239
243
|
def _run_pagination_request(self,
|
|
240
|
-
request_params:
|
|
241
|
-
return_field:
|
|
242
|
-
) -> Generator[
|
|
244
|
+
request_params: dict,
|
|
245
|
+
return_field: str | list | None = None
|
|
246
|
+
) -> Generator[dict | list, None, None]:
|
|
243
247
|
offset = 0
|
|
244
248
|
params = request_params.get('params', {})
|
|
245
249
|
while True:
|
|
@@ -219,36 +219,35 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
219
219
|
|
|
220
220
|
async with aiohttp.ClientSession() as session:
|
|
221
221
|
async def __upload_single_resource(file_path, segfiles: dict[str, list | dict], metadata_file: str | dict | None):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
return rid
|
|
222
|
+
rid = await self._upload_single_resource_async(
|
|
223
|
+
file_path=file_path,
|
|
224
|
+
mimetype=mimetype,
|
|
225
|
+
anonymize=anonymize,
|
|
226
|
+
anonymize_retain_codes=anonymize_retain_codes,
|
|
227
|
+
tags=tags,
|
|
228
|
+
session=session,
|
|
229
|
+
mung_filename=mung_filename,
|
|
230
|
+
channel=channel,
|
|
231
|
+
modality=modality,
|
|
232
|
+
publish=publish,
|
|
233
|
+
metadata_file=metadata_file,
|
|
234
|
+
)
|
|
235
|
+
if segfiles is not None:
|
|
236
|
+
fpaths = segfiles['files']
|
|
237
|
+
names = segfiles.get('names', _infinite_gen(None))
|
|
238
|
+
if isinstance(names, dict):
|
|
239
|
+
names = _infinite_gen(names)
|
|
240
|
+
frame_indices = segfiles.get('frame_index', _infinite_gen(None))
|
|
241
|
+
for f, name, frame_index in tqdm(zip(fpaths, names, frame_indices),
|
|
242
|
+
desc=f"Uploading segmentations for {file_path}",
|
|
243
|
+
total=len(fpaths)):
|
|
244
|
+
if f is not None:
|
|
245
|
+
await self._upload_segmentations_async(rid,
|
|
246
|
+
file_path=f,
|
|
247
|
+
name=name,
|
|
248
|
+
frame_index=frame_index,
|
|
249
|
+
transpose_segmentation=transpose_segmentation)
|
|
250
|
+
return rid
|
|
252
251
|
|
|
253
252
|
tasks = [__upload_single_resource(f, segfiles, metadata_file)
|
|
254
253
|
for f, segfiles, metadata_file in zip(files_path, segmentation_files, metadata_files)]
|
|
@@ -365,6 +364,32 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
365
364
|
|
|
366
365
|
return result[0]
|
|
367
366
|
|
|
367
|
+
@staticmethod
|
|
368
|
+
def _is_dicom_report(file_path: str | IO) -> bool:
|
|
369
|
+
"""
|
|
370
|
+
Check if a DICOM file is a report (e.g., Structured Report).
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
file_path: Path to the DICOM file or file-like object.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
bool: True if the DICOM file is a report, False otherwise.
|
|
377
|
+
"""
|
|
378
|
+
try:
|
|
379
|
+
if not is_dicom(file_path):
|
|
380
|
+
return False
|
|
381
|
+
|
|
382
|
+
ds = pydicom.dcmread(file_path, stop_before_pixels=True)
|
|
383
|
+
modality = getattr(ds, 'Modality', None)
|
|
384
|
+
|
|
385
|
+
# Common report modalities
|
|
386
|
+
report_modalities = {'SR', 'DOC', 'KO', 'PR', 'ESR'} # SR=Structured Report, DOC=Document, KO=Key Object, PR=Presentation State
|
|
387
|
+
|
|
388
|
+
return modality in report_modalities
|
|
389
|
+
except Exception as e:
|
|
390
|
+
_LOGGER.debug(f"Error checking if DICOM is a report: {e}")
|
|
391
|
+
return False
|
|
392
|
+
|
|
368
393
|
def upload_resources(self,
|
|
369
394
|
files_path: str | IO | Sequence[str | IO] | pydicom.dataset.Dataset,
|
|
370
395
|
mimetype: Optional[str] = None,
|
|
@@ -380,7 +405,8 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
380
405
|
transpose_segmentation: bool = False,
|
|
381
406
|
modality: Optional[str] = None,
|
|
382
407
|
assemble_dicoms: bool = True,
|
|
383
|
-
metadata: list[str | dict | None] | dict | str | None = None
|
|
408
|
+
metadata: list[str | dict | None] | dict | str | None = None,
|
|
409
|
+
discard_dicom_reports: bool = True
|
|
384
410
|
) -> list[str | Exception] | str | Exception:
|
|
385
411
|
"""
|
|
386
412
|
Upload resources.
|
|
@@ -417,6 +443,17 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
417
443
|
list[str | Exception]: A list of resource IDs or errors.
|
|
418
444
|
"""
|
|
419
445
|
|
|
446
|
+
if discard_dicom_reports:
|
|
447
|
+
if isinstance(files_path, (str, Path)):
|
|
448
|
+
files_path = [files_path]
|
|
449
|
+
elif isinstance(files_path, pydicom.dataset.Dataset):
|
|
450
|
+
files_path = [files_path]
|
|
451
|
+
|
|
452
|
+
old_size = len(files_path)
|
|
453
|
+
files_path = [f for f in files_path if not RootAPIHandler._is_dicom_report(f)]
|
|
454
|
+
if old_size != len(files_path):
|
|
455
|
+
_LOGGER.info(f"Discarded {old_size - len(files_path)} DICOM report files from upload.")
|
|
456
|
+
|
|
420
457
|
if on_error not in ['raise', 'skip']:
|
|
421
458
|
raise ValueError("on_error must be either 'raise' or 'skip'")
|
|
422
459
|
|
|
@@ -445,7 +482,7 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
445
482
|
|
|
446
483
|
segmentation_files = [segfiles if (isinstance(segfiles, dict) or segfiles is None) else {'files': segfiles}
|
|
447
484
|
for segfiles in segmentation_files]
|
|
448
|
-
|
|
485
|
+
|
|
449
486
|
for segfiles in segmentation_files:
|
|
450
487
|
if segfiles is None:
|
|
451
488
|
continue
|
|
@@ -454,7 +491,8 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
454
491
|
if 'names' in segfiles:
|
|
455
492
|
# same length as files
|
|
456
493
|
if isinstance(segfiles['names'], (list, tuple)) and len(segfiles['names']) != len(segfiles['files']):
|
|
457
|
-
raise ValueError(
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"segmentation_files['names'] must have the same length as segmentation_files['files'].")
|
|
458
496
|
|
|
459
497
|
loop = asyncio.get_event_loop()
|
|
460
498
|
task = self._upload_resources_async(files_path=files_path,
|
|
@@ -699,7 +737,7 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
699
737
|
order_field: Optional[ResourceFields] = None,
|
|
700
738
|
order_ascending: Optional[bool] = None,
|
|
701
739
|
channel: Optional[str] = None,
|
|
702
|
-
project_name:
|
|
740
|
+
project_name: str | list[str] | None = None,
|
|
703
741
|
filename: Optional[str] = None
|
|
704
742
|
) -> Generator[dict, None, None]:
|
|
705
743
|
"""
|
|
@@ -717,6 +755,8 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
717
755
|
return_ids_only (bool): Whether to return only the ids of the resources.
|
|
718
756
|
order_field (Optional[ResourceFields]): The field to order the resources. See :data:`~.base_api_handler.ResourceFields`.
|
|
719
757
|
order_ascending (Optional[bool]): Whether to order the resources in ascending order.
|
|
758
|
+
project_name (str | list[str] | None): The project name or a list of project names to filter resources by project.
|
|
759
|
+
If multiple projects are provided, resources will be filtered to include only those belonging to ALL of the specified projects.
|
|
720
760
|
|
|
721
761
|
Returns:
|
|
722
762
|
Generator[dict, None, None]: A generator of dictionaries with the resources information.
|
|
@@ -745,7 +785,10 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
745
785
|
"filename": filename,
|
|
746
786
|
}
|
|
747
787
|
if project_name is not None:
|
|
748
|
-
|
|
788
|
+
if isinstance(project_name, str):
|
|
789
|
+
project_name = [project_name]
|
|
790
|
+
payload["project"] = json.dumps({'items': project_name,
|
|
791
|
+
'filterType': 'intersection'}) # union or intersection
|
|
749
792
|
|
|
750
793
|
if tags is not None:
|
|
751
794
|
if isinstance(tags, str):
|
|
@@ -802,7 +845,7 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
802
845
|
yield from self._run_pagination_request(request_params,
|
|
803
846
|
return_field='data')
|
|
804
847
|
|
|
805
|
-
def set_resource_tags(self,
|
|
848
|
+
def set_resource_tags(self,
|
|
806
849
|
resource_id: str,
|
|
807
850
|
tags: Sequence[str],
|
|
808
851
|
):
|
|
@@ -824,6 +867,62 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
824
867
|
def _has_status_code(e, status_code: int) -> bool:
|
|
825
868
|
return hasattr(e, 'response') and (e.response is not None) and e.response.status_code == status_code
|
|
826
869
|
|
|
870
|
+
async def _async_download_file(self,
|
|
871
|
+
resource_id: str,
|
|
872
|
+
save_path: str,
|
|
873
|
+
session: aiohttp.ClientSession | None = None,
|
|
874
|
+
progress_bar: tqdm | None = None):
|
|
875
|
+
"""
|
|
876
|
+
Asynchronously download a file from the server.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
resource_id (str): The resource unique id.
|
|
880
|
+
save_path (str): The path to save the file.
|
|
881
|
+
session (aiohttp.ClientSession): The aiohttp session to use for the request.
|
|
882
|
+
progress_bar (tqdm | None): Optional progress bar to update after download completion.
|
|
883
|
+
"""
|
|
884
|
+
url = f"{self._get_endpoint_url(RootAPIHandler.ENDPOINT_RESOURCES)}/{resource_id}/file"
|
|
885
|
+
request_params = {
|
|
886
|
+
'method': 'GET',
|
|
887
|
+
'headers': {'accept': 'application/octet-stream'},
|
|
888
|
+
'url': url
|
|
889
|
+
}
|
|
890
|
+
try:
|
|
891
|
+
data_bytes = await self._run_request_async(request_params, session, 'content')
|
|
892
|
+
with open(save_path, 'wb') as f:
|
|
893
|
+
f.write(data_bytes)
|
|
894
|
+
if progress_bar:
|
|
895
|
+
progress_bar.update(1)
|
|
896
|
+
except ResourceNotFoundError as e:
|
|
897
|
+
e.set_params('resource', {'resource_id': resource_id})
|
|
898
|
+
raise e
|
|
899
|
+
|
|
900
|
+
def download_multiple_resources(self,
|
|
901
|
+
resource_ids: list[str],
|
|
902
|
+
save_path: list[str] | str
|
|
903
|
+
) -> None:
|
|
904
|
+
"""
|
|
905
|
+
Download multiple resources and save them to the specified paths.
|
|
906
|
+
|
|
907
|
+
Args:
|
|
908
|
+
resource_ids (list[str]): A list of resource unique ids.
|
|
909
|
+
save_path (list[str] | str): A list of paths to save the files or a directory path.
|
|
910
|
+
"""
|
|
911
|
+
async def _download_all_async():
|
|
912
|
+
async with aiohttp.ClientSession() as session:
|
|
913
|
+
tasks = [
|
|
914
|
+
self._async_download_file(resource_id, save_path=path, session=session, progress_bar=progress_bar)
|
|
915
|
+
for resource_id, path in zip(resource_ids, save_path)
|
|
916
|
+
]
|
|
917
|
+
await asyncio.gather(*tasks)
|
|
918
|
+
|
|
919
|
+
if isinstance(save_path, str):
|
|
920
|
+
save_path = [os.path.join(save_path, r) for r in resource_ids]
|
|
921
|
+
|
|
922
|
+
with tqdm(total=len(resource_ids), desc="Downloading resources", unit="file") as progress_bar:
|
|
923
|
+
loop = asyncio.get_event_loop()
|
|
924
|
+
loop.run_until_complete(_download_all_async())
|
|
925
|
+
|
|
827
926
|
def download_resource_file(self,
|
|
828
927
|
resource_id: str,
|
|
829
928
|
save_path: Optional[str] = None,
|
|
@@ -982,6 +1081,7 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
982
1081
|
response = self._run_request(request_params)
|
|
983
1082
|
return response.json()['data']
|
|
984
1083
|
|
|
1084
|
+
@deprecated(version='1.7')
|
|
985
1085
|
def get_datasetsinfo_by_name(self, dataset_name: str) -> list[dict]:
|
|
986
1086
|
request_params = {
|
|
987
1087
|
'method': 'GET',
|
|
@@ -1076,6 +1176,30 @@ class RootAPIHandler(BaseAPIHandler):
|
|
|
1076
1176
|
}
|
|
1077
1177
|
return self._run_request(request_params).json()['data']
|
|
1078
1178
|
|
|
1179
|
+
def get_project_resources(self, project_id: str) -> list[dict]:
|
|
1180
|
+
"""
|
|
1181
|
+
Get the resources of a project by its id.
|
|
1182
|
+
|
|
1183
|
+
Args:
|
|
1184
|
+
project_id (str): The project id.
|
|
1185
|
+
|
|
1186
|
+
Returns:
|
|
1187
|
+
list[dict]: The list of resources in the project.
|
|
1188
|
+
|
|
1189
|
+
Raises:
|
|
1190
|
+
ResourceNotFoundError: If the project does not exists.
|
|
1191
|
+
"""
|
|
1192
|
+
request_params = {
|
|
1193
|
+
'method': 'GET',
|
|
1194
|
+
'url': f'{self.root_url}/projects/{project_id}/resources'
|
|
1195
|
+
}
|
|
1196
|
+
try:
|
|
1197
|
+
return self._run_request(request_params).json()
|
|
1198
|
+
except HTTPError as e:
|
|
1199
|
+
if e.response is not None and e.response.status_code == 500:
|
|
1200
|
+
raise ResourceNotFoundError('project', {'project_id': project_id})
|
|
1201
|
+
raise e
|
|
1202
|
+
|
|
1079
1203
|
def create_project(self,
|
|
1080
1204
|
name: str,
|
|
1081
1205
|
description: str,
|