datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -1,10 +1,11 @@
1
- from typing import Any, Optional, Sequence, TypeAlias, Literal, IO
1
+ from typing import TypeAlias, Literal, IO
2
+ from collections.abc import Sequence
2
3
  from ..base_api import ApiConfig, BaseApi
3
4
  from ..entity_base_api import CreatableEntityApi, DeletableEntityApi
4
5
  from datamint.entities.resource import Resource
5
- from datamint.entities.project import Project
6
- from datamint.entities.annotation import Annotation
6
+ from datamint.entities.annotations.annotation import Annotation
7
7
  from datamint.exceptions import DatamintException, ResourceNotFoundError
8
+ from datamint.api.dto import AnnotationType
8
9
  import httpx
9
10
  from datetime import date
10
11
  import json
@@ -21,10 +22,10 @@ from tqdm.auto import tqdm
21
22
  import asyncio
22
23
  import aiohttp
23
24
  from pathlib import Path
24
- import nest_asyncio # For running asyncio in jupyter notebooks
25
25
  from PIL import Image
26
26
  import io
27
27
  from datamint.types import ImagingData
28
+ from collections import defaultdict
28
29
 
29
30
 
30
31
  _LOGGER = logging.getLogger(__name__)
@@ -52,9 +53,11 @@ def _open_io(file_path: str | Path | IO, mode: str = 'rb') -> IO:
52
53
  class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
53
54
  """API handler for resource-related endpoints."""
54
55
 
56
+ _ENDPOINT_BASE = 'resources'
57
+
55
58
  def __init__(self,
56
59
  config: ApiConfig,
57
- client: Optional[httpx.Client] = None,
60
+ client: httpx.Client | None = None,
58
61
  annotations_api=None,
59
62
  projects_api=None
60
63
  ) -> None:
@@ -66,25 +69,24 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
66
69
  """
67
70
  from .annotations_api import AnnotationsApi
68
71
  from .projects_api import ProjectsApi
69
- super().__init__(config, Resource, 'resources', client)
70
- nest_asyncio.apply()
72
+ super().__init__(config, Resource, ResourcesApi._ENDPOINT_BASE, client)
71
73
  self.annotations_api = AnnotationsApi(
72
74
  config, client, resources_api=self) if annotations_api is None else annotations_api
73
- self.projects_api = ProjectsApi(config, client) if projects_api is None else projects_api
75
+ self.projects_api = projects_api or ProjectsApi(config, client, resources_api=self)
74
76
 
75
77
  def get_list(self,
76
- status: Optional[ResourceStatus] = None,
78
+ status: ResourceStatus | None = None,
77
79
  from_date: date | str | None = None,
78
80
  to_date: date | str | None = None,
79
- tags: Optional[Sequence[str]] = None,
80
- modality: Optional[str] = None,
81
- mimetype: Optional[str] = None,
81
+ tags: Sequence[str] | None = None,
82
+ modality: str | None = None,
83
+ mimetype: str | None = None,
82
84
  # return_ids_only: bool = False,
83
- order_field: Optional[ResourceFields] = None,
84
- order_ascending: Optional[bool] = None,
85
- channel: Optional[str] = None,
85
+ order_field: ResourceFields | None = None,
86
+ order_ascending: bool | None = None,
87
+ channel: str | None = None,
86
88
  project_name: str | list[str] | None = None,
87
- filename: Optional[str] = None,
89
+ filename: str | None = None,
88
90
  limit: int | None = None
89
91
  ) -> Sequence[Resource]:
90
92
  """Get resources with optional filtering.
@@ -146,7 +148,8 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
146
148
 
147
149
  return super().get_list(limit=limit, params=payload)
148
150
 
149
- def get_annotations(self, resource: str | Resource) -> Sequence[Annotation]:
151
+ def get_annotations(self, resource: str | Resource,
152
+ annotation_type: AnnotationType | str | None = None) -> Sequence[Annotation]:
150
153
  """Get annotations for a specific resource.
151
154
 
152
155
  Args:
@@ -155,7 +158,9 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
155
158
  Returns:
156
159
  A sequence of Annotation objects associated with the specified resource.
157
160
  """
158
- return self.annotations_api.get_list(resource=resource)
161
+ return self.annotations_api.get_list(resource=resource,
162
+ load_ai_segmentations=True,
163
+ annotation_type=annotation_type)
159
164
 
160
165
  @staticmethod
161
166
  def __process_files_parameter(file_path: str | Sequence[str | IO | pydicom.Dataset]
@@ -229,16 +234,16 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
229
234
 
230
235
  async def _upload_single_resource_async(self,
231
236
  file_path: str | IO,
232
- mimetype: Optional[str] = None,
237
+ mimetype: str | None = None,
233
238
  anonymize: bool = False,
234
239
  anonymize_retain_codes: Sequence[tuple] = [],
235
240
  tags: list[str] = [],
236
241
  mung_filename: Sequence[int] | Literal['all'] | None = None,
237
- channel: Optional[str] = None,
242
+ channel: str | None = None,
238
243
  session=None,
239
- modality: Optional[str] = None,
244
+ modality: str | None = None,
240
245
  publish: bool = False,
241
- metadata_file: Optional[str | dict] = None,
246
+ metadata_file: str | dict | None = None,
242
247
  ) -> str:
243
248
  if is_io_object(file_path):
244
249
  source_filepath = os.path.abspath(os.path.expanduser(file_path.name))
@@ -263,8 +268,8 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
263
268
  filename = new_filename
264
269
 
265
270
  is_a_dicom_file = None
266
- if mimetype is None:
267
- mimetype_list, ext = guess_typez(file_path, use_magic=True)
271
+ if not mimetype or mimetype == DEFAULT_MIME_TYPE:
272
+ mimetype_list, ext = guess_typez(file_path, use_magic=False)
268
273
  for mime in mimetype_list:
269
274
  if mime in NIFTI_MIMES:
270
275
  mimetype = DEFAULT_NIFTI_MIME
@@ -273,7 +278,11 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
273
278
  if ext == '.nii.gz' or filename.lower().endswith('nii.gz'):
274
279
  mimetype = DEFAULT_NIFTI_MIME
275
280
  else:
276
- mimetype = mimetype_list[-1] if mimetype_list else DEFAULT_MIME_TYPE
281
+ mimetype = mimetype_list[-1] if mimetype_list and mimetype_list[-1] else DEFAULT_MIME_TYPE
282
+ if (not mimetype) or mimetype == DEFAULT_MIME_TYPE:
283
+ msg = f"Could not determine mimetype for file {source_filepath}."
284
+ _LOGGER.warning(msg)
285
+ _USER_LOGGER.warning(msg)
277
286
 
278
287
  mimetype = standardize_mimetype(mimetype)
279
288
 
@@ -349,9 +358,12 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
349
358
  except Exception as e:
350
359
  _LOGGER.warning(f"Failed to add metadata to form: {e}")
351
360
 
361
+ timeout = aiohttp.ClientTimeout(total=300, connect=60, sock_read=300)
352
362
  resp_data = await self._make_request_async_json('POST',
353
363
  endpoint=self.endpoint_base,
354
- data=form)
364
+ data=form,
365
+ session=session,
366
+ timeout=timeout)
355
367
  if 'error' in resp_data:
356
368
  raise DatamintException(resp_data['error'])
357
369
  _LOGGER.debug(f"Response on uploading {filename}: {resp_data}")
@@ -367,20 +379,33 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
367
379
 
368
380
  async def _upload_resources_async(self,
369
381
  files_path: Sequence[str | IO],
370
- mimetype: Optional[str] = None,
382
+ mimetype: str | None = None,
371
383
  anonymize: bool = False,
372
384
  anonymize_retain_codes: Sequence[tuple] = [],
373
385
  on_error: Literal['raise', 'skip'] = 'raise',
374
386
  tags=None,
375
387
  mung_filename: Sequence[int] | Literal['all'] | None = None,
376
- channel: Optional[str] = None,
377
- modality: Optional[str] = None,
388
+ channel: str | None = None,
389
+ modality: str | None = None,
378
390
  publish: bool = False,
379
391
  segmentation_files: Sequence[dict] | None = None,
380
392
  transpose_segmentation: bool = False,
381
393
  metadata_files: Sequence[str | dict | None] | None = None,
382
394
  progress_bar: tqdm | None = None,
395
+ session: aiohttp.ClientSession | None = None,
383
396
  ) -> list[str]:
397
+ """Upload multiple resources asynchronously.
398
+
399
+ Args:
400
+ session: Optional aiohttp session. If None, uses a shared session managed by
401
+ this API instance. Callers should NOT close a session they pass in until
402
+ all async operations complete, as nested calls may reuse it.
403
+
404
+ Note:
405
+ Session ownership: When session=None (default), a long-lived shared session
406
+ is automatically used and managed. When explicitly passed, the caller retains
407
+ ownership and must ensure it remains open for the duration of this call.
408
+ """
384
409
  if on_error not in ['raise', 'skip']:
385
410
  raise ValueError("on_error must be either 'raise' or 'skip'")
386
411
 
@@ -390,52 +415,61 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
390
415
  if metadata_files is None:
391
416
  metadata_files = _infinite_gen(None)
392
417
 
393
- async with aiohttp.ClientSession() as session:
394
- async def __upload_single_resource(file_path, segfiles: dict[str, list | dict],
395
- metadata_file: str | dict | None):
396
- name = file_path.name if is_io_object(file_path) else file_path
397
- name = os.path.basename(name)
398
- rid = await self._upload_single_resource_async(
399
- file_path=file_path,
400
- mimetype=mimetype,
401
- anonymize=anonymize,
402
- anonymize_retain_codes=anonymize_retain_codes,
403
- tags=tags,
404
- session=session,
405
- mung_filename=mung_filename,
406
- channel=channel,
407
- modality=modality,
408
- publish=publish,
409
- metadata_file=metadata_file,
410
- )
411
- if progress_bar:
412
- progress_bar.update(1)
413
- progress_bar.set_postfix(file=name)
414
- else:
415
- _USER_LOGGER.info(f'"{name}" uploaded')
416
-
417
- if segfiles is not None:
418
- fpaths = segfiles['files']
419
- names = segfiles.get('names', _infinite_gen(None))
420
- if isinstance(names, dict):
421
- names = _infinite_gen(names)
422
- frame_indices = segfiles.get('frame_index', _infinite_gen(None))
423
- for f, name, frame_index in tqdm(zip(fpaths, names, frame_indices),
424
- desc=f"Uploading segmentations for {file_path}",
425
- total=len(fpaths)):
426
- if f is not None:
427
- await self.annotations_api._upload_segmentations_async(
428
- rid,
429
- file_path=f,
430
- name=name,
431
- frame_index=frame_index,
432
- transpose_segmentation=transpose_segmentation
433
- )
434
- return rid
418
+ # Use shared session by default for stability across multiple sequential calls.
419
+ # This prevents connection churn and intermittent SSL shutdown timeouts.
420
+ session = session or self._get_aiohttp_session()
421
+
422
+ async def __upload_single_resource(file_path, segfiles: dict[str, list | dict],
423
+ metadata_file: str | dict | None):
424
+ name = file_path.name if is_io_object(file_path) else file_path
425
+ name = os.path.basename(name)
426
+ rid = await self._upload_single_resource_async(
427
+ file_path=file_path,
428
+ mimetype=mimetype,
429
+ anonymize=anonymize,
430
+ anonymize_retain_codes=anonymize_retain_codes,
431
+ tags=tags,
432
+ session=session,
433
+ mung_filename=mung_filename,
434
+ channel=channel,
435
+ modality=modality,
436
+ publish=publish,
437
+ metadata_file=metadata_file,
438
+ )
439
+ if progress_bar:
440
+ progress_bar.update(1)
441
+ progress_bar.set_postfix(file=name)
442
+ else:
443
+ _USER_LOGGER.info(f'"{name}" uploaded')
444
+
445
+ if segfiles is not None:
446
+ fpaths = segfiles['files']
447
+ names = segfiles.get('names', _infinite_gen(None))
448
+ if isinstance(names, dict):
449
+ names = _infinite_gen(names)
450
+ frame_indices = segfiles.get('frame_index', _infinite_gen(None))
451
+ for f, name, frame_index in tqdm(zip(fpaths, names, frame_indices),
452
+ desc=f"Uploading segmentations for {file_path}",
453
+ total=len(fpaths)):
454
+ if f is not None:
455
+ await self.annotations_api._upload_segmentations_async(
456
+ rid,
457
+ file_path=f,
458
+ name=name,
459
+ frame_index=frame_index,
460
+ transpose_segmentation=transpose_segmentation
461
+ )
462
+ return rid
435
463
 
464
+ try:
436
465
  tasks = [__upload_single_resource(f, segfiles, metadata_file)
437
466
  for f, segfiles, metadata_file in zip(files_path, segmentation_files, metadata_files)]
438
- return await asyncio.gather(*tasks, return_exceptions=on_error == 'skip')
467
+ except ValueError:
468
+ msg = f"Error preparing upload tasks. Try `assemble_dicom=False`."
469
+ _LOGGER.error(msg)
470
+ _USER_LOGGER.error(msg)
471
+ raise
472
+ return await asyncio.gather(*tasks, return_exceptions=on_error == 'skip')
439
473
 
440
474
  def upload_resources(self,
441
475
  files_path: Sequence[str | IO | pydicom.Dataset],
@@ -503,6 +537,18 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
503
537
  raise ValueError(
504
538
  "upload_resources() only accepts multiple resources. For single resource upload, use upload_resource() instead.")
505
539
 
540
+ if publish_to:
541
+ publish = True
542
+ # Check if project exists
543
+ proj = self.projects_api.get_by_name(publish_to)
544
+ if proj is None:
545
+ try:
546
+ proj = self.projects_api.get_by_id(publish_to)
547
+ except Exception:
548
+ pass
549
+ if proj is None:
550
+ raise ResourceNotFoundError('Project', {'name_or_id': publish_to})
551
+
506
552
  files_path = ResourcesApi.__process_files_parameter(files_path)
507
553
 
508
554
  # Discard DICOM reports
@@ -682,7 +728,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
682
728
  )
683
729
  """
684
730
  # Convert segmentation_files to the format expected by upload_resources
685
- segmentation_files_list: Optional[list[list[str] | dict]] = None
731
+ segmentation_files_list: list[list[str] | dict] | None = None
686
732
  if segmentation_files is not None:
687
733
  segmentation_files_list = [segmentation_files]
688
734
 
@@ -737,22 +783,32 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
737
783
  """
738
784
  save_path = str(save_path) # Ensure save_path is a string for file operations
739
785
  resource_id = self._entid(resource)
786
+
787
+ # Disable total timeout for large file downloads, keep connect timeout
788
+ timeout = aiohttp.ClientTimeout(total=None, sock_connect=self.config.timeout)
789
+
740
790
  try:
741
791
  async with self._make_request_async('GET',
742
792
  f'{self.endpoint_base}/{resource_id}/file',
743
793
  session=session,
744
- headers={'accept': 'application/octet-stream'}) as resp:
745
- data_bytes = await resp.read()
794
+ headers={'accept': 'application/octet-stream'},
795
+ timeout=timeout) as resp:
796
+
797
+ final_save_path = save_path
798
+ target_path = final_save_path
799
+ if add_extension:
800
+ target_path = f"{save_path}.tmp"
801
+
802
+ with open(target_path, 'wb') as f:
803
+ async for chunk in resp.content.iter_chunked(1024 * 1024): # 1MB chunks
804
+ f.write(chunk)
746
805
 
747
- final_save_path = save_path
748
806
  if add_extension:
749
- # Save to temporary file first to determine mimetype from content
750
- temp_path = f"{save_path}.tmp"
751
- with open(temp_path, 'wb') as f:
752
- f.write(data_bytes)
807
+ # Determine mimetype from file content (read first 2KB)
808
+ with open(target_path, 'rb') as f:
809
+ head_content = f.read(2048)
753
810
 
754
- # Determine mimetype from file content
755
- mimetype, ext = BaseApi._determine_mimetype(content=data_bytes,
811
+ mimetype, ext = BaseApi._determine_mimetype(content=head_content,
756
812
  declared_mimetype=resource.mimetype if isinstance(resource, Resource) else None)
757
813
 
758
814
  # Generate final path with extension if needed
@@ -763,11 +819,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
763
819
  final_save_path = save_path + ext
764
820
 
765
821
  # Move file to final location
766
- os.rename(temp_path, final_save_path)
767
- else:
768
- # Standard save without extension detection
769
- with open(final_save_path, 'wb') as f:
770
- f.write(data_bytes)
822
+ os.rename(target_path, final_save_path)
771
823
 
772
824
  if progress_bar:
773
825
  progress_bar.update(1)
@@ -801,7 +853,9 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
801
853
  raise ValueError("resources must be a list of resources")
802
854
 
803
855
  async def _download_all_async():
804
- async with aiohttp.ClientSession() as session:
856
+ connector = self._create_aiohttp_connector(force_close=True)
857
+ timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=600)
858
+ async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
805
859
  tasks = [
806
860
  self._async_download_file(
807
861
  resource=r,
@@ -838,7 +892,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
838
892
 
839
893
  def download_resource_file(self,
840
894
  resource: str | Resource,
841
- save_path: Optional[str] = None,
895
+ save_path: str | None = None,
842
896
  auto_convert: bool = True,
843
897
  add_extension: bool = False
844
898
  ) -> ImagingData | tuple[ImagingData, str] | bytes:
@@ -916,6 +970,53 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
916
970
  return resource_file, save_path
917
971
  return resource_file
918
972
 
973
+ def cache_resources(
974
+ self,
975
+ resources: Sequence[Resource],
976
+ progress_bar: bool = True,
977
+ ) -> None:
978
+ """Cache multiple resources in parallel, skipping already cached ones.
979
+
980
+ This method downloads and caches resource file data concurrently,
981
+ significantly improving efficiency when working with large datasets.
982
+ Only resources that are not already cached will be downloaded.
983
+
984
+ Args:
985
+ resources: Sequence of Resource instances to cache.
986
+ progress_bar: Whether to show a progress bar. Default is True.
987
+
988
+ Example:
989
+ >>> resources = api.resources.get_list(limit=100)
990
+ >>> api.resources.cache_resources(resources)
991
+ Caching resources: 100%|██████████| 85/85 [00:12<00:00, 6.8files/s]
992
+ """
993
+ # Filter out already cached resources
994
+ resources_to_cache = [res for res in resources if not res.is_cached()]
995
+
996
+ if not resources_to_cache:
997
+ _LOGGER.info("All resources are already cached.")
998
+ return
999
+
1000
+ _LOGGER.info(f"Caching {len(resources_to_cache)} of {len(resources)} resources...")
1001
+
1002
+ if progress_bar:
1003
+ pbar = tqdm(total=len(resources_to_cache), desc="Caching resources", unit="file")
1004
+ else:
1005
+ pbar = None
1006
+
1007
+ try:
1008
+ for res in resources_to_cache:
1009
+ res._api = self # Ensure the resource has a reference to the API
1010
+ res.fetch_file_data(auto_convert=False, use_cache=True)
1011
+ if pbar:
1012
+ pbar.set_postfix(filename=res.filename)
1013
+ pbar.update(1)
1014
+
1015
+ finally:
1016
+ if pbar:
1017
+ pbar.close()
1018
+ _LOGGER.info(f"Successfully cached {len(resources_to_cache)} resources.")
1019
+
919
1020
  def download_resource_frame(self,
920
1021
  resource: str | Resource,
921
1022
  frame_index: int) -> Image.Image:
@@ -940,7 +1041,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
940
1041
  resource = self.get_by_id(resource)
941
1042
  if resource.mimetype.startswith('image/') or resource.storage == 'ImageResource':
942
1043
  if frame_index != 0:
943
- raise DatamintException(f"Resource {resource.id} is a single frame image, "
1044
+ raise DatamintException(f"Resource {resource.id} is not a multi-frame resource, "
944
1045
  f"but frame_index is {frame_index}.")
945
1046
  return self.download_resource_file(resource, auto_convert=True)
946
1047
 
@@ -985,22 +1086,30 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
985
1086
  raise
986
1087
 
987
1088
  def set_tags(self,
988
- resource: str | Resource,
1089
+ resource: str | Resource | Sequence[str | Resource],
989
1090
  tags: Sequence[str],
990
1091
  ):
991
1092
  """
992
1093
  Set tags for a resource, IMPORTANT: This replaces all existing tags.
993
1094
  Args:
994
- resource: The resource unique id or Resource object.
1095
+ resource: The resource object or a list of resources.
995
1096
  tags: The tags to set.
996
1097
  """
997
- data = {'tags': tags}
998
- resource_id = self._entid(resource)
999
1098
 
1000
- response = self._make_entity_request('PUT',
1001
- resource_id,
1002
- add_path='tags',
1003
- json=data)
1099
+ uniq_tags = set(tags) # remove duplicates
1100
+
1101
+ if isinstance(resource, Sequence):
1102
+ resource_ids = [self._entid(res) for res in resource]
1103
+ response = self._make_request('PUT',
1104
+ f'{self.endpoint_base}/tags',
1105
+ json={'resource_ids': resource_ids,
1106
+ 'tags': list(uniq_tags)})
1107
+ else:
1108
+ resource_id = self._entid(resource)
1109
+ response = self._make_entity_request('PUT',
1110
+ resource_id,
1111
+ add_path='tags',
1112
+ json={'tags': list(uniq_tags)})
1004
1113
  return response
1005
1114
 
1006
1115
  # def get_projects(self, resource: Resource) -> Sequence[Project]:
@@ -1018,7 +1127,7 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
1018
1127
  # return [proj for proj in self.projects_api.get_all() if proj.id in proj_ids]
1019
1128
 
1020
1129
  def add_tags(self,
1021
- resource: str | Resource,
1130
+ resource: str | Resource | Sequence[str | Resource],
1022
1131
  tags: Sequence[str],
1023
1132
  ):
1024
1133
  """
@@ -1029,8 +1138,26 @@ class ResourcesApi(CreatableEntityApi[Resource], DeletableEntityApi[Resource]):
1029
1138
  """
1030
1139
  if isinstance(resource, str):
1031
1140
  resource = self.get_by_id(resource)
1141
+ elif isinstance(resource, Sequence):
1142
+ # Transform every str to Resource first.
1143
+ resources = [self.get_by_id(res) if isinstance(res, str) else res for res in resource]
1144
+
1145
+ # group resource having the exact same tags to minimize requests
1146
+ tag_map: dict[tuple, list[Resource]] = defaultdict(list)
1147
+ for res in resources:
1148
+ old_tags = res.tags if res.tags is not None else []
1149
+ # key = tuple(sorted(old_tags))
1150
+ key = tuple(old_tags) # keep order, assuming order matters for tags
1151
+ tag_map[key].append(res)
1152
+
1153
+ # finally, set tags for each group
1154
+ for old_tags_tuple, res_group in tag_map.items():
1155
+ old_tags = list(old_tags_tuple)
1156
+ self.set_tags(res_group, old_tags + list(tags))
1157
+ return
1158
+
1032
1159
  old_tags = resource.tags if resource.tags is not None else []
1033
- return self.set_tags(resource, old_tags + list(tags))
1160
+ self.set_tags(resource, old_tags + list(tags))
1034
1161
 
1035
1162
  def bulk_delete(self, entities: Sequence[str | Resource]) -> None:
1036
1163
  """Delete multiple entities. Faster than deleting them one by one.