flexllm 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexllm/__init__.py +224 -0
- flexllm/__main__.py +1096 -0
- flexllm/async_api/__init__.py +9 -0
- flexllm/async_api/concurrent_call.py +100 -0
- flexllm/async_api/concurrent_executor.py +1036 -0
- flexllm/async_api/core.py +373 -0
- flexllm/async_api/interface.py +12 -0
- flexllm/async_api/progress.py +277 -0
- flexllm/base_client.py +988 -0
- flexllm/batch_tools/__init__.py +16 -0
- flexllm/batch_tools/folder_processor.py +317 -0
- flexllm/batch_tools/table_processor.py +363 -0
- flexllm/cache/__init__.py +10 -0
- flexllm/cache/response_cache.py +293 -0
- flexllm/chain_of_thought_client.py +1120 -0
- flexllm/claudeclient.py +402 -0
- flexllm/client_pool.py +698 -0
- flexllm/geminiclient.py +563 -0
- flexllm/llm_client.py +523 -0
- flexllm/llm_parser.py +60 -0
- flexllm/mllm_client.py +559 -0
- flexllm/msg_processors/__init__.py +174 -0
- flexllm/msg_processors/image_processor.py +729 -0
- flexllm/msg_processors/image_processor_helper.py +485 -0
- flexllm/msg_processors/messages_processor.py +341 -0
- flexllm/msg_processors/unified_processor.py +1404 -0
- flexllm/openaiclient.py +256 -0
- flexllm/pricing/__init__.py +104 -0
- flexllm/pricing/data.json +1201 -0
- flexllm/pricing/updater.py +223 -0
- flexllm/provider_router.py +213 -0
- flexllm/token_counter.py +270 -0
- flexllm/utils/__init__.py +1 -0
- flexllm/utils/core.py +41 -0
- flexllm-0.3.3.dist-info/METADATA +573 -0
- flexllm-0.3.3.dist-info/RECORD +39 -0
- flexllm-0.3.3.dist-info/WHEEL +4 -0
- flexllm-0.3.3.dist-info/entry_points.txt +3 -0
- flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
flexllm/geminiclient.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini API Client - Google Gemini 模型的批量调用客户端
|
|
3
|
+
|
|
4
|
+
与 OpenAIClient 保持相同的接口,方便上层代码无缝切换。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
from typing import List, Optional, Any, Union
|
|
9
|
+
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from .base_client import LLMClientBase
|
|
13
|
+
from .cache import ResponseCacheConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GeminiClient(LLMClientBase):
|
|
17
|
+
"""
|
|
18
|
+
Google Gemini API 客户端
|
|
19
|
+
|
|
20
|
+
支持 Gemini Developer API 和 Vertex AI。
|
|
21
|
+
|
|
22
|
+
Example (Gemini Developer API):
|
|
23
|
+
>>> client = GeminiClient(api_key="your-key", model="gemini-3-flash-preview")
|
|
24
|
+
>>> result = await client.chat_completions(messages)
|
|
25
|
+
|
|
26
|
+
Example (Vertex AI):
|
|
27
|
+
>>> client = GeminiClient(
|
|
28
|
+
... project_id="your-project-id",
|
|
29
|
+
... location="us-central1",
|
|
30
|
+
... model="gemini-3-flash-preview",
|
|
31
|
+
... use_vertex_ai=True,
|
|
32
|
+
... )
|
|
33
|
+
|
|
34
|
+
Example (thinking 参数 - 统一的思考控制):
|
|
35
|
+
>>> # 禁用思考(最快响应)
|
|
36
|
+
>>> result = client.chat_completions_sync(
|
|
37
|
+
... messages=[{"role": "user", "content": "1+1=?"}],
|
|
38
|
+
... thinking=False,
|
|
39
|
+
... )
|
|
40
|
+
>>> # 启用思考并返回思考内容
|
|
41
|
+
>>> result = client.chat_completions_sync(
|
|
42
|
+
... messages=[{"role": "user", "content": "复杂推理问题"}],
|
|
43
|
+
... thinking=True,
|
|
44
|
+
... return_raw=True,
|
|
45
|
+
... )
|
|
46
|
+
>>> parsed = GeminiClient.parse_thoughts(result.data)
|
|
47
|
+
|
|
48
|
+
thinking 参数值:
|
|
49
|
+
- False: 禁用思考(thinkingLevel=minimal)
|
|
50
|
+
- True: 启用思考并返回思考内容(includeThoughts=True)
|
|
51
|
+
- "minimal"/"low"/"medium"/"high": 设置思考深度(仅 Gemini 3)
|
|
52
|
+
- None: 使用模型默认行为
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
|
56
|
+
VERTEX_AI_URL_TEMPLATE = "https://{location}-aiplatform.googleapis.com/v1"
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
api_key: str = None,
|
|
61
|
+
model: str = None,
|
|
62
|
+
base_url: str = None,
|
|
63
|
+
concurrency_limit: int = 10,
|
|
64
|
+
max_qps: int = 60,
|
|
65
|
+
timeout: int = 120,
|
|
66
|
+
retry_times: int = 3,
|
|
67
|
+
retry_delay: float = 1.0,
|
|
68
|
+
cache_image: bool = False,
|
|
69
|
+
cache_dir: str = "image_cache",
|
|
70
|
+
cache: Optional[ResponseCacheConfig] = None,
|
|
71
|
+
use_vertex_ai: bool = False,
|
|
72
|
+
project_id: str = None,
|
|
73
|
+
location: str = "us-central1",
|
|
74
|
+
credentials: Any = None,
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
self._use_vertex_ai = use_vertex_ai
|
|
78
|
+
self._project_id = project_id
|
|
79
|
+
self._location = location
|
|
80
|
+
self._credentials = credentials
|
|
81
|
+
self._access_token = None
|
|
82
|
+
self._token_expiry = None
|
|
83
|
+
|
|
84
|
+
if use_vertex_ai:
|
|
85
|
+
if not project_id:
|
|
86
|
+
raise ValueError("Vertex AI 模式需要提供 project_id")
|
|
87
|
+
effective_base_url = base_url or self.VERTEX_AI_URL_TEMPLATE.format(location=location)
|
|
88
|
+
else:
|
|
89
|
+
if not api_key:
|
|
90
|
+
raise ValueError("Gemini Developer API 模式需要提供 api_key")
|
|
91
|
+
effective_base_url = base_url or self.DEFAULT_BASE_URL
|
|
92
|
+
|
|
93
|
+
super().__init__(
|
|
94
|
+
base_url=effective_base_url,
|
|
95
|
+
api_key=api_key,
|
|
96
|
+
model=model,
|
|
97
|
+
concurrency_limit=concurrency_limit,
|
|
98
|
+
max_qps=max_qps,
|
|
99
|
+
timeout=timeout,
|
|
100
|
+
retry_times=retry_times,
|
|
101
|
+
retry_delay=retry_delay,
|
|
102
|
+
cache_image=cache_image,
|
|
103
|
+
cache_dir=cache_dir,
|
|
104
|
+
cache=cache,
|
|
105
|
+
**kwargs,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# ========== 实现基类核心方法 ==========
|
|
109
|
+
|
|
110
|
+
def _get_url(self, model: str, stream: bool = False) -> str:
|
|
111
|
+
action = "streamGenerateContent" if stream else "generateContent"
|
|
112
|
+
if self._use_vertex_ai:
|
|
113
|
+
return (
|
|
114
|
+
f"{self._base_url}/projects/{self._project_id}"
|
|
115
|
+
f"/locations/{self._location}/publishers/google/models/{model}:{action}"
|
|
116
|
+
)
|
|
117
|
+
return f"{self._base_url}/models/{model}:{action}?key={self._api_key}"
|
|
118
|
+
|
|
119
|
+
def _get_stream_url(self, model: str) -> str:
|
|
120
|
+
"""Gemini 流式需要添加 alt=sse 参数"""
|
|
121
|
+
url = self._get_url(model, stream=True)
|
|
122
|
+
return url + ("&alt=sse" if "?" in url else "?alt=sse")
|
|
123
|
+
|
|
124
|
+
def _get_headers(self) -> dict:
|
|
125
|
+
headers = {"Content-Type": "application/json"}
|
|
126
|
+
if self._use_vertex_ai:
|
|
127
|
+
headers["Authorization"] = f"Bearer {self._get_access_token()}"
|
|
128
|
+
return headers
|
|
129
|
+
|
|
130
|
+
def _build_request_body(
|
|
131
|
+
self,
|
|
132
|
+
messages: List[dict],
|
|
133
|
+
model: str,
|
|
134
|
+
stream: bool = False,
|
|
135
|
+
max_tokens: int = None,
|
|
136
|
+
temperature: float = None,
|
|
137
|
+
top_p: float = None,
|
|
138
|
+
top_k: int = None,
|
|
139
|
+
stop_sequences: List[str] = None,
|
|
140
|
+
safety_settings: List[dict] = None,
|
|
141
|
+
thinking: Union[bool, str, None] = None,
|
|
142
|
+
**kwargs,
|
|
143
|
+
) -> dict:
|
|
144
|
+
"""
|
|
145
|
+
构建请求体
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
thinking: 统一的思考控制参数
|
|
149
|
+
- False: 禁用思考(thinkingLevel=minimal)
|
|
150
|
+
- True: 启用思考并返回思考内容(includeThoughts=True)
|
|
151
|
+
- "minimal"/"low"/"medium"/"high": 设置思考深度
|
|
152
|
+
- None: 使用模型默认行为
|
|
153
|
+
"""
|
|
154
|
+
contents, system_obj = self._convert_messages_to_contents(messages)
|
|
155
|
+
body = {"contents": contents}
|
|
156
|
+
|
|
157
|
+
if system_obj:
|
|
158
|
+
body["systemInstruction"] = system_obj
|
|
159
|
+
|
|
160
|
+
gen_config = {}
|
|
161
|
+
if max_tokens is not None:
|
|
162
|
+
gen_config["maxOutputTokens"] = max_tokens
|
|
163
|
+
if temperature is not None:
|
|
164
|
+
gen_config["temperature"] = temperature
|
|
165
|
+
if top_p is not None:
|
|
166
|
+
gen_config["topP"] = top_p
|
|
167
|
+
if top_k is not None:
|
|
168
|
+
gen_config["topK"] = top_k
|
|
169
|
+
if stop_sequences:
|
|
170
|
+
gen_config["stopSequences"] = stop_sequences
|
|
171
|
+
|
|
172
|
+
# 构建 thinkingConfig
|
|
173
|
+
thinking_config = {}
|
|
174
|
+
if thinking is False:
|
|
175
|
+
# 禁用思考
|
|
176
|
+
thinking_config["thinkingLevel"] = "minimal"
|
|
177
|
+
elif thinking is True:
|
|
178
|
+
# 启用思考并返回思考内容
|
|
179
|
+
thinking_config["includeThoughts"] = True
|
|
180
|
+
elif isinstance(thinking, str):
|
|
181
|
+
# 设置思考深度
|
|
182
|
+
thinking_config["thinkingLevel"] = thinking
|
|
183
|
+
thinking_config["includeThoughts"] = True
|
|
184
|
+
# thinking=None 时不设置,使用默认行为
|
|
185
|
+
|
|
186
|
+
if thinking_config:
|
|
187
|
+
gen_config["thinkingConfig"] = thinking_config
|
|
188
|
+
|
|
189
|
+
if gen_config:
|
|
190
|
+
body["generationConfig"] = gen_config
|
|
191
|
+
if safety_settings:
|
|
192
|
+
body["safetySettings"] = safety_settings
|
|
193
|
+
|
|
194
|
+
return body
|
|
195
|
+
|
|
196
|
+
def _extract_content(self, response_data: dict) -> Optional[str]:
|
|
197
|
+
try:
|
|
198
|
+
candidates = response_data.get("candidates", [])
|
|
199
|
+
if not candidates:
|
|
200
|
+
if "promptFeedback" in response_data:
|
|
201
|
+
block_reason = response_data["promptFeedback"].get("blockReason", "UNKNOWN")
|
|
202
|
+
logger.warning(f"Request blocked by Gemini: {block_reason}")
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
206
|
+
# 只提取非 thought 部分的文本(即最终答案)
|
|
207
|
+
texts = [p.get("text", "") for p in parts if "text" in p and not p.get("thought")]
|
|
208
|
+
return "".join(texts) if texts else None
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.warning(f"Failed to extract response text: {e}")
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _extract_usage(self, response_data: dict) -> Optional[dict]:
|
|
214
|
+
"""
|
|
215
|
+
提取 Gemini API 的 usage 信息
|
|
216
|
+
|
|
217
|
+
Gemini 响应格式:
|
|
218
|
+
{
|
|
219
|
+
"candidates": [...],
|
|
220
|
+
"usageMetadata": {
|
|
221
|
+
"promptTokenCount": 100,
|
|
222
|
+
"candidatesTokenCount": 50,
|
|
223
|
+
"totalTokenCount": 150
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
转换为统一格式:
|
|
228
|
+
{
|
|
229
|
+
"prompt_tokens": 100,
|
|
230
|
+
"completion_tokens": 50,
|
|
231
|
+
"total_tokens": 150
|
|
232
|
+
}
|
|
233
|
+
"""
|
|
234
|
+
if not response_data:
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
usage_metadata = response_data.get("usageMetadata")
|
|
238
|
+
if not usage_metadata:
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
"prompt_tokens": usage_metadata.get("promptTokenCount", 0),
|
|
243
|
+
"completion_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
|
244
|
+
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
def _extract_tool_calls(self, response_data: dict):
|
|
248
|
+
"""
|
|
249
|
+
提取 Gemini 格式的 function calls
|
|
250
|
+
|
|
251
|
+
Gemini 响应格式:
|
|
252
|
+
{
|
|
253
|
+
"candidates": [{
|
|
254
|
+
"content": {
|
|
255
|
+
"parts": [{
|
|
256
|
+
"functionCall": {
|
|
257
|
+
"name": "get_weather",
|
|
258
|
+
"args": {"location": "Tokyo"}
|
|
259
|
+
}
|
|
260
|
+
}]
|
|
261
|
+
}
|
|
262
|
+
}]
|
|
263
|
+
}
|
|
264
|
+
"""
|
|
265
|
+
import json
|
|
266
|
+
from .base_client import ToolCall
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
candidates = response_data.get("candidates", [])
|
|
270
|
+
if not candidates:
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
274
|
+
tool_calls = []
|
|
275
|
+
for i, part in enumerate(parts):
|
|
276
|
+
if "functionCall" in part:
|
|
277
|
+
fc = part["functionCall"]
|
|
278
|
+
tool_calls.append(ToolCall(
|
|
279
|
+
id=f"call_{i}", # Gemini 没有 id,生成一个
|
|
280
|
+
type="function",
|
|
281
|
+
function={
|
|
282
|
+
"name": fc.get("name", ""),
|
|
283
|
+
"arguments": json.dumps(fc.get("args", {}))
|
|
284
|
+
}
|
|
285
|
+
))
|
|
286
|
+
return tool_calls if tool_calls else None
|
|
287
|
+
except Exception:
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
@staticmethod
|
|
291
|
+
def parse_thoughts(response_data: dict) -> dict:
|
|
292
|
+
"""
|
|
293
|
+
从响应中解析思考内容和答案
|
|
294
|
+
|
|
295
|
+
当使用 thinking=True 时,可以用此方法解析响应。
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
response_data: 原始响应数据(通过 return_raw=True 获取)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
dict: {
|
|
302
|
+
"thought": str, # 思考过程摘要(可能为空)
|
|
303
|
+
"answer": str, # 最终答案
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
Example:
|
|
307
|
+
>>> result = await client.chat_completions(
|
|
308
|
+
... messages=[...],
|
|
309
|
+
... thinking=True,
|
|
310
|
+
... return_raw=True,
|
|
311
|
+
... )
|
|
312
|
+
>>> parsed = GeminiClient.parse_thoughts(result.data)
|
|
313
|
+
>>> print("思考:", parsed["thought"])
|
|
314
|
+
>>> print("答案:", parsed["answer"])
|
|
315
|
+
"""
|
|
316
|
+
thought_parts = []
|
|
317
|
+
answer_parts = []
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
candidates = response_data.get("candidates", [])
|
|
321
|
+
if not candidates:
|
|
322
|
+
return {"thought": "", "answer": ""}
|
|
323
|
+
|
|
324
|
+
parts = candidates[0].get("content", {}).get("parts", [])
|
|
325
|
+
for part in parts:
|
|
326
|
+
text = part.get("text", "")
|
|
327
|
+
if not text:
|
|
328
|
+
continue
|
|
329
|
+
if part.get("thought"):
|
|
330
|
+
thought_parts.append(text)
|
|
331
|
+
else:
|
|
332
|
+
answer_parts.append(text)
|
|
333
|
+
|
|
334
|
+
return {
|
|
335
|
+
"thought": "".join(thought_parts),
|
|
336
|
+
"answer": "".join(answer_parts),
|
|
337
|
+
}
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.warning(f"Failed to parse thoughts: {e}")
|
|
340
|
+
return {"thought": "", "answer": ""}
|
|
341
|
+
|
|
342
|
+
def _extract_stream_content(self, data: dict) -> Optional[str]:
|
|
343
|
+
"""
|
|
344
|
+
从 Gemini 流式响应中提取文本内容
|
|
345
|
+
|
|
346
|
+
Gemini 流式响应格式:
|
|
347
|
+
{
|
|
348
|
+
"candidates": [{
|
|
349
|
+
"content": {
|
|
350
|
+
"parts": [{"text": "部分文本"}],
|
|
351
|
+
"role": "model"
|
|
352
|
+
}
|
|
353
|
+
}]
|
|
354
|
+
}
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
candidates = data.get("candidates", [])
|
|
358
|
+
if not candidates:
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
# 获取第一个候选的内容
|
|
362
|
+
content = candidates[0].get("content", {})
|
|
363
|
+
parts = content.get("parts", [])
|
|
364
|
+
|
|
365
|
+
# 提取所有文本部分
|
|
366
|
+
for part in parts:
|
|
367
|
+
if "text" in part:
|
|
368
|
+
return part["text"]
|
|
369
|
+
return None
|
|
370
|
+
except Exception:
|
|
371
|
+
return None
|
|
372
|
+
|
|
373
|
+
async def chat_completions_stream(
|
|
374
|
+
self,
|
|
375
|
+
messages: List[dict],
|
|
376
|
+
model: str = None,
|
|
377
|
+
return_usage: bool = False,
|
|
378
|
+
preprocess_msg: bool = False,
|
|
379
|
+
url: str = None,
|
|
380
|
+
timeout: int = None,
|
|
381
|
+
**kwargs,
|
|
382
|
+
):
|
|
383
|
+
"""
|
|
384
|
+
流式聊天完成(Gemini 专用实现)
|
|
385
|
+
|
|
386
|
+
Gemini 流式响应中的 usage 信息在最后一个 chunk 的 usageMetadata 字段中。
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
messages: 消息列表
|
|
390
|
+
model: 模型名称
|
|
391
|
+
return_usage: 是否返回 usage 信息
|
|
392
|
+
preprocess_msg: 是否预处理消息
|
|
393
|
+
url: 自定义请求 URL
|
|
394
|
+
timeout: 超时时间(秒)
|
|
395
|
+
|
|
396
|
+
Yields:
|
|
397
|
+
- return_usage=False: str 内容片段
|
|
398
|
+
- return_usage=True: dict,包含 type 和对应数据
|
|
399
|
+
"""
|
|
400
|
+
import aiohttp
|
|
401
|
+
import json
|
|
402
|
+
|
|
403
|
+
effective_model = self._get_effective_model(model)
|
|
404
|
+
messages = await self._preprocess_messages(messages, preprocess_msg)
|
|
405
|
+
|
|
406
|
+
body = self._build_request_body(messages, effective_model, stream=True, **kwargs)
|
|
407
|
+
# 注意:Gemini 不需要 stream_options,usage 自动包含在最后一个 chunk 中
|
|
408
|
+
|
|
409
|
+
effective_url = url or self._get_stream_url(effective_model)
|
|
410
|
+
headers = self._get_headers()
|
|
411
|
+
|
|
412
|
+
effective_timeout = timeout if timeout is not None else self._timeout
|
|
413
|
+
aio_timeout = aiohttp.ClientTimeout(total=effective_timeout)
|
|
414
|
+
|
|
415
|
+
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
416
|
+
async with session.post(effective_url, json=body, headers=headers, timeout=aio_timeout) as response:
|
|
417
|
+
if response.status != 200:
|
|
418
|
+
error_text = await response.text()
|
|
419
|
+
raise Exception(f"HTTP {response.status}: {error_text}")
|
|
420
|
+
|
|
421
|
+
async for line in response.content:
|
|
422
|
+
line = line.decode("utf-8").strip()
|
|
423
|
+
if line.startswith("data: "):
|
|
424
|
+
data_str = line[6:]
|
|
425
|
+
if data_str == "[DONE]":
|
|
426
|
+
break
|
|
427
|
+
try:
|
|
428
|
+
data = json.loads(data_str)
|
|
429
|
+
|
|
430
|
+
# 提取内容
|
|
431
|
+
content = self._extract_stream_content(data)
|
|
432
|
+
if content:
|
|
433
|
+
if return_usage:
|
|
434
|
+
yield {"type": "content", "content": content}
|
|
435
|
+
else:
|
|
436
|
+
yield content
|
|
437
|
+
|
|
438
|
+
# 检查是否包含 usage(Gemini 在最后一个 chunk 中包含 usageMetadata)
|
|
439
|
+
if return_usage and "usageMetadata" in data:
|
|
440
|
+
usage = self._extract_usage(data)
|
|
441
|
+
if usage:
|
|
442
|
+
yield {"type": "usage", "usage": usage}
|
|
443
|
+
|
|
444
|
+
except json.JSONDecodeError:
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
# ========== Gemini 特有方法 ==========
|
|
448
|
+
|
|
449
|
+
def _get_access_token(self) -> str:
|
|
450
|
+
"""获取 Vertex AI 的 Access Token"""
|
|
451
|
+
import time
|
|
452
|
+
|
|
453
|
+
if self._access_token and self._token_expiry and time.time() < self._token_expiry - 60:
|
|
454
|
+
return self._access_token
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
import google.auth
|
|
458
|
+
import google.auth.transport.requests
|
|
459
|
+
|
|
460
|
+
credentials = self._credentials or google.auth.default(
|
|
461
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
|
462
|
+
)[0]
|
|
463
|
+
|
|
464
|
+
request = google.auth.transport.requests.Request()
|
|
465
|
+
credentials.refresh(request)
|
|
466
|
+
|
|
467
|
+
self._access_token = credentials.token
|
|
468
|
+
self._token_expiry = time.time() + 3600
|
|
469
|
+
return self._access_token
|
|
470
|
+
except ImportError:
|
|
471
|
+
raise ImportError("Vertex AI 模式需要安装 google-auth: pip install google-auth")
|
|
472
|
+
except Exception as e:
|
|
473
|
+
raise RuntimeError(f"获取 Vertex AI 访问令牌失败: {e}")
|
|
474
|
+
|
|
475
|
+
def _convert_messages_to_contents(
|
|
476
|
+
self, messages: List[dict], system_instruction: str = None
|
|
477
|
+
) -> tuple[List[dict], Optional[dict]]:
|
|
478
|
+
"""将 OpenAI 格式的 messages 转换为 Gemini 格式"""
|
|
479
|
+
contents = []
|
|
480
|
+
extracted_system = system_instruction
|
|
481
|
+
|
|
482
|
+
for msg in messages:
|
|
483
|
+
role = msg.get("role", "user")
|
|
484
|
+
content = msg.get("content", "")
|
|
485
|
+
|
|
486
|
+
if role == "system":
|
|
487
|
+
if isinstance(content, str):
|
|
488
|
+
extracted_system = content
|
|
489
|
+
elif isinstance(content, list):
|
|
490
|
+
texts = [p.get("text", "") for p in content if p.get("type") == "text"]
|
|
491
|
+
extracted_system = "\n".join(texts)
|
|
492
|
+
continue
|
|
493
|
+
|
|
494
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
495
|
+
parts = self._convert_content_to_parts(content)
|
|
496
|
+
|
|
497
|
+
if parts:
|
|
498
|
+
contents.append({"role": gemini_role, "parts": parts})
|
|
499
|
+
|
|
500
|
+
system_obj = {"parts": [{"text": extracted_system}]} if extracted_system else None
|
|
501
|
+
return contents, system_obj
|
|
502
|
+
|
|
503
|
+
def _convert_content_to_parts(self, content: Any) -> List[dict]:
|
|
504
|
+
"""将 OpenAI 格式的 content 转换为 Gemini 格式的 parts"""
|
|
505
|
+
if content is None:
|
|
506
|
+
return []
|
|
507
|
+
if isinstance(content, str):
|
|
508
|
+
return [{"text": content}]
|
|
509
|
+
|
|
510
|
+
if isinstance(content, list):
|
|
511
|
+
parts = []
|
|
512
|
+
for item in content:
|
|
513
|
+
if isinstance(item, str):
|
|
514
|
+
parts.append({"text": item})
|
|
515
|
+
elif isinstance(item, dict):
|
|
516
|
+
item_type = item.get("type", "text")
|
|
517
|
+
if item_type == "text" and item.get("text"):
|
|
518
|
+
parts.append({"text": item["text"]})
|
|
519
|
+
elif item_type == "image_url":
|
|
520
|
+
if img := self._convert_image_url(item.get("image_url", {})):
|
|
521
|
+
parts.append(img)
|
|
522
|
+
elif item_type == "image":
|
|
523
|
+
if img := self._convert_image_direct(item):
|
|
524
|
+
parts.append(img)
|
|
525
|
+
return parts
|
|
526
|
+
return []
|
|
527
|
+
|
|
528
|
+
def _convert_image_url(self, image_url_obj: dict) -> Optional[dict]:
|
|
529
|
+
"""将 OpenAI 的 image_url 格式转换为 Gemini 的 inline_data 格式"""
|
|
530
|
+
url = image_url_obj.get("url", "")
|
|
531
|
+
if not url:
|
|
532
|
+
return None
|
|
533
|
+
|
|
534
|
+
if url.startswith("data:"):
|
|
535
|
+
match = re.match(r"data:([^;]+);base64,(.+)", url)
|
|
536
|
+
if match:
|
|
537
|
+
return {"inline_data": {"mime_type": match.group(1), "data": match.group(2)}}
|
|
538
|
+
|
|
539
|
+
logger.warning(f"Gemini API 不直接支持外部 URL,请先转换为 base64: {url[:50]}...")
|
|
540
|
+
return None
|
|
541
|
+
|
|
542
|
+
def _convert_image_direct(self, image_obj: dict) -> Optional[dict]:
|
|
543
|
+
"""处理直接的图片数据"""
|
|
544
|
+
data = image_obj.get("data", "")
|
|
545
|
+
if data:
|
|
546
|
+
return {"inline_data": {"mime_type": image_obj.get("mime_type", "image/jpeg"), "data": data}}
|
|
547
|
+
return None
|
|
548
|
+
|
|
549
|
+
def model_list(self) -> List[str]:
|
|
550
|
+
"""获取可用模型列表"""
|
|
551
|
+
import requests
|
|
552
|
+
|
|
553
|
+
if self._use_vertex_ai:
|
|
554
|
+
url = f"{self._base_url}/projects/{self._project_id}/locations/{self._location}/publishers/google/models"
|
|
555
|
+
response = requests.get(url, headers=self._get_headers())
|
|
556
|
+
else:
|
|
557
|
+
response = requests.get(f"{self._base_url}/models?key={self._api_key}")
|
|
558
|
+
|
|
559
|
+
if response.status_code == 200:
|
|
560
|
+
models = response.json().get("models", [])
|
|
561
|
+
return [m.get("name", "").replace("models/", "") for m in models]
|
|
562
|
+
logger.error(f"Failed to fetch model list: {response.text}")
|
|
563
|
+
return []
|