runware 0.4.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
runware/base.py ADDED
@@ -0,0 +1,2342 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import re
6
+ import uuid
7
+ from asyncio import gather
8
+ from dataclasses import asdict
9
+ from typing import List, Optional, Union, Callable, Any, Dict
10
+
11
+ from websockets.protocol import State
12
+
13
+ from .logging_config import configure_logging
14
+
15
+ from .async_retry import asyncRetry
16
+ from .reconnection import ConnectionState, ReconnectionManager
17
+ from .types import (
18
+ Environment,
19
+ IImageInference,
20
+ IPhotoMaker,
21
+ IImageCaption,
22
+ IImageToText,
23
+ IImageBackgroundRemoval,
24
+ ISafety,
25
+ IPromptEnhance,
26
+ IEnhancedPrompt,
27
+ IImageUpscale,
28
+ IUploadModelBaseType,
29
+ IUploadModelResponse,
30
+ ReconnectingWebsocketProps,
31
+ UploadImageType,
32
+ MediaStorageType,
33
+ EPreProcessorGroup,
34
+ File,
35
+ ETaskType,
36
+ IModelSearch,
37
+ IModelSearchResponse,
38
+ IControlNet,
39
+ IVideo,
40
+ IVideoCaption,
41
+ IVideoToText,
42
+ IVideoBackgroundRemoval,
43
+ IVideoUpscale,
44
+ IVideoInference,
45
+ IVideoAdvancedFeatures,
46
+ IAcceleratorOptions,
47
+ IAudio,
48
+ IAudioInference,
49
+ IFrameImage,
50
+ IAsyncTaskResponse,
51
+ IVectorize,
52
+ )
53
+ from .types import IImage, IError, SdkType, ListenerType
54
+ from .utils import (
55
+ BASE_RUNWARE_URLS,
56
+ getUUID,
57
+ fileToBase64,
58
+ createImageFromResponse,
59
+ createImageToTextFromResponse,
60
+ createVideoToTextFromResponse,
61
+ createEnhancedPromptsFromResponse,
62
+ instantiateDataclassList,
63
+ RunwareAPIError,
64
+ RunwareError,
65
+ instantiateDataclass,
66
+ TIMEOUT_DURATION,
67
+ accessDeepObject,
68
+ getIntervalWithPromise,
69
+ removeListener,
70
+ LISTEN_TO_IMAGES_KEY,
71
+ isLocalFile,
72
+ process_image,
73
+ createAsyncTaskResponse,
74
+ VIDEO_INITIAL_TIMEOUT,
75
+ VIDEO_POLLING_DELAY,
76
+ WEBHOOK_TIMEOUT,
77
+ IMAGE_INFERENCE_TIMEOUT,
78
+ IMAGE_OPERATION_TIMEOUT,
79
+ PROMPT_ENHANCE_TIMEOUT,
80
+ IMAGE_UPLOAD_TIMEOUT,
81
+ AUDIO_INFERENCE_TIMEOUT,
82
+ AUDIO_POLLING_DELAY,
83
+ MAX_POLLS_AUDIO_GENERATION,
84
+ MAX_POLLS_VIDEO_GENERATION,
85
+ )
86
+
87
+ # Configure logging
88
+ configure_logging(log_level=logging.CRITICAL)
89
+
90
+ logger = logging.getLogger(__name__)
91
+
92
+
93
+ class RunwareBase:
94
+ def __init__(
95
+ self,
96
+ api_key: str,
97
+ url: str = BASE_RUNWARE_URLS[Environment.PRODUCTION],
98
+ timeout: int = TIMEOUT_DURATION,
99
+ log_level=logging.CRITICAL,
100
+ ):
101
+ if timeout <= 0:
102
+ raise ValueError("Timeout must be greater than 0 milliseconds")
103
+
104
+ # Configure logging
105
+ configure_logging(log_level)
106
+ self.logger = logging.getLogger(__name__)
107
+ self.logger.setLevel(log_level)
108
+
109
+ self._ws: Optional[ReconnectingWebsocketProps] = None
110
+ self._listeners: List[ListenerType] = []
111
+ self._apiKey: str = api_key
112
+ self._url: Optional[str] = url
113
+ self._timeout: int = timeout
114
+ self._globalMessages: Dict[str, Any] = {}
115
+ self._globalImages: List[IImage] = []
116
+ self._globalError: Optional[IError] = None
117
+ self._connectionSessionUUID: Optional[str] = None
118
+ self._invalidAPIkey: Optional[str] = None
119
+ self._sdkType: SdkType = SdkType.SERVER
120
+ self._messages_lock = asyncio.Lock()
121
+ self._images_lock = asyncio.Lock()
122
+ self._listener_tasks = set()
123
+ self._reconnection_manager = ReconnectionManager(logger=self.logger)
124
+
125
+
126
+ def _create_safe_async_listener(self, async_func):
127
+ def wrapper(m):
128
+ task = asyncio.create_task(async_func(m))
129
+ self._listener_tasks.add(task)
130
+ def handle_task_exception(t):
131
+ self._listener_tasks.discard(t)
132
+ if not t.cancelled():
133
+ try:
134
+ t.result()
135
+ except Exception as e:
136
+ logger.error(f"Unhandled exception in async listener: {e}", exc_info=True)
137
+
138
+ task.add_done_callback(handle_task_exception)
139
+ return None
140
+
141
+ return wrapper
142
+
143
+ async def _cleanup_listener_tasks(self):
144
+ if not self._listener_tasks:
145
+ return
146
+
147
+ self.logger.info(f"Cleaning up {len(self._listener_tasks)} listener tasks")
148
+
149
+ for task in list(self._listener_tasks):
150
+ if not task.done():
151
+ task.cancel()
152
+
153
+ if self._listener_tasks:
154
+ await asyncio.gather(*self._listener_tasks, return_exceptions=True)
155
+
156
+ self._listener_tasks.clear()
157
+ self.logger.info("All listener tasks cleaned up")
158
+
159
+ def isWebsocketReadyState(self) -> bool:
160
+ if self._ws is None:
161
+ return False
162
+ return self._ws.state is State.OPEN
163
+
164
+ def isAuthenticated(self):
165
+ return self._connectionSessionUUID is not None
166
+
167
+ def addListener(
168
+ self,
169
+ lis: Callable[[Any], Any],
170
+ check: Callable[[Any], Any],
171
+ groupKey: Optional[str] = None,
172
+ ) -> Dict[str, Callable[[], None]]:
173
+ # Get the current frame
174
+ current_frame = inspect.currentframe()
175
+
176
+ # Get the caller's frame
177
+ caller_frame = current_frame.f_back
178
+
179
+ # Get the caller's function name
180
+ caller_name = caller_frame.f_code.co_name
181
+
182
+ # Get the caller's line number
183
+ caller_line_number = caller_frame.f_lineno
184
+
185
+ debug_message = f"Listener {self.addListener.__name__} created by {caller_name} at line {caller_line_number} with listener: {lis} and check: {check}"
186
+ # logger.debug(debug_message)
187
+
188
+ if not lis or not check:
189
+ raise ValueError("Listener and check functions are required")
190
+
191
+ def listener(msg: Any) -> None:
192
+ if not lis or not check:
193
+ raise ValueError("Listener and check functions are required")
194
+ if msg.get("error"):
195
+ lis(msg)
196
+ elif check(msg):
197
+ lis(msg)
198
+
199
+ groupListener: ListenerType = ListenerType(
200
+ key=getUUID(),
201
+ listener=listener,
202
+ group_key=groupKey,
203
+ debug_message=debug_message,
204
+ )
205
+ self._listeners.append(groupListener)
206
+
207
+ def destroy() -> None:
208
+ self._listeners = removeListener(self._listeners, groupListener)
209
+
210
+ return {"destroy": destroy}
211
+
212
+ def handle_connection_response(self, m):
213
+ if m.get("error"):
214
+ if m["errorId"] == 19:
215
+ self._invalidAPIkey = "Invalid API key"
216
+ else:
217
+ self._invalidAPIkey = "Error connection"
218
+ return
219
+ self._connectionSessionUUID = m.get("newConnectionSessionUUID", {}).get(
220
+ "connectionSessionUUID"
221
+ )
222
+ self._invalidAPIkey = None
223
+
224
+ async def photoMaker(self, requestPhotoMaker: IPhotoMaker) -> Union[List[IImage], IAsyncTaskResponse]:
225
+ retry_count = 0
226
+
227
+ try:
228
+ await self.ensureConnection()
229
+
230
+ task_uuid = requestPhotoMaker.taskUUID or getUUID()
231
+ requestPhotoMaker.taskUUID = task_uuid
232
+
233
+ for i, image in enumerate(requestPhotoMaker.inputImages):
234
+ if isLocalFile(image) and not str(image).startswith("http"):
235
+ requestPhotoMaker.inputImages[i] = await fileToBase64(image)
236
+
237
+ prompt = f"{requestPhotoMaker.positivePrompt}".strip()
238
+ request_object = {
239
+ "taskUUID": requestPhotoMaker.taskUUID,
240
+ "model": requestPhotoMaker.model,
241
+ "positivePrompt": prompt,
242
+ "numberResults": requestPhotoMaker.numberResults,
243
+ "height": requestPhotoMaker.height,
244
+ "width": requestPhotoMaker.width,
245
+ "taskType": ETaskType.PHOTO_MAKER.value,
246
+ "style": requestPhotoMaker.style,
247
+ "strength": requestPhotoMaker.strength,
248
+ **(
249
+ {"inputImages": requestPhotoMaker.inputImages}
250
+ if requestPhotoMaker.inputImages
251
+ else {}
252
+ ),
253
+ **(
254
+ {"steps": requestPhotoMaker.steps}
255
+ if requestPhotoMaker.steps
256
+ else {}
257
+ ),
258
+ }
259
+
260
+ if requestPhotoMaker.outputFormat is not None:
261
+ request_object["outputFormat"] = requestPhotoMaker.outputFormat
262
+ if requestPhotoMaker.includeCost:
263
+ request_object["includeCost"] = requestPhotoMaker.includeCost
264
+ if requestPhotoMaker.outputType:
265
+ request_object["outputType"] = requestPhotoMaker.outputType
266
+ if requestPhotoMaker.webhookURL:
267
+ request_object["webhookURL"] = requestPhotoMaker.webhookURL
268
+
269
+ await self.send([request_object])
270
+
271
+ if requestPhotoMaker.webhookURL:
272
+ return await self._handleWebhookAcknowledgment(
273
+ task_uuid=task_uuid,
274
+ task_type="photoMaker",
275
+ debug_key="photo-maker-webhook"
276
+ )
277
+
278
+ lis = self.globalListener(
279
+ taskUUID=task_uuid,
280
+ )
281
+
282
+ numberOfResults = requestPhotoMaker.numberResults
283
+
284
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
285
+ async with self._messages_lock:
286
+ photo_maker_list = self._globalMessages.get(task_uuid, [])
287
+ unique_results = {}
288
+
289
+ for made_photo in photo_maker_list:
290
+ if made_photo.get("code"):
291
+ raise RunwareAPIError(made_photo)
292
+
293
+ if made_photo.get("taskType") != "photoMaker":
294
+ continue
295
+
296
+ image_uuid = made_photo.get("imageUUID")
297
+ if image_uuid not in unique_results:
298
+ unique_results[image_uuid] = made_photo
299
+
300
+ if 0 < numberOfResults <= len(unique_results):
301
+ del self._globalMessages[task_uuid]
302
+ resolve(list(unique_results.values()))
303
+ return True
304
+
305
+ return False
306
+
307
+ response = await getIntervalWithPromise(check, debugKey="photo-maker", timeOutDuration=IMAGE_INFERENCE_TIMEOUT)
308
+
309
+ lis["destroy"]()
310
+
311
+ if "code" in response:
312
+ # This indicates an error response
313
+ raise RunwareAPIError(response)
314
+
315
+ if response:
316
+ if not isinstance(response, list):
317
+ response = [response]
318
+
319
+ return instantiateDataclassList(IImage, response)
320
+
321
+ except Exception as e:
322
+ if retry_count >= 2:
323
+ logger.error(f"Error in photoMaker request:", exc_info=e)
324
+ raise RunwareAPIError({"message": f"PhotoMaker failed after retries: {str(e)}"})
325
+ else:
326
+ raise e
327
+
328
+ async def imageInference(
329
+ self, requestImage: IImageInference
330
+ ) -> Union[List[IImage], IAsyncTaskResponse]:
331
+ let_lis: Optional[Any] = None
332
+ request_object: Optional[Dict[str, Any]] = None
333
+ task_uuids: List[str] = []
334
+ retry_count = 0
335
+ try:
336
+ await self.ensureConnection()
337
+ control_net_data: List[IControlNet] = []
338
+ requestImage.maskImage = await process_image(requestImage.maskImage)
339
+ requestImage.seedImage = await process_image(requestImage.seedImage)
340
+ if requestImage.referenceImages:
341
+ requestImage.referenceImages = await process_image(
342
+ requestImage.referenceImages
343
+ )
344
+ if requestImage.controlNet:
345
+ for control_data in requestImage.controlNet:
346
+ image_uploaded = await self.uploadImage(control_data.guideImage)
347
+ if not image_uploaded:
348
+ return []
349
+ if hasattr(control_data, "preprocessor"):
350
+ control_data.preprocessor = control_data.preprocessor.value
351
+ control_data.guideImage = image_uploaded.imageUUID
352
+ control_net_data.append(control_data)
353
+ prompt = f"{requestImage.positivePrompt}".strip()
354
+
355
+ control_net_data_dicts = [asdict(item) for item in control_net_data]
356
+
357
+ instant_id_data = {}
358
+ if requestImage.instantID:
359
+ instant_id_data = {
360
+ k: v
361
+ for k, v in vars(requestImage.instantID).items()
362
+ if v is not None
363
+ }
364
+
365
+ if "inputImage" in instant_id_data:
366
+ instant_id_data["inputImage"] = await process_image(
367
+ instant_id_data["inputImage"]
368
+ )
369
+
370
+ if "poseImage" in instant_id_data:
371
+ instant_id_data["poseImage"] = await process_image(
372
+ instant_id_data["poseImage"]
373
+ )
374
+
375
+ ip_adapters_data = []
376
+ if requestImage.ipAdapters:
377
+ for ip_adapter in requestImage.ipAdapters:
378
+ ip_adapter_data = {
379
+ k: v for k, v in vars(ip_adapter).items() if v is not None
380
+ }
381
+ if "guideImage" in ip_adapter_data:
382
+ ip_adapter_data["guideImage"] = await process_image(
383
+ ip_adapter_data["guideImage"]
384
+ )
385
+
386
+ ip_adapters_data.append(ip_adapter_data)
387
+
388
+ ace_plus_plus_data = {}
389
+ if requestImage.acePlusPlus:
390
+ ace_plus_plus_data = {
391
+ "inputImages": [],
392
+ "repaintingScale": requestImage.acePlusPlus.repaintingScale,
393
+ "type": requestImage.acePlusPlus.taskType,
394
+ }
395
+ if requestImage.acePlusPlus.inputImages:
396
+ ace_plus_plus_data["inputImages"] = await process_image(
397
+ requestImage.acePlusPlus.inputImages
398
+ )
399
+ if requestImage.acePlusPlus.inputMasks:
400
+ ace_plus_plus_data["inputMasks"] = await process_image(
401
+ requestImage.acePlusPlus.inputMasks
402
+ )
403
+
404
+ pulid_data = {}
405
+ if requestImage.puLID:
406
+ pulid_data = {
407
+ "inputImages": [],
408
+ "idWeight": requestImage.puLID.idWeight,
409
+ "trueCFGScale": requestImage.puLID.trueCFGScale,
410
+ "CFGStartStep": requestImage.puLID.CFGStartStep,
411
+ "CFGStartStepPercentage": requestImage.puLID.CFGStartStepPercentage,
412
+ }
413
+ if requestImage.puLID.inputImages:
414
+ pulid_data["inputImages"] = await process_image(
415
+ requestImage.puLID.inputImages
416
+ )
417
+
418
+ request_object = self._buildImageRequest(requestImage, prompt, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
419
+
420
+ return await asyncRetry(
421
+ lambda: self._requestImages(
422
+ request_object=request_object,
423
+ task_uuids=task_uuids,
424
+ let_lis=let_lis,
425
+ retry_count=retry_count,
426
+ number_of_images=requestImage.numberResults,
427
+ on_partial_images=requestImage.onPartialImages,
428
+ )
429
+ )
430
+ except Exception as e:
431
+ if retry_count >= 2:
432
+ logger.error(f"Error in requestImages:", exc_info=e)
433
+ raise RunwareAPIError({"message": f"Image inference failed after retries: {str(e)}"})
434
+ else:
435
+ raise e
436
+
437
+ async def _requestImages(
438
+ self,
439
+ request_object: Dict[str, Any],
440
+ task_uuids: List[str],
441
+ let_lis: Optional[Any],
442
+ retry_count: int,
443
+ number_of_images: int,
444
+ on_partial_images: Optional[Callable[[List[IImage], Optional[IError]], None]],
445
+ ) -> Union[List[IImage], IAsyncTaskResponse]:
446
+ retry_count += 1
447
+ if let_lis:
448
+ let_lis["destroy"]()
449
+ images_with_similar_task = [
450
+ img for img in self._globalImages if img.get("taskUUID") in task_uuids
451
+ ]
452
+
453
+ task_uuid = request_object.get("taskUUID")
454
+ if task_uuid is None:
455
+ task_uuid = getUUID()
456
+
457
+ task_uuids.append(task_uuid)
458
+
459
+ image_remaining = number_of_images - len(images_with_similar_task)
460
+ new_request_object = {
461
+ **request_object,
462
+ "taskUUID": task_uuid,
463
+ "numberResults": image_remaining,
464
+ }
465
+
466
+ await self.send([new_request_object])
467
+
468
+ if new_request_object.get("webhookURL"):
469
+ return await self._handleWebhookAcknowledgment(
470
+ task_uuid=task_uuid,
471
+ task_type="imageInference",
472
+ debug_key="image-inference-webhook"
473
+ )
474
+
475
+ let_lis = await self.listenToImages(
476
+ onPartialImages=on_partial_images,
477
+ taskUUID=task_uuid,
478
+ groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
479
+ )
480
+ images = await self.getSimililarImage(
481
+ taskUUID=task_uuids,
482
+ numberOfImages=number_of_images,
483
+ shouldThrowError=True,
484
+ lis=let_lis,
485
+ )
486
+
487
+ let_lis["destroy"]()
488
+ # TODO: NameError("name 'image_path' is not defined"). I think I remove the images when I have onPartialImages
489
+ if images:
490
+ if "code" in images:
491
+ # This indicates an error response
492
+ raise RunwareAPIError(images)
493
+
494
+ return instantiateDataclassList(IImage, images)
495
+
496
+ # return images
497
+
498
+ async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]:
499
+ try:
500
+ await self.ensureConnection()
501
+ return await asyncRetry(
502
+ lambda: self._requestImageToText(requestImageToText)
503
+ )
504
+ except Exception as e:
505
+ raise e
506
+
507
+ async def _requestImageToText(
508
+ self, requestImageToText: IImageCaption
509
+ ) -> Union[IImageToText, IAsyncTaskResponse]:
510
+ # Prepare image list - inputImages is primary, inputImage is convenience
511
+ if requestImageToText.inputImages is not None:
512
+ images_to_process = requestImageToText.inputImages
513
+ elif requestImageToText.inputImage is not None:
514
+ # Single image provided via inputImage - convert to array
515
+ images_to_process = [requestImageToText.inputImage]
516
+ else:
517
+ raise ValueError("Either inputImages or inputImage must be provided")
518
+
519
+ # Set inputImage to inputImages[0] if not already provided
520
+ actual_input_image = requestImageToText.inputImage
521
+ if actual_input_image is None and images_to_process:
522
+ actual_input_image = images_to_process[0]
523
+ # Upload all images
524
+ uploaded_images = []
525
+ for image in images_to_process:
526
+ image_uploaded = await self.uploadImage(image)
527
+ if not image_uploaded or not image_uploaded.imageUUID:
528
+ return None
529
+ uploaded_images.append(image_uploaded.imageUUID)
530
+
531
+ taskUUID = getUUID()
532
+
533
+ # Create a dictionary with mandatory parameters
534
+ task_params = {
535
+ "taskType": ETaskType.IMAGE_CAPTION.value,
536
+ "taskUUID": taskUUID,
537
+ }
538
+ # Add either inputImage or inputImages, but not both (API requirement)
539
+ if len(uploaded_images) == 1:
540
+ # Single image - use inputImage parameter
541
+ task_params["inputImage"] = uploaded_images[0]
542
+ else:
543
+ # Multiple images - use inputImages parameter
544
+ task_params["inputImages"] = uploaded_images
545
+
546
+ # Add model parameter only if specified - backend handles default
547
+ if requestImageToText.model is not None:
548
+ task_params["model"] = requestImageToText.model
549
+
550
+ # Add template parameter if specified
551
+ if requestImageToText.template is not None:
552
+ task_params["template"] = requestImageToText.template
553
+ # When using template, do NOT include prompt parameter
554
+ else:
555
+ # Use the provided prompt when no template
556
+ task_params["prompt"] = requestImageToText.prompt
557
+
558
+ # Add optional parameters if they are provided
559
+ if requestImageToText.includeCost:
560
+ task_params["includeCost"] = requestImageToText.includeCost
561
+ if requestImageToText.webhookURL:
562
+ task_params["webhookURL"] = requestImageToText.webhookURL
563
+
564
+ await self.send([task_params])
565
+
566
+ if requestImageToText.webhookURL:
567
+ return await self._handleWebhookAcknowledgment(
568
+ task_uuid=taskUUID,
569
+ task_type="imageCaption",
570
+ debug_key="image-caption-webhook"
571
+ )
572
+
573
+ lis = self.globalListener(
574
+ taskUUID=taskUUID,
575
+ )
576
+
577
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
578
+ async with self._messages_lock:
579
+ response = self._globalMessages.get(taskUUID)
580
+ if response:
581
+ image_to_text = response[0]
582
+ else:
583
+ image_to_text = response
584
+ if image_to_text and image_to_text.get("error"):
585
+ reject(image_to_text)
586
+ return True
587
+
588
+ if image_to_text:
589
+ del self._globalMessages[taskUUID]
590
+ resolve(image_to_text)
591
+ return True
592
+
593
+ return False
594
+
595
+ response = await getIntervalWithPromise(
596
+ check, debugKey="image-to-text", timeOutDuration=IMAGE_OPERATION_TIMEOUT
597
+ )
598
+
599
+
600
+ lis["destroy"]()
601
+
602
+ if "code" in response:
603
+ # This indicates an error response
604
+ raise RunwareAPIError(response)
605
+
606
+ if response:
607
+ return createImageToTextFromResponse(response)
608
+ else:
609
+ return None
610
+
611
+ async def videoCaption(self, requestVideoCaption: IVideoCaption) -> Union[List[IVideoToText], IAsyncTaskResponse]:
612
+ try:
613
+ await self.ensureConnection()
614
+ return await asyncRetry(
615
+ lambda: self._requestVideoCaption(requestVideoCaption)
616
+ )
617
+ except Exception as e:
618
+ raise e
619
+
620
+ async def _requestVideoCaption(
621
+ self, requestVideoCaption: IVideoCaption
622
+ ) -> Union[List[IVideoToText], IAsyncTaskResponse]:
623
+ taskUUID = requestVideoCaption.taskUUID or getUUID()
624
+
625
+ # Create the request object
626
+ task_params = {
627
+ "taskType": ETaskType.VIDEO_CAPTION.value,
628
+ "taskUUID": taskUUID,
629
+ "model": requestVideoCaption.model,
630
+ "inputs": {
631
+ "video": requestVideoCaption.inputs.video
632
+ },
633
+ "deliveryMethod": requestVideoCaption.deliveryMethod,
634
+ }
635
+
636
+ # Add optional parameters
637
+ if requestVideoCaption.includeCost is not None:
638
+ task_params["includeCost"] = requestVideoCaption.includeCost
639
+ if requestVideoCaption.webhookURL:
640
+ task_params["webhookURL"] = requestVideoCaption.webhookURL
641
+
642
+ await self.send([task_params])
643
+
644
+ if requestVideoCaption.webhookURL:
645
+ return await self._handleWebhookAcknowledgment(
646
+ task_uuid=taskUUID,
647
+ task_type="caption",
648
+ debug_key="video-caption-webhook"
649
+ )
650
+
651
+ # For async without webhook, poll for results using _pollVideoResults
652
+ return await self._pollVideoResults(taskUUID, 1, IVideoToText)
653
+
654
+ async def videoBackgroundRemoval(self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval) -> Union[List[IVideo], IAsyncTaskResponse]:
655
+
656
+ try:
657
+ await self.ensureConnection()
658
+ return await asyncRetry(
659
+ lambda: self._requestVideoBackgroundRemoval(requestVideoBackgroundRemoval)
660
+ )
661
+ except Exception as e:
662
+ raise e
663
+
664
+ async def _requestVideoBackgroundRemoval(
665
+ self, requestVideoBackgroundRemoval: IVideoBackgroundRemoval
666
+ ) -> Union[List[IVideo], IAsyncTaskResponse]:
667
+ taskUUID = requestVideoBackgroundRemoval.taskUUID or getUUID()
668
+
669
+ # Create the request object
670
+ task_params = {
671
+ "taskType": ETaskType.VIDEO_BACKGROUND_REMOVAL.value, # "removeBackground"
672
+ "taskUUID": taskUUID,
673
+ "model": requestVideoBackgroundRemoval.model,
674
+ "inputs": {
675
+ "video": requestVideoBackgroundRemoval.inputs.video
676
+ },
677
+ "deliveryMethod": requestVideoBackgroundRemoval.deliveryMethod,
678
+ }
679
+
680
+ # Add optional parameters
681
+ if requestVideoBackgroundRemoval.outputFormat:
682
+ task_params["outputFormat"] = requestVideoBackgroundRemoval.outputFormat
683
+ if requestVideoBackgroundRemoval.includeCost is not None:
684
+ task_params["includeCost"] = requestVideoBackgroundRemoval.includeCost
685
+ if requestVideoBackgroundRemoval.webhookURL:
686
+ task_params["webhookURL"] = requestVideoBackgroundRemoval.webhookURL
687
+ if requestVideoBackgroundRemoval.settings:
688
+ # Convert IBackgroundRemovalSettings to dict, filtering out None values
689
+ settings_dict = {
690
+ k: v
691
+ for k, v in vars(requestVideoBackgroundRemoval.settings).items()
692
+ if v is not None
693
+ }
694
+ task_params["settings"] = settings_dict
695
+
696
+ await self.send([task_params])
697
+
698
+ if requestVideoBackgroundRemoval.webhookURL:
699
+ return await self._handleWebhookAcknowledgment(
700
+ task_uuid=taskUUID,
701
+ task_type="removeBackground",
702
+ debug_key="video-background-removal-webhook"
703
+ )
704
+
705
+ return await self._pollVideoResults(taskUUID, 1, IVideo)
706
+
707
+ async def videoUpscale(self, requestVideoUpscale: IVideoUpscale) -> Union[List[IVideo], IAsyncTaskResponse]:
708
+ try:
709
+ await self.ensureConnection()
710
+ return await asyncRetry(
711
+ lambda: self._requestVideoUpscale(requestVideoUpscale)
712
+ )
713
+ except Exception as e:
714
+ raise e
715
+
716
+ async def _requestVideoUpscale(
717
+ self, requestVideoUpscale: IVideoUpscale
718
+ ) -> Union[List[IVideo], IAsyncTaskResponse]:
719
+ taskUUID = requestVideoUpscale.taskUUID or getUUID()
720
+
721
+ # Create the request object
722
+ task_params = {
723
+ "taskType": ETaskType.VIDEO_UPSCALE.value, # "upscale"
724
+ "taskUUID": taskUUID,
725
+ "model": requestVideoUpscale.model,
726
+ "inputs": {
727
+ "video": requestVideoUpscale.inputs.video
728
+ },
729
+ "upscaleFactor": requestVideoUpscale.upscaleFactor,
730
+ "deliveryMethod": requestVideoUpscale.deliveryMethod,
731
+ }
732
+
733
+ # Add optional parameters
734
+ if requestVideoUpscale.outputFormat:
735
+ task_params["outputFormat"] = requestVideoUpscale.outputFormat
736
+ if requestVideoUpscale.outputType:
737
+ task_params["outputType"] = requestVideoUpscale.outputType
738
+ if requestVideoUpscale.includeCost is not None:
739
+ task_params["includeCost"] = requestVideoUpscale.includeCost
740
+ if requestVideoUpscale.webhookURL:
741
+ task_params["webhookURL"] = requestVideoUpscale.webhookURL
742
+
743
+ await self.send([task_params])
744
+
745
+ if requestVideoUpscale.webhookURL:
746
+ return await self._handleWebhookAcknowledgment(
747
+ task_uuid=taskUUID,
748
+ task_type="upscale",
749
+ debug_key="video-upscale-webhook"
750
+ )
751
+
752
+ return await self._pollVideoResults(taskUUID, 1, IVideo)
753
+
754
+ async def imageBackgroundRemoval(
755
+ self, removeImageBackgroundPayload: IImageBackgroundRemoval
756
+ ) -> Union[List[IImage], IAsyncTaskResponse]:
757
+ try:
758
+ await self.ensureConnection()
759
+ return await asyncRetry(
760
+ lambda: self._removeImageBackground(removeImageBackgroundPayload)
761
+ )
762
+ except Exception as e:
763
+ raise e
764
+
765
+ async def _removeImageBackground(
766
+ self, removeImageBackgroundPayload: IImageBackgroundRemoval
767
+ ) -> Union[List[IImage], IAsyncTaskResponse]:
768
+ inputImage = removeImageBackgroundPayload.inputImage
769
+
770
+ image_uploaded = await self.uploadImage(inputImage)
771
+
772
+ if not image_uploaded or not image_uploaded.imageUUID:
773
+ return []
774
+ if removeImageBackgroundPayload.taskUUID is not None:
775
+ taskUUID = removeImageBackgroundPayload.taskUUID
776
+ else:
777
+ taskUUID = getUUID()
778
+
779
+ # Create a dictionary with mandatory parameters
780
+ task_params = {
781
+ "taskType": ETaskType.IMAGE_BACKGROUND_REMOVAL.value,
782
+ "taskUUID": taskUUID,
783
+ "inputImage": image_uploaded.imageUUID,
784
+ }
785
+
786
+ # Add optional parameters if they are provided
787
+ if removeImageBackgroundPayload.outputType is not None:
788
+ task_params["outputType"] = removeImageBackgroundPayload.outputType
789
+ if removeImageBackgroundPayload.outputFormat is not None:
790
+ task_params["outputFormat"] = removeImageBackgroundPayload.outputFormat
791
+ if removeImageBackgroundPayload.includeCost:
792
+ task_params["includeCost"] = removeImageBackgroundPayload.includeCost
793
+ if removeImageBackgroundPayload.model:
794
+ task_params["model"] = removeImageBackgroundPayload.model
795
+ if removeImageBackgroundPayload.outputQuality:
796
+ task_params["outputQuality"] = removeImageBackgroundPayload.outputQuality
797
+ if removeImageBackgroundPayload.webhookURL:
798
+ task_params["webhookURL"] = removeImageBackgroundPayload.webhookURL
799
+
800
+ # Handle settings if provided - convert dataclass to dictionary and add non-None values
801
+ if removeImageBackgroundPayload.settings:
802
+ settings_dict = {
803
+ k: v
804
+ for k, v in vars(removeImageBackgroundPayload.settings).items()
805
+ if v is not None
806
+ }
807
+ task_params.update(settings_dict)
808
+
809
+ # Add provider settings if provided
810
+ if removeImageBackgroundPayload.providerSettings:
811
+ self._addImageProviderSettings(task_params, removeImageBackgroundPayload)
812
+
813
+ # Add safety settings if provided
814
+ if removeImageBackgroundPayload.safety:
815
+ self._addSafetySettings(task_params, removeImageBackgroundPayload.safety)
816
+
817
+ # Send the task with all applicable parameters
818
+ await self.send([task_params])
819
+
820
+ if removeImageBackgroundPayload.webhookURL:
821
+ return await self._handleWebhookAcknowledgment(
822
+ task_uuid=taskUUID,
823
+ task_type="imageBackgroundRemoval",
824
+ debug_key="image-background-removal-webhook"
825
+ )
826
+
827
+ lis = self.globalListener(
828
+ taskUUID=taskUUID,
829
+ )
830
+
831
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
832
+ async with self._messages_lock:
833
+ response = self._globalMessages.get(taskUUID)
834
+ if response:
835
+ new_remove_background = response[0]
836
+ else:
837
+ new_remove_background = response
838
+ if new_remove_background and new_remove_background.get("error"):
839
+ reject(new_remove_background)
840
+ return True
841
+
842
+ if new_remove_background:
843
+ del self._globalMessages[taskUUID]
844
+ resolve(new_remove_background)
845
+ return True
846
+
847
+ return False
848
+
849
+ response = await getIntervalWithPromise(
850
+ check, debugKey="remove-image-background", timeOutDuration=IMAGE_OPERATION_TIMEOUT
851
+ )
852
+
853
+ lis["destroy"]()
854
+
855
+ if "code" in response:
856
+ # This indicates an error response
857
+ raise RunwareAPIError(response)
858
+
859
+ image = createImageFromResponse(response)
860
+ image_list: List[IImage] = [image]
861
+
862
+ return image_list
863
+
864
+ async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
865
+ try:
866
+ await self.ensureConnection()
867
+ return await asyncRetry(lambda: self._upscaleGan(upscaleGanPayload))
868
+ except Exception as e:
869
+ raise e
870
+
871
+ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]:
872
+ # Support both inputImage (legacy) and inputs.image (new format)
873
+ inputImage = upscaleGanPayload.inputImage
874
+ if not inputImage and upscaleGanPayload.inputs and upscaleGanPayload.inputs.image:
875
+ inputImage = upscaleGanPayload.inputs.image
876
+
877
+ if not inputImage:
878
+ raise ValueError("Either inputImage or inputs.image must be provided")
879
+
880
+ upscaleFactor = upscaleGanPayload.upscaleFactor
881
+
882
+ image_uploaded = await self.uploadImage(inputImage)
883
+
884
+ if not image_uploaded or not image_uploaded.imageUUID:
885
+ return []
886
+
887
+ taskUUID = getUUID()
888
+
889
+ # Create a dictionary with mandatory parameters
890
+ task_params = {
891
+ "taskType": ETaskType.IMAGE_UPSCALE.value,
892
+ "taskUUID": taskUUID,
893
+ "upscaleFactor": upscaleGanPayload.upscaleFactor,
894
+ }
895
+
896
+ # Use inputs.image format if inputs is provided, otherwise use inputImage (legacy)
897
+ if upscaleGanPayload.inputs and upscaleGanPayload.inputs.image:
898
+ task_params["inputs"] = {"image": image_uploaded.imageUUID}
899
+ else:
900
+ task_params["inputImage"] = image_uploaded.imageUUID
901
+
902
+ # Add model parameter if specified
903
+ if upscaleGanPayload.model is not None:
904
+ task_params["model"] = upscaleGanPayload.model
905
+
906
+ # Add settings if provided
907
+ if upscaleGanPayload.settings is not None:
908
+ settings_dict = asdict(upscaleGanPayload.settings)
909
+ # Remove None values
910
+ settings_dict = {k: v for k, v in settings_dict.items() if v is not None}
911
+ if settings_dict:
912
+ task_params["settings"] = settings_dict
913
+
914
+ # Add optional parameters if they are provided
915
+ if upscaleGanPayload.outputType is not None:
916
+ task_params["outputType"] = upscaleGanPayload.outputType
917
+ if upscaleGanPayload.outputFormat is not None:
918
+ task_params["outputFormat"] = upscaleGanPayload.outputFormat
919
+ if upscaleGanPayload.includeCost:
920
+ task_params["includeCost"] = upscaleGanPayload.includeCost
921
+ if upscaleGanPayload.webhookURL:
922
+ task_params["webhookURL"] = upscaleGanPayload.webhookURL
923
+
924
+ # Add provider settings if provided
925
+ if upscaleGanPayload.providerSettings:
926
+ self._addImageProviderSettings(task_params, upscaleGanPayload)
927
+
928
+ # Add safety settings if provided
929
+ if upscaleGanPayload.safety:
930
+ self._addSafetySettings(task_params, upscaleGanPayload.safety)
931
+
932
+ # Send the task with all applicable parameters
933
+
934
+ await self.send([task_params])
935
+
936
+ if upscaleGanPayload.webhookURL:
937
+ return await self._handleWebhookAcknowledgment(
938
+ task_uuid=taskUUID,
939
+ task_type="imageUpscale",
940
+ debug_key="image-upscale-webhook"
941
+ )
942
+
943
+ lis = self.globalListener(
944
+ taskUUID=taskUUID,
945
+ )
946
+
947
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
948
+ async with self._messages_lock:
949
+ response = self._globalMessages.get(taskUUID)
950
+ if response:
951
+ upscaled_image = response[0]
952
+ else:
953
+ upscaled_image = response
954
+ if upscaled_image and upscaled_image.get("error"):
955
+ reject(upscaled_image)
956
+ return True
957
+
958
+ if upscaled_image:
959
+ del self._globalMessages[taskUUID]
960
+ resolve(upscaled_image)
961
+ return True
962
+
963
+ return False
964
+
965
+ response = await getIntervalWithPromise(
966
+ check, debugKey="upscale-gan", timeOutDuration=IMAGE_OPERATION_TIMEOUT
967
+ )
968
+
969
+ lis["destroy"]()
970
+
971
+ if "code" in response:
972
+ # This indicates an error response
973
+ raise RunwareAPIError(response)
974
+
975
+ image = createImageFromResponse(response)
976
+ image_list: List[IImage] = [image]
977
+ return image_list
978
+
979
+ async def imageVectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
980
+ try:
981
+ await self.ensureConnection()
982
+ return await asyncRetry(lambda: self._vectorize(vectorizePayload))
983
+ except Exception as e:
984
+ raise e
985
+
986
+ async def _vectorize(self, vectorizePayload: IVectorize) -> Union[List[IImage], IAsyncTaskResponse]:
987
+ # Process the image from inputs
988
+ input_image = vectorizePayload.inputs.image
989
+
990
+ if not input_image:
991
+ raise ValueError("Image is required in inputs for vectorize task")
992
+
993
+ # Upload the image if it's a local file
994
+ image_uploaded = await self.uploadImage(input_image)
995
+
996
+ if not image_uploaded or not image_uploaded.imageUUID:
997
+ return []
998
+
999
+ taskUUID = getUUID()
1000
+
1001
+ # Create a dictionary with mandatory parameters
1002
+ task_params = {
1003
+ "taskType": ETaskType.IMAGE_VECTORIZE.value,
1004
+ "taskUUID": taskUUID,
1005
+ "inputs": {
1006
+ "image": image_uploaded.imageUUID
1007
+ }
1008
+ }
1009
+
1010
+ # Add optional parameters if they are provided
1011
+ if vectorizePayload.model is not None:
1012
+ task_params["model"] = vectorizePayload.model
1013
+ if vectorizePayload.outputType is not None:
1014
+ task_params["outputType"] = vectorizePayload.outputType
1015
+ if vectorizePayload.outputFormat is not None:
1016
+ task_params["outputFormat"] = vectorizePayload.outputFormat
1017
+ if vectorizePayload.includeCost:
1018
+ task_params["includeCost"] = vectorizePayload.includeCost
1019
+ if vectorizePayload.webhookURL:
1020
+ task_params["webhookURL"] = vectorizePayload.webhookURL
1021
+
1022
+ # Send the task with all applicable parameters
1023
+ await self.send([task_params])
1024
+
1025
+ if vectorizePayload.webhookURL:
1026
+ return await self._handleWebhookAcknowledgment(
1027
+ task_uuid=taskUUID,
1028
+ task_type="vectorize",
1029
+ debug_key="image-vectorize-webhook"
1030
+ )
1031
+
1032
+ let_lis = await self.listenToImages(
1033
+ onPartialImages=None,
1034
+ taskUUID=taskUUID,
1035
+ groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES,
1036
+ )
1037
+
1038
+ images = await self.getSimililarImage(
1039
+ taskUUID=taskUUID,
1040
+ numberOfImages=1,
1041
+ shouldThrowError=True,
1042
+ lis=let_lis,
1043
+ )
1044
+
1045
+ let_lis["destroy"]()
1046
+
1047
+ if "code" in images or "errors" in images:
1048
+ # This indicates an error response
1049
+ raise RunwareAPIError(images)
1050
+
1051
+ return instantiateDataclassList(IImage, images)
1052
+
1053
+ async def promptEnhance(
1054
+ self, promptEnhancer: IPromptEnhance
1055
+ ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]:
1056
+ """
1057
+ Enhance the given prompt by generating multiple versions of it.
1058
+
1059
+ :param promptEnhancer: An IPromptEnhancer object containing the prompt details.
1060
+ :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt.
1061
+ :raises: Any error that occurs during the enhancement process.
1062
+ """
1063
+ try:
1064
+ await self.ensureConnection()
1065
+ return await asyncRetry(lambda: self._enhancePrompt(promptEnhancer))
1066
+ except Exception as e:
1067
+ raise e
1068
+
1069
+ async def _enhancePrompt(
1070
+ self, promptEnhancer: IPromptEnhance
1071
+ ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]:
1072
+ """
1073
+ Internal method to perform the actual prompt enhancement.
1074
+
1075
+ :param promptEnhancer: An IPromptEnhancer object containing the prompt details.
1076
+ :return: A list of IEnhancedPrompt objects representing the enhanced versions of the prompt.
1077
+ """
1078
+ prompt = promptEnhancer.prompt
1079
+ promptMaxLength = getattr(promptEnhancer, "promptMaxLength", 380)
1080
+
1081
+ promptVersions = promptEnhancer.promptVersions or 1
1082
+
1083
+ taskUUID = getUUID()
1084
+
1085
+ # Create a dictionary with mandatory parameters
1086
+ task_params = {
1087
+ "taskType": ETaskType.PROMPT_ENHANCE.value,
1088
+ "taskUUID": taskUUID,
1089
+ "prompt": prompt,
1090
+ "promptMaxLength": promptMaxLength,
1091
+ "promptVersions": promptVersions,
1092
+ }
1093
+
1094
+ # Add optional parameters if they are provided
1095
+ if promptEnhancer.includeCost:
1096
+ task_params["includeCost"] = promptEnhancer.includeCost
1097
+
1098
+ has_webhook = promptEnhancer.webhookURL
1099
+ if has_webhook:
1100
+ task_params["webhookURL"] = promptEnhancer.webhookURL
1101
+
1102
+ # Send the task with all applicable parameters
1103
+ await self.send([task_params])
1104
+
1105
+ if has_webhook:
1106
+ return await self._handleWebhookAcknowledgment(
1107
+ task_uuid=taskUUID,
1108
+ task_type="promptEnhance",
1109
+ debug_key="prompt-enhance-webhook"
1110
+ )
1111
+
1112
+ lis = self.globalListener(
1113
+ taskUUID=taskUUID,
1114
+ )
1115
+
1116
+ async def check(resolve: Any, reject: Any, *args: Any) -> bool:
1117
+ async with self._messages_lock:
1118
+ response = self._globalMessages.get(taskUUID)
1119
+ if isinstance(response, dict) and response.get("error"):
1120
+ reject(response)
1121
+ return True
1122
+ if response:
1123
+ del self._globalMessages[taskUUID]
1124
+ resolve(response)
1125
+ return True
1126
+
1127
+ return False
1128
+
1129
+ response = await getIntervalWithPromise(
1130
+ check, debugKey="enhance-prompt", timeOutDuration=PROMPT_ENHANCE_TIMEOUT
1131
+ )
1132
+
1133
+ lis["destroy"]()
1134
+
1135
+ if "code" in response[0]:
1136
+ # This indicates an error response
1137
+ raise RunwareAPIError(response[0])
1138
+
1139
+ # Transform the response to a list of IEnhancedPrompt objects
1140
+ enhanced_prompts = createEnhancedPromptsFromResponse(response)
1141
+
1142
+ return list(set(enhanced_prompts))
1143
+
1144
+ async def uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
1145
+ try:
1146
+ await self.ensureConnection()
1147
+ return await asyncRetry(lambda: self._uploadImage(file))
1148
+ except Exception as e:
1149
+ raise e
1150
+
1151
+ async def _uploadImage(self, file: Union[File, str]) -> Optional[UploadImageType]:
1152
+ task_uuid = getUUID()
1153
+ local_file = True
1154
+ if isinstance(file, str):
1155
+ if os.path.exists(file):
1156
+ local_file = True
1157
+ else:
1158
+ local_file = isLocalFile(file)
1159
+
1160
+ # Check if it's a base64 string (with or without data URI prefix)
1161
+ if file.startswith("data:") or re.match(
1162
+ r"^[A-Za-z0-9+/]+={0,2}$", file
1163
+ ):
1164
+ # Assume it's a base64 string (with or without data URI prefix)
1165
+ local_file = False
1166
+ if not local_file:
1167
+ return UploadImageType(
1168
+ imageUUID=file,
1169
+ imageURL=file,
1170
+ taskUUID=task_uuid,
1171
+ )
1172
+
1173
+ file = await fileToBase64(file)
1174
+
1175
+ await self.send(
1176
+ [
1177
+ {
1178
+ "taskType": ETaskType.IMAGE_UPLOAD.value,
1179
+ "taskUUID": task_uuid,
1180
+ "image": file,
1181
+ }
1182
+ ]
1183
+ )
1184
+
1185
+ lis = self.globalListener(taskUUID=task_uuid)
1186
+
1187
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
1188
+ async with self._messages_lock:
1189
+ uploaded_image_list = self._globalMessages.get(task_uuid)
1190
+ uploaded_image = uploaded_image_list[0] if uploaded_image_list else None
1191
+
1192
+ if uploaded_image and uploaded_image.get("error"):
1193
+ reject(uploaded_image)
1194
+ return True
1195
+
1196
+ if uploaded_image:
1197
+ del self._globalMessages[task_uuid]
1198
+ resolve(uploaded_image)
1199
+ return True
1200
+
1201
+ return False
1202
+
1203
+ response = await getIntervalWithPromise(
1204
+ check, debugKey="upload-image", timeOutDuration=IMAGE_UPLOAD_TIMEOUT
1205
+ )
1206
+
1207
+ lis["destroy"]()
1208
+
1209
+ if "code" in response:
1210
+ # This indicates an error response
1211
+ raise RunwareAPIError(response)
1212
+
1213
+ if response:
1214
+ image = UploadImageType(
1215
+ imageUUID=response["imageUUID"],
1216
+ imageURL=response["imageURL"],
1217
+ taskUUID=response["taskUUID"],
1218
+ )
1219
+ else:
1220
+ image = None
1221
+ return image
1222
+
1223
+ async def uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
1224
+ try:
1225
+ await self.ensureConnection()
1226
+ return await asyncRetry(lambda: self._uploadMedia(media_url))
1227
+ except Exception as e:
1228
+ raise e
1229
+
1230
+ async def _uploadMedia(self, media_url: str) -> Optional[MediaStorageType]:
1231
+ task_uuid = getUUID()
1232
+ local_file = True
1233
+
1234
+ if isinstance(media_url, str):
1235
+ if os.path.exists(media_url):
1236
+ # Local file - convert to base64
1237
+ media_url = await fileToBase64(media_url)
1238
+ # Strip the data URI prefix for media storage API
1239
+ if media_url.startswith("data:"):
1240
+ media_url = media_url.split(",", 1)[1]
1241
+ # For URLs and base64 strings, send them directly to the API
1242
+
1243
+ await self.send(
1244
+ [
1245
+ {
1246
+ "taskType": ETaskType.MEDIA_STORAGE.value,
1247
+ "taskUUID": task_uuid,
1248
+ "operation": "upload",
1249
+ "media": media_url,
1250
+ }
1251
+ ]
1252
+ )
1253
+
1254
+ lis = self.globalListener(taskUUID=task_uuid)
1255
+
1256
+ def check(resolve: callable, reject: callable, *args: Any) -> bool:
1257
+ uploaded_media_list = self._globalMessages.get(task_uuid)
1258
+ uploaded_media = uploaded_media_list[0] if uploaded_media_list else None
1259
+
1260
+ if uploaded_media and uploaded_media.get("error"):
1261
+ reject(uploaded_media)
1262
+ return True
1263
+
1264
+ if uploaded_media:
1265
+ del self._globalMessages[task_uuid]
1266
+ resolve(uploaded_media)
1267
+ return True
1268
+
1269
+ return False
1270
+
1271
+ response = await getIntervalWithPromise(
1272
+ check, debugKey="upload-media", timeOutDuration=self._timeout
1273
+ )
1274
+
1275
+ lis["destroy"]()
1276
+
1277
+ if "code" in response:
1278
+ # This indicates an error response
1279
+ raise RunwareAPIError(response)
1280
+
1281
+ if response:
1282
+ media = MediaStorageType(
1283
+ mediaUUID=response["mediaUUID"],
1284
+ taskUUID=response["taskUUID"],
1285
+ )
1286
+ else:
1287
+ media = None
1288
+ return media
1289
+
1290
+ async def uploadUnprocessedImage(
1291
+ self,
1292
+ file: Union[File, str],
1293
+ preProcessorType: EPreProcessorGroup,
1294
+ width: int = None,
1295
+ height: int = None,
1296
+ lowThresholdCanny: int = None,
1297
+ highThresholdCanny: int = None,
1298
+ includeHandsAndFaceOpenPose: bool = True,
1299
+ ) -> Optional[UploadImageType]:
1300
+ # Create a dummy UploadImageType object
1301
+ uploaded_unprocessed_image = UploadImageType(
1302
+ imageUUID=str(uuid.uuid4()),
1303
+ imageURL="https://example.com/uploaded_unprocessed_image.jpg",
1304
+ taskUUID=str(uuid.uuid4()),
1305
+ )
1306
+
1307
+ return uploaded_unprocessed_image
1308
+
1309
+ async def listenToImages(
1310
+ self,
1311
+ onPartialImages: Optional[Callable[[List[IImage], Optional[IError]], None]],
1312
+ taskUUID: str,
1313
+ groupKey: LISTEN_TO_IMAGES_KEY,
1314
+ ) -> Dict[str, Callable[[], None]]:
1315
+ logger.debug("Setting up images listener for taskUUID: %s", taskUUID)
1316
+
1317
+ async def listen_to_images_lis(m: Dict[str, Any]) -> None:
1318
+ if isinstance(m.get("data"), list):
1319
+ images = [
1320
+ img
1321
+ for img in m["data"]
1322
+ if img.get("taskType") in ["imageInference", "vectorize"]
1323
+ and img.get("taskUUID") == taskUUID
1324
+ ]
1325
+
1326
+ if images:
1327
+ async with self._images_lock:
1328
+ self._globalImages.extend(images)
1329
+
1330
+ try:
1331
+ partial_images = instantiateDataclassList(IImage, images)
1332
+ if onPartialImages:
1333
+ onPartialImages(
1334
+ partial_images, None
1335
+ )
1336
+ except Exception as e:
1337
+ logger.error(
1338
+ f"Error occurred in user on_partial_images callback function: {e}"
1339
+ )
1340
+ elif isinstance(m.get("errors"), list):
1341
+ errors = [
1342
+ error for error in m["errors"] if error.get("taskUUID") == taskUUID
1343
+ ]
1344
+ if errors:
1345
+ error = IError(
1346
+ error=True,
1347
+ error_message=errors[0].get("message", "Unknown error"),
1348
+ task_uuid=errors[0].get("taskUUID", ""),
1349
+ error_code=errors[0].get("code"),
1350
+ error_type=errors[0].get("type"),
1351
+ parameter=errors[0].get("parameter"),
1352
+ documentation=errors[0].get("documentation"),
1353
+ )
1354
+ self._globalError = error
1355
+ if onPartialImages:
1356
+ onPartialImages(
1357
+ [], self._globalError
1358
+ )
1359
+
1360
+ def listen_to_images_check(m):
1361
+ logger.debug("Images check message: %s", m)
1362
+ image_inference_check = isinstance(m.get("data"), list) and any(
1363
+ item.get("taskType") in ["imageInference", "vectorize"] for item in m["data"]
1364
+ )
1365
+ error_check = isinstance(m.get("errors"), list) and any(
1366
+ error.get("taskUUID") == taskUUID for error in m["errors"]
1367
+ )
1368
+ error_code_check = (
1369
+ True
1370
+ if any([error.get("code") for error in m.get("errors", [])])
1371
+ else False
1372
+ )
1373
+ if error_code_check:
1374
+ self._globalError = IError(
1375
+ error=True,
1376
+ error_message=f"Error in image inference: {m.get('errors')}",
1377
+ task_uuid=taskUUID,
1378
+ )
1379
+
1380
+ response = image_inference_check or error_check
1381
+ return response
1382
+
1383
+ temp_listener = self.addListener(
1384
+ check=listen_to_images_check,
1385
+ lis=self._create_safe_async_listener(listen_to_images_lis),
1386
+ groupKey=groupKey
1387
+ )
1388
+
1389
+ logger.debug("listenToImages :: Temp listener: %s", temp_listener)
1390
+
1391
+ return temp_listener
1392
+
1393
+ def globalListener(self, taskUUID: str) -> Dict[str, Callable[[], None]]:
1394
+ """
1395
+ Set up a global listener to capture specific messages based on the provided taskUUID.
1396
+
1397
+ :param taskUUID: The unique identifier of the task associated with the listener.
1398
+ :return: A dictionary containing a 'destroy' function to remove the listener.
1399
+ """
1400
+ logger.debug("Setting up global listener for taskUUID: %s", taskUUID)
1401
+
1402
+ async def global_lis(m: Dict[str, Any]) -> None:
1403
+ logger.debug("Global listener message: %s", m)
1404
+ logger.debug("Global listener taskUUID: %s", taskUUID)
1405
+
1406
+ async with self._messages_lock:
1407
+ if m.get("error"):
1408
+ self._globalMessages[taskUUID] = m
1409
+ return
1410
+
1411
+ value = accessDeepObject(
1412
+ taskUUID, m
1413
+ )
1414
+
1415
+ if isinstance(value, list):
1416
+ for v in value:
1417
+ self._globalMessages[v["taskUUID"]] = self._globalMessages.get(
1418
+ v["taskUUID"], []
1419
+ ) + [v]
1420
+ logger.debug("Global messages v: %s", v)
1421
+ logger.debug(
1422
+ "self._globalMessages[v[taskUUID]]: %s",
1423
+ self._globalMessages[v["taskUUID"]],
1424
+ )
1425
+ else:
1426
+ self._globalMessages[value["taskUUID"]] = value
1427
+
1428
+ def global_check(m):
1429
+ logger.debug("Global check message: %s", m)
1430
+ return accessDeepObject(taskUUID, m)
1431
+
1432
+ logger.debug("Global Listener taskUUID: %s", taskUUID)
1433
+
1434
+ temp_listener = self.addListener(check=global_check, lis=self._create_safe_async_listener(global_lis))
1435
+ logger.debug("globalListener :: Temp listener: %s", temp_listener)
1436
+
1437
+ return temp_listener
1438
+
1439
+ async def handleIncompleteImages(
1440
+ self, taskUUIDs: List[str], error: Any
1441
+ ) -> Optional[List[IImage]]:
1442
+ """
1443
+ Handle scenarios where the requested number of images is not fully received.
1444
+
1445
+ :param taskUUIDs: A list of task UUIDs to filter the images.
1446
+ :param error: The error object to raise if there are no or only one image.
1447
+ :return: A list of available images if there are more than one, otherwise None.
1448
+ :raises: The provided error if there are no or only one image.
1449
+ """
1450
+ async with self._images_lock:
1451
+ imagesWithSimilarTask = [
1452
+ img for img in self._globalImages if img["taskUUID"] in taskUUIDs
1453
+ ]
1454
+ if len(imagesWithSimilarTask) > 1:
1455
+ self._globalImages = [
1456
+ img for img in self._globalImages if img["taskUUID"] not in taskUUIDs
1457
+ ]
1458
+ return imagesWithSimilarTask
1459
+ else:
1460
+ raise error
1461
+
1462
+ async def ensureConnection(self) -> None:
1463
+ """
1464
+ Ensure that a connection is established with the server.
1465
+
1466
+ This method checks if the current connection is active and, if not, initiates a new connection.
1467
+ It handles authentication and retries the connection if necessary.
1468
+
1469
+ :raises: An error message if the connection cannot be established due to an invalid API key or other reasons.
1470
+ """
1471
+ isConnected = self.connected() and self._ws.state is State.OPEN
1472
+
1473
+ try:
1474
+ if self._invalidAPIkey:
1475
+ if not self._reconnection_manager._had_successful_auth:
1476
+ raise ConnectionError(self._invalidAPIkey)
1477
+
1478
+ circuit_state = self._reconnection_manager.get_state()
1479
+ if circuit_state == ConnectionState.CIRCUIT_OPEN:
1480
+ raise ConnectionError(self._invalidAPIkey)
1481
+
1482
+ if not isConnected:
1483
+ await self.connect()
1484
+
1485
+ if self._invalidAPIkey and not self._reconnection_manager._had_successful_auth:
1486
+ raise ConnectionError(self._invalidAPIkey)
1487
+
1488
+ except Exception as e:
1489
+ raise ConnectionError(
1490
+ self._invalidAPIkey
1491
+ or "Could not connect to server. Ensure your API key is correct"
1492
+ )
1493
+
1494
+ async def getSimililarImage(
1495
+ self,
1496
+ taskUUID: Union[str, List[str]],
1497
+ numberOfImages: int = 1,
1498
+ shouldThrowError: bool = True,
1499
+ lis: Optional[ListenerType] = None,
1500
+ ) -> List[IImage]:
1501
+ """
1502
+ Retrieve similar images based on the provided task UUID(s).
1503
+
1504
+ :param taskUUID: A single task UUID or a list of task UUIDs.
1505
+ :param numberOfImages: The number of images to retrieve. Defaults to 1.
1506
+ :param shouldThrowError: Whether to raise an error on timeout. Defaults to True.
1507
+ :param lis: Optional listener to destroy upon completion.
1508
+ :param timeout: The timeout duration for the operation.
1509
+ :return: A list of IImage objects representing the images.
1510
+ """
1511
+ taskUUIDs = taskUUID if isinstance(taskUUID, list) else [taskUUID]
1512
+
1513
+ async def check(
1514
+ resolve: Callable[[List[IImage]], None],
1515
+ reject: Callable[[IError], None],
1516
+ intervalId: Any,
1517
+ ) -> Optional[bool]:
1518
+ async with self._images_lock:
1519
+ logger.debug(f"Check # Global images: {self._globalImages}")
1520
+ imagesWithSimilarTask = [
1521
+ img
1522
+ for img in self._globalImages
1523
+ if img.get("taskType") in ["imageInference", "vectorize"]
1524
+ and img.get("taskUUID") in taskUUIDs
1525
+ ]
1526
+
1527
+ if self._globalError:
1528
+ logger.debug(f"Check # _globalError: {self._globalError}")
1529
+ error = self._globalError
1530
+ self._globalError = None
1531
+ logger.debug(f"Rejecting with error: {error}")
1532
+ reject(RunwareError(error))
1533
+ return True
1534
+ elif len(imagesWithSimilarTask) >= numberOfImages:
1535
+ self._globalImages = [
1536
+ img
1537
+ for img in self._globalImages
1538
+ if img.get("taskType") in ["imageInference", "vectorize"]
1539
+ and img.get("taskUUID") not in taskUUIDs
1540
+ ]
1541
+ resolve(imagesWithSimilarTask[:numberOfImages])
1542
+ return True
1543
+
1544
+ return False
1545
+
1546
+ try:
1547
+ return await getIntervalWithPromise(
1548
+ check,
1549
+ debugKey="getting images",
1550
+ shouldThrowError=shouldThrowError,
1551
+ timeOutDuration=IMAGE_INFERENCE_TIMEOUT,
1552
+ )
1553
+ except Exception as e:
1554
+ async with self._images_lock:
1555
+ current_images = len([
1556
+ img for img in self._globalImages
1557
+ if img.get("taskType") in ["imageInference", "vectorize"]
1558
+ and img.get("taskUUID") in taskUUIDs
1559
+ ])
1560
+ error_msg = (
1561
+ f"Timeout waiting for images | "
1562
+ f"TaskUUIDs: {taskUUIDs} | "
1563
+ f"Expected: {numberOfImages} images | "
1564
+ f"Received: {current_images} images | "
1565
+ f"Timeout: {IMAGE_INFERENCE_TIMEOUT}ms | "
1566
+ f"Original error: {str(e)}"
1567
+ )
1568
+ raise Exception(error_msg) from e
1569
+
1570
+ async def _modelUpload(
1571
+ self, requestModel: IUploadModelBaseType
1572
+ ) -> Optional[IUploadModelResponse]:
1573
+ task_uuid = getUUID()
1574
+ base_fields = {
1575
+ "taskType": ETaskType.MODEL_UPLOAD.value,
1576
+ "taskUUID": task_uuid,
1577
+ "air": requestModel.air,
1578
+ "name": requestModel.name,
1579
+ "downloadURL": requestModel.downloadURL,
1580
+ "uniqueIdentifier": requestModel.uniqueIdentifier,
1581
+ "version": requestModel.version,
1582
+ "format": requestModel.format,
1583
+ "private": requestModel.private,
1584
+ "category": requestModel.category,
1585
+ "architecture": requestModel.architecture,
1586
+ }
1587
+
1588
+ optional_fields = [
1589
+ "retry",
1590
+ "heroImageURL",
1591
+ "tags",
1592
+ "shortDescription",
1593
+ "comment",
1594
+ "positiveTriggerWords",
1595
+ "type",
1596
+ "negativeTriggerWords",
1597
+ "defaultWeight",
1598
+ "defaultStrength",
1599
+ "defaultGuidanceScale",
1600
+ "defaultSteps",
1601
+ "defaultScheduler",
1602
+ "conditioning",
1603
+ ]
1604
+
1605
+ request_object = {
1606
+ **base_fields,
1607
+ **{
1608
+ field: getattr(requestModel, field)
1609
+ for field in optional_fields
1610
+ if getattr(requestModel, field, None) is not None
1611
+ },
1612
+ }
1613
+
1614
+ await self.send([request_object])
1615
+
1616
+ lis = self.globalListener(
1617
+ taskUUID=task_uuid,
1618
+ )
1619
+
1620
+ async def check(resolve: callable, reject: callable, *args: Any) -> bool:
1621
+ async with self._messages_lock:
1622
+ uploaded_model_list = self._globalMessages.get(task_uuid, [])
1623
+ unique_statuses = set()
1624
+ all_models = []
1625
+
1626
+ for uploaded_model in uploaded_model_list:
1627
+ if uploaded_model.get("code"):
1628
+ raise RunwareAPIError(uploaded_model)
1629
+
1630
+ status = uploaded_model.get("status")
1631
+
1632
+ if status not in unique_statuses:
1633
+ all_models.append(uploaded_model)
1634
+ unique_statuses.add(status)
1635
+
1636
+ if status is not None and "error" in status:
1637
+ raise RunwareAPIError(uploaded_model)
1638
+
1639
+ if status == "ready":
1640
+ uploaded_model_list.remove(uploaded_model)
1641
+ if not uploaded_model_list:
1642
+ del self._globalMessages[task_uuid]
1643
+ else:
1644
+ self._globalMessages[task_uuid] = uploaded_model_list
1645
+ resolve(all_models)
1646
+ return True
1647
+
1648
+ return False
1649
+
1650
+ response = await getIntervalWithPromise(
1651
+ check, debugKey="upload-model", timeOutDuration=self._timeout
1652
+ )
1653
+
1654
+ lis["destroy"]()
1655
+
1656
+ if "code" in response:
1657
+ # This indicates an error response
1658
+ raise RunwareAPIError(response)
1659
+
1660
+ if response:
1661
+ if not isinstance(response, list):
1662
+ response = [response]
1663
+
1664
+ models = []
1665
+ for item in response:
1666
+ models.append(
1667
+ {
1668
+ "taskType": item.get("taskType"),
1669
+ "taskUUID": item.get("taskUUID"),
1670
+ "status": item.get("status"),
1671
+ "message": item.get("message"),
1672
+ "air": item.get("air"),
1673
+ }
1674
+ )
1675
+ else:
1676
+ models = None
1677
+ return models
1678
+
1679
+ async def modelUpload(
1680
+ self, requestModel: IUploadModelBaseType
1681
+ ) -> Optional[IUploadModelResponse]:
1682
+ try:
1683
+ await self.ensureConnection()
1684
+ return await asyncRetry(lambda: self._modelUpload(requestModel))
1685
+ except Exception as e:
1686
+ raise e
1687
+
1688
+ async def modelSearch(self, payload: IModelSearch) -> IModelSearchResponse:
1689
+ try:
1690
+ await self.ensureConnection()
1691
+ task_uuid = getUUID()
1692
+
1693
+ request_object = {
1694
+ "taskUUID": task_uuid,
1695
+ "taskType": ETaskType.MODEL_SEARCH.value,
1696
+ **({"tags": payload.tags} if payload.tags else {}),
1697
+ }
1698
+
1699
+ request_object.update(
1700
+ {
1701
+ key: value
1702
+ for key, value in vars(payload).items()
1703
+ if value is not None and key != "additional_params"
1704
+ }
1705
+ )
1706
+
1707
+ await self.send([request_object])
1708
+
1709
+ listener = self.globalListener(taskUUID=task_uuid)
1710
+
1711
+ async def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
1712
+ async with self._messages_lock:
1713
+ response = self._globalMessages.get(task_uuid)
1714
+ if response:
1715
+ if response[0].get("error"):
1716
+ reject(response[0])
1717
+ return True
1718
+ del self._globalMessages[task_uuid]
1719
+ resolve(response[0])
1720
+ return True
1721
+ return False
1722
+
1723
+ response = await getIntervalWithPromise(
1724
+ check, debugKey="model-search", timeOutDuration=self._timeout
1725
+ )
1726
+
1727
+ listener["destroy"]()
1728
+
1729
+ if "code" in response:
1730
+ # This indicates an error response
1731
+ raise RunwareAPIError(response)
1732
+
1733
+ return instantiateDataclass(IModelSearchResponse, response)
1734
+
1735
+ except Exception as e:
1736
+ if isinstance(e, RunwareAPIError):
1737
+ raise
1738
+
1739
+ raise RunwareAPIError({"message": str(e)})
1740
+
1741
+ async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
1742
+ await self.ensureConnection()
1743
+ return await asyncRetry(lambda: self._requestVideo(requestVideo))
1744
+
1745
+ async def getResponse(
1746
+ self,
1747
+ taskUUID: str,
1748
+ numberResults: int = 1
1749
+ ) -> List[IVideo]:
1750
+ await self.ensureConnection()
1751
+ return await self._pollVideoResults(taskUUID, numberResults, IVideo)
1752
+
1753
+ async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]:
1754
+ await self._processVideoImages(requestVideo)
1755
+ requestVideo.taskUUID = requestVideo.taskUUID or getUUID()
1756
+ request_object = self._buildVideoRequest(requestVideo)
1757
+
1758
+ if requestVideo.webhookURL:
1759
+ request_object["webhookURL"] = requestVideo.webhookURL
1760
+
1761
+ await self.send([request_object])
1762
+
1763
+ if requestVideo.skipResponse:
1764
+ return IAsyncTaskResponse(
1765
+ taskType=ETaskType.VIDEO_INFERENCE.value,
1766
+ taskUUID=requestVideo.taskUUID
1767
+ )
1768
+
1769
+ return await self._handleInitialVideoResponse(
1770
+ requestVideo.taskUUID,
1771
+ requestVideo.numberResults,
1772
+ request_object.get("webhookURL")
1773
+ )
1774
+
1775
+ async def _processVideoImages(self, requestVideo: IVideoInference) -> None:
1776
+ frame_tasks = []
1777
+ reference_tasks = []
1778
+
1779
+ if requestVideo.frameImages:
1780
+ frame_tasks = [
1781
+ process_image(frame_item.inputImage)
1782
+ for frame_item in requestVideo.frameImages
1783
+ if isinstance(frame_item, IFrameImage)
1784
+ ]
1785
+
1786
+ if requestVideo.referenceImages:
1787
+ reference_tasks = [
1788
+ process_image(reference_item)
1789
+ for reference_item in requestVideo.referenceImages
1790
+ ]
1791
+
1792
+ frame_results = await gather(*frame_tasks) if frame_tasks else []
1793
+ reference_results = await gather(*reference_tasks) if reference_tasks else []
1794
+
1795
+ if requestVideo.frameImages and frame_results:
1796
+ processed_frame_images = []
1797
+ result_index = 0
1798
+ for frame_item in requestVideo.frameImages:
1799
+ if isinstance(frame_item, IFrameImage):
1800
+ frame_item.inputImages = frame_results[result_index]
1801
+ result_index += 1
1802
+ processed_frame_images.append(frame_item)
1803
+ requestVideo.frameImages = processed_frame_images
1804
+
1805
+ if requestVideo.referenceImages and reference_results:
1806
+ requestVideo.referenceImages = reference_results
1807
+
1808
+ def _buildVideoRequest(self, requestVideo: IVideoInference) -> Dict[str, Any]:
1809
+ request_object = {
1810
+ "deliveryMethod": requestVideo.deliveryMethod,
1811
+ "taskType": ETaskType.VIDEO_INFERENCE.value,
1812
+ "taskUUID": requestVideo.taskUUID,
1813
+ "model": requestVideo.model,
1814
+ "numberResults": requestVideo.numberResults,
1815
+ }
1816
+
1817
+ # Only add positivePrompt if it's not None
1818
+ if requestVideo.positivePrompt is not None:
1819
+ request_object["positivePrompt"] = requestVideo.positivePrompt.strip()
1820
+
1821
+ self._addOptionalField(request_object, requestVideo.speech)
1822
+ self._addOptionalVideoFields(request_object, requestVideo)
1823
+ self._addVideoImages(request_object, requestVideo)
1824
+ self._addVideoInputs(request_object, requestVideo)
1825
+ self._addProviderSettings(request_object, requestVideo)
1826
+ self._addOptionalField(request_object, requestVideo.safety)
1827
+ self._addOptionalField(request_object, requestVideo.advancedFeatures)
1828
+ self._addOptionalField(request_object, requestVideo.acceleratorOptions)
1829
+
1830
+
1831
+ return request_object
1832
+
1833
+ def _addOptionalVideoFields(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
1834
+ optional_fields = [
1835
+ "outputType", "outputFormat", "outputQuality", "uploadEndpoint",
1836
+ "includeCost", "negativePrompt", "inputAudios", "referenceVideos", "fps", "steps", "seed",
1837
+ "CFGScale", "seedImage", "duration", "width", "height", "nsfw_check",
1838
+ ]
1839
+
1840
+ for field in optional_fields:
1841
+ value = getattr(requestVideo, field, None)
1842
+ if value is not None:
1843
+ request_object[field] = value
1844
+
1845
+ def _addVideoImages(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
1846
+ if requestVideo.frameImages:
1847
+ frame_images_data = []
1848
+ for frame_item in requestVideo.frameImages:
1849
+ frame_images_data.append({k: v for k, v in asdict(frame_item).items() if v is not None})
1850
+ request_object["frameImages"] = frame_images_data
1851
+
1852
+ if requestVideo.referenceImages:
1853
+ request_object["referenceImages"] = requestVideo.referenceImages
1854
+
1855
+ # Add lora if present
1856
+ if requestVideo.lora:
1857
+ request_object["lora"] = [
1858
+ {"model": lora.model, "weight": lora.weight}
1859
+ for lora in requestVideo.lora
1860
+ ]
1861
+
1862
+ def _buildImageRequest(self, requestImage: IImageInference, prompt: str, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> Dict[str, Any]:
1863
+ request_object = {
1864
+ "taskType": ETaskType.IMAGE_INFERENCE.value,
1865
+ "model": requestImage.model,
1866
+ "positivePrompt": prompt,
1867
+ }
1868
+
1869
+ self._addOptionalImageFields(request_object, requestImage)
1870
+ self._addImageSpecialFields(request_object, requestImage, control_net_data_dicts, instant_id_data, ip_adapters_data, ace_plus_plus_data, pulid_data)
1871
+ self._addImageInputs(request_object, requestImage)
1872
+ self._addImageProviderSettings(request_object, requestImage)
1873
+
1874
+ return request_object
1875
+
1876
+ def _addOptionalImageFields(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
1877
+ optional_fields = [
1878
+ "outputType", "outputFormat", "outputQuality", "uploadEndpoint",
1879
+ "includeCost", "checkNsfw", "negativePrompt", "seedImage", "maskImage",
1880
+ "strength", "height", "width", "steps", "scheduler", "seed", "CFGScale",
1881
+ "clipSkip", "promptWeighting", "maskMargin", "vae", "webhookURL", "acceleration"
1882
+ ]
1883
+
1884
+ for field in optional_fields:
1885
+ value = getattr(requestImage, field, None)
1886
+ if value is not None:
1887
+ # Special handling for checkNsfw -> checkNSFW
1888
+ if field == "checkNsfw":
1889
+ request_object["checkNSFW"] = value
1890
+ else:
1891
+ request_object[field] = value
1892
+
1893
+ def _addImageSpecialFields(self, request_object: Dict[str, Any], requestImage: IImageInference, control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> None:
1894
+ # Add controlNet if present
1895
+ if control_net_data_dicts:
1896
+ request_object["controlNet"] = control_net_data_dicts
1897
+
1898
+ # Add lora if present
1899
+ if requestImage.lora:
1900
+ request_object["lora"] = [
1901
+ {"model": lora.model, "weight": lora.weight}
1902
+ for lora in requestImage.lora
1903
+ ]
1904
+
1905
+ # Add lycoris if present
1906
+ if requestImage.lycoris:
1907
+ request_object["lycoris"] = [
1908
+ {"model": lycoris.model, "weight": lycoris.weight}
1909
+ for lycoris in requestImage.lycoris
1910
+ ]
1911
+
1912
+ # Add embeddings if present
1913
+ if requestImage.embeddings:
1914
+ request_object["embeddings"] = [
1915
+ {"model": embedding.model}
1916
+ for embedding in requestImage.embeddings
1917
+ ]
1918
+
1919
+ # Add refiner if present
1920
+ if requestImage.refiner:
1921
+ refiner_dict = {"model": requestImage.refiner.model}
1922
+ if requestImage.refiner.startStep is not None:
1923
+ refiner_dict["startStep"] = requestImage.refiner.startStep
1924
+ if requestImage.refiner.startStepPercentage is not None:
1925
+ refiner_dict["startStepPercentage"] = requestImage.refiner.startStepPercentage
1926
+ request_object["refiner"] = refiner_dict
1927
+
1928
+ # Add instantID if present
1929
+ if instant_id_data:
1930
+ request_object["instantID"] = instant_id_data
1931
+
1932
+ # Add outpaint if present
1933
+ if requestImage.outpaint:
1934
+ outpaint_dict = {
1935
+ k: v
1936
+ for k, v in vars(requestImage.outpaint).items()
1937
+ if v is not None
1938
+ }
1939
+ request_object["outpaint"] = outpaint_dict
1940
+
1941
+ # Add ipAdapters if present
1942
+ if ip_adapters_data:
1943
+ request_object["ipAdapters"] = ip_adapters_data
1944
+
1945
+ # Add acePlusPlus if present
1946
+ if ace_plus_plus_data:
1947
+ request_object["acePlusPlus"] = ace_plus_plus_data
1948
+
1949
+ # Add puLID if present
1950
+ if pulid_data:
1951
+ request_object["puLID"] = pulid_data
1952
+
1953
+ # Add referenceImages if present
1954
+ if requestImage.referenceImages:
1955
+ request_object["referenceImages"] = requestImage.referenceImages
1956
+
1957
+ # Add acceleratorOptions if present
1958
+ self._addOptionalField(request_object, requestImage.acceleratorOptions)
1959
+
1960
+ # Add advancedFeatures if present
1961
+ if requestImage.advancedFeatures:
1962
+ pipeline_options_dict = {
1963
+ k: v.__dict__
1964
+ for k, v in vars(requestImage.advancedFeatures).items()
1965
+ if v is not None
1966
+ }
1967
+ request_object["advancedFeatures"] = pipeline_options_dict
1968
+
1969
+ # Add extraArgs if present
1970
+ if hasattr(requestImage, "extraArgs") and isinstance(requestImage.extraArgs, dict):
1971
+ request_object.update(requestImage.extraArgs)
1972
+
1973
+ def _addImageInputs(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
1974
+ # Add inputs if present
1975
+ if requestImage.inputs:
1976
+ inputs_dict = {
1977
+ k: v for k, v in asdict(requestImage.inputs).items()
1978
+ if v is not None
1979
+ }
1980
+ if inputs_dict:
1981
+ request_object["inputs"] = inputs_dict
1982
+
1983
+ def _addVideoInputs(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
1984
+ # Add inputs if present
1985
+ if requestVideo.inputs:
1986
+ inputs_dict = {
1987
+ k: v for k, v in asdict(requestVideo.inputs).items()
1988
+ if v is not None
1989
+ }
1990
+ if inputs_dict:
1991
+ request_object["inputs"] = inputs_dict
1992
+
1993
+ def _addSafetySettings(self, request_object: Dict[str, Any], safety: ISafety) -> None:
1994
+ safety_dict = asdict(safety)
1995
+ safety_dict = {k: v for k, v in safety_dict.items() if v is not None}
1996
+ if safety_dict:
1997
+ request_object["safety"] = safety_dict
1998
+
1999
+ def _addImageProviderSettings(self, request_object: Dict[str, Any], requestImage: IImageInference) -> None:
2000
+ if not requestImage.providerSettings:
2001
+ return
2002
+ provider_dict = requestImage.providerSettings.to_request_dict()
2003
+ if provider_dict:
2004
+ request_object["providerSettings"] = provider_dict
2005
+
2006
+ def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVideoInference) -> None:
2007
+ if not requestVideo.providerSettings:
2008
+ return
2009
+ provider_dict = requestVideo.providerSettings.to_request_dict()
2010
+ if provider_dict:
2011
+ request_object["providerSettings"] = provider_dict
2012
+
2013
+ def _addOptionalField(self, request_object: Dict[str, Any], obj: Any) -> None:
2014
+ if not obj:
2015
+ return
2016
+ obj_dict = obj.to_request_dict()
2017
+ if obj_dict:
2018
+ request_object.update(obj_dict)
2019
+
2020
+ async def _handleWebhookAcknowledgment(
2021
+ self,
2022
+ task_uuid: str,
2023
+ task_type: str,
2024
+ debug_key: str,
2025
+ ) -> IAsyncTaskResponse:
2026
+ lis = self.globalListener(taskUUID=task_uuid)
2027
+
2028
+ async def check_webhook_ack(resolve: callable, reject: callable, *args: Any) -> bool:
2029
+ async with self._messages_lock:
2030
+ response_list = self._globalMessages.get(task_uuid, [])
2031
+
2032
+ if not response_list:
2033
+ return False
2034
+
2035
+ response = response_list[0] if isinstance(response_list, list) else response_list
2036
+
2037
+ if response.get("code"):
2038
+ raise RunwareAPIError(response)
2039
+
2040
+ if isinstance(response, dict) and response.get("error"):
2041
+ reject(response)
2042
+ return True
2043
+
2044
+ if response.get("taskType") == task_type:
2045
+ del self._globalMessages[task_uuid]
2046
+ async_response = createAsyncTaskResponse(response)
2047
+ resolve(async_response)
2048
+ return True
2049
+
2050
+ return False
2051
+
2052
+ try:
2053
+ response = await getIntervalWithPromise(
2054
+ check_webhook_ack, debugKey=debug_key, timeOutDuration=WEBHOOK_TIMEOUT
2055
+ )
2056
+ finally:
2057
+ lis["destroy"]()
2058
+
2059
+ if isinstance(response, dict) and "code" in response:
2060
+ raise RunwareAPIError(response)
2061
+
2062
+ return response
2063
+
2064
+ async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int, webhook_url: Optional[str] = None) -> Union[List[IVideo], IAsyncTaskResponse]:
2065
+ lis = self.globalListener(taskUUID=task_uuid)
2066
+
2067
+ async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool:
2068
+ async with self._messages_lock:
2069
+ response_list = self._globalMessages.get(task_uuid, [])
2070
+
2071
+ if not response_list:
2072
+ return False
2073
+
2074
+ response = response_list[0]
2075
+
2076
+ if response.get("code"):
2077
+ raise RunwareAPIError(response)
2078
+
2079
+ if response.get("status") == "success":
2080
+ del self._globalMessages[task_uuid]
2081
+ resolve([response])
2082
+ return True
2083
+
2084
+ if not response.get("imageUUID") and webhook_url:
2085
+ del self._globalMessages[task_uuid]
2086
+ async_response = createAsyncTaskResponse(response)
2087
+ resolve([async_response])
2088
+ return True
2089
+
2090
+ del self._globalMessages[task_uuid]
2091
+ resolve("POLL_NEEDED")
2092
+ return True
2093
+
2094
+ return False
2095
+
2096
+ try:
2097
+ initial_response = await getIntervalWithPromise(
2098
+ check_initial_response,
2099
+ debugKey="video-inference-initial",
2100
+ timeOutDuration=VIDEO_INITIAL_TIMEOUT
2101
+ )
2102
+ finally:
2103
+ lis["destroy"]()
2104
+
2105
+ if initial_response == "POLL_NEEDED":
2106
+ return await self._pollVideoResults(task_uuid, number_results)
2107
+ else:
2108
+ if initial_response and len(initial_response) > 0 and isinstance(initial_response[0], IAsyncTaskResponse):
2109
+ return initial_response[0]
2110
+ return instantiateDataclassList(IVideo, initial_response)
2111
+
2112
+ async def _pollVideoResults(self, task_uuid: str, number_results: int, response_cls: IVideo | IVideoToText = IVideo) -> Union[List[IVideo], List[IVideoToText]]:
2113
+ for poll_count in range(MAX_POLLS_VIDEO_GENERATION):
2114
+ try:
2115
+ responses = await self._sendPollRequest(task_uuid, poll_count)
2116
+
2117
+ # Check if there are any error code, if so, raise RunwareAPIError
2118
+ for response in responses:
2119
+ if response.get("code"):
2120
+ raise RunwareAPIError(response)
2121
+
2122
+ # Process responses using the unified method
2123
+ completed_results = self._processVideoPollingResponse(responses)
2124
+
2125
+ if len(completed_results) >= number_results:
2126
+ return instantiateDataclassList(response_cls, completed_results[:number_results])
2127
+
2128
+ if not self._hasPendingVideos(responses) and not completed_results:
2129
+ raise RunwareAPIError({"message": f"Unexpected polling response at poll {poll_count}"})
2130
+
2131
+ except RunwareAPIError:
2132
+ raise
2133
+ except Exception as e:
2134
+ # For other exceptions, only raise on last poll
2135
+ if poll_count >= MAX_POLLS_VIDEO_GENERATION - 1:
2136
+ raise e
2137
+
2138
+ await asyncio.sleep(VIDEO_POLLING_DELAY / 1000)
2139
+
2140
+ # Different timeout messages based on response type
2141
+ timeout_msg = "Timed out"
2142
+ raise RunwareAPIError({"message": timeout_msg})
2143
+
2144
+ async def _sendPollRequest(self, task_uuid: str, poll_count: int) -> List[Dict[str, Any]]:
2145
+ lis = self.globalListener(taskUUID=task_uuid)
2146
+
2147
+ try:
2148
+ await self.send([{
2149
+ "taskType": ETaskType.GET_RESPONSE.value,
2150
+ "taskUUID": task_uuid
2151
+ }])
2152
+
2153
+ async def check_poll_response(resolve: callable, reject: callable, *args: Any) -> bool:
2154
+ async with self._messages_lock:
2155
+ response_list = self._globalMessages.get(task_uuid, [])
2156
+ if response_list:
2157
+ del self._globalMessages[task_uuid]
2158
+ resolve(response_list)
2159
+ return True
2160
+ return False
2161
+
2162
+ return await getIntervalWithPromise(
2163
+ check_poll_response,
2164
+ debugKey=f"video-poll-{poll_count}",
2165
+ timeOutDuration=VIDEO_INITIAL_TIMEOUT
2166
+ )
2167
+ finally:
2168
+ lis["destroy"]()
2169
+
2170
+ def _processVideoPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
2171
+ completed_results = []
2172
+
2173
+ for response in responses:
2174
+ if response.get("code"):
2175
+ raise RunwareAPIError(response)
2176
+ status = response.get("status")
2177
+
2178
+ if status == "success":
2179
+ completed_results.append(response)
2180
+ return completed_results
2181
+
2182
+ def _hasPendingVideos(self, responses: List[Dict[str, Any]]) -> bool:
2183
+ return any(response.get("status") == "processing" for response in responses)
2184
+
2185
+ async def audioInference(self, requestAudio: IAudioInference) -> List[IAudio]:
2186
+ await self.ensureConnection()
2187
+ return await asyncRetry(lambda: self._requestAudio(requestAudio))
2188
+
2189
+ async def _requestAudio(self, requestAudio: IAudioInference) -> List[IAudio]:
2190
+ requestAudio.taskUUID = requestAudio.taskUUID or getUUID()
2191
+ request_object = self._buildAudioRequest(requestAudio)
2192
+ await self.send([request_object])
2193
+ return await self._handleInitialAudioResponse(requestAudio.taskUUID, requestAudio.numberResults)
2194
+
2195
+ def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]:
2196
+ request_object = {
2197
+ "deliveryMethod": requestAudio.deliveryMethod,
2198
+ "taskType": ETaskType.AUDIO_INFERENCE.value,
2199
+ "taskUUID": requestAudio.taskUUID,
2200
+ "model": requestAudio.model,
2201
+ "numberResults": requestAudio.numberResults,
2202
+ }
2203
+
2204
+ # Only add positivePrompt if it's provided
2205
+ if requestAudio.positivePrompt is not None:
2206
+ request_object["positivePrompt"] = requestAudio.positivePrompt.strip()
2207
+
2208
+ # Only add duration if it's provided and not using composition plan
2209
+ if requestAudio.duration is not None:
2210
+ request_object["duration"] = requestAudio.duration
2211
+
2212
+ self._addOptionalAudioFields(request_object, requestAudio)
2213
+ self._addOptionalField(request_object, requestAudio.audioSettings)
2214
+ self._addAudioProviderSettings(request_object, requestAudio)
2215
+
2216
+ return request_object
2217
+
2218
+ def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
2219
+ optional_fields = [
2220
+ "outputType", "outputFormat", "includeCost", "uploadEndpoint", "webhookURL"
2221
+ ]
2222
+
2223
+ for field in optional_fields:
2224
+ value = getattr(requestAudio, field, None)
2225
+ if value is not None:
2226
+ request_object[field] = value
2227
+
2228
+
2229
+ def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
2230
+ if not requestAudio.providerSettings:
2231
+ return
2232
+ provider_dict = requestAudio.providerSettings.to_request_dict()
2233
+ if provider_dict:
2234
+ request_object["providerSettings"] = provider_dict
2235
+
2236
+ async def _handleInitialAudioResponse(self, task_uuid: str, number_results: int) -> List[IAudio]:
2237
+ if number_results == 1:
2238
+ # Single result - wait for completion
2239
+ response = await self._waitForAudioCompletion(task_uuid)
2240
+ return [response] if response else []
2241
+ else:
2242
+ # Multiple results - use polling
2243
+ return await self._pollAudioResults(task_uuid, number_results)
2244
+
2245
+ async def _waitForAudioCompletion(self, task_uuid: str) -> Optional[IAudio]:
2246
+ lis = self.globalListener(taskUUID=task_uuid)
2247
+
2248
+ async def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
2249
+ async with self._messages_lock:
2250
+ response = self._globalMessages.get(task_uuid)
2251
+ if response:
2252
+ audio_response = response[0] if isinstance(response, list) else response
2253
+ else:
2254
+ audio_response = response
2255
+
2256
+ if audio_response and audio_response.get("error"):
2257
+ reject(audio_response)
2258
+ return True
2259
+
2260
+ if audio_response:
2261
+ del self._globalMessages[task_uuid]
2262
+ resolve(audio_response)
2263
+ return True
2264
+
2265
+ return False
2266
+
2267
+ try:
2268
+ response = await getIntervalWithPromise(
2269
+ check, debugKey="audio-inference", timeOutDuration=AUDIO_INFERENCE_TIMEOUT
2270
+ )
2271
+ lis["destroy"]()
2272
+
2273
+ if "code" in response:
2274
+ raise RunwareAPIError(response)
2275
+
2276
+ return self._createAudioFromResponse(response) if response else None
2277
+ except Exception as e:
2278
+ lis["destroy"]()
2279
+ raise e
2280
+
2281
+ async def _pollAudioResults(self, task_uuid: str, number_results: int) -> List[IAudio]:
2282
+ completed_results = []
2283
+ lis = self.globalListener(taskUUID=task_uuid)
2284
+
2285
+ try:
2286
+ for poll_count in range(MAX_POLLS_AUDIO_GENERATION):
2287
+ async with self._messages_lock:
2288
+ responses = self._globalMessages.get(task_uuid, [])
2289
+ if not isinstance(responses, list):
2290
+ responses = [responses] if responses else []
2291
+
2292
+ processed_responses = self._processAudioPollingResponse(responses)
2293
+ completed_results.extend(processed_responses)
2294
+
2295
+ if len(completed_results) >= number_results:
2296
+ break
2297
+
2298
+ if poll_count >= MAX_POLLS_AUDIO_GENERATION - 1:
2299
+ raise RunwareAPIError(
2300
+ {"message": f"Audio generation timeout after {MAX_POLLS_AUDIO_GENERATION} polls"})
2301
+
2302
+ await asyncio.sleep(AUDIO_POLLING_DELAY / 1000)
2303
+
2304
+ finally:
2305
+ lis["destroy"]()
2306
+ async with self._messages_lock:
2307
+ if task_uuid in self._globalMessages:
2308
+ del self._globalMessages[task_uuid]
2309
+
2310
+ return [self._createAudioFromResponse(response) for response in completed_results[:number_results]]
2311
+
2312
+ def _processAudioPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
2313
+ completed_results = []
2314
+
2315
+ for response in responses:
2316
+ if response.get("code"):
2317
+ raise RunwareAPIError(response)
2318
+ status = response.get("status")
2319
+ if status == "success":
2320
+ completed_results.append(response)
2321
+
2322
+ return completed_results
2323
+
2324
+ def _createAudioFromResponse(self, response: Dict[str, Any]) -> IAudio:
2325
+ return IAudio(
2326
+ taskType=response.get("taskType", ""),
2327
+ taskUUID=response.get("taskUUID", ""),
2328
+ status=response.get("status"),
2329
+ audioUUID=response.get("audioUUID"),
2330
+ audioURL=response.get("audioURL"),
2331
+ audioBase64Data=response.get("audioBase64Data"),
2332
+ audioDataURI=response.get("audioDataURI"),
2333
+ cost=response.get("cost")
2334
+ )
2335
+
2336
+ def connected(self) -> bool:
2337
+ """
2338
+ Check if the current WebSocket connection is active and authenticated.
2339
+
2340
+ :return: True if the connection is active and authenticated, False otherwise.
2341
+ """
2342
+ return self.isWebsocketReadyState() and self._connectionSessionUUID is not None