aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
aient/core/response.py
ADDED
@@ -0,0 +1,531 @@
|
|
1
|
+
import json
|
2
|
+
import random
|
3
|
+
import string
|
4
|
+
from datetime import datetime
|
5
|
+
|
6
|
+
from .log_config import logger
|
7
|
+
|
8
|
+
from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line
|
9
|
+
|
10
|
+
async def check_response(response, error_log):
|
11
|
+
if response and not (200 <= response.status_code < 300):
|
12
|
+
error_message = await response.aread()
|
13
|
+
error_str = error_message.decode('utf-8', errors='replace')
|
14
|
+
try:
|
15
|
+
error_json = json.loads(error_str)
|
16
|
+
except json.JSONDecodeError:
|
17
|
+
error_json = error_str
|
18
|
+
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
|
19
|
+
return None
|
20
|
+
|
21
|
+
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
22
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
23
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
24
|
+
error_message = await check_response(response, "fetch_gemini_response_stream")
|
25
|
+
if error_message:
|
26
|
+
yield error_message
|
27
|
+
return
|
28
|
+
buffer = ""
|
29
|
+
revicing_function_call = False
|
30
|
+
function_full_response = "{"
|
31
|
+
need_function_call = False
|
32
|
+
is_finish = False
|
33
|
+
# line_index = 0
|
34
|
+
# last_text_line = 0
|
35
|
+
# if "thinking" in model:
|
36
|
+
# is_thinking = True
|
37
|
+
# else:
|
38
|
+
# is_thinking = False
|
39
|
+
async for chunk in response.aiter_text():
|
40
|
+
buffer += chunk
|
41
|
+
|
42
|
+
while "\n" in buffer:
|
43
|
+
line, buffer = buffer.split("\n", 1)
|
44
|
+
# line_index += 1
|
45
|
+
if line and '\"finishReason\": \"' in line:
|
46
|
+
is_finish = True
|
47
|
+
break
|
48
|
+
# print(line)
|
49
|
+
if line and '\"text\": \"' in line:
|
50
|
+
try:
|
51
|
+
json_data = json.loads( "{" + line + "}")
|
52
|
+
content = json_data.get('text', '')
|
53
|
+
content = "\n".join(content.split("\\n"))
|
54
|
+
# content = content.replace("\n", "\n\n")
|
55
|
+
# if last_text_line == 0 and is_thinking:
|
56
|
+
# content = "> " + content.lstrip()
|
57
|
+
# if is_thinking:
|
58
|
+
# content = content.replace("\n", "\n> ")
|
59
|
+
# if last_text_line == line_index - 3:
|
60
|
+
# is_thinking = False
|
61
|
+
# content = "\n\n\n" + content.lstrip()
|
62
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
63
|
+
yield sse_string
|
64
|
+
except json.JSONDecodeError:
|
65
|
+
logger.error(f"无法解析JSON: {line}")
|
66
|
+
# last_text_line = line_index
|
67
|
+
|
68
|
+
if line and ('\"functionCall\": {' in line or revicing_function_call):
|
69
|
+
revicing_function_call = True
|
70
|
+
need_function_call = True
|
71
|
+
if ']' in line:
|
72
|
+
revicing_function_call = False
|
73
|
+
continue
|
74
|
+
|
75
|
+
function_full_response += line
|
76
|
+
|
77
|
+
if is_finish:
|
78
|
+
break
|
79
|
+
|
80
|
+
if need_function_call:
|
81
|
+
function_call = json.loads(function_full_response)
|
82
|
+
function_call_name = function_call["functionCall"]["name"]
|
83
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
|
84
|
+
yield sse_string
|
85
|
+
function_full_response = json.dumps(function_call["functionCall"]["args"])
|
86
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
87
|
+
yield sse_string
|
88
|
+
yield "data: [DONE]" + end_of_line
|
89
|
+
|
90
|
+
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
91
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
92
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
93
|
+
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
|
94
|
+
if error_message:
|
95
|
+
yield error_message
|
96
|
+
return
|
97
|
+
|
98
|
+
buffer = ""
|
99
|
+
revicing_function_call = False
|
100
|
+
function_full_response = "{"
|
101
|
+
need_function_call = False
|
102
|
+
async for chunk in response.aiter_text():
|
103
|
+
buffer += chunk
|
104
|
+
while "\n" in buffer:
|
105
|
+
line, buffer = buffer.split("\n", 1)
|
106
|
+
# logger.info(f"{line}")
|
107
|
+
if line and '\"text\": \"' in line:
|
108
|
+
try:
|
109
|
+
json_data = json.loads( "{" + line + "}")
|
110
|
+
content = json_data.get('text', '')
|
111
|
+
content = "\n".join(content.split("\\n"))
|
112
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
113
|
+
yield sse_string
|
114
|
+
except json.JSONDecodeError:
|
115
|
+
logger.error(f"无法解析JSON: {line}")
|
116
|
+
|
117
|
+
if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
|
118
|
+
revicing_function_call = True
|
119
|
+
need_function_call = True
|
120
|
+
if ']' in line:
|
121
|
+
revicing_function_call = False
|
122
|
+
continue
|
123
|
+
|
124
|
+
function_full_response += line
|
125
|
+
|
126
|
+
if need_function_call:
|
127
|
+
function_call = json.loads(function_full_response)
|
128
|
+
function_call_name = function_call["name"]
|
129
|
+
function_call_id = function_call["id"]
|
130
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
131
|
+
yield sse_string
|
132
|
+
function_full_response = json.dumps(function_call["input"])
|
133
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
|
134
|
+
yield sse_string
|
135
|
+
yield "data: [DONE]" + end_of_line
|
136
|
+
|
137
|
+
async def fetch_gpt_response_stream(client, url, headers, payload):
|
138
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
139
|
+
random.seed(timestamp)
|
140
|
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
141
|
+
is_thinking = False
|
142
|
+
has_send_thinking = False
|
143
|
+
ark_tag = False
|
144
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
145
|
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
146
|
+
if error_message:
|
147
|
+
yield error_message
|
148
|
+
return
|
149
|
+
|
150
|
+
buffer = ""
|
151
|
+
enter_buffer = ""
|
152
|
+
async for chunk in response.aiter_text():
|
153
|
+
buffer += chunk
|
154
|
+
while "\n" in buffer:
|
155
|
+
line, buffer = buffer.split("\n", 1)
|
156
|
+
# logger.info("line: %s", repr(line))
|
157
|
+
if line and line != "data: " and line != "data:" and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
|
158
|
+
if result.strip() == "[DONE]":
|
159
|
+
yield "data: [DONE]" + end_of_line
|
160
|
+
return
|
161
|
+
line = json.loads(result)
|
162
|
+
line['id'] = f"chatcmpl-{random_str}"
|
163
|
+
|
164
|
+
# 处理 <think> 标签
|
165
|
+
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
166
|
+
if "<think>" in content:
|
167
|
+
is_thinking = True
|
168
|
+
ark_tag = True
|
169
|
+
content = content.replace("<think>", "")
|
170
|
+
if "</think>" in content:
|
171
|
+
end_think_reasoning_content = ""
|
172
|
+
end_think_content = ""
|
173
|
+
is_thinking = False
|
174
|
+
|
175
|
+
if content.rstrip('\n').endswith("</think>"):
|
176
|
+
end_think_reasoning_content = content.replace("</think>", "").rstrip('\n')
|
177
|
+
elif content.lstrip('\n').startswith("</think>"):
|
178
|
+
end_think_content = content.replace("</think>", "").lstrip('\n')
|
179
|
+
else:
|
180
|
+
end_think_reasoning_content = content.split("</think>")[0]
|
181
|
+
end_think_content = content.split("</think>")[1]
|
182
|
+
|
183
|
+
if end_think_reasoning_content:
|
184
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=end_think_reasoning_content)
|
185
|
+
yield sse_string
|
186
|
+
if end_think_content:
|
187
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], content=end_think_content)
|
188
|
+
yield sse_string
|
189
|
+
continue
|
190
|
+
if is_thinking and ark_tag:
|
191
|
+
if not has_send_thinking:
|
192
|
+
content = content.replace("\n\n", "")
|
193
|
+
if content:
|
194
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
195
|
+
yield sse_string
|
196
|
+
has_send_thinking = True
|
197
|
+
continue
|
198
|
+
|
199
|
+
# 处理 poe thinking 标签
|
200
|
+
if "Thinking..." in content and "\n> " in content:
|
201
|
+
is_thinking = True
|
202
|
+
content = content.replace("Thinking...", "").replace("\n> ", "")
|
203
|
+
if is_thinking and "\n\n" in content and not ark_tag:
|
204
|
+
is_thinking = False
|
205
|
+
if is_thinking and not ark_tag:
|
206
|
+
content = content.replace("\n> ", "")
|
207
|
+
if not has_send_thinking:
|
208
|
+
content = content.replace("\n", "")
|
209
|
+
if content:
|
210
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
211
|
+
yield sse_string
|
212
|
+
has_send_thinking = True
|
213
|
+
continue
|
214
|
+
|
215
|
+
no_stream_content = safe_get(line, "choices", 0, "message", "content", default=None)
|
216
|
+
openrouter_reasoning = safe_get(line, "choices", 0, "delta", "reasoning", default="")
|
217
|
+
# print("openrouter_reasoning", repr(openrouter_reasoning), openrouter_reasoning.endswith("\\\\"), openrouter_reasoning.endswith("\\"))
|
218
|
+
if openrouter_reasoning:
|
219
|
+
if openrouter_reasoning.endswith("\\"):
|
220
|
+
enter_buffer += openrouter_reasoning
|
221
|
+
continue
|
222
|
+
elif enter_buffer.endswith("\\") and openrouter_reasoning == 'n':
|
223
|
+
enter_buffer += "n"
|
224
|
+
continue
|
225
|
+
elif enter_buffer.endswith("\\n") and openrouter_reasoning == '\\n':
|
226
|
+
enter_buffer += "\\n"
|
227
|
+
continue
|
228
|
+
elif enter_buffer.endswith("\\n\\n"):
|
229
|
+
openrouter_reasoning = '\n\n' + openrouter_reasoning
|
230
|
+
enter_buffer = ""
|
231
|
+
elif enter_buffer:
|
232
|
+
openrouter_reasoning = enter_buffer + openrouter_reasoning
|
233
|
+
enter_buffer = ''
|
234
|
+
openrouter_reasoning = openrouter_reasoning.replace("\\n", "\n")
|
235
|
+
|
236
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=openrouter_reasoning)
|
237
|
+
yield sse_string
|
238
|
+
elif no_stream_content and has_send_thinking == False:
|
239
|
+
sse_string = await generate_sse_response(safe_get(line, "created", default=None), safe_get(line, "model", default=None), content=no_stream_content)
|
240
|
+
yield sse_string
|
241
|
+
else:
|
242
|
+
if no_stream_content:
|
243
|
+
del line["choices"][0]["message"]
|
244
|
+
yield "data: " + json.dumps(line).strip() + end_of_line
|
245
|
+
|
246
|
+
async def fetch_azure_response_stream(client, url, headers, payload):
|
247
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
248
|
+
is_thinking = False
|
249
|
+
has_send_thinking = False
|
250
|
+
ark_tag = False
|
251
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
252
|
+
error_message = await check_response(response, "fetch_azure_response_stream")
|
253
|
+
if error_message:
|
254
|
+
yield error_message
|
255
|
+
return
|
256
|
+
|
257
|
+
buffer = ""
|
258
|
+
sse_string = ""
|
259
|
+
async for chunk in response.aiter_text():
|
260
|
+
buffer += chunk
|
261
|
+
while "\n" in buffer:
|
262
|
+
line, buffer = buffer.split("\n", 1)
|
263
|
+
# logger.info("line: %s", repr(line))
|
264
|
+
if line and line != "data: " and line != "data:" and not line.startswith(": "):
|
265
|
+
result = line.lstrip("data: ")
|
266
|
+
if result.strip() == "[DONE]":
|
267
|
+
yield "data: [DONE]" + end_of_line
|
268
|
+
return
|
269
|
+
line = json.loads(result)
|
270
|
+
no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
|
271
|
+
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
272
|
+
|
273
|
+
# 处理 <think> 标签
|
274
|
+
if "<think>" in content:
|
275
|
+
is_thinking = True
|
276
|
+
ark_tag = True
|
277
|
+
content = content.replace("<think>", "")
|
278
|
+
if "</think>" in content:
|
279
|
+
is_thinking = False
|
280
|
+
content = content.replace("</think>", "")
|
281
|
+
if not content:
|
282
|
+
continue
|
283
|
+
if is_thinking and ark_tag:
|
284
|
+
if not has_send_thinking:
|
285
|
+
content = content.replace("\n\n", "")
|
286
|
+
if content:
|
287
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
288
|
+
yield sse_string
|
289
|
+
has_send_thinking = True
|
290
|
+
continue
|
291
|
+
|
292
|
+
if no_stream_content or content or sse_string:
|
293
|
+
sse_string = await generate_sse_response(timestamp, safe_get(line, "model", default=None), content=no_stream_content or content)
|
294
|
+
yield sse_string
|
295
|
+
else:
|
296
|
+
if no_stream_content:
|
297
|
+
del line["choices"][0]["message"]
|
298
|
+
yield "data: " + json.dumps(line).strip() + end_of_line
|
299
|
+
yield "data: [DONE]" + end_of_line
|
300
|
+
|
301
|
+
async def fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
302
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
303
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
304
|
+
error_message = await check_response(response, "fetch_cloudflare_response_stream")
|
305
|
+
if error_message:
|
306
|
+
yield error_message
|
307
|
+
return
|
308
|
+
|
309
|
+
buffer = ""
|
310
|
+
async for chunk in response.aiter_text():
|
311
|
+
buffer += chunk
|
312
|
+
while "\n" in buffer:
|
313
|
+
line, buffer = buffer.split("\n", 1)
|
314
|
+
# logger.info("line: %s", repr(line))
|
315
|
+
if line.startswith("data:"):
|
316
|
+
line = line.lstrip("data: ")
|
317
|
+
if line == "[DONE]":
|
318
|
+
yield "data: [DONE]" + end_of_line
|
319
|
+
return
|
320
|
+
resp: dict = json.loads(line)
|
321
|
+
message = resp.get("response")
|
322
|
+
if message:
|
323
|
+
sse_string = await generate_sse_response(timestamp, model, content=message)
|
324
|
+
yield sse_string
|
325
|
+
|
326
|
+
async def fetch_cohere_response_stream(client, url, headers, payload, model):
|
327
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
328
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
329
|
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
330
|
+
if error_message:
|
331
|
+
yield error_message
|
332
|
+
return
|
333
|
+
|
334
|
+
buffer = ""
|
335
|
+
async for chunk in response.aiter_text():
|
336
|
+
buffer += chunk
|
337
|
+
while "\n" in buffer:
|
338
|
+
line, buffer = buffer.split("\n", 1)
|
339
|
+
# logger.info("line: %s", repr(line))
|
340
|
+
resp: dict = json.loads(line)
|
341
|
+
if resp.get("is_finished") == True:
|
342
|
+
yield "data: [DONE]" + end_of_line
|
343
|
+
return
|
344
|
+
if resp.get("event_type") == "text-generation":
|
345
|
+
message = resp.get("text")
|
346
|
+
sse_string = await generate_sse_response(timestamp, model, content=message)
|
347
|
+
yield sse_string
|
348
|
+
|
349
|
+
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
350
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
351
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
352
|
+
error_message = await check_response(response, "fetch_claude_response_stream")
|
353
|
+
if error_message:
|
354
|
+
yield error_message
|
355
|
+
return
|
356
|
+
buffer = ""
|
357
|
+
input_tokens = 0
|
358
|
+
async for chunk in response.aiter_text():
|
359
|
+
# logger.info(f"chunk: {repr(chunk)}")
|
360
|
+
buffer += chunk
|
361
|
+
while "\n" in buffer:
|
362
|
+
line, buffer = buffer.split("\n", 1)
|
363
|
+
# logger.info(line)
|
364
|
+
|
365
|
+
if line.startswith("data:"):
|
366
|
+
line = line.lstrip("data: ")
|
367
|
+
resp: dict = json.loads(line)
|
368
|
+
message = resp.get("message")
|
369
|
+
if message:
|
370
|
+
role = message.get("role")
|
371
|
+
if role:
|
372
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
373
|
+
yield sse_string
|
374
|
+
tokens_use = message.get("usage")
|
375
|
+
if tokens_use:
|
376
|
+
input_tokens = tokens_use.get("input_tokens", 0)
|
377
|
+
usage = resp.get("usage")
|
378
|
+
if usage:
|
379
|
+
output_tokens = usage.get("output_tokens", 0)
|
380
|
+
total_tokens = input_tokens + output_tokens
|
381
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
|
382
|
+
yield sse_string
|
383
|
+
# print("\n\rtotal_tokens", total_tokens)
|
384
|
+
|
385
|
+
tool_use = resp.get("content_block")
|
386
|
+
tools_id = None
|
387
|
+
function_call_name = None
|
388
|
+
if tool_use and "tool_use" == tool_use['type']:
|
389
|
+
# print("tool_use", tool_use)
|
390
|
+
tools_id = tool_use["id"]
|
391
|
+
if "name" in tool_use:
|
392
|
+
function_call_name = tool_use["name"]
|
393
|
+
sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
|
394
|
+
yield sse_string
|
395
|
+
delta = resp.get("delta")
|
396
|
+
# print("delta", delta)
|
397
|
+
if not delta:
|
398
|
+
continue
|
399
|
+
if "text" in delta:
|
400
|
+
content = delta["text"]
|
401
|
+
sse_string = await generate_sse_response(timestamp, model, content, None, None)
|
402
|
+
yield sse_string
|
403
|
+
if "thinking" in delta and delta["thinking"]:
|
404
|
+
content = delta["thinking"]
|
405
|
+
sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
|
406
|
+
yield sse_string
|
407
|
+
if "partial_json" in delta:
|
408
|
+
# {"type":"input_json_delta","partial_json":""}
|
409
|
+
function_call_content = delta["partial_json"]
|
410
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
|
411
|
+
yield sse_string
|
412
|
+
yield "data: [DONE]" + end_of_line
|
413
|
+
|
414
|
+
async def fetch_response(client, url, headers, payload, engine, model):
|
415
|
+
response = None
|
416
|
+
if payload.get("file"):
|
417
|
+
file = payload.pop("file")
|
418
|
+
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
419
|
+
else:
|
420
|
+
response = await client.post(url, headers=headers, json=payload)
|
421
|
+
error_message = await check_response(response, "fetch_response")
|
422
|
+
if error_message:
|
423
|
+
yield error_message
|
424
|
+
return
|
425
|
+
|
426
|
+
if engine == "tts":
|
427
|
+
yield response.read()
|
428
|
+
|
429
|
+
elif engine == "gemini" or engine == "vertex-gemini":
|
430
|
+
response_json = response.json()
|
431
|
+
|
432
|
+
if isinstance(response_json, str):
|
433
|
+
import ast
|
434
|
+
parsed_data = ast.literal_eval(str(response_json))
|
435
|
+
elif isinstance(response_json, list):
|
436
|
+
parsed_data = response_json
|
437
|
+
else:
|
438
|
+
logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
|
439
|
+
parsed_data = response_json
|
440
|
+
# print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
|
441
|
+
content = ""
|
442
|
+
for item in parsed_data:
|
443
|
+
chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
|
444
|
+
# logger.info(f"chunk: {repr(chunk)}")
|
445
|
+
if chunk:
|
446
|
+
content += chunk
|
447
|
+
|
448
|
+
usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
|
449
|
+
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
450
|
+
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
451
|
+
total_tokens = usage_metadata.get("totalTokenCount", 0)
|
452
|
+
|
453
|
+
role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
|
454
|
+
if role == "model":
|
455
|
+
role = "assistant"
|
456
|
+
else:
|
457
|
+
logger.error(f"Unknown role: {role}")
|
458
|
+
role = "assistant"
|
459
|
+
|
460
|
+
function_call_name = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "name", default=None)
|
461
|
+
function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
|
462
|
+
|
463
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
464
|
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
465
|
+
|
466
|
+
elif engine == "claude":
|
467
|
+
response_json = response.json()
|
468
|
+
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
469
|
+
|
470
|
+
content = safe_get(response_json, "content", 0, "text")
|
471
|
+
|
472
|
+
prompt_tokens = safe_get(response_json, "usage", "input_tokens")
|
473
|
+
output_tokens = safe_get(response_json, "usage", "output_tokens")
|
474
|
+
total_tokens = prompt_tokens + output_tokens
|
475
|
+
|
476
|
+
role = safe_get(response_json, "role")
|
477
|
+
|
478
|
+
function_call_name = safe_get(response_json, "content", 1, "name", default=None)
|
479
|
+
function_call_content = safe_get(response_json, "content", 1, "input", default=None)
|
480
|
+
tools_id = safe_get(response_json, "content", 1, "id", default=None)
|
481
|
+
|
482
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
483
|
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
|
484
|
+
|
485
|
+
elif engine == "azure":
|
486
|
+
response_json = response.json()
|
487
|
+
# 删除 content_filter_results
|
488
|
+
if "choices" in response_json:
|
489
|
+
for choice in response_json["choices"]:
|
490
|
+
if "content_filter_results" in choice:
|
491
|
+
del choice["content_filter_results"]
|
492
|
+
|
493
|
+
# 删除 prompt_filter_results
|
494
|
+
if "prompt_filter_results" in response_json:
|
495
|
+
del response_json["prompt_filter_results"]
|
496
|
+
|
497
|
+
yield response_json
|
498
|
+
|
499
|
+
else:
|
500
|
+
response_json = response.json()
|
501
|
+
yield response_json
|
502
|
+
|
503
|
+
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
504
|
+
# try:
|
505
|
+
if engine == "gemini" or engine == "vertex-gemini":
|
506
|
+
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
507
|
+
yield chunk
|
508
|
+
elif engine == "claude" or engine == "vertex-claude":
|
509
|
+
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
510
|
+
yield chunk
|
511
|
+
elif engine == "gpt":
|
512
|
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
513
|
+
yield chunk
|
514
|
+
elif engine == "azure":
|
515
|
+
async for chunk in fetch_azure_response_stream(client, url, headers, payload):
|
516
|
+
yield chunk
|
517
|
+
elif engine == "openrouter":
|
518
|
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
519
|
+
yield chunk
|
520
|
+
elif engine == "cloudflare":
|
521
|
+
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
522
|
+
yield chunk
|
523
|
+
elif engine == "cohere":
|
524
|
+
async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
|
525
|
+
yield chunk
|
526
|
+
else:
|
527
|
+
raise ValueError("Unknown response")
|
528
|
+
# except httpx.ConnectError as e:
|
529
|
+
# yield {"error": f"500", "details": "fetch_response_stream Connect Error"}
|
530
|
+
# except httpx.ReadTimeout as e:
|
531
|
+
# yield {"error": f"500", "details": "fetch_response_stream Read Response Timeout"}
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from ..utils import BaseAPI
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
# GOOGLE_AI_API_KEY = os.environ.get('GOOGLE_AI_API_KEY', None)
|
6
|
+
|
7
|
+
# base_api = BaseAPI("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:streamGenerateContent?key=" + GOOGLE_AI_API_KEY)
|
8
|
+
|
9
|
+
# print(base_api.chat_url)
|
10
|
+
|
11
|
+
base_api = BaseAPI("http://127.0.0.1:8000/v1")
|
12
|
+
|
13
|
+
print(base_api.chat_url)
|
14
|
+
|
15
|
+
"""
|
16
|
+
python -m core.test.test_base_api
|
17
|
+
"""
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from ..request import get_image_message
|
2
|
+
import os
|
3
|
+
import asyncio
|
4
|
+
IMAGE_URL = os.getenv("IMAGE_URL")
|
5
|
+
|
6
|
+
async def test_image():
|
7
|
+
image_message = await get_image_message(IMAGE_URL, engine="gemini")
|
8
|
+
print(image_message)
|
9
|
+
|
10
|
+
if __name__ == "__main__":
|
11
|
+
asyncio.run(test_image())
|
12
|
+
|
13
|
+
'''
|
14
|
+
python -m core.test.test_image
|
15
|
+
'''
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import json
|
2
|
+
import asyncio
|
3
|
+
|
4
|
+
from ..request import prepare_request_payload
|
5
|
+
|
6
|
+
"""
|
7
|
+
测试脚本: 用于测试core/request.py中的get_payload函数
|
8
|
+
|
9
|
+
该测试脚本构造core/models.py中的RequestModel对象,然后调用core/request.py的get_payload函数,
|
10
|
+
返回url, headers, payload。通过这种方式可以单独测试get_payload模块的功能。
|
11
|
+
|
12
|
+
测试用例: 带工具函数的模型调用测试
|
13
|
+
|
14
|
+
使用的API配置信息来自api.yaml中的'new-i1-pe'。
|
15
|
+
|
16
|
+
python -m core.test_payload
|
17
|
+
"""
|
18
|
+
|
19
|
+
async def test_payload():
|
20
|
+
print("===== 开始测试 get_payload 函数 =====")
|
21
|
+
|
22
|
+
# 步骤1: 配置provider
|
23
|
+
provider = {
|
24
|
+
"provider": "new",
|
25
|
+
"base_url": "/v1/chat/completions",
|
26
|
+
"api": "",
|
27
|
+
"model": [
|
28
|
+
"gpt-4" # 使用支持工具功能的模型名称
|
29
|
+
],
|
30
|
+
"tools": True # 启用工具支持
|
31
|
+
}
|
32
|
+
|
33
|
+
request_data = {
|
34
|
+
"model": "gpt-4",
|
35
|
+
"messages": [
|
36
|
+
{
|
37
|
+
"role": "system",
|
38
|
+
"content": "你是一个有用的AI助手。"
|
39
|
+
},
|
40
|
+
{
|
41
|
+
"role": "user",
|
42
|
+
"content": "你好,请介绍一下自己。"
|
43
|
+
}
|
44
|
+
],
|
45
|
+
"stream": True,
|
46
|
+
"temperature": 0.7,
|
47
|
+
"max_tokens": 1000,
|
48
|
+
"tools": [
|
49
|
+
{
|
50
|
+
"type": "function",
|
51
|
+
"function": {
|
52
|
+
"name": "get_current_weather",
|
53
|
+
"description": "获取当前天气信息",
|
54
|
+
"parameters": {
|
55
|
+
"type": "object",
|
56
|
+
"properties": {
|
57
|
+
"location": {
|
58
|
+
"type": "string",
|
59
|
+
"description": "城市名称,例如:北京"
|
60
|
+
},
|
61
|
+
"unit": {
|
62
|
+
"type": "string",
|
63
|
+
"enum": ["celsius", "fahrenheit"],
|
64
|
+
"description": "温度单位"
|
65
|
+
}
|
66
|
+
},
|
67
|
+
"required": ["location"]
|
68
|
+
}
|
69
|
+
}
|
70
|
+
}
|
71
|
+
],
|
72
|
+
"tool_choice": "auto" # 添加工具选择参数
|
73
|
+
}
|
74
|
+
|
75
|
+
# 调用函数处理请求并获取结果
|
76
|
+
url, headers, payload, engine = await prepare_request_payload(provider, request_data)
|
77
|
+
|
78
|
+
# 打印结果
|
79
|
+
print("\nURL:")
|
80
|
+
print(url)
|
81
|
+
|
82
|
+
print("\nHeaders:")
|
83
|
+
print(json.dumps(headers, indent=4, ensure_ascii=False))
|
84
|
+
|
85
|
+
print("\nPayload:")
|
86
|
+
print(json.dumps(payload, indent=4, ensure_ascii=False))
|
87
|
+
|
88
|
+
print("\n===== 测试完成 =====")
|
89
|
+
|
90
|
+
if __name__ == "__main__":
|
91
|
+
# 执行异步测试函数
|
92
|
+
asyncio.run(test_payload())
|