@weirdfingers/baseboards 0.6.2 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +54 -28
- package/dist/index.js.map +1 -1
- package/package.json +1 -1
- package/templates/README.md +2 -0
- package/templates/api/.env.example +3 -0
- package/templates/api/config/generators.yaml +58 -0
- package/templates/api/pyproject.toml +1 -1
- package/templates/api/src/boards/__init__.py +1 -1
- package/templates/api/src/boards/api/endpoints/storage.py +85 -4
- package/templates/api/src/boards/api/endpoints/uploads.py +1 -2
- package/templates/api/src/boards/database/connection.py +98 -58
- package/templates/api/src/boards/generators/implementations/fal/audio/__init__.py +4 -0
- package/templates/api/src/boards/generators/implementations/fal/audio/chatterbox_text_to_speech.py +176 -0
- package/templates/api/src/boards/generators/implementations/fal/audio/chatterbox_tts_turbo.py +195 -0
- package/templates/api/src/boards/generators/implementations/fal/image/__init__.py +14 -0
- package/templates/api/src/boards/generators/implementations/fal/image/bytedance_seedream_v45_edit.py +219 -0
- package/templates/api/src/boards/generators/implementations/fal/image/gemini_25_flash_image_edit.py +208 -0
- package/templates/api/src/boards/generators/implementations/fal/image/gpt_image_15_edit.py +216 -0
- package/templates/api/src/boards/generators/implementations/fal/image/gpt_image_1_5.py +177 -0
- package/templates/api/src/boards/generators/implementations/fal/image/reve_edit.py +178 -0
- package/templates/api/src/boards/generators/implementations/fal/image/reve_text_to_image.py +155 -0
- package/templates/api/src/boards/generators/implementations/fal/image/seedream_v45_text_to_image.py +180 -0
- package/templates/api/src/boards/generators/implementations/fal/video/__init__.py +18 -0
- package/templates/api/src/boards/generators/implementations/fal/video/kling_video_ai_avatar_v2_pro.py +168 -0
- package/templates/api/src/boards/generators/implementations/fal/video/kling_video_ai_avatar_v2_standard.py +159 -0
- package/templates/api/src/boards/generators/implementations/fal/video/veed_fabric_1_0.py +180 -0
- package/templates/api/src/boards/generators/implementations/fal/video/veo31.py +190 -0
- package/templates/api/src/boards/generators/implementations/fal/video/veo31_fast.py +190 -0
- package/templates/api/src/boards/generators/implementations/fal/video/veo31_fast_image_to_video.py +191 -0
- package/templates/api/src/boards/generators/implementations/fal/video/veo31_first_last_frame_to_video.py +13 -6
- package/templates/api/src/boards/generators/implementations/fal/video/wan_25_preview_image_to_video.py +212 -0
- package/templates/api/src/boards/generators/implementations/fal/video/wan_25_preview_text_to_video.py +208 -0
- package/templates/api/src/boards/generators/implementations/kie/__init__.py +11 -0
- package/templates/api/src/boards/generators/implementations/kie/base.py +316 -0
- package/templates/api/src/boards/generators/implementations/kie/image/__init__.py +3 -0
- package/templates/api/src/boards/generators/implementations/kie/image/nano_banana_edit.py +190 -0
- package/templates/api/src/boards/generators/implementations/kie/utils.py +98 -0
- package/templates/api/src/boards/generators/implementations/kie/video/__init__.py +8 -0
- package/templates/api/src/boards/generators/implementations/kie/video/veo3.py +161 -0
- package/templates/api/src/boards/graphql/resolvers/upload.py +1 -1
- package/templates/web/package.json +4 -1
- package/templates/web/src/app/boards/[boardId]/page.tsx +156 -24
- package/templates/web/src/app/globals.css +3 -0
- package/templates/web/src/app/layout.tsx +15 -5
- package/templates/web/src/components/boards/ArtifactInputSlots.tsx +9 -9
- package/templates/web/src/components/boards/ArtifactPreview.tsx +34 -18
- package/templates/web/src/components/boards/GenerationGrid.tsx +101 -7
- package/templates/web/src/components/boards/GenerationInput.tsx +21 -21
- package/templates/web/src/components/boards/GeneratorSelector.tsx +232 -30
- package/templates/web/src/components/boards/UploadArtifact.tsx +385 -75
- package/templates/web/src/components/header.tsx +3 -1
- package/templates/web/src/components/theme-provider.tsx +10 -0
- package/templates/web/src/components/theme-toggle.tsx +75 -0
- package/templates/web/src/components/ui/alert-dialog.tsx +157 -0
- package/templates/web/src/components/ui/toast.tsx +128 -0
- package/templates/web/src/components/ui/toaster.tsx +35 -0
- package/templates/web/src/components/ui/use-toast.ts +186 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wan 2.5 Preview text-to-video generator.
|
|
3
|
+
|
|
4
|
+
A text-to-video generation model that converts text prompts into video content,
|
|
5
|
+
supporting both Chinese and English inputs up to 800 characters.
|
|
6
|
+
|
|
7
|
+
Based on Fal AI's fal-ai/wan-25-preview/text-to-video model.
|
|
8
|
+
See: https://fal.ai/models/fal-ai/wan-25-preview/text-to-video
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
from ....base import BaseGenerator, GeneratorExecutionContext, GeneratorResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Wan25PreviewTextToVideoInput(BaseModel):
|
|
20
|
+
"""Input schema for Wan 2.5 Preview text-to-video generation.
|
|
21
|
+
|
|
22
|
+
Artifact fields are automatically detected via type introspection
|
|
23
|
+
and resolved from generation IDs to artifact objects.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
prompt: str = Field(
|
|
27
|
+
description="Text prompt for video generation. Supports Chinese and English.",
|
|
28
|
+
min_length=1,
|
|
29
|
+
max_length=800,
|
|
30
|
+
)
|
|
31
|
+
aspect_ratio: Literal["16:9", "9:16", "1:1"] = Field(
|
|
32
|
+
default="16:9",
|
|
33
|
+
description="Aspect ratio of the generated video",
|
|
34
|
+
)
|
|
35
|
+
resolution: Literal["480p", "720p", "1080p"] = Field(
|
|
36
|
+
default="1080p",
|
|
37
|
+
description="Resolution of the generated video",
|
|
38
|
+
)
|
|
39
|
+
duration: Literal[5, 10] = Field(
|
|
40
|
+
default=5,
|
|
41
|
+
description="Duration of the generated video in seconds",
|
|
42
|
+
)
|
|
43
|
+
audio_url: str | None = Field(
|
|
44
|
+
default=None,
|
|
45
|
+
description="URL of background audio (WAV/MP3). Must be 3-30 seconds and max 15MB. "
|
|
46
|
+
"Audio longer than video is truncated; shorter audio results in silent sections.",
|
|
47
|
+
)
|
|
48
|
+
seed: int | None = Field(
|
|
49
|
+
default=None,
|
|
50
|
+
description="Random seed for reproducibility",
|
|
51
|
+
)
|
|
52
|
+
enable_safety_checker: bool = Field(
|
|
53
|
+
default=True,
|
|
54
|
+
description="Whether to enable safety filtering",
|
|
55
|
+
)
|
|
56
|
+
negative_prompt: str | None = Field(
|
|
57
|
+
default=None,
|
|
58
|
+
description="Content to avoid in generation",
|
|
59
|
+
max_length=500,
|
|
60
|
+
)
|
|
61
|
+
enable_prompt_expansion: bool = Field(
|
|
62
|
+
default=True,
|
|
63
|
+
description="Whether to expand short prompts for improved results. "
|
|
64
|
+
"Increases processing time but improves quality for short prompts.",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class FalWan25PreviewTextToVideoGenerator(BaseGenerator):
|
|
69
|
+
"""Generator for text-to-video using Wan 2.5 Preview."""
|
|
70
|
+
|
|
71
|
+
name = "fal-wan-25-preview-text-to-video"
|
|
72
|
+
description = (
|
|
73
|
+
"Fal: Wan 2.5 Preview - Text-to-video generation supporting "
|
|
74
|
+
"Chinese/English prompts up to 800 characters"
|
|
75
|
+
)
|
|
76
|
+
artifact_type = "video"
|
|
77
|
+
|
|
78
|
+
def get_input_schema(self) -> type[Wan25PreviewTextToVideoInput]:
|
|
79
|
+
"""Return the input schema for this generator."""
|
|
80
|
+
return Wan25PreviewTextToVideoInput
|
|
81
|
+
|
|
82
|
+
async def generate(
|
|
83
|
+
self, inputs: Wan25PreviewTextToVideoInput, context: GeneratorExecutionContext
|
|
84
|
+
) -> GeneratorResult:
|
|
85
|
+
"""Generate video using fal.ai Wan 2.5 Preview model."""
|
|
86
|
+
# Check for API key
|
|
87
|
+
if not os.getenv("FAL_KEY"):
|
|
88
|
+
raise ValueError("API configuration invalid. Missing FAL_KEY environment variable")
|
|
89
|
+
|
|
90
|
+
# Import fal_client
|
|
91
|
+
try:
|
|
92
|
+
import fal_client
|
|
93
|
+
except ImportError as e:
|
|
94
|
+
raise ImportError(
|
|
95
|
+
"fal.ai SDK is required for FalWan25PreviewTextToVideoGenerator. "
|
|
96
|
+
"Install with: pip install weirdfingers-boards[generators-fal]"
|
|
97
|
+
) from e
|
|
98
|
+
|
|
99
|
+
# Prepare arguments for fal.ai API
|
|
100
|
+
arguments: dict = {
|
|
101
|
+
"prompt": inputs.prompt,
|
|
102
|
+
"aspect_ratio": inputs.aspect_ratio,
|
|
103
|
+
"resolution": inputs.resolution,
|
|
104
|
+
"duration": inputs.duration,
|
|
105
|
+
"enable_safety_checker": inputs.enable_safety_checker,
|
|
106
|
+
"enable_prompt_expansion": inputs.enable_prompt_expansion,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# Add optional parameters
|
|
110
|
+
if inputs.audio_url is not None:
|
|
111
|
+
arguments["audio_url"] = inputs.audio_url
|
|
112
|
+
if inputs.seed is not None:
|
|
113
|
+
arguments["seed"] = inputs.seed
|
|
114
|
+
if inputs.negative_prompt is not None:
|
|
115
|
+
arguments["negative_prompt"] = inputs.negative_prompt
|
|
116
|
+
|
|
117
|
+
# Submit async job
|
|
118
|
+
handler = await fal_client.submit_async(
|
|
119
|
+
"fal-ai/wan-25-preview/text-to-video",
|
|
120
|
+
arguments=arguments,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Store external job ID
|
|
124
|
+
await context.set_external_job_id(handler.request_id)
|
|
125
|
+
|
|
126
|
+
# Stream progress updates
|
|
127
|
+
from .....progress.models import ProgressUpdate
|
|
128
|
+
|
|
129
|
+
event_count = 0
|
|
130
|
+
async for event in handler.iter_events(with_logs=True):
|
|
131
|
+
event_count += 1
|
|
132
|
+
# Sample every 3rd event to avoid spam
|
|
133
|
+
if event_count % 3 == 0:
|
|
134
|
+
# Extract logs if available
|
|
135
|
+
logs = getattr(event, "logs", None)
|
|
136
|
+
if logs:
|
|
137
|
+
# Join log entries into a single message
|
|
138
|
+
if isinstance(logs, list):
|
|
139
|
+
message = " | ".join(str(log) for log in logs if log)
|
|
140
|
+
else:
|
|
141
|
+
message = str(logs)
|
|
142
|
+
|
|
143
|
+
if message:
|
|
144
|
+
await context.publish_progress(
|
|
145
|
+
ProgressUpdate(
|
|
146
|
+
job_id=handler.request_id,
|
|
147
|
+
status="processing",
|
|
148
|
+
progress=50.0, # Approximate mid-point progress
|
|
149
|
+
phase="processing",
|
|
150
|
+
message=message,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Get final result
|
|
155
|
+
result = await handler.get()
|
|
156
|
+
|
|
157
|
+
# Extract video from result
|
|
158
|
+
video_data = result.get("video")
|
|
159
|
+
if not video_data:
|
|
160
|
+
raise ValueError("No video returned from fal.ai API")
|
|
161
|
+
|
|
162
|
+
video_url = video_data.get("url")
|
|
163
|
+
if not video_url:
|
|
164
|
+
raise ValueError("Video missing URL in fal.ai response")
|
|
165
|
+
|
|
166
|
+
# Extract video metadata from response or use defaults
|
|
167
|
+
width = video_data.get("width")
|
|
168
|
+
height = video_data.get("height")
|
|
169
|
+
duration = video_data.get("duration")
|
|
170
|
+
fps = video_data.get("fps")
|
|
171
|
+
|
|
172
|
+
# If dimensions not provided, determine based on aspect ratio and resolution
|
|
173
|
+
if width is None or height is None:
|
|
174
|
+
resolution_dimensions = {
|
|
175
|
+
"480p": {"16:9": (854, 480), "9:16": (480, 854), "1:1": (480, 480)},
|
|
176
|
+
"720p": {"16:9": (1280, 720), "9:16": (720, 1280), "1:1": (720, 720)},
|
|
177
|
+
"1080p": {"16:9": (1920, 1080), "9:16": (1080, 1920), "1:1": (1080, 1080)},
|
|
178
|
+
}
|
|
179
|
+
dims = resolution_dimensions.get(inputs.resolution, {}).get(
|
|
180
|
+
inputs.aspect_ratio, (1920, 1080)
|
|
181
|
+
)
|
|
182
|
+
width, height = dims
|
|
183
|
+
|
|
184
|
+
# Store video result
|
|
185
|
+
artifact = await context.store_video_result(
|
|
186
|
+
storage_url=video_url,
|
|
187
|
+
format="mp4",
|
|
188
|
+
width=width,
|
|
189
|
+
height=height,
|
|
190
|
+
duration=float(duration) if duration else float(inputs.duration),
|
|
191
|
+
fps=fps,
|
|
192
|
+
output_index=0,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return GeneratorResult(outputs=[artifact])
|
|
196
|
+
|
|
197
|
+
async def estimate_cost(self, inputs: Wan25PreviewTextToVideoInput) -> float:
|
|
198
|
+
"""Estimate cost for Wan 2.5 Preview generation.
|
|
199
|
+
|
|
200
|
+
Pricing information not provided in official documentation.
|
|
201
|
+
Estimated at $0.10 per 5-second video based on typical video generation costs.
|
|
202
|
+
Cost scales with duration.
|
|
203
|
+
"""
|
|
204
|
+
# Base cost per 5-second video
|
|
205
|
+
base_cost = 0.10
|
|
206
|
+
# Scale by duration: 5s = 1x, 10s = 2x
|
|
207
|
+
duration_multiplier = inputs.duration / 5
|
|
208
|
+
return base_cost * duration_multiplier
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Kie.ai generator implementations."""
|
|
2
|
+
|
|
3
|
+
from .image.nano_banana_edit import KieNanoBananaEditGenerator, NanoBananaEditInput
|
|
4
|
+
from .video.veo3 import KieVeo3Generator, KieVeo3Input
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"KieNanoBananaEditGenerator",
|
|
8
|
+
"NanoBananaEditInput",
|
|
9
|
+
"KieVeo3Generator",
|
|
10
|
+
"KieVeo3Input",
|
|
11
|
+
]
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Base classes for Kie.ai generators.
|
|
2
|
+
|
|
3
|
+
Provides common functionality shared across all Kie.ai generator implementations,
|
|
4
|
+
including API key validation, HTTP client setup, response validation, and polling logic.
|
|
5
|
+
|
|
6
|
+
Kie.ai supports two API patterns:
|
|
7
|
+
- Market API: Unified endpoint for 30+ models using /api/v1/jobs endpoints
|
|
8
|
+
- Dedicated API: Model-specific endpoints with custom paths
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import os
|
|
13
|
+
from abc import abstractmethod
|
|
14
|
+
from typing import Any, ClassVar, Literal
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
|
|
18
|
+
from ....progress.models import ProgressUpdate
|
|
19
|
+
from ...base import BaseGenerator, GeneratorExecutionContext
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class KieBaseGenerator(BaseGenerator):
|
|
23
|
+
"""Base class for all Kie.ai generators with common functionality.
|
|
24
|
+
|
|
25
|
+
Provides shared methods for API key management, HTTP requests,
|
|
26
|
+
response validation, and external job ID storage.
|
|
27
|
+
|
|
28
|
+
Subclasses must define:
|
|
29
|
+
- api_pattern: Either "market" or "dedicated"
|
|
30
|
+
- model_id: Model identifier (for market) or endpoint path (for dedicated)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# Subclasses must define these
|
|
34
|
+
api_pattern: ClassVar[Literal["market", "dedicated"]]
|
|
35
|
+
model_id: str
|
|
36
|
+
|
|
37
|
+
def _get_api_key(self) -> str:
|
|
38
|
+
"""Get and validate KIE_API_KEY from environment.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The API key string
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If KIE_API_KEY is not set
|
|
45
|
+
"""
|
|
46
|
+
api_key = os.getenv("KIE_API_KEY")
|
|
47
|
+
if not api_key:
|
|
48
|
+
raise ValueError("API configuration invalid. Missing KIE_API_KEY environment variable")
|
|
49
|
+
return api_key
|
|
50
|
+
|
|
51
|
+
def _validate_response(self, response: dict[str, Any]) -> None:
|
|
52
|
+
"""Validate standard Kie.ai response structure.
|
|
53
|
+
|
|
54
|
+
All Kie.ai APIs return responses with a "code" field where 200 indicates success.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
response: The JSON response from Kie.ai API
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If the response code is not 200
|
|
61
|
+
"""
|
|
62
|
+
if response.get("code") != 200:
|
|
63
|
+
error_msg = response.get("msg", "Unknown error")
|
|
64
|
+
raise ValueError(f"Kie.ai API error: {error_msg}")
|
|
65
|
+
|
|
66
|
+
async def _make_request(
|
|
67
|
+
self,
|
|
68
|
+
url: str,
|
|
69
|
+
method: Literal["GET", "POST"],
|
|
70
|
+
api_key: str,
|
|
71
|
+
json: dict[str, Any] | None = None,
|
|
72
|
+
timeout: float = 30.0,
|
|
73
|
+
) -> dict[str, Any]:
|
|
74
|
+
"""Make HTTP request to Kie.ai API with standard error handling.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
url: Full URL to request
|
|
78
|
+
method: HTTP method (GET or POST)
|
|
79
|
+
api_key: API key for authorization
|
|
80
|
+
json: Request body for POST requests
|
|
81
|
+
timeout: Request timeout in seconds
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The validated JSON response
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If the request fails or returns an error response
|
|
88
|
+
"""
|
|
89
|
+
async with httpx.AsyncClient() as client:
|
|
90
|
+
if method == "POST":
|
|
91
|
+
response = await client.post(
|
|
92
|
+
url,
|
|
93
|
+
json=json,
|
|
94
|
+
headers={
|
|
95
|
+
"Authorization": f"Bearer {api_key}",
|
|
96
|
+
"Content-Type": "application/json",
|
|
97
|
+
},
|
|
98
|
+
timeout=timeout,
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
response = await client.get(
|
|
102
|
+
url,
|
|
103
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
104
|
+
timeout=timeout,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if response.status_code != 200:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Kie.ai API request failed: {response.status_code} {response.text}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
result = response.json()
|
|
113
|
+
self._validate_response(result)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
async def _poll_for_completion(
|
|
118
|
+
self,
|
|
119
|
+
task_id: str,
|
|
120
|
+
api_key: str,
|
|
121
|
+
context: GeneratorExecutionContext,
|
|
122
|
+
) -> dict[str, Any]:
|
|
123
|
+
"""Poll for task completion.
|
|
124
|
+
|
|
125
|
+
Subclasses implement this based on their API pattern (Market vs Dedicated).
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
task_id: The task ID to poll
|
|
129
|
+
api_key: API key for authorization
|
|
130
|
+
context: Generator execution context for progress updates
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The completed task data containing results
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValueError: If polling fails or task fails
|
|
137
|
+
"""
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class KieMarketAPIGenerator(KieBaseGenerator):
|
|
142
|
+
"""Base class for Kie.ai Market API generators.
|
|
143
|
+
|
|
144
|
+
Market API is used for 30+ models through a unified endpoint.
|
|
145
|
+
- Submit: POST /api/v1/jobs/createTask with model parameter
|
|
146
|
+
- Status: GET /api/v1/jobs/recordInfo?taskId={id}
|
|
147
|
+
- Status field: "state" with values: "waiting", "pending", "processing", "success", "failed"
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
api_pattern: ClassVar[Literal["market"]] = "market"
|
|
151
|
+
|
|
152
|
+
async def _poll_for_completion(
|
|
153
|
+
self,
|
|
154
|
+
task_id: str,
|
|
155
|
+
api_key: str,
|
|
156
|
+
context: GeneratorExecutionContext,
|
|
157
|
+
max_polls: int = 120,
|
|
158
|
+
poll_interval: int = 10,
|
|
159
|
+
) -> dict[str, Any]:
|
|
160
|
+
"""Poll Market API for task completion using state field.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
task_id: The task ID to poll
|
|
164
|
+
api_key: API key for authorization
|
|
165
|
+
context: Generator execution context for progress updates
|
|
166
|
+
max_polls: Maximum number of polling attempts (default: 120 = 20 minutes)
|
|
167
|
+
poll_interval: Seconds between polls (default: 10)
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
The completed task data from the "data" field
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If task fails or times out
|
|
174
|
+
"""
|
|
175
|
+
status_url = f"https://api.kie.ai/api/v1/jobs/recordInfo?taskId={task_id}"
|
|
176
|
+
|
|
177
|
+
async with httpx.AsyncClient() as client:
|
|
178
|
+
for poll_count in range(max_polls):
|
|
179
|
+
# Don't sleep on first poll - check status immediately
|
|
180
|
+
if poll_count > 0:
|
|
181
|
+
await asyncio.sleep(poll_interval)
|
|
182
|
+
|
|
183
|
+
status_response = await client.get(
|
|
184
|
+
status_url,
|
|
185
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
186
|
+
timeout=30.0,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if status_response.status_code != 200:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Status check failed: {status_response.status_code} {status_response.text}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
status_result = status_response.json()
|
|
195
|
+
self._validate_response(status_result)
|
|
196
|
+
|
|
197
|
+
task_data = status_result.get("data", {})
|
|
198
|
+
state = task_data.get("state")
|
|
199
|
+
|
|
200
|
+
if state == "success":
|
|
201
|
+
return task_data
|
|
202
|
+
elif state == "failed":
|
|
203
|
+
error_msg = task_data.get("failMsg", "Unknown error")
|
|
204
|
+
raise ValueError(f"Generation failed: {error_msg}")
|
|
205
|
+
elif state not in ["waiting", "pending", "processing", None]:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Unknown state '{state}' from Kie.ai API. Full response: {status_result}"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Publish progress
|
|
211
|
+
progress = min(90, (poll_count / max_polls) * 100)
|
|
212
|
+
await context.publish_progress(
|
|
213
|
+
ProgressUpdate(
|
|
214
|
+
job_id=task_id,
|
|
215
|
+
status="processing",
|
|
216
|
+
progress=progress,
|
|
217
|
+
phase="processing",
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
timeout_minutes = (max_polls * poll_interval) / 60
|
|
222
|
+
raise ValueError(f"Generation timed out after {timeout_minutes} minutes")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class KieDedicatedAPIGenerator(KieBaseGenerator):
|
|
226
|
+
"""Base class for Kie.ai Dedicated API generators.
|
|
227
|
+
|
|
228
|
+
Dedicated APIs have model-specific endpoints with custom paths.
|
|
229
|
+
- Submit: POST /api/v1/{model}/generate (no model parameter in body)
|
|
230
|
+
- Status: GET /api/v1/{model}/record-info?taskId={id}
|
|
231
|
+
- Status field: "successFlag" with values: 0 (processing), 1 (success), 2/3 (failed)
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
api_pattern: ClassVar[Literal["dedicated"]] = "dedicated"
|
|
235
|
+
|
|
236
|
+
@abstractmethod
|
|
237
|
+
def _get_status_url(self, task_id: str) -> str:
|
|
238
|
+
"""Get the status check URL for this specific dedicated API.
|
|
239
|
+
|
|
240
|
+
Each dedicated API has its own status endpoint path.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
task_id: The task ID to check status for
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Full URL for status checking
|
|
247
|
+
"""
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
async def _poll_for_completion(
|
|
251
|
+
self,
|
|
252
|
+
task_id: str,
|
|
253
|
+
api_key: str,
|
|
254
|
+
context: GeneratorExecutionContext,
|
|
255
|
+
max_polls: int = 180,
|
|
256
|
+
poll_interval: int = 10,
|
|
257
|
+
) -> dict[str, Any]:
|
|
258
|
+
"""Poll Dedicated API for task completion using successFlag field.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
task_id: The task ID to poll
|
|
262
|
+
api_key: API key for authorization
|
|
263
|
+
context: Generator execution context for progress updates
|
|
264
|
+
max_polls: Maximum number of polling attempts (default: 180 = 30 minutes)
|
|
265
|
+
poll_interval: Seconds between polls (default: 10)
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The completed task data from the "data" field
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
ValueError: If task fails or times out
|
|
272
|
+
"""
|
|
273
|
+
status_url = self._get_status_url(task_id)
|
|
274
|
+
|
|
275
|
+
async with httpx.AsyncClient() as client:
|
|
276
|
+
for poll_count in range(max_polls):
|
|
277
|
+
# Don't sleep on first poll - check status immediately
|
|
278
|
+
if poll_count > 0:
|
|
279
|
+
await asyncio.sleep(poll_interval)
|
|
280
|
+
|
|
281
|
+
status_response = await client.get(
|
|
282
|
+
status_url,
|
|
283
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
284
|
+
timeout=30.0,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if status_response.status_code != 200:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
f"Status check failed: {status_response.status_code} {status_response.text}"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
status_result = status_response.json()
|
|
293
|
+
self._validate_response(status_result)
|
|
294
|
+
|
|
295
|
+
task_data = status_result.get("data", {})
|
|
296
|
+
success_flag = task_data.get("successFlag")
|
|
297
|
+
|
|
298
|
+
if success_flag == 1:
|
|
299
|
+
return task_data
|
|
300
|
+
elif success_flag in [2, 3]:
|
|
301
|
+
error_msg = task_data.get("errorMsg", "Unknown error")
|
|
302
|
+
raise ValueError(f"Generation failed: {error_msg}")
|
|
303
|
+
|
|
304
|
+
# Publish progress
|
|
305
|
+
progress = min(90, (poll_count / max_polls) * 100)
|
|
306
|
+
await context.publish_progress(
|
|
307
|
+
ProgressUpdate(
|
|
308
|
+
job_id=task_id,
|
|
309
|
+
status="processing",
|
|
310
|
+
progress=progress,
|
|
311
|
+
phase="processing",
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
timeout_minutes = (max_polls * poll_interval) / 60
|
|
316
|
+
raise ValueError(f"Generation timed out after {timeout_minutes} minutes")
|