nous-genai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. nous/__init__.py +3 -0
  2. nous/genai/__init__.py +56 -0
  3. nous/genai/__main__.py +3 -0
  4. nous/genai/_internal/__init__.py +1 -0
  5. nous/genai/_internal/capability_rules.py +476 -0
  6. nous/genai/_internal/config.py +102 -0
  7. nous/genai/_internal/errors.py +63 -0
  8. nous/genai/_internal/http.py +951 -0
  9. nous/genai/_internal/json_schema.py +54 -0
  10. nous/genai/cli.py +1316 -0
  11. nous/genai/client.py +719 -0
  12. nous/genai/mcp_cli.py +275 -0
  13. nous/genai/mcp_server.py +1080 -0
  14. nous/genai/providers/__init__.py +15 -0
  15. nous/genai/providers/aliyun.py +535 -0
  16. nous/genai/providers/anthropic.py +483 -0
  17. nous/genai/providers/gemini.py +1606 -0
  18. nous/genai/providers/openai.py +1909 -0
  19. nous/genai/providers/tuzi.py +1158 -0
  20. nous/genai/providers/volcengine.py +273 -0
  21. nous/genai/reference/__init__.py +17 -0
  22. nous/genai/reference/catalog.py +206 -0
  23. nous/genai/reference/mappings.py +467 -0
  24. nous/genai/reference/mode_overrides.py +26 -0
  25. nous/genai/reference/model_catalog.py +82 -0
  26. nous/genai/reference/model_catalog_data/__init__.py +1 -0
  27. nous/genai/reference/model_catalog_data/aliyun.py +98 -0
  28. nous/genai/reference/model_catalog_data/anthropic.py +10 -0
  29. nous/genai/reference/model_catalog_data/google.py +45 -0
  30. nous/genai/reference/model_catalog_data/openai.py +44 -0
  31. nous/genai/reference/model_catalog_data/tuzi_anthropic.py +21 -0
  32. nous/genai/reference/model_catalog_data/tuzi_google.py +19 -0
  33. nous/genai/reference/model_catalog_data/tuzi_openai.py +75 -0
  34. nous/genai/reference/model_catalog_data/tuzi_web.py +136 -0
  35. nous/genai/reference/model_catalog_data/volcengine.py +107 -0
  36. nous/genai/tools/__init__.py +13 -0
  37. nous/genai/tools/output_parser.py +119 -0
  38. nous/genai/types.py +416 -0
  39. nous/py.typed +1 -0
  40. nous_genai-0.1.0.dist-info/METADATA +200 -0
  41. nous_genai-0.1.0.dist-info/RECORD +45 -0
  42. nous_genai-0.1.0.dist-info/WHEEL +5 -0
  43. nous_genai-0.1.0.dist-info/entry_points.txt +4 -0
  44. nous_genai-0.1.0.dist-info/licenses/LICENSE +190 -0
  45. nous_genai-0.1.0.dist-info/top_level.txt +1 -0
nous/genai/cli.py ADDED
@@ -0,0 +1,1316 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import base64
5
+ import json
6
+ import secrets
7
+ import sys
8
+ import threading
9
+ import time
10
+ import urllib.parse
11
+ from collections.abc import Callable
12
+ from dataclasses import replace
13
+ from typing import TypeVar
14
+
15
+ from .client import Client
16
+ from ._internal.errors import GenAIError
17
+ from ._internal.http import download_to_file
18
+ from .reference import (
19
+ get_model_catalog,
20
+ get_parameter_mappings,
21
+ get_sdk_supported_models,
22
+ )
23
+ from .types import (
24
+ GenerateRequest,
25
+ Message,
26
+ OutputAudioSpec,
27
+ OutputEmbeddingSpec,
28
+ OutputImageSpec,
29
+ OutputSpec,
30
+ OutputTextSpec,
31
+ OutputVideoSpec,
32
+ Part,
33
+ PartSourceBytes,
34
+ PartSourcePath,
35
+ PartSourceUrl,
36
+ detect_mime_type,
37
+ )
38
+
39
+
40
+ def main(argv: list[str] | None = None) -> None:
41
+ parser = argparse.ArgumentParser(
42
+ prog="genai", description="nous-genai CLI (minimal)"
43
+ )
44
+ sub = parser.add_subparsers(dest="command")
45
+
46
+ model = sub.add_parser("model", help="Model discovery and support listing")
47
+ model_sub = model.add_subparsers(dest="model_command", required=True)
48
+
49
+ model_sub.add_parser("sdk", help="List SDK curated models")
50
+
51
+ pm = model_sub.add_parser(
52
+ "provider", help="List remotely available model ids for a provider"
53
+ )
54
+ pm.add_argument(
55
+ "--provider", required=True, help="Provider (e.g. openai/google/tuzi-openai)"
56
+ )
57
+ pm.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
58
+
59
+ av = model_sub.add_parser(
60
+ "available",
61
+ help="List available models (sdk ∩ provider) with capabilities",
62
+ )
63
+ av_scope = av.add_mutually_exclusive_group(required=True)
64
+ av_scope.add_argument(
65
+ "--provider", help="Provider (e.g. openai/google/tuzi-openai)"
66
+ )
67
+ av_scope.add_argument(
68
+ "--all", action="store_true", help="List across all providers"
69
+ )
70
+ av.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
71
+
72
+ um = model_sub.add_parser(
73
+ "unsupported",
74
+ help="List provider-available but not in SDK catalog models",
75
+ )
76
+ um.add_argument("--provider", help="Provider (omit to scan all catalog providers)")
77
+ um.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
78
+
79
+ st = model_sub.add_parser(
80
+ "stale",
81
+ help="List stale model ids (sdk catalog - provider) for a provider",
82
+ )
83
+ st.add_argument(
84
+ "--provider", required=True, help="Provider (e.g. openai/google/tuzi-openai)"
85
+ )
86
+ st.add_argument("--timeout-ms", type=int, help="Timeout budget in milliseconds")
87
+
88
+ sub.add_parser("mapping", help="Print parameter mapping table")
89
+
90
+ token = sub.add_parser("token", help="Token utilities")
91
+ token_sub = token.add_subparsers(dest="token_command", required=True)
92
+ token_sub.add_parser("generate", help="Generate a new token (sk-...)")
93
+
94
+ parser.add_argument(
95
+ "--model", default="openai:gpt-4o-mini", help='Model like "openai:gpt-4o-mini"'
96
+ )
97
+ parser.add_argument(
98
+ "--protocol",
99
+ choices=["chat_completions", "responses"],
100
+ help='OpenAI chat protocol override ("chat_completions" or "responses")',
101
+ )
102
+ parser.add_argument("--prompt", help="Text prompt")
103
+ parser.add_argument(
104
+ "--prompt-path",
105
+ help="Read prompt text from a file (lower priority than --prompt)",
106
+ )
107
+ parser.add_argument("--image-path", help="Input image file path")
108
+ parser.add_argument("--audio-path", help="Input audio file path")
109
+ parser.add_argument("--video-path", help="Input video file path")
110
+ parser.add_argument("--output-path", help="Write output to file (text/json/binary)")
111
+ parser.add_argument("--ouput-path", dest="output_path", help=argparse.SUPPRESS)
112
+ parser.add_argument(
113
+ "--timeout-ms",
114
+ type=int,
115
+ help="Timeout budget in milliseconds (overrides NOUS_GENAI_TIMEOUT_MS)",
116
+ )
117
+
118
+ probe = sub.add_parser("probe", help="Probe model modalities/modes for a provider")
119
+ probe.add_argument("--provider", required=True, help="Provider (e.g. tuzi-web)")
120
+ probe.add_argument(
121
+ "--model", help="Comma-separated model ids (or provider:model_id)"
122
+ )
123
+ probe.add_argument(
124
+ "--all",
125
+ action="store_true",
126
+ help="Probe all SDK-supported models for the provider",
127
+ )
128
+ probe.add_argument(
129
+ "--timeout-ms",
130
+ type=int,
131
+ help="Timeout budget in milliseconds (overrides NOUS_GENAI_TIMEOUT_MS)",
132
+ )
133
+
134
+ args = parser.parse_args(argv)
135
+ timeout_ms: int | None = getattr(args, "timeout_ms", None)
136
+ if timeout_ms is not None and timeout_ms < 1:
137
+ raise SystemExit("--timeout-ms must be >= 1")
138
+
139
+ if args.command == "mapping":
140
+ try:
141
+ _print_mappings()
142
+ except BrokenPipeError:
143
+ return
144
+ return
145
+
146
+ if args.command == "token":
147
+ cmd = str(getattr(args, "token_command", "") or "").strip()
148
+ if cmd == "generate":
149
+ print(_generate_token())
150
+ return
151
+ raise SystemExit(f"unknown token subcommand: {cmd}")
152
+
153
+ if args.command == "model":
154
+ try:
155
+ cmd = str(getattr(args, "model_command", "") or "").strip()
156
+ if cmd == "sdk":
157
+ _print_sdk_supported()
158
+ return
159
+ if cmd == "provider":
160
+ _print_provider_models(str(args.provider), timeout_ms=timeout_ms)
161
+ return
162
+ if cmd == "available":
163
+ if bool(getattr(args, "all", False)):
164
+ _print_all_available_models(timeout_ms=timeout_ms)
165
+ return
166
+ _print_available_models(str(args.provider), timeout_ms=timeout_ms)
167
+ return
168
+ if cmd == "unsupported":
169
+ if getattr(args, "provider", None):
170
+ _print_unsupported_models(str(args.provider), timeout_ms=timeout_ms)
171
+ return
172
+ _print_unsupported(timeout_ms=timeout_ms)
173
+ return
174
+ if cmd == "stale":
175
+ _print_stale_models(str(args.provider), timeout_ms=timeout_ms)
176
+ return
177
+ raise SystemExit(f"unknown model subcommand: {cmd}")
178
+ except BrokenPipeError:
179
+ return
180
+
181
+ if args.command == "probe":
182
+ try:
183
+ raise SystemExit(_run_probe(args, timeout_ms=timeout_ms))
184
+ except BrokenPipeError:
185
+ return
186
+
187
+ provider, model_id = _split_model(args.model)
188
+ prompt = args.prompt
189
+ if prompt is None and args.prompt_path:
190
+ try:
191
+ with open(args.prompt_path, "r", encoding="utf-8") as f:
192
+ prompt = f.read()
193
+ except OSError as e:
194
+ raise SystemExit(f"cannot read --prompt-path: {e}") from None
195
+ client = Client()
196
+ _apply_protocol_override(client, provider=provider, protocol=args.protocol)
197
+
198
+ cap = client.capabilities(args.model)
199
+ output = _infer_output_spec(provider=provider, model_id=model_id, cap=cap)
200
+
201
+ parts = _build_input_parts(
202
+ prompt=prompt,
203
+ image_path=args.image_path,
204
+ audio_path=args.audio_path,
205
+ video_path=args.video_path,
206
+ input_modalities=set(cap.input_modalities),
207
+ output_modalities=set(output.modalities),
208
+ provider=provider,
209
+ model_id=model_id,
210
+ )
211
+ req = GenerateRequest(
212
+ model=args.model,
213
+ input=[Message(role="user", content=parts)],
214
+ output=output,
215
+ wait=True,
216
+ )
217
+ if timeout_ms is not None:
218
+ req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
219
+
220
+ try:
221
+ wait_spinner = bool(req.wait) and bool(getattr(cap, "supports_job", False))
222
+ show_progress = wait_spinner and sys.stderr.isatty()
223
+ resp, elapsed_s = _run_with_spinner(
224
+ lambda: client.generate(req),
225
+ enabled=show_progress,
226
+ label="等待任务完成",
227
+ )
228
+ if resp.status != "completed":
229
+ if resp.job and resp.job.job_id:
230
+ print(resp.job.job_id)
231
+ if resp.status == "running":
232
+ effective_timeout_ms = timeout_ms
233
+ if effective_timeout_ms is None:
234
+ effective_timeout_ms = getattr(
235
+ client, "_default_timeout_ms", None
236
+ )
237
+ timeout_note = (
238
+ f"{effective_timeout_ms}ms"
239
+ if isinstance(effective_timeout_ms, int)
240
+ else "timeout"
241
+ )
242
+ print(
243
+ f"[INFO] 任务仍在运行(等待 {elapsed_s:.1f}s,可能已超时 {timeout_note});已返回 job_id。"
244
+ "可增大 --timeout-ms 或设置 NOUS_GENAI_TIMEOUT_MS 后重试。",
245
+ file=sys.stderr,
246
+ )
247
+ if args.output_path:
248
+ print(
249
+ f"[INFO] 未写入输出文件:{args.output_path}",
250
+ file=sys.stderr,
251
+ )
252
+ return
253
+ raise SystemExit(f"[FAIL]: request status={resp.status}")
254
+ if not resp.output:
255
+ raise SystemExit("[FAIL]: missing output")
256
+ _write_response(
257
+ resp.output[0].content,
258
+ output=output,
259
+ output_path=args.output_path,
260
+ timeout_ms=timeout_ms,
261
+ download_auth=_download_auth(client, provider=provider),
262
+ )
263
+ if show_progress:
264
+ print(f"[INFO] 完成,用时 {elapsed_s:.1f}s", file=sys.stderr)
265
+ except GenAIError as e:
266
+ code = f" ({e.info.provider_code})" if e.info.provider_code else ""
267
+ retryable = " retryable" if e.info.retryable else ""
268
+ raise SystemExit(
269
+ f"[FAIL]{code}{retryable}: {e.info.type}: {e.info.message}"
270
+ ) from None
271
+
272
+
273
+ _DEFAULT_VIDEO_URL = (
274
+ "https://interactive-examples.mdn.mozilla.net/media/cc0-videos/flower.mp4"
275
+ )
276
+
277
+
278
+ def _run_probe(args: argparse.Namespace, *, timeout_ms: int | None) -> int:
279
+ from .client import _normalize_provider
280
+ from .reference import get_sdk_supported_models_for_provider
281
+
282
+ provider = _normalize_provider(str(args.provider))
283
+ if not provider:
284
+ raise SystemExit("--provider must be non-empty")
285
+
286
+ if bool(args.all) == bool(args.model):
287
+ raise SystemExit('probe requires exactly one of: "--model" or "--all"')
288
+
289
+ if args.all:
290
+ rows = get_sdk_supported_models_for_provider(provider)
291
+ model_ids = [
292
+ str(r["model_id"])
293
+ for r in rows
294
+ if isinstance(r, dict) and isinstance(r.get("model_id"), str)
295
+ ]
296
+ else:
297
+ model_ids = _parse_probe_models(provider, str(args.model))
298
+
299
+ if not model_ids:
300
+ raise SystemExit(f"no models to probe for provider={provider}")
301
+
302
+ client = Client()
303
+ totals = {"ok": 0, "fail": 0, "skip": 0}
304
+
305
+ for model_id in model_ids:
306
+ model = f"{provider}:{model_id}"
307
+ print(f"== {model} ==")
308
+ try:
309
+ cap = client.capabilities(model)
310
+ except GenAIError as e:
311
+ print(f"[FAIL] capabilities: {e.info.type}: {e.info.message}")
312
+ totals["fail"] += 1
313
+ print()
314
+ continue
315
+
316
+ modes = _probe_modes_for(cap)
317
+ print(
318
+ "declared:"
319
+ f" modes={','.join(modes)}"
320
+ f" in={','.join(sorted(cap.input_modalities))}"
321
+ f" out={','.join(sorted(cap.output_modalities))}"
322
+ )
323
+
324
+ results = _probe_model(
325
+ client,
326
+ provider=provider,
327
+ model_id=model_id,
328
+ cap=cap,
329
+ timeout_ms=timeout_ms,
330
+ )
331
+ for k in totals.keys():
332
+ totals[k] += results[k]
333
+
334
+ status = "OK" if results["fail"] == 0 else "FAIL"
335
+ print(
336
+ f"result: {status} (ok={results['ok']} fail={results['fail']} skip={results['skip']})"
337
+ )
338
+ print()
339
+
340
+ print(f"[SUMMARY] ok={totals['ok']} fail={totals['fail']} skip={totals['skip']}")
341
+ return 0 if totals["fail"] == 0 else 1
342
+
343
+
344
+ _T = TypeVar("_T")
345
+
346
+
347
+ def _run_with_spinner(
348
+ fn: Callable[[], _T], *, enabled: bool, label: str
349
+ ) -> tuple[_T, float]:
350
+ start = time.perf_counter()
351
+ if not enabled or not sys.stderr.isatty():
352
+ out = fn()
353
+ return out, time.perf_counter() - start
354
+
355
+ done = threading.Event()
356
+ result: dict[str, _T] = {}
357
+ error: dict[str, BaseException] = {}
358
+
359
+ def _worker() -> None:
360
+ try:
361
+ result["value"] = fn()
362
+ except BaseException as e: # noqa: BLE001
363
+ error["exc"] = e
364
+ finally:
365
+ done.set()
366
+
367
+ t = threading.Thread(target=_worker, name="genai-cli-wait", daemon=True)
368
+ t.start()
369
+
370
+ if done.wait(0.25):
371
+ t.join()
372
+ exc = error.get("exc")
373
+ if exc is not None:
374
+ raise exc
375
+ if "value" not in result:
376
+ raise RuntimeError("missing result value")
377
+ return result["value"], time.perf_counter() - start
378
+
379
+ frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
380
+ i = 0
381
+ try:
382
+ while not done.wait(0.1):
383
+ frame = frames[i % len(frames)]
384
+ i += 1
385
+ elapsed = time.perf_counter() - start
386
+ sys.stderr.write(f"\r{frame} {label}... {elapsed:5.1f}s")
387
+ sys.stderr.flush()
388
+ finally:
389
+ sys.stderr.write("\r" + (" " * 64) + "\r")
390
+ sys.stderr.flush()
391
+ t.join()
392
+
393
+ exc = error.get("exc")
394
+ if exc is not None:
395
+ raise exc
396
+ if "value" not in result:
397
+ raise RuntimeError("missing result value")
398
+ return result["value"], time.perf_counter() - start
399
+
400
+
401
+ def _parse_probe_models(provider: str, value: str) -> list[str]:
402
+ from .client import _normalize_provider
403
+
404
+ p = _normalize_provider(provider)
405
+ items = [x.strip() for x in value.split(",")]
406
+ out: list[str] = []
407
+ seen: set[str] = set()
408
+ for raw in items:
409
+ if not raw:
410
+ continue
411
+ if ":" in raw:
412
+ pp, mid = raw.split(":", 1)
413
+ pp = _normalize_provider(pp)
414
+ mid = mid.strip()
415
+ if pp != p:
416
+ raise SystemExit(f"model provider mismatch: expected {p}, got {pp}")
417
+ raw = mid
418
+ if not raw:
419
+ continue
420
+ if raw in seen:
421
+ continue
422
+ seen.add(raw)
423
+ out.append(raw)
424
+ return out
425
+
426
+
427
+ def _probe_modes_for(cap) -> list[str]:
428
+ modes: list[str] = ["sync"]
429
+ if cap.supports_stream:
430
+ modes.append("stream")
431
+ if cap.supports_job:
432
+ modes.append("job")
433
+ modes.append("async")
434
+ return modes
435
+
436
+
437
+ def _probe_model(
438
+ client: Client,
439
+ *,
440
+ provider: str,
441
+ model_id: str,
442
+ cap,
443
+ timeout_ms: int | None,
444
+ ) -> dict[str, int]:
445
+ results = {"ok": 0, "fail": 0, "skip": 0}
446
+
447
+ out_modalities = sorted(cap.output_modalities)
448
+ for i, out_modality in enumerate(out_modalities):
449
+ ok = _probe_output_modality(
450
+ client,
451
+ provider=provider,
452
+ model_id=model_id,
453
+ cap=cap,
454
+ out_modality=out_modality,
455
+ timeout_ms=timeout_ms,
456
+ probe_job=bool(cap.supports_job) and i == 0,
457
+ )
458
+ _accumulate(results, ok)
459
+
460
+ for in_modality in sorted(set(cap.input_modalities) - {"text"}):
461
+ ok = _probe_input_modality(
462
+ client,
463
+ provider=provider,
464
+ model_id=model_id,
465
+ cap=cap,
466
+ in_modality=in_modality,
467
+ timeout_ms=timeout_ms,
468
+ )
469
+ _accumulate(results, ok)
470
+
471
+ if cap.supports_stream:
472
+ ok = _probe_stream_mode(
473
+ client,
474
+ provider=provider,
475
+ model_id=model_id,
476
+ cap=cap,
477
+ timeout_ms=timeout_ms,
478
+ )
479
+ _accumulate(results, ok)
480
+
481
+ return results
482
+
483
+
484
+ def _accumulate(totals: dict[str, int], outcome: dict[str, int]) -> None:
485
+ for k, v in outcome.items():
486
+ totals[k] += v
487
+
488
+
489
+ def _probe_output_modality(
490
+ client: Client,
491
+ *,
492
+ provider: str,
493
+ model_id: str,
494
+ cap,
495
+ out_modality: str,
496
+ timeout_ms: int | None,
497
+ probe_job: bool,
498
+ ) -> dict[str, int]:
499
+ label = f"output:{out_modality}"
500
+ try:
501
+ req = _build_probe_request(
502
+ provider=provider,
503
+ model_id=model_id,
504
+ cap=cap,
505
+ out_modality=out_modality,
506
+ in_modality=None,
507
+ timeout_ms=timeout_ms,
508
+ force_wait=False if cap.supports_job else None,
509
+ )
510
+ resp = client.generate(req)
511
+ _validate_probe_response(resp, expected_out=out_modality)
512
+ except GenAIError as e:
513
+ return _probe_fail(label, e)
514
+ except SystemExit as e:
515
+ return _probe_fail(label, e)
516
+ except Exception as e:
517
+ return _probe_fail(label, e)
518
+ out = _probe_ok(label)
519
+ if not probe_job:
520
+ return out
521
+ if resp.status != "running":
522
+ _accumulate(out, _probe_skip("mode:job", f"response status={resp.status}"))
523
+ return out
524
+ _accumulate(out, _probe_ok("mode:job"))
525
+ return out
526
+
527
+
528
+ def _probe_input_modality(
529
+ client: Client,
530
+ *,
531
+ provider: str,
532
+ model_id: str,
533
+ cap,
534
+ in_modality: str,
535
+ timeout_ms: int | None,
536
+ ) -> dict[str, int]:
537
+ out_modality = (
538
+ "text"
539
+ if "text" in set(cap.output_modalities)
540
+ else sorted(cap.output_modalities)[0]
541
+ )
542
+ label = f"input:{in_modality}"
543
+ try:
544
+ req = _build_probe_request(
545
+ provider=provider,
546
+ model_id=model_id,
547
+ cap=cap,
548
+ out_modality=out_modality,
549
+ in_modality=in_modality,
550
+ timeout_ms=timeout_ms,
551
+ force_wait=False if cap.supports_job else None,
552
+ )
553
+ resp = client.generate(req)
554
+ _validate_probe_response(resp, expected_out=out_modality)
555
+ except GenAIError as e:
556
+ return _probe_fail(label, e)
557
+ except SystemExit as e:
558
+ return _probe_fail(label, e)
559
+ except Exception as e:
560
+ return _probe_fail(label, e)
561
+ return _probe_ok(label)
562
+
563
+
564
+ def _probe_stream_mode(
565
+ client: Client,
566
+ *,
567
+ provider: str,
568
+ model_id: str,
569
+ cap,
570
+ timeout_ms: int | None,
571
+ ) -> dict[str, int]:
572
+ label = "mode:stream"
573
+ if "text" not in set(cap.output_modalities):
574
+ return _probe_skip(label, "stream probe requires text output")
575
+ try:
576
+ req = _build_probe_request(
577
+ provider=provider,
578
+ model_id=model_id,
579
+ cap=cap,
580
+ out_modality="text",
581
+ in_modality=None,
582
+ timeout_ms=timeout_ms,
583
+ )
584
+ deltas = 0
585
+ for ev in client.generate_stream(req):
586
+ if ev.type == "output.text.delta":
587
+ delta = ev.data.get("delta")
588
+ if isinstance(delta, str) and delta:
589
+ deltas += 1
590
+ break
591
+ if ev.type == "done":
592
+ break
593
+ if deltas == 0:
594
+ raise SystemExit("no output.text.delta received")
595
+ except GenAIError as e:
596
+ return _probe_fail(label, e)
597
+ except SystemExit as e:
598
+ return _probe_fail(label, e)
599
+ except Exception as e:
600
+ return _probe_fail(label, e)
601
+ return _probe_ok(label)
602
+
603
+
604
+ def _build_probe_request(
605
+ *,
606
+ provider: str,
607
+ model_id: str,
608
+ cap,
609
+ out_modality: str,
610
+ in_modality: str | None,
611
+ timeout_ms: int | None,
612
+ force_wait: bool | None = None,
613
+ ) -> GenerateRequest:
614
+ model = f"{provider}:{model_id}"
615
+ wait = True
616
+ if force_wait is not None:
617
+ wait = force_wait
618
+ elif out_modality == "video":
619
+ wait = False
620
+
621
+ output = _output_spec_for_modality(
622
+ provider=provider, model_id=model_id, modality=out_modality
623
+ )
624
+ parts = _probe_input_parts(
625
+ provider=provider,
626
+ model_id=model_id,
627
+ cap=cap,
628
+ out_modality=out_modality,
629
+ in_modality=in_modality,
630
+ )
631
+ req = GenerateRequest(
632
+ model=model,
633
+ input=[Message(role="user", content=parts)],
634
+ output=output,
635
+ wait=wait,
636
+ )
637
+ if timeout_ms is not None:
638
+ req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
639
+ return req
640
+
641
+
642
+ def _probe_input_parts(
643
+ *,
644
+ provider: str,
645
+ model_id: str,
646
+ cap,
647
+ out_modality: str,
648
+ in_modality: str | None,
649
+ ) -> list[Part]:
650
+ prompt = _probe_prompt(out_modality=out_modality, in_modality=in_modality)
651
+ if out_modality == "embedding":
652
+ return [Part.from_text(prompt)]
653
+
654
+ if in_modality is None:
655
+ if set(cap.input_modalities) == {"audio"}:
656
+ return [
657
+ Part(
658
+ type="audio",
659
+ mime_type="audio/wav",
660
+ source=PartSourceBytes(data=_probe_wav_bytes()),
661
+ )
662
+ ]
663
+ return [Part.from_text(prompt)]
664
+
665
+ if in_modality == "text":
666
+ return [Part.from_text(prompt)]
667
+ if in_modality == "image":
668
+ return [Part.from_text(prompt), _probe_image_part()]
669
+ if in_modality == "audio":
670
+ text_meta: dict[str, object] = {}
671
+ if set(cap.input_modalities) == {"audio"} and out_modality == "text":
672
+ text_meta = {"transcription_prompt": True}
673
+ parts = [
674
+ Part(
675
+ type="audio",
676
+ mime_type="audio/wav",
677
+ source=PartSourceBytes(data=_probe_wav_bytes()),
678
+ )
679
+ ]
680
+ if prompt:
681
+ parts.insert(0, Part(type="text", text=prompt, meta=text_meta))
682
+ return parts
683
+ if in_modality == "video":
684
+ return [
685
+ Part.from_text(prompt),
686
+ Part(
687
+ type="video",
688
+ mime_type="video/mp4",
689
+ source=PartSourceUrl(url=_DEFAULT_VIDEO_URL),
690
+ ),
691
+ ]
692
+ raise SystemExit(f"unknown input modality: {in_modality}")
693
+
694
+
695
+ def _probe_prompt(*, out_modality: str, in_modality: str | None) -> str:
696
+ if out_modality == "embedding":
697
+ return "hello"
698
+ if out_modality == "image":
699
+ return "生成一张简单的红色方块图。"
700
+ if out_modality == "audio":
701
+ return "用中文说:你好。"
702
+ if out_modality == "video":
703
+ return "生成一个简单短视频:一只猫在草地上奔跑。"
704
+ if in_modality == "image":
705
+ return "请用一句话描述这张图。"
706
+ if in_modality == "audio":
707
+ return "请转写音频内容。"
708
+ if in_modality == "video":
709
+ return "请用一句话描述视频内容。"
710
+ return "只回复:pong"
711
+
712
+
713
+ def _output_spec_for_modality(
714
+ *, provider: str, model_id: str, modality: str
715
+ ) -> OutputSpec:
716
+ if modality == "embedding":
717
+ return OutputSpec(modalities=["embedding"], embedding=OutputEmbeddingSpec())
718
+ if modality == "image":
719
+ return OutputSpec(modalities=["image"], image=OutputImageSpec(n=1))
720
+ if modality == "audio":
721
+ voice, language = _probe_audio_voice(provider=provider, model_id=model_id)
722
+ return OutputSpec(
723
+ modalities=["audio"],
724
+ audio=OutputAudioSpec(voice=voice, language=language, format="mp3"),
725
+ )
726
+ if modality == "video":
727
+ duration = _probe_video_duration(provider=provider, model_id=model_id)
728
+ return OutputSpec(
729
+ modalities=["video"],
730
+ video=OutputVideoSpec(duration_sec=duration, aspect_ratio="16:9"),
731
+ )
732
+ if modality == "text":
733
+ return OutputSpec(modalities=["text"], text=OutputTextSpec(format="text"))
734
+ raise SystemExit(f"unknown output modality: {modality}")
735
+
736
+
737
+ def _probe_audio_voice(*, provider: str, model_id: str) -> tuple[str, str | None]:
738
+ mid_l = model_id.lower().strip()
739
+ if (
740
+ provider in {"google", "tuzi-google"}
741
+ or mid_l.startswith(("gemini-", "gemma-"))
742
+ or "native-audio" in mid_l
743
+ ):
744
+ return ("Kore", "en-US")
745
+ if provider == "aliyun":
746
+ return ("Cherry", "zh-CN")
747
+ return ("alloy", None)
748
+
749
+
750
+ def _probe_video_duration(*, provider: str, model_id: str) -> int:
751
+ mid_l = model_id.lower().strip()
752
+ if mid_l.startswith("sora-"):
753
+ return 4
754
+ if mid_l.startswith("veo-"):
755
+ return 5
756
+ if provider.startswith("tuzi"):
757
+ return 10
758
+ return 4
759
+
760
+
761
+ def _probe_image_part() -> Part:
762
+ return Part(
763
+ type="image",
764
+ mime_type="image/png",
765
+ source=PartSourceBytes(data=_probe_png_bytes()),
766
+ )
767
+
768
+
769
+ def _probe_png_bytes() -> bytes:
770
+ import struct
771
+ import zlib
772
+
773
+ width = 64
774
+ height = 64
775
+
776
+ # RGB red square (unfiltered scanlines).
777
+ row = b"\x00" + (b"\xff\x00\x00" * width)
778
+ raw = row * height
779
+ compressed = zlib.compress(raw, level=6)
780
+
781
+ def chunk(typ: bytes, data: bytes) -> bytes:
782
+ crc = zlib.crc32(typ)
783
+ crc = zlib.crc32(data, crc)
784
+ return (
785
+ struct.pack(">I", len(data))
786
+ + typ
787
+ + data
788
+ + struct.pack(">I", crc & 0xFFFFFFFF)
789
+ )
790
+
791
+ signature = b"\x89PNG\r\n\x1a\n"
792
+ ihdr = struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)
793
+ return (
794
+ signature
795
+ + chunk(b"IHDR", ihdr)
796
+ + chunk(b"IDAT", compressed)
797
+ + chunk(b"IEND", b"")
798
+ )
799
+
800
+
801
+ def _probe_wav_bytes() -> bytes:
802
+ import io
803
+ import wave
804
+
805
+ buf = io.BytesIO()
806
+ with wave.open(buf, "wb") as wf:
807
+ wf.setnchannels(1)
808
+ wf.setsampwidth(2)
809
+ wf.setframerate(16_000)
810
+ wf.writeframes(b"\x00\x00" * 16_000) # 1s silence
811
+ return buf.getvalue()
812
+
813
+
814
+ def _validate_probe_response(resp, *, expected_out: str) -> None:
815
+ if resp.status == "running":
816
+ if not resp.job or not resp.job.job_id:
817
+ raise SystemExit("running response missing job info")
818
+ return
819
+ if resp.status != "completed":
820
+ raise SystemExit(f"unexpected status: {resp.status}")
821
+ if not resp.output:
822
+ raise SystemExit("missing output")
823
+ parts = [p for m in resp.output for p in m.content]
824
+ if expected_out == "text":
825
+ if not any(p.type == "text" and isinstance(p.text, str) for p in parts):
826
+ raise SystemExit("missing text output part")
827
+ return
828
+ if expected_out == "embedding":
829
+ if not any(
830
+ p.type == "embedding" and isinstance(p.embedding, list) for p in parts
831
+ ):
832
+ raise SystemExit("missing embedding output part")
833
+ return
834
+ if expected_out in {"image", "audio", "video"}:
835
+ if not any(p.type == expected_out and p.source is not None for p in parts):
836
+ raise SystemExit(f"missing {expected_out} output part")
837
+ return
838
+ raise SystemExit(f"unknown expected output modality: {expected_out}")
839
+
840
+
841
+ def _probe_ok(label: str) -> dict[str, int]:
842
+ print(f"[OK] {label}")
843
+ return {"ok": 1, "fail": 0, "skip": 0}
844
+
845
+
846
+ def _probe_skip(label: str, reason: str) -> dict[str, int]:
847
+ print(f"[SKIP] {label}: {reason}")
848
+ return {"ok": 0, "fail": 0, "skip": 1}
849
+
850
+
851
+ def _probe_fail(label: str, err: BaseException) -> dict[str, int]:
852
+ if isinstance(err, GenAIError):
853
+ msg = f"{err.info.type}: {err.info.message}"
854
+ else:
855
+ msg = str(err) or err.__class__.__name__
856
+ print(f"[FAIL] {label}: {msg}")
857
+ return {"ok": 0, "fail": 1, "skip": 0}
858
+
859
+
860
+ def _print_sdk_supported() -> None:
861
+ models = get_sdk_supported_models()
862
+ by_provider: dict[str, list[dict]] = {}
863
+ for m in models:
864
+ by_provider.setdefault(m["provider"], []).append(m)
865
+
866
+ for p in sorted(by_provider.keys()):
867
+ print(f"== {p} ==")
868
+ for m in sorted(by_provider[p], key=lambda x: (x["category"], x["model_id"])):
869
+ inp = ",".join(m["input_modalities"])
870
+ out = ",".join(m["output_modalities"])
871
+ modes = ",".join(m["modes"])
872
+ print(
873
+ f"{m['category']:13} {m['model']:45} modes={modes:18} in={inp:18} out={out:18}"
874
+ )
875
+ print()
876
+
877
+
878
+ def _print_provider_models(provider: str, *, timeout_ms: int | None) -> None:
879
+ from .client import _normalize_provider
880
+
881
+ client = Client()
882
+ p = _normalize_provider(provider)
883
+ for model_id in sorted(
884
+ client.list_provider_models(provider, timeout_ms=timeout_ms)
885
+ ):
886
+ print(f"{p}:{model_id}")
887
+
888
+
889
+ def _print_available_models(provider: str, *, timeout_ms: int | None) -> None:
890
+ from .client import _normalize_provider
891
+ from .reference import get_sdk_supported_models_for_provider
892
+
893
+ client = Client()
894
+ p = _normalize_provider(provider)
895
+ rows = get_sdk_supported_models_for_provider(p)
896
+ by_model_id = {m["model_id"]: m for m in rows}
897
+ for model_id in client.list_available_models(provider, timeout_ms=timeout_ms):
898
+ m = by_model_id.get(model_id)
899
+ if m is None:
900
+ print(f"{p}:{model_id}")
901
+ continue
902
+ inp = ",".join(m["input_modalities"])
903
+ out = ",".join(m["output_modalities"])
904
+ modes = ",".join(m["modes"])
905
+ print(f"{m['model']:45} modes={modes:18} in={inp:18} out={out:18}")
906
+
907
+
908
+ def _print_all_available_models(*, timeout_ms: int | None) -> None:
909
+ models = get_sdk_supported_models()
910
+ by_model: dict[str, dict] = {m["model"]: m for m in models}
911
+ client = Client()
912
+ for model in client.list_all_available_models(timeout_ms=timeout_ms):
913
+ m = by_model.get(model)
914
+ if m is None:
915
+ print(model)
916
+ continue
917
+ inp = ",".join(m["input_modalities"])
918
+ out = ",".join(m["output_modalities"])
919
+ modes = ",".join(m["modes"])
920
+ print(f"{model:45} modes={modes:18} in={inp:18} out={out:18}")
921
+
922
+
923
+ def _print_unsupported_models(provider: str, *, timeout_ms: int | None) -> None:
924
+ from .client import _normalize_provider
925
+
926
+ client = Client()
927
+ p = _normalize_provider(provider)
928
+ for model_id in client.list_unsupported_models(provider, timeout_ms=timeout_ms):
929
+ print(f"{p}:{model_id}")
930
+
931
+
932
+ def _print_stale_models(provider: str, *, timeout_ms: int | None) -> None:
933
+ from .client import _normalize_provider
934
+
935
+ client = Client()
936
+ p = _normalize_provider(provider)
937
+ for model_id in client.list_stale_models(provider, timeout_ms=timeout_ms):
938
+ print(f"{p}:{model_id}")
939
+
940
+
941
+ def _print_unsupported(*, timeout_ms: int | None) -> None:
942
+ supported: dict[str, set[str]] = {}
943
+ catalog = get_model_catalog()
944
+ for provider, model_ids in catalog.items():
945
+ supported[provider] = {m for m in model_ids if isinstance(m, str) and m}
946
+
947
+ client = Client()
948
+ for provider in sorted(catalog.keys()):
949
+ remote = set(client.list_provider_models(provider, timeout_ms=timeout_ms))
950
+ if not remote:
951
+ continue
952
+ unknown = sorted(remote - supported.get(provider, set()))
953
+ if not unknown:
954
+ continue
955
+ print(f"== {provider} ==")
956
+ for model_id in unknown:
957
+ print(f"{provider}:{model_id}")
958
+ print()
959
+
960
+
961
+ def _print_mappings() -> None:
962
+ items = get_parameter_mappings()
963
+ items = sorted(
964
+ items,
965
+ key=lambda x: (
966
+ x["provider"],
967
+ x["protocol"],
968
+ x["operation"],
969
+ x["from"],
970
+ x["to"],
971
+ ),
972
+ )
973
+ cur = None
974
+ for m in items:
975
+ key = (m["provider"], m["protocol"])
976
+ if key != cur:
977
+ if cur is not None:
978
+ print()
979
+ cur = key
980
+ print(f"== {m['provider']} (protocol={m['protocol']}) ==")
981
+ note = f" # {m['notes']}" if m.get("notes") else ""
982
+ print(f"{m['operation']:14} {m['from']:55} -> {m['to']}{note}")
983
+
984
+
985
+ def _split_model(model: str) -> tuple[str, str]:
986
+ if ":" not in model:
987
+ raise SystemExit('model must be "{provider}:{model_id}"')
988
+ provider, model_id = model.split(":", 1)
989
+ provider = provider.strip().lower()
990
+ model_id = model_id.strip()
991
+ if not provider or not model_id:
992
+ raise SystemExit('model must be "{provider}:{model_id}"')
993
+ return provider, model_id
994
+
995
+
996
+ def _apply_protocol_override(
997
+ client: Client, *, provider: str, protocol: str | None
998
+ ) -> None:
999
+ if not protocol:
1000
+ return
1001
+ if provider != "openai":
1002
+ raise SystemExit("--protocol only applies to provider=openai")
1003
+ if client._openai is None:
1004
+ raise SystemExit("NOUS_GENAI_OPENAI_API_KEY/OPENAI_API_KEY not configured")
1005
+ client._openai = replace(client._openai, chat_api=protocol)
1006
+
1007
+
1008
+ def _infer_output_spec(*, provider: str, model_id: str, cap) -> OutputSpec:
1009
+ out = set(cap.output_modalities)
1010
+ if out == {"embedding"}:
1011
+ return OutputSpec(modalities=["embedding"], embedding=OutputEmbeddingSpec())
1012
+ if out == {"image"}:
1013
+ return OutputSpec(modalities=["image"], image=OutputImageSpec(n=1))
1014
+ if out == {"audio"}:
1015
+ if provider == "google":
1016
+ audio = OutputAudioSpec(voice="Kore", language="en-US")
1017
+ elif provider == "aliyun":
1018
+ audio = OutputAudioSpec(voice="Cherry", language="zh-CN", format="wav")
1019
+ else:
1020
+ audio = OutputAudioSpec(voice="alloy", format="mp3")
1021
+ return OutputSpec(modalities=["audio"], audio=audio)
1022
+ if out == {"video"}:
1023
+ duration = 4
1024
+ if provider.startswith("tuzi"):
1025
+ duration = 10
1026
+ if model_id.lower().startswith("sora-"):
1027
+ duration = 4
1028
+ if provider == "google" and model_id.lower().startswith("veo-"):
1029
+ duration = 5
1030
+ return OutputSpec(
1031
+ modalities=["video"],
1032
+ video=OutputVideoSpec(duration_sec=duration, aspect_ratio="16:9"),
1033
+ )
1034
+ if "text" in out:
1035
+ return OutputSpec(modalities=["text"], text=OutputTextSpec(format="text"))
1036
+ raise SystemExit(f"cannot infer output for model={provider}:{model_id}")
1037
+
1038
+
1039
+ def _build_input_parts(
1040
+ *,
1041
+ prompt: str | None,
1042
+ image_path: str | None,
1043
+ audio_path: str | None,
1044
+ video_path: str | None,
1045
+ input_modalities: set[str],
1046
+ output_modalities: set[str],
1047
+ provider: str,
1048
+ model_id: str,
1049
+ ) -> list[Part]:
1050
+ if output_modalities == {"embedding"}:
1051
+ if image_path or audio_path or video_path:
1052
+ raise SystemExit("embedding only supports text input")
1053
+ if not prompt:
1054
+ raise SystemExit("embedding requires --prompt")
1055
+ return [Part.from_text(prompt)]
1056
+
1057
+ if output_modalities == {"image"}:
1058
+ if audio_path or video_path:
1059
+ raise SystemExit(
1060
+ "image generation does not take --audio-path/--video-path input"
1061
+ )
1062
+ if image_path and "image" not in input_modalities:
1063
+ raise SystemExit(
1064
+ "image generation does not take --image-path input for this model"
1065
+ )
1066
+ if not prompt:
1067
+ raise SystemExit("image generation requires --prompt")
1068
+
1069
+ if output_modalities == {"audio"}:
1070
+ if image_path or audio_path or video_path:
1071
+ raise SystemExit(
1072
+ "TTS does not take --image-path/--audio-path/--video-path input"
1073
+ )
1074
+ if not prompt:
1075
+ raise SystemExit("audio generation requires --prompt")
1076
+ return [Part.from_text(prompt)]
1077
+
1078
+ if output_modalities == {"video"}:
1079
+ if image_path or audio_path or video_path:
1080
+ raise SystemExit(
1081
+ "video generation does not take --image-path/--audio-path/--video-path input"
1082
+ )
1083
+ if not prompt:
1084
+ raise SystemExit("video generation requires --prompt")
1085
+ return [Part.from_text(prompt)]
1086
+
1087
+ parts: list[Part] = []
1088
+ if prompt:
1089
+ meta = {}
1090
+ if provider == "openai" and (
1091
+ model_id == "whisper-1" or "-transcribe" in model_id
1092
+ ):
1093
+ meta = {"transcription_prompt": True}
1094
+ parts.append(Part(type="text", text=prompt, meta=meta))
1095
+
1096
+ if image_path:
1097
+ mime = detect_mime_type(image_path)
1098
+ if not mime or not mime.startswith("image/"):
1099
+ raise SystemExit(f"cannot detect image mime type for: {image_path}")
1100
+ parts.append(
1101
+ Part(type="image", mime_type=mime, source=PartSourcePath(path=image_path))
1102
+ )
1103
+ if audio_path:
1104
+ mime = detect_mime_type(audio_path)
1105
+ if not mime or not mime.startswith("audio/"):
1106
+ raise SystemExit(f"cannot detect audio mime type for: {audio_path}")
1107
+ parts.append(
1108
+ Part(type="audio", mime_type=mime, source=PartSourcePath(path=audio_path))
1109
+ )
1110
+ if video_path:
1111
+ mime = detect_mime_type(video_path)
1112
+ if not mime or not mime.startswith("video/"):
1113
+ raise SystemExit(f"cannot detect video mime type for: {video_path}")
1114
+ parts.append(
1115
+ Part(type="video", mime_type=mime, source=PartSourcePath(path=video_path))
1116
+ )
1117
+
1118
+ if not parts:
1119
+ raise SystemExit(
1120
+ "missing input: provide --prompt and/or --image-path/--audio-path/--video-path"
1121
+ )
1122
+ return parts
1123
+
1124
+
1125
+ def _run_stream_text(
1126
+ client: Client, req: GenerateRequest, *, timeout_ms: int | None
1127
+ ) -> str:
1128
+ if timeout_ms is not None:
1129
+ req = replace(req, params=replace(req.params, timeout_ms=timeout_ms))
1130
+ chunks: list[str] = []
1131
+ for ev in client.generate_stream(req):
1132
+ if ev.type != "output.text.delta":
1133
+ continue
1134
+ delta = ev.data.get("delta")
1135
+ if isinstance(delta, str) and delta:
1136
+ chunks.append(delta)
1137
+ return "".join(chunks)
1138
+
1139
+
1140
+ def _write_text(text: str, *, output_path: str | None) -> None:
1141
+ if output_path:
1142
+ with open(output_path, "w", encoding="utf-8") as f:
1143
+ f.write(text)
1144
+ print(f"[OK] wrote {output_path}")
1145
+ return
1146
+ print(text)
1147
+
1148
+
1149
+ def _guess_ext(mime: str | None) -> str:
1150
+ if not mime:
1151
+ return ""
1152
+ m = mime.lower()
1153
+ if m == "image/png":
1154
+ return ".png"
1155
+ if m in {"image/jpeg", "image/jpg"}:
1156
+ return ".jpg"
1157
+ if m == "image/webp":
1158
+ return ".webp"
1159
+ if m in {"audio/mpeg", "audio/mp3"}:
1160
+ return ".mp3"
1161
+ if m in {"audio/wav", "audio/wave"}:
1162
+ return ".wav"
1163
+ if m in {"audio/mp4", "audio/m4a"}:
1164
+ return ".m4a"
1165
+ if m == "video/mp4":
1166
+ return ".mp4"
1167
+ if m == "video/quicktime":
1168
+ return ".mov"
1169
+ return ""
1170
+
1171
+
1172
+ def _download_with_headers(
1173
+ url: str,
1174
+ output_path: str,
1175
+ *,
1176
+ timeout_ms: int | None,
1177
+ headers: dict[str, str] | None,
1178
+ ) -> None:
1179
+ download_to_file(
1180
+ url=url,
1181
+ output_path=output_path,
1182
+ timeout_ms=timeout_ms,
1183
+ max_bytes=None,
1184
+ headers=headers,
1185
+ )
1186
+
1187
+
1188
+ def _download_auth(
1189
+ client: object, *, provider: str
1190
+ ) -> tuple[dict[str, str], set[str]] | None:
1191
+ adapter_getter = getattr(client, "_adapter", None)
1192
+ if not callable(adapter_getter):
1193
+ return None
1194
+ try:
1195
+ adapter = adapter_getter(provider)
1196
+ except Exception:
1197
+ return None
1198
+ header_fn = getattr(adapter, "_download_headers", None)
1199
+ if not callable(header_fn):
1200
+ return None
1201
+ try:
1202
+ raw = header_fn()
1203
+ except Exception:
1204
+ return None
1205
+ if not isinstance(raw, dict):
1206
+ return None
1207
+ headers: dict[str, str] = {}
1208
+ for k, v in raw.items():
1209
+ if isinstance(k, str) and k and isinstance(v, str) and v:
1210
+ headers[k] = v
1211
+ if not headers:
1212
+ return None
1213
+
1214
+ base_url = getattr(adapter, "base_url", None)
1215
+ if not isinstance(base_url, str) or not base_url:
1216
+ return None
1217
+ host = urllib.parse.urlparse(base_url).hostname
1218
+ if not isinstance(host, str) or not host:
1219
+ return None
1220
+ return headers, {host.lower()}
1221
+
1222
+
1223
+ def _write_response(
1224
+ parts: list[Part],
1225
+ *,
1226
+ output: OutputSpec,
1227
+ output_path: str | None,
1228
+ timeout_ms: int | None,
1229
+ download_auth: tuple[dict[str, str], set[str]] | None,
1230
+ ) -> None:
1231
+ modalities = set(output.modalities)
1232
+ if modalities == {"text"}:
1233
+ text = parts[0].text or ""
1234
+ _write_text(text, output_path=output_path)
1235
+ return
1236
+
1237
+ if modalities == {"embedding"}:
1238
+ vec = parts[0].embedding or []
1239
+ if output_path:
1240
+ with open(output_path, "w", encoding="utf-8") as f:
1241
+ json.dump(vec, f, ensure_ascii=False)
1242
+ print(f"[OK] wrote {output_path} (dims={len(vec)})")
1243
+ return
1244
+ head = ", ".join([f"{x:.6f}" for x in vec[:8]])
1245
+ print(f"dims={len(vec)} [{head}{', ...' if len(vec) > 8 else ''}]")
1246
+ return
1247
+
1248
+ if modalities in ({"image"}, {"audio"}, {"video"}):
1249
+ p = next((x for x in parts if x.type in modalities), None)
1250
+ if not p:
1251
+ raise SystemExit("missing binary output part")
1252
+ if not p.source:
1253
+ raise SystemExit("missing output source")
1254
+ if p.source.kind == "url":
1255
+ if output_path:
1256
+ headers: dict[str, str] | None = None
1257
+ if download_auth is not None:
1258
+ allowed = download_auth[1]
1259
+ host = urllib.parse.urlparse(p.source.url).hostname
1260
+ if isinstance(host, str) and host and host.lower() in allowed:
1261
+ headers = download_auth[0]
1262
+ _download_with_headers(
1263
+ p.source.url,
1264
+ output_path,
1265
+ timeout_ms=timeout_ms,
1266
+ headers=headers,
1267
+ )
1268
+ print(f"[OK] downloaded to {output_path}")
1269
+ else:
1270
+ print(p.source.url)
1271
+ return
1272
+ if p.source.kind == "ref":
1273
+ if output_path:
1274
+ raise SystemExit(
1275
+ "cannot write ref output; provider-specific download required"
1276
+ )
1277
+ ref = (
1278
+ f"{p.source.provider}:{p.source.id}"
1279
+ if p.source.provider
1280
+ else p.source.id
1281
+ )
1282
+ print(ref)
1283
+ return
1284
+ if p.source.kind != "bytes":
1285
+ raise SystemExit(f"unsupported output source kind: {p.source.kind}")
1286
+ data: bytes
1287
+ if p.source.encoding == "base64":
1288
+ raw_b64 = p.source.data
1289
+ if not isinstance(raw_b64, str) or not raw_b64:
1290
+ raise SystemExit("invalid base64 output") from None
1291
+ try:
1292
+ data = base64.b64decode(raw_b64)
1293
+ except Exception:
1294
+ raise SystemExit("invalid base64 output") from None
1295
+ else:
1296
+ raw = p.source.data
1297
+ if isinstance(raw, bytearray):
1298
+ raw = bytes(raw)
1299
+ if not isinstance(raw, bytes):
1300
+ raise SystemExit(f"invalid bytes output (encoding={p.source.encoding})")
1301
+ data = raw
1302
+ path = output_path or f"genai_output{_guess_ext(p.mime_type)}"
1303
+ with open(path, "wb") as f:
1304
+ f.write(data)
1305
+ print(f"[OK] wrote {path} ({p.mime_type}, {len(data)} bytes)")
1306
+ return
1307
+
1308
+ raise SystemExit(f"unsupported output modalities: {output.modalities}")
1309
+
1310
+
1311
+ def _generate_token() -> str:
1312
+ return f"sk-{secrets.token_urlsafe(32)}"
1313
+
1314
+
1315
+ if __name__ == "__main__":
1316
+ main()