nous-genai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nous/__init__.py +3 -0
- nous/genai/__init__.py +56 -0
- nous/genai/__main__.py +3 -0
- nous/genai/_internal/__init__.py +1 -0
- nous/genai/_internal/capability_rules.py +476 -0
- nous/genai/_internal/config.py +102 -0
- nous/genai/_internal/errors.py +63 -0
- nous/genai/_internal/http.py +951 -0
- nous/genai/_internal/json_schema.py +54 -0
- nous/genai/cli.py +1316 -0
- nous/genai/client.py +719 -0
- nous/genai/mcp_cli.py +275 -0
- nous/genai/mcp_server.py +1080 -0
- nous/genai/providers/__init__.py +15 -0
- nous/genai/providers/aliyun.py +535 -0
- nous/genai/providers/anthropic.py +483 -0
- nous/genai/providers/gemini.py +1606 -0
- nous/genai/providers/openai.py +1909 -0
- nous/genai/providers/tuzi.py +1158 -0
- nous/genai/providers/volcengine.py +273 -0
- nous/genai/reference/__init__.py +17 -0
- nous/genai/reference/catalog.py +206 -0
- nous/genai/reference/mappings.py +467 -0
- nous/genai/reference/mode_overrides.py +26 -0
- nous/genai/reference/model_catalog.py +82 -0
- nous/genai/reference/model_catalog_data/__init__.py +1 -0
- nous/genai/reference/model_catalog_data/aliyun.py +98 -0
- nous/genai/reference/model_catalog_data/anthropic.py +10 -0
- nous/genai/reference/model_catalog_data/google.py +45 -0
- nous/genai/reference/model_catalog_data/openai.py +44 -0
- nous/genai/reference/model_catalog_data/tuzi_anthropic.py +21 -0
- nous/genai/reference/model_catalog_data/tuzi_google.py +19 -0
- nous/genai/reference/model_catalog_data/tuzi_openai.py +75 -0
- nous/genai/reference/model_catalog_data/tuzi_web.py +136 -0
- nous/genai/reference/model_catalog_data/volcengine.py +107 -0
- nous/genai/tools/__init__.py +13 -0
- nous/genai/tools/output_parser.py +119 -0
- nous/genai/types.py +416 -0
- nous/py.typed +1 -0
- nous_genai-0.1.0.dist-info/METADATA +200 -0
- nous_genai-0.1.0.dist-info/RECORD +45 -0
- nous_genai-0.1.0.dist-info/WHEEL +5 -0
- nous_genai-0.1.0.dist-info/entry_points.txt +4 -0
- nous_genai-0.1.0.dist-info/licenses/LICENSE +190 -0
- nous_genai-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1606 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import tempfile
|
|
6
|
+
import time
|
|
7
|
+
import urllib.parse
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any, Iterator, Literal
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
from .._internal.capability_rules import (
|
|
14
|
+
gemini_image_input_modalities,
|
|
15
|
+
gemini_model_kind,
|
|
16
|
+
gemini_output_modalities,
|
|
17
|
+
)
|
|
18
|
+
from .._internal.errors import (
|
|
19
|
+
invalid_request_error,
|
|
20
|
+
not_supported_error,
|
|
21
|
+
provider_error,
|
|
22
|
+
timeout_error,
|
|
23
|
+
)
|
|
24
|
+
from .._internal.http import (
|
|
25
|
+
download_to_tempfile,
|
|
26
|
+
multipart_form_data_json_and_file,
|
|
27
|
+
request_json,
|
|
28
|
+
request_stream_json_sse,
|
|
29
|
+
request_streaming_body_json,
|
|
30
|
+
)
|
|
31
|
+
from ..types import (
|
|
32
|
+
Capability,
|
|
33
|
+
GenerateEvent,
|
|
34
|
+
GenerateRequest,
|
|
35
|
+
GenerateResponse,
|
|
36
|
+
JobInfo,
|
|
37
|
+
Message,
|
|
38
|
+
Part,
|
|
39
|
+
PartType,
|
|
40
|
+
PartSourceBytes,
|
|
41
|
+
PartSourcePath,
|
|
42
|
+
PartSourceRef,
|
|
43
|
+
PartSourceUrl,
|
|
44
|
+
Usage,
|
|
45
|
+
bytes_to_base64,
|
|
46
|
+
detect_mime_type,
|
|
47
|
+
file_to_bytes,
|
|
48
|
+
normalize_reasoning_effort,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
_GEMINI_DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com"
|
|
52
|
+
|
|
53
|
+
_ASYNCDATA_BASE_URL = "https://asyncdata.net"
|
|
54
|
+
|
|
55
|
+
_TUZI_TASK_ID_RE = re.compile(r"Task ID:\s*`([^`]+)`")
|
|
56
|
+
_MP4_URL_RE = re.compile(r"https?://[^\s)]+\.mp4")
|
|
57
|
+
_MD_IMAGE_URL_RE = re.compile(r"!\[[^\]]*]\((https?://[^\s)]+)\)")
|
|
58
|
+
|
|
59
|
+
_GEMINI_SCHEMA_TYPES = frozenset(
|
|
60
|
+
{
|
|
61
|
+
"TYPE_UNSPECIFIED",
|
|
62
|
+
"STRING",
|
|
63
|
+
"NUMBER",
|
|
64
|
+
"INTEGER",
|
|
65
|
+
"BOOLEAN",
|
|
66
|
+
"ARRAY",
|
|
67
|
+
"OBJECT",
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
_JSON_SCHEMA_TO_GEMINI_TYPE: dict[str, str] = {
|
|
72
|
+
"string": "STRING",
|
|
73
|
+
"number": "NUMBER",
|
|
74
|
+
"integer": "INTEGER",
|
|
75
|
+
"boolean": "BOOLEAN",
|
|
76
|
+
"array": "ARRAY",
|
|
77
|
+
"object": "OBJECT",
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _looks_like_gemini_schema(schema: dict[str, Any]) -> bool:
|
|
82
|
+
t = schema.get("type")
|
|
83
|
+
return isinstance(t, str) and t in _GEMINI_SCHEMA_TYPES
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _resolve_json_schema_ref(ref: str, defs: dict[str, Any]) -> dict[str, Any]:
|
|
87
|
+
if not ref.startswith("#/$defs/"):
|
|
88
|
+
raise invalid_request_error(
|
|
89
|
+
"Gemini responseSchema only supports local $defs $ref"
|
|
90
|
+
)
|
|
91
|
+
name = ref[len("#/$defs/") :]
|
|
92
|
+
if not name or "/" in name:
|
|
93
|
+
raise invalid_request_error(
|
|
94
|
+
"Gemini responseSchema only supports simple $defs refs"
|
|
95
|
+
)
|
|
96
|
+
resolved = defs.get(name)
|
|
97
|
+
if not isinstance(resolved, dict):
|
|
98
|
+
raise invalid_request_error(f"unresolved $ref: {ref}")
|
|
99
|
+
return resolved
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _is_json_schema_null(schema: Any) -> bool:
|
|
103
|
+
if not isinstance(schema, dict):
|
|
104
|
+
return False
|
|
105
|
+
return schema.get("type") == "null"
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _convert_nullable_union(
|
|
109
|
+
*,
|
|
110
|
+
tag: str,
|
|
111
|
+
options: Any,
|
|
112
|
+
defs: dict[str, Any],
|
|
113
|
+
ref_stack: tuple[str, ...],
|
|
114
|
+
) -> dict[str, Any]:
|
|
115
|
+
if not isinstance(options, list) or not options:
|
|
116
|
+
raise invalid_request_error(
|
|
117
|
+
f"Gemini responseSchema {tag} must be a non-empty array"
|
|
118
|
+
)
|
|
119
|
+
non_null: list[dict[str, Any]] = []
|
|
120
|
+
null_count = 0
|
|
121
|
+
for item in options:
|
|
122
|
+
if _is_json_schema_null(item):
|
|
123
|
+
null_count += 1
|
|
124
|
+
continue
|
|
125
|
+
if not isinstance(item, dict):
|
|
126
|
+
raise invalid_request_error(
|
|
127
|
+
f"Gemini responseSchema {tag} items must be objects"
|
|
128
|
+
)
|
|
129
|
+
non_null.append(item)
|
|
130
|
+
if null_count != 1 or len(non_null) != 1:
|
|
131
|
+
raise invalid_request_error(
|
|
132
|
+
f"Gemini responseSchema only supports nullable unions ({tag} with exactly one null and one schema)"
|
|
133
|
+
)
|
|
134
|
+
out = _json_schema_to_gemini_schema(non_null[0], defs=defs, ref_stack=ref_stack)
|
|
135
|
+
out["nullable"] = True
|
|
136
|
+
return out
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _json_schema_to_gemini_schema(
|
|
140
|
+
schema: Any,
|
|
141
|
+
*,
|
|
142
|
+
defs: dict[str, Any],
|
|
143
|
+
ref_stack: tuple[str, ...],
|
|
144
|
+
) -> dict[str, Any]:
|
|
145
|
+
if not isinstance(schema, dict):
|
|
146
|
+
raise invalid_request_error(
|
|
147
|
+
"output.text.json_schema must be an object for Gemini"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
local_defs = schema.get("$defs")
|
|
151
|
+
if isinstance(local_defs, dict) and local_defs:
|
|
152
|
+
merged = dict(defs)
|
|
153
|
+
merged.update(local_defs)
|
|
154
|
+
defs = merged
|
|
155
|
+
|
|
156
|
+
ref = schema.get("$ref")
|
|
157
|
+
if ref is not None:
|
|
158
|
+
if not isinstance(ref, str) or not ref:
|
|
159
|
+
raise invalid_request_error("$ref must be a non-empty string")
|
|
160
|
+
if ref in ref_stack:
|
|
161
|
+
raise invalid_request_error(
|
|
162
|
+
"Gemini responseSchema does not support recursive $ref"
|
|
163
|
+
)
|
|
164
|
+
resolved = _resolve_json_schema_ref(ref, defs)
|
|
165
|
+
out = _json_schema_to_gemini_schema(
|
|
166
|
+
resolved, defs=defs, ref_stack=ref_stack + (ref,)
|
|
167
|
+
)
|
|
168
|
+
desc = schema.get("description")
|
|
169
|
+
if isinstance(desc, str) and desc.strip() and "description" not in out:
|
|
170
|
+
out["description"] = desc.strip()
|
|
171
|
+
return out
|
|
172
|
+
|
|
173
|
+
all_of = schema.get("allOf")
|
|
174
|
+
if all_of is not None:
|
|
175
|
+
if isinstance(all_of, list) and len(all_of) == 1:
|
|
176
|
+
out = _json_schema_to_gemini_schema(
|
|
177
|
+
all_of[0], defs=defs, ref_stack=ref_stack
|
|
178
|
+
)
|
|
179
|
+
desc = schema.get("description")
|
|
180
|
+
if isinstance(desc, str) and desc.strip() and "description" not in out:
|
|
181
|
+
out["description"] = desc.strip()
|
|
182
|
+
return out
|
|
183
|
+
raise invalid_request_error(
|
|
184
|
+
"Gemini responseSchema does not support allOf (except a single item)"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
any_of = schema.get("anyOf")
|
|
188
|
+
if any_of is not None:
|
|
189
|
+
out = _convert_nullable_union(
|
|
190
|
+
tag="anyOf", options=any_of, defs=defs, ref_stack=ref_stack
|
|
191
|
+
)
|
|
192
|
+
desc = schema.get("description")
|
|
193
|
+
if isinstance(desc, str) and desc.strip() and "description" not in out:
|
|
194
|
+
out["description"] = desc.strip()
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
one_of = schema.get("oneOf")
|
|
198
|
+
if one_of is not None:
|
|
199
|
+
out = _convert_nullable_union(
|
|
200
|
+
tag="oneOf", options=one_of, defs=defs, ref_stack=ref_stack
|
|
201
|
+
)
|
|
202
|
+
desc = schema.get("description")
|
|
203
|
+
if isinstance(desc, str) and desc.strip() and "description" not in out:
|
|
204
|
+
out["description"] = desc.strip()
|
|
205
|
+
return out
|
|
206
|
+
|
|
207
|
+
nullable = False
|
|
208
|
+
t = schema.get("type")
|
|
209
|
+
if isinstance(t, list):
|
|
210
|
+
types = [x for x in t if isinstance(x, str) and x]
|
|
211
|
+
if "null" in types:
|
|
212
|
+
types = [x for x in types if x != "null"]
|
|
213
|
+
if len(types) != 1:
|
|
214
|
+
raise invalid_request_error(
|
|
215
|
+
"Gemini responseSchema only supports nullable union with one non-null type"
|
|
216
|
+
)
|
|
217
|
+
t = types[0]
|
|
218
|
+
nullable = True
|
|
219
|
+
elif len(types) == 1:
|
|
220
|
+
t = types[0]
|
|
221
|
+
else:
|
|
222
|
+
raise invalid_request_error(
|
|
223
|
+
"Gemini responseSchema does not support union types"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if not isinstance(t, str) or not t:
|
|
227
|
+
if isinstance(schema.get("properties"), dict):
|
|
228
|
+
t = "object"
|
|
229
|
+
else:
|
|
230
|
+
raise invalid_request_error("output.text.json_schema missing type")
|
|
231
|
+
|
|
232
|
+
t_norm = t.strip().lower()
|
|
233
|
+
gemini_type = _JSON_SCHEMA_TO_GEMINI_TYPE.get(t_norm)
|
|
234
|
+
if gemini_type is None:
|
|
235
|
+
raise invalid_request_error(f"Gemini responseSchema unsupported type: {t}")
|
|
236
|
+
|
|
237
|
+
out = {"type": gemini_type}
|
|
238
|
+
if nullable:
|
|
239
|
+
out["nullable"] = True
|
|
240
|
+
|
|
241
|
+
desc = schema.get("description")
|
|
242
|
+
if isinstance(desc, str) and desc.strip():
|
|
243
|
+
out["description"] = desc.strip()
|
|
244
|
+
|
|
245
|
+
const = schema.get("const")
|
|
246
|
+
if const is not None:
|
|
247
|
+
out["enum"] = [const]
|
|
248
|
+
enum = schema.get("enum")
|
|
249
|
+
if isinstance(enum, list) and enum:
|
|
250
|
+
out["enum"] = enum
|
|
251
|
+
|
|
252
|
+
if gemini_type == "OBJECT":
|
|
253
|
+
props = schema.get("properties")
|
|
254
|
+
if isinstance(props, dict):
|
|
255
|
+
out_props: dict[str, Any] = {}
|
|
256
|
+
for k, v in props.items():
|
|
257
|
+
if not isinstance(k, str) or not k:
|
|
258
|
+
continue
|
|
259
|
+
out_props[k] = _json_schema_to_gemini_schema(
|
|
260
|
+
v, defs=defs, ref_stack=ref_stack
|
|
261
|
+
)
|
|
262
|
+
if out_props:
|
|
263
|
+
out["properties"] = out_props
|
|
264
|
+
required = schema.get("required")
|
|
265
|
+
if isinstance(required, list):
|
|
266
|
+
req = [x for x in required if isinstance(x, str) and x]
|
|
267
|
+
if req:
|
|
268
|
+
out["required"] = req
|
|
269
|
+
|
|
270
|
+
addl = schema.get("additionalProperties")
|
|
271
|
+
if isinstance(addl, dict):
|
|
272
|
+
out["additionalProperties"] = _json_schema_to_gemini_schema(
|
|
273
|
+
addl, defs=defs, ref_stack=ref_stack
|
|
274
|
+
)
|
|
275
|
+
elif isinstance(addl, bool):
|
|
276
|
+
out["additionalProperties"] = addl
|
|
277
|
+
|
|
278
|
+
if gemini_type == "ARRAY":
|
|
279
|
+
items = schema.get("items")
|
|
280
|
+
if isinstance(items, dict):
|
|
281
|
+
out["items"] = _json_schema_to_gemini_schema(
|
|
282
|
+
items, defs=defs, ref_stack=ref_stack
|
|
283
|
+
)
|
|
284
|
+
elif items is not None:
|
|
285
|
+
raise invalid_request_error(
|
|
286
|
+
"Gemini responseSchema array items must be an object"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return out
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _to_gemini_response_schema(schema: Any) -> dict[str, Any]:
|
|
293
|
+
if not isinstance(schema, dict):
|
|
294
|
+
raise invalid_request_error("output.text.json_schema must be an object")
|
|
295
|
+
if _looks_like_gemini_schema(schema):
|
|
296
|
+
raise invalid_request_error(
|
|
297
|
+
"output.text.json_schema must be JSON Schema (not Gemini responseSchema); "
|
|
298
|
+
"pass a Python type/model or use provider_options.google.generationConfig.responseSchema"
|
|
299
|
+
)
|
|
300
|
+
defs = schema.get("$defs")
|
|
301
|
+
defs_map = defs if isinstance(defs, dict) else {}
|
|
302
|
+
return _json_schema_to_gemini_schema(schema, defs=defs_map, ref_stack=())
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@dataclass(frozen=True, slots=True)
|
|
306
|
+
class GeminiAdapter:
|
|
307
|
+
api_key: str
|
|
308
|
+
base_url: str = _GEMINI_DEFAULT_BASE_URL
|
|
309
|
+
provider_name: str = "google"
|
|
310
|
+
auth_mode: Literal["query_key", "bearer"] = "query_key"
|
|
311
|
+
supports_file_upload: bool = True
|
|
312
|
+
proxy_url: str | None = None
|
|
313
|
+
|
|
314
|
+
def _auth_headers(self) -> dict[str, str]:
|
|
315
|
+
if self.auth_mode == "bearer":
|
|
316
|
+
return {"Authorization": f"Bearer {self.api_key}"}
|
|
317
|
+
return {}
|
|
318
|
+
|
|
319
|
+
def _with_key(self, url: str) -> str:
|
|
320
|
+
if self.auth_mode != "query_key":
|
|
321
|
+
return url
|
|
322
|
+
sep = "&" if "?" in url else "?"
|
|
323
|
+
return f"{url}{sep}key={self.api_key}"
|
|
324
|
+
|
|
325
|
+
def _v1beta_url(self, path: str) -> str:
|
|
326
|
+
base = self.base_url.rstrip("/")
|
|
327
|
+
return self._with_key(f"{base}/v1beta/{path.lstrip('/')}")
|
|
328
|
+
|
|
329
|
+
def _upload_url(self, path: str) -> str:
|
|
330
|
+
base = self.base_url.rstrip("/")
|
|
331
|
+
return self._with_key(f"{base}/upload/v1beta/{path.lstrip('/')}")
|
|
332
|
+
|
|
333
|
+
def _download_headers(self) -> dict[str, str] | None:
|
|
334
|
+
if self.auth_mode == "bearer":
|
|
335
|
+
return {"Authorization": f"Bearer {self.api_key}"}
|
|
336
|
+
return {"x-goog-api-key": self.api_key}
|
|
337
|
+
|
|
338
|
+
def capabilities(self, model_id: str) -> Capability:
|
|
339
|
+
kind = gemini_model_kind(model_id)
|
|
340
|
+
out_mods = gemini_output_modalities(kind)
|
|
341
|
+
|
|
342
|
+
if kind == "video":
|
|
343
|
+
return Capability(
|
|
344
|
+
input_modalities={"text"},
|
|
345
|
+
output_modalities=out_mods,
|
|
346
|
+
supports_stream=False,
|
|
347
|
+
supports_job=True,
|
|
348
|
+
supports_tools=False,
|
|
349
|
+
supports_json_schema=False,
|
|
350
|
+
)
|
|
351
|
+
if kind == "embedding":
|
|
352
|
+
return Capability(
|
|
353
|
+
input_modalities={"text"},
|
|
354
|
+
output_modalities=out_mods,
|
|
355
|
+
supports_stream=False,
|
|
356
|
+
supports_job=False,
|
|
357
|
+
supports_tools=False,
|
|
358
|
+
supports_json_schema=False,
|
|
359
|
+
)
|
|
360
|
+
if kind == "tts":
|
|
361
|
+
return Capability(
|
|
362
|
+
input_modalities={"text"},
|
|
363
|
+
output_modalities=out_mods,
|
|
364
|
+
supports_stream=False,
|
|
365
|
+
supports_job=False,
|
|
366
|
+
supports_tools=False,
|
|
367
|
+
supports_json_schema=False,
|
|
368
|
+
)
|
|
369
|
+
if kind == "native_audio":
|
|
370
|
+
return Capability(
|
|
371
|
+
input_modalities={"text", "audio", "video"},
|
|
372
|
+
output_modalities=out_mods,
|
|
373
|
+
supports_stream=True,
|
|
374
|
+
supports_job=False,
|
|
375
|
+
supports_tools=True,
|
|
376
|
+
supports_json_schema=True,
|
|
377
|
+
)
|
|
378
|
+
if kind == "image":
|
|
379
|
+
return Capability(
|
|
380
|
+
input_modalities=gemini_image_input_modalities(model_id),
|
|
381
|
+
output_modalities=out_mods,
|
|
382
|
+
supports_stream=False,
|
|
383
|
+
supports_job=False,
|
|
384
|
+
supports_tools=False,
|
|
385
|
+
supports_json_schema=False,
|
|
386
|
+
)
|
|
387
|
+
return Capability(
|
|
388
|
+
input_modalities={"text", "image", "audio", "video"},
|
|
389
|
+
output_modalities=out_mods,
|
|
390
|
+
supports_stream=True,
|
|
391
|
+
supports_job=False,
|
|
392
|
+
supports_tools=True,
|
|
393
|
+
supports_json_schema=True,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def list_models(self, *, timeout_ms: int | None = None) -> list[str]:
|
|
397
|
+
"""
|
|
398
|
+
Fetch remote model ids via Gemini Developer API GET /v1beta/models.
|
|
399
|
+
|
|
400
|
+
Returns model ids without the leading "models/" prefix.
|
|
401
|
+
"""
|
|
402
|
+
out: list[str] = []
|
|
403
|
+
page_token: str | None = None
|
|
404
|
+
for _ in range(20):
|
|
405
|
+
url = self._v1beta_url("models")
|
|
406
|
+
if page_token:
|
|
407
|
+
token = urllib.parse.quote(page_token, safe="")
|
|
408
|
+
sep = "&" if "?" in url else "?"
|
|
409
|
+
url = f"{url}{sep}pageToken={token}"
|
|
410
|
+
obj = request_json(
|
|
411
|
+
method="GET",
|
|
412
|
+
url=url,
|
|
413
|
+
headers=self._auth_headers(),
|
|
414
|
+
timeout_ms=timeout_ms,
|
|
415
|
+
proxy_url=self.proxy_url,
|
|
416
|
+
)
|
|
417
|
+
models = obj.get("models")
|
|
418
|
+
if isinstance(models, list):
|
|
419
|
+
for m in models:
|
|
420
|
+
if not isinstance(m, dict):
|
|
421
|
+
continue
|
|
422
|
+
name = m.get("name")
|
|
423
|
+
if not isinstance(name, str) or not name:
|
|
424
|
+
continue
|
|
425
|
+
out.append(
|
|
426
|
+
name[len("models/") :] if name.startswith("models/") else name
|
|
427
|
+
)
|
|
428
|
+
next_token = obj.get("nextPageToken")
|
|
429
|
+
if not isinstance(next_token, str) or not next_token:
|
|
430
|
+
break
|
|
431
|
+
page_token = next_token
|
|
432
|
+
return sorted(set(out))
|
|
433
|
+
|
|
434
|
+
def generate(
|
|
435
|
+
self, request: GenerateRequest, *, stream: bool
|
|
436
|
+
) -> GenerateResponse | Iterator[GenerateEvent]:
|
|
437
|
+
model_id = request.model_id()
|
|
438
|
+
model_name = (
|
|
439
|
+
model_id if model_id.startswith("models/") else f"models/{model_id}"
|
|
440
|
+
)
|
|
441
|
+
modalities = set(request.output.modalities)
|
|
442
|
+
if "embedding" in modalities:
|
|
443
|
+
if modalities != {"embedding"}:
|
|
444
|
+
raise not_supported_error(
|
|
445
|
+
"embedding cannot be combined with other output modalities"
|
|
446
|
+
)
|
|
447
|
+
if stream:
|
|
448
|
+
raise not_supported_error("embedding does not support streaming")
|
|
449
|
+
return self._embed(request, model_name=model_name)
|
|
450
|
+
|
|
451
|
+
if "video" in modalities:
|
|
452
|
+
if modalities != {"video"}:
|
|
453
|
+
raise not_supported_error(
|
|
454
|
+
"video cannot be combined with other output modalities"
|
|
455
|
+
)
|
|
456
|
+
if stream:
|
|
457
|
+
raise not_supported_error("video does not support streaming")
|
|
458
|
+
return self._video(request, model_name=model_name)
|
|
459
|
+
if stream and (modalities & {"image", "audio"}):
|
|
460
|
+
raise not_supported_error(
|
|
461
|
+
"streaming image/audio output is not supported in this SDK yet"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if stream:
|
|
465
|
+
return self._generate_stream(request, model_name=model_name)
|
|
466
|
+
return self._generate(request, model_name=model_name)
|
|
467
|
+
|
|
468
|
+
def _generate(
|
|
469
|
+
self, request: GenerateRequest, *, model_name: str
|
|
470
|
+
) -> GenerateResponse:
|
|
471
|
+
url = self._v1beta_url(f"{model_name}:generateContent")
|
|
472
|
+
body = self._generate_body(request, model_name=model_name)
|
|
473
|
+
obj = request_json(
|
|
474
|
+
method="POST",
|
|
475
|
+
url=url,
|
|
476
|
+
headers=self._auth_headers(),
|
|
477
|
+
json_body=body,
|
|
478
|
+
timeout_ms=request.params.timeout_ms,
|
|
479
|
+
proxy_url=self.proxy_url,
|
|
480
|
+
)
|
|
481
|
+
return self._parse_generate(obj, model=model_name)
|
|
482
|
+
|
|
483
|
+
def _video(self, request: GenerateRequest, *, model_name: str) -> GenerateResponse:
|
|
484
|
+
if self.provider_name.startswith("tuzi") and model_name.startswith(
|
|
485
|
+
"models/veo2"
|
|
486
|
+
):
|
|
487
|
+
return self._tuzi_veo2_video(request, model_name=model_name)
|
|
488
|
+
|
|
489
|
+
if not model_name.startswith("models/veo-"):
|
|
490
|
+
raise not_supported_error(
|
|
491
|
+
'video generation requires model like "google:veo-3.1-generate-preview"'
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
texts = _gather_text_inputs(request)
|
|
495
|
+
if len(texts) != 1:
|
|
496
|
+
raise invalid_request_error(
|
|
497
|
+
"video generation requires exactly one text part"
|
|
498
|
+
)
|
|
499
|
+
prompt = texts[0]
|
|
500
|
+
|
|
501
|
+
body: dict[str, Any] = {"instances": [{"prompt": prompt}]}
|
|
502
|
+
params: dict[str, Any] = {}
|
|
503
|
+
video = request.output.video
|
|
504
|
+
if video and video.duration_sec is not None:
|
|
505
|
+
duration = int(video.duration_sec)
|
|
506
|
+
if duration < 5 or duration > 8:
|
|
507
|
+
raise invalid_request_error(
|
|
508
|
+
"google veo duration_sec must be between 5 and 8 seconds"
|
|
509
|
+
)
|
|
510
|
+
params["durationSeconds"] = duration
|
|
511
|
+
if video and video.aspect_ratio:
|
|
512
|
+
params["aspectRatio"] = video.aspect_ratio
|
|
513
|
+
|
|
514
|
+
opts = (
|
|
515
|
+
request.provider_options.get(self.provider_name)
|
|
516
|
+
or request.provider_options.get("google")
|
|
517
|
+
or request.provider_options.get("gemini")
|
|
518
|
+
)
|
|
519
|
+
if isinstance(opts, dict):
|
|
520
|
+
opt_params = opts.get("parameters")
|
|
521
|
+
if opt_params is not None:
|
|
522
|
+
if not isinstance(opt_params, dict):
|
|
523
|
+
raise invalid_request_error(
|
|
524
|
+
"provider_options.google.parameters must be an object"
|
|
525
|
+
)
|
|
526
|
+
for k, v in opt_params.items():
|
|
527
|
+
if k in params:
|
|
528
|
+
raise invalid_request_error(
|
|
529
|
+
f"provider_options cannot override parameters.{k}"
|
|
530
|
+
)
|
|
531
|
+
params[k] = v
|
|
532
|
+
for k, v in opts.items():
|
|
533
|
+
if k == "parameters":
|
|
534
|
+
continue
|
|
535
|
+
if k in body:
|
|
536
|
+
raise invalid_request_error(
|
|
537
|
+
f"provider_options cannot override body.{k}"
|
|
538
|
+
)
|
|
539
|
+
body[k] = v
|
|
540
|
+
|
|
541
|
+
if params:
|
|
542
|
+
body["parameters"] = params
|
|
543
|
+
|
|
544
|
+
budget_ms = (
|
|
545
|
+
120_000 if request.params.timeout_ms is None else request.params.timeout_ms
|
|
546
|
+
)
|
|
547
|
+
deadline = time.time() + max(1, budget_ms) / 1000.0
|
|
548
|
+
url = self._v1beta_url(f"{model_name}:predictLongRunning")
|
|
549
|
+
obj = request_json(
|
|
550
|
+
method="POST",
|
|
551
|
+
url=url,
|
|
552
|
+
headers=self._auth_headers(),
|
|
553
|
+
json_body=body,
|
|
554
|
+
timeout_ms=min(30_000, max(1, budget_ms)),
|
|
555
|
+
proxy_url=self.proxy_url,
|
|
556
|
+
)
|
|
557
|
+
name = obj.get("name")
|
|
558
|
+
if not isinstance(name, str) or not name:
|
|
559
|
+
raise provider_error("gemini veo response missing operation name")
|
|
560
|
+
|
|
561
|
+
if not request.wait:
|
|
562
|
+
return GenerateResponse(
|
|
563
|
+
id=f"sdk_{uuid4().hex}",
|
|
564
|
+
provider="google",
|
|
565
|
+
model=f"google:{model_name}",
|
|
566
|
+
status="running",
|
|
567
|
+
job=JobInfo(job_id=name, poll_after_ms=1_000),
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
if obj.get("done") is not True:
|
|
571
|
+
obj = self._wait_operation_done(name=name, deadline=deadline)
|
|
572
|
+
if obj.get("done") is not True:
|
|
573
|
+
return GenerateResponse(
|
|
574
|
+
id=f"sdk_{uuid4().hex}",
|
|
575
|
+
provider="google",
|
|
576
|
+
model=f"google:{model_name}",
|
|
577
|
+
status="running",
|
|
578
|
+
job=JobInfo(job_id=name, poll_after_ms=1_000),
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
err = obj.get("error")
|
|
582
|
+
if isinstance(err, dict):
|
|
583
|
+
msg = err.get("message") or str(err)
|
|
584
|
+
raise provider_error(f"gemini veo operation failed: {msg}")
|
|
585
|
+
|
|
586
|
+
video_uri = _extract_veo_video_uri(obj.get("response"))
|
|
587
|
+
if not video_uri:
|
|
588
|
+
raise provider_error("gemini veo operation response missing video uri")
|
|
589
|
+
scheme = urllib.parse.urlparse(video_uri).scheme.lower()
|
|
590
|
+
source = (
|
|
591
|
+
PartSourceUrl(url=video_uri)
|
|
592
|
+
if scheme in {"http", "https"}
|
|
593
|
+
else PartSourceRef(provider=self.provider_name, id=video_uri)
|
|
594
|
+
)
|
|
595
|
+
part = Part(type="video", mime_type="video/mp4", source=source)
|
|
596
|
+
return GenerateResponse(
|
|
597
|
+
id=f"sdk_{uuid4().hex}",
|
|
598
|
+
provider=self.provider_name,
|
|
599
|
+
model=f"{self.provider_name}:{model_name}",
|
|
600
|
+
status="completed",
|
|
601
|
+
output=[Message(role="assistant", content=[part])],
|
|
602
|
+
usage=None,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
def _tuzi_veo2_video(
|
|
606
|
+
self, request: GenerateRequest, *, model_name: str
|
|
607
|
+
) -> GenerateResponse:
|
|
608
|
+
texts = _gather_text_inputs(request)
|
|
609
|
+
if len(texts) != 1:
|
|
610
|
+
raise invalid_request_error(
|
|
611
|
+
"video generation requires exactly one text part"
|
|
612
|
+
)
|
|
613
|
+
prompt = texts[0]
|
|
614
|
+
|
|
615
|
+
params: dict[str, Any] = {}
|
|
616
|
+
video = request.output.video
|
|
617
|
+
if video and video.duration_sec is not None:
|
|
618
|
+
params["durationSeconds"] = int(video.duration_sec)
|
|
619
|
+
if video and video.aspect_ratio:
|
|
620
|
+
params["aspectRatio"] = video.aspect_ratio
|
|
621
|
+
|
|
622
|
+
body: dict[str, Any] = {
|
|
623
|
+
"contents": [{"role": "user", "parts": [{"text": prompt}]}]
|
|
624
|
+
}
|
|
625
|
+
if params:
|
|
626
|
+
body["parameters"] = params
|
|
627
|
+
|
|
628
|
+
budget_ms = (
|
|
629
|
+
120_000 if request.params.timeout_ms is None else request.params.timeout_ms
|
|
630
|
+
)
|
|
631
|
+
obj = request_json(
|
|
632
|
+
method="POST",
|
|
633
|
+
url=self._v1beta_url(f"{model_name}:predictLongRunning"),
|
|
634
|
+
headers=self._auth_headers(),
|
|
635
|
+
json_body=body,
|
|
636
|
+
timeout_ms=max(1, budget_ms),
|
|
637
|
+
proxy_url=self.proxy_url,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
text = _first_candidate_text(obj)
|
|
641
|
+
mp4_url = _extract_first_mp4_url(text) if text else None
|
|
642
|
+
task_id = _extract_tuzi_task_id(text) if text else None
|
|
643
|
+
|
|
644
|
+
if mp4_url:
|
|
645
|
+
part = Part(
|
|
646
|
+
type="video", mime_type="video/mp4", source=PartSourceUrl(url=mp4_url)
|
|
647
|
+
)
|
|
648
|
+
return GenerateResponse(
|
|
649
|
+
id=f"sdk_{uuid4().hex}",
|
|
650
|
+
provider=self.provider_name,
|
|
651
|
+
model=f"{self.provider_name}:{model_name}",
|
|
652
|
+
status="completed",
|
|
653
|
+
output=[Message(role="assistant", content=[part])],
|
|
654
|
+
usage=_usage_from_gemini(obj.get("usageMetadata")),
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
if task_id and request.wait:
|
|
658
|
+
deadline = time.time() + max(1, budget_ms) / 1000.0
|
|
659
|
+
mp4_url = _poll_tuzi_video_mp4(
|
|
660
|
+
task_id=task_id, deadline=deadline, proxy_url=self.proxy_url
|
|
661
|
+
)
|
|
662
|
+
if mp4_url:
|
|
663
|
+
part = Part(
|
|
664
|
+
type="video",
|
|
665
|
+
mime_type="video/mp4",
|
|
666
|
+
source=PartSourceUrl(url=mp4_url),
|
|
667
|
+
)
|
|
668
|
+
return GenerateResponse(
|
|
669
|
+
id=f"sdk_{uuid4().hex}",
|
|
670
|
+
provider=self.provider_name,
|
|
671
|
+
model=f"{self.provider_name}:{model_name}",
|
|
672
|
+
status="completed",
|
|
673
|
+
output=[Message(role="assistant", content=[part])],
|
|
674
|
+
usage=None,
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
job = JobInfo(job_id=task_id, poll_after_ms=2_000) if task_id else None
|
|
678
|
+
return GenerateResponse(
|
|
679
|
+
id=f"sdk_{uuid4().hex}",
|
|
680
|
+
provider=self.provider_name,
|
|
681
|
+
model=f"{self.provider_name}:{model_name}",
|
|
682
|
+
status="running",
|
|
683
|
+
job=job,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def _generate_stream(
|
|
687
|
+
self, request: GenerateRequest, *, model_name: str
|
|
688
|
+
) -> Iterator[GenerateEvent]:
|
|
689
|
+
url = self._v1beta_url(f"{model_name}:streamGenerateContent?alt=sse")
|
|
690
|
+
body = self._generate_body(request, model_name=model_name)
|
|
691
|
+
events = request_stream_json_sse(
|
|
692
|
+
method="POST",
|
|
693
|
+
url=url,
|
|
694
|
+
headers=self._auth_headers(),
|
|
695
|
+
json_body=body,
|
|
696
|
+
timeout_ms=request.params.timeout_ms,
|
|
697
|
+
proxy_url=self.proxy_url,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
def _iter() -> Iterator[GenerateEvent]:
|
|
701
|
+
for obj in events:
|
|
702
|
+
delta = self._extract_text_delta(obj)
|
|
703
|
+
if delta:
|
|
704
|
+
yield GenerateEvent(type="output.text.delta", data={"delta": delta})
|
|
705
|
+
yield GenerateEvent(type="done", data={})
|
|
706
|
+
|
|
707
|
+
return _iter()
|
|
708
|
+
|
|
709
|
+
def _embed(self, request: GenerateRequest, *, model_name: str) -> GenerateResponse:
|
|
710
|
+
texts = _gather_text_inputs(request)
|
|
711
|
+
emb = request.output.embedding
|
|
712
|
+
dims = emb.dimensions if emb and emb.dimensions is not None else None
|
|
713
|
+
if model_name == "models/embedding-gecko-001":
|
|
714
|
+
if dims is not None:
|
|
715
|
+
raise invalid_request_error(
|
|
716
|
+
"models/embedding-gecko-001 does not support embedding.dimensions"
|
|
717
|
+
)
|
|
718
|
+
if len(texts) != 1:
|
|
719
|
+
raise invalid_request_error(
|
|
720
|
+
"models/embedding-gecko-001 only supports single text per request"
|
|
721
|
+
)
|
|
722
|
+
url = self._v1beta_url(f"{model_name}:embedText")
|
|
723
|
+
obj = request_json(
|
|
724
|
+
method="POST",
|
|
725
|
+
url=url,
|
|
726
|
+
headers=self._auth_headers(),
|
|
727
|
+
json_body={"model": model_name, "text": texts[0]},
|
|
728
|
+
timeout_ms=request.params.timeout_ms,
|
|
729
|
+
proxy_url=self.proxy_url,
|
|
730
|
+
)
|
|
731
|
+
embedding = obj.get("embedding")
|
|
732
|
+
if not isinstance(embedding, dict):
|
|
733
|
+
raise provider_error("gemini embedText response missing embedding")
|
|
734
|
+
values = embedding.get("value")
|
|
735
|
+
if not isinstance(values, list) or not all(
|
|
736
|
+
isinstance(x, (int, float)) for x in values
|
|
737
|
+
):
|
|
738
|
+
raise provider_error(
|
|
739
|
+
"gemini embedText response missing embedding.value"
|
|
740
|
+
)
|
|
741
|
+
parts = [Part(type="embedding", embedding=[float(x) for x in values])]
|
|
742
|
+
else:
|
|
743
|
+
url = self._v1beta_url(f"{model_name}:batchEmbedContents")
|
|
744
|
+
reqs: list[dict[str, Any]] = [
|
|
745
|
+
{"model": model_name, "content": {"parts": [{"text": t}]}}
|
|
746
|
+
for t in texts
|
|
747
|
+
]
|
|
748
|
+
if dims is not None:
|
|
749
|
+
for r in reqs:
|
|
750
|
+
r["outputDimensionality"] = dims
|
|
751
|
+
obj = request_json(
|
|
752
|
+
method="POST",
|
|
753
|
+
url=url,
|
|
754
|
+
headers=self._auth_headers(),
|
|
755
|
+
json_body={"requests": reqs},
|
|
756
|
+
timeout_ms=request.params.timeout_ms,
|
|
757
|
+
proxy_url=self.proxy_url,
|
|
758
|
+
)
|
|
759
|
+
embeddings = obj.get("embeddings")
|
|
760
|
+
if not isinstance(embeddings, list) or len(embeddings) != len(texts):
|
|
761
|
+
raise provider_error(
|
|
762
|
+
"gemini batchEmbedContents response missing embeddings"
|
|
763
|
+
)
|
|
764
|
+
parts = []
|
|
765
|
+
for emb in embeddings:
|
|
766
|
+
if not isinstance(emb, dict):
|
|
767
|
+
raise provider_error("gemini embedding item is not object")
|
|
768
|
+
values = emb.get("values")
|
|
769
|
+
if not isinstance(values, list) or not all(
|
|
770
|
+
isinstance(x, (int, float)) for x in values
|
|
771
|
+
):
|
|
772
|
+
raise provider_error("gemini embedding item missing values")
|
|
773
|
+
parts.append(
|
|
774
|
+
Part(type="embedding", embedding=[float(x) for x in values])
|
|
775
|
+
)
|
|
776
|
+
return GenerateResponse(
|
|
777
|
+
id=f"sdk_{uuid4().hex}",
|
|
778
|
+
provider=self.provider_name,
|
|
779
|
+
model=f"{self.provider_name}:{model_name}",
|
|
780
|
+
status="completed",
|
|
781
|
+
output=[Message(role="assistant", content=parts)],
|
|
782
|
+
usage=None,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def _generate_body(
|
|
786
|
+
self, request: GenerateRequest, *, model_name: str
|
|
787
|
+
) -> dict[str, Any]:
|
|
788
|
+
system_text = _extract_system_text(request)
|
|
789
|
+
contents = [
|
|
790
|
+
_message_to_content(self, m, timeout_ms=request.params.timeout_ms)
|
|
791
|
+
for m in request.input
|
|
792
|
+
if m.role != "system"
|
|
793
|
+
]
|
|
794
|
+
if not contents:
|
|
795
|
+
raise invalid_request_error(
|
|
796
|
+
"request.input must contain at least one non-system message"
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
gen_cfg: dict[str, Any] = {}
|
|
800
|
+
params = request.params
|
|
801
|
+
if params.temperature is not None:
|
|
802
|
+
gen_cfg["temperature"] = params.temperature
|
|
803
|
+
if params.top_p is not None:
|
|
804
|
+
gen_cfg["topP"] = params.top_p
|
|
805
|
+
if params.seed is not None:
|
|
806
|
+
gen_cfg["seed"] = params.seed
|
|
807
|
+
text_spec = request.output.text
|
|
808
|
+
max_out = (
|
|
809
|
+
text_spec.max_output_tokens
|
|
810
|
+
if text_spec and text_spec.max_output_tokens is not None
|
|
811
|
+
else params.max_output_tokens
|
|
812
|
+
)
|
|
813
|
+
if max_out is not None:
|
|
814
|
+
gen_cfg["maxOutputTokens"] = max_out
|
|
815
|
+
if params.stop is not None:
|
|
816
|
+
gen_cfg["stopSequences"] = params.stop
|
|
817
|
+
|
|
818
|
+
if params.reasoning is not None:
|
|
819
|
+
thinking_cfg = _thinking_config(params.reasoning, model_name=model_name)
|
|
820
|
+
if thinking_cfg is not None:
|
|
821
|
+
gen_cfg["thinkingConfig"] = thinking_cfg
|
|
822
|
+
|
|
823
|
+
if text_spec and (
|
|
824
|
+
text_spec.format != "text" or text_spec.json_schema is not None
|
|
825
|
+
):
|
|
826
|
+
if set(request.output.modalities) != {"text"}:
|
|
827
|
+
raise invalid_request_error("json output requires text-only modality")
|
|
828
|
+
gen_cfg["responseMimeType"] = "application/json"
|
|
829
|
+
if text_spec.json_schema is not None:
|
|
830
|
+
gen_cfg["responseSchema"] = _to_gemini_response_schema(
|
|
831
|
+
text_spec.json_schema
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
response_modalities = _gemini_response_modalities(request.output.modalities)
|
|
835
|
+
if response_modalities == ["IMAGE"] and model_name.endswith("image-generation"):
|
|
836
|
+
response_modalities = ["TEXT", "IMAGE"]
|
|
837
|
+
if response_modalities:
|
|
838
|
+
gen_cfg["responseModalities"] = response_modalities
|
|
839
|
+
|
|
840
|
+
if request.output.image and request.output.image.n is not None:
|
|
841
|
+
gen_cfg["candidateCount"] = request.output.image.n
|
|
842
|
+
|
|
843
|
+
if "audio" in request.output.modalities:
|
|
844
|
+
audio = request.output.audio
|
|
845
|
+
if audio is None or not audio.voice:
|
|
846
|
+
raise invalid_request_error(
|
|
847
|
+
"output.audio.voice required for Gemini audio output"
|
|
848
|
+
)
|
|
849
|
+
speech_cfg: dict[str, Any] = {
|
|
850
|
+
"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": audio.voice}},
|
|
851
|
+
}
|
|
852
|
+
if audio.language:
|
|
853
|
+
speech_cfg["languageCode"] = audio.language
|
|
854
|
+
gen_cfg["speechConfig"] = speech_cfg
|
|
855
|
+
|
|
856
|
+
if "image" in request.output.modalities and request.output.image:
|
|
857
|
+
img_cfg = {}
|
|
858
|
+
if request.output.image.size:
|
|
859
|
+
img_cfg["imageSize"] = request.output.image.size
|
|
860
|
+
if img_cfg:
|
|
861
|
+
gen_cfg["imageConfig"] = img_cfg
|
|
862
|
+
|
|
863
|
+
opts = (
|
|
864
|
+
request.provider_options.get(self.provider_name)
|
|
865
|
+
or request.provider_options.get("google")
|
|
866
|
+
or request.provider_options.get("gemini")
|
|
867
|
+
)
|
|
868
|
+
if isinstance(opts, dict):
|
|
869
|
+
opt_gen_cfg = opts.get("generationConfig")
|
|
870
|
+
if opt_gen_cfg is not None:
|
|
871
|
+
if not isinstance(opt_gen_cfg, dict):
|
|
872
|
+
raise invalid_request_error(
|
|
873
|
+
"provider_options.google.generationConfig must be an object"
|
|
874
|
+
)
|
|
875
|
+
for k, v in opt_gen_cfg.items():
|
|
876
|
+
if k in gen_cfg:
|
|
877
|
+
raise invalid_request_error(
|
|
878
|
+
f"provider_options cannot override generationConfig.{k}"
|
|
879
|
+
)
|
|
880
|
+
gen_cfg[k] = v
|
|
881
|
+
|
|
882
|
+
body: dict[str, Any] = {"contents": contents}
|
|
883
|
+
if system_text:
|
|
884
|
+
body["systemInstruction"] = {
|
|
885
|
+
"role": "user",
|
|
886
|
+
"parts": [{"text": system_text}],
|
|
887
|
+
}
|
|
888
|
+
if gen_cfg:
|
|
889
|
+
body["generationConfig"] = gen_cfg
|
|
890
|
+
|
|
891
|
+
if request.tools:
|
|
892
|
+
decls: list[dict[str, Any]] = []
|
|
893
|
+
for t in request.tools:
|
|
894
|
+
name = t.name.strip()
|
|
895
|
+
if not name:
|
|
896
|
+
raise invalid_request_error("tool.name must be non-empty")
|
|
897
|
+
decl: dict[str, Any] = {"name": name}
|
|
898
|
+
if isinstance(t.description, str) and t.description.strip():
|
|
899
|
+
decl["description"] = t.description.strip()
|
|
900
|
+
decl["parameters"] = (
|
|
901
|
+
t.parameters if t.parameters is not None else {"type": "object"}
|
|
902
|
+
)
|
|
903
|
+
decls.append(decl)
|
|
904
|
+
body["tools"] = [{"functionDeclarations": decls}]
|
|
905
|
+
|
|
906
|
+
if request.tool_choice is not None:
|
|
907
|
+
choice = request.tool_choice.normalized()
|
|
908
|
+
if choice.mode in {"required", "tool"} and not request.tools:
|
|
909
|
+
raise invalid_request_error("tool_choice requires request.tools")
|
|
910
|
+
if choice.mode == "none":
|
|
911
|
+
body["toolConfig"] = {"functionCallingConfig": {"mode": "NONE"}}
|
|
912
|
+
elif choice.mode == "required":
|
|
913
|
+
body["toolConfig"] = {"functionCallingConfig": {"mode": "ANY"}}
|
|
914
|
+
elif choice.mode == "tool":
|
|
915
|
+
body["toolConfig"] = {
|
|
916
|
+
"functionCallingConfig": {
|
|
917
|
+
"mode": "ANY",
|
|
918
|
+
"allowedFunctionNames": [choice.name],
|
|
919
|
+
}
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
if isinstance(opts, dict):
|
|
923
|
+
for k, v in opts.items():
|
|
924
|
+
if k == "generationConfig":
|
|
925
|
+
continue
|
|
926
|
+
if k in body:
|
|
927
|
+
raise invalid_request_error(
|
|
928
|
+
f"provider_options cannot override body.{k}"
|
|
929
|
+
)
|
|
930
|
+
body[k] = v
|
|
931
|
+
return body
|
|
932
|
+
|
|
933
|
+
def _parse_generate(self, obj: dict[str, Any], *, model: str) -> GenerateResponse:
|
|
934
|
+
candidates = obj.get("candidates")
|
|
935
|
+
if not isinstance(candidates, list) or not candidates:
|
|
936
|
+
raise provider_error("gemini response missing candidates")
|
|
937
|
+
cand0 = candidates[0]
|
|
938
|
+
if not isinstance(cand0, dict):
|
|
939
|
+
raise provider_error("gemini candidate is not object")
|
|
940
|
+
content = cand0.get("content")
|
|
941
|
+
if not isinstance(content, dict):
|
|
942
|
+
raise provider_error("gemini candidate missing content")
|
|
943
|
+
parts = content.get("parts")
|
|
944
|
+
if not isinstance(parts, list):
|
|
945
|
+
raise provider_error("gemini candidate content missing parts")
|
|
946
|
+
out_parts: list[Part] = []
|
|
947
|
+
for p in parts:
|
|
948
|
+
if not isinstance(p, dict):
|
|
949
|
+
continue
|
|
950
|
+
out_parts.extend(_gemini_part_to_parts(p, provider_name=self.provider_name))
|
|
951
|
+
usage = _usage_from_gemini(obj.get("usageMetadata"))
|
|
952
|
+
if not out_parts:
|
|
953
|
+
out_parts.append(Part.from_text(""))
|
|
954
|
+
return GenerateResponse(
|
|
955
|
+
id=f"sdk_{uuid4().hex}",
|
|
956
|
+
provider=self.provider_name,
|
|
957
|
+
model=f"{self.provider_name}:{model}",
|
|
958
|
+
status="completed",
|
|
959
|
+
output=[Message(role="assistant", content=out_parts)],
|
|
960
|
+
usage=usage,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
def _extract_text_delta(self, obj: dict[str, Any]) -> str | None:
|
|
964
|
+
candidates = obj.get("candidates")
|
|
965
|
+
if not isinstance(candidates, list) or not candidates:
|
|
966
|
+
return None
|
|
967
|
+
cand0 = candidates[0]
|
|
968
|
+
if not isinstance(cand0, dict):
|
|
969
|
+
return None
|
|
970
|
+
content = cand0.get("content")
|
|
971
|
+
if not isinstance(content, dict):
|
|
972
|
+
return None
|
|
973
|
+
parts = content.get("parts")
|
|
974
|
+
if not isinstance(parts, list) or not parts:
|
|
975
|
+
return None
|
|
976
|
+
p0 = parts[0]
|
|
977
|
+
if not isinstance(p0, dict):
|
|
978
|
+
return None
|
|
979
|
+
text = p0.get("text")
|
|
980
|
+
if isinstance(text, str) and text:
|
|
981
|
+
return text
|
|
982
|
+
return None
|
|
983
|
+
|
|
984
|
+
def _upload_to_file_uri(
|
|
985
|
+
self, file_path: str, *, mime_type: str, timeout_ms: int | None
|
|
986
|
+
) -> str:
|
|
987
|
+
if not self.supports_file_upload:
|
|
988
|
+
raise not_supported_error(
|
|
989
|
+
"Gemini file upload is not supported for this provider"
|
|
990
|
+
)
|
|
991
|
+
meta = {"file": {"displayName": os.path.basename(file_path)}}
|
|
992
|
+
body = multipart_form_data_json_and_file(
|
|
993
|
+
metadata_field="metadata",
|
|
994
|
+
metadata=meta,
|
|
995
|
+
file_field="file",
|
|
996
|
+
file_path=file_path,
|
|
997
|
+
filename=os.path.basename(file_path),
|
|
998
|
+
file_mime_type=mime_type,
|
|
999
|
+
)
|
|
1000
|
+
url = self._upload_url("files")
|
|
1001
|
+
headers = {"X-Goog-Upload-Protocol": "multipart"}
|
|
1002
|
+
headers.update(self._auth_headers())
|
|
1003
|
+
obj = request_streaming_body_json(
|
|
1004
|
+
method="POST",
|
|
1005
|
+
url=url,
|
|
1006
|
+
headers=headers,
|
|
1007
|
+
body=body,
|
|
1008
|
+
timeout_ms=timeout_ms,
|
|
1009
|
+
proxy_url=self.proxy_url,
|
|
1010
|
+
)
|
|
1011
|
+
file_obj = obj.get("file")
|
|
1012
|
+
if not isinstance(file_obj, dict):
|
|
1013
|
+
raise provider_error("gemini upload response missing file")
|
|
1014
|
+
name = file_obj.get("name")
|
|
1015
|
+
uri = file_obj.get("uri")
|
|
1016
|
+
if not isinstance(name, str) or not name:
|
|
1017
|
+
raise provider_error("gemini upload response missing file.name")
|
|
1018
|
+
if not isinstance(uri, str) or not uri:
|
|
1019
|
+
raise provider_error("gemini upload response missing file.uri")
|
|
1020
|
+
state = file_obj.get("state")
|
|
1021
|
+
if state == "ACTIVE":
|
|
1022
|
+
return uri
|
|
1023
|
+
return self._wait_file_active(name=name, uri=uri, timeout_ms=timeout_ms)
|
|
1024
|
+
|
|
1025
|
+
def _wait_file_active(self, *, name: str, uri: str, timeout_ms: int | None) -> str:
|
|
1026
|
+
if not self.supports_file_upload:
|
|
1027
|
+
raise not_supported_error(
|
|
1028
|
+
"Gemini file upload is not supported for this provider"
|
|
1029
|
+
)
|
|
1030
|
+
url = self._v1beta_url(name)
|
|
1031
|
+
budget_ms = 120_000 if timeout_ms is None else timeout_ms
|
|
1032
|
+
deadline = time.time() + max(1, budget_ms) / 1000.0
|
|
1033
|
+
while True:
|
|
1034
|
+
remaining_ms = int((deadline - time.time()) * 1000)
|
|
1035
|
+
if remaining_ms <= 0:
|
|
1036
|
+
break
|
|
1037
|
+
obj = request_json(
|
|
1038
|
+
method="GET",
|
|
1039
|
+
url=url,
|
|
1040
|
+
headers=self._auth_headers(),
|
|
1041
|
+
json_body=None,
|
|
1042
|
+
timeout_ms=min(30_000, remaining_ms),
|
|
1043
|
+
proxy_url=self.proxy_url,
|
|
1044
|
+
)
|
|
1045
|
+
if not isinstance(obj, dict):
|
|
1046
|
+
raise provider_error("gemini get file response is not object")
|
|
1047
|
+
state = obj.get("state")
|
|
1048
|
+
if state == "ACTIVE":
|
|
1049
|
+
return uri
|
|
1050
|
+
if state == "FAILED":
|
|
1051
|
+
err = obj.get("error")
|
|
1052
|
+
raise provider_error(f"gemini file processing failed: {err}")
|
|
1053
|
+
time.sleep(min(1.0, max(0.0, deadline - time.time())))
|
|
1054
|
+
raise timeout_error("gemini file processing timeout")
|
|
1055
|
+
|
|
1056
|
+
def _wait_operation_done(self, *, name: str, deadline: float) -> dict[str, Any]:
|
|
1057
|
+
url = self._v1beta_url(name)
|
|
1058
|
+
while True:
|
|
1059
|
+
remaining_ms = int((deadline - time.time()) * 1000)
|
|
1060
|
+
if remaining_ms <= 0:
|
|
1061
|
+
break
|
|
1062
|
+
obj = request_json(
|
|
1063
|
+
method="GET",
|
|
1064
|
+
url=url,
|
|
1065
|
+
headers=self._auth_headers(),
|
|
1066
|
+
json_body=None,
|
|
1067
|
+
timeout_ms=min(30_000, remaining_ms),
|
|
1068
|
+
proxy_url=self.proxy_url,
|
|
1069
|
+
)
|
|
1070
|
+
if not isinstance(obj, dict):
|
|
1071
|
+
raise provider_error("gemini operation get response is not object")
|
|
1072
|
+
if obj.get("done") is True:
|
|
1073
|
+
return obj
|
|
1074
|
+
time.sleep(min(1.0, max(0.0, deadline - time.time())))
|
|
1075
|
+
return {"name": name, "done": False}
|
|
1076
|
+
|
|
1077
|
+
|
|
1078
|
+
def _gather_text_inputs(request: GenerateRequest) -> list[str]:
|
|
1079
|
+
texts: list[str] = []
|
|
1080
|
+
for message in request.input:
|
|
1081
|
+
for part in message.content:
|
|
1082
|
+
if part.type != "text":
|
|
1083
|
+
raise invalid_request_error("embedding requires text-only input")
|
|
1084
|
+
texts.append(part.require_text())
|
|
1085
|
+
if not texts:
|
|
1086
|
+
raise invalid_request_error("embedding requires at least one text part")
|
|
1087
|
+
return texts
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
def _extract_system_text(request: GenerateRequest) -> str | None:
|
|
1091
|
+
chunks: list[str] = []
|
|
1092
|
+
for m in request.input:
|
|
1093
|
+
if m.role != "system":
|
|
1094
|
+
continue
|
|
1095
|
+
for p in m.content:
|
|
1096
|
+
if p.type != "text":
|
|
1097
|
+
raise invalid_request_error(
|
|
1098
|
+
"system message only supports text for Gemini"
|
|
1099
|
+
)
|
|
1100
|
+
chunks.append(p.require_text())
|
|
1101
|
+
joined = "\n\n".join([c for c in chunks if c.strip()])
|
|
1102
|
+
return joined or None
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
def _gemini_supports_thinking_level(model_name: str) -> bool:
|
|
1106
|
+
mid = (
|
|
1107
|
+
model_name[len("models/") :] if model_name.startswith("models/") else model_name
|
|
1108
|
+
)
|
|
1109
|
+
mid = mid.strip().lower()
|
|
1110
|
+
if not mid.startswith("gemini-"):
|
|
1111
|
+
return False
|
|
1112
|
+
rest = mid[len("gemini-") :]
|
|
1113
|
+
digits = []
|
|
1114
|
+
for ch in rest:
|
|
1115
|
+
if ch.isdigit():
|
|
1116
|
+
digits.append(ch)
|
|
1117
|
+
else:
|
|
1118
|
+
break
|
|
1119
|
+
if not digits:
|
|
1120
|
+
return False
|
|
1121
|
+
try:
|
|
1122
|
+
major = int("".join(digits))
|
|
1123
|
+
except ValueError:
|
|
1124
|
+
return False
|
|
1125
|
+
return major >= 3
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def _gemini_supports_thinking_budget(model_name: str) -> bool:
|
|
1129
|
+
mid = (
|
|
1130
|
+
model_name[len("models/") :] if model_name.startswith("models/") else model_name
|
|
1131
|
+
)
|
|
1132
|
+
mid = mid.strip().lower()
|
|
1133
|
+
return (
|
|
1134
|
+
"gemini-2.5-" in mid or "robotics-er" in mid or "flash-live-native-audio" in mid
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def _thinking_config(reasoning, *, model_name: str) -> dict[str, Any] | None:
|
|
1139
|
+
cfg: dict[str, Any] = {}
|
|
1140
|
+
if reasoning.effort is not None:
|
|
1141
|
+
if _gemini_supports_thinking_level(model_name):
|
|
1142
|
+
cfg["thinkingLevel"] = _map_effort_to_thinking_level(
|
|
1143
|
+
reasoning.effort, model_name=model_name
|
|
1144
|
+
)
|
|
1145
|
+
elif _gemini_supports_thinking_budget(model_name):
|
|
1146
|
+
cfg["thinkingBudget"] = _map_effort_to_thinking_budget(
|
|
1147
|
+
reasoning.effort, model_name=model_name
|
|
1148
|
+
)
|
|
1149
|
+
return cfg or None
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
def _map_effort_to_thinking_level(effort: object, *, model_name: str) -> str:
|
|
1153
|
+
"""
|
|
1154
|
+
Gemini 3 thinkingLevel mapping (per Google docs):
|
|
1155
|
+
- Gemini 3 Pro: low/high only
|
|
1156
|
+
- Gemini 3 Flash: minimal/low/medium/high
|
|
1157
|
+
"""
|
|
1158
|
+
eff = normalize_reasoning_effort(effort)
|
|
1159
|
+
mid = (
|
|
1160
|
+
model_name[len("models/") :] if model_name.startswith("models/") else model_name
|
|
1161
|
+
)
|
|
1162
|
+
mid = mid.strip().lower()
|
|
1163
|
+
is_flash = "gemini-3-flash" in mid
|
|
1164
|
+
if eff in {"none", "minimal"}:
|
|
1165
|
+
return "minimal" if is_flash else "low"
|
|
1166
|
+
if eff == "low":
|
|
1167
|
+
return "low"
|
|
1168
|
+
if eff == "medium":
|
|
1169
|
+
return "medium" if is_flash else "high"
|
|
1170
|
+
return "high"
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def _map_effort_to_thinking_budget(effort: object, *, model_name: str) -> int:
|
|
1174
|
+
"""
|
|
1175
|
+
Gemini 2.5+ thinkingBudget mapping (参考 Google docs + LiteLLM 默认值)。
|
|
1176
|
+
|
|
1177
|
+
- `none` 尽量关闭(2.5 Pro 不能关闭,取最小预算)
|
|
1178
|
+
- `minimal/low/medium/high/xhigh` 映射到逐步增加的 budget
|
|
1179
|
+
"""
|
|
1180
|
+
eff = normalize_reasoning_effort(effort)
|
|
1181
|
+
mid = (
|
|
1182
|
+
model_name[len("models/") :] if model_name.startswith("models/") else model_name
|
|
1183
|
+
)
|
|
1184
|
+
mid = mid.strip().lower()
|
|
1185
|
+
|
|
1186
|
+
is_25_flash_lite = "gemini-2.5-flash-lite" in mid
|
|
1187
|
+
is_25_pro = "gemini-2.5-pro" in mid
|
|
1188
|
+
is_25_flash = "gemini-2.5-flash" in mid
|
|
1189
|
+
|
|
1190
|
+
if eff == "none":
|
|
1191
|
+
return 128 if is_25_pro else 0
|
|
1192
|
+
if eff == "minimal":
|
|
1193
|
+
if is_25_flash_lite:
|
|
1194
|
+
return 512
|
|
1195
|
+
if is_25_pro:
|
|
1196
|
+
return 128
|
|
1197
|
+
if is_25_flash:
|
|
1198
|
+
return 1
|
|
1199
|
+
return 128
|
|
1200
|
+
if eff == "low":
|
|
1201
|
+
return 1024
|
|
1202
|
+
if eff == "medium":
|
|
1203
|
+
return 2048
|
|
1204
|
+
if eff == "high":
|
|
1205
|
+
return 4096
|
|
1206
|
+
return 8192
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
def _gemini_response_modalities(modalities: Sequence[str]) -> list[str]:
|
|
1210
|
+
out: list[str] = []
|
|
1211
|
+
for m in modalities:
|
|
1212
|
+
m = m.lower()
|
|
1213
|
+
if m == "text":
|
|
1214
|
+
out.append("TEXT")
|
|
1215
|
+
elif m == "image":
|
|
1216
|
+
out.append("IMAGE")
|
|
1217
|
+
elif m == "audio":
|
|
1218
|
+
out.append("AUDIO")
|
|
1219
|
+
elif m == "video":
|
|
1220
|
+
raise not_supported_error("Gemini does not support video response modality")
|
|
1221
|
+
elif m == "embedding":
|
|
1222
|
+
raise not_supported_error(
|
|
1223
|
+
"embedding should use embedContent/batchEmbedContents"
|
|
1224
|
+
)
|
|
1225
|
+
else:
|
|
1226
|
+
raise invalid_request_error(f"unknown modality: {m}")
|
|
1227
|
+
return out
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
def _message_to_content(
|
|
1231
|
+
adapter: GeminiAdapter, message: Message, *, timeout_ms: int | None
|
|
1232
|
+
) -> dict[str, Any]:
|
|
1233
|
+
if message.role == "user" and any(p.type == "tool_call" for p in message.content):
|
|
1234
|
+
raise invalid_request_error(
|
|
1235
|
+
"tool_call parts are only allowed in assistant messages"
|
|
1236
|
+
)
|
|
1237
|
+
if message.role == "assistant" and any(
|
|
1238
|
+
p.type == "tool_result" for p in message.content
|
|
1239
|
+
):
|
|
1240
|
+
raise invalid_request_error("tool_result parts must be sent as role='tool'")
|
|
1241
|
+
if message.role == "tool" and any(p.type != "tool_result" for p in message.content):
|
|
1242
|
+
raise invalid_request_error("tool messages may only contain tool_result parts")
|
|
1243
|
+
role = (
|
|
1244
|
+
"user"
|
|
1245
|
+
if message.role in {"user", "tool"}
|
|
1246
|
+
else "model"
|
|
1247
|
+
if message.role == "assistant"
|
|
1248
|
+
else message.role
|
|
1249
|
+
)
|
|
1250
|
+
if role not in {"user", "model"}:
|
|
1251
|
+
raise not_supported_error(f"Gemini does not support role: {message.role}")
|
|
1252
|
+
parts = [
|
|
1253
|
+
_part_to_gemini_part(adapter, p, timeout_ms=timeout_ms) for p in message.content
|
|
1254
|
+
]
|
|
1255
|
+
return {"role": role, "parts": parts}
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
def _require_tool_call_meta(part: Part) -> tuple[str, Any]:
|
|
1259
|
+
name = part.meta.get("name")
|
|
1260
|
+
if not isinstance(name, str) or not name.strip():
|
|
1261
|
+
raise invalid_request_error("tool_call.meta.name must be a non-empty string")
|
|
1262
|
+
arguments = part.meta.get("arguments")
|
|
1263
|
+
if not isinstance(arguments, dict):
|
|
1264
|
+
raise invalid_request_error("Gemini tool_call.meta.arguments must be an object")
|
|
1265
|
+
return (name.strip(), arguments)
|
|
1266
|
+
|
|
1267
|
+
|
|
1268
|
+
def _require_tool_result_meta(part: Part) -> tuple[str, Any]:
|
|
1269
|
+
name = part.meta.get("name")
|
|
1270
|
+
if not isinstance(name, str) or not name.strip():
|
|
1271
|
+
raise invalid_request_error("tool_result.meta.name must be a non-empty string")
|
|
1272
|
+
return (name.strip(), part.meta.get("result"))
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
def _part_to_gemini_part(
|
|
1276
|
+
adapter: GeminiAdapter, part: Part, *, timeout_ms: int | None
|
|
1277
|
+
) -> dict[str, Any]:
|
|
1278
|
+
if part.type == "text":
|
|
1279
|
+
return {"text": part.require_text()}
|
|
1280
|
+
if part.type == "tool_call":
|
|
1281
|
+
name, arguments = _require_tool_call_meta(part)
|
|
1282
|
+
return {"functionCall": {"name": name, "args": arguments}}
|
|
1283
|
+
if part.type == "tool_result":
|
|
1284
|
+
name, result = _require_tool_result_meta(part)
|
|
1285
|
+
response = result if isinstance(result, dict) else {"result": result}
|
|
1286
|
+
return {"functionResponse": {"name": name, "response": response}}
|
|
1287
|
+
if part.type in {"image", "audio", "video"}:
|
|
1288
|
+
source = part.require_source()
|
|
1289
|
+
mime_type = part.mime_type
|
|
1290
|
+
if mime_type is None and isinstance(source, PartSourcePath):
|
|
1291
|
+
mime_type = detect_mime_type(source.path)
|
|
1292
|
+
if not mime_type:
|
|
1293
|
+
raise invalid_request_error(
|
|
1294
|
+
f"{part.type} requires mime_type (or path extension)"
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
if isinstance(source, PartSourceRef):
|
|
1298
|
+
return {"fileData": {"mimeType": mime_type, "fileUri": source.id}}
|
|
1299
|
+
|
|
1300
|
+
if isinstance(source, PartSourceUrl):
|
|
1301
|
+
tmp = download_to_tempfile(
|
|
1302
|
+
url=source.url,
|
|
1303
|
+
timeout_ms=timeout_ms,
|
|
1304
|
+
max_bytes=None,
|
|
1305
|
+
proxy_url=adapter.proxy_url,
|
|
1306
|
+
)
|
|
1307
|
+
try:
|
|
1308
|
+
return _upload_or_inline(
|
|
1309
|
+
adapter,
|
|
1310
|
+
part,
|
|
1311
|
+
file_path=tmp,
|
|
1312
|
+
mime_type=mime_type,
|
|
1313
|
+
timeout_ms=timeout_ms,
|
|
1314
|
+
)
|
|
1315
|
+
finally:
|
|
1316
|
+
try:
|
|
1317
|
+
os.unlink(tmp)
|
|
1318
|
+
except OSError:
|
|
1319
|
+
pass
|
|
1320
|
+
|
|
1321
|
+
if isinstance(source, PartSourcePath):
|
|
1322
|
+
return _upload_or_inline(
|
|
1323
|
+
adapter,
|
|
1324
|
+
part,
|
|
1325
|
+
file_path=source.path,
|
|
1326
|
+
mime_type=mime_type,
|
|
1327
|
+
timeout_ms=timeout_ms,
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
if isinstance(source, PartSourceBytes) and source.encoding == "base64":
|
|
1331
|
+
b64 = source.data
|
|
1332
|
+
if not isinstance(b64, str) or not b64:
|
|
1333
|
+
raise invalid_request_error(
|
|
1334
|
+
f"{part.type} base64 data must be non-empty"
|
|
1335
|
+
)
|
|
1336
|
+
return {"inlineData": {"mimeType": mime_type, "data": b64}}
|
|
1337
|
+
|
|
1338
|
+
assert isinstance(source, PartSourceBytes)
|
|
1339
|
+
data = source.data
|
|
1340
|
+
if not isinstance(data, bytes):
|
|
1341
|
+
raise invalid_request_error(f"{part.type} bytes data must be bytes")
|
|
1342
|
+
if len(data) <= 20 * 1024 * 1024 and part.type == "image":
|
|
1343
|
+
return {
|
|
1344
|
+
"inlineData": {"mimeType": mime_type, "data": bytes_to_base64(data)}
|
|
1345
|
+
}
|
|
1346
|
+
with tempfile.NamedTemporaryFile(
|
|
1347
|
+
prefix="genaisdk-", suffix=".bin", delete=False
|
|
1348
|
+
) as f:
|
|
1349
|
+
f.write(data)
|
|
1350
|
+
tmp = f.name
|
|
1351
|
+
try:
|
|
1352
|
+
return _upload_or_inline(
|
|
1353
|
+
adapter, part, file_path=tmp, mime_type=mime_type, timeout_ms=timeout_ms
|
|
1354
|
+
)
|
|
1355
|
+
finally:
|
|
1356
|
+
try:
|
|
1357
|
+
os.unlink(tmp)
|
|
1358
|
+
except OSError:
|
|
1359
|
+
pass
|
|
1360
|
+
|
|
1361
|
+
if part.type == "embedding":
|
|
1362
|
+
raise not_supported_error(
|
|
1363
|
+
"embedding is not a valid Gemini generateContent input part"
|
|
1364
|
+
)
|
|
1365
|
+
raise not_supported_error(f"unsupported part type: {part.type}")
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def _upload_or_inline(
|
|
1369
|
+
adapter: GeminiAdapter,
|
|
1370
|
+
part: Part,
|
|
1371
|
+
*,
|
|
1372
|
+
file_path: str,
|
|
1373
|
+
mime_type: str,
|
|
1374
|
+
timeout_ms: int | None,
|
|
1375
|
+
) -> dict[str, Any]:
|
|
1376
|
+
max_inline = 20 * 1024 * 1024
|
|
1377
|
+
if part.type == "image":
|
|
1378
|
+
st = os.stat(file_path)
|
|
1379
|
+
if st.st_size <= max_inline:
|
|
1380
|
+
data = file_to_bytes(file_path, max_inline)
|
|
1381
|
+
return {
|
|
1382
|
+
"inlineData": {"mimeType": mime_type, "data": bytes_to_base64(data)}
|
|
1383
|
+
}
|
|
1384
|
+
if part.type == "audio":
|
|
1385
|
+
st = os.stat(file_path)
|
|
1386
|
+
if st.st_size <= max_inline:
|
|
1387
|
+
data = file_to_bytes(file_path, max_inline)
|
|
1388
|
+
return {
|
|
1389
|
+
"inlineData": {"mimeType": mime_type, "data": bytes_to_base64(data)}
|
|
1390
|
+
}
|
|
1391
|
+
if part.type == "video":
|
|
1392
|
+
st = os.stat(file_path)
|
|
1393
|
+
if not adapter.supports_file_upload:
|
|
1394
|
+
if st.st_size <= max_inline:
|
|
1395
|
+
data = file_to_bytes(file_path, max_inline)
|
|
1396
|
+
return {
|
|
1397
|
+
"inlineData": {"mimeType": mime_type, "data": bytes_to_base64(data)}
|
|
1398
|
+
}
|
|
1399
|
+
raise not_supported_error(
|
|
1400
|
+
f"file too large for inline bytes ({st.st_size} > {max_inline}); use url/ref instead"
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
file_uri = adapter._upload_to_file_uri(
|
|
1404
|
+
file_path, mime_type=mime_type, timeout_ms=timeout_ms
|
|
1405
|
+
)
|
|
1406
|
+
return {"fileData": {"mimeType": mime_type, "fileUri": file_uri}}
|
|
1407
|
+
|
|
1408
|
+
|
|
1409
|
+
def _gemini_part_to_parts(part: dict[str, Any], *, provider_name: str) -> list[Part]:
|
|
1410
|
+
if "text" in part and isinstance(part["text"], str):
|
|
1411
|
+
text = part["text"]
|
|
1412
|
+
if provider_name.startswith("tuzi"):
|
|
1413
|
+
urls = [
|
|
1414
|
+
m.group(1).strip()
|
|
1415
|
+
for m in _MD_IMAGE_URL_RE.finditer(text)
|
|
1416
|
+
if m.group(1).strip()
|
|
1417
|
+
]
|
|
1418
|
+
if urls:
|
|
1419
|
+
out: list[Part] = []
|
|
1420
|
+
remaining = _MD_IMAGE_URL_RE.sub("", text).strip()
|
|
1421
|
+
if remaining:
|
|
1422
|
+
out.append(Part.from_text(remaining))
|
|
1423
|
+
out.extend(
|
|
1424
|
+
[Part(type="image", source=PartSourceUrl(url=u)) for u in urls]
|
|
1425
|
+
)
|
|
1426
|
+
return out
|
|
1427
|
+
return [Part.from_text(text)]
|
|
1428
|
+
fc_obj = part.get("functionCall") or part.get("function_call")
|
|
1429
|
+
if isinstance(fc_obj, dict):
|
|
1430
|
+
fc = fc_obj
|
|
1431
|
+
name = fc.get("name")
|
|
1432
|
+
args = fc.get("args")
|
|
1433
|
+
if isinstance(name, str) and name and isinstance(args, dict):
|
|
1434
|
+
return [Part.tool_call(name=name, arguments=args)]
|
|
1435
|
+
return []
|
|
1436
|
+
if "inlineData" in part and isinstance(part["inlineData"], dict):
|
|
1437
|
+
blob = part["inlineData"]
|
|
1438
|
+
mime = blob.get("mimeType")
|
|
1439
|
+
data_b64 = blob.get("data")
|
|
1440
|
+
if not isinstance(mime, str) or not isinstance(data_b64, str):
|
|
1441
|
+
return []
|
|
1442
|
+
if mime.startswith("image/"):
|
|
1443
|
+
return [
|
|
1444
|
+
Part(
|
|
1445
|
+
type="image",
|
|
1446
|
+
mime_type=mime,
|
|
1447
|
+
source=PartSourceBytes(data=data_b64, encoding="base64"),
|
|
1448
|
+
)
|
|
1449
|
+
]
|
|
1450
|
+
if mime.startswith("audio/"):
|
|
1451
|
+
return [
|
|
1452
|
+
Part(
|
|
1453
|
+
type="audio",
|
|
1454
|
+
mime_type=mime,
|
|
1455
|
+
source=PartSourceBytes(data=data_b64, encoding="base64"),
|
|
1456
|
+
)
|
|
1457
|
+
]
|
|
1458
|
+
if mime.startswith("video/"):
|
|
1459
|
+
return [
|
|
1460
|
+
Part(
|
|
1461
|
+
type="video",
|
|
1462
|
+
mime_type=mime,
|
|
1463
|
+
source=PartSourceBytes(data=data_b64, encoding="base64"),
|
|
1464
|
+
)
|
|
1465
|
+
]
|
|
1466
|
+
return [
|
|
1467
|
+
Part(
|
|
1468
|
+
type="file",
|
|
1469
|
+
mime_type=mime,
|
|
1470
|
+
source=PartSourceBytes(data=data_b64, encoding="base64"),
|
|
1471
|
+
)
|
|
1472
|
+
]
|
|
1473
|
+
if "fileData" in part and isinstance(part["fileData"], dict):
|
|
1474
|
+
fd = part["fileData"]
|
|
1475
|
+
uri = fd.get("fileUri")
|
|
1476
|
+
mime = fd.get("mimeType")
|
|
1477
|
+
if not isinstance(uri, str) or not uri:
|
|
1478
|
+
return []
|
|
1479
|
+
mime_s = mime if isinstance(mime, str) else None
|
|
1480
|
+
kind: PartType = "file"
|
|
1481
|
+
if mime_s and mime_s.startswith("image/"):
|
|
1482
|
+
kind = "image"
|
|
1483
|
+
elif mime_s and mime_s.startswith("audio/"):
|
|
1484
|
+
kind = "audio"
|
|
1485
|
+
elif mime_s and mime_s.startswith("video/"):
|
|
1486
|
+
kind = "video"
|
|
1487
|
+
return [
|
|
1488
|
+
Part(
|
|
1489
|
+
type=kind,
|
|
1490
|
+
mime_type=mime_s,
|
|
1491
|
+
source=PartSourceRef(provider=provider_name, id=uri),
|
|
1492
|
+
)
|
|
1493
|
+
]
|
|
1494
|
+
return []
|
|
1495
|
+
|
|
1496
|
+
|
|
1497
|
+
def _usage_from_gemini(usage: Any) -> Usage | None:
|
|
1498
|
+
if not isinstance(usage, dict):
|
|
1499
|
+
return None
|
|
1500
|
+
return Usage(
|
|
1501
|
+
input_tokens=usage.get("promptTokenCount"),
|
|
1502
|
+
output_tokens=usage.get("candidatesTokenCount"),
|
|
1503
|
+
total_tokens=usage.get("totalTokenCount"),
|
|
1504
|
+
)
|
|
1505
|
+
|
|
1506
|
+
|
|
1507
|
+
def _extract_veo_video_uri(response: Any) -> str | None:
|
|
1508
|
+
if not isinstance(response, dict):
|
|
1509
|
+
return None
|
|
1510
|
+
gvr = response.get("generateVideoResponse")
|
|
1511
|
+
if isinstance(gvr, dict):
|
|
1512
|
+
uri = _extract_veo_video_uri_from_samples(gvr.get("generatedSamples"))
|
|
1513
|
+
if uri:
|
|
1514
|
+
return uri
|
|
1515
|
+
uri = _extract_veo_video_uri_from_samples(response.get("generatedSamples"))
|
|
1516
|
+
if uri:
|
|
1517
|
+
return uri
|
|
1518
|
+
videos = response.get("generatedVideos") or response.get("generated_videos")
|
|
1519
|
+
if isinstance(videos, list) and videos:
|
|
1520
|
+
v0 = videos[0]
|
|
1521
|
+
if isinstance(v0, dict):
|
|
1522
|
+
v = v0.get("video")
|
|
1523
|
+
if isinstance(v, dict):
|
|
1524
|
+
uri = v.get("uri") or v.get("downloadUri") or v.get("fileUri")
|
|
1525
|
+
if isinstance(uri, str) and uri:
|
|
1526
|
+
return uri
|
|
1527
|
+
return None
|
|
1528
|
+
|
|
1529
|
+
|
|
1530
|
+
def _extract_veo_video_uri_from_samples(samples: Any) -> str | None:
|
|
1531
|
+
if not isinstance(samples, list) or not samples:
|
|
1532
|
+
return None
|
|
1533
|
+
s0 = samples[0]
|
|
1534
|
+
if not isinstance(s0, dict):
|
|
1535
|
+
return None
|
|
1536
|
+
v = s0.get("video")
|
|
1537
|
+
if isinstance(v, dict):
|
|
1538
|
+
uri = v.get("uri") or v.get("downloadUri") or v.get("fileUri")
|
|
1539
|
+
if isinstance(uri, str) and uri:
|
|
1540
|
+
return uri
|
|
1541
|
+
uri = s0.get("uri")
|
|
1542
|
+
if isinstance(uri, str) and uri:
|
|
1543
|
+
return uri
|
|
1544
|
+
return None
|
|
1545
|
+
|
|
1546
|
+
|
|
1547
|
+
def _first_candidate_text(obj: dict[str, Any]) -> str | None:
|
|
1548
|
+
candidates = obj.get("candidates")
|
|
1549
|
+
if not isinstance(candidates, list) or not candidates:
|
|
1550
|
+
return None
|
|
1551
|
+
cand0 = candidates[0]
|
|
1552
|
+
if not isinstance(cand0, dict):
|
|
1553
|
+
return None
|
|
1554
|
+
content = cand0.get("content")
|
|
1555
|
+
if not isinstance(content, dict):
|
|
1556
|
+
return None
|
|
1557
|
+
parts = content.get("parts")
|
|
1558
|
+
if not isinstance(parts, list):
|
|
1559
|
+
return None
|
|
1560
|
+
for p in parts:
|
|
1561
|
+
if not isinstance(p, dict):
|
|
1562
|
+
continue
|
|
1563
|
+
text = p.get("text")
|
|
1564
|
+
if isinstance(text, str) and text:
|
|
1565
|
+
return text
|
|
1566
|
+
return None
|
|
1567
|
+
|
|
1568
|
+
|
|
1569
|
+
def _extract_tuzi_task_id(text: str) -> str | None:
|
|
1570
|
+
m = _TUZI_TASK_ID_RE.search(text)
|
|
1571
|
+
if m is None:
|
|
1572
|
+
return None
|
|
1573
|
+
tid = m.group(1).strip()
|
|
1574
|
+
return tid or None
|
|
1575
|
+
|
|
1576
|
+
|
|
1577
|
+
def _extract_first_mp4_url(text: str) -> str | None:
|
|
1578
|
+
m = _MP4_URL_RE.search(text)
|
|
1579
|
+
if m is None:
|
|
1580
|
+
return None
|
|
1581
|
+
url = m.group(0).strip()
|
|
1582
|
+
return url or None
|
|
1583
|
+
|
|
1584
|
+
|
|
1585
|
+
def _poll_tuzi_video_mp4(
|
|
1586
|
+
*, task_id: str, deadline: float, proxy_url: str | None
|
|
1587
|
+
) -> str | None:
|
|
1588
|
+
poll_url = f"{_ASYNCDATA_BASE_URL}/source/{task_id}"
|
|
1589
|
+
while True:
|
|
1590
|
+
remaining_ms = int((deadline - time.time()) * 1000)
|
|
1591
|
+
if remaining_ms <= 0:
|
|
1592
|
+
return None
|
|
1593
|
+
obj = request_json(
|
|
1594
|
+
method="GET",
|
|
1595
|
+
url=poll_url,
|
|
1596
|
+
headers=None,
|
|
1597
|
+
json_body=None,
|
|
1598
|
+
timeout_ms=min(30_000, remaining_ms),
|
|
1599
|
+
proxy_url=proxy_url,
|
|
1600
|
+
)
|
|
1601
|
+
content = obj.get("content")
|
|
1602
|
+
if isinstance(content, str) and content:
|
|
1603
|
+
mp4_url = _extract_first_mp4_url(content)
|
|
1604
|
+
if mp4_url:
|
|
1605
|
+
return mp4_url
|
|
1606
|
+
time.sleep(min(2.0, max(0.0, deadline - time.time())))
|