nous-genai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nous/__init__.py +3 -0
- nous/genai/__init__.py +56 -0
- nous/genai/__main__.py +3 -0
- nous/genai/_internal/__init__.py +1 -0
- nous/genai/_internal/capability_rules.py +476 -0
- nous/genai/_internal/config.py +102 -0
- nous/genai/_internal/errors.py +63 -0
- nous/genai/_internal/http.py +951 -0
- nous/genai/_internal/json_schema.py +54 -0
- nous/genai/cli.py +1316 -0
- nous/genai/client.py +719 -0
- nous/genai/mcp_cli.py +275 -0
- nous/genai/mcp_server.py +1080 -0
- nous/genai/providers/__init__.py +15 -0
- nous/genai/providers/aliyun.py +535 -0
- nous/genai/providers/anthropic.py +483 -0
- nous/genai/providers/gemini.py +1606 -0
- nous/genai/providers/openai.py +1909 -0
- nous/genai/providers/tuzi.py +1158 -0
- nous/genai/providers/volcengine.py +273 -0
- nous/genai/reference/__init__.py +17 -0
- nous/genai/reference/catalog.py +206 -0
- nous/genai/reference/mappings.py +467 -0
- nous/genai/reference/mode_overrides.py +26 -0
- nous/genai/reference/model_catalog.py +82 -0
- nous/genai/reference/model_catalog_data/__init__.py +1 -0
- nous/genai/reference/model_catalog_data/aliyun.py +98 -0
- nous/genai/reference/model_catalog_data/anthropic.py +10 -0
- nous/genai/reference/model_catalog_data/google.py +45 -0
- nous/genai/reference/model_catalog_data/openai.py +44 -0
- nous/genai/reference/model_catalog_data/tuzi_anthropic.py +21 -0
- nous/genai/reference/model_catalog_data/tuzi_google.py +19 -0
- nous/genai/reference/model_catalog_data/tuzi_openai.py +75 -0
- nous/genai/reference/model_catalog_data/tuzi_web.py +136 -0
- nous/genai/reference/model_catalog_data/volcengine.py +107 -0
- nous/genai/tools/__init__.py +13 -0
- nous/genai/tools/output_parser.py +119 -0
- nous/genai/types.py +416 -0
- nous/py.typed +1 -0
- nous_genai-0.1.0.dist-info/METADATA +200 -0
- nous_genai-0.1.0.dist-info/RECORD +45 -0
- nous_genai-0.1.0.dist-info/WHEEL +5 -0
- nous_genai-0.1.0.dist-info/entry_points.txt +4 -0
- nous_genai-0.1.0.dist-info/licenses/LICENSE +190 -0
- nous_genai-0.1.0.dist-info/top_level.txt +1 -0
nous/genai/cli.py
ADDED
|
@@ -0,0 +1,1316 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import base64
|
|
5
|
+
import json
|
|
6
|
+
import secrets
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
import urllib.parse
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from dataclasses import replace
|
|
13
|
+
from typing import TypeVar
|
|
14
|
+
|
|
15
|
+
from .client import Client
|
|
16
|
+
from ._internal.errors import GenAIError
|
|
17
|
+
from ._internal.http import download_to_file
|
|
18
|
+
from .reference import (
|
|
19
|
+
get_model_catalog,
|
|
20
|
+
get_parameter_mappings,
|
|
21
|
+
get_sdk_supported_models,
|
|
22
|
+
)
|
|
23
|
+
from .types import (
|
|
24
|
+
GenerateRequest,
|
|
25
|
+
Message,
|
|
26
|
+
OutputAudioSpec,
|
|
27
|
+
OutputEmbeddingSpec,
|
|
28
|
+
OutputImageSpec,
|
|
29
|
+
OutputSpec,
|
|
30
|
+
OutputTextSpec,
|
|
31
|
+
OutputVideoSpec,
|
|
32
|
+
Part,
|
|
33
|
+
PartSourceBytes,
|
|
34
|
+
PartSourcePath,
|
|
35
|
+
PartSourceUrl,
|
|
36
|
+
detect_mime_type,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def main(argv: list[str] | None = None) -> None:
|
|
41
|
+
parser = argparse.ArgumentParser(
|
|
42
|
+
prog="genai", description="nous-genai CLI (minimal)"
|
|
43
|
+
)
|
|
44
|
+
sub = parser.add_subparsers(dest="command")
|
|
45
|
+
|
|
46
|
+
model = sub.add_parser("model", help="Model discovery and support listing")
|
|
47
|
+
model_sub = model.add_subparsers(dest="model_command", required=True)
|
|
48
|
+
|
|
49
|
+
model_sub.add_parser("sdk", help="List SDK curated models")
|
|
50
|
+
|
|
51
|
+
pm = model_sub.add_parser(
|
|
52
|
+
"provider", help="List remotely available model ids for a provider"
|
|
53
|
+
)
|
|
54
|
+
pm.add_argument(
|
|
55
|
+
"--provider", required=True, help="Provider (e.g. openai/google/tuzi-openai)"
|
|
56
|
+
)
|
|
57
|
+
pm.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
|
|
58
|
+
|
|
59
|
+
av = model_sub.add_parser(
|
|
60
|
+
"available",
|
|
61
|
+
help="List available models (sdk ∩ provider) with capabilities",
|
|
62
|
+
)
|
|
63
|
+
av_scope = av.add_mutually_exclusive_group(required=True)
|
|
64
|
+
av_scope.add_argument(
|
|
65
|
+
"--provider", help="Provider (e.g. openai/google/tuzi-openai)"
|
|
66
|
+
)
|
|
67
|
+
av_scope.add_argument(
|
|
68
|
+
"--all", action="store_true", help="List across all providers"
|
|
69
|
+
)
|
|
70
|
+
av.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
|
|
71
|
+
|
|
72
|
+
um = model_sub.add_parser(
|
|
73
|
+
"unsupported",
|
|
74
|
+
help="List provider-available but not in SDK catalog models",
|
|
75
|
+
)
|
|
76
|
+
um.add_argument("--provider", help="Provider (omit to scan all catalog providers)")
|
|
77
|
+
um.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
|
|
78
|
+
|
|
79
|
+
st = model_sub.add_parser(
|
|
80
|
+
"stale",
|
|
81
|
+
help="List stale model ids (sdk catalog - provider) for a provider",
|
|
82
|
+
)
|
|
83
|
+
st.add_argument(
|
|
84
|
+
"--provider", required=True, help="Provider (e.g. openai/google/tuzi-openai)"
|
|
85
|
+
)
|
|
86
|
+
st.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
|
|
87
|
+
|
|
88
|
+
sub.add_parser("mapping", help="Print parameter mapping table")
|
|
89
|
+
|
|
90
|
+
token = sub.add_parser("token", help="Token utilities")
|
|
91
|
+
token_sub = token.add_subparsers(dest="token_command", required=True)
|
|
92
|
+
token_sub.add_parser("generate", help="Generate a new token (sk-...)")
|
|
93
|
+
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--model", default="openai:gpt-4o-mini", help='Model like "openai:gpt-4o-mini"'
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--protocol",
|
|
99
|
+
choices=["chat_completions", "responses"],
|
|
100
|
+
help='OpenAI chat protocol override ("chat_completions" or "responses")',
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument("--prompt", help="Text prompt")
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--prompt-path",
|
|
105
|
+
help="Read prompt text from a file (lower priority than --prompt)",
|
|
106
|
+
)
|
|
107
|
+
parser.add_argument("--image-path", help="Input image file path")
|
|
108
|
+
parser.add_argument("--audio-path", help="Input audio file path")
|
|
109
|
+
parser.add_argument("--video-path", help="Input video file path")
|
|
110
|
+
parser.add_argument("--output-path", help="Write output to file (text/json/binary)")
|
|
111
|
+
parser.add_argument("--ouput-path", dest="output_path", help=argparse.SUPPRESS)
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--timeout-ms",
|
|
114
|
+
type=int,
|
|
115
|
+
help="Timeout budget in milliseconds (overrides NOUS_GENAI_TIMEOUT_MS)",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
probe = sub.add_parser("probe", help="Probe model modalities/modes for a provider")
|
|
119
|
+
probe.add_argument("--provider", required=True, help="Provider (e.g. tuzi-web)")
|
|
120
|
+
probe.add_argument(
|
|
121
|
+
"--model", help="Comma-separated model ids (or provider:model_id)"
|
|
122
|
+
)
|
|
123
|
+
probe.add_argument(
|
|
124
|
+
"--all",
|
|
125
|
+
action="store_true",
|
|
126
|
+
help="Probe all SDK-supported models for the provider",
|
|
127
|
+
)
|
|
128
|
+
probe.add_argument(
|
|
129
|
+
"--timeout-ms",
|
|
130
|
+
type=int,
|
|
131
|
+
help="Timeout budget in milliseconds (overrides NOUS_GENAI_TIMEOUT_MS)",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
args = parser.parse_args(argv)
|
|
135
|
+
timeout_ms: int | None = getattr(args, "timeout_ms", None)
|
|
136
|
+
if timeout_ms is not None and timeout_ms < 1:
|
|
137
|
+
raise SystemExit("--timeout-ms must be >= 1")
|
|
138
|
+
|
|
139
|
+
if args.command == "mapping":
|
|
140
|
+
try:
|
|
141
|
+
_print_mappings()
|
|
142
|
+
except BrokenPipeError:
|
|
143
|
+
return
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
if args.command == "token":
|
|
147
|
+
cmd = str(getattr(args, "token_command", "") or "").strip()
|
|
148
|
+
if cmd == "generate":
|
|
149
|
+
print(_generate_token())
|
|
150
|
+
return
|
|
151
|
+
raise SystemExit(f"unknown token subcommand: {cmd}")
|
|
152
|
+
|
|
153
|
+
if args.command == "model":
|
|
154
|
+
try:
|
|
155
|
+
cmd = str(getattr(args, "model_command", "") or "").strip()
|
|
156
|
+
if cmd == "sdk":
|
|
157
|
+
_print_sdk_supported()
|
|
158
|
+
return
|
|
159
|
+
if cmd == "provider":
|
|
160
|
+
_print_provider_models(str(args.provider), timeout_ms=timeout_ms)
|
|
161
|
+
return
|
|
162
|
+
if cmd == "available":
|
|
163
|
+
if bool(getattr(args, "all", False)):
|
|
164
|
+
_print_all_available_models(timeout_ms=timeout_ms)
|
|
165
|
+
return
|
|
166
|
+
_print_available_models(str(args.provider), timeout_ms=timeout_ms)
|
|
167
|
+
return
|
|
168
|
+
if cmd == "unsupported":
|
|
169
|
+
if getattr(args, "provider", None):
|
|
170
|
+
_print_unsupported_models(str(args.provider), timeout_ms=timeout_ms)
|
|
171
|
+
return
|
|
172
|
+
_print_unsupported(timeout_ms=timeout_ms)
|
|
173
|
+
return
|
|
174
|
+
if cmd == "stale":
|
|
175
|
+
_print_stale_models(str(args.provider), timeout_ms=timeout_ms)
|
|
176
|
+
return
|
|
177
|
+
raise SystemExit(f"unknown model subcommand: {cmd}")
|
|
178
|
+
except BrokenPipeError:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
if args.command == "probe":
|
|
182
|
+
try:
|
|
183
|
+
raise SystemExit(_run_probe(args, timeout_ms=timeout_ms))
|
|
184
|
+
except BrokenPipeError:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
provider, model_id = _split_model(args.model)
|
|
188
|
+
prompt = args.prompt
|
|
189
|
+
if prompt is None and args.prompt_path:
|
|
190
|
+
try:
|
|
191
|
+
with open(args.prompt_path, "r", encoding="utf-8") as f:
|
|
192
|
+
prompt = f.read()
|
|
193
|
+
except OSError as e:
|
|
194
|
+
raise SystemExit(f"cannot read --prompt-path: {e}") from None
|
|
195
|
+
client = Client()
|
|
196
|
+
_apply_protocol_override(client, provider=provider, protocol=args.protocol)
|
|
197
|
+
|
|
198
|
+
cap = client.capabilities(args.model)
|
|
199
|
+
output = _infer_output_spec(provider=provider, model_id=model_id, cap=cap)
|
|
200
|
+
|
|
201
|
+
parts = _build_input_parts(
|
|
202
|
+
prompt=prompt,
|
|
203
|
+
image_path=args.image_path,
|
|
204
|
+
audio_path=args.audio_path,
|
|
205
|
+
video_path=args.video_path,
|
|
206
|
+
input_modalities=set(cap.input_modalities),
|
|
207
|
+
output_modalities=set(output.modalities),
|
|
208
|
+
provider=provider,
|
|
209
|
+
model_id=model_id,
|
|
210
|
+
)
|
|
211
|
+
req = GenerateRequest(
|
|
212
|
+
model=args.model,
|
|
213
|
+
input=[Message(role="user", content=parts)],
|
|
214
|
+
output=output,
|
|
215
|
+
wait=True,
|
|
216
|
+
)
|
|
217
|
+
if timeout_ms is not None:
|
|
218
|
+
req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
wait_spinner = bool(req.wait) and bool(getattr(cap, "supports_job", False))
|
|
222
|
+
show_progress = wait_spinner and sys.stderr.isatty()
|
|
223
|
+
resp, elapsed_s = _run_with_spinner(
|
|
224
|
+
lambda: client.generate(req),
|
|
225
|
+
enabled=show_progress,
|
|
226
|
+
label="等待任务完成",
|
|
227
|
+
)
|
|
228
|
+
if resp.status != "completed":
|
|
229
|
+
if resp.job and resp.job.job_id:
|
|
230
|
+
print(resp.job.job_id)
|
|
231
|
+
if resp.status == "running":
|
|
232
|
+
effective_timeout_ms = timeout_ms
|
|
233
|
+
if effective_timeout_ms is None:
|
|
234
|
+
effective_timeout_ms = getattr(
|
|
235
|
+
client, "_default_timeout_ms", None
|
|
236
|
+
)
|
|
237
|
+
timeout_note = (
|
|
238
|
+
f"{effective_timeout_ms}ms"
|
|
239
|
+
if isinstance(effective_timeout_ms, int)
|
|
240
|
+
else "timeout"
|
|
241
|
+
)
|
|
242
|
+
print(
|
|
243
|
+
f"[INFO] 任务仍在运行(等待 {elapsed_s:.1f}s,可能已超时 {timeout_note});已返回 job_id。"
|
|
244
|
+
"可增大 --timeout-ms 或设置 NOUS_GENAI_TIMEOUT_MS 后重试。",
|
|
245
|
+
file=sys.stderr,
|
|
246
|
+
)
|
|
247
|
+
if args.output_path:
|
|
248
|
+
print(
|
|
249
|
+
f"[INFO] 未写入输出文件:{args.output_path}",
|
|
250
|
+
file=sys.stderr,
|
|
251
|
+
)
|
|
252
|
+
return
|
|
253
|
+
raise SystemExit(f"[FAIL]: request status={resp.status}")
|
|
254
|
+
if not resp.output:
|
|
255
|
+
raise SystemExit("[FAIL]: missing output")
|
|
256
|
+
_write_response(
|
|
257
|
+
resp.output[0].content,
|
|
258
|
+
output=output,
|
|
259
|
+
output_path=args.output_path,
|
|
260
|
+
timeout_ms=timeout_ms,
|
|
261
|
+
download_auth=_download_auth(client, provider=provider),
|
|
262
|
+
)
|
|
263
|
+
if show_progress:
|
|
264
|
+
print(f"[INFO] 完成,用时 {elapsed_s:.1f}s", file=sys.stderr)
|
|
265
|
+
except GenAIError as e:
|
|
266
|
+
code = f" ({e.info.provider_code})" if e.info.provider_code else ""
|
|
267
|
+
retryable = " retryable" if e.info.retryable else ""
|
|
268
|
+
raise SystemExit(
|
|
269
|
+
f"[FAIL]{code}{retryable}: {e.info.type}: {e.info.message}"
|
|
270
|
+
) from None
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
_DEFAULT_VIDEO_URL = (
|
|
274
|
+
"https://interactive-examples.mdn.mozilla.net/media/cc0-videos/flower.mp4"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _run_probe(args: argparse.Namespace, *, timeout_ms: int | None) -> int:
|
|
279
|
+
from .client import _normalize_provider
|
|
280
|
+
from .reference import get_sdk_supported_models_for_provider
|
|
281
|
+
|
|
282
|
+
provider = _normalize_provider(str(args.provider))
|
|
283
|
+
if not provider:
|
|
284
|
+
raise SystemExit("--provider must be non-empty")
|
|
285
|
+
|
|
286
|
+
if bool(args.all) == bool(args.model):
|
|
287
|
+
raise SystemExit('probe requires exactly one of: "--model" or "--all"')
|
|
288
|
+
|
|
289
|
+
if args.all:
|
|
290
|
+
rows = get_sdk_supported_models_for_provider(provider)
|
|
291
|
+
model_ids = [
|
|
292
|
+
str(r["model_id"])
|
|
293
|
+
for r in rows
|
|
294
|
+
if isinstance(r, dict) and isinstance(r.get("model_id"), str)
|
|
295
|
+
]
|
|
296
|
+
else:
|
|
297
|
+
model_ids = _parse_probe_models(provider, str(args.model))
|
|
298
|
+
|
|
299
|
+
if not model_ids:
|
|
300
|
+
raise SystemExit(f"no models to probe for provider={provider}")
|
|
301
|
+
|
|
302
|
+
client = Client()
|
|
303
|
+
totals = {"ok": 0, "fail": 0, "skip": 0}
|
|
304
|
+
|
|
305
|
+
for model_id in model_ids:
|
|
306
|
+
model = f"{provider}:{model_id}"
|
|
307
|
+
print(f"== {model} ==")
|
|
308
|
+
try:
|
|
309
|
+
cap = client.capabilities(model)
|
|
310
|
+
except GenAIError as e:
|
|
311
|
+
print(f"[FAIL] capabilities: {e.info.type}: {e.info.message}")
|
|
312
|
+
totals["fail"] += 1
|
|
313
|
+
print()
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
modes = _probe_modes_for(cap)
|
|
317
|
+
print(
|
|
318
|
+
"declared:"
|
|
319
|
+
f" modes={','.join(modes)}"
|
|
320
|
+
f" in={','.join(sorted(cap.input_modalities))}"
|
|
321
|
+
f" out={','.join(sorted(cap.output_modalities))}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
results = _probe_model(
|
|
325
|
+
client,
|
|
326
|
+
provider=provider,
|
|
327
|
+
model_id=model_id,
|
|
328
|
+
cap=cap,
|
|
329
|
+
timeout_ms=timeout_ms,
|
|
330
|
+
)
|
|
331
|
+
for k in totals.keys():
|
|
332
|
+
totals[k] += results[k]
|
|
333
|
+
|
|
334
|
+
status = "OK" if results["fail"] == 0 else "FAIL"
|
|
335
|
+
print(
|
|
336
|
+
f"result: {status} (ok={results['ok']} fail={results['fail']} skip={results['skip']})"
|
|
337
|
+
)
|
|
338
|
+
print()
|
|
339
|
+
|
|
340
|
+
print(f"[SUMMARY] ok={totals['ok']} fail={totals['fail']} skip={totals['skip']}")
|
|
341
|
+
return 0 if totals["fail"] == 0 else 1
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
_T = TypeVar("_T")
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _run_with_spinner(
|
|
348
|
+
fn: Callable[[], _T], *, enabled: bool, label: str
|
|
349
|
+
) -> tuple[_T, float]:
|
|
350
|
+
start = time.perf_counter()
|
|
351
|
+
if not enabled or not sys.stderr.isatty():
|
|
352
|
+
out = fn()
|
|
353
|
+
return out, time.perf_counter() - start
|
|
354
|
+
|
|
355
|
+
done = threading.Event()
|
|
356
|
+
result: dict[str, _T] = {}
|
|
357
|
+
error: dict[str, BaseException] = {}
|
|
358
|
+
|
|
359
|
+
def _worker() -> None:
|
|
360
|
+
try:
|
|
361
|
+
result["value"] = fn()
|
|
362
|
+
except BaseException as e: # noqa: BLE001
|
|
363
|
+
error["exc"] = e
|
|
364
|
+
finally:
|
|
365
|
+
done.set()
|
|
366
|
+
|
|
367
|
+
t = threading.Thread(target=_worker, name="genai-cli-wait", daemon=True)
|
|
368
|
+
t.start()
|
|
369
|
+
|
|
370
|
+
if done.wait(0.25):
|
|
371
|
+
t.join()
|
|
372
|
+
exc = error.get("exc")
|
|
373
|
+
if exc is not None:
|
|
374
|
+
raise exc
|
|
375
|
+
if "value" not in result:
|
|
376
|
+
raise RuntimeError("missing result value")
|
|
377
|
+
return result["value"], time.perf_counter() - start
|
|
378
|
+
|
|
379
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
380
|
+
i = 0
|
|
381
|
+
try:
|
|
382
|
+
while not done.wait(0.1):
|
|
383
|
+
frame = frames[i % len(frames)]
|
|
384
|
+
i += 1
|
|
385
|
+
elapsed = time.perf_counter() - start
|
|
386
|
+
sys.stderr.write(f"\r{frame} {label}... {elapsed:5.1f}s")
|
|
387
|
+
sys.stderr.flush()
|
|
388
|
+
finally:
|
|
389
|
+
sys.stderr.write("\r" + (" " * 64) + "\r")
|
|
390
|
+
sys.stderr.flush()
|
|
391
|
+
t.join()
|
|
392
|
+
|
|
393
|
+
exc = error.get("exc")
|
|
394
|
+
if exc is not None:
|
|
395
|
+
raise exc
|
|
396
|
+
if "value" not in result:
|
|
397
|
+
raise RuntimeError("missing result value")
|
|
398
|
+
return result["value"], time.perf_counter() - start
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _parse_probe_models(provider: str, value: str) -> list[str]:
|
|
402
|
+
from .client import _normalize_provider
|
|
403
|
+
|
|
404
|
+
p = _normalize_provider(provider)
|
|
405
|
+
items = [x.strip() for x in value.split(",")]
|
|
406
|
+
out: list[str] = []
|
|
407
|
+
seen: set[str] = set()
|
|
408
|
+
for raw in items:
|
|
409
|
+
if not raw:
|
|
410
|
+
continue
|
|
411
|
+
if ":" in raw:
|
|
412
|
+
pp, mid = raw.split(":", 1)
|
|
413
|
+
pp = _normalize_provider(pp)
|
|
414
|
+
mid = mid.strip()
|
|
415
|
+
if pp != p:
|
|
416
|
+
raise SystemExit(f"model provider mismatch: expected {p}, got {pp}")
|
|
417
|
+
raw = mid
|
|
418
|
+
if not raw:
|
|
419
|
+
continue
|
|
420
|
+
if raw in seen:
|
|
421
|
+
continue
|
|
422
|
+
seen.add(raw)
|
|
423
|
+
out.append(raw)
|
|
424
|
+
return out
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _probe_modes_for(cap) -> list[str]:
|
|
428
|
+
modes: list[str] = ["sync"]
|
|
429
|
+
if cap.supports_stream:
|
|
430
|
+
modes.append("stream")
|
|
431
|
+
if cap.supports_job:
|
|
432
|
+
modes.append("job")
|
|
433
|
+
modes.append("async")
|
|
434
|
+
return modes
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def _probe_model(
|
|
438
|
+
client: Client,
|
|
439
|
+
*,
|
|
440
|
+
provider: str,
|
|
441
|
+
model_id: str,
|
|
442
|
+
cap,
|
|
443
|
+
timeout_ms: int | None,
|
|
444
|
+
) -> dict[str, int]:
|
|
445
|
+
results = {"ok": 0, "fail": 0, "skip": 0}
|
|
446
|
+
|
|
447
|
+
out_modalities = sorted(cap.output_modalities)
|
|
448
|
+
for i, out_modality in enumerate(out_modalities):
|
|
449
|
+
ok = _probe_output_modality(
|
|
450
|
+
client,
|
|
451
|
+
provider=provider,
|
|
452
|
+
model_id=model_id,
|
|
453
|
+
cap=cap,
|
|
454
|
+
out_modality=out_modality,
|
|
455
|
+
timeout_ms=timeout_ms,
|
|
456
|
+
probe_job=bool(cap.supports_job) and i == 0,
|
|
457
|
+
)
|
|
458
|
+
_accumulate(results, ok)
|
|
459
|
+
|
|
460
|
+
for in_modality in sorted(set(cap.input_modalities) - {"text"}):
|
|
461
|
+
ok = _probe_input_modality(
|
|
462
|
+
client,
|
|
463
|
+
provider=provider,
|
|
464
|
+
model_id=model_id,
|
|
465
|
+
cap=cap,
|
|
466
|
+
in_modality=in_modality,
|
|
467
|
+
timeout_ms=timeout_ms,
|
|
468
|
+
)
|
|
469
|
+
_accumulate(results, ok)
|
|
470
|
+
|
|
471
|
+
if cap.supports_stream:
|
|
472
|
+
ok = _probe_stream_mode(
|
|
473
|
+
client,
|
|
474
|
+
provider=provider,
|
|
475
|
+
model_id=model_id,
|
|
476
|
+
cap=cap,
|
|
477
|
+
timeout_ms=timeout_ms,
|
|
478
|
+
)
|
|
479
|
+
_accumulate(results, ok)
|
|
480
|
+
|
|
481
|
+
return results
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _accumulate(totals: dict[str, int], outcome: dict[str, int]) -> None:
|
|
485
|
+
for k, v in outcome.items():
|
|
486
|
+
totals[k] += v
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def _probe_output_modality(
|
|
490
|
+
client: Client,
|
|
491
|
+
*,
|
|
492
|
+
provider: str,
|
|
493
|
+
model_id: str,
|
|
494
|
+
cap,
|
|
495
|
+
out_modality: str,
|
|
496
|
+
timeout_ms: int | None,
|
|
497
|
+
probe_job: bool,
|
|
498
|
+
) -> dict[str, int]:
|
|
499
|
+
label = f"output:{out_modality}"
|
|
500
|
+
try:
|
|
501
|
+
req = _build_probe_request(
|
|
502
|
+
provider=provider,
|
|
503
|
+
model_id=model_id,
|
|
504
|
+
cap=cap,
|
|
505
|
+
out_modality=out_modality,
|
|
506
|
+
in_modality=None,
|
|
507
|
+
timeout_ms=timeout_ms,
|
|
508
|
+
force_wait=False if cap.supports_job else None,
|
|
509
|
+
)
|
|
510
|
+
resp = client.generate(req)
|
|
511
|
+
_validate_probe_response(resp, expected_out=out_modality)
|
|
512
|
+
except GenAIError as e:
|
|
513
|
+
return _probe_fail(label, e)
|
|
514
|
+
except SystemExit as e:
|
|
515
|
+
return _probe_fail(label, e)
|
|
516
|
+
except Exception as e:
|
|
517
|
+
return _probe_fail(label, e)
|
|
518
|
+
out = _probe_ok(label)
|
|
519
|
+
if not probe_job:
|
|
520
|
+
return out
|
|
521
|
+
if resp.status != "running":
|
|
522
|
+
_accumulate(out, _probe_skip("mode:job", f"response status={resp.status}"))
|
|
523
|
+
return out
|
|
524
|
+
_accumulate(out, _probe_ok("mode:job"))
|
|
525
|
+
return out
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _probe_input_modality(
|
|
529
|
+
client: Client,
|
|
530
|
+
*,
|
|
531
|
+
provider: str,
|
|
532
|
+
model_id: str,
|
|
533
|
+
cap,
|
|
534
|
+
in_modality: str,
|
|
535
|
+
timeout_ms: int | None,
|
|
536
|
+
) -> dict[str, int]:
|
|
537
|
+
out_modality = (
|
|
538
|
+
"text"
|
|
539
|
+
if "text" in set(cap.output_modalities)
|
|
540
|
+
else sorted(cap.output_modalities)[0]
|
|
541
|
+
)
|
|
542
|
+
label = f"input:{in_modality}"
|
|
543
|
+
try:
|
|
544
|
+
req = _build_probe_request(
|
|
545
|
+
provider=provider,
|
|
546
|
+
model_id=model_id,
|
|
547
|
+
cap=cap,
|
|
548
|
+
out_modality=out_modality,
|
|
549
|
+
in_modality=in_modality,
|
|
550
|
+
timeout_ms=timeout_ms,
|
|
551
|
+
force_wait=False if cap.supports_job else None,
|
|
552
|
+
)
|
|
553
|
+
resp = client.generate(req)
|
|
554
|
+
_validate_probe_response(resp, expected_out=out_modality)
|
|
555
|
+
except GenAIError as e:
|
|
556
|
+
return _probe_fail(label, e)
|
|
557
|
+
except SystemExit as e:
|
|
558
|
+
return _probe_fail(label, e)
|
|
559
|
+
except Exception as e:
|
|
560
|
+
return _probe_fail(label, e)
|
|
561
|
+
return _probe_ok(label)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def _probe_stream_mode(
|
|
565
|
+
client: Client,
|
|
566
|
+
*,
|
|
567
|
+
provider: str,
|
|
568
|
+
model_id: str,
|
|
569
|
+
cap,
|
|
570
|
+
timeout_ms: int | None,
|
|
571
|
+
) -> dict[str, int]:
|
|
572
|
+
label = "mode:stream"
|
|
573
|
+
if "text" not in set(cap.output_modalities):
|
|
574
|
+
return _probe_skip(label, "stream probe requires text output")
|
|
575
|
+
try:
|
|
576
|
+
req = _build_probe_request(
|
|
577
|
+
provider=provider,
|
|
578
|
+
model_id=model_id,
|
|
579
|
+
cap=cap,
|
|
580
|
+
out_modality="text",
|
|
581
|
+
in_modality=None,
|
|
582
|
+
timeout_ms=timeout_ms,
|
|
583
|
+
)
|
|
584
|
+
deltas = 0
|
|
585
|
+
for ev in client.generate_stream(req):
|
|
586
|
+
if ev.type == "output.text.delta":
|
|
587
|
+
delta = ev.data.get("delta")
|
|
588
|
+
if isinstance(delta, str) and delta:
|
|
589
|
+
deltas += 1
|
|
590
|
+
break
|
|
591
|
+
if ev.type == "done":
|
|
592
|
+
break
|
|
593
|
+
if deltas == 0:
|
|
594
|
+
raise SystemExit("no output.text.delta received")
|
|
595
|
+
except GenAIError as e:
|
|
596
|
+
return _probe_fail(label, e)
|
|
597
|
+
except SystemExit as e:
|
|
598
|
+
return _probe_fail(label, e)
|
|
599
|
+
except Exception as e:
|
|
600
|
+
return _probe_fail(label, e)
|
|
601
|
+
return _probe_ok(label)
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def _build_probe_request(
|
|
605
|
+
*,
|
|
606
|
+
provider: str,
|
|
607
|
+
model_id: str,
|
|
608
|
+
cap,
|
|
609
|
+
out_modality: str,
|
|
610
|
+
in_modality: str | None,
|
|
611
|
+
timeout_ms: int | None,
|
|
612
|
+
force_wait: bool | None = None,
|
|
613
|
+
) -> GenerateRequest:
|
|
614
|
+
model = f"{provider}:{model_id}"
|
|
615
|
+
wait = True
|
|
616
|
+
if force_wait is not None:
|
|
617
|
+
wait = force_wait
|
|
618
|
+
elif out_modality == "video":
|
|
619
|
+
wait = False
|
|
620
|
+
|
|
621
|
+
output = _output_spec_for_modality(
|
|
622
|
+
provider=provider, model_id=model_id, modality=out_modality
|
|
623
|
+
)
|
|
624
|
+
parts = _probe_input_parts(
|
|
625
|
+
provider=provider,
|
|
626
|
+
model_id=model_id,
|
|
627
|
+
cap=cap,
|
|
628
|
+
out_modality=out_modality,
|
|
629
|
+
in_modality=in_modality,
|
|
630
|
+
)
|
|
631
|
+
req = GenerateRequest(
|
|
632
|
+
model=model,
|
|
633
|
+
input=[Message(role="user", content=parts)],
|
|
634
|
+
output=output,
|
|
635
|
+
wait=wait,
|
|
636
|
+
)
|
|
637
|
+
if timeout_ms is not None:
|
|
638
|
+
req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
|
|
639
|
+
return req
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def _probe_input_parts(
|
|
643
|
+
*,
|
|
644
|
+
provider: str,
|
|
645
|
+
model_id: str,
|
|
646
|
+
cap,
|
|
647
|
+
out_modality: str,
|
|
648
|
+
in_modality: str | None,
|
|
649
|
+
) -> list[Part]:
|
|
650
|
+
prompt = _probe_prompt(out_modality=out_modality, in_modality=in_modality)
|
|
651
|
+
if out_modality == "embedding":
|
|
652
|
+
return [Part.from_text(prompt)]
|
|
653
|
+
|
|
654
|
+
if in_modality is None:
|
|
655
|
+
if set(cap.input_modalities) == {"audio"}:
|
|
656
|
+
return [
|
|
657
|
+
Part(
|
|
658
|
+
type="audio",
|
|
659
|
+
mime_type="audio/wav",
|
|
660
|
+
source=PartSourceBytes(data=_probe_wav_bytes()),
|
|
661
|
+
)
|
|
662
|
+
]
|
|
663
|
+
return [Part.from_text(prompt)]
|
|
664
|
+
|
|
665
|
+
if in_modality == "text":
|
|
666
|
+
return [Part.from_text(prompt)]
|
|
667
|
+
if in_modality == "image":
|
|
668
|
+
return [Part.from_text(prompt), _probe_image_part()]
|
|
669
|
+
if in_modality == "audio":
|
|
670
|
+
text_meta: dict[str, object] = {}
|
|
671
|
+
if set(cap.input_modalities) == {"audio"} and out_modality == "text":
|
|
672
|
+
text_meta = {"transcription_prompt": True}
|
|
673
|
+
parts = [
|
|
674
|
+
Part(
|
|
675
|
+
type="audio",
|
|
676
|
+
mime_type="audio/wav",
|
|
677
|
+
source=PartSourceBytes(data=_probe_wav_bytes()),
|
|
678
|
+
)
|
|
679
|
+
]
|
|
680
|
+
if prompt:
|
|
681
|
+
parts.insert(0, Part(type="text", text=prompt, meta=text_meta))
|
|
682
|
+
return parts
|
|
683
|
+
if in_modality == "video":
|
|
684
|
+
return [
|
|
685
|
+
Part.from_text(prompt),
|
|
686
|
+
Part(
|
|
687
|
+
type="video",
|
|
688
|
+
mime_type="video/mp4",
|
|
689
|
+
source=PartSourceUrl(url=_DEFAULT_VIDEO_URL),
|
|
690
|
+
),
|
|
691
|
+
]
|
|
692
|
+
raise SystemExit(f"unknown input modality: {in_modality}")
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def _probe_prompt(*, out_modality: str, in_modality: str | None) -> str:
|
|
696
|
+
if out_modality == "embedding":
|
|
697
|
+
return "hello"
|
|
698
|
+
if out_modality == "image":
|
|
699
|
+
return "生成一张简单的红色方块图。"
|
|
700
|
+
if out_modality == "audio":
|
|
701
|
+
return "用中文说:你好。"
|
|
702
|
+
if out_modality == "video":
|
|
703
|
+
return "生成一个简单短视频:一只猫在草地上奔跑。"
|
|
704
|
+
if in_modality == "image":
|
|
705
|
+
return "请用一句话描述这张图。"
|
|
706
|
+
if in_modality == "audio":
|
|
707
|
+
return "请转写音频内容。"
|
|
708
|
+
if in_modality == "video":
|
|
709
|
+
return "请用一句话描述视频内容。"
|
|
710
|
+
return "只回复:pong"
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
def _output_spec_for_modality(
|
|
714
|
+
*, provider: str, model_id: str, modality: str
|
|
715
|
+
) -> OutputSpec:
|
|
716
|
+
if modality == "embedding":
|
|
717
|
+
return OutputSpec(modalities=["embedding"], embedding=OutputEmbeddingSpec())
|
|
718
|
+
if modality == "image":
|
|
719
|
+
return OutputSpec(modalities=["image"], image=OutputImageSpec(n=1))
|
|
720
|
+
if modality == "audio":
|
|
721
|
+
voice, language = _probe_audio_voice(provider=provider, model_id=model_id)
|
|
722
|
+
return OutputSpec(
|
|
723
|
+
modalities=["audio"],
|
|
724
|
+
audio=OutputAudioSpec(voice=voice, language=language, format="mp3"),
|
|
725
|
+
)
|
|
726
|
+
if modality == "video":
|
|
727
|
+
duration = _probe_video_duration(provider=provider, model_id=model_id)
|
|
728
|
+
return OutputSpec(
|
|
729
|
+
modalities=["video"],
|
|
730
|
+
video=OutputVideoSpec(duration_sec=duration, aspect_ratio="16:9"),
|
|
731
|
+
)
|
|
732
|
+
if modality == "text":
|
|
733
|
+
return OutputSpec(modalities=["text"], text=OutputTextSpec(format="text"))
|
|
734
|
+
raise SystemExit(f"unknown output modality: {modality}")
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def _probe_audio_voice(*, provider: str, model_id: str) -> tuple[str, str | None]:
|
|
738
|
+
mid_l = model_id.lower().strip()
|
|
739
|
+
if (
|
|
740
|
+
provider in {"google", "tuzi-google"}
|
|
741
|
+
or mid_l.startswith(("gemini-", "gemma-"))
|
|
742
|
+
or "native-audio" in mid_l
|
|
743
|
+
):
|
|
744
|
+
return ("Kore", "en-US")
|
|
745
|
+
if provider == "aliyun":
|
|
746
|
+
return ("Cherry", "zh-CN")
|
|
747
|
+
return ("alloy", None)
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def _probe_video_duration(*, provider: str, model_id: str) -> int:
|
|
751
|
+
mid_l = model_id.lower().strip()
|
|
752
|
+
if mid_l.startswith("sora-"):
|
|
753
|
+
return 4
|
|
754
|
+
if mid_l.startswith("veo-"):
|
|
755
|
+
return 5
|
|
756
|
+
if provider.startswith("tuzi"):
|
|
757
|
+
return 10
|
|
758
|
+
return 4
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
def _probe_image_part() -> Part:
|
|
762
|
+
return Part(
|
|
763
|
+
type="image",
|
|
764
|
+
mime_type="image/png",
|
|
765
|
+
source=PartSourceBytes(data=_probe_png_bytes()),
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _probe_png_bytes() -> bytes:
|
|
770
|
+
import struct
|
|
771
|
+
import zlib
|
|
772
|
+
|
|
773
|
+
width = 64
|
|
774
|
+
height = 64
|
|
775
|
+
|
|
776
|
+
# RGB red square (unfiltered scanlines).
|
|
777
|
+
row = b"\x00" + (b"\xff\x00\x00" * width)
|
|
778
|
+
raw = row * height
|
|
779
|
+
compressed = zlib.compress(raw, level=6)
|
|
780
|
+
|
|
781
|
+
def chunk(typ: bytes, data: bytes) -> bytes:
|
|
782
|
+
crc = zlib.crc32(typ)
|
|
783
|
+
crc = zlib.crc32(data, crc)
|
|
784
|
+
return (
|
|
785
|
+
struct.pack(">I", len(data))
|
|
786
|
+
+ typ
|
|
787
|
+
+ data
|
|
788
|
+
+ struct.pack(">I", crc & 0xFFFFFFFF)
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
signature = b"\x89PNG\r\n\x1a\n"
|
|
792
|
+
ihdr = struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)
|
|
793
|
+
return (
|
|
794
|
+
signature
|
|
795
|
+
+ chunk(b"IHDR", ihdr)
|
|
796
|
+
+ chunk(b"IDAT", compressed)
|
|
797
|
+
+ chunk(b"IEND", b"")
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
def _probe_wav_bytes() -> bytes:
|
|
802
|
+
import io
|
|
803
|
+
import wave
|
|
804
|
+
|
|
805
|
+
buf = io.BytesIO()
|
|
806
|
+
with wave.open(buf, "wb") as wf:
|
|
807
|
+
wf.setnchannels(1)
|
|
808
|
+
wf.setsampwidth(2)
|
|
809
|
+
wf.setframerate(16_000)
|
|
810
|
+
wf.writeframes(b"\x00\x00" * 16_000) # 1s silence
|
|
811
|
+
return buf.getvalue()
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def _validate_probe_response(resp, *, expected_out: str) -> None:
|
|
815
|
+
if resp.status == "running":
|
|
816
|
+
if not resp.job or not resp.job.job_id:
|
|
817
|
+
raise SystemExit("running response missing job info")
|
|
818
|
+
return
|
|
819
|
+
if resp.status != "completed":
|
|
820
|
+
raise SystemExit(f"unexpected status: {resp.status}")
|
|
821
|
+
if not resp.output:
|
|
822
|
+
raise SystemExit("missing output")
|
|
823
|
+
parts = [p for m in resp.output for p in m.content]
|
|
824
|
+
if expected_out == "text":
|
|
825
|
+
if not any(p.type == "text" and isinstance(p.text, str) for p in parts):
|
|
826
|
+
raise SystemExit("missing text output part")
|
|
827
|
+
return
|
|
828
|
+
if expected_out == "embedding":
|
|
829
|
+
if not any(
|
|
830
|
+
p.type == "embedding" and isinstance(p.embedding, list) for p in parts
|
|
831
|
+
):
|
|
832
|
+
raise SystemExit("missing embedding output part")
|
|
833
|
+
return
|
|
834
|
+
if expected_out in {"image", "audio", "video"}:
|
|
835
|
+
if not any(p.type == expected_out and p.source is not None for p in parts):
|
|
836
|
+
raise SystemExit(f"missing {expected_out} output part")
|
|
837
|
+
return
|
|
838
|
+
raise SystemExit(f"unknown expected output modality: {expected_out}")
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def _probe_ok(label: str) -> dict[str, int]:
|
|
842
|
+
print(f"[OK] {label}")
|
|
843
|
+
return {"ok": 1, "fail": 0, "skip": 0}
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
def _probe_skip(label: str, reason: str) -> dict[str, int]:
|
|
847
|
+
print(f"[SKIP] {label}: {reason}")
|
|
848
|
+
return {"ok": 0, "fail": 0, "skip": 1}
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _probe_fail(label: str, err: BaseException) -> dict[str, int]:
|
|
852
|
+
if isinstance(err, GenAIError):
|
|
853
|
+
msg = f"{err.info.type}: {err.info.message}"
|
|
854
|
+
else:
|
|
855
|
+
msg = str(err) or err.__class__.__name__
|
|
856
|
+
print(f"[FAIL] {label}: {msg}")
|
|
857
|
+
return {"ok": 0, "fail": 1, "skip": 0}
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def _print_sdk_supported() -> None:
|
|
861
|
+
models = get_sdk_supported_models()
|
|
862
|
+
by_provider: dict[str, list[dict]] = {}
|
|
863
|
+
for m in models:
|
|
864
|
+
by_provider.setdefault(m["provider"], []).append(m)
|
|
865
|
+
|
|
866
|
+
for p in sorted(by_provider.keys()):
|
|
867
|
+
print(f"== {p} ==")
|
|
868
|
+
for m in sorted(by_provider[p], key=lambda x: (x["category"], x["model_id"])):
|
|
869
|
+
inp = ",".join(m["input_modalities"])
|
|
870
|
+
out = ",".join(m["output_modalities"])
|
|
871
|
+
modes = ",".join(m["modes"])
|
|
872
|
+
print(
|
|
873
|
+
f"{m['category']:13} {m['model']:45} modes={modes:18} in={inp:18} out={out:18}"
|
|
874
|
+
)
|
|
875
|
+
print()
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
def _print_provider_models(provider: str, *, timeout_ms: int | None) -> None:
|
|
879
|
+
from .client import _normalize_provider
|
|
880
|
+
|
|
881
|
+
client = Client()
|
|
882
|
+
p = _normalize_provider(provider)
|
|
883
|
+
for model_id in sorted(
|
|
884
|
+
client.list_provider_models(provider, timeout_ms=timeout_ms)
|
|
885
|
+
):
|
|
886
|
+
print(f"{p}:{model_id}")
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def _print_available_models(provider: str, *, timeout_ms: int | None) -> None:
|
|
890
|
+
from .client import _normalize_provider
|
|
891
|
+
from .reference import get_sdk_supported_models_for_provider
|
|
892
|
+
|
|
893
|
+
client = Client()
|
|
894
|
+
p = _normalize_provider(provider)
|
|
895
|
+
rows = get_sdk_supported_models_for_provider(p)
|
|
896
|
+
by_model_id = {m["model_id"]: m for m in rows}
|
|
897
|
+
for model_id in client.list_available_models(provider, timeout_ms=timeout_ms):
|
|
898
|
+
m = by_model_id.get(model_id)
|
|
899
|
+
if m is None:
|
|
900
|
+
print(f"{p}:{model_id}")
|
|
901
|
+
continue
|
|
902
|
+
inp = ",".join(m["input_modalities"])
|
|
903
|
+
out = ",".join(m["output_modalities"])
|
|
904
|
+
modes = ",".join(m["modes"])
|
|
905
|
+
print(f"{m['model']:45} modes={modes:18} in={inp:18} out={out:18}")
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def _print_all_available_models(*, timeout_ms: int | None) -> None:
|
|
909
|
+
models = get_sdk_supported_models()
|
|
910
|
+
by_model: dict[str, dict] = {m["model"]: m for m in models}
|
|
911
|
+
client = Client()
|
|
912
|
+
for model in client.list_all_available_models(timeout_ms=timeout_ms):
|
|
913
|
+
m = by_model.get(model)
|
|
914
|
+
if m is None:
|
|
915
|
+
print(model)
|
|
916
|
+
continue
|
|
917
|
+
inp = ",".join(m["input_modalities"])
|
|
918
|
+
out = ",".join(m["output_modalities"])
|
|
919
|
+
modes = ",".join(m["modes"])
|
|
920
|
+
print(f"{model:45} modes={modes:18} in={inp:18} out={out:18}")
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
def _print_unsupported_models(provider: str, *, timeout_ms: int | None) -> None:
|
|
924
|
+
from .client import _normalize_provider
|
|
925
|
+
|
|
926
|
+
client = Client()
|
|
927
|
+
p = _normalize_provider(provider)
|
|
928
|
+
for model_id in client.list_unsupported_models(provider, timeout_ms=timeout_ms):
|
|
929
|
+
print(f"{p}:{model_id}")
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def _print_stale_models(provider: str, *, timeout_ms: int | None) -> None:
|
|
933
|
+
from .client import _normalize_provider
|
|
934
|
+
|
|
935
|
+
client = Client()
|
|
936
|
+
p = _normalize_provider(provider)
|
|
937
|
+
for model_id in client.list_stale_models(provider, timeout_ms=timeout_ms):
|
|
938
|
+
print(f"{p}:{model_id}")
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
def _print_unsupported(*, timeout_ms: int | None) -> None:
|
|
942
|
+
supported: dict[str, set[str]] = {}
|
|
943
|
+
catalog = get_model_catalog()
|
|
944
|
+
for provider, model_ids in catalog.items():
|
|
945
|
+
supported[provider] = {m for m in model_ids if isinstance(m, str) and m}
|
|
946
|
+
|
|
947
|
+
client = Client()
|
|
948
|
+
for provider in sorted(catalog.keys()):
|
|
949
|
+
remote = set(client.list_provider_models(provider, timeout_ms=timeout_ms))
|
|
950
|
+
if not remote:
|
|
951
|
+
continue
|
|
952
|
+
unknown = sorted(remote - supported.get(provider, set()))
|
|
953
|
+
if not unknown:
|
|
954
|
+
continue
|
|
955
|
+
print(f"== {provider} ==")
|
|
956
|
+
for model_id in unknown:
|
|
957
|
+
print(f"{provider}:{model_id}")
|
|
958
|
+
print()
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
def _print_mappings() -> None:
|
|
962
|
+
items = get_parameter_mappings()
|
|
963
|
+
items = sorted(
|
|
964
|
+
items,
|
|
965
|
+
key=lambda x: (
|
|
966
|
+
x["provider"],
|
|
967
|
+
x["protocol"],
|
|
968
|
+
x["operation"],
|
|
969
|
+
x["from"],
|
|
970
|
+
x["to"],
|
|
971
|
+
),
|
|
972
|
+
)
|
|
973
|
+
cur = None
|
|
974
|
+
for m in items:
|
|
975
|
+
key = (m["provider"], m["protocol"])
|
|
976
|
+
if key != cur:
|
|
977
|
+
if cur is not None:
|
|
978
|
+
print()
|
|
979
|
+
cur = key
|
|
980
|
+
print(f"== {m['provider']} (protocol={m['protocol']}) ==")
|
|
981
|
+
note = f" # {m['notes']}" if m.get("notes") else ""
|
|
982
|
+
print(f"{m['operation']:14} {m['from']:55} -> {m['to']}{note}")
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _split_model(model: str) -> tuple[str, str]:
|
|
986
|
+
if ":" not in model:
|
|
987
|
+
raise SystemExit('model must be "{provider}:{model_id}"')
|
|
988
|
+
provider, model_id = model.split(":", 1)
|
|
989
|
+
provider = provider.strip().lower()
|
|
990
|
+
model_id = model_id.strip()
|
|
991
|
+
if not provider or not model_id:
|
|
992
|
+
raise SystemExit('model must be "{provider}:{model_id}"')
|
|
993
|
+
return provider, model_id
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
def _apply_protocol_override(
|
|
997
|
+
client: Client, *, provider: str, protocol: str | None
|
|
998
|
+
) -> None:
|
|
999
|
+
if not protocol:
|
|
1000
|
+
return
|
|
1001
|
+
if provider != "openai":
|
|
1002
|
+
raise SystemExit("--protocol only applies to provider=openai")
|
|
1003
|
+
if client._openai is None:
|
|
1004
|
+
raise SystemExit("NOUS_GENAI_OPENAI_API_KEY/OPENAI_API_KEY not configured")
|
|
1005
|
+
client._openai = replace(client._openai, chat_api=protocol)
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
def _infer_output_spec(*, provider: str, model_id: str, cap) -> OutputSpec:
|
|
1009
|
+
out = set(cap.output_modalities)
|
|
1010
|
+
if out == {"embedding"}:
|
|
1011
|
+
return OutputSpec(modalities=["embedding"], embedding=OutputEmbeddingSpec())
|
|
1012
|
+
if out == {"image"}:
|
|
1013
|
+
return OutputSpec(modalities=["image"], image=OutputImageSpec(n=1))
|
|
1014
|
+
if out == {"audio"}:
|
|
1015
|
+
if provider == "google":
|
|
1016
|
+
audio = OutputAudioSpec(voice="Kore", language="en-US")
|
|
1017
|
+
elif provider == "aliyun":
|
|
1018
|
+
audio = OutputAudioSpec(voice="Cherry", language="zh-CN", format="wav")
|
|
1019
|
+
else:
|
|
1020
|
+
audio = OutputAudioSpec(voice="alloy", format="mp3")
|
|
1021
|
+
return OutputSpec(modalities=["audio"], audio=audio)
|
|
1022
|
+
if out == {"video"}:
|
|
1023
|
+
duration = 4
|
|
1024
|
+
if provider.startswith("tuzi"):
|
|
1025
|
+
duration = 10
|
|
1026
|
+
if model_id.lower().startswith("sora-"):
|
|
1027
|
+
duration = 4
|
|
1028
|
+
if provider == "google" and model_id.lower().startswith("veo-"):
|
|
1029
|
+
duration = 5
|
|
1030
|
+
return OutputSpec(
|
|
1031
|
+
modalities=["video"],
|
|
1032
|
+
video=OutputVideoSpec(duration_sec=duration, aspect_ratio="16:9"),
|
|
1033
|
+
)
|
|
1034
|
+
if "text" in out:
|
|
1035
|
+
return OutputSpec(modalities=["text"], text=OutputTextSpec(format="text"))
|
|
1036
|
+
raise SystemExit(f"cannot infer output for model={provider}:{model_id}")
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _build_input_parts(
|
|
1040
|
+
*,
|
|
1041
|
+
prompt: str | None,
|
|
1042
|
+
image_path: str | None,
|
|
1043
|
+
audio_path: str | None,
|
|
1044
|
+
video_path: str | None,
|
|
1045
|
+
input_modalities: set[str],
|
|
1046
|
+
output_modalities: set[str],
|
|
1047
|
+
provider: str,
|
|
1048
|
+
model_id: str,
|
|
1049
|
+
) -> list[Part]:
|
|
1050
|
+
if output_modalities == {"embedding"}:
|
|
1051
|
+
if image_path or audio_path or video_path:
|
|
1052
|
+
raise SystemExit("embedding only supports text input")
|
|
1053
|
+
if not prompt:
|
|
1054
|
+
raise SystemExit("embedding requires --prompt")
|
|
1055
|
+
return [Part.from_text(prompt)]
|
|
1056
|
+
|
|
1057
|
+
if output_modalities == {"image"}:
|
|
1058
|
+
if audio_path or video_path:
|
|
1059
|
+
raise SystemExit(
|
|
1060
|
+
"image generation does not take --audio-path/--video-path input"
|
|
1061
|
+
)
|
|
1062
|
+
if image_path and "image" not in input_modalities:
|
|
1063
|
+
raise SystemExit(
|
|
1064
|
+
"image generation does not take --image-path input for this model"
|
|
1065
|
+
)
|
|
1066
|
+
if not prompt:
|
|
1067
|
+
raise SystemExit("image generation requires --prompt")
|
|
1068
|
+
|
|
1069
|
+
if output_modalities == {"audio"}:
|
|
1070
|
+
if image_path or audio_path or video_path:
|
|
1071
|
+
raise SystemExit(
|
|
1072
|
+
"TTS does not take --image-path/--audio-path/--video-path input"
|
|
1073
|
+
)
|
|
1074
|
+
if not prompt:
|
|
1075
|
+
raise SystemExit("audio generation requires --prompt")
|
|
1076
|
+
return [Part.from_text(prompt)]
|
|
1077
|
+
|
|
1078
|
+
if output_modalities == {"video"}:
|
|
1079
|
+
if image_path or audio_path or video_path:
|
|
1080
|
+
raise SystemExit(
|
|
1081
|
+
"video generation does not take --image-path/--audio-path/--video-path input"
|
|
1082
|
+
)
|
|
1083
|
+
if not prompt:
|
|
1084
|
+
raise SystemExit("video generation requires --prompt")
|
|
1085
|
+
return [Part.from_text(prompt)]
|
|
1086
|
+
|
|
1087
|
+
parts: list[Part] = []
|
|
1088
|
+
if prompt:
|
|
1089
|
+
meta = {}
|
|
1090
|
+
if provider == "openai" and (
|
|
1091
|
+
model_id == "whisper-1" or "-transcribe" in model_id
|
|
1092
|
+
):
|
|
1093
|
+
meta = {"transcription_prompt": True}
|
|
1094
|
+
parts.append(Part(type="text", text=prompt, meta=meta))
|
|
1095
|
+
|
|
1096
|
+
if image_path:
|
|
1097
|
+
mime = detect_mime_type(image_path)
|
|
1098
|
+
if not mime or not mime.startswith("image/"):
|
|
1099
|
+
raise SystemExit(f"cannot detect image mime type for: {image_path}")
|
|
1100
|
+
parts.append(
|
|
1101
|
+
Part(type="image", mime_type=mime, source=PartSourcePath(path=image_path))
|
|
1102
|
+
)
|
|
1103
|
+
if audio_path:
|
|
1104
|
+
mime = detect_mime_type(audio_path)
|
|
1105
|
+
if not mime or not mime.startswith("audio/"):
|
|
1106
|
+
raise SystemExit(f"cannot detect audio mime type for: {audio_path}")
|
|
1107
|
+
parts.append(
|
|
1108
|
+
Part(type="audio", mime_type=mime, source=PartSourcePath(path=audio_path))
|
|
1109
|
+
)
|
|
1110
|
+
if video_path:
|
|
1111
|
+
mime = detect_mime_type(video_path)
|
|
1112
|
+
if not mime or not mime.startswith("video/"):
|
|
1113
|
+
raise SystemExit(f"cannot detect video mime type for: {video_path}")
|
|
1114
|
+
parts.append(
|
|
1115
|
+
Part(type="video", mime_type=mime, source=PartSourcePath(path=video_path))
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
if not parts:
|
|
1119
|
+
raise SystemExit(
|
|
1120
|
+
"missing input: provide --prompt and/or --image-path/--audio-path/--video-path"
|
|
1121
|
+
)
|
|
1122
|
+
return parts
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
def _run_stream_text(
|
|
1126
|
+
client: Client, req: GenerateRequest, *, timeout_ms: int | None
|
|
1127
|
+
) -> str:
|
|
1128
|
+
if timeout_ms is not None:
|
|
1129
|
+
req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
|
|
1130
|
+
chunks: list[str] = []
|
|
1131
|
+
for ev in client.generate_stream(req):
|
|
1132
|
+
if ev.type != "output.text.delta":
|
|
1133
|
+
continue
|
|
1134
|
+
delta = ev.data.get("delta")
|
|
1135
|
+
if isinstance(delta, str) and delta:
|
|
1136
|
+
chunks.append(delta)
|
|
1137
|
+
return "".join(chunks)
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
def _write_text(text: str, *, output_path: str | None) -> None:
|
|
1141
|
+
if output_path:
|
|
1142
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
1143
|
+
f.write(text)
|
|
1144
|
+
print(f"[OK] wrote {output_path}")
|
|
1145
|
+
return
|
|
1146
|
+
print(text)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
def _guess_ext(mime: str | None) -> str:
|
|
1150
|
+
if not mime:
|
|
1151
|
+
return ""
|
|
1152
|
+
m = mime.lower()
|
|
1153
|
+
if m == "image/png":
|
|
1154
|
+
return ".png"
|
|
1155
|
+
if m in {"image/jpeg", "image/jpg"}:
|
|
1156
|
+
return ".jpg"
|
|
1157
|
+
if m == "image/webp":
|
|
1158
|
+
return ".webp"
|
|
1159
|
+
if m in {"audio/mpeg", "audio/mp3"}:
|
|
1160
|
+
return ".mp3"
|
|
1161
|
+
if m in {"audio/wav", "audio/wave"}:
|
|
1162
|
+
return ".wav"
|
|
1163
|
+
if m in {"audio/mp4", "audio/m4a"}:
|
|
1164
|
+
return ".m4a"
|
|
1165
|
+
if m == "video/mp4":
|
|
1166
|
+
return ".mp4"
|
|
1167
|
+
if m == "video/quicktime":
|
|
1168
|
+
return ".mov"
|
|
1169
|
+
return ""
|
|
1170
|
+
|
|
1171
|
+
|
|
1172
|
+
def _download_with_headers(
|
|
1173
|
+
url: str,
|
|
1174
|
+
output_path: str,
|
|
1175
|
+
*,
|
|
1176
|
+
timeout_ms: int | None,
|
|
1177
|
+
headers: dict[str, str] | None,
|
|
1178
|
+
) -> None:
|
|
1179
|
+
download_to_file(
|
|
1180
|
+
url=url,
|
|
1181
|
+
output_path=output_path,
|
|
1182
|
+
timeout_ms=timeout_ms,
|
|
1183
|
+
max_bytes=None,
|
|
1184
|
+
headers=headers,
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
def _download_auth(
|
|
1189
|
+
client: object, *, provider: str
|
|
1190
|
+
) -> tuple[dict[str, str], set[str]] | None:
|
|
1191
|
+
adapter_getter = getattr(client, "_adapter", None)
|
|
1192
|
+
if not callable(adapter_getter):
|
|
1193
|
+
return None
|
|
1194
|
+
try:
|
|
1195
|
+
adapter = adapter_getter(provider)
|
|
1196
|
+
except Exception:
|
|
1197
|
+
return None
|
|
1198
|
+
header_fn = getattr(adapter, "_download_headers", None)
|
|
1199
|
+
if not callable(header_fn):
|
|
1200
|
+
return None
|
|
1201
|
+
try:
|
|
1202
|
+
raw = header_fn()
|
|
1203
|
+
except Exception:
|
|
1204
|
+
return None
|
|
1205
|
+
if not isinstance(raw, dict):
|
|
1206
|
+
return None
|
|
1207
|
+
headers: dict[str, str] = {}
|
|
1208
|
+
for k, v in raw.items():
|
|
1209
|
+
if isinstance(k, str) and k and isinstance(v, str) and v:
|
|
1210
|
+
headers[k] = v
|
|
1211
|
+
if not headers:
|
|
1212
|
+
return None
|
|
1213
|
+
|
|
1214
|
+
base_url = getattr(adapter, "base_url", None)
|
|
1215
|
+
if not isinstance(base_url, str) or not base_url:
|
|
1216
|
+
return None
|
|
1217
|
+
host = urllib.parse.urlparse(base_url).hostname
|
|
1218
|
+
if not isinstance(host, str) or not host:
|
|
1219
|
+
return None
|
|
1220
|
+
return headers, {host.lower()}
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
def _write_response(
|
|
1224
|
+
parts: list[Part],
|
|
1225
|
+
*,
|
|
1226
|
+
output: OutputSpec,
|
|
1227
|
+
output_path: str | None,
|
|
1228
|
+
timeout_ms: int | None,
|
|
1229
|
+
download_auth: tuple[dict[str, str], set[str]] | None,
|
|
1230
|
+
) -> None:
|
|
1231
|
+
modalities = set(output.modalities)
|
|
1232
|
+
if modalities == {"text"}:
|
|
1233
|
+
text = parts[0].text or ""
|
|
1234
|
+
_write_text(text, output_path=output_path)
|
|
1235
|
+
return
|
|
1236
|
+
|
|
1237
|
+
if modalities == {"embedding"}:
|
|
1238
|
+
vec = parts[0].embedding or []
|
|
1239
|
+
if output_path:
|
|
1240
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
|
1241
|
+
json.dump(vec, f, ensure_ascii=False)
|
|
1242
|
+
print(f"[OK] wrote {output_path} (dims={len(vec)})")
|
|
1243
|
+
return
|
|
1244
|
+
head = ", ".join([f"{x:.6f}" for x in vec[:8]])
|
|
1245
|
+
print(f"dims={len(vec)} [{head}{', ...' if len(vec) > 8 else ''}]")
|
|
1246
|
+
return
|
|
1247
|
+
|
|
1248
|
+
if modalities in ({"image"}, {"audio"}, {"video"}):
|
|
1249
|
+
p = next((x for x in parts if x.type in modalities), None)
|
|
1250
|
+
if not p:
|
|
1251
|
+
raise SystemExit("missing binary output part")
|
|
1252
|
+
if not p.source:
|
|
1253
|
+
raise SystemExit("missing output source")
|
|
1254
|
+
if p.source.kind == "url":
|
|
1255
|
+
if output_path:
|
|
1256
|
+
headers: dict[str, str] | None = None
|
|
1257
|
+
if download_auth is not None:
|
|
1258
|
+
allowed = download_auth[1]
|
|
1259
|
+
host = urllib.parse.urlparse(p.source.url).hostname
|
|
1260
|
+
if isinstance(host, str) and host and host.lower() in allowed:
|
|
1261
|
+
headers = download_auth[0]
|
|
1262
|
+
_download_with_headers(
|
|
1263
|
+
p.source.url,
|
|
1264
|
+
output_path,
|
|
1265
|
+
timeout_ms=timeout_ms,
|
|
1266
|
+
headers=headers,
|
|
1267
|
+
)
|
|
1268
|
+
print(f"[OK] downloaded to {output_path}")
|
|
1269
|
+
else:
|
|
1270
|
+
print(p.source.url)
|
|
1271
|
+
return
|
|
1272
|
+
if p.source.kind == "ref":
|
|
1273
|
+
if output_path:
|
|
1274
|
+
raise SystemExit(
|
|
1275
|
+
"cannot write ref output; provider-specific download required"
|
|
1276
|
+
)
|
|
1277
|
+
ref = (
|
|
1278
|
+
f"{p.source.provider}:{p.source.id}"
|
|
1279
|
+
if p.source.provider
|
|
1280
|
+
else p.source.id
|
|
1281
|
+
)
|
|
1282
|
+
print(ref)
|
|
1283
|
+
return
|
|
1284
|
+
if p.source.kind != "bytes":
|
|
1285
|
+
raise SystemExit(f"unsupported output source kind: {p.source.kind}")
|
|
1286
|
+
data: bytes
|
|
1287
|
+
if p.source.encoding == "base64":
|
|
1288
|
+
raw_b64 = p.source.data
|
|
1289
|
+
if not isinstance(raw_b64, str) or not raw_b64:
|
|
1290
|
+
raise SystemExit("invalid base64 output") from None
|
|
1291
|
+
try:
|
|
1292
|
+
data = base64.b64decode(raw_b64)
|
|
1293
|
+
except Exception:
|
|
1294
|
+
raise SystemExit("invalid base64 output") from None
|
|
1295
|
+
else:
|
|
1296
|
+
raw = p.source.data
|
|
1297
|
+
if isinstance(raw, bytearray):
|
|
1298
|
+
raw = bytes(raw)
|
|
1299
|
+
if not isinstance(raw, bytes):
|
|
1300
|
+
raise SystemExit(f"invalid bytes output (encoding={p.source.encoding})")
|
|
1301
|
+
data = raw
|
|
1302
|
+
path = output_path or f"genai_output{_guess_ext(p.mime_type)}"
|
|
1303
|
+
with open(path, "wb") as f:
|
|
1304
|
+
f.write(data)
|
|
1305
|
+
print(f"[OK] wrote {path} ({p.mime_type}, {len(data)} bytes)")
|
|
1306
|
+
return
|
|
1307
|
+
|
|
1308
|
+
raise SystemExit(f"unsupported output modalities: {output.modalities}")
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
def _generate_token() -> str:
|
|
1312
|
+
return f"sk-{secrets.token_urlsafe(32)}"
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
if __name__ == "__main__":
|
|
1316
|
+
main()
|