aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/core/response.py ADDED
@@ -0,0 +1,531 @@
1
+ import json
2
+ import random
3
+ import string
4
+ from datetime import datetime
5
+
6
+ from .log_config import logger
7
+
8
+ from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line
9
+
10
+ async def check_response(response, error_log):
11
+ if response and not (200 <= response.status_code < 300):
12
+ error_message = await response.aread()
13
+ error_str = error_message.decode('utf-8', errors='replace')
14
+ try:
15
+ error_json = json.loads(error_str)
16
+ except json.JSONDecodeError:
17
+ error_json = error_str
18
+ return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
19
+ return None
20
+
21
+ async def fetch_gemini_response_stream(client, url, headers, payload, model):
22
+ timestamp = int(datetime.timestamp(datetime.now()))
23
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
24
+ error_message = await check_response(response, "fetch_gemini_response_stream")
25
+ if error_message:
26
+ yield error_message
27
+ return
28
+ buffer = ""
29
+ revicing_function_call = False
30
+ function_full_response = "{"
31
+ need_function_call = False
32
+ is_finish = False
33
+ # line_index = 0
34
+ # last_text_line = 0
35
+ # if "thinking" in model:
36
+ # is_thinking = True
37
+ # else:
38
+ # is_thinking = False
39
+ async for chunk in response.aiter_text():
40
+ buffer += chunk
41
+
42
+ while "\n" in buffer:
43
+ line, buffer = buffer.split("\n", 1)
44
+ # line_index += 1
45
+ if line and '\"finishReason\": \"' in line:
46
+ is_finish = True
47
+ break
48
+ # print(line)
49
+ if line and '\"text\": \"' in line:
50
+ try:
51
+ json_data = json.loads( "{" + line + "}")
52
+ content = json_data.get('text', '')
53
+ content = "\n".join(content.split("\\n"))
54
+ # content = content.replace("\n", "\n\n")
55
+ # if last_text_line == 0 and is_thinking:
56
+ # content = "> " + content.lstrip()
57
+ # if is_thinking:
58
+ # content = content.replace("\n", "\n> ")
59
+ # if last_text_line == line_index - 3:
60
+ # is_thinking = False
61
+ # content = "\n\n\n" + content.lstrip()
62
+ sse_string = await generate_sse_response(timestamp, model, content=content)
63
+ yield sse_string
64
+ except json.JSONDecodeError:
65
+ logger.error(f"无法解析JSON: {line}")
66
+ # last_text_line = line_index
67
+
68
+ if line and ('\"functionCall\": {' in line or revicing_function_call):
69
+ revicing_function_call = True
70
+ need_function_call = True
71
+ if ']' in line:
72
+ revicing_function_call = False
73
+ continue
74
+
75
+ function_full_response += line
76
+
77
+ if is_finish:
78
+ break
79
+
80
+ if need_function_call:
81
+ function_call = json.loads(function_full_response)
82
+ function_call_name = function_call["functionCall"]["name"]
83
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=function_call_name)
84
+ yield sse_string
85
+ function_full_response = json.dumps(function_call["functionCall"]["args"])
86
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
87
+ yield sse_string
88
+ yield "data: [DONE]" + end_of_line
89
+
90
+ async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
91
+ timestamp = int(datetime.timestamp(datetime.now()))
92
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
93
+ error_message = await check_response(response, "fetch_vertex_claude_response_stream")
94
+ if error_message:
95
+ yield error_message
96
+ return
97
+
98
+ buffer = ""
99
+ revicing_function_call = False
100
+ function_full_response = "{"
101
+ need_function_call = False
102
+ async for chunk in response.aiter_text():
103
+ buffer += chunk
104
+ while "\n" in buffer:
105
+ line, buffer = buffer.split("\n", 1)
106
+ # logger.info(f"{line}")
107
+ if line and '\"text\": \"' in line:
108
+ try:
109
+ json_data = json.loads( "{" + line + "}")
110
+ content = json_data.get('text', '')
111
+ content = "\n".join(content.split("\\n"))
112
+ sse_string = await generate_sse_response(timestamp, model, content=content)
113
+ yield sse_string
114
+ except json.JSONDecodeError:
115
+ logger.error(f"无法解析JSON: {line}")
116
+
117
+ if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
118
+ revicing_function_call = True
119
+ need_function_call = True
120
+ if ']' in line:
121
+ revicing_function_call = False
122
+ continue
123
+
124
+ function_full_response += line
125
+
126
+ if need_function_call:
127
+ function_call = json.loads(function_full_response)
128
+ function_call_name = function_call["name"]
129
+ function_call_id = function_call["id"]
130
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
131
+ yield sse_string
132
+ function_full_response = json.dumps(function_call["input"])
133
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
134
+ yield sse_string
135
+ yield "data: [DONE]" + end_of_line
136
+
137
+ async def fetch_gpt_response_stream(client, url, headers, payload):
138
+ timestamp = int(datetime.timestamp(datetime.now()))
139
+ random.seed(timestamp)
140
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
141
+ is_thinking = False
142
+ has_send_thinking = False
143
+ ark_tag = False
144
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
145
+ error_message = await check_response(response, "fetch_gpt_response_stream")
146
+ if error_message:
147
+ yield error_message
148
+ return
149
+
150
+ buffer = ""
151
+ enter_buffer = ""
152
+ async for chunk in response.aiter_text():
153
+ buffer += chunk
154
+ while "\n" in buffer:
155
+ line, buffer = buffer.split("\n", 1)
156
+ # logger.info("line: %s", repr(line))
157
+ if line and line != "data: " and line != "data:" and not line.startswith(": ") and (result:=line.lstrip("data: ").strip()):
158
+ if result.strip() == "[DONE]":
159
+ yield "data: [DONE]" + end_of_line
160
+ return
161
+ line = json.loads(result)
162
+ line['id'] = f"chatcmpl-{random_str}"
163
+
164
+ # 处理 <think> 标签
165
+ content = safe_get(line, "choices", 0, "delta", "content", default="")
166
+ if "<think>" in content:
167
+ is_thinking = True
168
+ ark_tag = True
169
+ content = content.replace("<think>", "")
170
+ if "</think>" in content:
171
+ end_think_reasoning_content = ""
172
+ end_think_content = ""
173
+ is_thinking = False
174
+
175
+ if content.rstrip('\n').endswith("</think>"):
176
+ end_think_reasoning_content = content.replace("</think>", "").rstrip('\n')
177
+ elif content.lstrip('\n').startswith("</think>"):
178
+ end_think_content = content.replace("</think>", "").lstrip('\n')
179
+ else:
180
+ end_think_reasoning_content = content.split("</think>")[0]
181
+ end_think_content = content.split("</think>")[1]
182
+
183
+ if end_think_reasoning_content:
184
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=end_think_reasoning_content)
185
+ yield sse_string
186
+ if end_think_content:
187
+ sse_string = await generate_sse_response(timestamp, payload["model"], content=end_think_content)
188
+ yield sse_string
189
+ continue
190
+ if is_thinking and ark_tag:
191
+ if not has_send_thinking:
192
+ content = content.replace("\n\n", "")
193
+ if content:
194
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
195
+ yield sse_string
196
+ has_send_thinking = True
197
+ continue
198
+
199
+ # 处理 poe thinking 标签
200
+ if "Thinking..." in content and "\n> " in content:
201
+ is_thinking = True
202
+ content = content.replace("Thinking...", "").replace("\n> ", "")
203
+ if is_thinking and "\n\n" in content and not ark_tag:
204
+ is_thinking = False
205
+ if is_thinking and not ark_tag:
206
+ content = content.replace("\n> ", "")
207
+ if not has_send_thinking:
208
+ content = content.replace("\n", "")
209
+ if content:
210
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
211
+ yield sse_string
212
+ has_send_thinking = True
213
+ continue
214
+
215
+ no_stream_content = safe_get(line, "choices", 0, "message", "content", default=None)
216
+ openrouter_reasoning = safe_get(line, "choices", 0, "delta", "reasoning", default="")
217
+ # print("openrouter_reasoning", repr(openrouter_reasoning), openrouter_reasoning.endswith("\\\\"), openrouter_reasoning.endswith("\\"))
218
+ if openrouter_reasoning:
219
+ if openrouter_reasoning.endswith("\\"):
220
+ enter_buffer += openrouter_reasoning
221
+ continue
222
+ elif enter_buffer.endswith("\\") and openrouter_reasoning == 'n':
223
+ enter_buffer += "n"
224
+ continue
225
+ elif enter_buffer.endswith("\\n") and openrouter_reasoning == '\\n':
226
+ enter_buffer += "\\n"
227
+ continue
228
+ elif enter_buffer.endswith("\\n\\n"):
229
+ openrouter_reasoning = '\n\n' + openrouter_reasoning
230
+ enter_buffer = ""
231
+ elif enter_buffer:
232
+ openrouter_reasoning = enter_buffer + openrouter_reasoning
233
+ enter_buffer = ''
234
+ openrouter_reasoning = openrouter_reasoning.replace("\\n", "\n")
235
+
236
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=openrouter_reasoning)
237
+ yield sse_string
238
+ elif no_stream_content and has_send_thinking == False:
239
+ sse_string = await generate_sse_response(safe_get(line, "created", default=None), safe_get(line, "model", default=None), content=no_stream_content)
240
+ yield sse_string
241
+ else:
242
+ if no_stream_content:
243
+ del line["choices"][0]["message"]
244
+ yield "data: " + json.dumps(line).strip() + end_of_line
245
+
246
+ async def fetch_azure_response_stream(client, url, headers, payload):
247
+ timestamp = int(datetime.timestamp(datetime.now()))
248
+ is_thinking = False
249
+ has_send_thinking = False
250
+ ark_tag = False
251
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
252
+ error_message = await check_response(response, "fetch_azure_response_stream")
253
+ if error_message:
254
+ yield error_message
255
+ return
256
+
257
+ buffer = ""
258
+ sse_string = ""
259
+ async for chunk in response.aiter_text():
260
+ buffer += chunk
261
+ while "\n" in buffer:
262
+ line, buffer = buffer.split("\n", 1)
263
+ # logger.info("line: %s", repr(line))
264
+ if line and line != "data: " and line != "data:" and not line.startswith(": "):
265
+ result = line.lstrip("data: ")
266
+ if result.strip() == "[DONE]":
267
+ yield "data: [DONE]" + end_of_line
268
+ return
269
+ line = json.loads(result)
270
+ no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
271
+ content = safe_get(line, "choices", 0, "delta", "content", default="")
272
+
273
+ # 处理 <think> 标签
274
+ if "<think>" in content:
275
+ is_thinking = True
276
+ ark_tag = True
277
+ content = content.replace("<think>", "")
278
+ if "</think>" in content:
279
+ is_thinking = False
280
+ content = content.replace("</think>", "")
281
+ if not content:
282
+ continue
283
+ if is_thinking and ark_tag:
284
+ if not has_send_thinking:
285
+ content = content.replace("\n\n", "")
286
+ if content:
287
+ sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=content)
288
+ yield sse_string
289
+ has_send_thinking = True
290
+ continue
291
+
292
+ if no_stream_content or content or sse_string:
293
+ sse_string = await generate_sse_response(timestamp, safe_get(line, "model", default=None), content=no_stream_content or content)
294
+ yield sse_string
295
+ else:
296
+ if no_stream_content:
297
+ del line["choices"][0]["message"]
298
+ yield "data: " + json.dumps(line).strip() + end_of_line
299
+ yield "data: [DONE]" + end_of_line
300
+
301
+ async def fetch_cloudflare_response_stream(client, url, headers, payload, model):
302
+ timestamp = int(datetime.timestamp(datetime.now()))
303
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
304
+ error_message = await check_response(response, "fetch_cloudflare_response_stream")
305
+ if error_message:
306
+ yield error_message
307
+ return
308
+
309
+ buffer = ""
310
+ async for chunk in response.aiter_text():
311
+ buffer += chunk
312
+ while "\n" in buffer:
313
+ line, buffer = buffer.split("\n", 1)
314
+ # logger.info("line: %s", repr(line))
315
+ if line.startswith("data:"):
316
+ line = line.lstrip("data: ")
317
+ if line == "[DONE]":
318
+ yield "data: [DONE]" + end_of_line
319
+ return
320
+ resp: dict = json.loads(line)
321
+ message = resp.get("response")
322
+ if message:
323
+ sse_string = await generate_sse_response(timestamp, model, content=message)
324
+ yield sse_string
325
+
326
+ async def fetch_cohere_response_stream(client, url, headers, payload, model):
327
+ timestamp = int(datetime.timestamp(datetime.now()))
328
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
329
+ error_message = await check_response(response, "fetch_gpt_response_stream")
330
+ if error_message:
331
+ yield error_message
332
+ return
333
+
334
+ buffer = ""
335
+ async for chunk in response.aiter_text():
336
+ buffer += chunk
337
+ while "\n" in buffer:
338
+ line, buffer = buffer.split("\n", 1)
339
+ # logger.info("line: %s", repr(line))
340
+ resp: dict = json.loads(line)
341
+ if resp.get("is_finished") == True:
342
+ yield "data: [DONE]" + end_of_line
343
+ return
344
+ if resp.get("event_type") == "text-generation":
345
+ message = resp.get("text")
346
+ sse_string = await generate_sse_response(timestamp, model, content=message)
347
+ yield sse_string
348
+
349
+ async def fetch_claude_response_stream(client, url, headers, payload, model):
350
+ timestamp = int(datetime.timestamp(datetime.now()))
351
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
352
+ error_message = await check_response(response, "fetch_claude_response_stream")
353
+ if error_message:
354
+ yield error_message
355
+ return
356
+ buffer = ""
357
+ input_tokens = 0
358
+ async for chunk in response.aiter_text():
359
+ # logger.info(f"chunk: {repr(chunk)}")
360
+ buffer += chunk
361
+ while "\n" in buffer:
362
+ line, buffer = buffer.split("\n", 1)
363
+ # logger.info(line)
364
+
365
+ if line.startswith("data:"):
366
+ line = line.lstrip("data: ")
367
+ resp: dict = json.loads(line)
368
+ message = resp.get("message")
369
+ if message:
370
+ role = message.get("role")
371
+ if role:
372
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
373
+ yield sse_string
374
+ tokens_use = message.get("usage")
375
+ if tokens_use:
376
+ input_tokens = tokens_use.get("input_tokens", 0)
377
+ usage = resp.get("usage")
378
+ if usage:
379
+ output_tokens = usage.get("output_tokens", 0)
380
+ total_tokens = input_tokens + output_tokens
381
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
382
+ yield sse_string
383
+ # print("\n\rtotal_tokens", total_tokens)
384
+
385
+ tool_use = resp.get("content_block")
386
+ tools_id = None
387
+ function_call_name = None
388
+ if tool_use and "tool_use" == tool_use['type']:
389
+ # print("tool_use", tool_use)
390
+ tools_id = tool_use["id"]
391
+ if "name" in tool_use:
392
+ function_call_name = tool_use["name"]
393
+ sse_string = await generate_sse_response(timestamp, model, None, tools_id, function_call_name, None)
394
+ yield sse_string
395
+ delta = resp.get("delta")
396
+ # print("delta", delta)
397
+ if not delta:
398
+ continue
399
+ if "text" in delta:
400
+ content = delta["text"]
401
+ sse_string = await generate_sse_response(timestamp, model, content, None, None)
402
+ yield sse_string
403
+ if "thinking" in delta and delta["thinking"]:
404
+ content = delta["thinking"]
405
+ sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
406
+ yield sse_string
407
+ if "partial_json" in delta:
408
+ # {"type":"input_json_delta","partial_json":""}
409
+ function_call_content = delta["partial_json"]
410
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, function_call_content)
411
+ yield sse_string
412
+ yield "data: [DONE]" + end_of_line
413
+
414
+ async def fetch_response(client, url, headers, payload, engine, model):
415
+ response = None
416
+ if payload.get("file"):
417
+ file = payload.pop("file")
418
+ response = await client.post(url, headers=headers, data=payload, files={"file": file})
419
+ else:
420
+ response = await client.post(url, headers=headers, json=payload)
421
+ error_message = await check_response(response, "fetch_response")
422
+ if error_message:
423
+ yield error_message
424
+ return
425
+
426
+ if engine == "tts":
427
+ yield response.read()
428
+
429
+ elif engine == "gemini" or engine == "vertex-gemini":
430
+ response_json = response.json()
431
+
432
+ if isinstance(response_json, str):
433
+ import ast
434
+ parsed_data = ast.literal_eval(str(response_json))
435
+ elif isinstance(response_json, list):
436
+ parsed_data = response_json
437
+ else:
438
+ logger.error(f"error fetch_response: Unknown response_json type: {type(response_json)}")
439
+ parsed_data = response_json
440
+ # print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
441
+ content = ""
442
+ for item in parsed_data:
443
+ chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
444
+ # logger.info(f"chunk: {repr(chunk)}")
445
+ if chunk:
446
+ content += chunk
447
+
448
+ usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
449
+ prompt_tokens = usage_metadata.get("promptTokenCount", 0)
450
+ candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
451
+ total_tokens = usage_metadata.get("totalTokenCount", 0)
452
+
453
+ role = safe_get(parsed_data, -1, "candidates", 0, "content", "role")
454
+ if role == "model":
455
+ role = "assistant"
456
+ else:
457
+ logger.error(f"Unknown role: {role}")
458
+ role = "assistant"
459
+
460
+ function_call_name = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "name", default=None)
461
+ function_call_content = safe_get(parsed_data, -1, "candidates", 0, "content", "parts", 0, "functionCall", "args", default=None)
462
+
463
+ timestamp = int(datetime.timestamp(datetime.now()))
464
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens)
465
+
466
+ elif engine == "claude":
467
+ response_json = response.json()
468
+ # print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
469
+
470
+ content = safe_get(response_json, "content", 0, "text")
471
+
472
+ prompt_tokens = safe_get(response_json, "usage", "input_tokens")
473
+ output_tokens = safe_get(response_json, "usage", "output_tokens")
474
+ total_tokens = prompt_tokens + output_tokens
475
+
476
+ role = safe_get(response_json, "role")
477
+
478
+ function_call_name = safe_get(response_json, "content", 1, "name", default=None)
479
+ function_call_content = safe_get(response_json, "content", 1, "input", default=None)
480
+ tools_id = safe_get(response_json, "content", 1, "id", default=None)
481
+
482
+ timestamp = int(datetime.timestamp(datetime.now()))
483
+ yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
484
+
485
+ elif engine == "azure":
486
+ response_json = response.json()
487
+ # 删除 content_filter_results
488
+ if "choices" in response_json:
489
+ for choice in response_json["choices"]:
490
+ if "content_filter_results" in choice:
491
+ del choice["content_filter_results"]
492
+
493
+ # 删除 prompt_filter_results
494
+ if "prompt_filter_results" in response_json:
495
+ del response_json["prompt_filter_results"]
496
+
497
+ yield response_json
498
+
499
+ else:
500
+ response_json = response.json()
501
+ yield response_json
502
+
503
+ async def fetch_response_stream(client, url, headers, payload, engine, model):
504
+ # try:
505
+ if engine == "gemini" or engine == "vertex-gemini":
506
+ async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
507
+ yield chunk
508
+ elif engine == "claude" or engine == "vertex-claude":
509
+ async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
510
+ yield chunk
511
+ elif engine == "gpt":
512
+ async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
513
+ yield chunk
514
+ elif engine == "azure":
515
+ async for chunk in fetch_azure_response_stream(client, url, headers, payload):
516
+ yield chunk
517
+ elif engine == "openrouter":
518
+ async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
519
+ yield chunk
520
+ elif engine == "cloudflare":
521
+ async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
522
+ yield chunk
523
+ elif engine == "cohere":
524
+ async for chunk in fetch_cohere_response_stream(client, url, headers, payload, model):
525
+ yield chunk
526
+ else:
527
+ raise ValueError("Unknown response")
528
+ # except httpx.ConnectError as e:
529
+ # yield {"error": f"500", "details": "fetch_response_stream Connect Error"}
530
+ # except httpx.ReadTimeout as e:
531
+ # yield {"error": f"500", "details": "fetch_response_stream Read Response Timeout"}
@@ -0,0 +1,17 @@
1
+ from ..utils import BaseAPI
2
+
3
+ import os
4
+
5
+ # GOOGLE_AI_API_KEY = os.environ.get('GOOGLE_AI_API_KEY', None)
6
+
7
+ # base_api = BaseAPI("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:streamGenerateContent?key=" + GOOGLE_AI_API_KEY)
8
+
9
+ # print(base_api.chat_url)
10
+
11
+ base_api = BaseAPI("http://127.0.0.1:8000/v1")
12
+
13
+ print(base_api.chat_url)
14
+
15
+ """
16
+ python -m core.test.test_base_api
17
+ """
@@ -0,0 +1,15 @@
1
+ from ..request import get_image_message
2
+ import os
3
+ import asyncio
4
+ IMAGE_URL = os.getenv("IMAGE_URL")
5
+
6
+ async def test_image():
7
+ image_message = await get_image_message(IMAGE_URL, engine="gemini")
8
+ print(image_message)
9
+
10
+ if __name__ == "__main__":
11
+ asyncio.run(test_image())
12
+
13
+ '''
14
+ python -m core.test.test_image
15
+ '''
@@ -0,0 +1,92 @@
1
+ import json
2
+ import asyncio
3
+
4
+ from ..request import prepare_request_payload
5
+
6
+ """
7
+ 测试脚本: 用于测试core/request.py中的get_payload函数
8
+
9
+ 该测试脚本构造core/models.py中的RequestModel对象,然后调用core/request.py的get_payload函数,
10
+ 返回url, headers, payload。通过这种方式可以单独测试get_payload模块的功能。
11
+
12
+ 测试用例: 带工具函数的模型调用测试
13
+
14
+ 使用的API配置信息来自api.yaml中的'new-i1-pe'。
15
+
16
+ python -m core.test_payload
17
+ """
18
+
19
+ async def test_payload():
20
+ print("===== 开始测试 get_payload 函数 =====")
21
+
22
+ # 步骤1: 配置provider
23
+ provider = {
24
+ "provider": "new",
25
+ "base_url": "/v1/chat/completions",
26
+ "api": "",
27
+ "model": [
28
+ "gpt-4" # 使用支持工具功能的模型名称
29
+ ],
30
+ "tools": True # 启用工具支持
31
+ }
32
+
33
+ request_data = {
34
+ "model": "gpt-4",
35
+ "messages": [
36
+ {
37
+ "role": "system",
38
+ "content": "你是一个有用的AI助手。"
39
+ },
40
+ {
41
+ "role": "user",
42
+ "content": "你好,请介绍一下自己。"
43
+ }
44
+ ],
45
+ "stream": True,
46
+ "temperature": 0.7,
47
+ "max_tokens": 1000,
48
+ "tools": [
49
+ {
50
+ "type": "function",
51
+ "function": {
52
+ "name": "get_current_weather",
53
+ "description": "获取当前天气信息",
54
+ "parameters": {
55
+ "type": "object",
56
+ "properties": {
57
+ "location": {
58
+ "type": "string",
59
+ "description": "城市名称,例如:北京"
60
+ },
61
+ "unit": {
62
+ "type": "string",
63
+ "enum": ["celsius", "fahrenheit"],
64
+ "description": "温度单位"
65
+ }
66
+ },
67
+ "required": ["location"]
68
+ }
69
+ }
70
+ }
71
+ ],
72
+ "tool_choice": "auto" # 添加工具选择参数
73
+ }
74
+
75
+ # 调用函数处理请求并获取结果
76
+ url, headers, payload, engine = await prepare_request_payload(provider, request_data)
77
+
78
+ # 打印结果
79
+ print("\nURL:")
80
+ print(url)
81
+
82
+ print("\nHeaders:")
83
+ print(json.dumps(headers, indent=4, ensure_ascii=False))
84
+
85
+ print("\nPayload:")
86
+ print(json.dumps(payload, indent=4, ensure_ascii=False))
87
+
88
+ print("\n===== 测试完成 =====")
89
+
90
+ if __name__ == "__main__":
91
+ # 执行异步测试函数
92
+ asyncio.run(test_payload())