nous-genai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nous/__init__.py +3 -0
- nous/genai/__init__.py +56 -0
- nous/genai/__main__.py +3 -0
- nous/genai/_internal/__init__.py +1 -0
- nous/genai/_internal/capability_rules.py +476 -0
- nous/genai/_internal/config.py +102 -0
- nous/genai/_internal/errors.py +63 -0
- nous/genai/_internal/http.py +951 -0
- nous/genai/_internal/json_schema.py +54 -0
- nous/genai/cli.py +1316 -0
- nous/genai/client.py +719 -0
- nous/genai/mcp_cli.py +275 -0
- nous/genai/mcp_server.py +1080 -0
- nous/genai/providers/__init__.py +15 -0
- nous/genai/providers/aliyun.py +535 -0
- nous/genai/providers/anthropic.py +483 -0
- nous/genai/providers/gemini.py +1606 -0
- nous/genai/providers/openai.py +1909 -0
- nous/genai/providers/tuzi.py +1158 -0
- nous/genai/providers/volcengine.py +273 -0
- nous/genai/reference/__init__.py +17 -0
- nous/genai/reference/catalog.py +206 -0
- nous/genai/reference/mappings.py +467 -0
- nous/genai/reference/mode_overrides.py +26 -0
- nous/genai/reference/model_catalog.py +82 -0
- nous/genai/reference/model_catalog_data/__init__.py +1 -0
- nous/genai/reference/model_catalog_data/aliyun.py +98 -0
- nous/genai/reference/model_catalog_data/anthropic.py +10 -0
- nous/genai/reference/model_catalog_data/google.py +45 -0
- nous/genai/reference/model_catalog_data/openai.py +44 -0
- nous/genai/reference/model_catalog_data/tuzi_anthropic.py +21 -0
- nous/genai/reference/model_catalog_data/tuzi_google.py +19 -0
- nous/genai/reference/model_catalog_data/tuzi_openai.py +75 -0
- nous/genai/reference/model_catalog_data/tuzi_web.py +136 -0
- nous/genai/reference/model_catalog_data/volcengine.py +107 -0
- nous/genai/tools/__init__.py +13 -0
- nous/genai/tools/output_parser.py +119 -0
- nous/genai/types.py +416 -0
- nous/py.typed +1 -0
- nous_genai-0.1.0.dist-info/METADATA +200 -0
- nous_genai-0.1.0.dist-info/RECORD +45 -0
- nous_genai-0.1.0.dist-info/WHEEL +5 -0
- nous_genai-0.1.0.dist-info/entry_points.txt +4 -0
- nous_genai-0.1.0.dist-info/licenses/LICENSE +190 -0
- nous_genai-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1909 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
import time
|
|
8
|
+
import urllib.parse
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any, Iterator
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
from .._internal.capability_rules import (
|
|
14
|
+
chat_input_modalities,
|
|
15
|
+
chat_output_modalities,
|
|
16
|
+
image_input_modalities,
|
|
17
|
+
infer_model_kind,
|
|
18
|
+
is_transcribe_model,
|
|
19
|
+
output_modalities_for_kind,
|
|
20
|
+
transcribe_input_modalities,
|
|
21
|
+
video_input_modalities,
|
|
22
|
+
)
|
|
23
|
+
from .._internal.errors import (
|
|
24
|
+
invalid_request_error,
|
|
25
|
+
not_supported_error,
|
|
26
|
+
provider_error,
|
|
27
|
+
)
|
|
28
|
+
from .._internal.http import (
|
|
29
|
+
download_to_tempfile,
|
|
30
|
+
multipart_form_data,
|
|
31
|
+
multipart_form_data_fields,
|
|
32
|
+
request_bytes,
|
|
33
|
+
request_json,
|
|
34
|
+
request_stream_json_sse,
|
|
35
|
+
request_streaming_body_json,
|
|
36
|
+
)
|
|
37
|
+
from ..types import (
|
|
38
|
+
Capability,
|
|
39
|
+
GenerateEvent,
|
|
40
|
+
GenerateRequest,
|
|
41
|
+
GenerateResponse,
|
|
42
|
+
JobInfo,
|
|
43
|
+
Message,
|
|
44
|
+
Part,
|
|
45
|
+
PartSourceBytes,
|
|
46
|
+
PartSourcePath,
|
|
47
|
+
PartSourceRef,
|
|
48
|
+
PartSourceUrl,
|
|
49
|
+
Usage,
|
|
50
|
+
bytes_to_base64,
|
|
51
|
+
detect_mime_type,
|
|
52
|
+
file_to_bytes,
|
|
53
|
+
normalize_reasoning_effort,
|
|
54
|
+
sniff_image_mime_type,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
_OPENAI_DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
|
59
|
+
|
|
60
|
+
_INLINE_BYTES_LIMIT = 20 * 1024 * 1024
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_mcp_transport() -> bool:
|
|
64
|
+
value = os.environ.get("NOUS_GENAI_TRANSPORT", "").strip().lower()
|
|
65
|
+
return value in {"mcp", "sse", "streamable", "streamable-http", "streamable_http"}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _download_image_url_as_data_url(
|
|
69
|
+
url: str,
|
|
70
|
+
*,
|
|
71
|
+
mime_type: str | None,
|
|
72
|
+
timeout_ms: int | None,
|
|
73
|
+
proxy_url: str | None,
|
|
74
|
+
) -> str:
|
|
75
|
+
suffix = urllib.parse.urlparse(url).path
|
|
76
|
+
ext = os.path.splitext(suffix)[1]
|
|
77
|
+
tmp = download_to_tempfile(
|
|
78
|
+
url=url,
|
|
79
|
+
timeout_ms=timeout_ms,
|
|
80
|
+
max_bytes=_INLINE_BYTES_LIMIT,
|
|
81
|
+
suffix=ext if ext else "",
|
|
82
|
+
proxy_url=proxy_url,
|
|
83
|
+
)
|
|
84
|
+
try:
|
|
85
|
+
data = file_to_bytes(tmp, _INLINE_BYTES_LIMIT)
|
|
86
|
+
if not mime_type:
|
|
87
|
+
mime_type = detect_mime_type(tmp) or sniff_image_mime_type(data)
|
|
88
|
+
if not mime_type:
|
|
89
|
+
raise invalid_request_error(
|
|
90
|
+
"could not infer image mime_type from url content"
|
|
91
|
+
)
|
|
92
|
+
b64 = bytes_to_base64(data)
|
|
93
|
+
return f"data:{mime_type};base64,{b64}"
|
|
94
|
+
finally:
|
|
95
|
+
try:
|
|
96
|
+
os.unlink(tmp)
|
|
97
|
+
except OSError:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _audio_format_from_mime(mime_type: str | None) -> str | None:
|
|
102
|
+
if mime_type is None:
|
|
103
|
+
return None
|
|
104
|
+
mime_type = mime_type.lower()
|
|
105
|
+
if mime_type in {"audio/wav", "audio/wave"}:
|
|
106
|
+
return "wav"
|
|
107
|
+
if mime_type in {"audio/mpeg", "audio/mp3"}:
|
|
108
|
+
return "mp3"
|
|
109
|
+
if mime_type in {"audio/mp4", "audio/m4a"}:
|
|
110
|
+
return "m4a"
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _audio_mime_from_format(fmt: str) -> str:
|
|
115
|
+
f = fmt.strip().lower()
|
|
116
|
+
if f == "mp3":
|
|
117
|
+
return "audio/mpeg"
|
|
118
|
+
if f == "wav":
|
|
119
|
+
return "audio/wav"
|
|
120
|
+
if f in {"m4a", "mp4"}:
|
|
121
|
+
return "audio/mp4"
|
|
122
|
+
if f == "aac":
|
|
123
|
+
return "audio/aac"
|
|
124
|
+
if f == "flac":
|
|
125
|
+
return "audio/flac"
|
|
126
|
+
if f == "opus":
|
|
127
|
+
return "audio/opus"
|
|
128
|
+
if f == "pcm":
|
|
129
|
+
return "audio/pcm"
|
|
130
|
+
return f"audio/{f}" if f else "application/octet-stream"
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _download_to_temp(
|
|
134
|
+
url: str, *, timeout_ms: int | None, max_bytes: int | None, proxy_url: str | None
|
|
135
|
+
) -> str:
|
|
136
|
+
return download_to_tempfile(
|
|
137
|
+
url=url, timeout_ms=timeout_ms, max_bytes=max_bytes, proxy_url=proxy_url
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _part_to_chat_content(
|
|
142
|
+
part: Part, *, timeout_ms: int | None, provider_name: str, proxy_url: str | None
|
|
143
|
+
) -> dict[str, Any]:
|
|
144
|
+
if part.type == "text":
|
|
145
|
+
return {"type": "text", "text": part.require_text()}
|
|
146
|
+
if part.type == "image":
|
|
147
|
+
source = part.require_source()
|
|
148
|
+
mime_type = part.mime_type
|
|
149
|
+
if mime_type is None and isinstance(source, PartSourcePath):
|
|
150
|
+
mime_type = detect_mime_type(source.path)
|
|
151
|
+
if isinstance(source, PartSourceUrl):
|
|
152
|
+
if not _is_mcp_transport():
|
|
153
|
+
return {"type": "image_url", "image_url": {"url": source.url}}
|
|
154
|
+
data_url = _download_image_url_as_data_url(
|
|
155
|
+
source.url,
|
|
156
|
+
mime_type=mime_type,
|
|
157
|
+
timeout_ms=timeout_ms,
|
|
158
|
+
proxy_url=proxy_url,
|
|
159
|
+
)
|
|
160
|
+
return {"type": "image_url", "image_url": {"url": data_url}}
|
|
161
|
+
if isinstance(source, PartSourceRef):
|
|
162
|
+
raise not_supported_error(
|
|
163
|
+
"openai does not support image ref in chat input; use url/bytes/path"
|
|
164
|
+
)
|
|
165
|
+
if isinstance(source, PartSourceBytes) and source.encoding == "base64":
|
|
166
|
+
if not mime_type:
|
|
167
|
+
raise invalid_request_error("image mime_type required for base64 input")
|
|
168
|
+
b64 = source.data
|
|
169
|
+
if not isinstance(b64, str) or not b64:
|
|
170
|
+
raise invalid_request_error("image base64 data must be non-empty")
|
|
171
|
+
return {
|
|
172
|
+
"type": "image_url",
|
|
173
|
+
"image_url": {"url": f"data:{mime_type};base64,{b64}"},
|
|
174
|
+
}
|
|
175
|
+
if isinstance(source, PartSourcePath):
|
|
176
|
+
data = file_to_bytes(source.path, _INLINE_BYTES_LIMIT)
|
|
177
|
+
else:
|
|
178
|
+
assert isinstance(source, PartSourceBytes)
|
|
179
|
+
raw = source.data
|
|
180
|
+
if not isinstance(raw, bytes):
|
|
181
|
+
raise invalid_request_error("image bytes data must be bytes")
|
|
182
|
+
data = raw
|
|
183
|
+
if len(data) > _INLINE_BYTES_LIMIT:
|
|
184
|
+
raise not_supported_error(
|
|
185
|
+
f"inline bytes too large ({len(data)} > {_INLINE_BYTES_LIMIT})"
|
|
186
|
+
)
|
|
187
|
+
if not mime_type:
|
|
188
|
+
raise invalid_request_error("image mime_type required for bytes/path input")
|
|
189
|
+
b64 = bytes_to_base64(data)
|
|
190
|
+
return {
|
|
191
|
+
"type": "image_url",
|
|
192
|
+
"image_url": {"url": f"data:{mime_type};base64,{b64}"},
|
|
193
|
+
}
|
|
194
|
+
if part.type == "audio":
|
|
195
|
+
source = part.require_source()
|
|
196
|
+
fmt = _audio_format_from_mime(part.mime_type)
|
|
197
|
+
if fmt is None and isinstance(source, PartSourcePath):
|
|
198
|
+
fmt = _audio_format_from_mime(detect_mime_type(source.path))
|
|
199
|
+
if fmt is None:
|
|
200
|
+
raise invalid_request_error(
|
|
201
|
+
"audio format (wav/mp3/m4a) required via mime_type or extension"
|
|
202
|
+
)
|
|
203
|
+
if provider_name == "aliyun":
|
|
204
|
+
if isinstance(source, PartSourceUrl):
|
|
205
|
+
return {"type": "input_audio", "input_audio": {"data": source.url}}
|
|
206
|
+
if isinstance(source, PartSourceRef):
|
|
207
|
+
raise not_supported_error(
|
|
208
|
+
"aliyun does not support audio ref in chat input; use url/bytes/path"
|
|
209
|
+
)
|
|
210
|
+
if isinstance(source, PartSourceBytes) and source.encoding == "base64":
|
|
211
|
+
mime_type = part.mime_type or _audio_mime_from_format(fmt)
|
|
212
|
+
b64 = source.data
|
|
213
|
+
if not isinstance(b64, str) or not b64:
|
|
214
|
+
raise invalid_request_error("audio base64 data must be non-empty")
|
|
215
|
+
return {
|
|
216
|
+
"type": "input_audio",
|
|
217
|
+
"input_audio": {"data": f"data:{mime_type};base64,{b64}"},
|
|
218
|
+
}
|
|
219
|
+
if isinstance(source, PartSourcePath):
|
|
220
|
+
data = file_to_bytes(source.path, _INLINE_BYTES_LIMIT)
|
|
221
|
+
mime_type = (
|
|
222
|
+
detect_mime_type(source.path)
|
|
223
|
+
or part.mime_type
|
|
224
|
+
or _audio_mime_from_format(fmt)
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
assert isinstance(source, PartSourceBytes)
|
|
228
|
+
raw = source.data
|
|
229
|
+
if not isinstance(raw, bytes):
|
|
230
|
+
raise invalid_request_error("audio bytes data must be bytes")
|
|
231
|
+
data = raw
|
|
232
|
+
if len(data) > _INLINE_BYTES_LIMIT:
|
|
233
|
+
raise not_supported_error(
|
|
234
|
+
f"inline bytes too large ({len(data)} > {_INLINE_BYTES_LIMIT})"
|
|
235
|
+
)
|
|
236
|
+
mime_type = part.mime_type or _audio_mime_from_format(fmt)
|
|
237
|
+
b64 = bytes_to_base64(data)
|
|
238
|
+
return {
|
|
239
|
+
"type": "input_audio",
|
|
240
|
+
"input_audio": {"data": f"data:{mime_type};base64,{b64}"},
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
if isinstance(source, PartSourceUrl):
|
|
244
|
+
tmp = download_to_tempfile(
|
|
245
|
+
url=source.url,
|
|
246
|
+
timeout_ms=timeout_ms,
|
|
247
|
+
max_bytes=_INLINE_BYTES_LIMIT,
|
|
248
|
+
proxy_url=proxy_url,
|
|
249
|
+
)
|
|
250
|
+
try:
|
|
251
|
+
data = file_to_bytes(tmp, _INLINE_BYTES_LIMIT)
|
|
252
|
+
finally:
|
|
253
|
+
try:
|
|
254
|
+
os.unlink(tmp)
|
|
255
|
+
except OSError:
|
|
256
|
+
pass
|
|
257
|
+
elif isinstance(source, PartSourceBytes) and source.encoding == "base64":
|
|
258
|
+
b64 = source.data
|
|
259
|
+
if not isinstance(b64, str) or not b64:
|
|
260
|
+
raise invalid_request_error("audio base64 data must be non-empty")
|
|
261
|
+
return {"type": "input_audio", "input_audio": {"data": b64, "format": fmt}}
|
|
262
|
+
elif isinstance(source, PartSourcePath):
|
|
263
|
+
data = file_to_bytes(source.path, _INLINE_BYTES_LIMIT)
|
|
264
|
+
elif isinstance(source, PartSourceRef):
|
|
265
|
+
raise not_supported_error(
|
|
266
|
+
"openai does not support audio ref in chat input; use url/bytes/path"
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
assert isinstance(source, PartSourceBytes)
|
|
270
|
+
raw = source.data
|
|
271
|
+
if not isinstance(raw, bytes):
|
|
272
|
+
raise invalid_request_error("audio bytes data must be bytes")
|
|
273
|
+
data = raw
|
|
274
|
+
if len(data) > _INLINE_BYTES_LIMIT:
|
|
275
|
+
raise not_supported_error(
|
|
276
|
+
f"inline bytes too large ({len(data)} > {_INLINE_BYTES_LIMIT})"
|
|
277
|
+
)
|
|
278
|
+
return {
|
|
279
|
+
"type": "input_audio",
|
|
280
|
+
"input_audio": {"data": bytes_to_base64(data), "format": fmt},
|
|
281
|
+
}
|
|
282
|
+
if part.type in {"video", "embedding"}:
|
|
283
|
+
raise not_supported_error(
|
|
284
|
+
f"openai chat input does not support part type: {part.type}"
|
|
285
|
+
)
|
|
286
|
+
raise not_supported_error(f"unsupported part type: {part.type}")
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _tool_result_to_string(result: Any) -> str:
|
|
290
|
+
if isinstance(result, str):
|
|
291
|
+
return result
|
|
292
|
+
return json.dumps(result, ensure_ascii=False, separators=(",", ":"))
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _tool_call_to_json_arguments(arguments: Any) -> str:
|
|
296
|
+
if isinstance(arguments, str):
|
|
297
|
+
return arguments
|
|
298
|
+
return json.dumps(arguments, ensure_ascii=False, separators=(",", ":"))
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _parse_tool_call_arguments(value: Any) -> Any:
|
|
302
|
+
if not isinstance(value, str):
|
|
303
|
+
return value
|
|
304
|
+
try:
|
|
305
|
+
return json.loads(value)
|
|
306
|
+
except Exception:
|
|
307
|
+
return value
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _require_tool_call_meta(part: Part) -> tuple[str | None, str, Any]:
|
|
311
|
+
tool_call_id = part.meta.get("tool_call_id")
|
|
312
|
+
if tool_call_id is not None and not isinstance(tool_call_id, str):
|
|
313
|
+
raise invalid_request_error("tool_call.meta.tool_call_id must be a string")
|
|
314
|
+
name = part.meta.get("name")
|
|
315
|
+
if not isinstance(name, str) or not name.strip():
|
|
316
|
+
raise invalid_request_error("tool_call.meta.name must be a non-empty string")
|
|
317
|
+
arguments = part.meta.get("arguments")
|
|
318
|
+
return (tool_call_id, name.strip(), arguments)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _require_tool_result_meta(part: Part) -> tuple[str | None, str, Any, bool | None]:
|
|
322
|
+
tool_call_id = part.meta.get("tool_call_id")
|
|
323
|
+
if tool_call_id is not None and not isinstance(tool_call_id, str):
|
|
324
|
+
raise invalid_request_error("tool_result.meta.tool_call_id must be a string")
|
|
325
|
+
name = part.meta.get("name")
|
|
326
|
+
if not isinstance(name, str) or not name.strip():
|
|
327
|
+
raise invalid_request_error("tool_result.meta.name must be a non-empty string")
|
|
328
|
+
result = part.meta.get("result")
|
|
329
|
+
is_error = part.meta.get("is_error")
|
|
330
|
+
if is_error is not None and not isinstance(is_error, bool):
|
|
331
|
+
raise invalid_request_error("tool_result.meta.is_error must be a bool")
|
|
332
|
+
return (tool_call_id, name.strip(), result, is_error)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _part_to_responses_image_content(
|
|
336
|
+
part: Part, *, timeout_ms: int | None, proxy_url: str | None
|
|
337
|
+
) -> dict[str, Any]:
|
|
338
|
+
if part.type != "image":
|
|
339
|
+
raise not_supported_error(
|
|
340
|
+
f"responses protocol does not support part type: {part.type}"
|
|
341
|
+
)
|
|
342
|
+
source = part.require_source()
|
|
343
|
+
mime_type = part.mime_type
|
|
344
|
+
if mime_type is None and isinstance(source, PartSourcePath):
|
|
345
|
+
mime_type = detect_mime_type(source.path)
|
|
346
|
+
if isinstance(source, PartSourceUrl):
|
|
347
|
+
if not _is_mcp_transport():
|
|
348
|
+
return {"type": "input_image", "image_url": source.url}
|
|
349
|
+
data_url = _download_image_url_as_data_url(
|
|
350
|
+
source.url,
|
|
351
|
+
mime_type=mime_type,
|
|
352
|
+
timeout_ms=timeout_ms,
|
|
353
|
+
proxy_url=proxy_url,
|
|
354
|
+
)
|
|
355
|
+
return {"type": "input_image", "image_url": data_url}
|
|
356
|
+
if isinstance(source, PartSourceRef):
|
|
357
|
+
raise not_supported_error(
|
|
358
|
+
"responses protocol does not support image ref in input; use url/bytes/path"
|
|
359
|
+
)
|
|
360
|
+
if isinstance(source, PartSourcePath):
|
|
361
|
+
data = file_to_bytes(source.path, _INLINE_BYTES_LIMIT)
|
|
362
|
+
else:
|
|
363
|
+
assert isinstance(source, PartSourceBytes)
|
|
364
|
+
if source.encoding == "base64":
|
|
365
|
+
b64 = source.data
|
|
366
|
+
if not isinstance(b64, str) or not b64:
|
|
367
|
+
raise invalid_request_error("image base64 data must be non-empty")
|
|
368
|
+
estimated_bytes = (len(b64) * 3) // 4
|
|
369
|
+
if estimated_bytes > _INLINE_BYTES_LIMIT:
|
|
370
|
+
raise not_supported_error(
|
|
371
|
+
f"inline bytes too large ({estimated_bytes} > {_INLINE_BYTES_LIMIT})"
|
|
372
|
+
)
|
|
373
|
+
if not mime_type:
|
|
374
|
+
raise invalid_request_error(
|
|
375
|
+
"image mime_type required for base64 bytes/path input"
|
|
376
|
+
)
|
|
377
|
+
return {
|
|
378
|
+
"type": "input_image",
|
|
379
|
+
"image_url": f"data:{mime_type};base64,{b64}",
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
raw = source.data
|
|
383
|
+
if not isinstance(raw, bytes):
|
|
384
|
+
raise invalid_request_error("image bytes data must be bytes")
|
|
385
|
+
data = raw
|
|
386
|
+
if len(data) > _INLINE_BYTES_LIMIT:
|
|
387
|
+
raise not_supported_error(
|
|
388
|
+
f"inline bytes too large ({len(data)} > {_INLINE_BYTES_LIMIT})"
|
|
389
|
+
)
|
|
390
|
+
if not mime_type:
|
|
391
|
+
raise invalid_request_error("image mime_type required for bytes/path input")
|
|
392
|
+
b64 = bytes_to_base64(data)
|
|
393
|
+
return {"type": "input_image", "image_url": f"data:{mime_type};base64,{b64}"}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _gather_text_inputs(request: GenerateRequest) -> list[str]:
|
|
397
|
+
texts: list[str] = []
|
|
398
|
+
for message in request.input:
|
|
399
|
+
for part in message.content:
|
|
400
|
+
if part.type != "text":
|
|
401
|
+
raise invalid_request_error("embedding requires text-only input")
|
|
402
|
+
texts.append(part.require_text())
|
|
403
|
+
if not texts:
|
|
404
|
+
raise invalid_request_error("embedding requires at least one text part")
|
|
405
|
+
return texts
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _usage_from_openai(obj: dict[str, Any]) -> Usage | None:
|
|
409
|
+
usage = obj.get("usage")
|
|
410
|
+
if not isinstance(usage, dict):
|
|
411
|
+
return None
|
|
412
|
+
return Usage(
|
|
413
|
+
input_tokens=usage.get("prompt_tokens"),
|
|
414
|
+
output_tokens=usage.get("completion_tokens"),
|
|
415
|
+
total_tokens=usage.get("total_tokens"),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _usage_from_openai_responses(obj: dict[str, Any]) -> Usage | None:
|
|
420
|
+
usage = obj.get("usage")
|
|
421
|
+
if not isinstance(usage, dict):
|
|
422
|
+
return None
|
|
423
|
+
return Usage(
|
|
424
|
+
input_tokens=usage.get("input_tokens"),
|
|
425
|
+
output_tokens=usage.get("output_tokens"),
|
|
426
|
+
total_tokens=usage.get("total_tokens"),
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@dataclass(frozen=True, slots=True)
|
|
431
|
+
class OpenAIAdapter:
|
|
432
|
+
api_key: str
|
|
433
|
+
base_url: str = _OPENAI_DEFAULT_BASE_URL
|
|
434
|
+
provider_name: str = "openai"
|
|
435
|
+
chat_api: str = "chat_completions"
|
|
436
|
+
proxy_url: str | None = None
|
|
437
|
+
|
|
438
|
+
def capabilities(self, model_id: str) -> Capability:
|
|
439
|
+
kind = infer_model_kind(model_id)
|
|
440
|
+
kind_out_mods = output_modalities_for_kind(kind)
|
|
441
|
+
|
|
442
|
+
if kind == "video":
|
|
443
|
+
return Capability(
|
|
444
|
+
input_modalities=video_input_modalities(model_id),
|
|
445
|
+
output_modalities=kind_out_mods or {"video"},
|
|
446
|
+
supports_stream=False,
|
|
447
|
+
supports_job=True,
|
|
448
|
+
supports_tools=False,
|
|
449
|
+
supports_json_schema=False,
|
|
450
|
+
)
|
|
451
|
+
if kind == "image":
|
|
452
|
+
return Capability(
|
|
453
|
+
input_modalities=image_input_modalities(model_id),
|
|
454
|
+
output_modalities=kind_out_mods or {"image"},
|
|
455
|
+
supports_stream=False,
|
|
456
|
+
supports_job=False,
|
|
457
|
+
supports_tools=False,
|
|
458
|
+
supports_json_schema=False,
|
|
459
|
+
)
|
|
460
|
+
if kind == "embedding":
|
|
461
|
+
return Capability(
|
|
462
|
+
input_modalities={"text"},
|
|
463
|
+
output_modalities=kind_out_mods or {"embedding"},
|
|
464
|
+
supports_stream=False,
|
|
465
|
+
supports_job=False,
|
|
466
|
+
supports_tools=False,
|
|
467
|
+
supports_json_schema=False,
|
|
468
|
+
)
|
|
469
|
+
if kind == "tts":
|
|
470
|
+
return Capability(
|
|
471
|
+
input_modalities={"text"},
|
|
472
|
+
output_modalities=kind_out_mods or {"audio"},
|
|
473
|
+
supports_stream=False,
|
|
474
|
+
supports_job=False,
|
|
475
|
+
supports_tools=False,
|
|
476
|
+
supports_json_schema=False,
|
|
477
|
+
)
|
|
478
|
+
if kind == "transcribe":
|
|
479
|
+
return Capability(
|
|
480
|
+
input_modalities=transcribe_input_modalities(model_id),
|
|
481
|
+
output_modalities=kind_out_mods or {"text"},
|
|
482
|
+
supports_stream=False,
|
|
483
|
+
supports_job=False,
|
|
484
|
+
supports_tools=False,
|
|
485
|
+
supports_json_schema=False,
|
|
486
|
+
)
|
|
487
|
+
assert kind == "chat"
|
|
488
|
+
if self.chat_api == "responses":
|
|
489
|
+
in_mods = chat_input_modalities(model_id) & {"text", "image"}
|
|
490
|
+
return Capability(
|
|
491
|
+
input_modalities=in_mods,
|
|
492
|
+
output_modalities={"text"},
|
|
493
|
+
supports_stream=True,
|
|
494
|
+
supports_job=False,
|
|
495
|
+
supports_tools=True,
|
|
496
|
+
supports_json_schema=True,
|
|
497
|
+
)
|
|
498
|
+
in_mods = chat_input_modalities(model_id)
|
|
499
|
+
out_mods = chat_output_modalities(model_id)
|
|
500
|
+
return Capability(
|
|
501
|
+
input_modalities=in_mods,
|
|
502
|
+
output_modalities=out_mods,
|
|
503
|
+
supports_stream=True,
|
|
504
|
+
supports_job=False,
|
|
505
|
+
supports_tools=True,
|
|
506
|
+
supports_json_schema=True,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
def list_models(self, *, timeout_ms: int | None = None) -> list[str]:
|
|
510
|
+
"""
|
|
511
|
+
Fetch remote model ids via OpenAI-compatible GET /models.
|
|
512
|
+
"""
|
|
513
|
+
url = f"{self.base_url.rstrip('/')}/models"
|
|
514
|
+
obj = request_json(
|
|
515
|
+
method="GET",
|
|
516
|
+
url=url,
|
|
517
|
+
headers=self._headers(),
|
|
518
|
+
timeout_ms=timeout_ms,
|
|
519
|
+
proxy_url=self.proxy_url,
|
|
520
|
+
)
|
|
521
|
+
data = obj.get("data")
|
|
522
|
+
if not isinstance(data, list):
|
|
523
|
+
return []
|
|
524
|
+
out: list[str] = []
|
|
525
|
+
for item in data:
|
|
526
|
+
if not isinstance(item, dict):
|
|
527
|
+
continue
|
|
528
|
+
mid = item.get("id")
|
|
529
|
+
if isinstance(mid, str) and mid:
|
|
530
|
+
out.append(mid)
|
|
531
|
+
return sorted(set(out))
|
|
532
|
+
|
|
533
|
+
def generate(
|
|
534
|
+
self, request: GenerateRequest, *, stream: bool
|
|
535
|
+
) -> GenerateResponse | Iterator[GenerateEvent]:
|
|
536
|
+
model_id = request.model_id()
|
|
537
|
+
modalities = set(request.output.modalities)
|
|
538
|
+
if "embedding" in modalities:
|
|
539
|
+
if modalities != {"embedding"}:
|
|
540
|
+
raise not_supported_error(
|
|
541
|
+
"embedding cannot be combined with other output modalities"
|
|
542
|
+
)
|
|
543
|
+
if stream:
|
|
544
|
+
raise not_supported_error("embedding does not support streaming")
|
|
545
|
+
return self._embed(request, model_id=model_id)
|
|
546
|
+
|
|
547
|
+
if modalities == {"video"}:
|
|
548
|
+
if stream:
|
|
549
|
+
raise not_supported_error(
|
|
550
|
+
"openai video generation does not support streaming"
|
|
551
|
+
)
|
|
552
|
+
return self._video(request, model_id=model_id)
|
|
553
|
+
|
|
554
|
+
if modalities == {"image"}:
|
|
555
|
+
if stream:
|
|
556
|
+
raise not_supported_error(
|
|
557
|
+
"openai image generation does not support streaming"
|
|
558
|
+
)
|
|
559
|
+
return self._images(request, model_id=model_id)
|
|
560
|
+
|
|
561
|
+
if modalities == {"audio"}:
|
|
562
|
+
if stream:
|
|
563
|
+
raise not_supported_error("openai TTS does not support streaming")
|
|
564
|
+
return self._tts(request, model_id=model_id)
|
|
565
|
+
|
|
566
|
+
if (
|
|
567
|
+
modalities == {"text"}
|
|
568
|
+
and self._is_transcribe_model(model_id)
|
|
569
|
+
and self._has_audio_input(request)
|
|
570
|
+
):
|
|
571
|
+
if stream:
|
|
572
|
+
raise not_supported_error(
|
|
573
|
+
"openai transcription does not support streaming"
|
|
574
|
+
)
|
|
575
|
+
return self._transcribe(request, model_id=model_id)
|
|
576
|
+
|
|
577
|
+
if self.chat_api == "responses":
|
|
578
|
+
if stream:
|
|
579
|
+
if "audio" in modalities:
|
|
580
|
+
raise not_supported_error(
|
|
581
|
+
"responses protocol does not support audio output in this SDK yet"
|
|
582
|
+
)
|
|
583
|
+
return self._responses_stream(request, model_id=model_id)
|
|
584
|
+
if "audio" in modalities:
|
|
585
|
+
raise not_supported_error(
|
|
586
|
+
"responses protocol does not support audio output in this SDK yet"
|
|
587
|
+
)
|
|
588
|
+
return self._responses(request, model_id=model_id)
|
|
589
|
+
|
|
590
|
+
if stream:
|
|
591
|
+
if "audio" in modalities:
|
|
592
|
+
raise not_supported_error(
|
|
593
|
+
"streaming audio output is not supported in this SDK yet"
|
|
594
|
+
)
|
|
595
|
+
return self._chat_stream(request, model_id=model_id)
|
|
596
|
+
return self._chat(request, model_id=model_id)
|
|
597
|
+
|
|
598
|
+
def _headers(self, request: GenerateRequest | None = None) -> dict[str, str]:
|
|
599
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
600
|
+
if request and request.params.idempotency_key:
|
|
601
|
+
headers["Idempotency-Key"] = request.params.idempotency_key
|
|
602
|
+
return headers
|
|
603
|
+
|
|
604
|
+
def _apply_provider_options(
|
|
605
|
+
self, body: dict[str, Any], request: GenerateRequest
|
|
606
|
+
) -> None:
|
|
607
|
+
opts = request.provider_options.get(self.provider_name)
|
|
608
|
+
if not isinstance(opts, dict):
|
|
609
|
+
return
|
|
610
|
+
for k, v in opts.items():
|
|
611
|
+
if k in body:
|
|
612
|
+
raise invalid_request_error(
|
|
613
|
+
f"provider_options cannot override body.{k}"
|
|
614
|
+
)
|
|
615
|
+
body[k] = v
|
|
616
|
+
|
|
617
|
+
def _apply_provider_options_form_fields(
|
|
618
|
+
self, fields: dict[str, str], request: GenerateRequest
|
|
619
|
+
) -> None:
|
|
620
|
+
opts = request.provider_options.get(self.provider_name)
|
|
621
|
+
if not isinstance(opts, dict):
|
|
622
|
+
return
|
|
623
|
+
for k, v in opts.items():
|
|
624
|
+
if v is None:
|
|
625
|
+
continue
|
|
626
|
+
if k in fields:
|
|
627
|
+
raise invalid_request_error(
|
|
628
|
+
f"provider_options cannot override fields.{k}"
|
|
629
|
+
)
|
|
630
|
+
if isinstance(v, bool):
|
|
631
|
+
fields[k] = "true" if v else "false"
|
|
632
|
+
elif isinstance(v, (int, float, str)):
|
|
633
|
+
fields[k] = str(v)
|
|
634
|
+
else:
|
|
635
|
+
fields[k] = json.dumps(v, separators=(",", ":"))
|
|
636
|
+
|
|
637
|
+
def _is_transcribe_model(self, model_id: str) -> bool:
|
|
638
|
+
return is_transcribe_model(model_id)
|
|
639
|
+
|
|
640
|
+
def _has_audio_input(self, request: GenerateRequest) -> bool:
|
|
641
|
+
for m in request.input:
|
|
642
|
+
for p in m.content:
|
|
643
|
+
if p.type == "audio":
|
|
644
|
+
return True
|
|
645
|
+
return False
|
|
646
|
+
|
|
647
|
+
def _text_max_output_tokens(self, request: GenerateRequest) -> int | None:
|
|
648
|
+
spec = request.output.text
|
|
649
|
+
if spec and spec.max_output_tokens is not None:
|
|
650
|
+
return spec.max_output_tokens
|
|
651
|
+
return request.params.max_output_tokens
|
|
652
|
+
|
|
653
|
+
def _chat_response_format(self, request: GenerateRequest) -> dict[str, Any] | None:
|
|
654
|
+
spec = request.output.text
|
|
655
|
+
if spec is None:
|
|
656
|
+
return None
|
|
657
|
+
if spec.format == "text" and spec.json_schema is None:
|
|
658
|
+
return None
|
|
659
|
+
if set(request.output.modalities) != {"text"}:
|
|
660
|
+
raise invalid_request_error("json output requires text-only modality")
|
|
661
|
+
if spec.json_schema is not None:
|
|
662
|
+
return {
|
|
663
|
+
"type": "json_schema",
|
|
664
|
+
"json_schema": {
|
|
665
|
+
"name": "output",
|
|
666
|
+
"schema": spec.json_schema,
|
|
667
|
+
"strict": True,
|
|
668
|
+
},
|
|
669
|
+
}
|
|
670
|
+
return {"type": "json_object"}
|
|
671
|
+
|
|
672
|
+
def _responses_text_format(self, request: GenerateRequest) -> dict[str, Any] | None:
|
|
673
|
+
spec = request.output.text
|
|
674
|
+
if spec is None:
|
|
675
|
+
return None
|
|
676
|
+
if spec.format == "text" and spec.json_schema is None:
|
|
677
|
+
return None
|
|
678
|
+
if set(request.output.modalities) != {"text"}:
|
|
679
|
+
raise invalid_request_error("json output requires text-only modality")
|
|
680
|
+
if spec.json_schema is not None:
|
|
681
|
+
return {
|
|
682
|
+
"type": "json_schema",
|
|
683
|
+
"name": "output",
|
|
684
|
+
"schema": spec.json_schema,
|
|
685
|
+
"strict": True,
|
|
686
|
+
}
|
|
687
|
+
return {"type": "json_object"}
|
|
688
|
+
|
|
689
|
+
def _chat_body(self, request: GenerateRequest, *, model_id: str) -> dict[str, Any]:
|
|
690
|
+
messages: list[dict[str, Any]] = []
|
|
691
|
+
for m in request.input:
|
|
692
|
+
if m.role == "tool":
|
|
693
|
+
tool_parts = [p for p in m.content if p.type == "tool_result"]
|
|
694
|
+
if len(tool_parts) != 1 or len(m.content) != 1:
|
|
695
|
+
raise invalid_request_error(
|
|
696
|
+
"tool messages must contain exactly one tool_result part"
|
|
697
|
+
)
|
|
698
|
+
tool_call_id, _, result, _ = _require_tool_result_meta(tool_parts[0])
|
|
699
|
+
if not tool_call_id:
|
|
700
|
+
raise invalid_request_error(
|
|
701
|
+
"tool_result.meta.tool_call_id required for OpenAI tool messages"
|
|
702
|
+
)
|
|
703
|
+
messages.append(
|
|
704
|
+
{
|
|
705
|
+
"role": "tool",
|
|
706
|
+
"tool_call_id": tool_call_id,
|
|
707
|
+
"content": _tool_result_to_string(result),
|
|
708
|
+
}
|
|
709
|
+
)
|
|
710
|
+
continue
|
|
711
|
+
|
|
712
|
+
tool_calls: list[dict[str, Any]] = []
|
|
713
|
+
content: list[dict[str, Any]] = []
|
|
714
|
+
for p in m.content:
|
|
715
|
+
if p.type == "tool_call":
|
|
716
|
+
if m.role != "assistant":
|
|
717
|
+
raise invalid_request_error(
|
|
718
|
+
"tool_call parts are only allowed in assistant messages"
|
|
719
|
+
)
|
|
720
|
+
tool_call_id, name, arguments = _require_tool_call_meta(p)
|
|
721
|
+
if not tool_call_id:
|
|
722
|
+
raise invalid_request_error(
|
|
723
|
+
"tool_call.meta.tool_call_id required for OpenAI tool calls"
|
|
724
|
+
)
|
|
725
|
+
tool_calls.append(
|
|
726
|
+
{
|
|
727
|
+
"id": tool_call_id,
|
|
728
|
+
"type": "function",
|
|
729
|
+
"function": {
|
|
730
|
+
"name": name,
|
|
731
|
+
"arguments": _tool_call_to_json_arguments(arguments),
|
|
732
|
+
},
|
|
733
|
+
}
|
|
734
|
+
)
|
|
735
|
+
continue
|
|
736
|
+
if p.type == "tool_result":
|
|
737
|
+
raise invalid_request_error(
|
|
738
|
+
"tool_result parts must be sent as role='tool'"
|
|
739
|
+
)
|
|
740
|
+
content.append(
|
|
741
|
+
_part_to_chat_content(
|
|
742
|
+
p,
|
|
743
|
+
timeout_ms=request.params.timeout_ms,
|
|
744
|
+
provider_name=self.provider_name,
|
|
745
|
+
proxy_url=self.proxy_url,
|
|
746
|
+
)
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
msg: dict[str, Any] = {
|
|
750
|
+
"role": m.role,
|
|
751
|
+
"content": content if content else None,
|
|
752
|
+
}
|
|
753
|
+
if tool_calls:
|
|
754
|
+
msg["tool_calls"] = tool_calls
|
|
755
|
+
messages.append(msg)
|
|
756
|
+
|
|
757
|
+
body: dict[str, Any] = {"model": model_id, "messages": messages}
|
|
758
|
+
params = request.params
|
|
759
|
+
if params.temperature is not None:
|
|
760
|
+
body["temperature"] = params.temperature
|
|
761
|
+
if params.top_p is not None:
|
|
762
|
+
body["top_p"] = params.top_p
|
|
763
|
+
if params.seed is not None:
|
|
764
|
+
body["seed"] = params.seed
|
|
765
|
+
if params.reasoning is not None:
|
|
766
|
+
if params.reasoning.effort is not None:
|
|
767
|
+
body["reasoning_effort"] = normalize_reasoning_effort(
|
|
768
|
+
params.reasoning.effort
|
|
769
|
+
)
|
|
770
|
+
max_out = self._text_max_output_tokens(request)
|
|
771
|
+
if max_out is not None:
|
|
772
|
+
body["max_completion_tokens"] = max_out
|
|
773
|
+
if params.stop is not None:
|
|
774
|
+
body["stop"] = params.stop
|
|
775
|
+
resp_fmt = self._chat_response_format(request)
|
|
776
|
+
if resp_fmt is not None:
|
|
777
|
+
body["response_format"] = resp_fmt
|
|
778
|
+
|
|
779
|
+
if request.tools:
|
|
780
|
+
tools: list[dict[str, Any]] = []
|
|
781
|
+
for t in request.tools:
|
|
782
|
+
name = t.name.strip()
|
|
783
|
+
if not name:
|
|
784
|
+
raise invalid_request_error("tool.name must be non-empty")
|
|
785
|
+
fn: dict[str, Any] = {"name": name}
|
|
786
|
+
if isinstance(t.description, str) and t.description.strip():
|
|
787
|
+
fn["description"] = t.description.strip()
|
|
788
|
+
if t.parameters is not None:
|
|
789
|
+
fn["parameters"] = t.parameters
|
|
790
|
+
if t.strict is not None:
|
|
791
|
+
fn["strict"] = bool(t.strict)
|
|
792
|
+
tools.append({"type": "function", "function": fn})
|
|
793
|
+
body["tools"] = tools
|
|
794
|
+
|
|
795
|
+
if request.tool_choice is not None:
|
|
796
|
+
choice = request.tool_choice.normalized()
|
|
797
|
+
if choice.mode in {"required", "tool"} and not request.tools:
|
|
798
|
+
raise invalid_request_error("tool_choice requires request.tools")
|
|
799
|
+
if choice.mode == "tool":
|
|
800
|
+
body["tool_choice"] = {
|
|
801
|
+
"type": "function",
|
|
802
|
+
"function": {"name": choice.name},
|
|
803
|
+
}
|
|
804
|
+
else:
|
|
805
|
+
body["tool_choice"] = choice.mode
|
|
806
|
+
|
|
807
|
+
modalities = request.output.modalities
|
|
808
|
+
if "audio" in modalities:
|
|
809
|
+
audio = request.output.audio
|
|
810
|
+
if audio is None or not audio.voice:
|
|
811
|
+
raise invalid_request_error(
|
|
812
|
+
"output.audio.voice required for audio output"
|
|
813
|
+
)
|
|
814
|
+
fmt = audio.format or "wav"
|
|
815
|
+
body["modalities"] = (
|
|
816
|
+
["audio"] if modalities == ["audio"] else ["text", "audio"]
|
|
817
|
+
)
|
|
818
|
+
body["audio"] = {"voice": audio.voice, "format": fmt}
|
|
819
|
+
|
|
820
|
+
self._apply_provider_options(body, request)
|
|
821
|
+
return body
|
|
822
|
+
|
|
823
|
+
def _chat(self, request: GenerateRequest, *, model_id: str) -> GenerateResponse:
|
|
824
|
+
url = f"{self.base_url}/chat/completions"
|
|
825
|
+
obj = request_json(
|
|
826
|
+
method="POST",
|
|
827
|
+
url=url,
|
|
828
|
+
headers=self._headers(request),
|
|
829
|
+
json_body=self._chat_body(request, model_id=model_id),
|
|
830
|
+
timeout_ms=request.params.timeout_ms,
|
|
831
|
+
proxy_url=self.proxy_url,
|
|
832
|
+
)
|
|
833
|
+
return self._parse_chat_response(
|
|
834
|
+
obj,
|
|
835
|
+
provider=self.provider_name,
|
|
836
|
+
model=f"{self.provider_name}:{model_id}",
|
|
837
|
+
request=request,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
def _responses(
|
|
841
|
+
self, request: GenerateRequest, *, model_id: str
|
|
842
|
+
) -> GenerateResponse:
|
|
843
|
+
url = f"{self.base_url}/responses"
|
|
844
|
+
body: dict[str, Any] = {
|
|
845
|
+
"model": model_id,
|
|
846
|
+
"input": self._responses_input(request),
|
|
847
|
+
}
|
|
848
|
+
params = request.params
|
|
849
|
+
if params.temperature is not None:
|
|
850
|
+
body["temperature"] = params.temperature
|
|
851
|
+
if params.top_p is not None:
|
|
852
|
+
body["top_p"] = params.top_p
|
|
853
|
+
if params.reasoning is not None:
|
|
854
|
+
if params.reasoning.effort is not None:
|
|
855
|
+
body["reasoning"] = {
|
|
856
|
+
"effort": normalize_reasoning_effort(params.reasoning.effort)
|
|
857
|
+
}
|
|
858
|
+
max_out = self._text_max_output_tokens(request)
|
|
859
|
+
if max_out is not None:
|
|
860
|
+
body["max_output_tokens"] = max_out
|
|
861
|
+
text_fmt = self._responses_text_format(request)
|
|
862
|
+
if text_fmt is not None:
|
|
863
|
+
body["text"] = {"format": text_fmt}
|
|
864
|
+
|
|
865
|
+
if request.tools:
|
|
866
|
+
tools: list[dict[str, Any]] = []
|
|
867
|
+
for t in request.tools:
|
|
868
|
+
name = t.name.strip()
|
|
869
|
+
if not name:
|
|
870
|
+
raise invalid_request_error("tool.name must be non-empty")
|
|
871
|
+
tool_obj: dict[str, Any] = {"type": "function", "name": name}
|
|
872
|
+
if isinstance(t.description, str) and t.description.strip():
|
|
873
|
+
tool_obj["description"] = t.description.strip()
|
|
874
|
+
tool_obj["parameters"] = (
|
|
875
|
+
t.parameters if t.parameters is not None else {"type": "object"}
|
|
876
|
+
)
|
|
877
|
+
if t.strict is not None:
|
|
878
|
+
tool_obj["strict"] = bool(t.strict)
|
|
879
|
+
tools.append(tool_obj)
|
|
880
|
+
body["tools"] = tools
|
|
881
|
+
|
|
882
|
+
if request.tool_choice is not None:
|
|
883
|
+
choice = request.tool_choice.normalized()
|
|
884
|
+
if choice.mode in {"required", "tool"} and not request.tools:
|
|
885
|
+
raise invalid_request_error("tool_choice requires request.tools")
|
|
886
|
+
if choice.mode == "tool":
|
|
887
|
+
if self.provider_name.startswith("tuzi"):
|
|
888
|
+
req_tools = request.tools or []
|
|
889
|
+
if len(req_tools) == 1 and req_tools[0].name.strip() == choice.name:
|
|
890
|
+
body["tool_choice"] = "required"
|
|
891
|
+
else:
|
|
892
|
+
raise not_supported_error(
|
|
893
|
+
"tuzi responses protocol does not support tool_choice by name; "
|
|
894
|
+
"use tool_choice.mode='required' with a single tool"
|
|
895
|
+
)
|
|
896
|
+
else:
|
|
897
|
+
body["tool_choice"] = {"type": "function", "name": choice.name}
|
|
898
|
+
else:
|
|
899
|
+
body["tool_choice"] = choice.mode
|
|
900
|
+
|
|
901
|
+
self._apply_provider_options(body, request)
|
|
902
|
+
|
|
903
|
+
obj = request_json(
|
|
904
|
+
method="POST",
|
|
905
|
+
url=url,
|
|
906
|
+
headers=self._headers(request),
|
|
907
|
+
json_body=body,
|
|
908
|
+
timeout_ms=request.params.timeout_ms,
|
|
909
|
+
proxy_url=self.proxy_url,
|
|
910
|
+
)
|
|
911
|
+
return self._parse_responses_response(
|
|
912
|
+
obj, provider=self.provider_name, model=f"{self.provider_name}:{model_id}"
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
def _responses_stream(
|
|
916
|
+
self, request: GenerateRequest, *, model_id: str
|
|
917
|
+
) -> Iterator[GenerateEvent]:
|
|
918
|
+
url = f"{self.base_url}/responses"
|
|
919
|
+
body: dict[str, Any] = {
|
|
920
|
+
"model": model_id,
|
|
921
|
+
"input": self._responses_input(request),
|
|
922
|
+
"stream": True,
|
|
923
|
+
}
|
|
924
|
+
params = request.params
|
|
925
|
+
if params.temperature is not None:
|
|
926
|
+
body["temperature"] = params.temperature
|
|
927
|
+
if params.top_p is not None:
|
|
928
|
+
body["top_p"] = params.top_p
|
|
929
|
+
if params.reasoning is not None:
|
|
930
|
+
if params.reasoning.effort is not None:
|
|
931
|
+
body["reasoning"] = {
|
|
932
|
+
"effort": normalize_reasoning_effort(params.reasoning.effort)
|
|
933
|
+
}
|
|
934
|
+
max_out = self._text_max_output_tokens(request)
|
|
935
|
+
if max_out is not None:
|
|
936
|
+
body["max_output_tokens"] = max_out
|
|
937
|
+
text_fmt = self._responses_text_format(request)
|
|
938
|
+
if text_fmt is not None:
|
|
939
|
+
body["text"] = {"format": text_fmt}
|
|
940
|
+
|
|
941
|
+
if request.tools:
|
|
942
|
+
tools: list[dict[str, Any]] = []
|
|
943
|
+
for t in request.tools:
|
|
944
|
+
name = t.name.strip()
|
|
945
|
+
if not name:
|
|
946
|
+
raise invalid_request_error("tool.name must be non-empty")
|
|
947
|
+
tool_obj: dict[str, Any] = {"type": "function", "name": name}
|
|
948
|
+
if isinstance(t.description, str) and t.description.strip():
|
|
949
|
+
tool_obj["description"] = t.description.strip()
|
|
950
|
+
tool_obj["parameters"] = (
|
|
951
|
+
t.parameters if t.parameters is not None else {"type": "object"}
|
|
952
|
+
)
|
|
953
|
+
if t.strict is not None:
|
|
954
|
+
tool_obj["strict"] = bool(t.strict)
|
|
955
|
+
tools.append(tool_obj)
|
|
956
|
+
body["tools"] = tools
|
|
957
|
+
|
|
958
|
+
if request.tool_choice is not None:
|
|
959
|
+
choice = request.tool_choice.normalized()
|
|
960
|
+
if choice.mode in {"required", "tool"} and not request.tools:
|
|
961
|
+
raise invalid_request_error("tool_choice requires request.tools")
|
|
962
|
+
if choice.mode == "tool":
|
|
963
|
+
if self.provider_name.startswith("tuzi"):
|
|
964
|
+
req_tools = request.tools or []
|
|
965
|
+
if len(req_tools) == 1 and req_tools[0].name.strip() == choice.name:
|
|
966
|
+
body["tool_choice"] = "required"
|
|
967
|
+
else:
|
|
968
|
+
raise not_supported_error(
|
|
969
|
+
"tuzi responses protocol does not support tool_choice by name; "
|
|
970
|
+
"use tool_choice.mode='required' with a single tool"
|
|
971
|
+
)
|
|
972
|
+
else:
|
|
973
|
+
body["tool_choice"] = {"type": "function", "name": choice.name}
|
|
974
|
+
else:
|
|
975
|
+
body["tool_choice"] = choice.mode
|
|
976
|
+
|
|
977
|
+
self._apply_provider_options(body, request)
|
|
978
|
+
|
|
979
|
+
events = request_stream_json_sse(
|
|
980
|
+
method="POST",
|
|
981
|
+
url=url,
|
|
982
|
+
headers=self._headers(request),
|
|
983
|
+
json_body=body,
|
|
984
|
+
timeout_ms=request.params.timeout_ms,
|
|
985
|
+
proxy_url=self.proxy_url,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
def _iter() -> Iterator[GenerateEvent]:
|
|
989
|
+
for obj in events:
|
|
990
|
+
typ = obj.get("type")
|
|
991
|
+
if typ == "response.output_text.delta":
|
|
992
|
+
delta = obj.get("delta")
|
|
993
|
+
if isinstance(delta, str) and delta:
|
|
994
|
+
yield GenerateEvent(
|
|
995
|
+
type="output.text.delta", data={"delta": delta}
|
|
996
|
+
)
|
|
997
|
+
continue
|
|
998
|
+
if typ == "response.completed":
|
|
999
|
+
break
|
|
1000
|
+
if typ == "response.incomplete":
|
|
1001
|
+
resp = obj.get("response")
|
|
1002
|
+
reason: str | None = None
|
|
1003
|
+
if isinstance(resp, dict):
|
|
1004
|
+
details = resp.get("incomplete_details")
|
|
1005
|
+
if isinstance(details, dict) and isinstance(
|
|
1006
|
+
details.get("reason"), str
|
|
1007
|
+
):
|
|
1008
|
+
reason = details["reason"]
|
|
1009
|
+
msg = "responses returned status: incomplete"
|
|
1010
|
+
if reason:
|
|
1011
|
+
msg = f"{msg} ({reason})"
|
|
1012
|
+
raise provider_error(msg, retryable=False)
|
|
1013
|
+
if typ == "response.failed":
|
|
1014
|
+
resp = obj.get("response")
|
|
1015
|
+
code: str | None = None
|
|
1016
|
+
msg = "responses returned status: failed"
|
|
1017
|
+
if isinstance(resp, dict):
|
|
1018
|
+
err = resp.get("error")
|
|
1019
|
+
if isinstance(err, dict):
|
|
1020
|
+
err_code = err.get("code")
|
|
1021
|
+
if isinstance(err_code, str) and err_code:
|
|
1022
|
+
code = err_code
|
|
1023
|
+
err_msg = err.get("message")
|
|
1024
|
+
if isinstance(err_msg, str) and err_msg:
|
|
1025
|
+
msg = err_msg
|
|
1026
|
+
raise provider_error(
|
|
1027
|
+
msg[:2_000], provider_code=code, retryable=False
|
|
1028
|
+
)
|
|
1029
|
+
if typ == "error":
|
|
1030
|
+
code = obj.get("code") if isinstance(obj.get("code"), str) else None
|
|
1031
|
+
err_msg = obj.get("message")
|
|
1032
|
+
msg = (
|
|
1033
|
+
err_msg
|
|
1034
|
+
if isinstance(err_msg, str) and err_msg
|
|
1035
|
+
else "responses stream error"
|
|
1036
|
+
)
|
|
1037
|
+
raise provider_error(
|
|
1038
|
+
msg[:2_000], provider_code=code, retryable=False
|
|
1039
|
+
)
|
|
1040
|
+
yield GenerateEvent(type="done", data={})
|
|
1041
|
+
|
|
1042
|
+
return _iter()
|
|
1043
|
+
|
|
1044
|
+
def _responses_input(self, request: GenerateRequest) -> list[dict[str, Any]]:
|
|
1045
|
+
items: list[dict[str, Any]] = []
|
|
1046
|
+
for m in request.input:
|
|
1047
|
+
if m.role == "tool":
|
|
1048
|
+
tool_parts = [p for p in m.content if p.type == "tool_result"]
|
|
1049
|
+
if len(tool_parts) != 1 or len(m.content) != 1:
|
|
1050
|
+
raise invalid_request_error(
|
|
1051
|
+
"tool messages must contain exactly one tool_result part"
|
|
1052
|
+
)
|
|
1053
|
+
tool_call_id, _, result, _ = _require_tool_result_meta(tool_parts[0])
|
|
1054
|
+
if not tool_call_id:
|
|
1055
|
+
raise invalid_request_error(
|
|
1056
|
+
"tool_result.meta.tool_call_id required for responses protocol"
|
|
1057
|
+
)
|
|
1058
|
+
items.append(
|
|
1059
|
+
{
|
|
1060
|
+
"type": "function_call_output",
|
|
1061
|
+
"call_id": tool_call_id,
|
|
1062
|
+
"output": _tool_result_to_string(result),
|
|
1063
|
+
}
|
|
1064
|
+
)
|
|
1065
|
+
continue
|
|
1066
|
+
role = m.role
|
|
1067
|
+
content: list[dict[str, Any]] = []
|
|
1068
|
+
for p in m.content:
|
|
1069
|
+
if p.type == "text":
|
|
1070
|
+
content.append({"type": "input_text", "text": p.require_text()})
|
|
1071
|
+
continue
|
|
1072
|
+
if p.type == "image":
|
|
1073
|
+
content.append(
|
|
1074
|
+
_part_to_responses_image_content(
|
|
1075
|
+
p,
|
|
1076
|
+
timeout_ms=request.params.timeout_ms,
|
|
1077
|
+
proxy_url=self.proxy_url,
|
|
1078
|
+
)
|
|
1079
|
+
)
|
|
1080
|
+
continue
|
|
1081
|
+
if p.type in {"tool_call", "tool_result"}:
|
|
1082
|
+
raise not_supported_error(
|
|
1083
|
+
"responses protocol does not support tool parts in message input"
|
|
1084
|
+
)
|
|
1085
|
+
raise not_supported_error(
|
|
1086
|
+
f"responses protocol does not support input part: {p.type}"
|
|
1087
|
+
)
|
|
1088
|
+
items.append({"role": role, "content": content})
|
|
1089
|
+
return items
|
|
1090
|
+
|
|
1091
|
+
def _parse_responses_response(
|
|
1092
|
+
self, obj: dict[str, Any], *, provider: str, model: str
|
|
1093
|
+
) -> GenerateResponse:
|
|
1094
|
+
resp_id = obj.get("id") or f"sdk_{uuid4().hex}"
|
|
1095
|
+
status = obj.get("status")
|
|
1096
|
+
if status != "completed":
|
|
1097
|
+
raise provider_error(f"responses returned status: {status}")
|
|
1098
|
+
output = obj.get("output")
|
|
1099
|
+
if not isinstance(output, list):
|
|
1100
|
+
raise provider_error("responses missing output")
|
|
1101
|
+
|
|
1102
|
+
parts: list[Part] = []
|
|
1103
|
+
for item in output:
|
|
1104
|
+
if not isinstance(item, dict):
|
|
1105
|
+
continue
|
|
1106
|
+
typ = item.get("type")
|
|
1107
|
+
if typ == "message":
|
|
1108
|
+
content = item.get("content")
|
|
1109
|
+
if not isinstance(content, list):
|
|
1110
|
+
continue
|
|
1111
|
+
for c in content:
|
|
1112
|
+
if not isinstance(c, dict):
|
|
1113
|
+
continue
|
|
1114
|
+
if c.get("type") == "output_text" and isinstance(
|
|
1115
|
+
c.get("text"), str
|
|
1116
|
+
):
|
|
1117
|
+
parts.append(Part.from_text(c["text"]))
|
|
1118
|
+
continue
|
|
1119
|
+
if typ == "function_call":
|
|
1120
|
+
call_id = item.get("call_id")
|
|
1121
|
+
name = item.get("name")
|
|
1122
|
+
arguments = item.get("arguments")
|
|
1123
|
+
if (
|
|
1124
|
+
isinstance(call_id, str)
|
|
1125
|
+
and call_id
|
|
1126
|
+
and isinstance(name, str)
|
|
1127
|
+
and name
|
|
1128
|
+
):
|
|
1129
|
+
parts.append(
|
|
1130
|
+
Part.tool_call(
|
|
1131
|
+
tool_call_id=call_id,
|
|
1132
|
+
name=name,
|
|
1133
|
+
arguments=_parse_tool_call_arguments(arguments),
|
|
1134
|
+
)
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
if not parts:
|
|
1138
|
+
parts.append(Part.from_text(""))
|
|
1139
|
+
usage = _usage_from_openai_responses(obj)
|
|
1140
|
+
return GenerateResponse(
|
|
1141
|
+
id=str(resp_id),
|
|
1142
|
+
provider=provider,
|
|
1143
|
+
model=model,
|
|
1144
|
+
status="completed",
|
|
1145
|
+
output=[Message(role="assistant", content=parts)],
|
|
1146
|
+
usage=usage,
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
def _chat_stream(
|
|
1150
|
+
self, request: GenerateRequest, *, model_id: str
|
|
1151
|
+
) -> Iterator[GenerateEvent]:
|
|
1152
|
+
url = f"{self.base_url}/chat/completions"
|
|
1153
|
+
body = self._chat_body(request, model_id=model_id)
|
|
1154
|
+
body["stream"] = True
|
|
1155
|
+
events = request_stream_json_sse(
|
|
1156
|
+
method="POST",
|
|
1157
|
+
url=url,
|
|
1158
|
+
headers=self._headers(request),
|
|
1159
|
+
json_body=body,
|
|
1160
|
+
timeout_ms=request.params.timeout_ms,
|
|
1161
|
+
proxy_url=self.proxy_url,
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
def _iter() -> Iterator[GenerateEvent]:
|
|
1165
|
+
for obj in events:
|
|
1166
|
+
choices = obj.get("choices")
|
|
1167
|
+
if not isinstance(choices, list) or not choices:
|
|
1168
|
+
continue
|
|
1169
|
+
delta = choices[0].get("delta")
|
|
1170
|
+
if not isinstance(delta, dict):
|
|
1171
|
+
continue
|
|
1172
|
+
text = delta.get("content")
|
|
1173
|
+
if isinstance(text, str) and text:
|
|
1174
|
+
yield GenerateEvent(type="output.text.delta", data={"delta": text})
|
|
1175
|
+
yield GenerateEvent(type="done", data={})
|
|
1176
|
+
|
|
1177
|
+
return _iter()
|
|
1178
|
+
|
|
1179
|
+
def _parse_chat_response(
|
|
1180
|
+
self,
|
|
1181
|
+
obj: dict[str, Any],
|
|
1182
|
+
*,
|
|
1183
|
+
provider: str,
|
|
1184
|
+
model: str,
|
|
1185
|
+
request: GenerateRequest | None = None,
|
|
1186
|
+
) -> GenerateResponse:
|
|
1187
|
+
resp_id = obj.get("id") or f"sdk_{uuid4().hex}"
|
|
1188
|
+
choices = obj.get("choices")
|
|
1189
|
+
if not isinstance(choices, list) or not choices:
|
|
1190
|
+
raise provider_error("openai chat response missing choices")
|
|
1191
|
+
msg = choices[0].get("message")
|
|
1192
|
+
if not isinstance(msg, dict):
|
|
1193
|
+
raise provider_error("openai chat response missing message")
|
|
1194
|
+
|
|
1195
|
+
parts: list[Part] = []
|
|
1196
|
+
content_text = msg.get("content")
|
|
1197
|
+
if isinstance(content_text, str) and content_text:
|
|
1198
|
+
parts.append(Part.from_text(content_text))
|
|
1199
|
+
|
|
1200
|
+
audio = msg.get("audio")
|
|
1201
|
+
if isinstance(audio, dict):
|
|
1202
|
+
data_b64 = audio.get("data")
|
|
1203
|
+
if isinstance(data_b64, str) and data_b64:
|
|
1204
|
+
fmt = None
|
|
1205
|
+
if request and request.output.audio and request.output.audio.format:
|
|
1206
|
+
fmt = request.output.audio.format
|
|
1207
|
+
mime = _audio_mime_from_format(fmt or "wav")
|
|
1208
|
+
parts.append(
|
|
1209
|
+
Part(
|
|
1210
|
+
type="audio",
|
|
1211
|
+
mime_type=mime,
|
|
1212
|
+
source=PartSourceBytes(data=data_b64, encoding="base64"),
|
|
1213
|
+
)
|
|
1214
|
+
)
|
|
1215
|
+
transcript = audio.get("transcript")
|
|
1216
|
+
if (
|
|
1217
|
+
isinstance(transcript, str)
|
|
1218
|
+
and transcript
|
|
1219
|
+
and not (isinstance(content_text, str) and content_text)
|
|
1220
|
+
):
|
|
1221
|
+
parts.append(Part.from_text(transcript))
|
|
1222
|
+
|
|
1223
|
+
tool_calls = msg.get("tool_calls")
|
|
1224
|
+
if isinstance(tool_calls, list):
|
|
1225
|
+
for call in tool_calls:
|
|
1226
|
+
if not isinstance(call, dict):
|
|
1227
|
+
continue
|
|
1228
|
+
tool_call_id = call.get("id")
|
|
1229
|
+
if not isinstance(tool_call_id, str) or not tool_call_id:
|
|
1230
|
+
continue
|
|
1231
|
+
fn = call.get("function")
|
|
1232
|
+
if not isinstance(fn, dict):
|
|
1233
|
+
continue
|
|
1234
|
+
name = fn.get("name")
|
|
1235
|
+
args = fn.get("arguments")
|
|
1236
|
+
if not isinstance(name, str) or not name:
|
|
1237
|
+
continue
|
|
1238
|
+
parts.append(
|
|
1239
|
+
Part.tool_call(
|
|
1240
|
+
tool_call_id=tool_call_id,
|
|
1241
|
+
name=name,
|
|
1242
|
+
arguments=_parse_tool_call_arguments(args),
|
|
1243
|
+
)
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
if not parts:
|
|
1247
|
+
parts.append(Part.from_text(""))
|
|
1248
|
+
|
|
1249
|
+
usage = _usage_from_openai(obj)
|
|
1250
|
+
return GenerateResponse(
|
|
1251
|
+
id=str(resp_id),
|
|
1252
|
+
provider=provider,
|
|
1253
|
+
model=model,
|
|
1254
|
+
status="completed",
|
|
1255
|
+
output=[Message(role="assistant", content=parts)],
|
|
1256
|
+
usage=usage,
|
|
1257
|
+
)
|
|
1258
|
+
|
|
1259
|
+
def _images(self, request: GenerateRequest, *, model_id: str) -> GenerateResponse:
|
|
1260
|
+
if self.provider_name == "openai" and not (
|
|
1261
|
+
model_id.startswith("dall-e-")
|
|
1262
|
+
or model_id.startswith("gpt-image-")
|
|
1263
|
+
or model_id.startswith("chatgpt-image")
|
|
1264
|
+
):
|
|
1265
|
+
raise not_supported_error(
|
|
1266
|
+
f'image generation requires model like "{self.provider_name}:gpt-image-1"'
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
texts: list[str] = []
|
|
1270
|
+
images: list[Part] = []
|
|
1271
|
+
for m in request.input:
|
|
1272
|
+
for p in m.content:
|
|
1273
|
+
if p.type == "text":
|
|
1274
|
+
t = p.require_text().strip()
|
|
1275
|
+
if t:
|
|
1276
|
+
texts.append(t)
|
|
1277
|
+
continue
|
|
1278
|
+
if p.type == "image":
|
|
1279
|
+
images.append(p)
|
|
1280
|
+
continue
|
|
1281
|
+
raise invalid_request_error(
|
|
1282
|
+
"image generation only supports text (+ optional image)"
|
|
1283
|
+
)
|
|
1284
|
+
if len(texts) != 1:
|
|
1285
|
+
raise invalid_request_error(
|
|
1286
|
+
"image generation requires exactly one text part"
|
|
1287
|
+
)
|
|
1288
|
+
if len(images) > 1:
|
|
1289
|
+
raise invalid_request_error(
|
|
1290
|
+
"image generation supports at most one image input"
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
prompt = texts[0]
|
|
1294
|
+
image_part = images[0] if images else None
|
|
1295
|
+
|
|
1296
|
+
response_format: str | None = None
|
|
1297
|
+
img = request.output.image
|
|
1298
|
+
if img and img.n is not None:
|
|
1299
|
+
n = img.n
|
|
1300
|
+
else:
|
|
1301
|
+
n = None
|
|
1302
|
+
if img and img.size is not None:
|
|
1303
|
+
size = img.size
|
|
1304
|
+
else:
|
|
1305
|
+
size = None
|
|
1306
|
+
if model_id.startswith("dall-e-"):
|
|
1307
|
+
response_format = "url"
|
|
1308
|
+
if img and img.format:
|
|
1309
|
+
fmt = img.format.strip().lower()
|
|
1310
|
+
if fmt in {"url"}:
|
|
1311
|
+
response_format = "url"
|
|
1312
|
+
elif fmt in {"b64_json", "base64", "bytes"}:
|
|
1313
|
+
response_format = "b64_json"
|
|
1314
|
+
|
|
1315
|
+
if image_part is None:
|
|
1316
|
+
body: dict[str, Any] = {"model": model_id, "prompt": prompt}
|
|
1317
|
+
if n is not None:
|
|
1318
|
+
body["n"] = n
|
|
1319
|
+
if size is not None:
|
|
1320
|
+
body["size"] = size
|
|
1321
|
+
if response_format:
|
|
1322
|
+
body["response_format"] = response_format
|
|
1323
|
+
self._apply_provider_options(body, request)
|
|
1324
|
+
url = f"{self.base_url}/images/generations"
|
|
1325
|
+
obj = request_json(
|
|
1326
|
+
method="POST",
|
|
1327
|
+
url=url,
|
|
1328
|
+
headers=self._headers(request),
|
|
1329
|
+
json_body=body,
|
|
1330
|
+
timeout_ms=request.params.timeout_ms,
|
|
1331
|
+
proxy_url=self.proxy_url,
|
|
1332
|
+
)
|
|
1333
|
+
else:
|
|
1334
|
+
src = image_part.require_source()
|
|
1335
|
+
tmp_path: str | None = None
|
|
1336
|
+
if isinstance(src, PartSourceUrl):
|
|
1337
|
+
tmp_path = download_to_tempfile(
|
|
1338
|
+
url=src.url,
|
|
1339
|
+
timeout_ms=request.params.timeout_ms,
|
|
1340
|
+
max_bytes=_INLINE_BYTES_LIMIT,
|
|
1341
|
+
proxy_url=self.proxy_url,
|
|
1342
|
+
)
|
|
1343
|
+
file_path = tmp_path
|
|
1344
|
+
elif isinstance(src, PartSourcePath):
|
|
1345
|
+
file_path = src.path
|
|
1346
|
+
elif isinstance(src, PartSourceBytes) and src.encoding == "base64":
|
|
1347
|
+
try:
|
|
1348
|
+
decoded = base64.b64decode(src.data)
|
|
1349
|
+
except Exception:
|
|
1350
|
+
raise invalid_request_error("image base64 data is not valid base64")
|
|
1351
|
+
with tempfile.NamedTemporaryFile(
|
|
1352
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1353
|
+
) as f:
|
|
1354
|
+
f.write(decoded)
|
|
1355
|
+
tmp_path = f.name
|
|
1356
|
+
file_path = tmp_path
|
|
1357
|
+
elif isinstance(src, PartSourceRef):
|
|
1358
|
+
raise not_supported_error(
|
|
1359
|
+
"image edits do not support ref input; use url/bytes/path"
|
|
1360
|
+
)
|
|
1361
|
+
else:
|
|
1362
|
+
assert isinstance(src, PartSourceBytes)
|
|
1363
|
+
if not isinstance(src.data, bytes):
|
|
1364
|
+
raise invalid_request_error("image bytes data must be bytes")
|
|
1365
|
+
with tempfile.NamedTemporaryFile(
|
|
1366
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1367
|
+
) as f:
|
|
1368
|
+
f.write(src.data)
|
|
1369
|
+
tmp_path = f.name
|
|
1370
|
+
file_path = tmp_path
|
|
1371
|
+
|
|
1372
|
+
try:
|
|
1373
|
+
fields: dict[str, str] = {"model": model_id, "prompt": prompt}
|
|
1374
|
+
if n is not None:
|
|
1375
|
+
fields["n"] = str(n)
|
|
1376
|
+
if size is not None:
|
|
1377
|
+
fields["size"] = str(size)
|
|
1378
|
+
if response_format:
|
|
1379
|
+
fields["response_format"] = response_format
|
|
1380
|
+
self._apply_provider_options_form_fields(fields, request)
|
|
1381
|
+
streaming_body = multipart_form_data(
|
|
1382
|
+
fields=fields,
|
|
1383
|
+
file_field="image",
|
|
1384
|
+
file_path=file_path,
|
|
1385
|
+
filename=os.path.basename(file_path),
|
|
1386
|
+
file_mime_type=image_part.mime_type
|
|
1387
|
+
or detect_mime_type(file_path)
|
|
1388
|
+
or "application/octet-stream",
|
|
1389
|
+
)
|
|
1390
|
+
obj = request_streaming_body_json(
|
|
1391
|
+
method="POST",
|
|
1392
|
+
url=f"{self.base_url}/images/edits",
|
|
1393
|
+
headers=self._headers(request),
|
|
1394
|
+
body=streaming_body,
|
|
1395
|
+
timeout_ms=request.params.timeout_ms,
|
|
1396
|
+
proxy_url=self.proxy_url,
|
|
1397
|
+
)
|
|
1398
|
+
finally:
|
|
1399
|
+
if tmp_path is not None:
|
|
1400
|
+
try:
|
|
1401
|
+
os.unlink(tmp_path)
|
|
1402
|
+
except OSError:
|
|
1403
|
+
pass
|
|
1404
|
+
resp_id = obj.get("created") or f"sdk_{uuid4().hex}"
|
|
1405
|
+
data_field = obj.get("data")
|
|
1406
|
+
items: list[object] | None = None
|
|
1407
|
+
if isinstance(data_field, list):
|
|
1408
|
+
items = data_field
|
|
1409
|
+
elif isinstance(data_field, dict):
|
|
1410
|
+
images_data = data_field.get("images")
|
|
1411
|
+
if isinstance(images_data, list):
|
|
1412
|
+
items = images_data
|
|
1413
|
+
else:
|
|
1414
|
+
inner = data_field.get("data")
|
|
1415
|
+
if isinstance(inner, list):
|
|
1416
|
+
items = inner
|
|
1417
|
+
if not items:
|
|
1418
|
+
raise provider_error("openai images response missing data")
|
|
1419
|
+
parts = []
|
|
1420
|
+
for item in items:
|
|
1421
|
+
if not isinstance(item, dict):
|
|
1422
|
+
continue
|
|
1423
|
+
u = item.get("url")
|
|
1424
|
+
if isinstance(u, str) and u:
|
|
1425
|
+
parts.append(Part(type="image", source=PartSourceUrl(url=u)))
|
|
1426
|
+
continue
|
|
1427
|
+
b64 = item.get("b64_json")
|
|
1428
|
+
if isinstance(b64, str) and b64:
|
|
1429
|
+
parts.append(
|
|
1430
|
+
Part(
|
|
1431
|
+
type="image",
|
|
1432
|
+
mime_type="image/png",
|
|
1433
|
+
source=PartSourceBytes(data=b64, encoding="base64"),
|
|
1434
|
+
)
|
|
1435
|
+
)
|
|
1436
|
+
if not parts:
|
|
1437
|
+
raise provider_error("openai images response missing urls")
|
|
1438
|
+
return GenerateResponse(
|
|
1439
|
+
id=str(resp_id),
|
|
1440
|
+
provider=self.provider_name,
|
|
1441
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1442
|
+
status="completed",
|
|
1443
|
+
output=[Message(role="assistant", content=parts)],
|
|
1444
|
+
usage=None,
|
|
1445
|
+
)
|
|
1446
|
+
|
|
1447
|
+
def _tts(self, request: GenerateRequest, *, model_id: str) -> GenerateResponse:
|
|
1448
|
+
if self.provider_name == "openai" and not (
|
|
1449
|
+
model_id.startswith("tts-") or "-tts" in model_id
|
|
1450
|
+
):
|
|
1451
|
+
raise invalid_request_error(
|
|
1452
|
+
f'TTS requires model like "{self.provider_name}:tts-1"'
|
|
1453
|
+
)
|
|
1454
|
+
text = self._single_text_prompt(request)
|
|
1455
|
+
audio = request.output.audio
|
|
1456
|
+
if audio is None or not audio.voice:
|
|
1457
|
+
raise invalid_request_error("output.audio.voice required for TTS")
|
|
1458
|
+
fmt = audio.format or "mp3"
|
|
1459
|
+
body: dict[str, Any] = {
|
|
1460
|
+
"model": model_id,
|
|
1461
|
+
"voice": audio.voice,
|
|
1462
|
+
"input": text,
|
|
1463
|
+
"response_format": fmt,
|
|
1464
|
+
}
|
|
1465
|
+
self._apply_provider_options(body, request)
|
|
1466
|
+
url = f"{self.base_url}/audio/speech"
|
|
1467
|
+
data = request_bytes(
|
|
1468
|
+
method="POST",
|
|
1469
|
+
url=url,
|
|
1470
|
+
headers={**self._headers(request), "Content-Type": "application/json"},
|
|
1471
|
+
body=json.dumps(body, separators=(",", ":")).encode("utf-8"),
|
|
1472
|
+
timeout_ms=request.params.timeout_ms,
|
|
1473
|
+
proxy_url=self.proxy_url,
|
|
1474
|
+
)
|
|
1475
|
+
part = Part(
|
|
1476
|
+
type="audio",
|
|
1477
|
+
mime_type=_audio_mime_from_format(fmt),
|
|
1478
|
+
source=PartSourceBytes(data=bytes_to_base64(data), encoding="base64"),
|
|
1479
|
+
)
|
|
1480
|
+
return GenerateResponse(
|
|
1481
|
+
id=f"sdk_{uuid4().hex}",
|
|
1482
|
+
provider=self.provider_name,
|
|
1483
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1484
|
+
status="completed",
|
|
1485
|
+
output=[Message(role="assistant", content=[part])],
|
|
1486
|
+
usage=None,
|
|
1487
|
+
)
|
|
1488
|
+
|
|
1489
|
+
def _transcribe(
|
|
1490
|
+
self, request: GenerateRequest, *, model_id: str
|
|
1491
|
+
) -> GenerateResponse:
|
|
1492
|
+
audio_part = self._single_audio_part(request)
|
|
1493
|
+
prompt = self._transcription_prompt(request, audio_part=audio_part)
|
|
1494
|
+
src = audio_part.require_source()
|
|
1495
|
+
tmp_path: str | None = None
|
|
1496
|
+
if isinstance(src, PartSourceUrl):
|
|
1497
|
+
tmp_path = _download_to_temp(
|
|
1498
|
+
src.url,
|
|
1499
|
+
timeout_ms=request.params.timeout_ms,
|
|
1500
|
+
max_bytes=None,
|
|
1501
|
+
proxy_url=self.proxy_url,
|
|
1502
|
+
)
|
|
1503
|
+
file_path = tmp_path
|
|
1504
|
+
elif isinstance(src, PartSourcePath):
|
|
1505
|
+
file_path = src.path
|
|
1506
|
+
elif isinstance(src, PartSourceBytes) and src.encoding == "base64":
|
|
1507
|
+
try:
|
|
1508
|
+
data = base64.b64decode(src.data)
|
|
1509
|
+
except Exception:
|
|
1510
|
+
raise invalid_request_error("audio base64 data is not valid base64")
|
|
1511
|
+
with tempfile.NamedTemporaryFile(
|
|
1512
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1513
|
+
) as f:
|
|
1514
|
+
f.write(data)
|
|
1515
|
+
tmp_path = f.name
|
|
1516
|
+
file_path = tmp_path
|
|
1517
|
+
elif isinstance(src, PartSourceRef):
|
|
1518
|
+
raise not_supported_error("openai transcription does not support ref input")
|
|
1519
|
+
else:
|
|
1520
|
+
assert isinstance(src, PartSourceBytes)
|
|
1521
|
+
if not isinstance(src.data, bytes):
|
|
1522
|
+
raise invalid_request_error("audio bytes data must be bytes")
|
|
1523
|
+
with tempfile.NamedTemporaryFile(
|
|
1524
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1525
|
+
) as f:
|
|
1526
|
+
f.write(src.data)
|
|
1527
|
+
tmp_path = f.name
|
|
1528
|
+
file_path = tmp_path
|
|
1529
|
+
|
|
1530
|
+
try:
|
|
1531
|
+
fields = {"model": model_id}
|
|
1532
|
+
if request.params.temperature is not None:
|
|
1533
|
+
fields["temperature"] = str(request.params.temperature)
|
|
1534
|
+
lang = audio_part.meta.get("language")
|
|
1535
|
+
if isinstance(lang, str) and lang.strip():
|
|
1536
|
+
fields["language"] = lang.strip()
|
|
1537
|
+
if prompt:
|
|
1538
|
+
fields["prompt"] = prompt
|
|
1539
|
+
self._apply_provider_options_form_fields(fields, request)
|
|
1540
|
+
if "diarize" in model_id and "chunking_strategy" not in fields:
|
|
1541
|
+
fields["chunking_strategy"] = "auto"
|
|
1542
|
+
body = multipart_form_data(
|
|
1543
|
+
fields=fields,
|
|
1544
|
+
file_field="file",
|
|
1545
|
+
file_path=file_path,
|
|
1546
|
+
filename=os.path.basename(file_path),
|
|
1547
|
+
file_mime_type=audio_part.mime_type
|
|
1548
|
+
or detect_mime_type(file_path)
|
|
1549
|
+
or "application/octet-stream",
|
|
1550
|
+
)
|
|
1551
|
+
url = f"{self.base_url}/audio/transcriptions"
|
|
1552
|
+
obj = request_streaming_body_json(
|
|
1553
|
+
method="POST",
|
|
1554
|
+
url=url,
|
|
1555
|
+
headers=self._headers(request),
|
|
1556
|
+
body=body,
|
|
1557
|
+
timeout_ms=request.params.timeout_ms,
|
|
1558
|
+
proxy_url=self.proxy_url,
|
|
1559
|
+
)
|
|
1560
|
+
finally:
|
|
1561
|
+
if tmp_path is not None:
|
|
1562
|
+
try:
|
|
1563
|
+
os.unlink(tmp_path)
|
|
1564
|
+
except OSError:
|
|
1565
|
+
pass
|
|
1566
|
+
|
|
1567
|
+
text = obj.get("text")
|
|
1568
|
+
if not isinstance(text, str):
|
|
1569
|
+
raise provider_error("openai transcription missing text")
|
|
1570
|
+
return GenerateResponse(
|
|
1571
|
+
id=f"sdk_{uuid4().hex}",
|
|
1572
|
+
provider=self.provider_name,
|
|
1573
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1574
|
+
status="completed",
|
|
1575
|
+
output=[Message(role="assistant", content=[Part.from_text(text)])],
|
|
1576
|
+
usage=None,
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
def _embed(self, request: GenerateRequest, *, model_id: str) -> GenerateResponse:
|
|
1580
|
+
if self.provider_name == "openai" and not model_id.startswith(
|
|
1581
|
+
"text-embedding-"
|
|
1582
|
+
):
|
|
1583
|
+
raise not_supported_error(
|
|
1584
|
+
f'embedding requires model like "{self.provider_name}:text-embedding-3-small"'
|
|
1585
|
+
)
|
|
1586
|
+
texts = _gather_text_inputs(request)
|
|
1587
|
+
url = f"{self.base_url}/embeddings"
|
|
1588
|
+
body: dict[str, Any] = {"model": model_id, "input": texts}
|
|
1589
|
+
emb = request.output.embedding
|
|
1590
|
+
if emb and emb.dimensions is not None:
|
|
1591
|
+
if self.provider_name == "openai" and not model_id.startswith(
|
|
1592
|
+
"text-embedding-3-"
|
|
1593
|
+
):
|
|
1594
|
+
raise invalid_request_error(
|
|
1595
|
+
"embedding.dimensions is only supported for OpenAI text-embedding-3 models"
|
|
1596
|
+
)
|
|
1597
|
+
body["dimensions"] = emb.dimensions
|
|
1598
|
+
self._apply_provider_options(body, request)
|
|
1599
|
+
obj = request_json(
|
|
1600
|
+
method="POST",
|
|
1601
|
+
url=url,
|
|
1602
|
+
headers=self._headers(request),
|
|
1603
|
+
json_body=body,
|
|
1604
|
+
timeout_ms=request.params.timeout_ms,
|
|
1605
|
+
proxy_url=self.proxy_url,
|
|
1606
|
+
)
|
|
1607
|
+
data = obj.get("data")
|
|
1608
|
+
if not isinstance(data, list) or len(data) != len(texts):
|
|
1609
|
+
raise provider_error("openai embeddings response missing data")
|
|
1610
|
+
parts: list[Part] = []
|
|
1611
|
+
for item in data:
|
|
1612
|
+
if not isinstance(item, dict):
|
|
1613
|
+
raise provider_error("openai embeddings item is not object")
|
|
1614
|
+
emb = item.get("embedding")
|
|
1615
|
+
if not isinstance(emb, list) or not all(
|
|
1616
|
+
isinstance(x, (int, float)) for x in emb
|
|
1617
|
+
):
|
|
1618
|
+
raise provider_error("openai embeddings item missing embedding")
|
|
1619
|
+
parts.append(Part(type="embedding", embedding=[float(x) for x in emb]))
|
|
1620
|
+
|
|
1621
|
+
usage = None
|
|
1622
|
+
u = obj.get("usage")
|
|
1623
|
+
if isinstance(u, dict):
|
|
1624
|
+
usage = Usage(
|
|
1625
|
+
input_tokens=u.get("prompt_tokens"),
|
|
1626
|
+
total_tokens=u.get("total_tokens"),
|
|
1627
|
+
)
|
|
1628
|
+
|
|
1629
|
+
return GenerateResponse(
|
|
1630
|
+
id=f"sdk_{uuid4().hex}",
|
|
1631
|
+
provider=self.provider_name,
|
|
1632
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1633
|
+
status="completed",
|
|
1634
|
+
output=[Message(role="assistant", content=parts)],
|
|
1635
|
+
usage=usage,
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
def _video(self, request: GenerateRequest, *, model_id: str) -> GenerateResponse:
|
|
1639
|
+
if self.provider_name == "openai" and not model_id.startswith("sora-"):
|
|
1640
|
+
raise not_supported_error(
|
|
1641
|
+
f'video generation requires model like "{self.provider_name}:sora-2"'
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
is_tuzi = self.provider_name.startswith("tuzi")
|
|
1645
|
+
|
|
1646
|
+
def _tuzi_prompt_and_image(req: GenerateRequest) -> tuple[str, Part | None]:
|
|
1647
|
+
texts: list[str] = []
|
|
1648
|
+
images: list[Part] = []
|
|
1649
|
+
for msg in req.input:
|
|
1650
|
+
for part in msg.content:
|
|
1651
|
+
if part.type == "text":
|
|
1652
|
+
t = part.require_text().strip()
|
|
1653
|
+
if t:
|
|
1654
|
+
texts.append(t)
|
|
1655
|
+
continue
|
|
1656
|
+
if part.type == "image":
|
|
1657
|
+
images.append(part)
|
|
1658
|
+
continue
|
|
1659
|
+
raise invalid_request_error(
|
|
1660
|
+
"video generation only supports text (+ optional image)"
|
|
1661
|
+
)
|
|
1662
|
+
if len(texts) != 1:
|
|
1663
|
+
raise invalid_request_error(
|
|
1664
|
+
"video generation requires exactly one text part"
|
|
1665
|
+
)
|
|
1666
|
+
if len(images) > 1:
|
|
1667
|
+
raise invalid_request_error(
|
|
1668
|
+
"video generation supports at most one image input"
|
|
1669
|
+
)
|
|
1670
|
+
return texts[0], images[0] if images else None
|
|
1671
|
+
|
|
1672
|
+
if is_tuzi:
|
|
1673
|
+
prompt, image_part = _tuzi_prompt_and_image(request)
|
|
1674
|
+
else:
|
|
1675
|
+
prompt = self._single_text_prompt(request)
|
|
1676
|
+
|
|
1677
|
+
video = request.output.video
|
|
1678
|
+
if is_tuzi:
|
|
1679
|
+
is_sora = model_id.lower().startswith("sora-")
|
|
1680
|
+
fields: dict[str, str] = {"model": model_id, "prompt": prompt}
|
|
1681
|
+
if video and video.duration_sec is not None:
|
|
1682
|
+
fields["seconds"] = str(
|
|
1683
|
+
_closest_video_seconds(video.duration_sec, is_tuzi=not is_sora)
|
|
1684
|
+
)
|
|
1685
|
+
if video and video.aspect_ratio:
|
|
1686
|
+
size = _video_size_from_aspect_ratio(video.aspect_ratio)
|
|
1687
|
+
if size:
|
|
1688
|
+
fields["size"] = size
|
|
1689
|
+
self._apply_provider_options_form_fields(fields, request)
|
|
1690
|
+
tmp_path: str | None = None
|
|
1691
|
+
try:
|
|
1692
|
+
if image_part is None:
|
|
1693
|
+
stream_body = multipart_form_data_fields(fields=fields)
|
|
1694
|
+
else:
|
|
1695
|
+
src = image_part.require_source()
|
|
1696
|
+
if isinstance(src, PartSourceUrl):
|
|
1697
|
+
fields["first_frame_image"] = src.url
|
|
1698
|
+
fields["input_reference"] = src.url
|
|
1699
|
+
stream_body = multipart_form_data_fields(fields=fields)
|
|
1700
|
+
elif isinstance(src, PartSourceRef):
|
|
1701
|
+
raise not_supported_error(
|
|
1702
|
+
"tuzi video generation does not support ref image input"
|
|
1703
|
+
)
|
|
1704
|
+
else:
|
|
1705
|
+
if isinstance(src, PartSourcePath):
|
|
1706
|
+
file_path = src.path
|
|
1707
|
+
elif (
|
|
1708
|
+
isinstance(src, PartSourceBytes)
|
|
1709
|
+
and src.encoding == "base64"
|
|
1710
|
+
):
|
|
1711
|
+
try:
|
|
1712
|
+
decoded = base64.b64decode(src.data)
|
|
1713
|
+
except Exception:
|
|
1714
|
+
raise invalid_request_error(
|
|
1715
|
+
"image base64 data is not valid base64"
|
|
1716
|
+
)
|
|
1717
|
+
with tempfile.NamedTemporaryFile(
|
|
1718
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1719
|
+
) as f:
|
|
1720
|
+
f.write(decoded)
|
|
1721
|
+
tmp_path = f.name
|
|
1722
|
+
file_path = tmp_path
|
|
1723
|
+
else:
|
|
1724
|
+
assert isinstance(src, PartSourceBytes)
|
|
1725
|
+
if not isinstance(src.data, bytes):
|
|
1726
|
+
raise invalid_request_error(
|
|
1727
|
+
"image bytes data must be bytes"
|
|
1728
|
+
)
|
|
1729
|
+
with tempfile.NamedTemporaryFile(
|
|
1730
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1731
|
+
) as f:
|
|
1732
|
+
f.write(src.data)
|
|
1733
|
+
tmp_path = f.name
|
|
1734
|
+
file_path = tmp_path
|
|
1735
|
+
|
|
1736
|
+
stream_body = multipart_form_data(
|
|
1737
|
+
fields=fields,
|
|
1738
|
+
file_field="input_reference",
|
|
1739
|
+
file_path=file_path,
|
|
1740
|
+
filename=os.path.basename(file_path),
|
|
1741
|
+
file_mime_type=image_part.mime_type
|
|
1742
|
+
or detect_mime_type(file_path)
|
|
1743
|
+
or "application/octet-stream",
|
|
1744
|
+
)
|
|
1745
|
+
obj = request_streaming_body_json(
|
|
1746
|
+
method="POST",
|
|
1747
|
+
url=f"{self.base_url}/videos",
|
|
1748
|
+
headers=self._headers(request),
|
|
1749
|
+
body=stream_body,
|
|
1750
|
+
timeout_ms=request.params.timeout_ms,
|
|
1751
|
+
proxy_url=self.proxy_url,
|
|
1752
|
+
)
|
|
1753
|
+
finally:
|
|
1754
|
+
if tmp_path is not None:
|
|
1755
|
+
try:
|
|
1756
|
+
os.unlink(tmp_path)
|
|
1757
|
+
except OSError:
|
|
1758
|
+
pass
|
|
1759
|
+
else:
|
|
1760
|
+
body: dict[str, Any] = {"model": model_id, "prompt": prompt}
|
|
1761
|
+
if video and video.duration_sec is not None:
|
|
1762
|
+
body["seconds"] = _closest_video_seconds(
|
|
1763
|
+
video.duration_sec, is_tuzi=False
|
|
1764
|
+
)
|
|
1765
|
+
if video and video.aspect_ratio:
|
|
1766
|
+
size = _video_size_from_aspect_ratio(video.aspect_ratio)
|
|
1767
|
+
if size:
|
|
1768
|
+
body["size"] = size
|
|
1769
|
+
self._apply_provider_options(body, request)
|
|
1770
|
+
obj = request_json(
|
|
1771
|
+
method="POST",
|
|
1772
|
+
url=f"{self.base_url}/videos",
|
|
1773
|
+
headers=self._headers(request),
|
|
1774
|
+
json_body=body,
|
|
1775
|
+
timeout_ms=request.params.timeout_ms,
|
|
1776
|
+
proxy_url=self.proxy_url,
|
|
1777
|
+
)
|
|
1778
|
+
video_id = obj.get("id")
|
|
1779
|
+
if not isinstance(video_id, str) or not video_id:
|
|
1780
|
+
raise provider_error("openai video response missing id")
|
|
1781
|
+
|
|
1782
|
+
if not request.wait:
|
|
1783
|
+
return GenerateResponse(
|
|
1784
|
+
id=f"sdk_{uuid4().hex}",
|
|
1785
|
+
provider=self.provider_name,
|
|
1786
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1787
|
+
status="running",
|
|
1788
|
+
job=JobInfo(job_id=video_id, poll_after_ms=1_000),
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
job = self._wait_video_job(video_id, timeout_ms=request.params.timeout_ms)
|
|
1792
|
+
status = job.get("status")
|
|
1793
|
+
if status != "completed":
|
|
1794
|
+
if status == "failed":
|
|
1795
|
+
err = job.get("error")
|
|
1796
|
+
raise provider_error(f"openai video generation failed: {err}")
|
|
1797
|
+
return GenerateResponse(
|
|
1798
|
+
id=f"sdk_{uuid4().hex}",
|
|
1799
|
+
provider=self.provider_name,
|
|
1800
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1801
|
+
status="running",
|
|
1802
|
+
job=JobInfo(job_id=video_id, poll_after_ms=1_000),
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
data = request_bytes(
|
|
1806
|
+
method="GET",
|
|
1807
|
+
url=f"{self.base_url}/videos/{video_id}/content",
|
|
1808
|
+
headers=self._headers(request),
|
|
1809
|
+
timeout_ms=request.params.timeout_ms,
|
|
1810
|
+
proxy_url=self.proxy_url,
|
|
1811
|
+
)
|
|
1812
|
+
part = Part(
|
|
1813
|
+
type="video",
|
|
1814
|
+
mime_type="video/mp4",
|
|
1815
|
+
source=PartSourceBytes(data=bytes_to_base64(data), encoding="base64"),
|
|
1816
|
+
)
|
|
1817
|
+
return GenerateResponse(
|
|
1818
|
+
id=f"sdk_{uuid4().hex}",
|
|
1819
|
+
provider=self.provider_name,
|
|
1820
|
+
model=f"{self.provider_name}:{model_id}",
|
|
1821
|
+
status="completed",
|
|
1822
|
+
output=[Message(role="assistant", content=[part])],
|
|
1823
|
+
usage=None,
|
|
1824
|
+
)
|
|
1825
|
+
|
|
1826
|
+
def _wait_video_job(
|
|
1827
|
+
self, video_id: str, *, timeout_ms: int | None
|
|
1828
|
+
) -> dict[str, Any]:
|
|
1829
|
+
budget_ms = 120_000 if timeout_ms is None else timeout_ms
|
|
1830
|
+
deadline = time.time() + max(1, budget_ms) / 1000.0
|
|
1831
|
+
while True:
|
|
1832
|
+
remaining_ms = int((deadline - time.time()) * 1000)
|
|
1833
|
+
if remaining_ms <= 0:
|
|
1834
|
+
break
|
|
1835
|
+
obj = request_json(
|
|
1836
|
+
method="GET",
|
|
1837
|
+
url=f"{self.base_url}/videos/{video_id}",
|
|
1838
|
+
headers=self._headers(),
|
|
1839
|
+
timeout_ms=min(30_000, remaining_ms),
|
|
1840
|
+
proxy_url=self.proxy_url,
|
|
1841
|
+
)
|
|
1842
|
+
status = obj.get("status")
|
|
1843
|
+
if status in {"completed", "failed"}:
|
|
1844
|
+
return obj
|
|
1845
|
+
time.sleep(min(1.0, max(0.0, deadline - time.time())))
|
|
1846
|
+
return {"id": video_id, "status": "in_progress"}
|
|
1847
|
+
|
|
1848
|
+
def _single_text_prompt(self, request: GenerateRequest) -> str:
|
|
1849
|
+
texts = _gather_text_inputs(request)
|
|
1850
|
+
if len(texts) != 1:
|
|
1851
|
+
raise invalid_request_error("this operation requires exactly one text part")
|
|
1852
|
+
return texts[0]
|
|
1853
|
+
|
|
1854
|
+
def _single_audio_part(self, request: GenerateRequest) -> Part:
|
|
1855
|
+
parts: list[Part] = []
|
|
1856
|
+
for m in request.input:
|
|
1857
|
+
for p in m.content:
|
|
1858
|
+
if p.type == "audio":
|
|
1859
|
+
parts.append(p)
|
|
1860
|
+
elif p.type != "text":
|
|
1861
|
+
raise invalid_request_error(
|
|
1862
|
+
"transcription only supports audio (+ optional text)"
|
|
1863
|
+
)
|
|
1864
|
+
if len(parts) != 1:
|
|
1865
|
+
raise invalid_request_error("transcription requires exactly one audio part")
|
|
1866
|
+
return parts[0]
|
|
1867
|
+
|
|
1868
|
+
def _transcription_prompt(
|
|
1869
|
+
self, request: GenerateRequest, *, audio_part: Part
|
|
1870
|
+
) -> str | None:
|
|
1871
|
+
v = audio_part.meta.get("transcription_prompt")
|
|
1872
|
+
if isinstance(v, str) and v.strip():
|
|
1873
|
+
return v.strip()
|
|
1874
|
+
chunks: list[str] = []
|
|
1875
|
+
for m in request.input:
|
|
1876
|
+
for p in m.content:
|
|
1877
|
+
if p.type != "text":
|
|
1878
|
+
continue
|
|
1879
|
+
if p.meta.get("transcription_prompt") is not True:
|
|
1880
|
+
continue
|
|
1881
|
+
t = p.require_text().strip()
|
|
1882
|
+
if t:
|
|
1883
|
+
chunks.append(t)
|
|
1884
|
+
if not chunks:
|
|
1885
|
+
return None
|
|
1886
|
+
return "\n\n".join(chunks)
|
|
1887
|
+
|
|
1888
|
+
|
|
1889
|
+
def _closest_video_seconds(duration_sec: int, *, is_tuzi: bool) -> int:
|
|
1890
|
+
if is_tuzi:
|
|
1891
|
+
if duration_sec <= 10:
|
|
1892
|
+
return 10
|
|
1893
|
+
if duration_sec <= 15:
|
|
1894
|
+
return 15
|
|
1895
|
+
return 25
|
|
1896
|
+
if duration_sec <= 4:
|
|
1897
|
+
return 4
|
|
1898
|
+
if duration_sec <= 8:
|
|
1899
|
+
return 8
|
|
1900
|
+
return 12
|
|
1901
|
+
|
|
1902
|
+
|
|
1903
|
+
def _video_size_from_aspect_ratio(aspect_ratio: str) -> str | None:
|
|
1904
|
+
ar = aspect_ratio.strip()
|
|
1905
|
+
if ar == "16:9":
|
|
1906
|
+
return "1280x720"
|
|
1907
|
+
if ar == "9:16":
|
|
1908
|
+
return "720x1280"
|
|
1909
|
+
return None
|