beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +938 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
  72. beswarm-0.1.13.dist-info/RECORD +131 -0
  73. beswarm-0.1.12.dist-info/RECORD +0 -61
  74. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,617 @@
1
+ import re
2
+ import json
3
+ import random
4
+ import string
5
+ import base64
6
+ from datetime import datetime
7
+
8
+ from .log_config import logger
9
+
10
+ from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line, parse_json_safely
11
+
12
+ async def check_response(response, error_log):
13
+ if response and not (200 <= response.status_code < 300):
14
+ error_message = await response.aread()
15
+ error_str = error_message.decode('utf-8', errors='replace')
16
+ try:
17
+ error_json = json.loads(error_str)
18
+ except json.JSONDecodeError:
19
+ error_json = error_str
20
+ return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
21
+ return None
22
+
23
+ async def fetch_gemini_response_stream(client, url, headers, payload, model):
24
+ timestamp = int(datetime.timestamp(datetime.now()))
25
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
26
+ error_message = await check_response(response, "fetch_gemini_response_stream")
27
+ if error_message:
28
+ yield error_message
29
+ return
30
+ buffer = ""
31
+ revicing_function_call = False
32
+ function_full_response = "{"
33
+ need_function_call = False
34
+ is_finish = False
35
+ promptTokenCount = 0
36
+ candidatesTokenCount = 0
37
+ totalTokenCount = 0
38
+ # line_index = 0
39
+ # last_text_line = 0
40
+ # if "thinking" in model:
41
+ # is_thinking = True
42
+ # else:
43
+ # is_thinking = False
44
+ async for chunk in response.aiter_text():
45
+ buffer += chunk
46
+
47
+ while "\n" in buffer:
48
+ line, buffer = buffer.split("\n", 1)
49
+ # line_index += 1
50
+
51
+ # https://ai.google.dev/api/generate-content?hl=zh-cn#FinishReason
52
+ if line and '\"finishReason\": \"' in line:
53
+ if "stop" not in line.lower():
54
+ logger.error(f"finishReason: {line}")
55
+ is_finish = True
56
+ if is_finish and '\"promptTokenCount\": ' in line:
57
+ json_data = parse_json_safely( "{" + line + "}")
58
+ promptTokenCount = json_data.get('promptTokenCount', 0)
59
+ if is_finish and '\"candidatesTokenCount\": ' in line:
60
+ json_data = parse_json_safely( "{" + line + "}")
61
+ candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
62
+ if is_finish and '\"totalTokenCount\": ' in line:
63
+ json_data = parse_json_safely( "{" + line + "}")
64
+ totalTokenCount = json_data.get('totalTokenCount', 0)
65
+
66
+ # print(line)
67
+ if line and '\"text\": \"' in line and is_finish == False:
68
+ try:
69
+ json_data = json.loads( "{" + line + "}")
70
+ content = json_data.get('text', '')
71
+ # content = content.replace("\n", "\n\n")
72
+ # if last_text_line == 0 and is_thinking:
73
+ # content = "> " + content.lstrip()
74
+ # if is_thinking:
75
+ # content = content.replace("\n", "\n> ")
76
+ # if last_text_line == line_index - 3:
77
+ # is_thinking = False
78
+ # content = "\n\n\n" + content.lstrip()
79
+ sse_string = await generate_sse_response(timestamp, model, content=content)
80
+ yield sse_string
81
+ except json.JSONDecodeError:
82
+ logger.error(f"无法解析JSON: {line}")
83
+ # last_text_line = line_index
84
+
85
+ if line and ('\"functionCall\": {' in line or revicing_function_call):
86
+ revicing_function_call = True
87
+ need_function_call = True
88
+ if ']' in line:
89
+ revicing_function_call = False
90
+ continue
91
+
92
+ function_full_response += line
93
+
94
+ if need_function_call:
95
+ function_call = json.loads(function_full_response)
96
+ function_call_name = function_call["functionCall"]["name"]
97
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
98
+ yield sse_string
99
+ function_full_response = json.dumps(function_call["functionCall"]["args"])
100
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
101
+ yield sse_string
102
+
103
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
104
+ yield sse_string
105
+
106
+ yield "data: [DONE]" + end_of_line
107
+
108
+ async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
109
+ timestamp = int(datetime.timestamp(datetime.now()))
110
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
111
+ error_message = await check_response(response, "fetch_vertex_claude_response_stream")
112
+ if error_message:
113
+ yield error_message
114
+ return
115
+
116
+ buffer = ""
117
+ revicing_function_call = False
118
+ function_full_response = "{"
119
+ need_function_call = False
120
+ is_finish = False
121
+ promptTokenCount = 0
122
+ candidatesTokenCount = 0
123
+ totalTokenCount = 0
124
+
125
+ async for chunk in response.aiter_text():
126
+ buffer += chunk
127
+ while "\n" in buffer:
128
+ line, buffer = buffer.split("\n", 1)
129
+ # logger.info(f"{line}")
130
+
131
+ if line and '\"finishReason\": \"' in line:
132
+ is_finish = True
133
+ if is_finish and '\"promptTokenCount\": ' in line:
134
+ json_data = parse_json_safely( "{" + line + "}")
135
+ promptTokenCount = json_data.get('promptTokenCount', 0)
136
+ if is_finish and '\"candidatesTokenCount\": ' in line:
137
+ json_data = parse_json_safely( "{" + line + "}")
138
+ candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
139
+ if is_finish and '\"totalTokenCount\": ' in line:
140
+ json_data = parse_json_safely( "{" + line + "}")
141
+ totalTokenCount = json_data.get('totalTokenCount', 0)
142
+
143
+ if line and '\"text\": \"' in line and is_finish == False:
144
+ try:
145
+ json_data = json.loads( "{" + line + "}")
146
+ content = json_data.get('text', '')
147
+ sse_string = await generate_sse_response(timestamp, model, content=content)
148
+ yield sse_string
149
+ except json.JSONDecodeError:
150
+ logger.error(f"无法解析JSON: {line}")
151
+
152
+ if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
153
+ revicing_function_call = True
154
+ need_function_call = True
155
+ if ']' in line:
156
+ revicing_function_call = False
157
+ continue
158
+
159
+ function_full_response += line
160
+
161
+ if need_function_call:
162
+ function_call = json.loads(function_full_response)
163
+ function_call_name = function_call["name"]
164
+ function_call_id = function_call["id"]
165
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
166
+ yield sse_string
167
+ function_full_response = json.dumps(function_call["input"])
168
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
169
+ yield sse_string
170
+
171
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
172
+ yield sse_string
173
+
174
+ yield "data: [DONE]" + end_of_line
175
+
176
+ async def fetch_gpt_response_stream(client, url, headers, payload):
177
+ timestamp = int(datetime.timestamp(datetime.now()))
178
+ random.seed(timestamp)
179
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
180
+ is_thinking = False
181
+ has_send_thinking = False
182
+ ark_tag = False
183
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
184
+ error_message = await check_response(response, "fetch_gpt_response_stream")
185
+ if error_message:
186
+ yield error_message
187
+ return
188
+
189
+ buffer = ""
190
+ enter_buffer = ""
191
+ async for chunk in response.aiter_text():
192
+ buffer += chunk
193
+ while "\n" in buffer:
194
+ line, buffer = buffer.split("\n", 1)
195
+ # logger.info("line: %s", repr(line))
196
+ if line and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
197
+ if result.strip() == "[DONE]":
198
+ break
199
+ line = json.loads(result)
200
+ line['id'] = f"chatcmpl-{random_str}"
201
+
202
+ # 处理 <think> 标签
203
+ content = safe_get(line, "choices", 0, "delta", "content", default="")
204
+ if "<think>" in content:
205
+ is_thinking = True
206
+ ark_tag = True
207
+ content = content.replace("<think>", "")
208
+ if "</think>" in content:
209
+ end_think_reasoning_content = ""
210
+ end_think_content = ""
211
+ is_thinking = False
212
+
213
+ if content.rstrip('\n').endswith("</think>"):
214
+ end_think_reasoning_content = content.replace("</think>", "").rstrip('\n')
215
+ elif content.lstrip('\n').startswith("</think>"):
216
+ end_think_content = content.replace("</think>", "").lstrip('\n')
217
+ else:
218
+ end_think_reasoning_content = content.split("</think>")[0]
219
+ end_think_content = content.split("</think>")[1]
220
+
221
+ if end_think_reasoning_content:
222
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=end_think_reasoning_content)
223
+ yield sse_string
224
+ if end_think_content:
225
+ sse_string = await generate_sse_response(timestamp, payload["model"], content=end_think_content)
226
+ yield sse_string
227
+ continue
228
+ if is_thinking and ark_tag:
229
+ if not has_send_thinking:
230
+ content = content.replace("\n\n", "")
231
+ if content:
232
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
233
+ yield sse_string
234
+ has_send_thinking = True
235
+ continue
236
+
237
+ # 处理 poe thinking 标签
238
+ if "Thinking..." in content and "\n> " in content:
239
+ is_thinking = True
240
+ content = content.replace("Thinking...", "").replace("\n> ", "")
241
+ if is_thinking and "\n\n" in content and not ark_tag:
242
+ is_thinking = False
243
+ if is_thinking and not ark_tag:
244
+ content = content.replace("\n> ", "")
245
+ if not has_send_thinking:
246
+ content = content.replace("\n", "")
247
+ if content:
248
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
249
+ yield sse_string
250
+ has_send_thinking = True
251
+ continue
252
+
253
+ no_stream_content = safe_get(line, "choices", 0, "message", "content", default=None)
254
+ openrouter_reasoning = safe_get(line, "choices", 0, "delta", "reasoning", default="")
255
+ # print("openrouter_reasoning", repr(openrouter_reasoning), openrouter_reasoning.endswith("\\\\"), openrouter_reasoning.endswith("\\"))
256
+ if openrouter_reasoning:
257
+ if openrouter_reasoning.endswith("\\"):
258
+ enter_buffer += openrouter_reasoning
259
+ continue
260
+ elif enter_buffer.endswith("\\") and openrouter_reasoning == 'n':
261
+ enter_buffer += "n"
262
+ continue
263
+ elif enter_buffer.endswith("\\n") and openrouter_reasoning == '\\n':
264
+ enter_buffer += "\\n"
265
+ continue
266
+ elif enter_buffer.endswith("\\n\\n"):
267
+ openrouter_reasoning = '\n\n' + openrouter_reasoning
268
+ enter_buffer = ""
269
+ elif enter_buffer:
270
+ openrouter_reasoning = enter_buffer + openrouter_reasoning
271
+ enter_buffer = ''
272
+ openrouter_reasoning = openrouter_reasoning.replace("\\n", "\n")
273
+
274
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=openrouter_reasoning)
275
+ yield sse_string
276
+ elif no_stream_content and has_send_thinking == False:
277
+ sse_string = await generate_sse_response(safe_get(line, "created", default=None), safe_get(line, "model", default=None), content=no_stream_content)
278
+ yield sse_string
279
+ else:
280
+ if no_stream_content:
281
+ del line["choices"][0]["message"]
282
+ yield "data: " + json.dumps(line).strip() + end_of_line
283
+ yield "data: [DONE]" + end_of_line
284
+
285
+ async def fetch_azure_response_stream(client, url, headers, payload):
286
+ timestamp = int(datetime.timestamp(datetime.now()))
287
+ is_thinking = False
288
+ has_send_thinking = False
289
+ ark_tag = False
290
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
291
+ error_message = await check_response(response, "fetch_azure_response_stream")
292
+ if error_message:
293
+ yield error_message
294
+ return
295
+
296
+ buffer = ""
297
+ sse_string = ""
298
+ async for chunk in response.aiter_text():
299
+ buffer += chunk
300
+ while "\n" in buffer:
301
+ line, buffer = buffer.split("\n", 1)
302
+ # logger.info("line: %s", repr(line))
303
+ if line and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
304
+ if result.strip() == "[DONE]":
305
+ break
306
+ line = json.loads(result)
307
+ no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
308
+ content = safe_get(line, "choices", 0, "delta", "content", default="")
309
+
310
+ # 处理 <think> 标签
311
+ if "<think>" in content:
312
+ is_thinking = True
313
+ ark_tag = True
314
+ content = content.replace("<think>", "")
315
+ if "</think>" in content:
316
+ is_thinking = False
317
+ content = content.replace("</think>", "")
318
+ if not content:
319
+ continue
320
+ if is_thinking and ark_tag:
321
+ if not has_send_thinking:
322
+ content = content.replace("\n\n", "")
323
+ if content:
324
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
325
+ yield sse_string
326
+ has_send_thinking = True
327
+ continue
328
+
329
+ if no_stream_content or content or sse_string:
330
+ sse_string = await generate_sse_response(timestamp, safe_get(line, "model", default=None), content=no_stream_content or content)
331
+ yield sse_string
332
+ else:
333
+ if no_stream_content:
334
+ del line["choices"][0]["message"]
335
+ yield "data: " + json.dumps(line).strip() + end_of_line
336
+ yield "data: [DONE]" + end_of_line
337
+
338
+ async def fetch_cloudflare_response_stream(client, url, headers, payload, model):
339
+ timestamp = int(datetime.timestamp(datetime.now()))
340
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
341
+ error_message = await check_response(response, "fetch_cloudflare_response_stream")
342
+ if error_message:
343
+ yield error_message
344
+ return
345
+
346
+ buffer = ""
347
+ async for chunk in response.aiter_text():
348
+ buffer += chunk
349
+ while "\n" in buffer:
350
+ line, buffer = buffer.split("\n", 1)
351
+ # logger.info("line: %s", repr(line))
352
+ if line.startswith("data:"):
353
+ line = line.lstrip("data: ")
354
+ if line == "[DONE]":
355
+ break
356
+ resp: dict = json.loads(line)
357
+ message = resp.get("response")
358
+ if message:
359
+ sse_string = await generate_sse_response(timestamp, model, content=message)
360
+ yield sse_string
361
+ yield "data: [DONE]" + end_of_line
362
+
363
+ async def fetch_cohere_response_stream(client, url, headers, payload, model):
364
+ timestamp = int(datetime.timestamp(datetime.now()))
365
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
366
+ error_message = await check_response(response, "fetch_gpt_response_stream")
367
+ if error_message:
368
+ yield error_message
369
+ return
370
+
371
+ buffer = ""
372
+ async for chunk in response.aiter_text():
373
+ buffer += chunk
374
+ while "\n" in buffer:
375
+ line, buffer = buffer.split("\n", 1)
376
+ # logger.info("line: %s", repr(line))
377
+ resp: dict = json.loads(line)
378
+ if resp.get("is_finished") == True:
379
+ break
380
+ if resp.get("event_type") == "text-generation":
381
+ message = resp.get("text")
382
+ sse_string = await generate_sse_response(timestamp, model, content=message)
383
+ yield sse_string
384
+ yield "data: [DONE]" + end_of_line
385
+
386
+ async def fetch_claude_response_stream(client, url, headers, payload, model):
387
+ timestamp = int(datetime.timestamp(datetime.now()))
388
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
389
+ error_message = await check_response(response, "fetch_claude_response_stream")
390
+ if error_message:
391
+ yield error_message
392
+ return
393
+ buffer = ""
394
+ input_tokens = 0
395
+ async for chunk in response.aiter_text():
396
+ # logger.info(f"chunk: {repr(chunk)}")
397
+ buffer += chunk
398
+ while "\n" in buffer:
399
+ line, buffer = buffer.split("\n", 1)
400
+ # logger.info(line)
401
+
402
+ if line.startswith("data:"):
403
+ line = line.lstrip("data: ")
404
+ resp: dict = json.loads(line)
405
+ message = resp.get("message")
406
+ if message:
407
+ role = message.get("role")
408
+ if role:
409
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
410
+ yield sse_string
411
+ tokens_use = message.get("usage")
412
+ if tokens_use:
413
+ input_tokens = tokens_use.get("input_tokens", 0)
414
+ usage = resp.get("usage")
415
+ if usage:
416
+ output_tokens = usage.get("output_tokens", 0)
417
+ total_tokens = input_tokens + output_tokens
418
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
419
+ yield sse_string
420
+ # print("\n\rtotal_tokens", total_tokens)
421
+
422
+ tool_use = resp.get("content_block")
423
+ tools_id = None
424
+ function_call_name = None
425
+ if tool_use and "tool_use" == tool_use['type']:
426
+ # print("tool_use", tool_use)
427
+ tools_id = tool_use["id"]
428
+ if "name" in tool_use:
429
+ function_call_name = tool_use["name"]
430
+ sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
431
+ yield sse_string
432
+ delta = resp.get("delta")
433
+ # print("delta", delta)
434
+ if not delta:
435
+ continue
436
+ if "text" in delta:
437
+ content = delta["text"]
438
+ sse_string = await generate_sse_response(timestamp, model, content, None, None)
439
+ yield sse_string
440
+ if "thinking" in delta and delta["thinking"]:
441
+ content = delta["thinking"]
442
+ sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
443
+ yield sse_string
444
+ if "partial_json" in delta:
445
+ # {"type":"input_json_delta","partial_json":""}
446
+ function_call_content = delta["partial_json"]
447
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
448
+ yield sse_string
449
+ yield "data: [DONE]" + end_of_line
450
+
451
+ async def fetch_aws_response_stream(client, url, headers, payload, model):
452
+ timestamp = int(datetime.timestamp(datetime.now()))
453
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
454
+ error_message = await check_response(response, "fetch_aws_response_stream")
455
+ if error_message:
456
+ yield error_message
457
+ return
458
+
459
+ buffer = ""
460
+ async for line in response.aiter_text():
461
+ buffer += line
462
+ while "\r" in buffer:
463
+ line, buffer = buffer.split("\r", 1)
464
+ if not line or \
465
+ line.strip() == "" or \
466
+ line.strip().startswith(':content-type') or \
467
+ line.strip().startswith(':event-type'): # 过滤掉完全空的行或只有空白的行
468
+ continue
469
+
470
+ json_match = re.search(r'event{.*?}', line)
471
+ if not json_match:
472
+ continue
473
+ try:
474
+ chunk_data = json.loads(json_match.group(0).lstrip('event'))
475
+ except json.JSONDecodeError:
476
+ logger.error(f"DEBUG json.JSONDecodeError: {json_match.group(0).lstrip('event')!r}")
477
+ continue
478
+
479
+ # --- 后续处理逻辑不变 ---
480
+ if "bytes" in chunk_data:
481
+ # 解码 Base64 编码的字节
482
+ decoded_bytes = base64.b64decode(chunk_data["bytes"])
483
+ # 将解码后的字节再次解析为 JSON
484
+ payload_chunk = json.loads(decoded_bytes.decode('utf-8'))
485
+ # print(f"DEBUG payload_chunk: {payload_chunk!r}")
486
+
487
+ text = safe_get(payload_chunk, "delta", "text", default="")
488
+ if text:
489
+ sse_string = await generate_sse_response(timestamp, model, text, None, None)
490
+ yield sse_string
491
+
492
+ usage = safe_get(payload_chunk, "amazon-bedrock-invocationMetrics", default="")
493
+ if usage:
494
+ input_tokens = usage.get("inputTokenCount", 0)
495
+ output_tokens = usage.get("outputTokenCount", 0)
496
+ total_tokens = input_tokens + output_tokens
497
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
498
+ yield sse_string
499
+
500
+ yield "data: [DONE]" + end_of_line
501
+
502
+ async def fetch_response(client, url, headers, payload, engine, model):
503
+ response = None
504
+ if payload.get("file"):
505
+ file = payload.pop("file")
506
+ response = await client.post(url, headers=headers, data=payload, files={"file": file})
507
+ else:
508
+ response = await client.post(url, headers=headers, json=payload)
509
+ error_message = await check_response(response, "fetch_response")
510
+ if error_message:
511
+ yield error_message
512
+ return
513
+
514
+ if engine == "tts":
515
+ yield response.read()
516
+
517
+ elif engine == "gemini" or engine == "vertex-gemini" or engine == "aws":
518
+ response_json = response.json()
519
+
520
+ if isinstance(response_json, str):
521
+ import ast
522
+ parsed_data = ast.literal_eval(str(response_json))
523
+ elif isinstance(response_json, list):
524
+ parsed_data = response_json
525
+ else:
526
+ logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
527
+ parsed_data = response_json
528
+ # print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
529
+ content = ""
530
+ for item in parsed_data:
531
+ chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
532
+ # logger.info(f"chunk: {repr(chunk)}")
533
+ if chunk:
534
+ content += chunk
535
+
536
+ usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
537
+ prompt_tokens = usage_metadata.get("promptTokenCount", 0)
538
+ candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
539
+ total_tokens = usage_metadata.get("totalTokenCount", 0)
540
+
541
+ role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
542
+ if role == "model":
543
+ role = "assistant"
544
+ else:
545
+ logger.error(f"Unknown role: {role}, parsed_data: {parsed_data}")
546
+ role = "assistant"
547
+
548
+ function_call_name = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "name", default=None)
549
+ function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
550
+
551
+ timestamp = int(datetime.timestamp(datetime.now()))
552
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
553
+
554
+ elif engine == "claude":
555
+ response_json = response.json()
556
+ # print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
557
+
558
+ content = safe_get(response_json, "content", 0, "text")
559
+
560
+ prompt_tokens = safe_get(response_json, "usage", "input_tokens")
561
+ output_tokens = safe_get(response_json, "usage", "output_tokens")
562
+ total_tokens = prompt_tokens + output_tokens
563
+
564
+ role = safe_get(response_json, "role")
565
+
566
+ function_call_name = safe_get(response_json, "content", 1, "name", default=None)
567
+ function_call_content = safe_get(response_json, "content", 1, "input", default=None)
568
+ tools_id = safe_get(response_json, "content", 1, "id", default=None)
569
+
570
+ timestamp = int(datetime.timestamp(datetime.now()))
571
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
572
+
573
+ elif engine == "azure":
574
+ response_json = response.json()
575
+ # 删除 content_filter_results
576
+ if "choices" in response_json:
577
+ for choice in response_json["choices"]:
578
+ if "content_filter_results" in choice:
579
+ del choice["content_filter_results"]
580
+
581
+ # 删除 prompt_filter_results
582
+ if "prompt_filter_results" in response_json:
583
+ del response_json["prompt_filter_results"]
584
+
585
+ yield response_json
586
+
587
+ else:
588
+ response_json = response.json()
589
+ yield response_json
590
+
591
+ async def fetch_response_stream(client, url, headers, payload, engine, model):
592
+ if engine == "gemini" or engine == "vertex-gemini":
593
+ async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
594
+ yield chunk
595
+ elif engine == "claude" or engine == "vertex-claude":
596
+ async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
597
+ yield chunk
598
+ elif engine == "aws":
599
+ async for chunk in fetch_aws_response_stream(client, url, headers, payload, model):
600
+ yield chunk
601
+ elif engine == "gpt":
602
+ async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
603
+ yield chunk
604
+ elif engine == "azure":
605
+ async for chunk in fetch_azure_response_stream(client, url, headers, payload):
606
+ yield chunk
607
+ elif engine == "openrouter":
608
+ async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
609
+ yield chunk
610
+ elif engine == "cloudflare":
611
+ async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
612
+ yield chunk
613
+ elif engine == "cohere":
614
+ async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
615
+ yield chunk
616
+ else:
617
+ raise ValueError("Unknown response")
@@ -0,0 +1,18 @@
1
+ from ..utils import BaseAPI
2
+
3
+ # GOOGLE_AI_API_KEY = os.environ.get('GOOGLE_AI_API_KEY', None)
4
+
5
+ # base_api = BaseAPI("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:streamGenerateContent?key=" + GOOGLE_AI_API_KEY)
6
+
7
+ # print(base_api.chat_url)
8
+
9
+ base_api = BaseAPI("http://127.0.0.1:8000/v1")
10
+ print(base_api.chat_url)
11
+
12
+ base_api = BaseAPI("http://127.0.0.1:8000/v1/images/generations")
13
+ print(base_api.image_url)
14
+
15
+ """
16
+ python -m core.test.test_base_api
17
+ python -m aient.src.aient.core.test.test_base_api
18
+ """
@@ -0,0 +1,15 @@
1
+ from ..request import get_image_message
2
+ import os
3
+ import asyncio
4
+ IMAGE_URL = os.getenv("IMAGE_URL")
5
+
6
+ async def test_image():
7
+ image_message = await get_image_message(IMAGE_URL, engine="gemini")
8
+ print(image_message)
9
+
10
+ if __name__ == "__main__":
11
+ asyncio.run(test_image())
12
+
13
+ '''
14
+ python -m core.test.test_image
15
+ '''