dataroom-client 0.0.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataroom_client/__init__.py +2 -0
- dataroom_client/client.py +1443 -0
- dataroom_client/counter.py +45 -0
- dataroom_client/loader.py +61 -0
- dataroom_client/print_utils.py +11 -0
- dataroom_client-0.0.0.post1.dev0.dist-info/METADATA +41 -0
- dataroom_client-0.0.0.post1.dev0.dist-info/RECORD +8 -0
- dataroom_client-0.0.0.post1.dev0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1443 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import json as json_module
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import uuid
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from io import BytesIO
|
|
9
|
+
import mimetypes
|
|
10
|
+
from typing import List, TypedDict, Optional
|
|
11
|
+
from urllib.parse import urljoin
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
mimetypes.add_type("image/webp", ".webp")
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DataRoomError(Exception):
|
|
20
|
+
"""Base exception class for DataRoomClient errors"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, *args, **kwargs):
|
|
23
|
+
self.response = kwargs.pop("response", None)
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
|
|
26
|
+
def __str__(self):
|
|
27
|
+
if self.response:
|
|
28
|
+
return f"{super().__str__()}\n{self.response.status_code}\n{self.response.text}"
|
|
29
|
+
else:
|
|
30
|
+
return super().__str__()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DataRoomFile:
|
|
34
|
+
"""A wrapper for a file-like object that can be used with DataRoomClient"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, bytes_io, content_type, path=None, extension=None):
|
|
37
|
+
extension = (
|
|
38
|
+
mimetypes.guess_extension(content_type) or "" if extension is None else extension
|
|
39
|
+
)
|
|
40
|
+
self.bytes_io = bytes_io
|
|
41
|
+
self.content_type = content_type
|
|
42
|
+
if extension[0] != ".":
|
|
43
|
+
extension = f".{extension}"
|
|
44
|
+
self.extension = extension
|
|
45
|
+
self.name = f"{uuid.uuid4().hex}"
|
|
46
|
+
self.filename = f"{self.name}{extension}"
|
|
47
|
+
self.path = path
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_path(cls, path: str):
|
|
51
|
+
content_type, encoding = mimetypes.guess_type(path)
|
|
52
|
+
if not content_type:
|
|
53
|
+
raise DataRoomError(f"Could not guess content type for file: {path}")
|
|
54
|
+
with open(path, "rb") as f:
|
|
55
|
+
return DataRoomFile(
|
|
56
|
+
bytes_io=BytesIO(f.read()),
|
|
57
|
+
content_type=content_type,
|
|
58
|
+
path=path,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_bytesio(cls, bytes_io, extension):
|
|
63
|
+
assert extension is not None, "Please provide a file extension"
|
|
64
|
+
return DataRoomFile(
|
|
65
|
+
bytes_io=bytes_io,
|
|
66
|
+
extension=extension,
|
|
67
|
+
content_type=None,
|
|
68
|
+
path=None,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ClientDuplicateState(Enum):
|
|
73
|
+
UNPROCESSED = 'None'
|
|
74
|
+
ORIGINAL = 1
|
|
75
|
+
DUPLICATE = 2
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LatentType(TypedDict, total=False):
|
|
79
|
+
latent_type: str
|
|
80
|
+
file: DataRoomFile
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ImageUpdate(TypedDict, total=False):
|
|
84
|
+
id: str # noqa: A003
|
|
85
|
+
source: Optional[str]
|
|
86
|
+
attributes: Optional[dict]
|
|
87
|
+
tags: Optional[List[str]]
|
|
88
|
+
coca_embedding: Optional[str]
|
|
89
|
+
related_images: Optional[dict[str, str]]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ImageCreate(TypedDict, total=False):
|
|
93
|
+
id: str # noqa: A003
|
|
94
|
+
source: str
|
|
95
|
+
image_file: Optional[DataRoomFile]
|
|
96
|
+
image_url: Optional[str]
|
|
97
|
+
attributes: Optional[dict]
|
|
98
|
+
tags: Optional[list[str]]
|
|
99
|
+
related_images: Optional[dict[str, str]]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def arg_deprecation_msg(arg_name, msg=''):
|
|
103
|
+
return f'DEPRECATION WARNING: Argument "{arg_name}" is deprecated, and will be removed in the future. {msg}'
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DataRoomClient:
|
|
107
|
+
"""
|
|
108
|
+
The official client of the DataRoom API. See notebooks for usage examples.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, api_key=None, api_url=None, timeout=120):
|
|
112
|
+
"""
|
|
113
|
+
@param api_key: API key for DataRoom API
|
|
114
|
+
@param api_url: URL of the DataRoom backend API
|
|
115
|
+
"""
|
|
116
|
+
self.api_key = api_key or os.environ.get("DATAROOM_API_KEY")
|
|
117
|
+
self.api_url = (
|
|
118
|
+
api_url
|
|
119
|
+
or os.environ.get("DATAROOM_API_URL")
|
|
120
|
+
)
|
|
121
|
+
if not self.api_url:
|
|
122
|
+
raise DataRoomError("DataRoom api_url is not set")
|
|
123
|
+
self.client = httpx.AsyncClient()
|
|
124
|
+
self.timeout = timeout
|
|
125
|
+
|
|
126
|
+
# -------------------- Private methods --------------------
|
|
127
|
+
|
|
128
|
+
async def _make_request(
|
|
129
|
+
self, url, params=None, method="GET", json=None, files=None, headers=None,
|
|
130
|
+
):
|
|
131
|
+
absolute_url = urljoin(self.api_url, url)
|
|
132
|
+
if headers is None:
|
|
133
|
+
headers = {}
|
|
134
|
+
headers.update({
|
|
135
|
+
"Authorization": f"Token {self.api_key}",
|
|
136
|
+
})
|
|
137
|
+
try:
|
|
138
|
+
response = await self.client.request(
|
|
139
|
+
method=method,
|
|
140
|
+
url=absolute_url,
|
|
141
|
+
params=params,
|
|
142
|
+
json=json,
|
|
143
|
+
files=files,
|
|
144
|
+
headers=headers,
|
|
145
|
+
timeout=self.timeout,
|
|
146
|
+
)
|
|
147
|
+
response.raise_for_status()
|
|
148
|
+
except httpx.HTTPError as e:
|
|
149
|
+
response = None
|
|
150
|
+
if hasattr(e, "response"):
|
|
151
|
+
response = e.response
|
|
152
|
+
raise DataRoomError(e, response=response) from e
|
|
153
|
+
else:
|
|
154
|
+
if response.content:
|
|
155
|
+
return response.json()
|
|
156
|
+
|
|
157
|
+
async def _make_paginated_request(
|
|
158
|
+
self, url, limit=1000, params=None, method="GET", json=None, headers=None,
|
|
159
|
+
):
|
|
160
|
+
items = []
|
|
161
|
+
next_url = url
|
|
162
|
+
while next_url:
|
|
163
|
+
response = await self._make_request(
|
|
164
|
+
next_url, params=params, method=method, json=json, headers=headers,
|
|
165
|
+
)
|
|
166
|
+
if "results" not in response:
|
|
167
|
+
raise NotImplementedError(f'No "results" in response to {url}')
|
|
168
|
+
if "next" not in response:
|
|
169
|
+
raise NotImplementedError(f'No "next" in response to {url}')
|
|
170
|
+
next_url = response["next"]
|
|
171
|
+
items += response["results"]
|
|
172
|
+
if limit is not None and len(items) >= limit:
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
if limit is not None:
|
|
176
|
+
return items[:limit]
|
|
177
|
+
return items
|
|
178
|
+
|
|
179
|
+
async def _make_paginated_request_iter(
|
|
180
|
+
self, url, limit=1000, params=None, method="GET", json=None, headers=None,
|
|
181
|
+
):
|
|
182
|
+
next_url = url
|
|
183
|
+
returned_items = 0
|
|
184
|
+
while next_url:
|
|
185
|
+
response = await self._make_request(
|
|
186
|
+
next_url, params=params, method=method, json=json, headers=headers,
|
|
187
|
+
)
|
|
188
|
+
if "results" not in response:
|
|
189
|
+
raise NotImplementedError(f'No "results" in response to {url}')
|
|
190
|
+
if "next" not in response:
|
|
191
|
+
raise NotImplementedError(f'No "next" in response to {url}')
|
|
192
|
+
next_url = response["next"]
|
|
193
|
+
for item in response["results"]:
|
|
194
|
+
yield item
|
|
195
|
+
returned_items += 1
|
|
196
|
+
if limit is not None and returned_items >= limit:
|
|
197
|
+
break
|
|
198
|
+
if limit is not None and returned_items >= limit:
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def _dict_filter_none(d: dict):
|
|
203
|
+
return {k: v for k, v in d.items() if v is not None}
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def _get_attributes_filter(attributes: dict | None):
|
|
207
|
+
if not attributes:
|
|
208
|
+
return None
|
|
209
|
+
for key, val in attributes.items():
|
|
210
|
+
val = str(val)
|
|
211
|
+
if "," in key or "," in val:
|
|
212
|
+
raise DataRoomError(
|
|
213
|
+
"Commas are not allowed in attribute keys or values"
|
|
214
|
+
)
|
|
215
|
+
if ":" in key or ":" in val:
|
|
216
|
+
raise DataRoomError(
|
|
217
|
+
"Colons are not allowed in attribute keys or values"
|
|
218
|
+
)
|
|
219
|
+
attrs_str = ",".join([f"{key}:{val}" for key, val in attributes.items()])
|
|
220
|
+
return attrs_str
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def _validate_vector(vector: str):
|
|
224
|
+
err_msg = "Argument vector must be a string representing a list of 768 floats."
|
|
225
|
+
if not isinstance(vector, str) or not len(vector) > 0:
|
|
226
|
+
raise DataRoomError(f"{err_msg} Not a string.")
|
|
227
|
+
if vector[0] != "[" or vector[-1] != "]":
|
|
228
|
+
raise DataRoomError(f"{err_msg} Not a list.")
|
|
229
|
+
if len(vector[1:-1].split(',')) != 768:
|
|
230
|
+
raise DataRoomError(f"{err_msg} Incorrect length.")
|
|
231
|
+
|
|
232
|
+
# -------------------- Utils --------------------
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
async def download_image_from_url(cls, image_url: str) -> DataRoomFile:
|
|
236
|
+
try:
|
|
237
|
+
async with httpx.AsyncClient() as client:
|
|
238
|
+
response = await client.get(image_url)
|
|
239
|
+
response.raise_for_status()
|
|
240
|
+
except httpx.HTTPError as e:
|
|
241
|
+
response = None
|
|
242
|
+
if hasattr(e, "response"):
|
|
243
|
+
response = e.response
|
|
244
|
+
raise DataRoomError(e, response=response) from e
|
|
245
|
+
else:
|
|
246
|
+
content_type = response.headers.get("Content-Type")
|
|
247
|
+
return DataRoomFile(
|
|
248
|
+
bytes_io=BytesIO(response.content),
|
|
249
|
+
content_type=content_type,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# -------------------- Image API methods --------------------
|
|
253
|
+
|
|
254
|
+
async def get_images(
|
|
255
|
+
self,
|
|
256
|
+
limit: int | None = 1000,
|
|
257
|
+
page_size: int = None,
|
|
258
|
+
fields: List[str] = None,
|
|
259
|
+
include_fields: List[str] = None,
|
|
260
|
+
exclude_fields: List[str] = None,
|
|
261
|
+
all_fields: bool = False,
|
|
262
|
+
return_latents: List[str] = None,
|
|
263
|
+
cache_ttl: int = None,
|
|
264
|
+
partitions_count: int = None,
|
|
265
|
+
partition: int = None,
|
|
266
|
+
# filters
|
|
267
|
+
short_edge: int = None,
|
|
268
|
+
short_edge__gt: int = None,
|
|
269
|
+
short_edge__gte: int = None,
|
|
270
|
+
short_edge__lt: int = None,
|
|
271
|
+
short_edge__lte: int = None,
|
|
272
|
+
pixel_count: int = None,
|
|
273
|
+
pixel_count__gt: int = None,
|
|
274
|
+
pixel_count__gte: int = None,
|
|
275
|
+
pixel_count__lt: int = None,
|
|
276
|
+
pixel_count__lte: int = None,
|
|
277
|
+
aspect_ratio_fraction: str = None,
|
|
278
|
+
aspect_ratio: float = None,
|
|
279
|
+
aspect_ratio__gt: float = None,
|
|
280
|
+
aspect_ratio__gte: float = None,
|
|
281
|
+
aspect_ratio__lt: float = None,
|
|
282
|
+
aspect_ratio__lte: float = None,
|
|
283
|
+
source: str = None,
|
|
284
|
+
sources: List[str] = None,
|
|
285
|
+
sources__ne: List[str] = None,
|
|
286
|
+
attributes: dict = None,
|
|
287
|
+
has_attributes: list = None,
|
|
288
|
+
lacks_attributes: list = None,
|
|
289
|
+
has_latents: List[str] = None,
|
|
290
|
+
lacks_latents: List[str] = None,
|
|
291
|
+
has_masks: List[str] = None,
|
|
292
|
+
lacks_masks: List[str] = None,
|
|
293
|
+
tags: list = None,
|
|
294
|
+
tags__ne: list = None,
|
|
295
|
+
tags__all: list = None,
|
|
296
|
+
tags__ne_all: list = None,
|
|
297
|
+
tags__empty: bool = None,
|
|
298
|
+
coca_embedding__empty: bool = None,
|
|
299
|
+
duplicate_state: ClientDuplicateState = None,
|
|
300
|
+
date_created__gt: datetime = None,
|
|
301
|
+
date_created__gte: datetime = None,
|
|
302
|
+
date_created__lt: datetime = None,
|
|
303
|
+
date_created__lte: datetime = None,
|
|
304
|
+
date_updated__gt: datetime = None,
|
|
305
|
+
date_updated__gte: datetime = None,
|
|
306
|
+
date_updated__lt: datetime = None,
|
|
307
|
+
date_updated__lte: datetime = None,
|
|
308
|
+
):
|
|
309
|
+
headers = {}
|
|
310
|
+
if cache_ttl:
|
|
311
|
+
headers["Cache-Control"] = f"max-age={cache_ttl}"
|
|
312
|
+
|
|
313
|
+
if source is not None:
|
|
314
|
+
sources = [source]
|
|
315
|
+
logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
|
|
316
|
+
|
|
317
|
+
return await self._make_paginated_request(
|
|
318
|
+
url="images/",
|
|
319
|
+
limit=limit,
|
|
320
|
+
params=self._dict_filter_none(
|
|
321
|
+
{
|
|
322
|
+
"fields": ",".join(fields) if fields else None,
|
|
323
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
324
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
325
|
+
"all_fields": all_fields if all_fields else None,
|
|
326
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
327
|
+
"page_size": page_size,
|
|
328
|
+
"partitions_count": partitions_count,
|
|
329
|
+
"partition": partition,
|
|
330
|
+
# filters
|
|
331
|
+
"short_edge": short_edge,
|
|
332
|
+
"short_edge__gt": short_edge__gt,
|
|
333
|
+
"short_edge__gte": short_edge__gte,
|
|
334
|
+
"short_edge__lt": short_edge__lt,
|
|
335
|
+
"short_edge__lte": short_edge__lte,
|
|
336
|
+
"pixel_count": pixel_count,
|
|
337
|
+
"pixel_count__gt": pixel_count__gt,
|
|
338
|
+
"pixel_count__gte": pixel_count__gte,
|
|
339
|
+
"pixel_count__lt": pixel_count__lt,
|
|
340
|
+
"pixel_count__lte": pixel_count__lte,
|
|
341
|
+
"aspect_ratio_fraction": aspect_ratio_fraction,
|
|
342
|
+
"aspect_ratio": aspect_ratio,
|
|
343
|
+
"aspect_ratio__gt": aspect_ratio__gt,
|
|
344
|
+
"aspect_ratio__gte": aspect_ratio__gte,
|
|
345
|
+
"aspect_ratio__lt": aspect_ratio__lt,
|
|
346
|
+
"aspect_ratio__lte": aspect_ratio__lte,
|
|
347
|
+
"sources": ",".join(sources) if sources else None,
|
|
348
|
+
"sources__ne": ",".join(sources__ne) if sources__ne else None,
|
|
349
|
+
"attributes": self._get_attributes_filter(attributes),
|
|
350
|
+
"has_attributes": ",".join(has_attributes) if has_attributes else None,
|
|
351
|
+
"lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
|
|
352
|
+
"has_latents": ",".join(has_latents) if has_latents else None,
|
|
353
|
+
"lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
|
|
354
|
+
"has_masks": ",".join(has_masks) if has_masks else None,
|
|
355
|
+
"lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
|
|
356
|
+
"tags": ",".join(tags) if tags else None,
|
|
357
|
+
"tags__ne": ",".join(tags__ne) if tags__ne else None,
|
|
358
|
+
"tags__all": ",".join(tags__all) if tags__all else None,
|
|
359
|
+
"tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
|
|
360
|
+
"tags__empty": tags__empty,
|
|
361
|
+
"coca_embedding__empty": coca_embedding__empty,
|
|
362
|
+
"duplicate_state": duplicate_state.value if duplicate_state else None,
|
|
363
|
+
"date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
|
|
364
|
+
"date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
|
|
365
|
+
"date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
|
|
366
|
+
"date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
|
|
367
|
+
"date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
|
|
368
|
+
"date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
|
|
369
|
+
"date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
|
|
370
|
+
"date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
|
|
371
|
+
}
|
|
372
|
+
),
|
|
373
|
+
headers=headers,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
async def get_images_iter(
|
|
377
|
+
self,
|
|
378
|
+
limit: int | None = 1000,
|
|
379
|
+
page_size: int = None,
|
|
380
|
+
fields: List[str] = None,
|
|
381
|
+
include_fields: List[str] = None,
|
|
382
|
+
exclude_fields: List[str] = None,
|
|
383
|
+
all_fields: bool = False,
|
|
384
|
+
return_latents: List[str] = None,
|
|
385
|
+
cache_ttl: int = None,
|
|
386
|
+
partitions_count: int = None,
|
|
387
|
+
partition: int = None,
|
|
388
|
+
# filters
|
|
389
|
+
short_edge: int = None,
|
|
390
|
+
short_edge__gt: int = None,
|
|
391
|
+
short_edge__gte: int = None,
|
|
392
|
+
short_edge__lt: int = None,
|
|
393
|
+
short_edge__lte: int = None,
|
|
394
|
+
pixel_count: int = None,
|
|
395
|
+
pixel_count__gt: int = None,
|
|
396
|
+
pixel_count__gte: int = None,
|
|
397
|
+
pixel_count__lt: int = None,
|
|
398
|
+
pixel_count__lte: int = None,
|
|
399
|
+
aspect_ratio_fraction: str = None,
|
|
400
|
+
aspect_ratio: float = None,
|
|
401
|
+
aspect_ratio__gt: float = None,
|
|
402
|
+
aspect_ratio__gte: float = None,
|
|
403
|
+
aspect_ratio__lt: float = None,
|
|
404
|
+
aspect_ratio__lte: float = None,
|
|
405
|
+
source: str = None,
|
|
406
|
+
sources: List[str] = None,
|
|
407
|
+
sources__ne: List[str] = None,
|
|
408
|
+
attributes: dict = None,
|
|
409
|
+
has_attributes: list = None,
|
|
410
|
+
lacks_attributes: list = None,
|
|
411
|
+
has_latents: List[str] = None,
|
|
412
|
+
lacks_latents: List[str] = None,
|
|
413
|
+
has_masks: List[str] = None,
|
|
414
|
+
lacks_masks: List[str] = None,
|
|
415
|
+
tags: list = None,
|
|
416
|
+
tags__ne: list = None,
|
|
417
|
+
tags__all: list = None,
|
|
418
|
+
tags__ne_all: list = None,
|
|
419
|
+
tags__empty: bool = None,
|
|
420
|
+
coca_embedding__empty: bool = None,
|
|
421
|
+
duplicate_state: ClientDuplicateState = None,
|
|
422
|
+
date_created__gt: datetime = None,
|
|
423
|
+
date_created__gte: datetime = None,
|
|
424
|
+
date_created__lt: datetime = None,
|
|
425
|
+
date_created__lte: datetime = None,
|
|
426
|
+
date_updated__gt: datetime = None,
|
|
427
|
+
date_updated__gte: datetime = None,
|
|
428
|
+
date_updated__lt: datetime = None,
|
|
429
|
+
date_updated__lte: datetime = None,
|
|
430
|
+
):
|
|
431
|
+
headers = {}
|
|
432
|
+
if cache_ttl:
|
|
433
|
+
headers["Cache-Control"] = f"max-age={cache_ttl}"
|
|
434
|
+
|
|
435
|
+
if source is not None:
|
|
436
|
+
sources = [source]
|
|
437
|
+
logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
|
|
438
|
+
|
|
439
|
+
async for item in self._make_paginated_request_iter(
|
|
440
|
+
url="images/",
|
|
441
|
+
limit=limit,
|
|
442
|
+
params=self._dict_filter_none(
|
|
443
|
+
{
|
|
444
|
+
"fields": ",".join(fields) if fields else None,
|
|
445
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
446
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
447
|
+
"all_fields": all_fields if all_fields else None,
|
|
448
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
449
|
+
"page_size": page_size,
|
|
450
|
+
"partitions_count": partitions_count,
|
|
451
|
+
"partition": partition,
|
|
452
|
+
# filters
|
|
453
|
+
"short_edge": short_edge,
|
|
454
|
+
"short_edge__gt": short_edge__gt,
|
|
455
|
+
"short_edge__gte": short_edge__gte,
|
|
456
|
+
"short_edge__lt": short_edge__lt,
|
|
457
|
+
"short_edge__lte": short_edge__lte,
|
|
458
|
+
"pixel_count": pixel_count,
|
|
459
|
+
"pixel_count__gt": pixel_count__gt,
|
|
460
|
+
"pixel_count__gte": pixel_count__gte,
|
|
461
|
+
"pixel_count__lt": pixel_count__lt,
|
|
462
|
+
"pixel_count__lte": pixel_count__lte,
|
|
463
|
+
"aspect_ratio_fraction": aspect_ratio_fraction,
|
|
464
|
+
"aspect_ratio": aspect_ratio,
|
|
465
|
+
"aspect_ratio__gt": aspect_ratio__gt,
|
|
466
|
+
"aspect_ratio__gte": aspect_ratio__gte,
|
|
467
|
+
"aspect_ratio__lt": aspect_ratio__lt,
|
|
468
|
+
"aspect_ratio__lte": aspect_ratio__lte,
|
|
469
|
+
"sources": ",".join(sources) if sources else None,
|
|
470
|
+
"sources__ne": ",".join(sources__ne) if sources__ne else None,
|
|
471
|
+
"attributes": self._get_attributes_filter(attributes),
|
|
472
|
+
"has_attributes": ",".join(has_attributes) if has_attributes else None,
|
|
473
|
+
"lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
|
|
474
|
+
"has_latents": ",".join(has_latents) if has_latents else None,
|
|
475
|
+
"lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
|
|
476
|
+
"has_masks": ",".join(has_masks) if has_masks else None,
|
|
477
|
+
"lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
|
|
478
|
+
"tags": ",".join(tags) if tags else None,
|
|
479
|
+
"tags__ne": ",".join(tags__ne) if tags__ne else None,
|
|
480
|
+
"tags__all": ",".join(tags__all) if tags__all else None,
|
|
481
|
+
"tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
|
|
482
|
+
"tags__empty": tags__empty,
|
|
483
|
+
"coca_embedding__empty": coca_embedding__empty,
|
|
484
|
+
"duplicate_state": duplicate_state.value if duplicate_state else None,
|
|
485
|
+
"date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
|
|
486
|
+
"date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
|
|
487
|
+
"date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
|
|
488
|
+
"date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
|
|
489
|
+
"date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
|
|
490
|
+
"date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
|
|
491
|
+
"date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
|
|
492
|
+
"date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
|
|
493
|
+
}
|
|
494
|
+
),
|
|
495
|
+
headers=headers,
|
|
496
|
+
):
|
|
497
|
+
yield item
|
|
498
|
+
|
|
499
|
+
async def get_random_images(
|
|
500
|
+
self,
|
|
501
|
+
limit: int | None = 1000,
|
|
502
|
+
page_size: int = None,
|
|
503
|
+
fields: List[str] = None,
|
|
504
|
+
include_fields: List[str] = None,
|
|
505
|
+
exclude_fields: List[str] = None,
|
|
506
|
+
all_fields: bool = False,
|
|
507
|
+
return_latents: List[str] = None,
|
|
508
|
+
cache_ttl: int = None,
|
|
509
|
+
prefix_length: int = None,
|
|
510
|
+
num_prefixes: int = None,
|
|
511
|
+
# filters
|
|
512
|
+
short_edge: int = None,
|
|
513
|
+
short_edge__gt: int = None,
|
|
514
|
+
short_edge__gte: int = None,
|
|
515
|
+
short_edge__lt: int = None,
|
|
516
|
+
short_edge__lte: int = None,
|
|
517
|
+
pixel_count: int = None,
|
|
518
|
+
pixel_count__gt: int = None,
|
|
519
|
+
pixel_count__gte: int = None,
|
|
520
|
+
pixel_count__lt: int = None,
|
|
521
|
+
pixel_count__lte: int = None,
|
|
522
|
+
aspect_ratio_fraction: str = None,
|
|
523
|
+
aspect_ratio: float = None,
|
|
524
|
+
aspect_ratio__gt: float = None,
|
|
525
|
+
aspect_ratio__gte: float = None,
|
|
526
|
+
aspect_ratio__lt: float = None,
|
|
527
|
+
aspect_ratio__lte: float = None,
|
|
528
|
+
source: str = None,
|
|
529
|
+
sources: List[str] = None,
|
|
530
|
+
sources__ne: List[str] = None,
|
|
531
|
+
attributes: dict = None,
|
|
532
|
+
has_attributes: list = None,
|
|
533
|
+
lacks_attributes: list = None,
|
|
534
|
+
has_latents: List[str] = None,
|
|
535
|
+
lacks_latents: List[str] = None,
|
|
536
|
+
has_masks: List[str] = None,
|
|
537
|
+
lacks_masks: List[str] = None,
|
|
538
|
+
tags: list = None,
|
|
539
|
+
tags__ne: list = None,
|
|
540
|
+
tags__all: list = None,
|
|
541
|
+
tags__ne_all: list = None,
|
|
542
|
+
tags__empty: bool = None,
|
|
543
|
+
coca_embedding__empty: bool = None,
|
|
544
|
+
duplicate_state: ClientDuplicateState = None,
|
|
545
|
+
date_created__gt: datetime = None,
|
|
546
|
+
date_created__gte: datetime = None,
|
|
547
|
+
date_created__lt: datetime = None,
|
|
548
|
+
date_created__lte: datetime = None,
|
|
549
|
+
date_updated__gt: datetime = None,
|
|
550
|
+
date_updated__gte: datetime = None,
|
|
551
|
+
date_updated__lt: datetime = None,
|
|
552
|
+
date_updated__lte: datetime = None,
|
|
553
|
+
):
|
|
554
|
+
"""
|
|
555
|
+
Get a list of random images.
|
|
556
|
+
|
|
557
|
+
Random sampling works by filtering image_hash by a number of random hex prefixes. Use prefix_length and
|
|
558
|
+
num_prefixes to adjust the randomness factor. In general, a smaller prefix_length will give you more samples,
|
|
559
|
+
but less random and a higher num_prefixes will give you more samples, but slow down the query. The default
|
|
560
|
+
values are prefix_length=5 and num_prefixes=100.
|
|
561
|
+
"""
|
|
562
|
+
headers = {}
|
|
563
|
+
if cache_ttl:
|
|
564
|
+
headers["Cache-Control"] = f"max-age={cache_ttl}"
|
|
565
|
+
|
|
566
|
+
if source is not None:
|
|
567
|
+
sources = [source]
|
|
568
|
+
logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
|
|
569
|
+
|
|
570
|
+
return await self._make_paginated_request(
|
|
571
|
+
url="images/random/",
|
|
572
|
+
limit=limit,
|
|
573
|
+
params=self._dict_filter_none(
|
|
574
|
+
{
|
|
575
|
+
"fields": ",".join(fields) if fields else None,
|
|
576
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
577
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
578
|
+
"all_fields": all_fields if all_fields else None,
|
|
579
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
580
|
+
"page_size": page_size,
|
|
581
|
+
"prefix_length": prefix_length,
|
|
582
|
+
"num_prefixes": num_prefixes,
|
|
583
|
+
# filters
|
|
584
|
+
"short_edge": short_edge,
|
|
585
|
+
"short_edge__gt": short_edge__gt,
|
|
586
|
+
"short_edge__gte": short_edge__gte,
|
|
587
|
+
"short_edge__lt": short_edge__lt,
|
|
588
|
+
"short_edge__lte": short_edge__lte,
|
|
589
|
+
"pixel_count": pixel_count,
|
|
590
|
+
"pixel_count__gt": pixel_count__gt,
|
|
591
|
+
"pixel_count__gte": pixel_count__gte,
|
|
592
|
+
"pixel_count__lt": pixel_count__lt,
|
|
593
|
+
"pixel_count__lte": pixel_count__lte,
|
|
594
|
+
"aspect_ratio_fraction": aspect_ratio_fraction,
|
|
595
|
+
"aspect_ratio": aspect_ratio,
|
|
596
|
+
"aspect_ratio__gt": aspect_ratio__gt,
|
|
597
|
+
"aspect_ratio__gte": aspect_ratio__gte,
|
|
598
|
+
"aspect_ratio__lt": aspect_ratio__lt,
|
|
599
|
+
"aspect_ratio__lte": aspect_ratio__lte,
|
|
600
|
+
"sources": ",".join(sources) if sources else None,
|
|
601
|
+
"sources__ne": ",".join(sources__ne) if sources__ne else None,
|
|
602
|
+
"attributes": self._get_attributes_filter(attributes),
|
|
603
|
+
"has_attributes": ",".join(has_attributes) if has_attributes else None,
|
|
604
|
+
"lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
|
|
605
|
+
"has_latents": ",".join(has_latents) if has_latents else None,
|
|
606
|
+
"lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
|
|
607
|
+
"has_masks": ",".join(has_masks) if has_masks else None,
|
|
608
|
+
"lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
|
|
609
|
+
"tags": ",".join(tags) if tags else None,
|
|
610
|
+
"tags__ne": ",".join(tags__ne) if tags__ne else None,
|
|
611
|
+
"tags__all": ",".join(tags__all) if tags__all else None,
|
|
612
|
+
"tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
|
|
613
|
+
"tags__empty": tags__empty,
|
|
614
|
+
"coca_embedding__empty": coca_embedding__empty,
|
|
615
|
+
"duplicate_state": duplicate_state.value if duplicate_state else None,
|
|
616
|
+
"date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
|
|
617
|
+
"date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
|
|
618
|
+
"date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
|
|
619
|
+
"date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
|
|
620
|
+
"date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
|
|
621
|
+
"date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
|
|
622
|
+
"date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
|
|
623
|
+
"date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
|
|
624
|
+
}
|
|
625
|
+
),
|
|
626
|
+
headers=headers,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
async def count_images(
|
|
630
|
+
self,
|
|
631
|
+
partitions_count: int = None,
|
|
632
|
+
partition: int = None,
|
|
633
|
+
# filters
|
|
634
|
+
short_edge: int | None = None,
|
|
635
|
+
short_edge__gt: int = None,
|
|
636
|
+
short_edge__gte: int = None,
|
|
637
|
+
short_edge__lt: int = None,
|
|
638
|
+
short_edge__lte: int = None,
|
|
639
|
+
pixel_count: int | None = None,
|
|
640
|
+
pixel_count__gt: int = None,
|
|
641
|
+
pixel_count__gte: int = None,
|
|
642
|
+
pixel_count__lt: int = None,
|
|
643
|
+
pixel_count__lte: int = None,
|
|
644
|
+
aspect_ratio_fraction: str = None,
|
|
645
|
+
aspect_ratio: float = None,
|
|
646
|
+
aspect_ratio__gt: float = None,
|
|
647
|
+
aspect_ratio__gte: float = None,
|
|
648
|
+
aspect_ratio__lt: float = None,
|
|
649
|
+
aspect_ratio__lte: float = None,
|
|
650
|
+
source: str = None,
|
|
651
|
+
sources: List[str] = None,
|
|
652
|
+
sources__ne: List[str] = None,
|
|
653
|
+
attributes: dict = None,
|
|
654
|
+
has_attributes: list = None,
|
|
655
|
+
lacks_attributes: list = None,
|
|
656
|
+
has_latents: List[str] = None,
|
|
657
|
+
lacks_latents: List[str] = None,
|
|
658
|
+
has_masks: List[str] = None,
|
|
659
|
+
lacks_masks: List[str] = None,
|
|
660
|
+
tags: list = None,
|
|
661
|
+
tags__ne: list = None,
|
|
662
|
+
tags__all: list = None,
|
|
663
|
+
tags__ne_all: list = None,
|
|
664
|
+
tags__empty: bool = None,
|
|
665
|
+
coca_embedding__empty: bool = None,
|
|
666
|
+
duplicate_state: ClientDuplicateState = None,
|
|
667
|
+
date_created__gt: datetime = None,
|
|
668
|
+
date_created__gte: datetime = None,
|
|
669
|
+
date_created__lt: datetime = None,
|
|
670
|
+
date_created__lte: datetime = None,
|
|
671
|
+
date_updated__gt: datetime = None,
|
|
672
|
+
date_updated__gte: datetime = None,
|
|
673
|
+
date_updated__lt: datetime = None,
|
|
674
|
+
date_updated__lte: datetime = None,
|
|
675
|
+
):
|
|
676
|
+
|
|
677
|
+
if source is not None:
|
|
678
|
+
sources = [source]
|
|
679
|
+
logger.warning(arg_deprecation_msg('source', 'Please use "sources" instead.'))
|
|
680
|
+
|
|
681
|
+
response = await self._make_request(
|
|
682
|
+
url="images/count/",
|
|
683
|
+
params=self._dict_filter_none(
|
|
684
|
+
{
|
|
685
|
+
"partitions_count": partitions_count,
|
|
686
|
+
"partition": partition,
|
|
687
|
+
# filters
|
|
688
|
+
"short_edge": short_edge,
|
|
689
|
+
"short_edge__gt": short_edge__gt,
|
|
690
|
+
"short_edge__gte": short_edge__gte,
|
|
691
|
+
"short_edge__lt": short_edge__lt,
|
|
692
|
+
"short_edge__lte": short_edge__lte,
|
|
693
|
+
"pixel_count": pixel_count,
|
|
694
|
+
"pixel_count__gt": pixel_count__gt,
|
|
695
|
+
"pixel_count__gte": pixel_count__gte,
|
|
696
|
+
"pixel_count__lt": pixel_count__lt,
|
|
697
|
+
"pixel_count__lte": pixel_count__lte,
|
|
698
|
+
"aspect_ratio_fraction": aspect_ratio_fraction,
|
|
699
|
+
"aspect_ratio": aspect_ratio,
|
|
700
|
+
"aspect_ratio__gt": aspect_ratio__gt,
|
|
701
|
+
"aspect_ratio__gte": aspect_ratio__gte,
|
|
702
|
+
"aspect_ratio__lt": aspect_ratio__lt,
|
|
703
|
+
"aspect_ratio__lte": aspect_ratio__lte,
|
|
704
|
+
"sources": ",".join(sources) if sources else None,
|
|
705
|
+
"sources__ne": ",".join(sources__ne) if sources__ne else None,
|
|
706
|
+
"attributes": self._get_attributes_filter(attributes),
|
|
707
|
+
"has_attributes": ",".join(has_attributes) if has_attributes else None,
|
|
708
|
+
"lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
|
|
709
|
+
"has_latents": ",".join(has_latents) if has_latents else None,
|
|
710
|
+
"lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
|
|
711
|
+
"has_masks": ",".join(has_masks) if has_masks else None,
|
|
712
|
+
"lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
|
|
713
|
+
"tags": ",".join(tags) if tags else None,
|
|
714
|
+
"tags__ne": ",".join(tags__ne) if tags__ne else None,
|
|
715
|
+
"tags__all": ",".join(tags__all) if tags__all else None,
|
|
716
|
+
"tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
|
|
717
|
+
"tags__empty": tags__empty,
|
|
718
|
+
"coca_embedding__empty": coca_embedding__empty,
|
|
719
|
+
"duplicate_state": duplicate_state.value if duplicate_state else None,
|
|
720
|
+
"date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
|
|
721
|
+
"date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
|
|
722
|
+
"date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
|
|
723
|
+
"date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
|
|
724
|
+
"date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
|
|
725
|
+
"date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
|
|
726
|
+
"date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
|
|
727
|
+
"date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
|
|
728
|
+
}
|
|
729
|
+
),
|
|
730
|
+
)
|
|
731
|
+
return response["count"]
|
|
732
|
+
|
|
733
|
+
async def get_image(
|
|
734
|
+
self,
|
|
735
|
+
image_id: str,
|
|
736
|
+
fields: List[str] = None,
|
|
737
|
+
include_fields: List[str] = None,
|
|
738
|
+
exclude_fields: List[str] = None,
|
|
739
|
+
all_fields: bool = False,
|
|
740
|
+
return_latents: List[str] = None,
|
|
741
|
+
):
|
|
742
|
+
return await self._make_request(
|
|
743
|
+
url=f"images/{image_id}/",
|
|
744
|
+
params=self._dict_filter_none({
|
|
745
|
+
"fields": ",".join(fields) if fields else None,
|
|
746
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
747
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
748
|
+
"all_fields": all_fields if all_fields else None,
|
|
749
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
750
|
+
}),
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
async def create_image(
|
|
754
|
+
self,
|
|
755
|
+
image_id: str = None,
|
|
756
|
+
source: str = None,
|
|
757
|
+
image_file: DataRoomFile = None,
|
|
758
|
+
image_url: str = None,
|
|
759
|
+
attributes: dict = None,
|
|
760
|
+
tags: list[str] = None,
|
|
761
|
+
related_images: dict[str, str] | None = None,
|
|
762
|
+
):
|
|
763
|
+
if not image_file and not image_url:
|
|
764
|
+
raise DataRoomError('Please provide either an "image_file" or "image_url" field')
|
|
765
|
+
|
|
766
|
+
if not image_id and not image_url:
|
|
767
|
+
raise DataRoomError('Please provide either an "image_id" or "image_url" field')
|
|
768
|
+
|
|
769
|
+
if not source:
|
|
770
|
+
raise DataRoomError('Please provide a "source" field')
|
|
771
|
+
|
|
772
|
+
json_data = self._dict_filter_none(
|
|
773
|
+
{
|
|
774
|
+
"id": image_id,
|
|
775
|
+
"image_url": image_url,
|
|
776
|
+
"source": source,
|
|
777
|
+
"attributes": attributes,
|
|
778
|
+
"tags": tags,
|
|
779
|
+
"related_images": related_images,
|
|
780
|
+
}
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
if image_file:
|
|
784
|
+
# when uploading an image, we need to send a multipart/form-data request, not JSON
|
|
785
|
+
# the image is sent as a file, and the rest of the data is sent as text/plain
|
|
786
|
+
if not isinstance(image_file, DataRoomFile):
|
|
787
|
+
raise DataRoomError("Argument image_file must be a DataRoomFile")
|
|
788
|
+
files = {
|
|
789
|
+
"image": (
|
|
790
|
+
image_file.filename,
|
|
791
|
+
image_file.bytes_io,
|
|
792
|
+
image_file.content_type,
|
|
793
|
+
),
|
|
794
|
+
"json": (None, json_module.dumps(json_data), "text/plain"),
|
|
795
|
+
}
|
|
796
|
+
return await self._make_request(url="images/", method="POST", files=files)
|
|
797
|
+
else:
|
|
798
|
+
# application/json request
|
|
799
|
+
return await self._make_request(
|
|
800
|
+
url="images/",
|
|
801
|
+
method="POST",
|
|
802
|
+
json=json_data,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
async def create_images(
|
|
806
|
+
self,
|
|
807
|
+
images: List[ImageCreate],
|
|
808
|
+
):
|
|
809
|
+
files = []
|
|
810
|
+
for i, image in enumerate(images):
|
|
811
|
+
if 'id' not in image:
|
|
812
|
+
raise DataRoomError("Missing 'id' field in image")
|
|
813
|
+
if 'source' not in image:
|
|
814
|
+
raise DataRoomError("Missing 'source' field in image")
|
|
815
|
+
if 'image_file' not in image and 'image_url' not in image:
|
|
816
|
+
raise DataRoomError('Please provide either an "image_file" or "image_url" field')
|
|
817
|
+
|
|
818
|
+
image_file = image.get('image_file')
|
|
819
|
+
if image_file and not isinstance(image_file, DataRoomFile):
|
|
820
|
+
raise DataRoomError("Argument image_file must be a DataRoomFile")
|
|
821
|
+
|
|
822
|
+
if image_file:
|
|
823
|
+
files.append((
|
|
824
|
+
f"image_{i}",
|
|
825
|
+
(
|
|
826
|
+
image_file.filename,
|
|
827
|
+
image_file.bytes_io,
|
|
828
|
+
image_file.content_type,
|
|
829
|
+
),
|
|
830
|
+
))
|
|
831
|
+
|
|
832
|
+
json_data = self._dict_filter_none({
|
|
833
|
+
"id": image['id'],
|
|
834
|
+
"source": image['source'],
|
|
835
|
+
"image_url": image.get('image_url'),
|
|
836
|
+
"attributes": image.get('attributes'),
|
|
837
|
+
"tags": image.get('tags'),
|
|
838
|
+
"related_images": image.get('related_images'),
|
|
839
|
+
})
|
|
840
|
+
files.append((
|
|
841
|
+
f"json_{i}",
|
|
842
|
+
(None, json_module.dumps(json_data), "text/plain")
|
|
843
|
+
))
|
|
844
|
+
|
|
845
|
+
return await self._make_request(url="images/", method="POST", files=files)
|
|
846
|
+
|
|
847
|
+
async def update_image(
|
|
848
|
+
self,
|
|
849
|
+
image_id: str,
|
|
850
|
+
source: str = None,
|
|
851
|
+
attributes: dict = None,
|
|
852
|
+
latents: List[LatentType] = None,
|
|
853
|
+
tags: list[str] = None,
|
|
854
|
+
coca_embedding: str = None,
|
|
855
|
+
related_images: dict[str, str] | None = None,
|
|
856
|
+
):
|
|
857
|
+
"""
|
|
858
|
+
Update the image, overwriting the tags and merging attributes, latents and related_images.
|
|
859
|
+
"""
|
|
860
|
+
|
|
861
|
+
if coca_embedding:
|
|
862
|
+
self._validate_vector(coca_embedding)
|
|
863
|
+
|
|
864
|
+
if latents:
|
|
865
|
+
files = []
|
|
866
|
+
for i, latent in enumerate(latents):
|
|
867
|
+
if 'latent_type' not in latent:
|
|
868
|
+
raise DataRoomError("Missing 'latent_type' field in latent")
|
|
869
|
+
if 'file' not in latent:
|
|
870
|
+
raise DataRoomError("Missing 'file' field in latent")
|
|
871
|
+
if not isinstance(latent['file'], DataRoomFile):
|
|
872
|
+
raise DataRoomError("Property 'file' must be a DataRoomFile")
|
|
873
|
+
|
|
874
|
+
latent_file = latent['file']
|
|
875
|
+
files.append((
|
|
876
|
+
f"latent_{i}",
|
|
877
|
+
(
|
|
878
|
+
latent_file.filename,
|
|
879
|
+
latent_file.bytes_io,
|
|
880
|
+
latent_file.content_type,
|
|
881
|
+
),
|
|
882
|
+
))
|
|
883
|
+
|
|
884
|
+
json_data = self._dict_filter_none({
|
|
885
|
+
"latent_type": latent['latent_type'],
|
|
886
|
+
})
|
|
887
|
+
files.append((
|
|
888
|
+
f"latent_json_{i}",
|
|
889
|
+
(None, json_module.dumps(json_data), "text/plain")
|
|
890
|
+
))
|
|
891
|
+
|
|
892
|
+
image_data = self._dict_filter_none({
|
|
893
|
+
"source": source,
|
|
894
|
+
"attributes": attributes,
|
|
895
|
+
"tags": tags,
|
|
896
|
+
"coca_embedding": coca_embedding,
|
|
897
|
+
"related_images": related_images,
|
|
898
|
+
})
|
|
899
|
+
files.append((
|
|
900
|
+
"json",
|
|
901
|
+
(None, json_module.dumps(image_data), "text/plain")
|
|
902
|
+
))
|
|
903
|
+
return await self._make_request(url=f"images/{image_id}/", method="PUT", files=files)
|
|
904
|
+
else:
|
|
905
|
+
return await self._make_request(
|
|
906
|
+
url=f"images/{image_id}/",
|
|
907
|
+
method="PUT",
|
|
908
|
+
json=self._dict_filter_none({
|
|
909
|
+
"source": source,
|
|
910
|
+
"attributes": attributes,
|
|
911
|
+
"tags": tags,
|
|
912
|
+
"coca_embedding": coca_embedding,
|
|
913
|
+
"related_images": related_images,
|
|
914
|
+
}),
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
async def update_images(
|
|
918
|
+
self,
|
|
919
|
+
images: List[ImageUpdate],
|
|
920
|
+
):
|
|
921
|
+
"""
|
|
922
|
+
Bulk update images, overwriting the tags and merging attributes, latents and related_images.
|
|
923
|
+
"""
|
|
924
|
+
for image in images:
|
|
925
|
+
if 'id' not in image:
|
|
926
|
+
raise DataRoomError("Missing 'id' field in image")
|
|
927
|
+
image.setdefault('source', None)
|
|
928
|
+
image.setdefault('attributes', None)
|
|
929
|
+
image.setdefault('tags', None)
|
|
930
|
+
image.setdefault('coca_embedding', None)
|
|
931
|
+
image.setdefault('related_images', None)
|
|
932
|
+
|
|
933
|
+
return await self._make_request(
|
|
934
|
+
url=f"images/bulk_update/",
|
|
935
|
+
method="PUT",
|
|
936
|
+
json=[
|
|
937
|
+
self._dict_filter_none({
|
|
938
|
+
"id": image['id'],
|
|
939
|
+
"source": image['source'],
|
|
940
|
+
"attributes": image['attributes'],
|
|
941
|
+
"tags": image['tags'],
|
|
942
|
+
"coca_embedding": image['coca_embedding'],
|
|
943
|
+
"related_images": image['related_images'],
|
|
944
|
+
})
|
|
945
|
+
for image in images
|
|
946
|
+
],
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
async def add_image_attributes(
|
|
950
|
+
self,
|
|
951
|
+
image_id: str,
|
|
952
|
+
attributes: dict,
|
|
953
|
+
):
|
|
954
|
+
"""
|
|
955
|
+
Update attributes of an image, merging them with the existing attributes.
|
|
956
|
+
"""
|
|
957
|
+
logger.warning(
|
|
958
|
+
'DEPRECATION WARNING: Method "add_image_attributes" is deprecated, and will be removed in the future. '
|
|
959
|
+
'Please use "update_image" instead.'
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
return await self._make_request(
|
|
963
|
+
url=f"images/{image_id}/add_attributes/",
|
|
964
|
+
method="PUT",
|
|
965
|
+
json={
|
|
966
|
+
"attributes": attributes,
|
|
967
|
+
},
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
async def add_image_attributes_in_bulk(
|
|
971
|
+
self,
|
|
972
|
+
ids_to_attributes: dict[str, dict],
|
|
973
|
+
):
|
|
974
|
+
"""
|
|
975
|
+
Update attributes of a list of images, merging them with the existing attributes.
|
|
976
|
+
"""
|
|
977
|
+
logger.warning(
|
|
978
|
+
'DEPRECATION WARNING: Method "add_image_attributes_in_bulk" is deprecated, '
|
|
979
|
+
'and will be removed in the future. Please use "update_image" instead.'
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
return await self._make_request(
|
|
983
|
+
url=f"images/add_attributes_bulk/",
|
|
984
|
+
method="POST",
|
|
985
|
+
json=[
|
|
986
|
+
{"image_id": key, "attributes": val}
|
|
987
|
+
for key, val in ids_to_attributes.items()
|
|
988
|
+
],
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
async def delete_image(self, image_id: str):
|
|
992
|
+
return await self._make_request(
|
|
993
|
+
url=f"images/{image_id}/",
|
|
994
|
+
method="DELETE",
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
async def get_image_audit_logs(self, image_id: str):
|
|
998
|
+
return await self._make_request(
|
|
999
|
+
url=f"images/{image_id}/audit_logs/",
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
async def get_image_similarity(self, image_id_1: str, image_id_2: str):
|
|
1003
|
+
response = await self._make_request(
|
|
1004
|
+
url=f"images/{image_id_1}/similarity/",
|
|
1005
|
+
method="POST",
|
|
1006
|
+
json={
|
|
1007
|
+
"image_id": image_id_2,
|
|
1008
|
+
},
|
|
1009
|
+
)
|
|
1010
|
+
return response["similarity"]
|
|
1011
|
+
|
|
1012
|
+
async def get_similar_images(
|
|
1013
|
+
self,
|
|
1014
|
+
# similarity by
|
|
1015
|
+
image_id: str = None,
|
|
1016
|
+
image_file: DataRoomFile = None,
|
|
1017
|
+
image_vector: str = None,
|
|
1018
|
+
image_text: str = None,
|
|
1019
|
+
# options
|
|
1020
|
+
number=5,
|
|
1021
|
+
fields: List[str] = None,
|
|
1022
|
+
include_fields: List[str] = None,
|
|
1023
|
+
exclude_fields: List[str] = None,
|
|
1024
|
+
all_fields: bool = False,
|
|
1025
|
+
return_latents: List[str] = None,
|
|
1026
|
+
# filters
|
|
1027
|
+
short_edge: int | None = None,
|
|
1028
|
+
short_edge__gt: int = None,
|
|
1029
|
+
short_edge__gte: int = None,
|
|
1030
|
+
short_edge__lt: int = None,
|
|
1031
|
+
short_edge__lte: int = None,
|
|
1032
|
+
pixel_count: int | None = None,
|
|
1033
|
+
pixel_count__gt: int = None,
|
|
1034
|
+
pixel_count__gte: int = None,
|
|
1035
|
+
pixel_count__lt: int = None,
|
|
1036
|
+
pixel_count__lte: int = None,
|
|
1037
|
+
aspect_ratio_fraction: str = None,
|
|
1038
|
+
aspect_ratio: float = None,
|
|
1039
|
+
aspect_ratio__gt: float = None,
|
|
1040
|
+
aspect_ratio__gte: float = None,
|
|
1041
|
+
aspect_ratio__lt: float = None,
|
|
1042
|
+
aspect_ratio__lte: float = None,
|
|
1043
|
+
sources: List[str] = None,
|
|
1044
|
+
sources__ne: List[str] = None,
|
|
1045
|
+
attributes: dict = None,
|
|
1046
|
+
has_attributes: list = None,
|
|
1047
|
+
lacks_attributes: list = None,
|
|
1048
|
+
has_latents: List[str] = None,
|
|
1049
|
+
lacks_latents: List[str] = None,
|
|
1050
|
+
has_masks: List[str] = None,
|
|
1051
|
+
lacks_masks: List[str] = None,
|
|
1052
|
+
tags: list = None,
|
|
1053
|
+
tags__ne: list = None,
|
|
1054
|
+
tags__all: list = None,
|
|
1055
|
+
tags__ne_all: list = None,
|
|
1056
|
+
tags__empty: bool = None,
|
|
1057
|
+
coca_embedding__empty: bool = None,
|
|
1058
|
+
duplicate_state: ClientDuplicateState = None,
|
|
1059
|
+
date_created__gt: datetime = None,
|
|
1060
|
+
date_created__gte: datetime = None,
|
|
1061
|
+
date_created__lt: datetime = None,
|
|
1062
|
+
date_created__lte: datetime = None,
|
|
1063
|
+
date_updated__gt: datetime = None,
|
|
1064
|
+
date_updated__gte: datetime = None,
|
|
1065
|
+
date_updated__lt: datetime = None,
|
|
1066
|
+
date_updated__lte: datetime = None,
|
|
1067
|
+
):
|
|
1068
|
+
search_args = {
|
|
1069
|
+
'image_id': image_id, 'image_file': image_file, 'image_vector': image_vector, 'image_text': image_text,
|
|
1070
|
+
}
|
|
1071
|
+
if sum([bool(arg) for arg in search_args.values()]) != 1:
|
|
1072
|
+
raise DataRoomError(f'Please provide one of the following arguments: {", ".join(search_args.keys())}')
|
|
1073
|
+
|
|
1074
|
+
params = self._dict_filter_none({
|
|
1075
|
+
"fields": ",".join(fields) if fields else None,
|
|
1076
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
1077
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
1078
|
+
"all_fields": all_fields if all_fields else None,
|
|
1079
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
1080
|
+
# filters
|
|
1081
|
+
"short_edge": short_edge,
|
|
1082
|
+
"short_edge__gt": short_edge__gt,
|
|
1083
|
+
"short_edge__gte": short_edge__gte,
|
|
1084
|
+
"short_edge__lt": short_edge__lt,
|
|
1085
|
+
"short_edge__lte": short_edge__lte,
|
|
1086
|
+
"pixel_count": pixel_count,
|
|
1087
|
+
"pixel_count__gt": pixel_count__gt,
|
|
1088
|
+
"pixel_count__gte": pixel_count__gte,
|
|
1089
|
+
"pixel_count__lt": pixel_count__lt,
|
|
1090
|
+
"pixel_count__lte": pixel_count__lte,
|
|
1091
|
+
"aspect_ratio_fraction": aspect_ratio_fraction,
|
|
1092
|
+
"aspect_ratio": aspect_ratio,
|
|
1093
|
+
"aspect_ratio__gt": aspect_ratio__gt,
|
|
1094
|
+
"aspect_ratio__gte": aspect_ratio__gte,
|
|
1095
|
+
"aspect_ratio__lt": aspect_ratio__lt,
|
|
1096
|
+
"aspect_ratio__lte": aspect_ratio__lte,
|
|
1097
|
+
"sources": ",".join(sources) if sources else None,
|
|
1098
|
+
"sources__ne": ",".join(sources__ne) if sources__ne else None,
|
|
1099
|
+
"attributes": self._get_attributes_filter(attributes),
|
|
1100
|
+
"has_attributes": ",".join(has_attributes) if has_attributes else None,
|
|
1101
|
+
"lacks_attributes": ",".join(lacks_attributes) if lacks_attributes else None,
|
|
1102
|
+
"has_latents": ",".join(has_latents) if has_latents else None,
|
|
1103
|
+
"lacks_latents": ",".join(lacks_latents) if lacks_latents else None,
|
|
1104
|
+
"has_masks": ",".join(has_masks) if has_masks else None,
|
|
1105
|
+
"lacks_masks": ",".join(lacks_masks) if lacks_masks else None,
|
|
1106
|
+
"tags": ",".join(tags) if tags else None,
|
|
1107
|
+
"tags__ne": ",".join(tags__ne) if tags__ne else None,
|
|
1108
|
+
"tags__all": ",".join(tags__all) if tags__all else None,
|
|
1109
|
+
"tags__ne_all": ",".join(tags__ne_all) if tags__ne_all else None,
|
|
1110
|
+
"tags__empty": tags__empty,
|
|
1111
|
+
"coca_embedding__empty": coca_embedding__empty,
|
|
1112
|
+
"duplicate_state": duplicate_state.value if duplicate_state else None,
|
|
1113
|
+
"date_created__gt": date_created__gt.isoformat() if date_created__gt else None,
|
|
1114
|
+
"date_created__gte": date_created__gte.isoformat() if date_created__gte else None,
|
|
1115
|
+
"date_created__lt": date_created__lt.isoformat() if date_created__lt else None,
|
|
1116
|
+
"date_created__lte": date_created__lte.isoformat() if date_created__lte else None,
|
|
1117
|
+
"date_updated__gt": date_updated__gt.isoformat() if date_updated__gt else None,
|
|
1118
|
+
"date_updated__gte": date_updated__gte.isoformat() if date_updated__gte else None,
|
|
1119
|
+
"date_updated__lt": date_updated__lt.isoformat() if date_updated__lt else None,
|
|
1120
|
+
"date_updated__lte": date_updated__lte.isoformat() if date_updated__lte else None,
|
|
1121
|
+
})
|
|
1122
|
+
|
|
1123
|
+
if image_file:
|
|
1124
|
+
# by image file
|
|
1125
|
+
if not isinstance(image_file, DataRoomFile):
|
|
1126
|
+
raise DataRoomError("Argument image_file must be a DataRoomFile")
|
|
1127
|
+
json_data = {
|
|
1128
|
+
"number": number,
|
|
1129
|
+
}
|
|
1130
|
+
files = {
|
|
1131
|
+
"image": (
|
|
1132
|
+
image_file.filename,
|
|
1133
|
+
image_file.bytes_io,
|
|
1134
|
+
image_file.content_type,
|
|
1135
|
+
),
|
|
1136
|
+
"json": (None, json_module.dumps(json_data), "text/plain"),
|
|
1137
|
+
}
|
|
1138
|
+
return await self._make_request(
|
|
1139
|
+
url=f"images/similar_to_file/",
|
|
1140
|
+
method="POST",
|
|
1141
|
+
files=files,
|
|
1142
|
+
params=params,
|
|
1143
|
+
)
|
|
1144
|
+
elif image_id:
|
|
1145
|
+
# by image id
|
|
1146
|
+
response = await self._make_request(
|
|
1147
|
+
url=f"images/{image_id}/similar/",
|
|
1148
|
+
params={
|
|
1149
|
+
"number": number,
|
|
1150
|
+
**params,
|
|
1151
|
+
},
|
|
1152
|
+
)
|
|
1153
|
+
return response
|
|
1154
|
+
elif image_vector:
|
|
1155
|
+
# by image vector
|
|
1156
|
+
self._validate_vector(image_vector)
|
|
1157
|
+
return await self._make_request(
|
|
1158
|
+
url=f"images/similar_to_vector/",
|
|
1159
|
+
method="POST",
|
|
1160
|
+
json={
|
|
1161
|
+
"vector": image_vector,
|
|
1162
|
+
"number": number,
|
|
1163
|
+
},
|
|
1164
|
+
params=params,
|
|
1165
|
+
)
|
|
1166
|
+
elif image_text:
|
|
1167
|
+
# by text
|
|
1168
|
+
return await self._make_request(
|
|
1169
|
+
url=f"images/similar_to_text/",
|
|
1170
|
+
method="POST",
|
|
1171
|
+
json={
|
|
1172
|
+
"text": image_text,
|
|
1173
|
+
"number": number,
|
|
1174
|
+
},
|
|
1175
|
+
params=params,
|
|
1176
|
+
)
|
|
1177
|
+
else:
|
|
1178
|
+
raise DataRoomError("Invalid arguments")
|
|
1179
|
+
|
|
1180
|
+
async def get_related_images(
|
|
1181
|
+
self,
|
|
1182
|
+
image_id: str,
|
|
1183
|
+
# options
|
|
1184
|
+
fields: List[str] = None,
|
|
1185
|
+
include_fields: List[str] = None,
|
|
1186
|
+
exclude_fields: List[str] = None,
|
|
1187
|
+
all_fields: bool = False,
|
|
1188
|
+
return_latents: List[str] = None,
|
|
1189
|
+
):
|
|
1190
|
+
params = self._dict_filter_none({
|
|
1191
|
+
"fields": ",".join(fields) if fields else None,
|
|
1192
|
+
"include_fields": ",".join(include_fields) if include_fields else None,
|
|
1193
|
+
"exclude_fields": ",".join(exclude_fields) if exclude_fields else None,
|
|
1194
|
+
"all_fields": all_fields if all_fields else None,
|
|
1195
|
+
"return_latents": ",".join(return_latents) if return_latents else None,
|
|
1196
|
+
})
|
|
1197
|
+
return await self._make_request(
|
|
1198
|
+
url=f"images/{image_id}/related/",
|
|
1199
|
+
params=params,
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
async def set_image_latent(
|
|
1203
|
+
self,
|
|
1204
|
+
image_id: str,
|
|
1205
|
+
latent_file: DataRoomFile,
|
|
1206
|
+
latent_type: str,
|
|
1207
|
+
is_mask=None,
|
|
1208
|
+
):
|
|
1209
|
+
logger.warning(
|
|
1210
|
+
'DEPRECATION WARNING: Method "set_image_latent" is deprecated, and will be removed in the future. '
|
|
1211
|
+
'Please use "update_image" instead.'
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
if not isinstance(latent_file, DataRoomFile):
|
|
1215
|
+
raise DataRoomError("Argument latent_file must be a DataRoomFile")
|
|
1216
|
+
|
|
1217
|
+
if is_mask is not None:
|
|
1218
|
+
logger.warning(arg_deprecation_msg('is_mask'))
|
|
1219
|
+
|
|
1220
|
+
json_data = {
|
|
1221
|
+
"latent_type": latent_type,
|
|
1222
|
+
}
|
|
1223
|
+
files = {
|
|
1224
|
+
"file": (
|
|
1225
|
+
latent_file.filename,
|
|
1226
|
+
latent_file.bytes_io,
|
|
1227
|
+
latent_file.content_type,
|
|
1228
|
+
),
|
|
1229
|
+
"json": (None, json_module.dumps(json_data), "text/plain"),
|
|
1230
|
+
}
|
|
1231
|
+
return await self._make_request(
|
|
1232
|
+
url=f"images/{image_id}/set_latent/",
|
|
1233
|
+
method="POST",
|
|
1234
|
+
files=files,
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
async def delete_image_latent(self, image_id: str, latent_type: str):
|
|
1238
|
+
return await self._make_request(
|
|
1239
|
+
url=f"images/{image_id}/delete_latent/",
|
|
1240
|
+
method="POST",
|
|
1241
|
+
json={
|
|
1242
|
+
"latent_type": latent_type,
|
|
1243
|
+
},
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
async def set_image_coca_embedding(self, image_id: str, vector: str):
|
|
1247
|
+
logger.warning(
|
|
1248
|
+
'DEPRECATION WARNING: Method "set_image_coca_embedding" is deprecated, and will be removed in the future. '
|
|
1249
|
+
'Please use "update_image" instead.'
|
|
1250
|
+
)
|
|
1251
|
+
self._validate_vector(vector)
|
|
1252
|
+
return await self._make_request(
|
|
1253
|
+
url=f"images/{image_id}/",
|
|
1254
|
+
method="PUT",
|
|
1255
|
+
json={
|
|
1256
|
+
"coca_embedding": vector,
|
|
1257
|
+
},
|
|
1258
|
+
)
|
|
1259
|
+
|
|
1260
|
+
async def aggregate_images(self, field, type):
|
|
1261
|
+
return await self._make_request(
|
|
1262
|
+
url="images/aggregate/",
|
|
1263
|
+
method="POST",
|
|
1264
|
+
json={
|
|
1265
|
+
"field": field,
|
|
1266
|
+
"type": type,
|
|
1267
|
+
},
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
async def bucket_images(self, field, size):
|
|
1271
|
+
return await self._make_request(
|
|
1272
|
+
url="images/bucket/",
|
|
1273
|
+
method="POST",
|
|
1274
|
+
json={
|
|
1275
|
+
"field": field,
|
|
1276
|
+
"size": size,
|
|
1277
|
+
},
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
# -------------------- Tag API methods --------------------
|
|
1281
|
+
|
|
1282
|
+
async def get_tags(self, limit: int = 1000):
|
|
1283
|
+
return await self._make_paginated_request(
|
|
1284
|
+
url=f"tags/",
|
|
1285
|
+
limit=limit,
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
async def get_tag(self, tag_id: str):
|
|
1289
|
+
return await self._make_request(
|
|
1290
|
+
url=f"tags/{tag_id}/",
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
async def create_tag(self, name: str, description: str = None):
|
|
1294
|
+
return await self._make_request(
|
|
1295
|
+
url="tags/",
|
|
1296
|
+
method="POST",
|
|
1297
|
+
json=self._dict_filter_none(
|
|
1298
|
+
{
|
|
1299
|
+
"name": name,
|
|
1300
|
+
"description": description,
|
|
1301
|
+
}
|
|
1302
|
+
),
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
async def tag_images(self, image_ids: List[str], tag_names: List[str]):
|
|
1306
|
+
return await self._make_request(
|
|
1307
|
+
url="tags/tag_images/",
|
|
1308
|
+
method="PUT",
|
|
1309
|
+
json={
|
|
1310
|
+
"image_ids": image_ids,
|
|
1311
|
+
"tag_names": tag_names,
|
|
1312
|
+
},
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
|
|
1316
|
+
class DataRoomClientSync:
|
|
1317
|
+
"""
|
|
1318
|
+
The official client of the DataRoom API using synchronous method and requests.
|
|
1319
|
+
"""
|
|
1320
|
+
|
|
1321
|
+
def __init__(self, api_key=None, api_url=None):
|
|
1322
|
+
"""
|
|
1323
|
+
@param api_key: API key for DataRoom API
|
|
1324
|
+
@param api_url: URL of the DataRoom backend API
|
|
1325
|
+
"""
|
|
1326
|
+
self.api_key = api_key or os.environ.get("DATAROOM_API_KEY")
|
|
1327
|
+
self.api_url = (
|
|
1328
|
+
api_url
|
|
1329
|
+
or os.environ.get("DATAROOM_API_URL")
|
|
1330
|
+
)
|
|
1331
|
+
if not self.api_url:
|
|
1332
|
+
raise DataRoomError("DataRoom api_url is not set")
|
|
1333
|
+
self._async_client = DataRoomClient(api_key=api_key, api_url=api_url)
|
|
1334
|
+
|
|
1335
|
+
# -------------------- Private methods --------------------
|
|
1336
|
+
|
|
1337
|
+
@classmethod
|
|
1338
|
+
def _run_sync(cls, coro):
|
|
1339
|
+
try:
|
|
1340
|
+
# Check if there's an existing running event loop
|
|
1341
|
+
loop = asyncio.get_running_loop()
|
|
1342
|
+
except RuntimeError:
|
|
1343
|
+
# No running event loop, create a new one
|
|
1344
|
+
return asyncio.run(coro)
|
|
1345
|
+
else:
|
|
1346
|
+
# A running event loop exists, use run_until_complete
|
|
1347
|
+
return loop.run_until_complete(coro)
|
|
1348
|
+
|
|
1349
|
+
def _make_request(self, *args, **kwargs):
|
|
1350
|
+
return self._run_sync(self._async_client._make_request(*args, **kwargs))
|
|
1351
|
+
|
|
1352
|
+
def _make_paginated_request(self, *args, **kwargs):
|
|
1353
|
+
return self._run_sync(
|
|
1354
|
+
self._async_client._make_paginated_request(*args, **kwargs)
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
# -------------------- Utils --------------------
|
|
1358
|
+
|
|
1359
|
+
@classmethod
|
|
1360
|
+
def download_image_from_url(cls, *args, **kwargs) -> DataRoomFile:
|
|
1361
|
+
return cls._run_sync(DataRoomClient.download_image_from_url(*args, **kwargs))
|
|
1362
|
+
|
|
1363
|
+
# -------------------- Image API methods --------------------
|
|
1364
|
+
|
|
1365
|
+
def get_images(self, *args, **kwargs):
|
|
1366
|
+
return self._run_sync(self._async_client.get_images(*args, **kwargs))
|
|
1367
|
+
|
|
1368
|
+
def get_images_iter(self, *args, **kwargs):
|
|
1369
|
+
return self._run_sync(self._async_client.get_images_iter(*args, **kwargs))
|
|
1370
|
+
|
|
1371
|
+
def get_random_images(self, *args, **kwargs):
|
|
1372
|
+
return self._run_sync(self._async_client.get_random_images(*args, **kwargs))
|
|
1373
|
+
|
|
1374
|
+
def count_images(self, *args, **kwargs):
|
|
1375
|
+
return self._run_sync(self._async_client.count_images(*args, **kwargs))
|
|
1376
|
+
|
|
1377
|
+
def get_image(self, *args, **kwargs):
|
|
1378
|
+
return self._run_sync(self._async_client.get_image(*args, **kwargs))
|
|
1379
|
+
|
|
1380
|
+
def create_image(self, *args, **kwargs):
|
|
1381
|
+
return self._run_sync(self._async_client.create_image(*args, **kwargs))
|
|
1382
|
+
|
|
1383
|
+
def create_images(self, *args, **kwargs):
|
|
1384
|
+
return self._run_sync(self._async_client.create_images(*args, **kwargs))
|
|
1385
|
+
|
|
1386
|
+
def delete_image(self, *args, **kwargs):
|
|
1387
|
+
return self._run_sync(self._async_client.delete_image(*args, **kwargs))
|
|
1388
|
+
|
|
1389
|
+
def get_image_audit_logs(self, *args, **kwargs):
|
|
1390
|
+
return self._run_sync(self._async_client.get_image_audit_logs(*args, **kwargs))
|
|
1391
|
+
|
|
1392
|
+
def get_image_similarity(self, *args, **kwargs):
|
|
1393
|
+
return self._run_sync(self._async_client.get_image_similarity(*args, **kwargs))
|
|
1394
|
+
|
|
1395
|
+
def get_similar_images(self, *args, **kwargs):
|
|
1396
|
+
return self._run_sync(self._async_client.get_similar_images(*args, **kwargs))
|
|
1397
|
+
|
|
1398
|
+
def set_image_latent(self, *args, **kwargs):
|
|
1399
|
+
return self._run_sync(self._async_client.set_image_latent(*args, **kwargs))
|
|
1400
|
+
|
|
1401
|
+
def delete_image_latent(self, *args, **kwargs):
|
|
1402
|
+
return self._run_sync(self._async_client.delete_image_latent(*args, **kwargs))
|
|
1403
|
+
|
|
1404
|
+
def update_image(self,*args, **kwargs):
|
|
1405
|
+
return self._run_sync(self._async_client.update_image(*args, **kwargs))
|
|
1406
|
+
|
|
1407
|
+
def update_images(self,*args, **kwargs):
|
|
1408
|
+
return self._run_sync(self._async_client.update_images(*args, **kwargs))
|
|
1409
|
+
|
|
1410
|
+
def add_image_attributes(self, *args, **kwargs):
|
|
1411
|
+
return self._run_sync(self._async_client.add_image_attributes(*args, **kwargs))
|
|
1412
|
+
|
|
1413
|
+
def add_image_attributes_in_bulk(self, *args, **kwargs):
|
|
1414
|
+
return self._run_sync(self._async_client.add_image_attributes_in_bulk(*args, **kwargs))
|
|
1415
|
+
|
|
1416
|
+
def set_image_coca_embedding(self, *args, **kwargs):
|
|
1417
|
+
return self._run_sync(self._async_client.set_image_coca_embedding(*args, **kwargs))
|
|
1418
|
+
|
|
1419
|
+
def aggregate_images(self, *args, **kwargs):
|
|
1420
|
+
return self._run_sync(self._async_client.aggregate_images(*args, **kwargs))
|
|
1421
|
+
|
|
1422
|
+
def bucket_images(self, *args, **kwargs):
|
|
1423
|
+
return self._run_sync(self._async_client.bucket_images(*args, **kwargs))
|
|
1424
|
+
|
|
1425
|
+
# -------------------- Tag API methods --------------------
|
|
1426
|
+
|
|
1427
|
+
def create_tag(self, *args, **kwargs):
|
|
1428
|
+
return self._run_sync(self._async_client.create_tag(*args, **kwargs))
|
|
1429
|
+
|
|
1430
|
+
def get_tag(self, *args, **kwargs):
|
|
1431
|
+
return self._run_sync(self._async_client.get_tag(*args, **kwargs))
|
|
1432
|
+
|
|
1433
|
+
def get_tags(self, *args, **kwargs):
|
|
1434
|
+
return self._run_sync(self._async_client.get_tags(*args, **kwargs))
|
|
1435
|
+
|
|
1436
|
+
def tag_images(self, *args, **kwargs):
|
|
1437
|
+
return self._run_sync(self._async_client.tag_images(*args, **kwargs))
|
|
1438
|
+
|
|
1439
|
+
|
|
1440
|
+
for method_name in dir(DataRoomClient):
|
|
1441
|
+
if not method_name.startswith("_"):
|
|
1442
|
+
if not hasattr(DataRoomClientSync, method_name):
|
|
1443
|
+
logger.warning(f"Missing implementation: DataRoomClientSync.{method_name}")
|