dataroom-client 0.0.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1443 @@
1
+ import asyncio
2
+ from datetime import datetime
3
+ import json as json_module
4
+ import logging
5
+ import os
6
+ import uuid
7
+ from enum import Enum
8
+ from io import BytesIO
9
+ import mimetypes
10
+ from typing import List, TypedDict, Optional
11
+ from urllib.parse import urljoin
12
+ import httpx
13
+
14
+
15
+ mimetypes.add_type("image/webp", ".webp")
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class DataRoomError(Exception):
20
+ """Base exception class for DataRoomClient errors"""
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ self.response = kwargs.pop("response", None)
24
+ super().__init__(*args, **kwargs)
25
+
26
+ def __str__(self):
27
+ if self.response:
28
+ return f"{super().__str__()}\n{self.response.status_code}\n{self.response.text}"
29
+ else:
30
+ return super().__str__()
31
+
32
+
33
+ class DataRoomFile:
34
+ """A wrapper for a file-like object that can be used with DataRoomClient"""
35
+
36
+ def __init__(self, bytes_io, content_type, path=None, extension=None):
37
+ extension = (
38
+ mimetypes.guess_extension(content_type) or "" if extension is None else extension
39
+ )
40
+ self.bytes_io = bytes_io
41
+ self.content_type = content_type
42
+ if extension[0] != ".":
43
+ extension = f".{extension}"
44
+ self.extension = extension
45
+ self.name = f"{uuid.uuid4().hex}"
46
+ self.filename = f"{self.name}{extension}"
47
+ self.path = path
48
+
49
+ @classmethod
50
+ def from_path(cls, path: str):
51
+ content_type, encoding = mimetypes.guess_type(path)
52
+ if not content_type:
53
+ raise DataRoomError(f"Could not guess content type for file: {path}")
54
+ with open(path, "rb") as f:
55
+ return DataRoomFile(
56
+ bytes_io=BytesIO(f.read()),
57
+ content_type=content_type,
58
+ path=path,
59
+ )
60
+
61
+ @classmethod
62
+ def from_bytesio(cls, bytes_io, extension):
63
+ assert extension is not None, "Please provide a file extension"
64
+ return DataRoomFile(
65
+ bytes_io=bytes_io,
66
+ extension=extension,
67
+ content_type=None,
68
+ path=None,
69
+ )
70
+
71
+
72
+ class ClientDuplicateState(Enum):
73
+ UNPROCESSED = 'None'
74
+ ORIGINAL = 1
75
+ DUPLICATE = 2
76
+
77
+
78
+ class LatentType(TypedDict, total=False):
79
+ latent_type: str
80
+ file: DataRoomFile
81
+
82
+
83
+ class ImageUpdate(TypedDict, total=False):
84
+ id: str # noqa: A003
85
+ source: Optional[str]
86
+ attributes: Optional[dict]
87
+ tags: Optional[List[str]]
88
+ coca_embedding: Optional[str]
89
+ related_images: Optional[dict[str, str]]
90
+
91
+
92
+ class ImageCreate(TypedDict, total=False):
93
+ id: str # noqa: A003
94
+ source: str
95
+ image_file: Optional[DataRoomFile]
96
+ image_url: Optional[str]
97
+ attributes: Optional[dict]
98
+ tags: Optional[list[str]]
99
+ related_images: Optional[dict[str, str]]
100
+
101
+
102
+ def arg_deprecation_msg(arg_name, msg=''):
103
+ return f'DEPRECATION WARNING: Argument "{arg_name}" is deprecated, and will be removed in the future. {msg}'
104
+
105
+
106
+ class DataRoomClient:
107
+ """
108
+ The official client of the DataRoom API. See notebooks for usage examples.
109
+ """
110
+
111
+ def __init__(self, api_key=None, api_url=None, timeout=120):
112
+ """
113
+ @param api_key: API key for DataRoom API
114
+ @param api_url: URL of the DataRoom backend API
115
+ """
116
+ self.api_key = api_key or os.environ.get("DATAROOM_API_KEY")
117
+ self.api_url = (
118
+ api_url
119
+ or os.environ.get("DATAROOM_API_URL")
120
+ )
121
+ if not self.api_url:
122
+ raise DataRoomError("DataRoom api_url is not set")
123
+ self.client = httpx.AsyncClient()
124
+ self.timeout = timeout
125
+
126
+ # -------------------- Private methods --------------------
127
+
128
+ async def _make_request(
129
+ self, url, params=None, method="GET", json=None, files=None, headers=None,
130
+ ):
131
+ absolute_url = urljoin(self.api_url, url)
132
+ if headers is None:
133
+ headers = {}
134
+ headers.update({
135
+ "Authorization": f"Token {self.api_key}",
136
+ })
137
+ try:
138
+ response = await self.client.request(
139
+ method=method,
140
+ url=absolute_url,
141
+ params=params,
142
+ json=json,
143
+ files=files,
144
+ headers=headers,
145
+ timeout=self.timeout,
146
+ )
147
+ response.raise_for_status()
148
+ except httpx.HTTPError as e:
149
+ response = None
150
+ if hasattr(e, "response"):
151
+ response = e.response
152
+ raise DataRoomError(e, response=response) from e
153
+ else:
154
+ if response.content:
155
+ return response.json()
156
+
157
+ async def _make_paginated_request(
158
+ self, url, limit=1000, params=None, method="GET", json=None, headers=None,
159
+ ):
160
+ items = []
161
+ next_url = url
162
+ while next_url:
163
+ response = await self._make_request(
164
+ next_url, params=params, method=method, json=json, headers=headers,
165
+ )
166
+ if "results" not in response:
167
+ raise NotImplementedError(f'No "results" in response to {url}')
168
+ if "next" not in response:
169
+ raise NotImplementedError(f'No "next" in response to {url}')
170
+ next_url = response["next"]
171
+ items += response["results"]
172
+ if limit is not None and len(items) >= limit:
173
+ break
174
+
175
+ if limit is not None:
176
+ return items[:limit]
177
+ return items
178
+
179
+ async def _make_paginated_request_iter(
180
+ self, url, limit=1000, params=None, method="GET", json=None, headers=None,
181
+ ):
182
+ next_url = url
183
+ returned_items = 0
184
+ while next_url:
185
+ response = await self._make_request(
186
+ next_url, params=params, method=method, json=json, headers=headers,
187
+ )
188
+ if "results" not in response:
189
+ raise NotImplementedError(f'No "results" in response to {url}')
190
+ if "next" not in response:
191
+ raise NotImplementedError(f'No "next" in response to {url}')
192
+ next_url = response["next"]
193
+ for item in response["results"]:
194
+ yield item
195
+ returned_items += 1
196
+ if limit is not None and returned_items >= limit:
197
+ break
198
+ if limit is not None and returned_items >= limit:
199
+ break
200
+
201
+ @staticmethod
202
+ def _dict_filter_none(d: dict):
203
+ return {k: v for k, v in d.items() if v is not None}
204
+
205
+ @staticmethod
206
+ def _get_attributes_filter(attributes: dict | None):
207
+ if not attributes:
208
+ return None
209
+ for key, val in attributes.items():
210
+ val = str(val)
211
+ if "," in key or "," in val:
212
+ raise DataRoomError(
213
+ "Commas are not allowed in attribute keys or values"
214
+ )
215
+ if ":" in key or ":" in val:
216
+ raise DataRoomError(
217
+ "Colons are not allowed in attribute keys or values"
218
+ )
219
+ attrs_str = ",".join([f"{key}:{val}" for key, val in attributes.items()])
220
+ return attrs_str
221
+
222
+ @staticmethod
223
+ def _validate_vector(vector: str):
224
+ err_msg = "Argument vector must be a string representing a list of 768 floats."
225
+ if not isinstance(vector, str) or not len(vector) > 0:
226
+ raise DataRoomError(f"{err_msg} Not a string.")
227
+ if vector[0] != "[" or vector[-1] != "]":
228
+ raise DataRoomError(f"{err_msg} Not a list.")
229
+ if len(vector[1:-1].split(',')) != 768:
230
+ raise DataRoomError(f"{err_msg} Incorrect length.")
231
+
232
+ # -------------------- Utils --------------------
233
+
234
+ @classmethod
235
+ async def download_image_from_url(cls, image_url: str) -> DataRoomFile:
236
+ try:
237
+ async with httpx.AsyncClient() as client:
238
+ response = await client.get(image_url)
239
+ response.raise_for_status()
240
+ except httpx.HTTPError as e:
241
+ response = None
242
+ if hasattr(e, "response"):
243
+ response = e.response
244
+ raise DataRoomError(e, response=response) from e
245
+ else:
246
+ content_type = response.headers.get("Content-Type")
247
+ return DataRoomFile(
248
+ bytes_io=BytesIO(response.content),
249
+ content_type=content_type,
250
+ )
251
+
252
+ # -------------------- Image API methods --------------------
253
+
254
+ async def get_images(
255
+ self,
256
+ limit: int | None = 1000,
257
+ page_size: int = None,
258
+ fields: List[str] = None,
259
+ include_fields: List[str] = None,
260
+ exclude_fields: List[str] = None,
261
+ all_fields: bool = False,
262
+ return_latents: List[str] = None,
263
+ cache_ttl: int = None,
264
+ partitions_count: int = None,
265
+ partition: int = None,
266
+ # filters
267
+ short_edge: int = None,
268
+ short_edge__gt: int = None,
269
+ short_edge__gte: int = None,
270
+ short_edge__lt: int = None,
271
+ short_edge__lte: int = None,
272
+ pixel_count: int = None,
273
+ pixel_count__gt: int = None,
274
+ pixel_count__gte: int = None,
275
+ pixel_count__lt: int = None,
276
+ pixel_count__lte: int = None,
277
+ aspect_ratio_fraction: str = None,
278
+ aspect_ratio: float = None,
279
+ aspect_ratio__gt: float = None,
280
+ aspect_ratio__gte: float = None,
281
+ aspect_ratio__lt: float = None,
282
+ aspect_ratio__lte: float = None,
283
+ source: str = None,
284
+ sources: List[str] = None,
285
+ sources__ne: List[str] = None,
286
+ attributes: dict = None,
287
+ has_attributes: list = None,
288
+ lacks_attributes: list = None,
289
+ has_latents: List[str] = None,
290
+ lacks_latents: List[str] = None,
291
+ has_masks: List[str] = None,
292
+ lacks_masks: List[str] = None,
293
+ tags: list = None,
294
+ tags__ne: list = None,
295
+ tags__all: list = None,
296
+ tags__ne_all: list = None,
297
+ tags__empty: bool = None,
298
+ coca_embedding__empty: bool = None,
299
+ duplicate_state: ClientDuplicateState = None,
300
+ date_created__gt: datetime = None,
301
+ date_created__gte: datetime = None,
302
+ date_created__lt: datetime = None,
303
+ date_created__lte: datetime = None,
304
+ date_updated__gt: datetime = None,
305
+ date_updated__gte: datetime = None,
306
+ date_updated__lt: datetime = None,
307
+ date_updated__lte: datetime = None,
308
+ ):
309
+ headers = {}
310
+ if cache_ttl:
311
+ headers["Cache-Control"] = f"max-age={cache_ttl}"
312
+
313
+ if source is not None:
314
+ sources = [source]
315
+ logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
316
+
317
+ return await self._make_paginated_request(
318
+ url="images/",
319
+ limit=limit,
320
+ params=self._dict_filter_none(
321
+ {
322
+ "fields": ",".join(fields) if fields else None,
323
+ "include_fields": ",".join(include_fields) if include_fields else None,
324
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
325
+ "all_fields": all_fields if all_fields else None,
326
+ "return_latents": ",".join(return_latents) if return_latents else None,
327
+ "page_size": page_size,
328
+ "partitions_count": partitions_count,
329
+ "partition": partition,
330
+ # filters
331
+ "short_edge": short_edge,
332
+ "short_edge__gt": short_edge__gt,
333
+ "short_edge__gte": short_edge__gte,
334
+ "short_edge__lt": short_edge__lt,
335
+ "short_edge__lte": short_edge__lte,
336
+ "pixel_count": pixel_count,
337
+ "pixel_count__gt": pixel_count__gt,
338
+ "pixel_count__gte": pixel_count__gte,
339
+ "pixel_count__lt": pixel_count__lt,
340
+ "pixel_count__lte": pixel_count__lte,
341
+ "aspect_ratio_fraction": aspect_ratio_fraction,
342
+ "aspect_ratio": aspect_ratio,
343
+ "aspect_ratio__gt": aspect_ratio__gt,
344
+ "aspect_ratio__gte": aspect_ratio__gte,
345
+ "aspect_ratio__lt": aspect_ratio__lt,
346
+ "aspect_ratio__lte": aspect_ratio__lte,
347
+ "sources": ",".join(sources) if sources else None,
348
+ "sources__ne": ",".join(sources__ne) if sources__ne else None,
349
+ "attributes": self._get_attributes_filter(attributes),
350
+ "has_attributes": ",".join(has_attributes) if has_attributes else None,
351
+ "lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
352
+ "has_latents": ",".join(has_latents) if has_latents else None,
353
+ "lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
354
+ "has_masks": ",".join(has_masks) if has_masks else None,
355
+ "lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
356
+ "tags": ",".join(tags) if tags else None,
357
+ "tags__ne": ",".join(tags__ne) if tags__ne else None,
358
+ "tags__all": ",".join(tags__all) if tags__all else None,
359
+ "tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
360
+ "tags__empty": tags__empty,
361
+ "coca_embedding__empty": coca_embedding__empty,
362
+ "duplicate_state": duplicate_state.value if duplicate_state else None,
363
+ "date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
364
+ "date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
365
+ "date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
366
+ "date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
367
+ "date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
368
+ "date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
369
+ "date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
370
+ "date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
371
+ }
372
+ ),
373
+ headers=headers,
374
+ )
375
+
376
+ async def get_images_iter(
377
+ self,
378
+ limit: int | None = 1000,
379
+ page_size: int = None,
380
+ fields: List[str] = None,
381
+ include_fields: List[str] = None,
382
+ exclude_fields: List[str] = None,
383
+ all_fields: bool = False,
384
+ return_latents: List[str] = None,
385
+ cache_ttl: int = None,
386
+ partitions_count: int = None,
387
+ partition: int = None,
388
+ # filters
389
+ short_edge: int = None,
390
+ short_edge__gt: int = None,
391
+ short_edge__gte: int = None,
392
+ short_edge__lt: int = None,
393
+ short_edge__lte: int = None,
394
+ pixel_count: int = None,
395
+ pixel_count__gt: int = None,
396
+ pixel_count__gte: int = None,
397
+ pixel_count__lt: int = None,
398
+ pixel_count__lte: int = None,
399
+ aspect_ratio_fraction: str = None,
400
+ aspect_ratio: float = None,
401
+ aspect_ratio__gt: float = None,
402
+ aspect_ratio__gte: float = None,
403
+ aspect_ratio__lt: float = None,
404
+ aspect_ratio__lte: float = None,
405
+ source: str = None,
406
+ sources: List[str] = None,
407
+ sources__ne: List[str] = None,
408
+ attributes: dict = None,
409
+ has_attributes: list = None,
410
+ lacks_attributes: list = None,
411
+ has_latents: List[str] = None,
412
+ lacks_latents: List[str] = None,
413
+ has_masks: List[str] = None,
414
+ lacks_masks: List[str] = None,
415
+ tags: list = None,
416
+ tags__ne: list = None,
417
+ tags__all: list = None,
418
+ tags__ne_all: list = None,
419
+ tags__empty: bool = None,
420
+ coca_embedding__empty: bool = None,
421
+ duplicate_state: ClientDuplicateState = None,
422
+ date_created__gt: datetime = None,
423
+ date_created__gte: datetime = None,
424
+ date_created__lt: datetime = None,
425
+ date_created__lte: datetime = None,
426
+ date_updated__gt: datetime = None,
427
+ date_updated__gte: datetime = None,
428
+ date_updated__lt: datetime = None,
429
+ date_updated__lte: datetime = None,
430
+ ):
431
+ headers = {}
432
+ if cache_ttl:
433
+ headers["Cache-Control"] = f"max-age={cache_ttl}"
434
+
435
+ if source is not None:
436
+ sources = [source]
437
+ logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
438
+
439
+ async for item in self._make_paginated_request_iter(
440
+ url="images/",
441
+ limit=limit,
442
+ params=self._dict_filter_none(
443
+ {
444
+ "fields": ",".join(fields) if fields else None,
445
+ "include_fields": ",".join(include_fields) if include_fields else None,
446
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
447
+ "all_fields": all_fields if all_fields else None,
448
+ "return_latents": ",".join(return_latents) if return_latents else None,
449
+ "page_size": page_size,
450
+ "partitions_count": partitions_count,
451
+ "partition": partition,
452
+ # filters
453
+ "short_edge": short_edge,
454
+ "short_edge__gt": short_edge__gt,
455
+ "short_edge__gte": short_edge__gte,
456
+ "short_edge__lt": short_edge__lt,
457
+ "short_edge__lte": short_edge__lte,
458
+ "pixel_count": pixel_count,
459
+ "pixel_count__gt": pixel_count__gt,
460
+ "pixel_count__gte": pixel_count__gte,
461
+ "pixel_count__lt": pixel_count__lt,
462
+ "pixel_count__lte": pixel_count__lte,
463
+ "aspect_ratio_fraction": aspect_ratio_fraction,
464
+ "aspect_ratio": aspect_ratio,
465
+ "aspect_ratio__gt": aspect_ratio__gt,
466
+ "aspect_ratio__gte": aspect_ratio__gte,
467
+ "aspect_ratio__lt": aspect_ratio__lt,
468
+ "aspect_ratio__lte": aspect_ratio__lte,
469
+ "sources": ",".join(sources) if sources else None,
470
+ "sources__ne": ",".join(sources__ne) if sources__ne else None,
471
+ "attributes": self._get_attributes_filter(attributes),
472
+ "has_attributes": ",".join(has_attributes) if has_attributes else None,
473
+ "lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
474
+ "has_latents": ",".join(has_latents) if has_latents else None,
475
+ "lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
476
+ "has_masks": ",".join(has_masks) if has_masks else None,
477
+ "lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
478
+ "tags": ",".join(tags) if tags else None,
479
+ "tags__ne": ",".join(tags__ne) if tags__ne else None,
480
+ "tags__all": ",".join(tags__all) if tags__all else None,
481
+ "tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
482
+ "tags__empty": tags__empty,
483
+ "coca_embedding__empty": coca_embedding__empty,
484
+ "duplicate_state": duplicate_state.value if duplicate_state else None,
485
+ "date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
486
+ "date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
487
+ "date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
488
+ "date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
489
+ "date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
490
+ "date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
491
+ "date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
492
+ "date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
493
+ }
494
+ ),
495
+ headers=headers,
496
+ ):
497
+ yield item
498
+
499
+ async def get_random_images(
500
+ self,
501
+ limit: int | None = 1000,
502
+ page_size: int = None,
503
+ fields: List[str] = None,
504
+ include_fields: List[str] = None,
505
+ exclude_fields: List[str] = None,
506
+ all_fields: bool = False,
507
+ return_latents: List[str] = None,
508
+ cache_ttl: int = None,
509
+ prefix_length: int = None,
510
+ num_prefixes: int = None,
511
+ # filters
512
+ short_edge: int = None,
513
+ short_edge__gt: int = None,
514
+ short_edge__gte: int = None,
515
+ short_edge__lt: int = None,
516
+ short_edge__lte: int = None,
517
+ pixel_count: int = None,
518
+ pixel_count__gt: int = None,
519
+ pixel_count__gte: int = None,
520
+ pixel_count__lt: int = None,
521
+ pixel_count__lte: int = None,
522
+ aspect_ratio_fraction: str = None,
523
+ aspect_ratio: float = None,
524
+ aspect_ratio__gt: float = None,
525
+ aspect_ratio__gte: float = None,
526
+ aspect_ratio__lt: float = None,
527
+ aspect_ratio__lte: float = None,
528
+ source: str = None,
529
+ sources: List[str] = None,
530
+ sources__ne: List[str] = None,
531
+ attributes: dict = None,
532
+ has_attributes: list = None,
533
+ lacks_attributes: list = None,
534
+ has_latents: List[str] = None,
535
+ lacks_latents: List[str] = None,
536
+ has_masks: List[str] = None,
537
+ lacks_masks: List[str] = None,
538
+ tags: list = None,
539
+ tags__ne: list = None,
540
+ tags__all: list = None,
541
+ tags__ne_all: list = None,
542
+ tags__empty: bool = None,
543
+ coca_embedding__empty: bool = None,
544
+ duplicate_state: ClientDuplicateState = None,
545
+ date_created__gt: datetime = None,
546
+ date_created__gte: datetime = None,
547
+ date_created__lt: datetime = None,
548
+ date_created__lte: datetime = None,
549
+ date_updated__gt: datetime = None,
550
+ date_updated__gte: datetime = None,
551
+ date_updated__lt: datetime = None,
552
+ date_updated__lte: datetime = None,
553
+ ):
554
+ """
555
+ Get a list of random images.
556
+
557
+ Random sampling works by filtering image_hash by a number of random hex prefixes. Use prefix_length and
558
+ num_prefixes to adjust the randomness factor. In general, a smaller prefix_length will give you more samples,
559
+ but less random and a higher num_prefixes will give you more samples, but slow down the query. The default
560
+ values are prefix_length=5 and num_prefixes=100.
561
+ """
562
+ headers = {}
563
+ if cache_ttl:
564
+ headers["Cache-Control"] = f"max-age={cache_ttl}"
565
+
566
+ if source is not None:
567
+ sources = [source]
568
+ logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
569
+
570
+ return await self._make_paginated_request(
571
+ url="images/random/",
572
+ limit=limit,
573
+ params=self._dict_filter_none(
574
+ {
575
+ "fields": ",".join(fields) if fields else None,
576
+ "include_fields": ",".join(include_fields) if include_fields else None,
577
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
578
+ "all_fields": all_fields if all_fields else None,
579
+ "return_latents": ",".join(return_latents) if return_latents else None,
580
+ "page_size": page_size,
581
+ "prefix_length": prefix_length,
582
+ "num_prefixes": num_prefixes,
583
+ # filters
584
+ "short_edge": short_edge,
585
+ "short_edge__gt": short_edge__gt,
586
+ "short_edge__gte": short_edge__gte,
587
+ "short_edge__lt": short_edge__lt,
588
+ "short_edge__lte": short_edge__lte,
589
+ "pixel_count": pixel_count,
590
+ "pixel_count__gt": pixel_count__gt,
591
+ "pixel_count__gte": pixel_count__gte,
592
+ "pixel_count__lt": pixel_count__lt,
593
+ "pixel_count__lte": pixel_count__lte,
594
+ "aspect_ratio_fraction": aspect_ratio_fraction,
595
+ "aspect_ratio": aspect_ratio,
596
+ "aspect_ratio__gt": aspect_ratio__gt,
597
+ "aspect_ratio__gte": aspect_ratio__gte,
598
+ "aspect_ratio__lt": aspect_ratio__lt,
599
+ "aspect_ratio__lte": aspect_ratio__lte,
600
+ "sources": ",".join(sources) if sources else None,
601
+ "sources__ne": ",".join(sources__ne) if sources__ne else None,
602
+ "attributes": self._get_attributes_filter(attributes),
603
+ "has_attributes": ",".join(has_attributes) if has_attributes else None,
604
+ "lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
605
+ "has_latents": ",".join(has_latents) if has_latents else None,
606
+ "lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
607
+ "has_masks": ",".join(has_masks) if has_masks else None,
608
+ "lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
609
+ "tags": ",".join(tags) if tags else None,
610
+ "tags__ne": ",".join(tags__ne) if tags__ne else None,
611
+ "tags__all": ",".join(tags__all) if tags__all else None,
612
+ "tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
613
+ "tags__empty": tags__empty,
614
+ "coca_embedding__empty": coca_embedding__empty,
615
+ "duplicate_state": duplicate_state.value if duplicate_state else None,
616
+ "date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
617
+ "date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
618
+ "date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
619
+ "date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
620
+ "date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
621
+ "date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
622
+ "date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
623
+ "date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
624
+ }
625
+ ),
626
+ headers=headers,
627
+ )
628
+
629
+ async def count_images(
630
+ self,
631
+ partitions_count: int = None,
632
+ partition: int = None,
633
+ # filters
634
+ short_edge: int | None = None,
635
+ short_edge__gt: int = None,
636
+ short_edge__gte: int = None,
637
+ short_edge__lt: int = None,
638
+ short_edge__lte: int = None,
639
+ pixel_count: int | None = None,
640
+ pixel_count__gt: int = None,
641
+ pixel_count__gte: int = None,
642
+ pixel_count__lt: int = None,
643
+ pixel_count__lte: int = None,
644
+ aspect_ratio_fraction: str = None,
645
+ aspect_ratio: float = None,
646
+ aspect_ratio__gt: float = None,
647
+ aspect_ratio__gte: float = None,
648
+ aspect_ratio__lt: float = None,
649
+ aspect_ratio__lte: float = None,
650
+ source: str = None,
651
+ sources: List[str] = None,
652
+ sources__ne: List[str] = None,
653
+ attributes: dict = None,
654
+ has_attributes: list = None,
655
+ lacks_attributes: list = None,
656
+ has_latents: List[str] = None,
657
+ lacks_latents: List[str] = None,
658
+ has_masks: List[str] = None,
659
+ lacks_masks: List[str] = None,
660
+ tags: list = None,
661
+ tags__ne: list = None,
662
+ tags__all: list = None,
663
+ tags__ne_all: list = None,
664
+ tags__empty: bool = None,
665
+ coca_embedding__empty: bool = None,
666
+ duplicate_state: ClientDuplicateState = None,
667
+ date_created__gt: datetime = None,
668
+ date_created__gte: datetime = None,
669
+ date_created__lt: datetime = None,
670
+ date_created__lte: datetime = None,
671
+ date_updated__gt: datetime = None,
672
+ date_updated__gte: datetime = None,
673
+ date_updated__lt: datetime = None,
674
+ date_updated__lte: datetime = None,
675
+ ):
676
+
677
+ if source is not None:
678
+ sources = [source]
679
+ logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
680
+
681
+ response = await self._make_request(
682
+ url="images/count/",
683
+ params=self._dict_filter_none(
684
+ {
685
+ "partitions_count": partitions_count,
686
+ "partition": partition,
687
+ # filters
688
+ "short_edge": short_edge,
689
+ "short_edge__gt": short_edge__gt,
690
+ "short_edge__gte": short_edge__gte,
691
+ "short_edge__lt": short_edge__lt,
692
+ "short_edge__lte": short_edge__lte,
693
+ "pixel_count": pixel_count,
694
+ "pixel_count__gt": pixel_count__gt,
695
+ "pixel_count__gte": pixel_count__gte,
696
+ "pixel_count__lt": pixel_count__lt,
697
+ "pixel_count__lte": pixel_count__lte,
698
+ "aspect_ratio_fraction": aspect_ratio_fraction,
699
+ "aspect_ratio": aspect_ratio,
700
+ "aspect_ratio__gt": aspect_ratio__gt,
701
+ "aspect_ratio__gte": aspect_ratio__gte,
702
+ "aspect_ratio__lt": aspect_ratio__lt,
703
+ "aspect_ratio__lte": aspect_ratio__lte,
704
+ "sources": ",".join(sources) if sources else None,
705
+ "sources__ne": ",".join(sources__ne) if sources__ne else None,
706
+ "attributes": self._get_attributes_filter(attributes),
707
+ "has_attributes": ",".join(has_attributes) if has_attributes else None,
708
+ "lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
709
+ "has_latents": ",".join(has_latents) if has_latents else None,
710
+ "lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
711
+ "has_masks": ",".join(has_masks) if has_masks else None,
712
+ "lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
713
+ "tags": ",".join(tags) if tags else None,
714
+ "tags__ne": ",".join(tags__ne) if tags__ne else None,
715
+ "tags__all": ",".join(tags__all) if tags__all else None,
716
+ "tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
717
+ "tags__empty": tags__empty,
718
+ "coca_embedding__empty": coca_embedding__empty,
719
+ "duplicate_state": duplicate_state.value if duplicate_state else None,
720
+ "date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
721
+ "date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
722
+ "date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
723
+ "date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
724
+ "date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
725
+ "date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
726
+ "date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
727
+ "date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
728
+ }
729
+ ),
730
+ )
731
+ return response["count"]
732
+
733
+ async def get_image(
734
+ self,
735
+ image_id: str,
736
+ fields: List[str] = None,
737
+ include_fields: List[str] = None,
738
+ exclude_fields: List[str] = None,
739
+ all_fields: bool = False,
740
+ return_latents: List[str] = None,
741
+ ):
742
+ return await self._make_request(
743
+ url=f"images/{image_id}/",
744
+ params=self._dict_filter_none({
745
+ "fields": ",".join(fields) if fields else None,
746
+ "include_fields": ",".join(include_fields) if include_fields else None,
747
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
748
+ "all_fields": all_fields if all_fields else None,
749
+ "return_latents": ",".join(return_latents) if return_latents else None,
750
+ }),
751
+ )
752
+
753
+ async def create_image(
754
+ self,
755
+ image_id: str = None,
756
+ source: str = None,
757
+ image_file: DataRoomFile = None,
758
+ image_url: str = None,
759
+ attributes: dict = None,
760
+ tags: list[str] = None,
761
+ related_images: dict[str, str] | None = None,
762
+ ):
763
+ if not image_file and not image_url:
764
+ raise DataRoomError('Please provide either an "image_file" or "image_url" field')
765
+
766
+ if not image_id and not image_url:
767
+ raise DataRoomError('Please provide either an "image_id" or "image_url" field')
768
+
769
+ if not source:
770
+ raise DataRoomError('Please provide a "source" field')
771
+
772
+ json_data = self._dict_filter_none(
773
+ {
774
+ "id": image_id,
775
+ "image_url": image_url,
776
+ "source": source,
777
+ "attributes": attributes,
778
+ "tags": tags,
779
+ "related_images": related_images,
780
+ }
781
+ )
782
+
783
+ if image_file:
784
+ # when uploading an image, we need to send a multipart/form-data request, not JSON
785
+ # the image is sent as a file, and the rest of the data is sent as text/plain
786
+ if not isinstance(image_file, DataRoomFile):
787
+ raise DataRoomError("Argument image_file must be a DataRoomFile")
788
+ files = {
789
+ "image": (
790
+ image_file.filename,
791
+ image_file.bytes_io,
792
+ image_file.content_type,
793
+ ),
794
+ "json": (None, json_module.dumps(json_data), "text/plain"),
795
+ }
796
+ return await self._make_request(url="images/", method="POST", files=files)
797
+ else:
798
+ # application/json request
799
+ return await self._make_request(
800
+ url="images/",
801
+ method="POST",
802
+ json=json_data,
803
+ )
804
+
805
+ async def create_images(
806
+ self,
807
+ images: List[ImageCreate],
808
+ ):
809
+ files = []
810
+ for i, image in enumerate(images):
811
+ if 'id' not in image:
812
+ raise DataRoomError("Missing 'id' field in image")
813
+ if 'source' not in image:
814
+ raise DataRoomError("Missing 'source' field in image")
815
+ if 'image_file' not in image and 'image_url' not in image:
816
+ raise DataRoomError('Please provide either an "image_file" or "image_url" field')
817
+
818
+ image_file = image.get('image_file')
819
+ if image_file and not isinstance(image_file, DataRoomFile):
820
+ raise DataRoomError("Argument image_file must be a DataRoomFile")
821
+
822
+ if image_file:
823
+ files.append((
824
+ f"image_{i}",
825
+ (
826
+ image_file.filename,
827
+ image_file.bytes_io,
828
+ image_file.content_type,
829
+ ),
830
+ ))
831
+
832
+ json_data = self._dict_filter_none({
833
+ "id": image['id'],
834
+ "source": image['source'],
835
+ "image_url": image.get('image_url'),
836
+ "attributes": image.get('attributes'),
837
+ "tags": image.get('tags'),
838
+ "related_images": image.get('related_images'),
839
+ })
840
+ files.append((
841
+ f"json_{i}",
842
+ (None, json_module.dumps(json_data), "text/plain")
843
+ ))
844
+
845
+ return await self._make_request(url="images/", method="POST", files=files)
846
+
847
+ async def update_image(
848
+ self,
849
+ image_id: str,
850
+ source: str = None,
851
+ attributes: dict = None,
852
+ latents: List[LatentType] = None,
853
+ tags: list[str] = None,
854
+ coca_embedding: str = None,
855
+ related_images: dict[str, str] | None = None,
856
+ ):
857
+ """
858
+ Update the image, overwriting the tags and merging attributes, latents and related_images.
859
+ """
860
+
861
+ if coca_embedding:
862
+ self._validate_vector(coca_embedding)
863
+
864
+ if latents:
865
+ files = []
866
+ for i, latent in enumerate(latents):
867
+ if 'latent_type' not in latent:
868
+ raise DataRoomError("Missing 'latent_type' field in latent")
869
+ if 'file' not in latent:
870
+ raise DataRoomError("Missing 'file' field in latent")
871
+ if not isinstance(latent['file'], DataRoomFile):
872
+ raise DataRoomError("Property 'file' must be a DataRoomFile")
873
+
874
+ latent_file = latent['file']
875
+ files.append((
876
+ f"latent_{i}",
877
+ (
878
+ latent_file.filename,
879
+ latent_file.bytes_io,
880
+ latent_file.content_type,
881
+ ),
882
+ ))
883
+
884
+ json_data = self._dict_filter_none({
885
+ "latent_type": latent['latent_type'],
886
+ })
887
+ files.append((
888
+ f"latent_json_{i}",
889
+ (None, json_module.dumps(json_data), "text/plain")
890
+ ))
891
+
892
+ image_data = self._dict_filter_none({
893
+ "source": source,
894
+ "attributes": attributes,
895
+ "tags": tags,
896
+ "coca_embedding": coca_embedding,
897
+ "related_images": related_images,
898
+ })
899
+ files.append((
900
+ "json",
901
+ (None, json_module.dumps(image_data), "text/plain")
902
+ ))
903
+ return await self._make_request(url=f"images/{image_id}/", method="PUT", files=files)
904
+ else:
905
+ return await self._make_request(
906
+ url=f"images/{image_id}/",
907
+ method="PUT",
908
+ json=self._dict_filter_none({
909
+ "source": source,
910
+ "attributes": attributes,
911
+ "tags": tags,
912
+ "coca_embedding": coca_embedding,
913
+ "related_images": related_images,
914
+ }),
915
+ )
916
+
917
+ async def update_images(
918
+ self,
919
+ images: List[ImageUpdate],
920
+ ):
921
+ """
922
+ Bulk update images, overwriting the tags and merging attributes, latents and related_images.
923
+ """
924
+ for image in images:
925
+ if 'id' not in image:
926
+ raise DataRoomError("Missing 'id' field in image")
927
+ image.setdefault('source', None)
928
+ image.setdefault('attributes', None)
929
+ image.setdefault('tags', None)
930
+ image.setdefault('coca_embedding', None)
931
+ image.setdefault('related_images', None)
932
+
933
+ return await self._make_request(
934
+ url=f"images/bulk_update/",
935
+ method="PUT",
936
+ json=[
937
+ self._dict_filter_none({
938
+ "id": image['id'],
939
+ "source": image['source'],
940
+ "attributes": image['attributes'],
941
+ "tags": image['tags'],
942
+ "coca_embedding": image['coca_embedding'],
943
+ "related_images": image['related_images'],
944
+ })
945
+ for image in images
946
+ ],
947
+ )
948
+
949
+ async def add_image_attributes(
950
+ self,
951
+ image_id: str,
952
+ attributes: dict,
953
+ ):
954
+ """
955
+ Update attributes of an image, merging them with the existing attributes.
956
+ """
957
+ logger.warning(
958
+ 'DEPRECATION WARNING: Method "add_image_attributes" is deprecated, and will be removed in the future. '
959
+ 'Please use "update_image" instead.'
960
+ )
961
+
962
+ return await self._make_request(
963
+ url=f"images/{image_id}/add_attributes/",
964
+ method="PUT",
965
+ json={
966
+ "attributes": attributes,
967
+ },
968
+ )
969
+
970
+ async def add_image_attributes_in_bulk(
971
+ self,
972
+ ids_to_attributes: dict[str, dict],
973
+ ):
974
+ """
975
+ Update attributes of a list of images, merging them with the existing attributes.
976
+ """
977
+ logger.warning(
978
+ 'DEPRECATION WARNING: Method "add_image_attributes_in_bulk" is deprecated, '
979
+ 'and will be removed in the future. Please use "update_image" instead.'
980
+ )
981
+
982
+ return await self._make_request(
983
+ url=f"images/add_attributes_bulk/",
984
+ method="POST",
985
+ json=[
986
+ {"image_id": key, "attributes": val}
987
+ for key, val in ids_to_attributes.items()
988
+ ],
989
+ )
990
+
991
+ async def delete_image(self, image_id: str):
992
+ return await self._make_request(
993
+ url=f"images/{image_id}/",
994
+ method="DELETE",
995
+ )
996
+
997
+ async def get_image_audit_logs(self, image_id: str):
998
+ return await self._make_request(
999
+ url=f"images/{image_id}/audit_logs/",
1000
+ )
1001
+
1002
+ async def get_image_similarity(self, image_id_1: str, image_id_2: str):
1003
+ response = await self._make_request(
1004
+ url=f"images/{image_id_1}/similarity/",
1005
+ method="POST",
1006
+ json={
1007
+ "image_id": image_id_2,
1008
+ },
1009
+ )
1010
+ return response["similarity"]
1011
+
1012
+ async def get_similar_images(
1013
+ self,
1014
+ # similarity by
1015
+ image_id: str = None,
1016
+ image_file: DataRoomFile = None,
1017
+ image_vector: str = None,
1018
+ image_text: str = None,
1019
+ # options
1020
+ number=5,
1021
+ fields: List[str] = None,
1022
+ include_fields: List[str] = None,
1023
+ exclude_fields: List[str] = None,
1024
+ all_fields: bool = False,
1025
+ return_latents: List[str] = None,
1026
+ # filters
1027
+ short_edge: int | None = None,
1028
+ short_edge__gt: int = None,
1029
+ short_edge__gte: int = None,
1030
+ short_edge__lt: int = None,
1031
+ short_edge__lte: int = None,
1032
+ pixel_count: int | None = None,
1033
+ pixel_count__gt: int = None,
1034
+ pixel_count__gte: int = None,
1035
+ pixel_count__lt: int = None,
1036
+ pixel_count__lte: int = None,
1037
+ aspect_ratio_fraction: str = None,
1038
+ aspect_ratio: float = None,
1039
+ aspect_ratio__gt: float = None,
1040
+ aspect_ratio__gte: float = None,
1041
+ aspect_ratio__lt: float = None,
1042
+ aspect_ratio__lte: float = None,
1043
+ sources: List[str] = None,
1044
+ sources__ne: List[str] = None,
1045
+ attributes: dict = None,
1046
+ has_attributes: list = None,
1047
+ lacks_attributes: list = None,
1048
+ has_latents: List[str] = None,
1049
+ lacks_latents: List[str] = None,
1050
+ has_masks: List[str] = None,
1051
+ lacks_masks: List[str] = None,
1052
+ tags: list = None,
1053
+ tags__ne: list = None,
1054
+ tags__all: list = None,
1055
+ tags__ne_all: list = None,
1056
+ tags__empty: bool = None,
1057
+ coca_embedding__empty: bool = None,
1058
+ duplicate_state: ClientDuplicateState = None,
1059
+ date_created__gt: datetime = None,
1060
+ date_created__gte: datetime = None,
1061
+ date_created__lt: datetime = None,
1062
+ date_created__lte: datetime = None,
1063
+ date_updated__gt: datetime = None,
1064
+ date_updated__gte: datetime = None,
1065
+ date_updated__lt: datetime = None,
1066
+ date_updated__lte: datetime = None,
1067
+ ):
1068
+ search_args = {
1069
+ 'image_id': image_id, 'image_file': image_file, 'image_vector': image_vector, 'image_text': image_text,
1070
+ }
1071
+ if sum([bool(arg) for arg in search_args.values()]) != 1:
1072
+ raise DataRoomError(f'Please provide one of the following arguments: {", ".join(search_args.keys())}')
1073
+
1074
+ params = self._dict_filter_none({
1075
+ "fields": ",".join(fields) if fields else None,
1076
+ "include_fields": ",".join(include_fields) if include_fields else None,
1077
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
1078
+ "all_fields": all_fields if all_fields else None,
1079
+ "return_latents": ",".join(return_latents) if return_latents else None,
1080
+ # filters
1081
+ "short_edge": short_edge,
1082
+ "short_edge__gt": short_edge__gt,
1083
+ "short_edge__gte": short_edge__gte,
1084
+ "short_edge__lt": short_edge__lt,
1085
+ "short_edge__lte": short_edge__lte,
1086
+ "pixel_count": pixel_count,
1087
+ "pixel_count__gt": pixel_count__gt,
1088
+ "pixel_count__gte": pixel_count__gte,
1089
+ "pixel_count__lt": pixel_count__lt,
1090
+ "pixel_count__lte": pixel_count__lte,
1091
+ "aspect_ratio_fraction": aspect_ratio_fraction,
1092
+ "aspect_ratio": aspect_ratio,
1093
+ "aspect_ratio__gt": aspect_ratio__gt,
1094
+ "aspect_ratio__gte": aspect_ratio__gte,
1095
+ "aspect_ratio__lt": aspect_ratio__lt,
1096
+ "aspect_ratio__lte": aspect_ratio__lte,
1097
+ "sources": ",".join(sources) if sources else None,
1098
+ "sources__ne": ",".join(sources__ne) if sources__ne else None,
1099
+ "attributes": self._get_attributes_filter(attributes),
1100
+ "has_attributes": ",".join(has_attributes) if has_attributes else None,
1101
+ "lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
1102
+ "has_latents": ",".join(has_latents) if has_latents else None,
1103
+ "lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
1104
+ "has_masks": ",".join(has_masks) if has_masks else None,
1105
+ "lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
1106
+ "tags": ",".join(tags) if tags else None,
1107
+ "tags__ne": ",".join(tags__ne) if tags__ne else None,
1108
+ "tags__all": ",".join(tags__all) if tags__all else None,
1109
+ "tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
1110
+ "tags__empty": tags__empty,
1111
+ "coca_embedding__empty": coca_embedding__empty,
1112
+ "duplicate_state": duplicate_state.value if duplicate_state else None,
1113
+ "date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
1114
+ "date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
1115
+ "date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
1116
+ "date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
1117
+ "date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
1118
+ "date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
1119
+ "date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
1120
+ "date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
1121
+ })
1122
+
1123
+ if image_file:
1124
+ # by image file
1125
+ if not isinstance(image_file, DataRoomFile):
1126
+ raise DataRoomError("Argument image_file must be a DataRoomFile")
1127
+ json_data = {
1128
+ "number": number,
1129
+ }
1130
+ files = {
1131
+ "image": (
1132
+ image_file.filename,
1133
+ image_file.bytes_io,
1134
+ image_file.content_type,
1135
+ ),
1136
+ "json": (None, json_module.dumps(json_data), "text/plain"),
1137
+ }
1138
+ return await self._make_request(
1139
+ url=f"images/similar_to_file/",
1140
+ method="POST",
1141
+ files=files,
1142
+ params=params,
1143
+ )
1144
+ elif image_id:
1145
+ # by image id
1146
+ response = await self._make_request(
1147
+ url=f"images/{image_id}/similar/",
1148
+ params={
1149
+ "number": number,
1150
+ **params,
1151
+ },
1152
+ )
1153
+ return response
1154
+ elif image_vector:
1155
+ # by image vector
1156
+ self._validate_vector(image_vector)
1157
+ return await self._make_request(
1158
+ url=f"images/similar_to_vector/",
1159
+ method="POST",
1160
+ json={
1161
+ "vector": image_vector,
1162
+ "number": number,
1163
+ },
1164
+ params=params,
1165
+ )
1166
+ elif image_text:
1167
+ # by text
1168
+ return await self._make_request(
1169
+ url=f"images/similar_to_text/",
1170
+ method="POST",
1171
+ json={
1172
+ "text": image_text,
1173
+ "number": number,
1174
+ },
1175
+ params=params,
1176
+ )
1177
+ else:
1178
+ raise DataRoomError("Invalid arguments")
1179
+
1180
+ async def get_related_images(
1181
+ self,
1182
+ image_id: str,
1183
+ # options
1184
+ fields: List[str] = None,
1185
+ include_fields: List[str] = None,
1186
+ exclude_fields: List[str] = None,
1187
+ all_fields: bool = False,
1188
+ return_latents: List[str] = None,
1189
+ ):
1190
+ params = self._dict_filter_none({
1191
+ "fields": ",".join(fields) if fields else None,
1192
+ "include_fields": ",".join(include_fields) if include_fields else None,
1193
+ "exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
1194
+ "all_fields": all_fields if all_fields else None,
1195
+ "return_latents": ",".join(return_latents) if return_latents else None,
1196
+ })
1197
+ return await self._make_request(
1198
+ url=f"images/{image_id}/related/",
1199
+ params=params,
1200
+ )
1201
+
1202
+ async def set_image_latent(
1203
+ self,
1204
+ image_id: str,
1205
+ latent_file: DataRoomFile,
1206
+ latent_type: str,
1207
+ is_mask=None,
1208
+ ):
1209
+ logger.warning(
1210
+ 'DEPRECATION WARNING: Method "set_image_latent" is deprecated, and will be removed in the future. '
1211
+ 'Please use "update_image" instead.'
1212
+ )
1213
+
1214
+ if not isinstance(latent_file, DataRoomFile):
1215
+ raise DataRoomError("Argument latent_file must be a DataRoomFile")
1216
+
1217
+ if is_mask is not None:
1218
+ logger.warning(arg_deprecation_msg('is_mask'))
1219
+
1220
+ json_data = {
1221
+ "latent_type": latent_type,
1222
+ }
1223
+ files = {
1224
+ "file": (
1225
+ latent_file.filename,
1226
+ latent_file.bytes_io,
1227
+ latent_file.content_type,
1228
+ ),
1229
+ "json": (None, json_module.dumps(json_data), "text/plain"),
1230
+ }
1231
+ return await self._make_request(
1232
+ url=f"images/{image_id}/set_latent/",
1233
+ method="POST",
1234
+ files=files,
1235
+ )
1236
+
1237
+ async def delete_image_latent(self, image_id: str, latent_type: str):
1238
+ return await self._make_request(
1239
+ url=f"images/{image_id}/delete_latent/",
1240
+ method="POST",
1241
+ json={
1242
+ "latent_type": latent_type,
1243
+ },
1244
+ )
1245
+
1246
+ async def set_image_coca_embedding(self, image_id: str, vector: str):
1247
+ logger.warning(
1248
+ 'DEPRECATION WARNING: Method "set_image_coca_embedding" is deprecated, and will be removed in the future. '
1249
+ 'Please use "update_image" instead.'
1250
+ )
1251
+ self._validate_vector(vector)
1252
+ return await self._make_request(
1253
+ url=f"images/{image_id}/",
1254
+ method="PUT",
1255
+ json={
1256
+ "coca_embedding": vector,
1257
+ },
1258
+ )
1259
+
1260
+ async def aggregate_images(self, field, type):
1261
+ return await self._make_request(
1262
+ url="images/aggregate/",
1263
+ method="POST",
1264
+ json={
1265
+ "field": field,
1266
+ "type": type,
1267
+ },
1268
+ )
1269
+
1270
+ async def bucket_images(self, field, size):
1271
+ return await self._make_request(
1272
+ url="images/bucket/",
1273
+ method="POST",
1274
+ json={
1275
+ "field": field,
1276
+ "size": size,
1277
+ },
1278
+ )
1279
+
1280
+ # -------------------- Tag API methods --------------------
1281
+
1282
+ async def get_tags(self, limit: int = 1000):
1283
+ return await self._make_paginated_request(
1284
+ url=f"tags/",
1285
+ limit=limit,
1286
+ )
1287
+
1288
+ async def get_tag(self, tag_id: str):
1289
+ return await self._make_request(
1290
+ url=f"tags/{tag_id}/",
1291
+ )
1292
+
1293
+ async def create_tag(self, name: str, description: str = None):
1294
+ return await self._make_request(
1295
+ url="tags/",
1296
+ method="POST",
1297
+ json=self._dict_filter_none(
1298
+ {
1299
+ "name": name,
1300
+ "description": description,
1301
+ }
1302
+ ),
1303
+ )
1304
+
1305
+ async def tag_images(self, image_ids: List[str], tag_names: List[str]):
1306
+ return await self._make_request(
1307
+ url="tags/tag_images/",
1308
+ method="PUT",
1309
+ json={
1310
+ "image_ids": image_ids,
1311
+ "tag_names": tag_names,
1312
+ },
1313
+ )
1314
+
1315
+
1316
+ class DataRoomClientSync:
1317
+ """
1318
+ The official client of the DataRoom API using synchronous method and requests.
1319
+ """
1320
+
1321
+ def __init__(self, api_key=None, api_url=None):
1322
+ """
1323
+ @param api_key: API key for DataRoom API
1324
+ @param api_url: URL of the DataRoom backend API
1325
+ """
1326
+ self.api_key = api_key or os.environ.get("DATAROOM_API_KEY")
1327
+ self.api_url = (
1328
+ api_url
1329
+ or os.environ.get("DATAROOM_API_URL")
1330
+ )
1331
+ if not self.api_url:
1332
+ raise DataRoomError("DataRoom api_url is not set")
1333
+ self._async_client = DataRoomClient(api_key=api_key, api_url=api_url)
1334
+
1335
+ # -------------------- Private methods --------------------
1336
+
1337
+ @classmethod
1338
+ def _run_sync(cls, coro):
1339
+ try:
1340
+ # Check if there's an existing running event loop
1341
+ loop = asyncio.get_running_loop()
1342
+ except RuntimeError:
1343
+ # No running event loop, create a new one
1344
+ return asyncio.run(coro)
1345
+ else:
1346
+ # A running event loop exists, use run_until_complete
1347
+ return loop.run_until_complete(coro)
1348
+
1349
+ def _make_request(self, *args, **kwargs):
1350
+ return self._run_sync(self._async_client._make_request(*args, **kwargs))
1351
+
1352
+ def _make_paginated_request(self, *args, **kwargs):
1353
+ return self._run_sync(
1354
+ self._async_client._make_paginated_request(*args, **kwargs)
1355
+ )
1356
+
1357
+ # -------------------- Utils --------------------
1358
+
1359
+ @classmethod
1360
+ def download_image_from_url(cls, *args, **kwargs) -> DataRoomFile:
1361
+ return cls._run_sync(DataRoomClient.download_image_from_url(*args, **kwargs))
1362
+
1363
+ # -------------------- Image API methods --------------------
1364
+
1365
+ def get_images(self, *args, **kwargs):
1366
+ return self._run_sync(self._async_client.get_images(*args, **kwargs))
1367
+
1368
+ def get_images_iter(self, *args, **kwargs):
1369
+ return self._run_sync(self._async_client.get_images_iter(*args, **kwargs))
1370
+
1371
+ def get_random_images(self, *args, **kwargs):
1372
+ return self._run_sync(self._async_client.get_random_images(*args, **kwargs))
1373
+
1374
+ def count_images(self, *args, **kwargs):
1375
+ return self._run_sync(self._async_client.count_images(*args, **kwargs))
1376
+
1377
+ def get_image(self, *args, **kwargs):
1378
+ return self._run_sync(self._async_client.get_image(*args, **kwargs))
1379
+
1380
+ def create_image(self, *args, **kwargs):
1381
+ return self._run_sync(self._async_client.create_image(*args, **kwargs))
1382
+
1383
+ def create_images(self, *args, **kwargs):
1384
+ return self._run_sync(self._async_client.create_images(*args, **kwargs))
1385
+
1386
+ def delete_image(self, *args, **kwargs):
1387
+ return self._run_sync(self._async_client.delete_image(*args, **kwargs))
1388
+
1389
+ def get_image_audit_logs(self, *args, **kwargs):
1390
+ return self._run_sync(self._async_client.get_image_audit_logs(*args, **kwargs))
1391
+
1392
+ def get_image_similarity(self, *args, **kwargs):
1393
+ return self._run_sync(self._async_client.get_image_similarity(*args, **kwargs))
1394
+
1395
+ def get_similar_images(self, *args, **kwargs):
1396
+ return self._run_sync(self._async_client.get_similar_images(*args, **kwargs))
1397
+
1398
+ def set_image_latent(self, *args, **kwargs):
1399
+ return self._run_sync(self._async_client.set_image_latent(*args, **kwargs))
1400
+
1401
+ def delete_image_latent(self, *args, **kwargs):
1402
+ return self._run_sync(self._async_client.delete_image_latent(*args, **kwargs))
1403
+
1404
+ def update_image(self,*args, **kwargs):
1405
+ return self._run_sync(self._async_client.update_image(*args, **kwargs))
1406
+
1407
+ def update_images(self,*args, **kwargs):
1408
+ return self._run_sync(self._async_client.update_images(*args, **kwargs))
1409
+
1410
+ def add_image_attributes(self, *args, **kwargs):
1411
+ return self._run_sync(self._async_client.add_image_attributes(*args, **kwargs))
1412
+
1413
+ def add_image_attributes_in_bulk(self, *args, **kwargs):
1414
+ return self._run_sync(self._async_client.add_image_attributes_in_bulk(*args, **kwargs))
1415
+
1416
+ def set_image_coca_embedding(self, *args, **kwargs):
1417
+ return self._run_sync(self._async_client.set_image_coca_embedding(*args, **kwargs))
1418
+
1419
+ def aggregate_images(self, *args, **kwargs):
1420
+ return self._run_sync(self._async_client.aggregate_images(*args, **kwargs))
1421
+
1422
+ def bucket_images(self, *args, **kwargs):
1423
+ return self._run_sync(self._async_client.bucket_images(*args, **kwargs))
1424
+
1425
+ # -------------------- Tag API methods --------------------
1426
+
1427
+ def create_tag(self, *args, **kwargs):
1428
+ return self._run_sync(self._async_client.create_tag(*args, **kwargs))
1429
+
1430
+ def get_tag(self, *args, **kwargs):
1431
+ return self._run_sync(self._async_client.get_tag(*args, **kwargs))
1432
+
1433
+ def get_tags(self, *args, **kwargs):
1434
+ return self._run_sync(self._async_client.get_tags(*args, **kwargs))
1435
+
1436
+ def tag_images(self, *args, **kwargs):
1437
+ return self._run_sync(self._async_client.tag_images(*args, **kwargs))
1438
+
1439
+
1440
+ for method_name in dir(DataRoomClient):
1441
+ if not method_name.startswith("_"):
1442
+ if not hasattr(DataRoomClientSync, method_name):
1443
+ logger.warning(f"Missing implementation: DataRoomClientSync.{method_name}")