runware 0.4.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- runware/__init__.py +9 -0
- runware/async_retry.py +96 -0
- runware/base.py +2342 -0
- runware/logging_config.py +30 -0
- runware/reconnection.py +101 -0
- runware/server.py +399 -0
- runware/types.py +1269 -0
- runware/utils.py +909 -0
- runware-0.4.30.dist-info/METADATA +1124 -0
- runware-0.4.30.dist-info/RECORD +18 -0
- runware-0.4.30.dist-info/WHEEL +5 -0
- runware-0.4.30.dist-info/licenses/LICENSE +0 -0
- runware-0.4.30.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_base.py +0 -0
- tests/test_server.py +0 -0
- tests/test_types.py +221 -0
- tests/test_utils.py +297 -0
runware/base.py
ADDED
|
@@ -0,0 +1,2342 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import uuid
|
|
7
|
+
from asyncio import gather
|
|
8
|
+
from dataclasses import asdict
|
|
9
|
+
from typing import List, Optional, Union, Callable, Any, Dict
|
|
10
|
+
|
|
11
|
+
from websockets.protocol import State
|
|
12
|
+
|
|
13
|
+
from .logging_config import configure_logging
|
|
14
|
+
|
|
15
|
+
from .async_retry import asyncRetry
|
|
16
|
+
from .reconnection import ConnectionState, ReconnectionManager
|
|
17
|
+
from .types import (
|
|
18
|
+
Environment,
|
|
19
|
+
IImageInference,
|
|
20
|
+
IPhotoMaker,
|
|
21
|
+
IImageCaption,
|
|
22
|
+
IImageToText,
|
|
23
|
+
IImageBackgroundRemoval,
|
|
24
|
+
ISafety,
|
|
25
|
+
IPromptEnhance,
|
|
26
|
+
IEnhancedPrompt,
|
|
27
|
+
IImageUpscale,
|
|
28
|
+
IUploadModelBaseType,
|
|
29
|
+
IUploadModelResponse,
|
|
30
|
+
ReconnectingWebsocketProps,
|
|
31
|
+
UploadImageType,
|
|
32
|
+
MediaStorageType,
|
|
33
|
+
EPreProcessorGroup,
|
|
34
|
+
File,
|
|
35
|
+
ETaskType,
|
|
36
|
+
IModelSearch,
|
|
37
|
+
IModelSearchResponse,
|
|
38
|
+
IControlNet,
|
|
39
|
+
IVideo,
|
|
40
|
+
IVideoCaption,
|
|
41
|
+
IVideoToText,
|
|
42
|
+
IVideoBackgroundRemoval,
|
|
43
|
+
IVideoUpscale,
|
|
44
|
+
IVideoInference,
|
|
45
|
+
IVideoAdvancedFeatures,
|
|
46
|
+
IAcceleratorOptions,
|
|
47
|
+
IAudio,
|
|
48
|
+
IAudioInference,
|
|
49
|
+
IFrameImage,
|
|
50
|
+
IAsyncTaskResponse,
|
|
51
|
+
IVectorize,
|
|
52
|
+
)
|
|
53
|
+
from .types import IImage, IError, SdkType, ListenerType
|
|
54
|
+
from .utils import (
|
|
55
|
+
BASE_RUNWARE_URLS,
|
|
56
|
+
getUUID,
|
|
57
|
+
fileToBase64,
|
|
58
|
+
createImageFromResponse,
|
|
59
|
+
createImageToTextFromResponse,
|
|
60
|
+
createVideoToTextFromResponse,
|
|
61
|
+
createEnhancedPromptsFromResponse,
|
|
62
|
+
instantiateDataclassList,
|
|
63
|
+
RunwareAPIError,
|
|
64
|
+
RunwareError,
|
|
65
|
+
instantiateDataclass,
|
|
66
|
+
TIMEOUT_DURATION,
|
|
67
|
+
accessDeepObject,
|
|
68
|
+
getIntervalWithPromise,
|
|
69
|
+
removeListener,
|
|
70
|
+
LISTEN_TO_IMAGES_KEY,
|
|
71
|
+
isLocalFile,
|
|
72
|
+
process_image,
|
|
73
|
+
createAsyncTaskResponse,
|
|
74
|
+
VIDEO_INITIAL_TIMEOUT,
|
|
75
|
+
VIDEO_POLLING_DELAY,
|
|
76
|
+
WEBHOOK_TIMEOUT,
|
|
77
|
+
IMAGE_INFERENCE_TIMEOUT,
|
|
78
|
+
IMAGE_OPERATION_TIMEOUT,
|
|
79
|
+
PROMPT_ENHANCE_TIMEOUT,
|
|
80
|
+
IMAGE_UPLOAD_TIMEOUT,
|
|
81
|
+
AUDIO_INFERENCE_TIMEOUT,
|
|
82
|
+
AUDIO_POLLING_DELAY,
|
|
83
|
+
MAX_POLLS_AUDIO_GENERATION,
|
|
84
|
+
MAX_POLLS_VIDEO_GENERATION,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Configure logging
|
|
88
|
+
configure_logging(log_level=logging.CRITICAL)
|
|
89
|
+
|
|
90
|
+
logger = logging.getLogger(__name__)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class RunwareBase:
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
api_key: str,
|
|
97
|
+
url: str = BASE_RUNWARE_URLS[Environment.PRODUCTION],
|
|
98
|
+
timeout: int = TIMEOUT_DURATION,
|
|
99
|
+
log_level=logging.CRITICAL,
|
|
100
|
+
):
|
|
101
|
+
if timeout <= 0:
|
|
102
|
+
raise ValueError("Timeout must be greater than 0 milliseconds")
|
|
103
|
+
|
|
104
|
+
# Configure logging
|
|
105
|
+
configure_logging(log_level)
|
|
106
|
+
self.logger = logging.getLogger(__name__)
|
|
107
|
+
self.logger.setLevel(log_level)
|
|
108
|
+
|
|
109
|
+
self._ws: Optional[ReconnectingWebsocketProps] = None
|
|
110
|
+
self._listeners: List[ListenerType] = []
|
|
111
|
+
self._apiKey: str = api_key
|
|
112
|
+
self._url: Optional[str] = url
|
|
113
|
+
self._timeout: int = timeout
|
|
114
|
+
self._globalMessages: Dict[str, Any] = {}
|
|
115
|
+
self._globalImages: List[IImage] = []
|
|
116
|
+
self._globalError: Optional[IError] = None
|
|
117
|
+
self._connectionSessionUUID: Optional[str] = None
|
|
118
|
+
self._invalidAPIkey: Optional[str] = None
|
|
119
|
+
self._sdkType: SdkType = SdkType.SERVER
|
|
120
|
+
self._messages_lock = asyncio.Lock()
|
|
121
|
+
self._images_lock = asyncio.Lock()
|
|
122
|
+
self._listener_tasks = set()
|
|
123
|
+
self._reconnection_manager = ReconnectionManager(logger=self.logger)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _create_safe_async_listener(self, async_func):
|
|
127
|
+
def wrapper(m):
|
|
128
|
+
task = asyncio.create_task(async_func(m))
|
|
129
|
+
self._listener_tasks.add(task)
|
|
130
|
+
def handle_task_exception(t):
|
|
131
|
+
self._listener_tasks.discard(t)
|
|
132
|
+
if not t.cancelled():
|
|
133
|
+
try:
|
|
134
|
+
t.result()
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"Unhandled exception in async listener: {e}", exc_info=True)
|
|
137
|
+
|
|
138
|
+
task.add_done_callback(handle_task_exception)
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
return wrapper
|
|
142
|
+
|
|
143
|
+
async def _cleanup_listener_tasks(self):
|
|
144
|
+
if not self._listener_tasks:
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
self.logger.info(f"Cleaning up {len(self._listener_tasks)} listener tasks")
|
|
148
|
+
|
|
149
|
+
for task in list(self._listener_tasks):
|
|
150
|
+
if not task.done():
|
|
151
|
+
task.cancel()
|
|
152
|
+
|
|
153
|
+
if self._listener_tasks:
|
|
154
|
+
await asyncio.gather(*self._listener_tasks, return_exceptions=True)
|
|
155
|
+
|
|
156
|
+
self._listener_tasks.clear()
|
|
157
|
+
self.logger.info("All listener tasks cleaned up")
|
|
158
|
+
|
|
159
|
+
def isWebsocketReadyState(self) -> bool:
|
|
160
|
+
if self._ws is None:
|
|
161
|
+
return False
|
|
162
|
+
return self._ws.state is State.OPEN
|
|
163
|
+
|
|
164
|
+
def isAuthenticated(self):
|
|
165
|
+
return self._connectionSessionUUID is not None
|
|
166
|
+
|
|
167
|
+
def addListener(
|
|
168
|
+
self,
|
|
169
|
+
lis: Callable[[Any], Any],
|
|
170
|
+
check: Callable[[Any], Any],
|
|
171
|
+
groupKey: Optional[str] = None,
|
|
172
|
+
) -> Dict[str, Callable[[], None]]:
|
|
173
|
+
# Get the current frame
|
|
174
|
+
current_frame = inspect.currentframe()
|
|
175
|
+
|
|
176
|
+
# Get the caller's frame
|
|
177
|
+
caller_frame = current_frame.f_back
|
|
178
|
+
|
|
179
|
+
# Get the caller's function name
|
|
180
|
+
caller_name = caller_frame.f_code.co_name
|
|
181
|
+
|
|
182
|
+
# Get the caller's line number
|
|
183
|
+
caller_line_number = caller_frame.f_lineno
|
|
184
|
+
|
|
185
|
+
debug_message = f"Listener {self.addListener.__name__} created by {caller_name} at line {caller_line_number} with listener: {lis} and check: {check}"
|
|
186
|
+
# logger.debug(debug_message)
|
|
187
|
+
|
|
188
|
+
if not lis or not check:
|
|
189
|
+
raise ValueError("Listener and check functions are required")
|
|
190
|
+
|
|
191
|
+
def listener(msg: Any) -> None:
|
|
192
|
+
if not lis or not check:
|
|
193
|
+
raise ValueError("Listener and check functions are required")
|
|
194
|
+
if msg.get("error"):
|
|
195
|
+
lis(msg)
|
|
196
|
+
elif check(msg):
|
|
197
|
+
lis(msg)
|
|
198
|
+
|
|
199
|
+
groupListener: ListenerType = ListenerType(
|
|
200
|
+
key=getUUID(),
|
|
201
|
+
listener=listener,
|
|
202
|
+
group_key=groupKey,
|
|
203
|
+
debug_message=debug_message,
|
|
204
|
+
)
|
|
205
|
+
self._listeners.append(groupListener)
|
|
206
|
+
|
|
207
|
+
def destroy() -> None:
|
|
208
|
+
self._listeners = removeListener(self._listeners, groupListener)
|
|
209
|
+
|
|
210
|
+
return {"destroy": destroy}
|
|
211
|
+
|
|
212
|
+
def handle_connection_response(self, m):
|
|
213
|
+
if m.get("error"):
|
|
214
|
+
if m["errorId"] == 19:
|
|
215
|
+
self._invalidAPIkey = "Invalid API key"
|
|
216
|
+
else:
|
|
217
|
+
self._invalidAPIkey = "Error connection"
|
|
218
|
+
return
|
|
219
|
+
self._connectionSessionUUID = m.get("newConnectionSessionUUID", {}).get(
|
|
220
|
+
"connectionSessionUUID"
|
|
221
|
+
)
|
|
222
|
+
self._invalidAPIkey = None
|
|
223
|
+
|
|
224
|
+
async def photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
225
|
+
retry_count = 0
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
await self.ensureConnection()
|
|
229
|
+
|
|
230
|
+
task_uuid = requestPhotoMaker.taskUUID or getUUID()
|
|
231
|
+
requestPhotoMaker.taskUUID = task_uuid
|
|
232
|
+
|
|
233
|
+
for i, image in enumerate(requestPhotoMaker.inputImages):
|
|
234
|
+
if isLocalFile(image) and not str(image).startswith("http"):
|
|
235
|
+
requestPhotoMaker.inputImages[i] = await fileToBase64(image)
|
|
236
|
+
|
|
237
|
+
prompt = f"{requestPhotoMaker.positivePrompt}".strip()
|
|
238
|
+
request_object = {
|
|
239
|
+
"taskUUID": requestPhotoMaker.taskUUID,
|
|
240
|
+
"model": requestPhotoMaker.model,
|
|
241
|
+
"positivePrompt": prompt,
|
|
242
|
+
"numberResults": requestPhotoMaker.numberResults,
|
|
243
|
+
"height": requestPhotoMaker.height,
|
|
244
|
+
"width": requestPhotoMaker.width,
|
|
245
|
+
"taskType": ETaskType.PHOTO_MAKER.value,
|
|
246
|
+
"style": requestPhotoMaker.style,
|
|
247
|
+
"strength": requestPhotoMaker.strength,
|
|
248
|
+
**(
|
|
249
|
+
{"inputImages": requestPhotoMaker.inputImages}
|
|
250
|
+
if requestPhotoMaker.inputImages
|
|
251
|
+
else {}
|
|
252
|
+
),
|
|
253
|
+
**(
|
|
254
|
+
{"steps": requestPhotoMaker.steps}
|
|
255
|
+
if requestPhotoMaker.steps
|
|
256
|
+
else {}
|
|
257
|
+
),
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
if requestPhotoMaker.outputFormat is not None:
|
|
261
|
+
request_object["outputFormat"] = requestPhotoMaker.outputFormat
|
|
262
|
+
if requestPhotoMaker.includeCost:
|
|
263
|
+
request_object["includeCost"] = requestPhotoMaker.includeCost
|
|
264
|
+
if requestPhotoMaker.outputType:
|
|
265
|
+
request_object["outputType"] = requestPhotoMaker.outputType
|
|
266
|
+
if requestPhotoMaker.webhookURL:
|
|
267
|
+
request_object["webhookURL"] = requestPhotoMaker.webhookURL
|
|
268
|
+
|
|
269
|
+
await self.send([request_object])
|
|
270
|
+
|
|
271
|
+
if requestPhotoMaker.webhookURL:
|
|
272
|
+
return await self._handleWebhookAcknowledgment(
|
|
273
|
+
task_uuid=task_uuid,
|
|
274
|
+
task_type="photoMaker",
|
|
275
|
+
debug_key="photo-maker-webhook"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
lis = self.globalListener(
|
|
279
|
+
taskUUID=task_uuid,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
numberOfResults = requestPhotoMaker.numberResults
|
|
283
|
+
|
|
284
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
285
|
+
async with self._messages_lock:
|
|
286
|
+
photo_maker_list = self._globalMessages.get(task_uuid, [])
|
|
287
|
+
unique_results = {}
|
|
288
|
+
|
|
289
|
+
for made_photo in photo_maker_list:
|
|
290
|
+
if made_photo.get("code"):
|
|
291
|
+
raise RunwareAPIError(made_photo)
|
|
292
|
+
|
|
293
|
+
if made_photo.get("taskType") != "photoMaker":
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
image_uuid = made_photo.get("imageUUID")
|
|
297
|
+
if image_uuid not in unique_results:
|
|
298
|
+
unique_results[image_uuid] = made_photo
|
|
299
|
+
|
|
300
|
+
if 0 < numberOfResults <= len(unique_results):
|
|
301
|
+
del self._globalMessages[task_uuid]
|
|
302
|
+
resolve(list(unique_results.values()))
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
response = await getIntervalWithPromise(check, debugKey="photo-maker", timeOutDuration=IMAGE_INFERENCE_TIMEOUT)
|
|
308
|
+
|
|
309
|
+
lis["destroy"]()
|
|
310
|
+
|
|
311
|
+
if "code" in response:
|
|
312
|
+
# This indicates an error response
|
|
313
|
+
raise RunwareAPIError(response)
|
|
314
|
+
|
|
315
|
+
if response:
|
|
316
|
+
if not isinstance(response, list):
|
|
317
|
+
response = [response]
|
|
318
|
+
|
|
319
|
+
return instantiateDataclassList(IImage, response)
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
if retry_count >= 2:
|
|
323
|
+
logger.error(f"Error in photoMaker request:", exc_info=e)
|
|
324
|
+
raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"})
|
|
325
|
+
else:
|
|
326
|
+
raise e
|
|
327
|
+
|
|
328
|
+
async def imageInference(
|
|
329
|
+
self, requestImage: IImageInference
|
|
330
|
+
) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
331
|
+
let_lis: Optional[Any] = None
|
|
332
|
+
request_object: Optional[Dict[str, Any]] = None
|
|
333
|
+
task_uuids: List[str] = []
|
|
334
|
+
retry_count = 0
|
|
335
|
+
try:
|
|
336
|
+
await self.ensureConnection()
|
|
337
|
+
control_net_data: List[IControlNet] = []
|
|
338
|
+
requestImage.maskImage = await process_image(requestImage.maskImage)
|
|
339
|
+
requestImage.seedImage = await process_image(requestImage.seedImage)
|
|
340
|
+
if requestImage.referenceImages:
|
|
341
|
+
requestImage.referenceImages = await process_image(
|
|
342
|
+
requestImage.referenceImages
|
|
343
|
+
)
|
|
344
|
+
if requestImage.controlNet:
|
|
345
|
+
for control_data in requestImage.controlNet:
|
|
346
|
+
image_uploaded = await self.uploadImage(control_data.guideImage)
|
|
347
|
+
if not image_uploaded:
|
|
348
|
+
return []
|
|
349
|
+
if hasattr(control_data, "preprocessor"):
|
|
350
|
+
control_data.preprocessor = control_data.preprocessor.value
|
|
351
|
+
control_data.guideImage = image_uploaded.imageUUID
|
|
352
|
+
control_net_data.append(control_data)
|
|
353
|
+
prompt = f"{requestImage.positivePrompt}".strip()
|
|
354
|
+
|
|
355
|
+
control_net_data_dicts = [asdict(item) for item in control_net_data]
|
|
356
|
+
|
|
357
|
+
instant_id_data = {}
|
|
358
|
+
if requestImage.instantID:
|
|
359
|
+
instant_id_data = {
|
|
360
|
+
k: v
|
|
361
|
+
for k, v in vars(requestImage.instantID).items()
|
|
362
|
+
if v is not None
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
if "inputImage" in instant_id_data:
|
|
366
|
+
instant_id_data["inputImage"] = await process_image(
|
|
367
|
+
instant_id_data["inputImage"]
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
if "poseImage" in instant_id_data:
|
|
371
|
+
instant_id_data["poseImage"] = await process_image(
|
|
372
|
+
instant_id_data["poseImage"]
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
ip_adapters_data = []
|
|
376
|
+
if requestImage.ipAdapters:
|
|
377
|
+
for ip_adapter in requestImage.ipAdapters:
|
|
378
|
+
ip_adapter_data = {
|
|
379
|
+
k: v for k, v in vars(ip_adapter).items() if v is not None
|
|
380
|
+
}
|
|
381
|
+
if "guideImage" in ip_adapter_data:
|
|
382
|
+
ip_adapter_data["guideImage"] = await process_image(
|
|
383
|
+
ip_adapter_data["guideImage"]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
ip_adapters_data.append(ip_adapter_data)
|
|
387
|
+
|
|
388
|
+
ace_plus_plus_data = {}
|
|
389
|
+
if requestImage.acePlusPlus:
|
|
390
|
+
ace_plus_plus_data = {
|
|
391
|
+
"inputImages": [],
|
|
392
|
+
"repaintingScale": requestImage.acePlusPlus.repaintingScale,
|
|
393
|
+
"type": requestImage.acePlusPlus.taskType,
|
|
394
|
+
}
|
|
395
|
+
if requestImage.acePlusPlus.inputImages:
|
|
396
|
+
ace_plus_plus_data["inputImages"] = await process_image(
|
|
397
|
+
requestImage.acePlusPlus.inputImages
|
|
398
|
+
)
|
|
399
|
+
if requestImage.acePlusPlus.inputMasks:
|
|
400
|
+
ace_plus_plus_data["inputMasks"] = await process_image(
|
|
401
|
+
requestImage.acePlusPlus.inputMasks
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
pulid_data = {}
|
|
405
|
+
if requestImage.puLID:
|
|
406
|
+
pulid_data = {
|
|
407
|
+
"inputImages": [],
|
|
408
|
+
"idWeight": requestImage.puLID.idWeight,
|
|
409
|
+
"trueCFGScale": requestImage.puLID.trueCFGScale,
|
|
410
|
+
"CFGStartStep": requestImage.puLID.CFGStartStep,
|
|
411
|
+
"CFGStartStepPercentage": requestImage.puLID.CFGStartStepPercentage,
|
|
412
|
+
}
|
|
413
|
+
if requestImage.puLID.inputImages:
|
|
414
|
+
pulid_data["inputImages"] = await process_image(
|
|
415
|
+
requestImage.puLID.inputImages
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
request_object = self._buildImageRequest(requestImage, prompt, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
|
|
419
|
+
|
|
420
|
+
return await asyncRetry(
|
|
421
|
+
lambda: self._requestImages(
|
|
422
|
+
request_object=request_object,
|
|
423
|
+
task_uuids=task_uuids,
|
|
424
|
+
let_lis=let_lis,
|
|
425
|
+
retry_count=retry_count,
|
|
426
|
+
number_of_images=requestImage.numberResults,
|
|
427
|
+
on_partial_images=requestImage.onPartialImages,
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
except Exception as e:
|
|
431
|
+
if retry_count >= 2:
|
|
432
|
+
logger.error(f"Error in requestImages:", exc_info=e)
|
|
433
|
+
raise RunwareAPIError({"message": f"Image inference failed after retries: {str(e)}"})
|
|
434
|
+
else:
|
|
435
|
+
raise e
|
|
436
|
+
|
|
437
|
+
async def _requestImages(
|
|
438
|
+
self,
|
|
439
|
+
request_object: Dict[str, Any],
|
|
440
|
+
task_uuids: List[str],
|
|
441
|
+
let_lis: Optional[Any],
|
|
442
|
+
retry_count: int,
|
|
443
|
+
number_of_images: int,
|
|
444
|
+
on_partial_images: Optional[Callable[[List[IImage], Optional[IError]], None]],
|
|
445
|
+
) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
446
|
+
retry_count += 1
|
|
447
|
+
if let_lis:
|
|
448
|
+
let_lis["destroy"]()
|
|
449
|
+
images_with_similar_task = [
|
|
450
|
+
img for img in self._globalImages if img.get("taskUUID") in task_uuids
|
|
451
|
+
]
|
|
452
|
+
|
|
453
|
+
task_uuid = request_object.get("taskUUID")
|
|
454
|
+
if task_uuid is None:
|
|
455
|
+
task_uuid = getUUID()
|
|
456
|
+
|
|
457
|
+
task_uuids.append(task_uuid)
|
|
458
|
+
|
|
459
|
+
image_remaining = number_of_images - len(images_with_similar_task)
|
|
460
|
+
new_request_object = {
|
|
461
|
+
**request_object,
|
|
462
|
+
"taskUUID": task_uuid,
|
|
463
|
+
"numberResults": image_remaining,
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
await self.send([new_request_object])
|
|
467
|
+
|
|
468
|
+
if new_request_object.get("webhookURL"):
|
|
469
|
+
return await self._handleWebhookAcknowledgment(
|
|
470
|
+
task_uuid=task_uuid,
|
|
471
|
+
task_type="imageInference",
|
|
472
|
+
debug_key="image-inference-webhook"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
let_lis = await self.listenToImages(
|
|
476
|
+
onPartialImages=on_partial_images,
|
|
477
|
+
taskUUID=task_uuid,
|
|
478
|
+
groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
|
|
479
|
+
)
|
|
480
|
+
images = await self.getSimililarImage(
|
|
481
|
+
taskUUID=task_uuids,
|
|
482
|
+
numberOfImages=number_of_images,
|
|
483
|
+
shouldThrowError=True,
|
|
484
|
+
lis=let_lis,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
let_lis["destroy"]()
|
|
488
|
+
# TODO: NameError("name 'image_path' is not defined"). I think I remove the images when I have onPartialImages
|
|
489
|
+
if images:
|
|
490
|
+
if "code" in images:
|
|
491
|
+
# This indicates an error response
|
|
492
|
+
raise RunwareAPIError(images)
|
|
493
|
+
|
|
494
|
+
return instantiateDataclassList(IImage, images)
|
|
495
|
+
|
|
496
|
+
# return images
|
|
497
|
+
|
|
498
|
+
async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]:
|
|
499
|
+
try:
|
|
500
|
+
await self.ensureConnection()
|
|
501
|
+
return await asyncRetry(
|
|
502
|
+
lambda: self._requestImageToText(requestImageToText)
|
|
503
|
+
)
|
|
504
|
+
except Exception as e:
|
|
505
|
+
raise e
|
|
506
|
+
|
|
507
|
+
async def _requestImageToText(
|
|
508
|
+
self, requestImageToText: IImageCaption
|
|
509
|
+
) -> Union[IImageToText, IAsyncTaskResponse]:
|
|
510
|
+
# Prepare image list - inputImages is primary, inputImage is convenience
|
|
511
|
+
if requestImageToText.inputImages is not None:
|
|
512
|
+
images_to_process = requestImageToText.inputImages
|
|
513
|
+
elif requestImageToText.inputImage is not None:
|
|
514
|
+
# Single image provided via inputImage - convert to array
|
|
515
|
+
images_to_process = [requestImageToText.inputImage]
|
|
516
|
+
else:
|
|
517
|
+
raise ValueError("Either inputImages or inputImage must be provided")
|
|
518
|
+
|
|
519
|
+
# Set inputImage to inputImages[0] if not already provided
|
|
520
|
+
actual_input_image = requestImageToText.inputImage
|
|
521
|
+
if actual_input_image is None and images_to_process:
|
|
522
|
+
actual_input_image = images_to_process[0]
|
|
523
|
+
# Upload all images
|
|
524
|
+
uploaded_images = []
|
|
525
|
+
for image in images_to_process:
|
|
526
|
+
image_uploaded = await self.uploadImage(image)
|
|
527
|
+
if not image_uploaded or not image_uploaded.imageUUID:
|
|
528
|
+
return None
|
|
529
|
+
uploaded_images.append(image_uploaded.imageUUID)
|
|
530
|
+
|
|
531
|
+
taskUUID = getUUID()
|
|
532
|
+
|
|
533
|
+
# Create a dictionary with mandatory parameters
|
|
534
|
+
task_params = {
|
|
535
|
+
"taskType": ETaskType.IMAGE_CAPTION.value,
|
|
536
|
+
"taskUUID": taskUUID,
|
|
537
|
+
}
|
|
538
|
+
# Add either inputImage or inputImages, but not both (API requirement)
|
|
539
|
+
if len(uploaded_images) == 1:
|
|
540
|
+
# Single image - use inputImage parameter
|
|
541
|
+
task_params["inputImage"] = uploaded_images[0]
|
|
542
|
+
else:
|
|
543
|
+
# Multiple images - use inputImages parameter
|
|
544
|
+
task_params["inputImages"] = uploaded_images
|
|
545
|
+
|
|
546
|
+
# Add model parameter only if specified - backend handles default
|
|
547
|
+
if requestImageToText.model is not None:
|
|
548
|
+
task_params["model"] = requestImageToText.model
|
|
549
|
+
|
|
550
|
+
# Add template parameter if specified
|
|
551
|
+
if requestImageToText.template is not None:
|
|
552
|
+
task_params["template"] = requestImageToText.template
|
|
553
|
+
# When using template, do NOT include prompt parameter
|
|
554
|
+
else:
|
|
555
|
+
# Use the provided prompt when no template
|
|
556
|
+
task_params["prompt"] = requestImageToText.prompt
|
|
557
|
+
|
|
558
|
+
# Add optional parameters if they are provided
|
|
559
|
+
if requestImageToText.includeCost:
|
|
560
|
+
task_params["includeCost"] = requestImageToText.includeCost
|
|
561
|
+
if requestImageToText.webhookURL:
|
|
562
|
+
task_params["webhookURL"] = requestImageToText.webhookURL
|
|
563
|
+
|
|
564
|
+
await self.send([task_params])
|
|
565
|
+
|
|
566
|
+
if requestImageToText.webhookURL:
|
|
567
|
+
return await self._handleWebhookAcknowledgment(
|
|
568
|
+
task_uuid=taskUUID,
|
|
569
|
+
task_type="imageCaption",
|
|
570
|
+
debug_key="image-caption-webhook"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
lis = self.globalListener(
|
|
574
|
+
taskUUID=taskUUID,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
578
|
+
async with self._messages_lock:
|
|
579
|
+
response = self._globalMessages.get(taskUUID)
|
|
580
|
+
if response:
|
|
581
|
+
image_to_text = response[0]
|
|
582
|
+
else:
|
|
583
|
+
image_to_text = response
|
|
584
|
+
if image_to_text and image_to_text.get("error"):
|
|
585
|
+
reject(image_to_text)
|
|
586
|
+
return True
|
|
587
|
+
|
|
588
|
+
if image_to_text:
|
|
589
|
+
del self._globalMessages[taskUUID]
|
|
590
|
+
resolve(image_to_text)
|
|
591
|
+
return True
|
|
592
|
+
|
|
593
|
+
return False
|
|
594
|
+
|
|
595
|
+
response = await getIntervalWithPromise(
|
|
596
|
+
check, debugKey="image-to-text", timeOutDuration=IMAGE_OPERATION_TIMEOUT
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
lis["destroy"]()
|
|
601
|
+
|
|
602
|
+
if "code" in response:
|
|
603
|
+
# This indicates an error response
|
|
604
|
+
raise RunwareAPIError(response)
|
|
605
|
+
|
|
606
|
+
if response:
|
|
607
|
+
return createImageToTextFromResponse(response)
|
|
608
|
+
else:
|
|
609
|
+
return None
|
|
610
|
+
|
|
611
|
+
async def videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]:
|
|
612
|
+
try:
|
|
613
|
+
await self.ensureConnection()
|
|
614
|
+
return await asyncRetry(
|
|
615
|
+
lambda: self._requestVideoCaption(requestVideoCaption)
|
|
616
|
+
)
|
|
617
|
+
except Exception as e:
|
|
618
|
+
raise e
|
|
619
|
+
|
|
620
|
+
async def _requestVideoCaption(
|
|
621
|
+
self, requestVideoCaption: IVideoCaption
|
|
622
|
+
) -> Union[List[IVideoToText], IAsyncTaskResponse]:
|
|
623
|
+
taskUUID = requestVideoCaption.taskUUID or getUUID()
|
|
624
|
+
|
|
625
|
+
# Create the request object
|
|
626
|
+
task_params = {
|
|
627
|
+
"taskType": ETaskType.VIDEO_CAPTION.value,
|
|
628
|
+
"taskUUID": taskUUID,
|
|
629
|
+
"model": requestVideoCaption.model,
|
|
630
|
+
"inputs": {
|
|
631
|
+
"video": requestVideoCaption.inputs.video
|
|
632
|
+
},
|
|
633
|
+
"deliveryMethod": requestVideoCaption.deliveryMethod,
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
# Add optional parameters
|
|
637
|
+
if requestVideoCaption.includeCost is not None:
|
|
638
|
+
task_params["includeCost"] = requestVideoCaption.includeCost
|
|
639
|
+
if requestVideoCaption.webhookURL:
|
|
640
|
+
task_params["webhookURL"] = requestVideoCaption.webhookURL
|
|
641
|
+
|
|
642
|
+
await self.send([task_params])
|
|
643
|
+
|
|
644
|
+
if requestVideoCaption.webhookURL:
|
|
645
|
+
return await self._handleWebhookAcknowledgment(
|
|
646
|
+
task_uuid=taskUUID,
|
|
647
|
+
task_type="caption",
|
|
648
|
+
debug_key="video-caption-webhook"
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# For async without webhook, poll for results using _pollVideoResults
|
|
652
|
+
return await self._pollVideoResults(taskUUID, 1, IVideoToText)
|
|
653
|
+
|
|
654
|
+
async def videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
await self.ensureConnection()
|
|
658
|
+
return await asyncRetry(
|
|
659
|
+
lambda: self._requestVideoBackgroundRemoval(requestVideoBackgroundRemoval)
|
|
660
|
+
)
|
|
661
|
+
except Exception as e:
|
|
662
|
+
raise e
|
|
663
|
+
|
|
664
|
+
async def _requestVideoBackgroundRemoval(
|
|
665
|
+
self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval
|
|
666
|
+
) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
667
|
+
taskUUID = requestVideoBackgroundRemoval.taskUUID or getUUID()
|
|
668
|
+
|
|
669
|
+
# Create the request object
|
|
670
|
+
task_params = {
|
|
671
|
+
"taskType": ETaskType.VIDEO_BACKGROUND_REMOVAL.value, # "removeBackground"
|
|
672
|
+
"taskUUID": taskUUID,
|
|
673
|
+
"model": requestVideoBackgroundRemoval.model,
|
|
674
|
+
"inputs": {
|
|
675
|
+
"video": requestVideoBackgroundRemoval.inputs.video
|
|
676
|
+
},
|
|
677
|
+
"deliveryMethod": requestVideoBackgroundRemoval.deliveryMethod,
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
# Add optional parameters
|
|
681
|
+
if requestVideoBackgroundRemoval.outputFormat:
|
|
682
|
+
task_params["outputFormat"] = requestVideoBackgroundRemoval.outputFormat
|
|
683
|
+
if requestVideoBackgroundRemoval.includeCost is not None:
|
|
684
|
+
task_params["includeCost"] = requestVideoBackgroundRemoval.includeCost
|
|
685
|
+
if requestVideoBackgroundRemoval.webhookURL:
|
|
686
|
+
task_params["webhookURL"] = requestVideoBackgroundRemoval.webhookURL
|
|
687
|
+
if requestVideoBackgroundRemoval.settings:
|
|
688
|
+
# Convert IBackgroundRemovalSettings to dict, filtering out None values
|
|
689
|
+
settings_dict = {
|
|
690
|
+
k: v
|
|
691
|
+
for k, v in vars(requestVideoBackgroundRemoval.settings).items()
|
|
692
|
+
if v is not None
|
|
693
|
+
}
|
|
694
|
+
task_params["settings"] = settings_dict
|
|
695
|
+
|
|
696
|
+
await self.send([task_params])
|
|
697
|
+
|
|
698
|
+
if requestVideoBackgroundRemoval.webhookURL:
|
|
699
|
+
return await self._handleWebhookAcknowledgment(
|
|
700
|
+
task_uuid=taskUUID,
|
|
701
|
+
task_type="removeBackground",
|
|
702
|
+
debug_key="video-background-removal-webhook"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
return await self._pollVideoResults(taskUUID, 1, IVideo)
|
|
706
|
+
|
|
707
|
+
async def videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
708
|
+
try:
|
|
709
|
+
await self.ensureConnection()
|
|
710
|
+
return await asyncRetry(
|
|
711
|
+
lambda: self._requestVideoUpscale(requestVideoUpscale)
|
|
712
|
+
)
|
|
713
|
+
except Exception as e:
|
|
714
|
+
raise e
|
|
715
|
+
|
|
716
|
+
async def _requestVideoUpscale(
|
|
717
|
+
self, requestVideoUpscale: IVideoUpscale
|
|
718
|
+
) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
719
|
+
taskUUID = requestVideoUpscale.taskUUID or getUUID()
|
|
720
|
+
|
|
721
|
+
# Create the request object
|
|
722
|
+
task_params = {
|
|
723
|
+
"taskType": ETaskType.VIDEO_UPSCALE.value, # "upscale"
|
|
724
|
+
"taskUUID": taskUUID,
|
|
725
|
+
"model": requestVideoUpscale.model,
|
|
726
|
+
"inputs": {
|
|
727
|
+
"video": requestVideoUpscale.inputs.video
|
|
728
|
+
},
|
|
729
|
+
"upscaleFactor": requestVideoUpscale.upscaleFactor,
|
|
730
|
+
"deliveryMethod": requestVideoUpscale.deliveryMethod,
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
# Add optional parameters
|
|
734
|
+
if requestVideoUpscale.outputFormat:
|
|
735
|
+
task_params["outputFormat"] = requestVideoUpscale.outputFormat
|
|
736
|
+
if requestVideoUpscale.outputType:
|
|
737
|
+
task_params["outputType"] = requestVideoUpscale.outputType
|
|
738
|
+
if requestVideoUpscale.includeCost is not None:
|
|
739
|
+
task_params["includeCost"] = requestVideoUpscale.includeCost
|
|
740
|
+
if requestVideoUpscale.webhookURL:
|
|
741
|
+
task_params["webhookURL"] = requestVideoUpscale.webhookURL
|
|
742
|
+
|
|
743
|
+
await self.send([task_params])
|
|
744
|
+
|
|
745
|
+
if requestVideoUpscale.webhookURL:
|
|
746
|
+
return await self._handleWebhookAcknowledgment(
|
|
747
|
+
task_uuid=taskUUID,
|
|
748
|
+
task_type="upscale",
|
|
749
|
+
debug_key="video-upscale-webhook"
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
return await self._pollVideoResults(taskUUID, 1, IVideo)
|
|
753
|
+
|
|
754
|
+
async def imageBackgroundRemoval(
|
|
755
|
+
self, removeImageBackgroundPayload: IImageBackgroundRemoval
|
|
756
|
+
) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
757
|
+
try:
|
|
758
|
+
await self.ensureConnection()
|
|
759
|
+
return await asyncRetry(
|
|
760
|
+
lambda: self._removeImageBackground(removeImageBackgroundPayload)
|
|
761
|
+
)
|
|
762
|
+
except Exception as e:
|
|
763
|
+
raise e
|
|
764
|
+
|
|
765
|
+
async def _removeImageBackground(
|
|
766
|
+
self, removeImageBackgroundPayload: IImageBackgroundRemoval
|
|
767
|
+
) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
768
|
+
inputImage = removeImageBackgroundPayload.inputImage
|
|
769
|
+
|
|
770
|
+
image_uploaded = await self.uploadImage(inputImage)
|
|
771
|
+
|
|
772
|
+
if not image_uploaded or not image_uploaded.imageUUID:
|
|
773
|
+
return []
|
|
774
|
+
if removeImageBackgroundPayload.taskUUID is not None:
|
|
775
|
+
taskUUID = removeImageBackgroundPayload.taskUUID
|
|
776
|
+
else:
|
|
777
|
+
taskUUID = getUUID()
|
|
778
|
+
|
|
779
|
+
# Create a dictionary with mandatory parameters
|
|
780
|
+
task_params = {
|
|
781
|
+
"taskType": ETaskType.IMAGE_BACKGROUND_REMOVAL.value,
|
|
782
|
+
"taskUUID": taskUUID,
|
|
783
|
+
"inputImage": image_uploaded.imageUUID,
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
# Add optional parameters if they are provided
|
|
787
|
+
if removeImageBackgroundPayload.outputType is not None:
|
|
788
|
+
task_params["outputType"] = removeImageBackgroundPayload.outputType
|
|
789
|
+
if removeImageBackgroundPayload.outputFormat is not None:
|
|
790
|
+
task_params["outputFormat"] = removeImageBackgroundPayload.outputFormat
|
|
791
|
+
if removeImageBackgroundPayload.includeCost:
|
|
792
|
+
task_params["includeCost"] = removeImageBackgroundPayload.includeCost
|
|
793
|
+
if removeImageBackgroundPayload.model:
|
|
794
|
+
task_params["model"] = removeImageBackgroundPayload.model
|
|
795
|
+
if removeImageBackgroundPayload.outputQuality:
|
|
796
|
+
task_params["outputQuality"] = removeImageBackgroundPayload.outputQuality
|
|
797
|
+
if removeImageBackgroundPayload.webhookURL:
|
|
798
|
+
task_params["webhookURL"] = removeImageBackgroundPayload.webhookURL
|
|
799
|
+
|
|
800
|
+
# Handle settings if provided - convert dataclass to dictionary and add non-None values
|
|
801
|
+
if removeImageBackgroundPayload.settings:
|
|
802
|
+
settings_dict = {
|
|
803
|
+
k: v
|
|
804
|
+
for k, v in vars(removeImageBackgroundPayload.settings).items()
|
|
805
|
+
if v is not None
|
|
806
|
+
}
|
|
807
|
+
task_params.update(settings_dict)
|
|
808
|
+
|
|
809
|
+
# Add provider settings if provided
|
|
810
|
+
if removeImageBackgroundPayload.providerSettings:
|
|
811
|
+
self._addImageProviderSettings(task_params, removeImageBackgroundPayload)
|
|
812
|
+
|
|
813
|
+
# Add safety settings if provided
|
|
814
|
+
if removeImageBackgroundPayload.safety:
|
|
815
|
+
self._addSafetySettings(task_params, removeImageBackgroundPayload.safety)
|
|
816
|
+
|
|
817
|
+
# Send the task with all applicable parameters
|
|
818
|
+
await self.send([task_params])
|
|
819
|
+
|
|
820
|
+
if removeImageBackgroundPayload.webhookURL:
|
|
821
|
+
return await self._handleWebhookAcknowledgment(
|
|
822
|
+
task_uuid=taskUUID,
|
|
823
|
+
task_type="imageBackgroundRemoval",
|
|
824
|
+
debug_key="image-background-removal-webhook"
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
lis = self.globalListener(
|
|
828
|
+
taskUUID=taskUUID,
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
832
|
+
async with self._messages_lock:
|
|
833
|
+
response = self._globalMessages.get(taskUUID)
|
|
834
|
+
if response:
|
|
835
|
+
new_remove_background = response[0]
|
|
836
|
+
else:
|
|
837
|
+
new_remove_background = response
|
|
838
|
+
if new_remove_background and new_remove_background.get("error"):
|
|
839
|
+
reject(new_remove_background)
|
|
840
|
+
return True
|
|
841
|
+
|
|
842
|
+
if new_remove_background:
|
|
843
|
+
del self._globalMessages[taskUUID]
|
|
844
|
+
resolve(new_remove_background)
|
|
845
|
+
return True
|
|
846
|
+
|
|
847
|
+
return False
|
|
848
|
+
|
|
849
|
+
response = await getIntervalWithPromise(
|
|
850
|
+
check, debugKey="remove-image-background", timeOutDuration=IMAGE_OPERATION_TIMEOUT
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
lis["destroy"]()
|
|
854
|
+
|
|
855
|
+
if "code" in response:
|
|
856
|
+
# This indicates an error response
|
|
857
|
+
raise RunwareAPIError(response)
|
|
858
|
+
|
|
859
|
+
image = createImageFromResponse(response)
|
|
860
|
+
image_list: List[IImage] = [image]
|
|
861
|
+
|
|
862
|
+
return image_list
|
|
863
|
+
|
|
864
|
+
async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
865
|
+
try:
|
|
866
|
+
await self.ensureConnection()
|
|
867
|
+
return await asyncRetry(lambda: self._upscaleGan(upscaleGanPayload))
|
|
868
|
+
except Exception as e:
|
|
869
|
+
raise e
|
|
870
|
+
|
|
871
|
+
async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
872
|
+
# Support both inputImage (legacy) and inputs.image (new format)
|
|
873
|
+
inputImage = upscaleGanPayload.inputImage
|
|
874
|
+
if not inputImage and upscaleGanPayload.inputs and upscaleGanPayload.inputs.image:
|
|
875
|
+
inputImage = upscaleGanPayload.inputs.image
|
|
876
|
+
|
|
877
|
+
if not inputImage:
|
|
878
|
+
raise ValueError("Either inputImage or inputs.image must be provided")
|
|
879
|
+
|
|
880
|
+
upscaleFactor = upscaleGanPayload.upscaleFactor
|
|
881
|
+
|
|
882
|
+
image_uploaded = await self.uploadImage(inputImage)
|
|
883
|
+
|
|
884
|
+
if not image_uploaded or not image_uploaded.imageUUID:
|
|
885
|
+
return []
|
|
886
|
+
|
|
887
|
+
taskUUID = getUUID()
|
|
888
|
+
|
|
889
|
+
# Create a dictionary with mandatory parameters
|
|
890
|
+
task_params = {
|
|
891
|
+
"taskType": ETaskType.IMAGE_UPSCALE.value,
|
|
892
|
+
"taskUUID": taskUUID,
|
|
893
|
+
"upscaleFactor": upscaleGanPayload.upscaleFactor,
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
# Use inputs.image format if inputs is provided, otherwise use inputImage (legacy)
|
|
897
|
+
if upscaleGanPayload.inputs and upscaleGanPayload.inputs.image:
|
|
898
|
+
task_params["inputs"] = {"image": image_uploaded.imageUUID}
|
|
899
|
+
else:
|
|
900
|
+
task_params["inputImage"] = image_uploaded.imageUUID
|
|
901
|
+
|
|
902
|
+
# Add model parameter if specified
|
|
903
|
+
if upscaleGanPayload.model is not None:
|
|
904
|
+
task_params["model"] = upscaleGanPayload.model
|
|
905
|
+
|
|
906
|
+
# Add settings if provided
|
|
907
|
+
if upscaleGanPayload.settings is not None:
|
|
908
|
+
settings_dict = asdict(upscaleGanPayload.settings)
|
|
909
|
+
# Remove None values
|
|
910
|
+
settings_dict = {k: v for k, v in settings_dict.items() if v is not None}
|
|
911
|
+
if settings_dict:
|
|
912
|
+
task_params["settings"] = settings_dict
|
|
913
|
+
|
|
914
|
+
# Add optional parameters if they are provided
|
|
915
|
+
if upscaleGanPayload.outputType is not None:
|
|
916
|
+
task_params["outputType"] = upscaleGanPayload.outputType
|
|
917
|
+
if upscaleGanPayload.outputFormat is not None:
|
|
918
|
+
task_params["outputFormat"] = upscaleGanPayload.outputFormat
|
|
919
|
+
if upscaleGanPayload.includeCost:
|
|
920
|
+
task_params["includeCost"] = upscaleGanPayload.includeCost
|
|
921
|
+
if upscaleGanPayload.webhookURL:
|
|
922
|
+
task_params["webhookURL"] = upscaleGanPayload.webhookURL
|
|
923
|
+
|
|
924
|
+
# Add provider settings if provided
|
|
925
|
+
if upscaleGanPayload.providerSettings:
|
|
926
|
+
self._addImageProviderSettings(task_params, upscaleGanPayload)
|
|
927
|
+
|
|
928
|
+
# Add safety settings if provided
|
|
929
|
+
if upscaleGanPayload.safety:
|
|
930
|
+
self._addSafetySettings(task_params, upscaleGanPayload.safety)
|
|
931
|
+
|
|
932
|
+
# Send the task with all applicable parameters
|
|
933
|
+
|
|
934
|
+
await self.send([task_params])
|
|
935
|
+
|
|
936
|
+
if upscaleGanPayload.webhookURL:
|
|
937
|
+
return await self._handleWebhookAcknowledgment(
|
|
938
|
+
task_uuid=taskUUID,
|
|
939
|
+
task_type="imageUpscale",
|
|
940
|
+
debug_key="image-upscale-webhook"
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
lis = self.globalListener(
|
|
944
|
+
taskUUID=taskUUID,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
948
|
+
async with self._messages_lock:
|
|
949
|
+
response = self._globalMessages.get(taskUUID)
|
|
950
|
+
if response:
|
|
951
|
+
upscaled_image = response[0]
|
|
952
|
+
else:
|
|
953
|
+
upscaled_image = response
|
|
954
|
+
if upscaled_image and upscaled_image.get("error"):
|
|
955
|
+
reject(upscaled_image)
|
|
956
|
+
return True
|
|
957
|
+
|
|
958
|
+
if upscaled_image:
|
|
959
|
+
del self._globalMessages[taskUUID]
|
|
960
|
+
resolve(upscaled_image)
|
|
961
|
+
return True
|
|
962
|
+
|
|
963
|
+
return False
|
|
964
|
+
|
|
965
|
+
response = await getIntervalWithPromise(
|
|
966
|
+
check, debugKey="upscale-gan", timeOutDuration=IMAGE_OPERATION_TIMEOUT
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
lis["destroy"]()
|
|
970
|
+
|
|
971
|
+
if "code" in response:
|
|
972
|
+
# This indicates an error response
|
|
973
|
+
raise RunwareAPIError(response)
|
|
974
|
+
|
|
975
|
+
image = createImageFromResponse(response)
|
|
976
|
+
image_list: List[IImage] = [image]
|
|
977
|
+
return image_list
|
|
978
|
+
|
|
979
|
+
async def imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
980
|
+
try:
|
|
981
|
+
await self.ensureConnection()
|
|
982
|
+
return await asyncRetry(lambda: self._vectorize(vectorizePayload))
|
|
983
|
+
except Exception as e:
|
|
984
|
+
raise e
|
|
985
|
+
|
|
986
|
+
async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
|
|
987
|
+
# Process the image from inputs
|
|
988
|
+
input_image = vectorizePayload.inputs.image
|
|
989
|
+
|
|
990
|
+
if not input_image:
|
|
991
|
+
raise ValueError("Image is required in inputs for vectorize task")
|
|
992
|
+
|
|
993
|
+
# Upload the image if it's a local file
|
|
994
|
+
image_uploaded = await self.uploadImage(input_image)
|
|
995
|
+
|
|
996
|
+
if not image_uploaded or not image_uploaded.imageUUID:
|
|
997
|
+
return []
|
|
998
|
+
|
|
999
|
+
taskUUID = getUUID()
|
|
1000
|
+
|
|
1001
|
+
# Create a dictionary with mandatory parameters
|
|
1002
|
+
task_params = {
|
|
1003
|
+
"taskType": ETaskType.IMAGE_VECTORIZE.value,
|
|
1004
|
+
"taskUUID": taskUUID,
|
|
1005
|
+
"inputs": {
|
|
1006
|
+
"image": image_uploaded.imageUUID
|
|
1007
|
+
}
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1010
|
+
# Add optional parameters if they are provided
|
|
1011
|
+
if vectorizePayload.model is not None:
|
|
1012
|
+
task_params["model"] = vectorizePayload.model
|
|
1013
|
+
if vectorizePayload.outputType is not None:
|
|
1014
|
+
task_params["outputType"] = vectorizePayload.outputType
|
|
1015
|
+
if vectorizePayload.outputFormat is not None:
|
|
1016
|
+
task_params["outputFormat"] = vectorizePayload.outputFormat
|
|
1017
|
+
if vectorizePayload.includeCost:
|
|
1018
|
+
task_params["includeCost"] = vectorizePayload.includeCost
|
|
1019
|
+
if vectorizePayload.webhookURL:
|
|
1020
|
+
task_params["webhookURL"] = vectorizePayload.webhookURL
|
|
1021
|
+
|
|
1022
|
+
# Send the task with all applicable parameters
|
|
1023
|
+
await self.send([task_params])
|
|
1024
|
+
|
|
1025
|
+
if vectorizePayload.webhookURL:
|
|
1026
|
+
return await self._handleWebhookAcknowledgment(
|
|
1027
|
+
task_uuid=taskUUID,
|
|
1028
|
+
task_type="vectorize",
|
|
1029
|
+
debug_key="image-vectorize-webhook"
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
let_lis = await self.listenToImages(
|
|
1033
|
+
onPartialImages=None,
|
|
1034
|
+
taskUUID=taskUUID,
|
|
1035
|
+
groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
images = await self.getSimililarImage(
|
|
1039
|
+
taskUUID=taskUUID,
|
|
1040
|
+
numberOfImages=1,
|
|
1041
|
+
shouldThrowError=True,
|
|
1042
|
+
lis=let_lis,
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
let_lis["destroy"]()
|
|
1046
|
+
|
|
1047
|
+
if "code" in images or "errors" in images:
|
|
1048
|
+
# This indicates an error response
|
|
1049
|
+
raise RunwareAPIError(images)
|
|
1050
|
+
|
|
1051
|
+
return instantiateDataclassList(IImage, images)
|
|
1052
|
+
|
|
1053
|
+
async def promptEnhance(
|
|
1054
|
+
self, promptEnhancer: IPromptEnhance
|
|
1055
|
+
) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]:
|
|
1056
|
+
"""
|
|
1057
|
+
Enhance the given prompt by generating multiple versions of it.
|
|
1058
|
+
|
|
1059
|
+
:param promptEnhancer: An IPromptEnhancer object containing the prompt details.
|
|
1060
|
+
:return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt.
|
|
1061
|
+
:raises: Any error that occurs during the enhancement process.
|
|
1062
|
+
"""
|
|
1063
|
+
try:
|
|
1064
|
+
await self.ensureConnection()
|
|
1065
|
+
return await asyncRetry(lambda: self._enhancePrompt(promptEnhancer))
|
|
1066
|
+
except Exception as e:
|
|
1067
|
+
raise e
|
|
1068
|
+
|
|
1069
|
+
async def _enhancePrompt(
|
|
1070
|
+
self, promptEnhancer: IPromptEnhance
|
|
1071
|
+
) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]:
|
|
1072
|
+
"""
|
|
1073
|
+
Internal method to perform the actual prompt enhancement.
|
|
1074
|
+
|
|
1075
|
+
:param promptEnhancer: An IPromptEnhancer object containing the prompt details.
|
|
1076
|
+
:return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt.
|
|
1077
|
+
"""
|
|
1078
|
+
prompt = promptEnhancer.prompt
|
|
1079
|
+
promptMaxLength = getattr(promptEnhancer, "promptMaxLength", 380)
|
|
1080
|
+
|
|
1081
|
+
promptVersions = promptEnhancer.promptVersions or 1
|
|
1082
|
+
|
|
1083
|
+
taskUUID = getUUID()
|
|
1084
|
+
|
|
1085
|
+
# Create a dictionary with mandatory parameters
|
|
1086
|
+
task_params = {
|
|
1087
|
+
"taskType": ETaskType.PROMPT_ENHANCE.value,
|
|
1088
|
+
"taskUUID": taskUUID,
|
|
1089
|
+
"prompt": prompt,
|
|
1090
|
+
"promptMaxLength": promptMaxLength,
|
|
1091
|
+
"promptVersions": promptVersions,
|
|
1092
|
+
}
|
|
1093
|
+
|
|
1094
|
+
# Add optional parameters if they are provided
|
|
1095
|
+
if promptEnhancer.includeCost:
|
|
1096
|
+
task_params["includeCost"] = promptEnhancer.includeCost
|
|
1097
|
+
|
|
1098
|
+
has_webhook = promptEnhancer.webhookURL
|
|
1099
|
+
if has_webhook:
|
|
1100
|
+
task_params["webhookURL"] = promptEnhancer.webhookURL
|
|
1101
|
+
|
|
1102
|
+
# Send the task with all applicable parameters
|
|
1103
|
+
await self.send([task_params])
|
|
1104
|
+
|
|
1105
|
+
if has_webhook:
|
|
1106
|
+
return await self._handleWebhookAcknowledgment(
|
|
1107
|
+
task_uuid=taskUUID,
|
|
1108
|
+
task_type="promptEnhance",
|
|
1109
|
+
debug_key="prompt-enhance-webhook"
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
lis = self.globalListener(
|
|
1113
|
+
taskUUID=taskUUID,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
async def check(resolve: Any, reject: Any, *args: Any) -> bool:
|
|
1117
|
+
async with self._messages_lock:
|
|
1118
|
+
response = self._globalMessages.get(taskUUID)
|
|
1119
|
+
if isinstance(response, dict) and response.get("error"):
|
|
1120
|
+
reject(response)
|
|
1121
|
+
return True
|
|
1122
|
+
if response:
|
|
1123
|
+
del self._globalMessages[taskUUID]
|
|
1124
|
+
resolve(response)
|
|
1125
|
+
return True
|
|
1126
|
+
|
|
1127
|
+
return False
|
|
1128
|
+
|
|
1129
|
+
response = await getIntervalWithPromise(
|
|
1130
|
+
check, debugKey="enhance-prompt", timeOutDuration=PROMPT_ENHANCE_TIMEOUT
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
lis["destroy"]()
|
|
1134
|
+
|
|
1135
|
+
if "code" in response[0]:
|
|
1136
|
+
# This indicates an error response
|
|
1137
|
+
raise RunwareAPIError(response[0])
|
|
1138
|
+
|
|
1139
|
+
# Transform the response to a list of IEnhancedPrompt objects
|
|
1140
|
+
enhanced_prompts = createEnhancedPromptsFromResponse(response)
|
|
1141
|
+
|
|
1142
|
+
return list(set(enhanced_prompts))
|
|
1143
|
+
|
|
1144
|
+
async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
|
|
1145
|
+
try:
|
|
1146
|
+
await self.ensureConnection()
|
|
1147
|
+
return await asyncRetry(lambda: self._uploadImage(file))
|
|
1148
|
+
except Exception as e:
|
|
1149
|
+
raise e
|
|
1150
|
+
|
|
1151
|
+
async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
|
|
1152
|
+
task_uuid = getUUID()
|
|
1153
|
+
local_file = True
|
|
1154
|
+
if isinstance(file, str):
|
|
1155
|
+
if os.path.exists(file):
|
|
1156
|
+
local_file = True
|
|
1157
|
+
else:
|
|
1158
|
+
local_file = isLocalFile(file)
|
|
1159
|
+
|
|
1160
|
+
# Check if it's a base64 string (with or without data URI prefix)
|
|
1161
|
+
if file.startswith("data:") or re.match(
|
|
1162
|
+
r"^[A-Za-z0-9+/]+={0,2}$", file
|
|
1163
|
+
):
|
|
1164
|
+
# Assume it's a base64 string (with or without data URI prefix)
|
|
1165
|
+
local_file = False
|
|
1166
|
+
if not local_file:
|
|
1167
|
+
return UploadImageType(
|
|
1168
|
+
imageUUID=file,
|
|
1169
|
+
imageURL=file,
|
|
1170
|
+
taskUUID=task_uuid,
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
file = await fileToBase64(file)
|
|
1174
|
+
|
|
1175
|
+
await self.send(
|
|
1176
|
+
[
|
|
1177
|
+
{
|
|
1178
|
+
"taskType": ETaskType.IMAGE_UPLOAD.value,
|
|
1179
|
+
"taskUUID": task_uuid,
|
|
1180
|
+
"image": file,
|
|
1181
|
+
}
|
|
1182
|
+
]
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
1186
|
+
|
|
1187
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
1188
|
+
async with self._messages_lock:
|
|
1189
|
+
uploaded_image_list = self._globalMessages.get(task_uuid)
|
|
1190
|
+
uploaded_image = uploaded_image_list[0] if uploaded_image_list else None
|
|
1191
|
+
|
|
1192
|
+
if uploaded_image and uploaded_image.get("error"):
|
|
1193
|
+
reject(uploaded_image)
|
|
1194
|
+
return True
|
|
1195
|
+
|
|
1196
|
+
if uploaded_image:
|
|
1197
|
+
del self._globalMessages[task_uuid]
|
|
1198
|
+
resolve(uploaded_image)
|
|
1199
|
+
return True
|
|
1200
|
+
|
|
1201
|
+
return False
|
|
1202
|
+
|
|
1203
|
+
response = await getIntervalWithPromise(
|
|
1204
|
+
check, debugKey="upload-image", timeOutDuration=IMAGE_UPLOAD_TIMEOUT
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1207
|
+
lis["destroy"]()
|
|
1208
|
+
|
|
1209
|
+
if "code" in response:
|
|
1210
|
+
# This indicates an error response
|
|
1211
|
+
raise RunwareAPIError(response)
|
|
1212
|
+
|
|
1213
|
+
if response:
|
|
1214
|
+
image = UploadImageType(
|
|
1215
|
+
imageUUID=response["imageUUID"],
|
|
1216
|
+
imageURL=response["imageURL"],
|
|
1217
|
+
taskUUID=response["taskUUID"],
|
|
1218
|
+
)
|
|
1219
|
+
else:
|
|
1220
|
+
image = None
|
|
1221
|
+
return image
|
|
1222
|
+
|
|
1223
|
+
async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
|
|
1224
|
+
try:
|
|
1225
|
+
await self.ensureConnection()
|
|
1226
|
+
return await asyncRetry(lambda: self._uploadMedia(media_url))
|
|
1227
|
+
except Exception as e:
|
|
1228
|
+
raise e
|
|
1229
|
+
|
|
1230
|
+
async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
|
|
1231
|
+
task_uuid = getUUID()
|
|
1232
|
+
local_file = True
|
|
1233
|
+
|
|
1234
|
+
if isinstance(media_url, str):
|
|
1235
|
+
if os.path.exists(media_url):
|
|
1236
|
+
# Local file - convert to base64
|
|
1237
|
+
media_url = await fileToBase64(media_url)
|
|
1238
|
+
# Strip the data URI prefix for media storage API
|
|
1239
|
+
if media_url.startswith("data:"):
|
|
1240
|
+
media_url = media_url.split(",", 1)[1]
|
|
1241
|
+
# For URLs and base64 strings, send them directly to the API
|
|
1242
|
+
|
|
1243
|
+
await self.send(
|
|
1244
|
+
[
|
|
1245
|
+
{
|
|
1246
|
+
"taskType": ETaskType.MEDIA_STORAGE.value,
|
|
1247
|
+
"taskUUID": task_uuid,
|
|
1248
|
+
"operation": "upload",
|
|
1249
|
+
"media": media_url,
|
|
1250
|
+
}
|
|
1251
|
+
]
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
1255
|
+
|
|
1256
|
+
def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
1257
|
+
uploaded_media_list = self._globalMessages.get(task_uuid)
|
|
1258
|
+
uploaded_media = uploaded_media_list[0] if uploaded_media_list else None
|
|
1259
|
+
|
|
1260
|
+
if uploaded_media and uploaded_media.get("error"):
|
|
1261
|
+
reject(uploaded_media)
|
|
1262
|
+
return True
|
|
1263
|
+
|
|
1264
|
+
if uploaded_media:
|
|
1265
|
+
del self._globalMessages[task_uuid]
|
|
1266
|
+
resolve(uploaded_media)
|
|
1267
|
+
return True
|
|
1268
|
+
|
|
1269
|
+
return False
|
|
1270
|
+
|
|
1271
|
+
response = await getIntervalWithPromise(
|
|
1272
|
+
check, debugKey="upload-media", timeOutDuration=self._timeout
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
lis["destroy"]()
|
|
1276
|
+
|
|
1277
|
+
if "code" in response:
|
|
1278
|
+
# This indicates an error response
|
|
1279
|
+
raise RunwareAPIError(response)
|
|
1280
|
+
|
|
1281
|
+
if response:
|
|
1282
|
+
media = MediaStorageType(
|
|
1283
|
+
mediaUUID=response["mediaUUID"],
|
|
1284
|
+
taskUUID=response["taskUUID"],
|
|
1285
|
+
)
|
|
1286
|
+
else:
|
|
1287
|
+
media = None
|
|
1288
|
+
return media
|
|
1289
|
+
|
|
1290
|
+
async def uploadUnprocessedImage(
|
|
1291
|
+
self,
|
|
1292
|
+
file: Union[File, str],
|
|
1293
|
+
preProcessorType: EPreProcessorGroup,
|
|
1294
|
+
width: int = None,
|
|
1295
|
+
height: int = None,
|
|
1296
|
+
lowThresholdCanny: int = None,
|
|
1297
|
+
highThresholdCanny: int = None,
|
|
1298
|
+
includeHandsAndFaceOpenPose: bool = True,
|
|
1299
|
+
) -> Optional[UploadImageType]:
|
|
1300
|
+
# Create a dummy UploadImageType object
|
|
1301
|
+
uploaded_unprocessed_image = UploadImageType(
|
|
1302
|
+
imageUUID=str(uuid.uuid4()),
|
|
1303
|
+
imageURL="https://example.com/uploaded_unprocessed_image.jpg",
|
|
1304
|
+
taskUUID=str(uuid.uuid4()),
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
return uploaded_unprocessed_image
|
|
1308
|
+
|
|
1309
|
+
async def listenToImages(
|
|
1310
|
+
self,
|
|
1311
|
+
onPartialImages: Optional[Callable[[List[IImage], Optional[IError]], None]],
|
|
1312
|
+
taskUUID: str,
|
|
1313
|
+
groupKey: LISTEN_TO_IMAGES_KEY,
|
|
1314
|
+
) -> Dict[str, Callable[[], None]]:
|
|
1315
|
+
logger.debug("Setting up images listener for taskUUID: %s", taskUUID)
|
|
1316
|
+
|
|
1317
|
+
async def listen_to_images_lis(m: Dict[str, Any]) -> None:
|
|
1318
|
+
if isinstance(m.get("data"), list):
|
|
1319
|
+
images = [
|
|
1320
|
+
img
|
|
1321
|
+
for img in m["data"]
|
|
1322
|
+
if img.get("taskType") in ["imageInference", "vectorize"]
|
|
1323
|
+
and img.get("taskUUID") == taskUUID
|
|
1324
|
+
]
|
|
1325
|
+
|
|
1326
|
+
if images:
|
|
1327
|
+
async with self._images_lock:
|
|
1328
|
+
self._globalImages.extend(images)
|
|
1329
|
+
|
|
1330
|
+
try:
|
|
1331
|
+
partial_images = instantiateDataclassList(IImage, images)
|
|
1332
|
+
if onPartialImages:
|
|
1333
|
+
onPartialImages(
|
|
1334
|
+
partial_images, None
|
|
1335
|
+
)
|
|
1336
|
+
except Exception as e:
|
|
1337
|
+
logger.error(
|
|
1338
|
+
f"Error occurred in user on_partial_images callback function: {e}"
|
|
1339
|
+
)
|
|
1340
|
+
elif isinstance(m.get("errors"), list):
|
|
1341
|
+
errors = [
|
|
1342
|
+
error for error in m["errors"] if error.get("taskUUID") == taskUUID
|
|
1343
|
+
]
|
|
1344
|
+
if errors:
|
|
1345
|
+
error = IError(
|
|
1346
|
+
error=True,
|
|
1347
|
+
error_message=errors[0].get("message", "Unknown error"),
|
|
1348
|
+
task_uuid=errors[0].get("taskUUID", ""),
|
|
1349
|
+
error_code=errors[0].get("code"),
|
|
1350
|
+
error_type=errors[0].get("type"),
|
|
1351
|
+
parameter=errors[0].get("parameter"),
|
|
1352
|
+
documentation=errors[0].get("documentation"),
|
|
1353
|
+
)
|
|
1354
|
+
self._globalError = error
|
|
1355
|
+
if onPartialImages:
|
|
1356
|
+
onPartialImages(
|
|
1357
|
+
[], self._globalError
|
|
1358
|
+
)
|
|
1359
|
+
|
|
1360
|
+
def listen_to_images_check(m):
|
|
1361
|
+
logger.debug("Images check message: %s", m)
|
|
1362
|
+
image_inference_check = isinstance(m.get("data"), list) and any(
|
|
1363
|
+
item.get("taskType") in ["imageInference", "vectorize"] for item in m["data"]
|
|
1364
|
+
)
|
|
1365
|
+
error_check = isinstance(m.get("errors"), list) and any(
|
|
1366
|
+
error.get("taskUUID") == taskUUID for error in m["errors"]
|
|
1367
|
+
)
|
|
1368
|
+
error_code_check = (
|
|
1369
|
+
True
|
|
1370
|
+
if any([error.get("code") for error in m.get("errors", [])])
|
|
1371
|
+
else False
|
|
1372
|
+
)
|
|
1373
|
+
if error_code_check:
|
|
1374
|
+
self._globalError = IError(
|
|
1375
|
+
error=True,
|
|
1376
|
+
error_message=f"Error in image inference: {m.get('errors')}",
|
|
1377
|
+
task_uuid=taskUUID,
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
response = image_inference_check or error_check
|
|
1381
|
+
return response
|
|
1382
|
+
|
|
1383
|
+
temp_listener = self.addListener(
|
|
1384
|
+
check=listen_to_images_check,
|
|
1385
|
+
lis=self._create_safe_async_listener(listen_to_images_lis),
|
|
1386
|
+
groupKey=groupKey
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
logger.debug("listenToImages :: Temp listener: %s", temp_listener)
|
|
1390
|
+
|
|
1391
|
+
return temp_listener
|
|
1392
|
+
|
|
1393
|
+
def globalListener(self, taskUUID: str) -> Dict[str, Callable[[], None]]:
|
|
1394
|
+
"""
|
|
1395
|
+
Set up a global listener to capture specific messages based on the provided taskUUID.
|
|
1396
|
+
|
|
1397
|
+
:param taskUUID: The unique identifier of the task associated with the listener.
|
|
1398
|
+
:return: A dictionary containing a 'destroy' function to remove the listener.
|
|
1399
|
+
"""
|
|
1400
|
+
logger.debug("Setting up global listener for taskUUID: %s", taskUUID)
|
|
1401
|
+
|
|
1402
|
+
async def global_lis(m: Dict[str, Any]) -> None:
|
|
1403
|
+
logger.debug("Global listener message: %s", m)
|
|
1404
|
+
logger.debug("Global listener taskUUID: %s", taskUUID)
|
|
1405
|
+
|
|
1406
|
+
async with self._messages_lock:
|
|
1407
|
+
if m.get("error"):
|
|
1408
|
+
self._globalMessages[taskUUID] = m
|
|
1409
|
+
return
|
|
1410
|
+
|
|
1411
|
+
value = accessDeepObject(
|
|
1412
|
+
taskUUID, m
|
|
1413
|
+
)
|
|
1414
|
+
|
|
1415
|
+
if isinstance(value, list):
|
|
1416
|
+
for v in value:
|
|
1417
|
+
self._globalMessages[v["taskUUID"]] = self._globalMessages.get(
|
|
1418
|
+
v["taskUUID"], []
|
|
1419
|
+
) + [v]
|
|
1420
|
+
logger.debug("Global messages v: %s", v)
|
|
1421
|
+
logger.debug(
|
|
1422
|
+
"self._globalMessages[v[taskUUID]]: %s",
|
|
1423
|
+
self._globalMessages[v["taskUUID"]],
|
|
1424
|
+
)
|
|
1425
|
+
else:
|
|
1426
|
+
self._globalMessages[value["taskUUID"]] = value
|
|
1427
|
+
|
|
1428
|
+
def global_check(m):
|
|
1429
|
+
logger.debug("Global check message: %s", m)
|
|
1430
|
+
return accessDeepObject(taskUUID, m)
|
|
1431
|
+
|
|
1432
|
+
logger.debug("Global Listener taskUUID: %s", taskUUID)
|
|
1433
|
+
|
|
1434
|
+
temp_listener = self.addListener(check=global_check, lis=self._create_safe_async_listener(global_lis))
|
|
1435
|
+
logger.debug("globalListener :: Temp listener: %s", temp_listener)
|
|
1436
|
+
|
|
1437
|
+
return temp_listener
|
|
1438
|
+
|
|
1439
|
+
async def handleIncompleteImages(
|
|
1440
|
+
self, taskUUIDs: List[str], error: Any
|
|
1441
|
+
) -> Optional[List[IImage]]:
|
|
1442
|
+
"""
|
|
1443
|
+
Handle scenarios where the requested number of images is not fully received.
|
|
1444
|
+
|
|
1445
|
+
:param taskUUIDs: A list of task UUIDs to filter the images.
|
|
1446
|
+
:param error: The error object to raise if there are no or only one image.
|
|
1447
|
+
:return: A list of available images if there are more than one, otherwise None.
|
|
1448
|
+
:raises: The provided error if there are no or only one image.
|
|
1449
|
+
"""
|
|
1450
|
+
async with self._images_lock:
|
|
1451
|
+
imagesWithSimilarTask = [
|
|
1452
|
+
img for img in self._globalImages if img["taskUUID"] in taskUUIDs
|
|
1453
|
+
]
|
|
1454
|
+
if len(imagesWithSimilarTask) > 1:
|
|
1455
|
+
self._globalImages = [
|
|
1456
|
+
img for img in self._globalImages if img["taskUUID"] not in taskUUIDs
|
|
1457
|
+
]
|
|
1458
|
+
return imagesWithSimilarTask
|
|
1459
|
+
else:
|
|
1460
|
+
raise error
|
|
1461
|
+
|
|
1462
|
+
async def ensureConnection(self) -> None:
|
|
1463
|
+
"""
|
|
1464
|
+
Ensure that a connection is established with the server.
|
|
1465
|
+
|
|
1466
|
+
This method checks if the current connection is active and, if not, initiates a new connection.
|
|
1467
|
+
It handles authentication and retries the connection if necessary.
|
|
1468
|
+
|
|
1469
|
+
:raises: An error message if the connection cannot be established due to an invalid API key or other reasons.
|
|
1470
|
+
"""
|
|
1471
|
+
isConnected = self.connected() and self._ws.state is State.OPEN
|
|
1472
|
+
|
|
1473
|
+
try:
|
|
1474
|
+
if self._invalidAPIkey:
|
|
1475
|
+
if not self._reconnection_manager._had_successful_auth:
|
|
1476
|
+
raise ConnectionError(self._invalidAPIkey)
|
|
1477
|
+
|
|
1478
|
+
circuit_state = self._reconnection_manager.get_state()
|
|
1479
|
+
if circuit_state == ConnectionState.CIRCUIT_OPEN:
|
|
1480
|
+
raise ConnectionError(self._invalidAPIkey)
|
|
1481
|
+
|
|
1482
|
+
if not isConnected:
|
|
1483
|
+
await self.connect()
|
|
1484
|
+
|
|
1485
|
+
if self._invalidAPIkey and not self._reconnection_manager._had_successful_auth:
|
|
1486
|
+
raise ConnectionError(self._invalidAPIkey)
|
|
1487
|
+
|
|
1488
|
+
except Exception as e:
|
|
1489
|
+
raise ConnectionError(
|
|
1490
|
+
self._invalidAPIkey
|
|
1491
|
+
or "Could not connect to server. Ensure your API key is correct"
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
async def getSimililarImage(
|
|
1495
|
+
self,
|
|
1496
|
+
taskUUID: Union[str, List[str]],
|
|
1497
|
+
numberOfImages: int = 1,
|
|
1498
|
+
shouldThrowError: bool = True,
|
|
1499
|
+
lis: Optional[ListenerType] = None,
|
|
1500
|
+
) -> List[IImage]:
|
|
1501
|
+
"""
|
|
1502
|
+
Retrieve similar images based on the provided task UUID(s).
|
|
1503
|
+
|
|
1504
|
+
:param taskUUID: A single task UUID or a list of task UUIDs.
|
|
1505
|
+
:param numberOfImages: The number of images to retrieve. Defaults to 1.
|
|
1506
|
+
:param shouldThrowError: Whether to raise an error on timeout. Defaults to True.
|
|
1507
|
+
:param lis: Optional listener to destroy upon completion.
|
|
1508
|
+
:param timeout: The timeout duration for the operation.
|
|
1509
|
+
:return: A list of IImage objects representing the images.
|
|
1510
|
+
"""
|
|
1511
|
+
taskUUIDs = taskUUID if isinstance(taskUUID, list) else [taskUUID]
|
|
1512
|
+
|
|
1513
|
+
async def check(
|
|
1514
|
+
resolve: Callable[[List[IImage]], None],
|
|
1515
|
+
reject: Callable[[IError], None],
|
|
1516
|
+
intervalId: Any,
|
|
1517
|
+
) -> Optional[bool]:
|
|
1518
|
+
async with self._images_lock:
|
|
1519
|
+
logger.debug(f"Check # Global images: {self._globalImages}")
|
|
1520
|
+
imagesWithSimilarTask = [
|
|
1521
|
+
img
|
|
1522
|
+
for img in self._globalImages
|
|
1523
|
+
if img.get("taskType") in ["imageInference", "vectorize"]
|
|
1524
|
+
and img.get("taskUUID") in taskUUIDs
|
|
1525
|
+
]
|
|
1526
|
+
|
|
1527
|
+
if self._globalError:
|
|
1528
|
+
logger.debug(f"Check # _globalError: {self._globalError}")
|
|
1529
|
+
error = self._globalError
|
|
1530
|
+
self._globalError = None
|
|
1531
|
+
logger.debug(f"Rejecting with error: {error}")
|
|
1532
|
+
reject(RunwareError(error))
|
|
1533
|
+
return True
|
|
1534
|
+
elif len(imagesWithSimilarTask) >= numberOfImages:
|
|
1535
|
+
self._globalImages = [
|
|
1536
|
+
img
|
|
1537
|
+
for img in self._globalImages
|
|
1538
|
+
if img.get("taskType") in ["imageInference", "vectorize"]
|
|
1539
|
+
and img.get("taskUUID") not in taskUUIDs
|
|
1540
|
+
]
|
|
1541
|
+
resolve(imagesWithSimilarTask[:numberOfImages])
|
|
1542
|
+
return True
|
|
1543
|
+
|
|
1544
|
+
return False
|
|
1545
|
+
|
|
1546
|
+
try:
|
|
1547
|
+
return await getIntervalWithPromise(
|
|
1548
|
+
check,
|
|
1549
|
+
debugKey="getting images",
|
|
1550
|
+
shouldThrowError=shouldThrowError,
|
|
1551
|
+
timeOutDuration=IMAGE_INFERENCE_TIMEOUT,
|
|
1552
|
+
)
|
|
1553
|
+
except Exception as e:
|
|
1554
|
+
async with self._images_lock:
|
|
1555
|
+
current_images = len([
|
|
1556
|
+
img for img in self._globalImages
|
|
1557
|
+
if img.get("taskType") in ["imageInference", "vectorize"]
|
|
1558
|
+
and img.get("taskUUID") in taskUUIDs
|
|
1559
|
+
])
|
|
1560
|
+
error_msg = (
|
|
1561
|
+
f"Timeout waiting for images | "
|
|
1562
|
+
f"TaskUUIDs: {taskUUIDs} | "
|
|
1563
|
+
f"Expected: {numberOfImages} images | "
|
|
1564
|
+
f"Received: {current_images} images | "
|
|
1565
|
+
f"Timeout: {IMAGE_INFERENCE_TIMEOUT}ms | "
|
|
1566
|
+
f"Original error: {str(e)}"
|
|
1567
|
+
)
|
|
1568
|
+
raise Exception(error_msg) from e
|
|
1569
|
+
|
|
1570
|
+
async def _modelUpload(
|
|
1571
|
+
self, requestModel: IUploadModelBaseType
|
|
1572
|
+
) -> Optional[IUploadModelResponse]:
|
|
1573
|
+
task_uuid = getUUID()
|
|
1574
|
+
base_fields = {
|
|
1575
|
+
"taskType": ETaskType.MODEL_UPLOAD.value,
|
|
1576
|
+
"taskUUID": task_uuid,
|
|
1577
|
+
"air": requestModel.air,
|
|
1578
|
+
"name": requestModel.name,
|
|
1579
|
+
"downloadURL": requestModel.downloadURL,
|
|
1580
|
+
"uniqueIdentifier": requestModel.uniqueIdentifier,
|
|
1581
|
+
"version": requestModel.version,
|
|
1582
|
+
"format": requestModel.format,
|
|
1583
|
+
"private": requestModel.private,
|
|
1584
|
+
"category": requestModel.category,
|
|
1585
|
+
"architecture": requestModel.architecture,
|
|
1586
|
+
}
|
|
1587
|
+
|
|
1588
|
+
optional_fields = [
|
|
1589
|
+
"retry",
|
|
1590
|
+
"heroImageURL",
|
|
1591
|
+
"tags",
|
|
1592
|
+
"shortDescription",
|
|
1593
|
+
"comment",
|
|
1594
|
+
"positiveTriggerWords",
|
|
1595
|
+
"type",
|
|
1596
|
+
"negativeTriggerWords",
|
|
1597
|
+
"defaultWeight",
|
|
1598
|
+
"defaultStrength",
|
|
1599
|
+
"defaultGuidanceScale",
|
|
1600
|
+
"defaultSteps",
|
|
1601
|
+
"defaultScheduler",
|
|
1602
|
+
"conditioning",
|
|
1603
|
+
]
|
|
1604
|
+
|
|
1605
|
+
request_object = {
|
|
1606
|
+
**base_fields,
|
|
1607
|
+
**{
|
|
1608
|
+
field: getattr(requestModel, field)
|
|
1609
|
+
for field in optional_fields
|
|
1610
|
+
if getattr(requestModel, field, None) is not None
|
|
1611
|
+
},
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
await self.send([request_object])
|
|
1615
|
+
|
|
1616
|
+
lis = self.globalListener(
|
|
1617
|
+
taskUUID=task_uuid,
|
|
1618
|
+
)
|
|
1619
|
+
|
|
1620
|
+
async def check(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
1621
|
+
async with self._messages_lock:
|
|
1622
|
+
uploaded_model_list = self._globalMessages.get(task_uuid, [])
|
|
1623
|
+
unique_statuses = set()
|
|
1624
|
+
all_models = []
|
|
1625
|
+
|
|
1626
|
+
for uploaded_model in uploaded_model_list:
|
|
1627
|
+
if uploaded_model.get("code"):
|
|
1628
|
+
raise RunwareAPIError(uploaded_model)
|
|
1629
|
+
|
|
1630
|
+
status = uploaded_model.get("status")
|
|
1631
|
+
|
|
1632
|
+
if status not in unique_statuses:
|
|
1633
|
+
all_models.append(uploaded_model)
|
|
1634
|
+
unique_statuses.add(status)
|
|
1635
|
+
|
|
1636
|
+
if status is not None and "error" in status:
|
|
1637
|
+
raise RunwareAPIError(uploaded_model)
|
|
1638
|
+
|
|
1639
|
+
if status == "ready":
|
|
1640
|
+
uploaded_model_list.remove(uploaded_model)
|
|
1641
|
+
if not uploaded_model_list:
|
|
1642
|
+
del self._globalMessages[task_uuid]
|
|
1643
|
+
else:
|
|
1644
|
+
self._globalMessages[task_uuid] = uploaded_model_list
|
|
1645
|
+
resolve(all_models)
|
|
1646
|
+
return True
|
|
1647
|
+
|
|
1648
|
+
return False
|
|
1649
|
+
|
|
1650
|
+
response = await getIntervalWithPromise(
|
|
1651
|
+
check, debugKey="upload-model", timeOutDuration=self._timeout
|
|
1652
|
+
)
|
|
1653
|
+
|
|
1654
|
+
lis["destroy"]()
|
|
1655
|
+
|
|
1656
|
+
if "code" in response:
|
|
1657
|
+
# This indicates an error response
|
|
1658
|
+
raise RunwareAPIError(response)
|
|
1659
|
+
|
|
1660
|
+
if response:
|
|
1661
|
+
if not isinstance(response, list):
|
|
1662
|
+
response = [response]
|
|
1663
|
+
|
|
1664
|
+
models = []
|
|
1665
|
+
for item in response:
|
|
1666
|
+
models.append(
|
|
1667
|
+
{
|
|
1668
|
+
"taskType": item.get("taskType"),
|
|
1669
|
+
"taskUUID": item.get("taskUUID"),
|
|
1670
|
+
"status": item.get("status"),
|
|
1671
|
+
"message": item.get("message"),
|
|
1672
|
+
"air": item.get("air"),
|
|
1673
|
+
}
|
|
1674
|
+
)
|
|
1675
|
+
else:
|
|
1676
|
+
models = None
|
|
1677
|
+
return models
|
|
1678
|
+
|
|
1679
|
+
async def modelUpload(
|
|
1680
|
+
self, requestModel: IUploadModelBaseType
|
|
1681
|
+
) -> Optional[IUploadModelResponse]:
|
|
1682
|
+
try:
|
|
1683
|
+
await self.ensureConnection()
|
|
1684
|
+
return await asyncRetry(lambda: self._modelUpload(requestModel))
|
|
1685
|
+
except Exception as e:
|
|
1686
|
+
raise e
|
|
1687
|
+
|
|
1688
|
+
async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse:
|
|
1689
|
+
try:
|
|
1690
|
+
await self.ensureConnection()
|
|
1691
|
+
task_uuid = getUUID()
|
|
1692
|
+
|
|
1693
|
+
request_object = {
|
|
1694
|
+
"taskUUID": task_uuid,
|
|
1695
|
+
"taskType": ETaskType.MODEL_SEARCH.value,
|
|
1696
|
+
**({"tags": payload.tags} if payload.tags else {}),
|
|
1697
|
+
}
|
|
1698
|
+
|
|
1699
|
+
request_object.update(
|
|
1700
|
+
{
|
|
1701
|
+
key: value
|
|
1702
|
+
for key, value in vars(payload).items()
|
|
1703
|
+
if value is not None and key != "additional_params"
|
|
1704
|
+
}
|
|
1705
|
+
)
|
|
1706
|
+
|
|
1707
|
+
await self.send([request_object])
|
|
1708
|
+
|
|
1709
|
+
listener = self.globalListener(taskUUID=task_uuid)
|
|
1710
|
+
|
|
1711
|
+
async def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
|
|
1712
|
+
async with self._messages_lock:
|
|
1713
|
+
response = self._globalMessages.get(task_uuid)
|
|
1714
|
+
if response:
|
|
1715
|
+
if response[0].get("error"):
|
|
1716
|
+
reject(response[0])
|
|
1717
|
+
return True
|
|
1718
|
+
del self._globalMessages[task_uuid]
|
|
1719
|
+
resolve(response[0])
|
|
1720
|
+
return True
|
|
1721
|
+
return False
|
|
1722
|
+
|
|
1723
|
+
response = await getIntervalWithPromise(
|
|
1724
|
+
check, debugKey="model-search", timeOutDuration=self._timeout
|
|
1725
|
+
)
|
|
1726
|
+
|
|
1727
|
+
listener["destroy"]()
|
|
1728
|
+
|
|
1729
|
+
if "code" in response:
|
|
1730
|
+
# This indicates an error response
|
|
1731
|
+
raise RunwareAPIError(response)
|
|
1732
|
+
|
|
1733
|
+
return instantiateDataclass(IModelSearchResponse, response)
|
|
1734
|
+
|
|
1735
|
+
except Exception as e:
|
|
1736
|
+
if isinstance(e, RunwareAPIError):
|
|
1737
|
+
raise
|
|
1738
|
+
|
|
1739
|
+
raise RunwareAPIError({"message": str(e)})
|
|
1740
|
+
|
|
1741
|
+
async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
1742
|
+
await self.ensureConnection()
|
|
1743
|
+
return await asyncRetry(lambda: self._requestVideo(requestVideo))
|
|
1744
|
+
|
|
1745
|
+
async def getResponse(
|
|
1746
|
+
self,
|
|
1747
|
+
taskUUID: str,
|
|
1748
|
+
numberResults: int = 1
|
|
1749
|
+
) -> List[IVideo]:
|
|
1750
|
+
await self.ensureConnection()
|
|
1751
|
+
return await self._pollVideoResults(taskUUID, numberResults, IVideo)
|
|
1752
|
+
|
|
1753
|
+
async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
1754
|
+
await self._processVideoImages(requestVideo)
|
|
1755
|
+
requestVideo.taskUUID = requestVideo.taskUUID or getUUID()
|
|
1756
|
+
request_object = self._buildVideoRequest(requestVideo)
|
|
1757
|
+
|
|
1758
|
+
if requestVideo.webhookURL:
|
|
1759
|
+
request_object["webhookURL"] = requestVideo.webhookURL
|
|
1760
|
+
|
|
1761
|
+
await self.send([request_object])
|
|
1762
|
+
|
|
1763
|
+
if requestVideo.skipResponse:
|
|
1764
|
+
return IAsyncTaskResponse(
|
|
1765
|
+
taskType=ETaskType.VIDEO_INFERENCE.value,
|
|
1766
|
+
taskUUID=requestVideo.taskUUID
|
|
1767
|
+
)
|
|
1768
|
+
|
|
1769
|
+
return await self._handleInitialVideoResponse(
|
|
1770
|
+
requestVideo.taskUUID,
|
|
1771
|
+
requestVideo.numberResults,
|
|
1772
|
+
request_object.get("webhookURL")
|
|
1773
|
+
)
|
|
1774
|
+
|
|
1775
|
+
async def _processVideoImages(self, requestVideo: IVideoInference) -> None:
|
|
1776
|
+
frame_tasks = []
|
|
1777
|
+
reference_tasks = []
|
|
1778
|
+
|
|
1779
|
+
if requestVideo.frameImages:
|
|
1780
|
+
frame_tasks = [
|
|
1781
|
+
process_image(frame_item.inputImage)
|
|
1782
|
+
for frame_item in requestVideo.frameImages
|
|
1783
|
+
if isinstance(frame_item, IFrameImage)
|
|
1784
|
+
]
|
|
1785
|
+
|
|
1786
|
+
if requestVideo.referenceImages:
|
|
1787
|
+
reference_tasks = [
|
|
1788
|
+
process_image(reference_item)
|
|
1789
|
+
for reference_item in requestVideo.referenceImages
|
|
1790
|
+
]
|
|
1791
|
+
|
|
1792
|
+
frame_results = await gather(*frame_tasks) if frame_tasks else []
|
|
1793
|
+
reference_results = await gather(*reference_tasks) if reference_tasks else []
|
|
1794
|
+
|
|
1795
|
+
if requestVideo.frameImages and frame_results:
|
|
1796
|
+
processed_frame_images = []
|
|
1797
|
+
result_index = 0
|
|
1798
|
+
for frame_item in requestVideo.frameImages:
|
|
1799
|
+
if isinstance(frame_item, IFrameImage):
|
|
1800
|
+
frame_item.inputImages = frame_results[result_index]
|
|
1801
|
+
result_index += 1
|
|
1802
|
+
processed_frame_images.append(frame_item)
|
|
1803
|
+
requestVideo.frameImages = processed_frame_images
|
|
1804
|
+
|
|
1805
|
+
if requestVideo.referenceImages and reference_results:
|
|
1806
|
+
requestVideo.referenceImages = reference_results
|
|
1807
|
+
|
|
1808
|
+
def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]:
|
|
1809
|
+
request_object = {
|
|
1810
|
+
"deliveryMethod": requestVideo.deliveryMethod,
|
|
1811
|
+
"taskType": ETaskType.VIDEO_INFERENCE.value,
|
|
1812
|
+
"taskUUID": requestVideo.taskUUID,
|
|
1813
|
+
"model": requestVideo.model,
|
|
1814
|
+
"numberResults": requestVideo.numberResults,
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
# Only add positivePrompt if it's not None
|
|
1818
|
+
if requestVideo.positivePrompt is not None:
|
|
1819
|
+
request_object["positivePrompt"] = requestVideo.positivePrompt.strip()
|
|
1820
|
+
|
|
1821
|
+
self._addOptionalField(request_object, requestVideo.speech)
|
|
1822
|
+
self._addOptionalVideoFields(request_object, requestVideo)
|
|
1823
|
+
self._addVideoImages(request_object, requestVideo)
|
|
1824
|
+
self._addVideoInputs(request_object, requestVideo)
|
|
1825
|
+
self._addProviderSettings(request_object, requestVideo)
|
|
1826
|
+
self._addOptionalField(request_object, requestVideo.safety)
|
|
1827
|
+
self._addOptionalField(request_object, requestVideo.advancedFeatures)
|
|
1828
|
+
self._addOptionalField(request_object, requestVideo.acceleratorOptions)
|
|
1829
|
+
|
|
1830
|
+
|
|
1831
|
+
return request_object
|
|
1832
|
+
|
|
1833
|
+
def _addOptionalVideoFields(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
|
|
1834
|
+
optional_fields = [
|
|
1835
|
+
"outputType", "outputFormat", "outputQuality", "uploadEndpoint",
|
|
1836
|
+
"includeCost", "negativePrompt", "inputAudios", "referenceVideos", "fps", "steps", "seed",
|
|
1837
|
+
"CFGScale", "seedImage", "duration", "width", "height", "nsfw_check",
|
|
1838
|
+
]
|
|
1839
|
+
|
|
1840
|
+
for field in optional_fields:
|
|
1841
|
+
value = getattr(requestVideo, field, None)
|
|
1842
|
+
if value is not None:
|
|
1843
|
+
request_object[field] = value
|
|
1844
|
+
|
|
1845
|
+
def _addVideoImages(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
|
|
1846
|
+
if requestVideo.frameImages:
|
|
1847
|
+
frame_images_data = []
|
|
1848
|
+
for frame_item in requestVideo.frameImages:
|
|
1849
|
+
frame_images_data.append({k: v for k, v in asdict(frame_item).items() if v is not None})
|
|
1850
|
+
request_object["frameImages"] = frame_images_data
|
|
1851
|
+
|
|
1852
|
+
if requestVideo.referenceImages:
|
|
1853
|
+
request_object["referenceImages"] = requestVideo.referenceImages
|
|
1854
|
+
|
|
1855
|
+
# Add lora if present
|
|
1856
|
+
if requestVideo.lora:
|
|
1857
|
+
request_object["lora"] = [
|
|
1858
|
+
{"model": lora.model, "weight": lora.weight}
|
|
1859
|
+
for lora in requestVideo.lora
|
|
1860
|
+
]
|
|
1861
|
+
|
|
1862
|
+
def _buildImageRequest(self, requestImage: IImageInference, prompt: str, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> Dict[str, Any]:
|
|
1863
|
+
request_object = {
|
|
1864
|
+
"taskType": ETaskType.IMAGE_INFERENCE.value,
|
|
1865
|
+
"model": requestImage.model,
|
|
1866
|
+
"positivePrompt": prompt,
|
|
1867
|
+
}
|
|
1868
|
+
|
|
1869
|
+
self._addOptionalImageFields(request_object, requestImage)
|
|
1870
|
+
self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
|
|
1871
|
+
self._addImageInputs(request_object, requestImage)
|
|
1872
|
+
self._addImageProviderSettings(request_object, requestImage)
|
|
1873
|
+
|
|
1874
|
+
return request_object
|
|
1875
|
+
|
|
1876
|
+
def _addOptionalImageFields(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
|
|
1877
|
+
optional_fields = [
|
|
1878
|
+
"outputType", "outputFormat", "outputQuality", "uploadEndpoint",
|
|
1879
|
+
"includeCost", "checkNsfw", "negativePrompt", "seedImage", "maskImage",
|
|
1880
|
+
"strength", "height", "width", "steps", "scheduler", "seed", "CFGScale",
|
|
1881
|
+
"clipSkip", "promptWeighting", "maskMargin", "vae", "webhookURL", "acceleration"
|
|
1882
|
+
]
|
|
1883
|
+
|
|
1884
|
+
for field in optional_fields:
|
|
1885
|
+
value = getattr(requestImage, field, None)
|
|
1886
|
+
if value is not None:
|
|
1887
|
+
# Special handling for checkNsfw -> checkNSFW
|
|
1888
|
+
if field == "checkNsfw":
|
|
1889
|
+
request_object["checkNSFW"] = value
|
|
1890
|
+
else:
|
|
1891
|
+
request_object[field] = value
|
|
1892
|
+
|
|
1893
|
+
def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: IImageInference, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> None:
|
|
1894
|
+
# Add controlNet if present
|
|
1895
|
+
if control_net_data_dicts:
|
|
1896
|
+
request_object["controlNet"] = control_net_data_dicts
|
|
1897
|
+
|
|
1898
|
+
# Add lora if present
|
|
1899
|
+
if requestImage.lora:
|
|
1900
|
+
request_object["lora"] = [
|
|
1901
|
+
{"model": lora.model, "weight": lora.weight}
|
|
1902
|
+
for lora in requestImage.lora
|
|
1903
|
+
]
|
|
1904
|
+
|
|
1905
|
+
# Add lycoris if present
|
|
1906
|
+
if requestImage.lycoris:
|
|
1907
|
+
request_object["lycoris"] = [
|
|
1908
|
+
{"model": lycoris.model, "weight": lycoris.weight}
|
|
1909
|
+
for lycoris in requestImage.lycoris
|
|
1910
|
+
]
|
|
1911
|
+
|
|
1912
|
+
# Add embeddings if present
|
|
1913
|
+
if requestImage.embeddings:
|
|
1914
|
+
request_object["embeddings"] = [
|
|
1915
|
+
{"model": embedding.model}
|
|
1916
|
+
for embedding in requestImage.embeddings
|
|
1917
|
+
]
|
|
1918
|
+
|
|
1919
|
+
# Add refiner if present
|
|
1920
|
+
if requestImage.refiner:
|
|
1921
|
+
refiner_dict = {"model": requestImage.refiner.model}
|
|
1922
|
+
if requestImage.refiner.startStep is not None:
|
|
1923
|
+
refiner_dict["startStep"] = requestImage.refiner.startStep
|
|
1924
|
+
if requestImage.refiner.startStepPercentage is not None:
|
|
1925
|
+
refiner_dict["startStepPercentage"] = requestImage.refiner.startStepPercentage
|
|
1926
|
+
request_object["refiner"] = refiner_dict
|
|
1927
|
+
|
|
1928
|
+
# Add instantID if present
|
|
1929
|
+
if instant_id_data:
|
|
1930
|
+
request_object["instantID"] = instant_id_data
|
|
1931
|
+
|
|
1932
|
+
# Add outpaint if present
|
|
1933
|
+
if requestImage.outpaint:
|
|
1934
|
+
outpaint_dict = {
|
|
1935
|
+
k: v
|
|
1936
|
+
for k, v in vars(requestImage.outpaint).items()
|
|
1937
|
+
if v is not None
|
|
1938
|
+
}
|
|
1939
|
+
request_object["outpaint"] = outpaint_dict
|
|
1940
|
+
|
|
1941
|
+
# Add ipAdapters if present
|
|
1942
|
+
if ip_adapters_data:
|
|
1943
|
+
request_object["ipAdapters"] = ip_adapters_data
|
|
1944
|
+
|
|
1945
|
+
# Add acePlusPlus if present
|
|
1946
|
+
if ace_plus_plus_data:
|
|
1947
|
+
request_object["acePlusPlus"] = ace_plus_plus_data
|
|
1948
|
+
|
|
1949
|
+
# Add puLID if present
|
|
1950
|
+
if pulid_data:
|
|
1951
|
+
request_object["puLID"] = pulid_data
|
|
1952
|
+
|
|
1953
|
+
# Add referenceImages if present
|
|
1954
|
+
if requestImage.referenceImages:
|
|
1955
|
+
request_object["referenceImages"] = requestImage.referenceImages
|
|
1956
|
+
|
|
1957
|
+
# Add acceleratorOptions if present
|
|
1958
|
+
self._addOptionalField(request_object, requestImage.acceleratorOptions)
|
|
1959
|
+
|
|
1960
|
+
# Add advancedFeatures if present
|
|
1961
|
+
if requestImage.advancedFeatures:
|
|
1962
|
+
pipeline_options_dict = {
|
|
1963
|
+
k: v.__dict__
|
|
1964
|
+
for k, v in vars(requestImage.advancedFeatures).items()
|
|
1965
|
+
if v is not None
|
|
1966
|
+
}
|
|
1967
|
+
request_object["advancedFeatures"] = pipeline_options_dict
|
|
1968
|
+
|
|
1969
|
+
# Add extraArgs if present
|
|
1970
|
+
if hasattr(requestImage, "extraArgs") and isinstance(requestImage.extraArgs, dict):
|
|
1971
|
+
request_object.update(requestImage.extraArgs)
|
|
1972
|
+
|
|
1973
|
+
def _addImageInputs(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
|
|
1974
|
+
# Add inputs if present
|
|
1975
|
+
if requestImage.inputs:
|
|
1976
|
+
inputs_dict = {
|
|
1977
|
+
k: v for k, v in asdict(requestImage.inputs).items()
|
|
1978
|
+
if v is not None
|
|
1979
|
+
}
|
|
1980
|
+
if inputs_dict:
|
|
1981
|
+
request_object["inputs"] = inputs_dict
|
|
1982
|
+
|
|
1983
|
+
def _addVideoInputs(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
|
|
1984
|
+
# Add inputs if present
|
|
1985
|
+
if requestVideo.inputs:
|
|
1986
|
+
inputs_dict = {
|
|
1987
|
+
k: v for k, v in asdict(requestVideo.inputs).items()
|
|
1988
|
+
if v is not None
|
|
1989
|
+
}
|
|
1990
|
+
if inputs_dict:
|
|
1991
|
+
request_object["inputs"] = inputs_dict
|
|
1992
|
+
|
|
1993
|
+
def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) -> None:
|
|
1994
|
+
safety_dict = asdict(safety)
|
|
1995
|
+
safety_dict = {k: v for k, v in safety_dict.items() if v is not None}
|
|
1996
|
+
if safety_dict:
|
|
1997
|
+
request_object["safety"] = safety_dict
|
|
1998
|
+
|
|
1999
|
+
def _addImageProviderSettings(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
|
|
2000
|
+
if not requestImage.providerSettings:
|
|
2001
|
+
return
|
|
2002
|
+
provider_dict = requestImage.providerSettings.to_request_dict()
|
|
2003
|
+
if provider_dict:
|
|
2004
|
+
request_object["providerSettings"] = provider_dict
|
|
2005
|
+
|
|
2006
|
+
def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
|
|
2007
|
+
if not requestVideo.providerSettings:
|
|
2008
|
+
return
|
|
2009
|
+
provider_dict = requestVideo.providerSettings.to_request_dict()
|
|
2010
|
+
if provider_dict:
|
|
2011
|
+
request_object["providerSettings"] = provider_dict
|
|
2012
|
+
|
|
2013
|
+
def _addOptionalField(self, request_object: Dict[str, Any], obj: Any) -> None:
|
|
2014
|
+
if not obj:
|
|
2015
|
+
return
|
|
2016
|
+
obj_dict = obj.to_request_dict()
|
|
2017
|
+
if obj_dict:
|
|
2018
|
+
request_object.update(obj_dict)
|
|
2019
|
+
|
|
2020
|
+
async def _handleWebhookAcknowledgment(
|
|
2021
|
+
self,
|
|
2022
|
+
task_uuid: str,
|
|
2023
|
+
task_type: str,
|
|
2024
|
+
debug_key: str,
|
|
2025
|
+
) -> IAsyncTaskResponse:
|
|
2026
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
2027
|
+
|
|
2028
|
+
async def check_webhook_ack(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
2029
|
+
async with self._messages_lock:
|
|
2030
|
+
response_list = self._globalMessages.get(task_uuid, [])
|
|
2031
|
+
|
|
2032
|
+
if not response_list:
|
|
2033
|
+
return False
|
|
2034
|
+
|
|
2035
|
+
response = response_list[0] if isinstance(response_list, list) else response_list
|
|
2036
|
+
|
|
2037
|
+
if response.get("code"):
|
|
2038
|
+
raise RunwareAPIError(response)
|
|
2039
|
+
|
|
2040
|
+
if isinstance(response, dict) and response.get("error"):
|
|
2041
|
+
reject(response)
|
|
2042
|
+
return True
|
|
2043
|
+
|
|
2044
|
+
if response.get("taskType") == task_type:
|
|
2045
|
+
del self._globalMessages[task_uuid]
|
|
2046
|
+
async_response = createAsyncTaskResponse(response)
|
|
2047
|
+
resolve(async_response)
|
|
2048
|
+
return True
|
|
2049
|
+
|
|
2050
|
+
return False
|
|
2051
|
+
|
|
2052
|
+
try:
|
|
2053
|
+
response = await getIntervalWithPromise(
|
|
2054
|
+
check_webhook_ack, debugKey=debug_key, timeOutDuration=WEBHOOK_TIMEOUT
|
|
2055
|
+
)
|
|
2056
|
+
finally:
|
|
2057
|
+
lis["destroy"]()
|
|
2058
|
+
|
|
2059
|
+
if isinstance(response, dict) and "code" in response:
|
|
2060
|
+
raise RunwareAPIError(response)
|
|
2061
|
+
|
|
2062
|
+
return response
|
|
2063
|
+
|
|
2064
|
+
async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int, webhook_url: Optional[str] = None) -> Union[List[IVideo], IAsyncTaskResponse]:
|
|
2065
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
2066
|
+
|
|
2067
|
+
async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
2068
|
+
async with self._messages_lock:
|
|
2069
|
+
response_list = self._globalMessages.get(task_uuid, [])
|
|
2070
|
+
|
|
2071
|
+
if not response_list:
|
|
2072
|
+
return False
|
|
2073
|
+
|
|
2074
|
+
response = response_list[0]
|
|
2075
|
+
|
|
2076
|
+
if response.get("code"):
|
|
2077
|
+
raise RunwareAPIError(response)
|
|
2078
|
+
|
|
2079
|
+
if response.get("status") == "success":
|
|
2080
|
+
del self._globalMessages[task_uuid]
|
|
2081
|
+
resolve([response])
|
|
2082
|
+
return True
|
|
2083
|
+
|
|
2084
|
+
if not response.get("imageUUID") and webhook_url:
|
|
2085
|
+
del self._globalMessages[task_uuid]
|
|
2086
|
+
async_response = createAsyncTaskResponse(response)
|
|
2087
|
+
resolve([async_response])
|
|
2088
|
+
return True
|
|
2089
|
+
|
|
2090
|
+
del self._globalMessages[task_uuid]
|
|
2091
|
+
resolve("POLL_NEEDED")
|
|
2092
|
+
return True
|
|
2093
|
+
|
|
2094
|
+
return False
|
|
2095
|
+
|
|
2096
|
+
try:
|
|
2097
|
+
initial_response = await getIntervalWithPromise(
|
|
2098
|
+
check_initial_response,
|
|
2099
|
+
debugKey="video-inference-initial",
|
|
2100
|
+
timeOutDuration=VIDEO_INITIAL_TIMEOUT
|
|
2101
|
+
)
|
|
2102
|
+
finally:
|
|
2103
|
+
lis["destroy"]()
|
|
2104
|
+
|
|
2105
|
+
if initial_response == "POLL_NEEDED":
|
|
2106
|
+
return await self._pollVideoResults(task_uuid, number_results)
|
|
2107
|
+
else:
|
|
2108
|
+
if initial_response and len(initial_response) > 0 and isinstance(initial_response[0], IAsyncTaskResponse):
|
|
2109
|
+
return initial_response[0]
|
|
2110
|
+
return instantiateDataclassList(IVideo, initial_response)
|
|
2111
|
+
|
|
2112
|
+
async def _pollVideoResults(self, task_uuid: str, number_results: int, response_cls: IVideo | IVideoToText = IVideo) -> Union[List[IVideo], List[IVideoToText]]:
|
|
2113
|
+
for poll_count in range(MAX_POLLS_VIDEO_GENERATION):
|
|
2114
|
+
try:
|
|
2115
|
+
responses = await self._sendPollRequest(task_uuid, poll_count)
|
|
2116
|
+
|
|
2117
|
+
# Check if there are any error code, if so, raise RunwareAPIError
|
|
2118
|
+
for response in responses:
|
|
2119
|
+
if response.get("code"):
|
|
2120
|
+
raise RunwareAPIError(response)
|
|
2121
|
+
|
|
2122
|
+
# Process responses using the unified method
|
|
2123
|
+
completed_results = self._processVideoPollingResponse(responses)
|
|
2124
|
+
|
|
2125
|
+
if len(completed_results) >= number_results:
|
|
2126
|
+
return instantiateDataclassList(response_cls, completed_results[:number_results])
|
|
2127
|
+
|
|
2128
|
+
if not self._hasPendingVideos(responses) and not completed_results:
|
|
2129
|
+
raise RunwareAPIError({"message": f"Unexpected polling response at poll {poll_count}"})
|
|
2130
|
+
|
|
2131
|
+
except RunwareAPIError:
|
|
2132
|
+
raise
|
|
2133
|
+
except Exception as e:
|
|
2134
|
+
# For other exceptions, only raise on last poll
|
|
2135
|
+
if poll_count >= MAX_POLLS_VIDEO_GENERATION - 1:
|
|
2136
|
+
raise e
|
|
2137
|
+
|
|
2138
|
+
await asyncio.sleep(VIDEO_POLLING_DELAY / 1000)
|
|
2139
|
+
|
|
2140
|
+
# Different timeout messages based on response type
|
|
2141
|
+
timeout_msg = "Timed out"
|
|
2142
|
+
raise RunwareAPIError({"message": timeout_msg})
|
|
2143
|
+
|
|
2144
|
+
async def _sendPollRequest(self, task_uuid: str, poll_count: int) -> List[Dict[str, Any]]:
|
|
2145
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
2146
|
+
|
|
2147
|
+
try:
|
|
2148
|
+
await self.send([{
|
|
2149
|
+
"taskType": ETaskType.GET_RESPONSE.value,
|
|
2150
|
+
"taskUUID": task_uuid
|
|
2151
|
+
}])
|
|
2152
|
+
|
|
2153
|
+
async def check_poll_response(resolve: callable, reject: callable, *args: Any) -> bool:
|
|
2154
|
+
async with self._messages_lock:
|
|
2155
|
+
response_list = self._globalMessages.get(task_uuid, [])
|
|
2156
|
+
if response_list:
|
|
2157
|
+
del self._globalMessages[task_uuid]
|
|
2158
|
+
resolve(response_list)
|
|
2159
|
+
return True
|
|
2160
|
+
return False
|
|
2161
|
+
|
|
2162
|
+
return await getIntervalWithPromise(
|
|
2163
|
+
check_poll_response,
|
|
2164
|
+
debugKey=f"video-poll-{poll_count}",
|
|
2165
|
+
timeOutDuration=VIDEO_INITIAL_TIMEOUT
|
|
2166
|
+
)
|
|
2167
|
+
finally:
|
|
2168
|
+
lis["destroy"]()
|
|
2169
|
+
|
|
2170
|
+
def _processVideoPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
2171
|
+
completed_results = []
|
|
2172
|
+
|
|
2173
|
+
for response in responses:
|
|
2174
|
+
if response.get("code"):
|
|
2175
|
+
raise RunwareAPIError(response)
|
|
2176
|
+
status = response.get("status")
|
|
2177
|
+
|
|
2178
|
+
if status == "success":
|
|
2179
|
+
completed_results.append(response)
|
|
2180
|
+
return completed_results
|
|
2181
|
+
|
|
2182
|
+
def _hasPendingVideos(self, responses: List[Dict[str, Any]]) -> bool:
|
|
2183
|
+
return any(response.get("status") == "processing" for response in responses)
|
|
2184
|
+
|
|
2185
|
+
async def audioInference(self, requestAudio: IAudioInference) -> List[IAudio]:
|
|
2186
|
+
await self.ensureConnection()
|
|
2187
|
+
return await asyncRetry(lambda: self._requestAudio(requestAudio))
|
|
2188
|
+
|
|
2189
|
+
async def _requestAudio(self, requestAudio: IAudioInference) -> List[IAudio]:
|
|
2190
|
+
requestAudio.taskUUID = requestAudio.taskUUID or getUUID()
|
|
2191
|
+
request_object = self._buildAudioRequest(requestAudio)
|
|
2192
|
+
await self.send([request_object])
|
|
2193
|
+
return await self._handleInitialAudioResponse(requestAudio.taskUUID, requestAudio.numberResults)
|
|
2194
|
+
|
|
2195
|
+
def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]:
|
|
2196
|
+
request_object = {
|
|
2197
|
+
"deliveryMethod": requestAudio.deliveryMethod,
|
|
2198
|
+
"taskType": ETaskType.AUDIO_INFERENCE.value,
|
|
2199
|
+
"taskUUID": requestAudio.taskUUID,
|
|
2200
|
+
"model": requestAudio.model,
|
|
2201
|
+
"numberResults": requestAudio.numberResults,
|
|
2202
|
+
}
|
|
2203
|
+
|
|
2204
|
+
# Only add positivePrompt if it's provided
|
|
2205
|
+
if requestAudio.positivePrompt is not None:
|
|
2206
|
+
request_object["positivePrompt"] = requestAudio.positivePrompt.strip()
|
|
2207
|
+
|
|
2208
|
+
# Only add duration if it's provided and not using composition plan
|
|
2209
|
+
if requestAudio.duration is not None:
|
|
2210
|
+
request_object["duration"] = requestAudio.duration
|
|
2211
|
+
|
|
2212
|
+
self._addOptionalAudioFields(request_object, requestAudio)
|
|
2213
|
+
self._addOptionalField(request_object, requestAudio.audioSettings)
|
|
2214
|
+
self._addAudioProviderSettings(request_object, requestAudio)
|
|
2215
|
+
|
|
2216
|
+
return request_object
|
|
2217
|
+
|
|
2218
|
+
def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
|
|
2219
|
+
optional_fields = [
|
|
2220
|
+
"outputType", "outputFormat", "includeCost", "uploadEndpoint", "webhookURL"
|
|
2221
|
+
]
|
|
2222
|
+
|
|
2223
|
+
for field in optional_fields:
|
|
2224
|
+
value = getattr(requestAudio, field, None)
|
|
2225
|
+
if value is not None:
|
|
2226
|
+
request_object[field] = value
|
|
2227
|
+
|
|
2228
|
+
|
|
2229
|
+
def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
|
|
2230
|
+
if not requestAudio.providerSettings:
|
|
2231
|
+
return
|
|
2232
|
+
provider_dict = requestAudio.providerSettings.to_request_dict()
|
|
2233
|
+
if provider_dict:
|
|
2234
|
+
request_object["providerSettings"] = provider_dict
|
|
2235
|
+
|
|
2236
|
+
async def _handleInitialAudioResponse(self, task_uuid: str, number_results: int) -> List[IAudio]:
|
|
2237
|
+
if number_results == 1:
|
|
2238
|
+
# Single result - wait for completion
|
|
2239
|
+
response = await self._waitForAudioCompletion(task_uuid)
|
|
2240
|
+
return [response] if response else []
|
|
2241
|
+
else:
|
|
2242
|
+
# Multiple results - use polling
|
|
2243
|
+
return await self._pollAudioResults(task_uuid, number_results)
|
|
2244
|
+
|
|
2245
|
+
async def _waitForAudioCompletion(self, task_uuid: str) -> Optional[IAudio]:
|
|
2246
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
2247
|
+
|
|
2248
|
+
async def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
|
|
2249
|
+
async with self._messages_lock:
|
|
2250
|
+
response = self._globalMessages.get(task_uuid)
|
|
2251
|
+
if response:
|
|
2252
|
+
audio_response = response[0] if isinstance(response, list) else response
|
|
2253
|
+
else:
|
|
2254
|
+
audio_response = response
|
|
2255
|
+
|
|
2256
|
+
if audio_response and audio_response.get("error"):
|
|
2257
|
+
reject(audio_response)
|
|
2258
|
+
return True
|
|
2259
|
+
|
|
2260
|
+
if audio_response:
|
|
2261
|
+
del self._globalMessages[task_uuid]
|
|
2262
|
+
resolve(audio_response)
|
|
2263
|
+
return True
|
|
2264
|
+
|
|
2265
|
+
return False
|
|
2266
|
+
|
|
2267
|
+
try:
|
|
2268
|
+
response = await getIntervalWithPromise(
|
|
2269
|
+
check, debugKey="audio-inference", timeOutDuration=AUDIO_INFERENCE_TIMEOUT
|
|
2270
|
+
)
|
|
2271
|
+
lis["destroy"]()
|
|
2272
|
+
|
|
2273
|
+
if "code" in response:
|
|
2274
|
+
raise RunwareAPIError(response)
|
|
2275
|
+
|
|
2276
|
+
return self._createAudioFromResponse(response) if response else None
|
|
2277
|
+
except Exception as e:
|
|
2278
|
+
lis["destroy"]()
|
|
2279
|
+
raise e
|
|
2280
|
+
|
|
2281
|
+
async def _pollAudioResults(self, task_uuid: str, number_results: int) -> List[IAudio]:
|
|
2282
|
+
completed_results = []
|
|
2283
|
+
lis = self.globalListener(taskUUID=task_uuid)
|
|
2284
|
+
|
|
2285
|
+
try:
|
|
2286
|
+
for poll_count in range(MAX_POLLS_AUDIO_GENERATION):
|
|
2287
|
+
async with self._messages_lock:
|
|
2288
|
+
responses = self._globalMessages.get(task_uuid, [])
|
|
2289
|
+
if not isinstance(responses, list):
|
|
2290
|
+
responses = [responses] if responses else []
|
|
2291
|
+
|
|
2292
|
+
processed_responses = self._processAudioPollingResponse(responses)
|
|
2293
|
+
completed_results.extend(processed_responses)
|
|
2294
|
+
|
|
2295
|
+
if len(completed_results) >= number_results:
|
|
2296
|
+
break
|
|
2297
|
+
|
|
2298
|
+
if poll_count >= MAX_POLLS_AUDIO_GENERATION - 1:
|
|
2299
|
+
raise RunwareAPIError(
|
|
2300
|
+
{"message": f"Audio generation timeout after {MAX_POLLS_AUDIO_GENERATION} polls"})
|
|
2301
|
+
|
|
2302
|
+
await asyncio.sleep(AUDIO_POLLING_DELAY / 1000)
|
|
2303
|
+
|
|
2304
|
+
finally:
|
|
2305
|
+
lis["destroy"]()
|
|
2306
|
+
async with self._messages_lock:
|
|
2307
|
+
if task_uuid in self._globalMessages:
|
|
2308
|
+
del self._globalMessages[task_uuid]
|
|
2309
|
+
|
|
2310
|
+
return [self._createAudioFromResponse(response) for response in completed_results[:number_results]]
|
|
2311
|
+
|
|
2312
|
+
def _processAudioPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
2313
|
+
completed_results = []
|
|
2314
|
+
|
|
2315
|
+
for response in responses:
|
|
2316
|
+
if response.get("code"):
|
|
2317
|
+
raise RunwareAPIError(response)
|
|
2318
|
+
status = response.get("status")
|
|
2319
|
+
if status == "success":
|
|
2320
|
+
completed_results.append(response)
|
|
2321
|
+
|
|
2322
|
+
return completed_results
|
|
2323
|
+
|
|
2324
|
+
def _createAudioFromResponse(self, response: Dict[str, Any]) -> IAudio:
|
|
2325
|
+
return IAudio(
|
|
2326
|
+
taskType=response.get("taskType", ""),
|
|
2327
|
+
taskUUID=response.get("taskUUID", ""),
|
|
2328
|
+
status=response.get("status"),
|
|
2329
|
+
audioUUID=response.get("audioUUID"),
|
|
2330
|
+
audioURL=response.get("audioURL"),
|
|
2331
|
+
audioBase64Data=response.get("audioBase64Data"),
|
|
2332
|
+
audioDataURI=response.get("audioDataURI"),
|
|
2333
|
+
cost=response.get("cost")
|
|
2334
|
+
)
|
|
2335
|
+
|
|
2336
|
+
def connected(self) -> bool:
|
|
2337
|
+
"""
|
|
2338
|
+
Check if the current WebSocket connection is active and authenticated.
|
|
2339
|
+
|
|
2340
|
+
:return: True if the connection is active and authenticated, False otherwise.
|
|
2341
|
+
"""
|
|
2342
|
+
return self.isWebsocketReadyState() and self._connectionSessionUUID is not None
|