beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/main.py +50 -0
- beswarm/aient/setup.py +15 -0
- beswarm/aient/src/aient/__init__.py +1 -0
- beswarm/aient/src/aient/core/__init__.py +1 -0
- beswarm/aient/src/aient/core/log_config.py +6 -0
- beswarm/aient/src/aient/core/models.py +232 -0
- beswarm/aient/src/aient/core/request.py +1665 -0
- beswarm/aient/src/aient/core/response.py +617 -0
- beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
- beswarm/aient/src/aient/core/test/test_image.py +15 -0
- beswarm/aient/src/aient/core/test/test_payload.py +92 -0
- beswarm/aient/src/aient/core/utils.py +715 -0
- beswarm/aient/src/aient/models/__init__.py +9 -0
- beswarm/aient/src/aient/models/audio.py +63 -0
- beswarm/aient/src/aient/models/base.py +251 -0
- beswarm/aient/src/aient/models/chatgpt.py +938 -0
- beswarm/aient/src/aient/models/claude.py +640 -0
- beswarm/aient/src/aient/models/duckduckgo.py +241 -0
- beswarm/aient/src/aient/models/gemini.py +357 -0
- beswarm/aient/src/aient/models/groq.py +268 -0
- beswarm/aient/src/aient/models/vertex.py +420 -0
- beswarm/aient/src/aient/plugins/__init__.py +33 -0
- beswarm/aient/src/aient/plugins/arXiv.py +48 -0
- beswarm/aient/src/aient/plugins/config.py +172 -0
- beswarm/aient/src/aient/plugins/excute_command.py +35 -0
- beswarm/aient/src/aient/plugins/get_time.py +19 -0
- beswarm/aient/src/aient/plugins/image.py +72 -0
- beswarm/aient/src/aient/plugins/list_directory.py +50 -0
- beswarm/aient/src/aient/plugins/read_file.py +79 -0
- beswarm/aient/src/aient/plugins/registry.py +116 -0
- beswarm/aient/src/aient/plugins/run_python.py +156 -0
- beswarm/aient/src/aient/plugins/websearch.py +394 -0
- beswarm/aient/src/aient/plugins/write_file.py +51 -0
- beswarm/aient/src/aient/prompt/__init__.py +1 -0
- beswarm/aient/src/aient/prompt/agent.py +280 -0
- beswarm/aient/src/aient/utils/__init__.py +0 -0
- beswarm/aient/src/aient/utils/prompt.py +143 -0
- beswarm/aient/src/aient/utils/scripts.py +721 -0
- beswarm/aient/test/chatgpt.py +161 -0
- beswarm/aient/test/claude.py +32 -0
- beswarm/aient/test/test.py +2 -0
- beswarm/aient/test/test_API.py +6 -0
- beswarm/aient/test/test_Deepbricks.py +20 -0
- beswarm/aient/test/test_Web_crawler.py +262 -0
- beswarm/aient/test/test_aiwaves.py +25 -0
- beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
- beswarm/aient/test/test_ask_gemini.py +8 -0
- beswarm/aient/test/test_class.py +17 -0
- beswarm/aient/test/test_claude.py +23 -0
- beswarm/aient/test/test_claude_zh_char.py +26 -0
- beswarm/aient/test/test_ddg_search.py +50 -0
- beswarm/aient/test/test_download_pdf.py +56 -0
- beswarm/aient/test/test_gemini.py +97 -0
- beswarm/aient/test/test_get_token_dict.py +21 -0
- beswarm/aient/test/test_google_search.py +35 -0
- beswarm/aient/test/test_jieba.py +32 -0
- beswarm/aient/test/test_json.py +65 -0
- beswarm/aient/test/test_langchain_search_old.py +235 -0
- beswarm/aient/test/test_logging.py +32 -0
- beswarm/aient/test/test_ollama.py +55 -0
- beswarm/aient/test/test_plugin.py +16 -0
- beswarm/aient/test/test_py_run.py +26 -0
- beswarm/aient/test/test_requests.py +162 -0
- beswarm/aient/test/test_search.py +18 -0
- beswarm/aient/test/test_tikitoken.py +19 -0
- beswarm/aient/test/test_token.py +94 -0
- beswarm/aient/test/test_url.py +33 -0
- beswarm/aient/test/test_whisper.py +14 -0
- beswarm/aient/test/test_wildcard.py +20 -0
- beswarm/aient/test/test_yjh.py +21 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
- beswarm-0.1.13.dist-info/RECORD +131 -0
- beswarm-0.1.12.dist-info/RECORD +0 -61
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,617 @@
|
|
1
|
+
import re
|
2
|
+
import json
|
3
|
+
import random
|
4
|
+
import string
|
5
|
+
import base64
|
6
|
+
from datetime import datetime
|
7
|
+
|
8
|
+
from .log_config import logger
|
9
|
+
|
10
|
+
from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line, parse_json_safely
|
11
|
+
|
12
|
+
async def check_response(response, error_log):
|
13
|
+
if response and not (200 <= response.status_code < 300):
|
14
|
+
error_message = await response.aread()
|
15
|
+
error_str = error_message.decode('utf-8', errors='replace')
|
16
|
+
try:
|
17
|
+
error_json = json.loads(error_str)
|
18
|
+
except json.JSONDecodeError:
|
19
|
+
error_json = error_str
|
20
|
+
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
|
21
|
+
return None
|
22
|
+
|
23
|
+
async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
24
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
25
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
26
|
+
error_message = await check_response(response, "fetch_gemini_response_stream")
|
27
|
+
if error_message:
|
28
|
+
yield error_message
|
29
|
+
return
|
30
|
+
buffer = ""
|
31
|
+
revicing_function_call = False
|
32
|
+
function_full_response = "{"
|
33
|
+
need_function_call = False
|
34
|
+
is_finish = False
|
35
|
+
promptTokenCount = 0
|
36
|
+
candidatesTokenCount = 0
|
37
|
+
totalTokenCount = 0
|
38
|
+
# line_index = 0
|
39
|
+
# last_text_line = 0
|
40
|
+
# if "thinking" in model:
|
41
|
+
# is_thinking = True
|
42
|
+
# else:
|
43
|
+
# is_thinking = False
|
44
|
+
async for chunk in response.aiter_text():
|
45
|
+
buffer += chunk
|
46
|
+
|
47
|
+
while "\n" in buffer:
|
48
|
+
line, buffer = buffer.split("\n", 1)
|
49
|
+
# line_index += 1
|
50
|
+
|
51
|
+
# https://ai.google.dev/api/generate-content?hl=zh-cn#FinishReason
|
52
|
+
if line and '\"finishReason\": \"' in line:
|
53
|
+
if "stop" not in line.lower():
|
54
|
+
logger.error(f"finishReason: {line}")
|
55
|
+
is_finish = True
|
56
|
+
if is_finish and '\"promptTokenCount\": ' in line:
|
57
|
+
json_data = parse_json_safely( "{" + line + "}")
|
58
|
+
promptTokenCount = json_data.get('promptTokenCount', 0)
|
59
|
+
if is_finish and '\"candidatesTokenCount\": ' in line:
|
60
|
+
json_data = parse_json_safely( "{" + line + "}")
|
61
|
+
candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
|
62
|
+
if is_finish and '\"totalTokenCount\": ' in line:
|
63
|
+
json_data = parse_json_safely( "{" + line + "}")
|
64
|
+
totalTokenCount = json_data.get('totalTokenCount', 0)
|
65
|
+
|
66
|
+
# print(line)
|
67
|
+
if line and '\"text\": \"' in line and is_finish == False:
|
68
|
+
try:
|
69
|
+
json_data = json.loads( "{" + line + "}")
|
70
|
+
content = json_data.get('text', '')
|
71
|
+
# content = content.replace("\n", "\n\n")
|
72
|
+
# if last_text_line == 0 and is_thinking:
|
73
|
+
# content = "> " + content.lstrip()
|
74
|
+
# if is_thinking:
|
75
|
+
# content = content.replace("\n", "\n> ")
|
76
|
+
# if last_text_line == line_index - 3:
|
77
|
+
# is_thinking = False
|
78
|
+
# content = "\n\n\n" + content.lstrip()
|
79
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
80
|
+
yield sse_string
|
81
|
+
except json.JSONDecodeError:
|
82
|
+
logger.error(f"无法解析JSON: {line}")
|
83
|
+
# last_text_line = line_index
|
84
|
+
|
85
|
+
if line and ('\"functionCall\": {' in line or revicing_function_call):
|
86
|
+
revicing_function_call = True
|
87
|
+
need_function_call = True
|
88
|
+
if ']' in line:
|
89
|
+
revicing_function_call = False
|
90
|
+
continue
|
91
|
+
|
92
|
+
function_full_response += line
|
93
|
+
|
94
|
+
if need_function_call:
|
95
|
+
function_call = json.loads(function_full_response)
|
96
|
+
function_call_name = function_call["functionCall"]["name"]
|
97
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
|
98
|
+
yield sse_string
|
99
|
+
function_full_response = json.dumps(function_call["functionCall"]["args"])
|
100
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
101
|
+
yield sse_string
|
102
|
+
|
103
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
|
104
|
+
yield sse_string
|
105
|
+
|
106
|
+
yield "data: [DONE]" + end_of_line
|
107
|
+
|
108
|
+
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
109
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
110
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
111
|
+
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
|
112
|
+
if error_message:
|
113
|
+
yield error_message
|
114
|
+
return
|
115
|
+
|
116
|
+
buffer = ""
|
117
|
+
revicing_function_call = False
|
118
|
+
function_full_response = "{"
|
119
|
+
need_function_call = False
|
120
|
+
is_finish = False
|
121
|
+
promptTokenCount = 0
|
122
|
+
candidatesTokenCount = 0
|
123
|
+
totalTokenCount = 0
|
124
|
+
|
125
|
+
async for chunk in response.aiter_text():
|
126
|
+
buffer += chunk
|
127
|
+
while "\n" in buffer:
|
128
|
+
line, buffer = buffer.split("\n", 1)
|
129
|
+
# logger.info(f"{line}")
|
130
|
+
|
131
|
+
if line and '\"finishReason\": \"' in line:
|
132
|
+
is_finish = True
|
133
|
+
if is_finish and '\"promptTokenCount\": ' in line:
|
134
|
+
json_data = parse_json_safely( "{" + line + "}")
|
135
|
+
promptTokenCount = json_data.get('promptTokenCount', 0)
|
136
|
+
if is_finish and '\"candidatesTokenCount\": ' in line:
|
137
|
+
json_data = parse_json_safely( "{" + line + "}")
|
138
|
+
candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
|
139
|
+
if is_finish and '\"totalTokenCount\": ' in line:
|
140
|
+
json_data = parse_json_safely( "{" + line + "}")
|
141
|
+
totalTokenCount = json_data.get('totalTokenCount', 0)
|
142
|
+
|
143
|
+
if line and '\"text\": \"' in line and is_finish == False:
|
144
|
+
try:
|
145
|
+
json_data = json.loads( "{" + line + "}")
|
146
|
+
content = json_data.get('text', '')
|
147
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
148
|
+
yield sse_string
|
149
|
+
except json.JSONDecodeError:
|
150
|
+
logger.error(f"无法解析JSON: {line}")
|
151
|
+
|
152
|
+
if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
|
153
|
+
revicing_function_call = True
|
154
|
+
need_function_call = True
|
155
|
+
if ']' in line:
|
156
|
+
revicing_function_call = False
|
157
|
+
continue
|
158
|
+
|
159
|
+
function_full_response += line
|
160
|
+
|
161
|
+
if need_function_call:
|
162
|
+
function_call = json.loads(function_full_response)
|
163
|
+
function_call_name = function_call["name"]
|
164
|
+
function_call_id = function_call["id"]
|
165
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
166
|
+
yield sse_string
|
167
|
+
function_full_response = json.dumps(function_call["input"])
|
168
|
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
|
169
|
+
yield sse_string
|
170
|
+
|
171
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
|
172
|
+
yield sse_string
|
173
|
+
|
174
|
+
yield "data: [DONE]" + end_of_line
|
175
|
+
|
176
|
+
async def fetch_gpt_response_stream(client, url, headers, payload):
|
177
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
178
|
+
random.seed(timestamp)
|
179
|
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
180
|
+
is_thinking = False
|
181
|
+
has_send_thinking = False
|
182
|
+
ark_tag = False
|
183
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
184
|
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
185
|
+
if error_message:
|
186
|
+
yield error_message
|
187
|
+
return
|
188
|
+
|
189
|
+
buffer = ""
|
190
|
+
enter_buffer = ""
|
191
|
+
async for chunk in response.aiter_text():
|
192
|
+
buffer += chunk
|
193
|
+
while "\n" in buffer:
|
194
|
+
line, buffer = buffer.split("\n", 1)
|
195
|
+
# logger.info("line: %s", repr(line))
|
196
|
+
if line and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
|
197
|
+
if result.strip() == "[DONE]":
|
198
|
+
break
|
199
|
+
line = json.loads(result)
|
200
|
+
line['id'] = f"chatcmpl-{random_str}"
|
201
|
+
|
202
|
+
# 处理 <think> 标签
|
203
|
+
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
204
|
+
if "<think>" in content:
|
205
|
+
is_thinking = True
|
206
|
+
ark_tag = True
|
207
|
+
content = content.replace("<think>", "")
|
208
|
+
if "</think>" in content:
|
209
|
+
end_think_reasoning_content = ""
|
210
|
+
end_think_content = ""
|
211
|
+
is_thinking = False
|
212
|
+
|
213
|
+
if content.rstrip('\n').endswith("</think>"):
|
214
|
+
end_think_reasoning_content = content.replace("</think>", "").rstrip('\n')
|
215
|
+
elif content.lstrip('\n').startswith("</think>"):
|
216
|
+
end_think_content = content.replace("</think>", "").lstrip('\n')
|
217
|
+
else:
|
218
|
+
end_think_reasoning_content = content.split("</think>")[0]
|
219
|
+
end_think_content = content.split("</think>")[1]
|
220
|
+
|
221
|
+
if end_think_reasoning_content:
|
222
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=end_think_reasoning_content)
|
223
|
+
yield sse_string
|
224
|
+
if end_think_content:
|
225
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], content=end_think_content)
|
226
|
+
yield sse_string
|
227
|
+
continue
|
228
|
+
if is_thinking and ark_tag:
|
229
|
+
if not has_send_thinking:
|
230
|
+
content = content.replace("\n\n", "")
|
231
|
+
if content:
|
232
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
233
|
+
yield sse_string
|
234
|
+
has_send_thinking = True
|
235
|
+
continue
|
236
|
+
|
237
|
+
# 处理 poe thinking 标签
|
238
|
+
if "Thinking..." in content and "\n> " in content:
|
239
|
+
is_thinking = True
|
240
|
+
content = content.replace("Thinking...", "").replace("\n> ", "")
|
241
|
+
if is_thinking and "\n\n" in content and not ark_tag:
|
242
|
+
is_thinking = False
|
243
|
+
if is_thinking and not ark_tag:
|
244
|
+
content = content.replace("\n> ", "")
|
245
|
+
if not has_send_thinking:
|
246
|
+
content = content.replace("\n", "")
|
247
|
+
if content:
|
248
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
249
|
+
yield sse_string
|
250
|
+
has_send_thinking = True
|
251
|
+
continue
|
252
|
+
|
253
|
+
no_stream_content = safe_get(line, "choices", 0, "message", "content", default=None)
|
254
|
+
openrouter_reasoning = safe_get(line, "choices", 0, "delta", "reasoning", default="")
|
255
|
+
# print("openrouter_reasoning", repr(openrouter_reasoning), openrouter_reasoning.endswith("\\\\"), openrouter_reasoning.endswith("\\"))
|
256
|
+
if openrouter_reasoning:
|
257
|
+
if openrouter_reasoning.endswith("\\"):
|
258
|
+
enter_buffer += openrouter_reasoning
|
259
|
+
continue
|
260
|
+
elif enter_buffer.endswith("\\") and openrouter_reasoning == 'n':
|
261
|
+
enter_buffer += "n"
|
262
|
+
continue
|
263
|
+
elif enter_buffer.endswith("\\n") and openrouter_reasoning == '\\n':
|
264
|
+
enter_buffer += "\\n"
|
265
|
+
continue
|
266
|
+
elif enter_buffer.endswith("\\n\\n"):
|
267
|
+
openrouter_reasoning = '\n\n' + openrouter_reasoning
|
268
|
+
enter_buffer = ""
|
269
|
+
elif enter_buffer:
|
270
|
+
openrouter_reasoning = enter_buffer + openrouter_reasoning
|
271
|
+
enter_buffer = ''
|
272
|
+
openrouter_reasoning = openrouter_reasoning.replace("\\n", "\n")
|
273
|
+
|
274
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=openrouter_reasoning)
|
275
|
+
yield sse_string
|
276
|
+
elif no_stream_content and has_send_thinking == False:
|
277
|
+
sse_string = await generate_sse_response(safe_get(line, "created", default=None), safe_get(line, "model", default=None), content=no_stream_content)
|
278
|
+
yield sse_string
|
279
|
+
else:
|
280
|
+
if no_stream_content:
|
281
|
+
del line["choices"][0]["message"]
|
282
|
+
yield "data: " + json.dumps(line).strip() + end_of_line
|
283
|
+
yield "data: [DONE]" + end_of_line
|
284
|
+
|
285
|
+
async def fetch_azure_response_stream(client, url, headers, payload):
|
286
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
287
|
+
is_thinking = False
|
288
|
+
has_send_thinking = False
|
289
|
+
ark_tag = False
|
290
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
291
|
+
error_message = await check_response(response, "fetch_azure_response_stream")
|
292
|
+
if error_message:
|
293
|
+
yield error_message
|
294
|
+
return
|
295
|
+
|
296
|
+
buffer = ""
|
297
|
+
sse_string = ""
|
298
|
+
async for chunk in response.aiter_text():
|
299
|
+
buffer += chunk
|
300
|
+
while "\n" in buffer:
|
301
|
+
line, buffer = buffer.split("\n", 1)
|
302
|
+
# logger.info("line: %s", repr(line))
|
303
|
+
if line and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
|
304
|
+
if result.strip() == "[DONE]":
|
305
|
+
break
|
306
|
+
line = json.loads(result)
|
307
|
+
no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
|
308
|
+
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
309
|
+
|
310
|
+
# 处理 <think> 标签
|
311
|
+
if "<think>" in content:
|
312
|
+
is_thinking = True
|
313
|
+
ark_tag = True
|
314
|
+
content = content.replace("<think>", "")
|
315
|
+
if "</think>" in content:
|
316
|
+
is_thinking = False
|
317
|
+
content = content.replace("</think>", "")
|
318
|
+
if not content:
|
319
|
+
continue
|
320
|
+
if is_thinking and ark_tag:
|
321
|
+
if not has_send_thinking:
|
322
|
+
content = content.replace("\n\n", "")
|
323
|
+
if content:
|
324
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
|
325
|
+
yield sse_string
|
326
|
+
has_send_thinking = True
|
327
|
+
continue
|
328
|
+
|
329
|
+
if no_stream_content or content or sse_string:
|
330
|
+
sse_string = await generate_sse_response(timestamp, safe_get(line, "model", default=None), content=no_stream_content or content)
|
331
|
+
yield sse_string
|
332
|
+
else:
|
333
|
+
if no_stream_content:
|
334
|
+
del line["choices"][0]["message"]
|
335
|
+
yield "data: " + json.dumps(line).strip() + end_of_line
|
336
|
+
yield "data: [DONE]" + end_of_line
|
337
|
+
|
338
|
+
async def fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
339
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
340
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
341
|
+
error_message = await check_response(response, "fetch_cloudflare_response_stream")
|
342
|
+
if error_message:
|
343
|
+
yield error_message
|
344
|
+
return
|
345
|
+
|
346
|
+
buffer = ""
|
347
|
+
async for chunk in response.aiter_text():
|
348
|
+
buffer += chunk
|
349
|
+
while "\n" in buffer:
|
350
|
+
line, buffer = buffer.split("\n", 1)
|
351
|
+
# logger.info("line: %s", repr(line))
|
352
|
+
if line.startswith("data:"):
|
353
|
+
line = line.lstrip("data: ")
|
354
|
+
if line == "[DONE]":
|
355
|
+
break
|
356
|
+
resp: dict = json.loads(line)
|
357
|
+
message = resp.get("response")
|
358
|
+
if message:
|
359
|
+
sse_string = await generate_sse_response(timestamp, model, content=message)
|
360
|
+
yield sse_string
|
361
|
+
yield "data: [DONE]" + end_of_line
|
362
|
+
|
363
|
+
async def fetch_cohere_response_stream(client, url, headers, payload, model):
|
364
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
365
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
366
|
+
error_message = await check_response(response, "fetch_gpt_response_stream")
|
367
|
+
if error_message:
|
368
|
+
yield error_message
|
369
|
+
return
|
370
|
+
|
371
|
+
buffer = ""
|
372
|
+
async for chunk in response.aiter_text():
|
373
|
+
buffer += chunk
|
374
|
+
while "\n" in buffer:
|
375
|
+
line, buffer = buffer.split("\n", 1)
|
376
|
+
# logger.info("line: %s", repr(line))
|
377
|
+
resp: dict = json.loads(line)
|
378
|
+
if resp.get("is_finished") == True:
|
379
|
+
break
|
380
|
+
if resp.get("event_type") == "text-generation":
|
381
|
+
message = resp.get("text")
|
382
|
+
sse_string = await generate_sse_response(timestamp, model, content=message)
|
383
|
+
yield sse_string
|
384
|
+
yield "data: [DONE]" + end_of_line
|
385
|
+
|
386
|
+
async def fetch_claude_response_stream(client, url, headers, payload, model):
|
387
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
388
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
389
|
+
error_message = await check_response(response, "fetch_claude_response_stream")
|
390
|
+
if error_message:
|
391
|
+
yield error_message
|
392
|
+
return
|
393
|
+
buffer = ""
|
394
|
+
input_tokens = 0
|
395
|
+
async for chunk in response.aiter_text():
|
396
|
+
# logger.info(f"chunk: {repr(chunk)}")
|
397
|
+
buffer += chunk
|
398
|
+
while "\n" in buffer:
|
399
|
+
line, buffer = buffer.split("\n", 1)
|
400
|
+
# logger.info(line)
|
401
|
+
|
402
|
+
if line.startswith("data:"):
|
403
|
+
line = line.lstrip("data: ")
|
404
|
+
resp: dict = json.loads(line)
|
405
|
+
message = resp.get("message")
|
406
|
+
if message:
|
407
|
+
role = message.get("role")
|
408
|
+
if role:
|
409
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
410
|
+
yield sse_string
|
411
|
+
tokens_use = message.get("usage")
|
412
|
+
if tokens_use:
|
413
|
+
input_tokens = tokens_use.get("input_tokens", 0)
|
414
|
+
usage = resp.get("usage")
|
415
|
+
if usage:
|
416
|
+
output_tokens = usage.get("output_tokens", 0)
|
417
|
+
total_tokens = input_tokens + output_tokens
|
418
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
|
419
|
+
yield sse_string
|
420
|
+
# print("\n\rtotal_tokens", total_tokens)
|
421
|
+
|
422
|
+
tool_use = resp.get("content_block")
|
423
|
+
tools_id = None
|
424
|
+
function_call_name = None
|
425
|
+
if tool_use and "tool_use" == tool_use['type']:
|
426
|
+
# print("tool_use", tool_use)
|
427
|
+
tools_id = tool_use["id"]
|
428
|
+
if "name" in tool_use:
|
429
|
+
function_call_name = tool_use["name"]
|
430
|
+
sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
|
431
|
+
yield sse_string
|
432
|
+
delta = resp.get("delta")
|
433
|
+
# print("delta", delta)
|
434
|
+
if not delta:
|
435
|
+
continue
|
436
|
+
if "text" in delta:
|
437
|
+
content = delta["text"]
|
438
|
+
sse_string = await generate_sse_response(timestamp, model, content, None, None)
|
439
|
+
yield sse_string
|
440
|
+
if "thinking" in delta and delta["thinking"]:
|
441
|
+
content = delta["thinking"]
|
442
|
+
sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
|
443
|
+
yield sse_string
|
444
|
+
if "partial_json" in delta:
|
445
|
+
# {"type":"input_json_delta","partial_json":""}
|
446
|
+
function_call_content = delta["partial_json"]
|
447
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
|
448
|
+
yield sse_string
|
449
|
+
yield "data: [DONE]" + end_of_line
|
450
|
+
|
451
|
+
async def fetch_aws_response_stream(client, url, headers, payload, model):
|
452
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
453
|
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
454
|
+
error_message = await check_response(response, "fetch_aws_response_stream")
|
455
|
+
if error_message:
|
456
|
+
yield error_message
|
457
|
+
return
|
458
|
+
|
459
|
+
buffer = ""
|
460
|
+
async for line in response.aiter_text():
|
461
|
+
buffer += line
|
462
|
+
while "\r" in buffer:
|
463
|
+
line, buffer = buffer.split("\r", 1)
|
464
|
+
if not line or \
|
465
|
+
line.strip() == "" or \
|
466
|
+
line.strip().startswith(':content-type') or \
|
467
|
+
line.strip().startswith(':event-type'): # 过滤掉完全空的行或只有空白的行
|
468
|
+
continue
|
469
|
+
|
470
|
+
json_match = re.search(r'event{.*?}', line)
|
471
|
+
if not json_match:
|
472
|
+
continue
|
473
|
+
try:
|
474
|
+
chunk_data = json.loads(json_match.group(0).lstrip('event'))
|
475
|
+
except json.JSONDecodeError:
|
476
|
+
logger.error(f"DEBUG json.JSONDecodeError: {json_match.group(0).lstrip('event')!r}")
|
477
|
+
continue
|
478
|
+
|
479
|
+
# --- 后续处理逻辑不变 ---
|
480
|
+
if "bytes" in chunk_data:
|
481
|
+
# 解码 Base64 编码的字节
|
482
|
+
decoded_bytes = base64.b64decode(chunk_data["bytes"])
|
483
|
+
# 将解码后的字节再次解析为 JSON
|
484
|
+
payload_chunk = json.loads(decoded_bytes.decode('utf-8'))
|
485
|
+
# print(f"DEBUG payload_chunk: {payload_chunk!r}")
|
486
|
+
|
487
|
+
text = safe_get(payload_chunk, "delta", "text", default="")
|
488
|
+
if text:
|
489
|
+
sse_string = await generate_sse_response(timestamp, model, text, None, None)
|
490
|
+
yield sse_string
|
491
|
+
|
492
|
+
usage = safe_get(payload_chunk, "amazon-bedrock-invocationMetrics", default="")
|
493
|
+
if usage:
|
494
|
+
input_tokens = usage.get("inputTokenCount", 0)
|
495
|
+
output_tokens = usage.get("outputTokenCount", 0)
|
496
|
+
total_tokens = input_tokens + output_tokens
|
497
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
|
498
|
+
yield sse_string
|
499
|
+
|
500
|
+
yield "data: [DONE]" + end_of_line
|
501
|
+
|
502
|
+
async def fetch_response(client, url, headers, payload, engine, model):
|
503
|
+
response = None
|
504
|
+
if payload.get("file"):
|
505
|
+
file = payload.pop("file")
|
506
|
+
response = await client.post(url, headers=headers, data=payload, files={"file": file})
|
507
|
+
else:
|
508
|
+
response = await client.post(url, headers=headers, json=payload)
|
509
|
+
error_message = await check_response(response, "fetch_response")
|
510
|
+
if error_message:
|
511
|
+
yield error_message
|
512
|
+
return
|
513
|
+
|
514
|
+
if engine == "tts":
|
515
|
+
yield response.read()
|
516
|
+
|
517
|
+
elif engine == "gemini" or engine == "vertex-gemini" or engine == "aws":
|
518
|
+
response_json = response.json()
|
519
|
+
|
520
|
+
if isinstance(response_json, str):
|
521
|
+
import ast
|
522
|
+
parsed_data = ast.literal_eval(str(response_json))
|
523
|
+
elif isinstance(response_json, list):
|
524
|
+
parsed_data = response_json
|
525
|
+
else:
|
526
|
+
logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
|
527
|
+
parsed_data = response_json
|
528
|
+
# print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
|
529
|
+
content = ""
|
530
|
+
for item in parsed_data:
|
531
|
+
chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
|
532
|
+
# logger.info(f"chunk: {repr(chunk)}")
|
533
|
+
if chunk:
|
534
|
+
content += chunk
|
535
|
+
|
536
|
+
usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
|
537
|
+
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
538
|
+
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
539
|
+
total_tokens = usage_metadata.get("totalTokenCount", 0)
|
540
|
+
|
541
|
+
role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
|
542
|
+
if role == "model":
|
543
|
+
role = "assistant"
|
544
|
+
else:
|
545
|
+
logger.error(f"Unknown role: {role}, parsed_data: {parsed_data}")
|
546
|
+
role = "assistant"
|
547
|
+
|
548
|
+
function_call_name = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "name", default=None)
|
549
|
+
function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
|
550
|
+
|
551
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
552
|
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
|
553
|
+
|
554
|
+
elif engine == "claude":
|
555
|
+
response_json = response.json()
|
556
|
+
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
557
|
+
|
558
|
+
content = safe_get(response_json, "content", 0, "text")
|
559
|
+
|
560
|
+
prompt_tokens = safe_get(response_json, "usage", "input_tokens")
|
561
|
+
output_tokens = safe_get(response_json, "usage", "output_tokens")
|
562
|
+
total_tokens = prompt_tokens + output_tokens
|
563
|
+
|
564
|
+
role = safe_get(response_json, "role")
|
565
|
+
|
566
|
+
function_call_name = safe_get(response_json, "content", 1, "name", default=None)
|
567
|
+
function_call_content = safe_get(response_json, "content", 1, "input", default=None)
|
568
|
+
tools_id = safe_get(response_json, "content", 1, "id", default=None)
|
569
|
+
|
570
|
+
timestamp = int(datetime.timestamp(datetime.now()))
|
571
|
+
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
|
572
|
+
|
573
|
+
elif engine == "azure":
|
574
|
+
response_json = response.json()
|
575
|
+
# 删除 content_filter_results
|
576
|
+
if "choices" in response_json:
|
577
|
+
for choice in response_json["choices"]:
|
578
|
+
if "content_filter_results" in choice:
|
579
|
+
del choice["content_filter_results"]
|
580
|
+
|
581
|
+
# 删除 prompt_filter_results
|
582
|
+
if "prompt_filter_results" in response_json:
|
583
|
+
del response_json["prompt_filter_results"]
|
584
|
+
|
585
|
+
yield response_json
|
586
|
+
|
587
|
+
else:
|
588
|
+
response_json = response.json()
|
589
|
+
yield response_json
|
590
|
+
|
591
|
+
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
592
|
+
if engine == "gemini" or engine == "vertex-gemini":
|
593
|
+
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
594
|
+
yield chunk
|
595
|
+
elif engine == "claude" or engine == "vertex-claude":
|
596
|
+
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
597
|
+
yield chunk
|
598
|
+
elif engine == "aws":
|
599
|
+
async for chunk in fetch_aws_response_stream(client, url, headers, payload, model):
|
600
|
+
yield chunk
|
601
|
+
elif engine == "gpt":
|
602
|
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
603
|
+
yield chunk
|
604
|
+
elif engine == "azure":
|
605
|
+
async for chunk in fetch_azure_response_stream(client, url, headers, payload):
|
606
|
+
yield chunk
|
607
|
+
elif engine == "openrouter":
|
608
|
+
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
609
|
+
yield chunk
|
610
|
+
elif engine == "cloudflare":
|
611
|
+
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
612
|
+
yield chunk
|
613
|
+
elif engine == "cohere":
|
614
|
+
async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
|
615
|
+
yield chunk
|
616
|
+
else:
|
617
|
+
raise ValueError("Unknown response")
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from ..utils import BaseAPI
|
2
|
+
|
3
|
+
# GOOGLE_AI_API_KEY = os.environ.get('GOOGLE_AI_API_KEY', None)
|
4
|
+
|
5
|
+
# base_api = BaseAPI("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:streamGenerateContent?key=" + GOOGLE_AI_API_KEY)
|
6
|
+
|
7
|
+
# print(base_api.chat_url)
|
8
|
+
|
9
|
+
base_api = BaseAPI("http://127.0.0.1:8000/v1")
|
10
|
+
print(base_api.chat_url)
|
11
|
+
|
12
|
+
base_api = BaseAPI("http://127.0.0.1:8000/v1/images/generations")
|
13
|
+
print(base_api.image_url)
|
14
|
+
|
15
|
+
"""
|
16
|
+
python -m core.test.test_base_api
|
17
|
+
python -m aient.src.aient.core.test.test_base_api
|
18
|
+
"""
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from ..request import get_image_message
|
2
|
+
import os
|
3
|
+
import asyncio
|
4
|
+
IMAGE_URL = os.getenv("IMAGE_URL")
|
5
|
+
|
6
|
+
async def test_image():
|
7
|
+
image_message = await get_image_message(IMAGE_URL, engine="gemini")
|
8
|
+
print(image_message)
|
9
|
+
|
10
|
+
if __name__ == "__main__":
|
11
|
+
asyncio.run(test_image())
|
12
|
+
|
13
|
+
'''
|
14
|
+
python -m core.test.test_image
|
15
|
+
'''
|