aient 1.1.71__py3-none-any.whl → 1.1.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/models/chatgpt.py +36 -205
- aient/plugins/config.py +1 -1
- {aient-1.1.71.dist-info → aient-1.1.73.dist-info}/METADATA +1 -1
- {aient-1.1.71.dist-info → aient-1.1.73.dist-info}/RECORD +7 -7
- {aient-1.1.71.dist-info → aient-1.1.73.dist-info}/WHEEL +0 -0
- {aient-1.1.71.dist-info → aient-1.1.73.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.71.dist-info → aient-1.1.73.dist-info}/top_level.txt +0 -0
aient/models/chatgpt.py
CHANGED
@@ -6,10 +6,7 @@ import httpx
|
|
6
6
|
import asyncio
|
7
7
|
import logging
|
8
8
|
import inspect
|
9
|
-
from typing import Set
|
10
9
|
from typing import Union, Optional, Callable
|
11
|
-
from pathlib import Path
|
12
|
-
|
13
10
|
|
14
11
|
from .base import BaseLLM
|
15
12
|
from ..plugins.registry import registry
|
@@ -18,27 +15,6 @@ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xm
|
|
18
15
|
from ..core.request import prepare_request_payload
|
19
16
|
from ..core.response import fetch_response_stream, fetch_response
|
20
17
|
|
21
|
-
def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
|
22
|
-
"""
|
23
|
-
Get filtered list of object variable names.
|
24
|
-
:param keys: List of keys to include. If the first key is "not", the remaining keys will be removed from the class keys.
|
25
|
-
:return: List of class keys.
|
26
|
-
"""
|
27
|
-
class_keys = obj.__dict__.keys()
|
28
|
-
if not keys:
|
29
|
-
return set(class_keys)
|
30
|
-
|
31
|
-
# Remove the passed keys from the class keys.
|
32
|
-
if keys[0] == "not":
|
33
|
-
return {key for key in class_keys if key not in keys[1:]}
|
34
|
-
# Check if all passed keys are valid
|
35
|
-
if invalid_keys := set(keys) - class_keys:
|
36
|
-
raise ValueError(
|
37
|
-
f"Invalid keys: {invalid_keys}",
|
38
|
-
)
|
39
|
-
# Only return specified keys that are in class_keys
|
40
|
-
return {key for key in keys if key in class_keys}
|
41
|
-
|
42
18
|
class chatgpt(BaseLLM):
|
43
19
|
"""
|
44
20
|
Official ChatGPT API
|
@@ -407,7 +383,7 @@ class chatgpt(BaseLLM):
|
|
407
383
|
|
408
384
|
resp = json.loads(line) if isinstance(line, str) else line
|
409
385
|
if "error" in resp:
|
410
|
-
raise Exception(
|
386
|
+
raise Exception(json.dumps({"type": "api_error", "details": resp}, ensure_ascii=False))
|
411
387
|
|
412
388
|
total_tokens = total_tokens or safe_get(resp, "usage", "total_tokens", default=0)
|
413
389
|
delta = safe_get(resp, "choices", 0, "delta")
|
@@ -459,6 +435,9 @@ class chatgpt(BaseLLM):
|
|
459
435
|
for chunk in process_sync():
|
460
436
|
yield chunk
|
461
437
|
|
438
|
+
if not full_response.strip():
|
439
|
+
raise Exception(json.dumps({"type": "response_empty_error", "message": "Response is empty"}, ensure_ascii=False))
|
440
|
+
|
462
441
|
if self.print_log:
|
463
442
|
self.logger.info(f"total_tokens: {total_tokens}")
|
464
443
|
|
@@ -471,9 +450,7 @@ class chatgpt(BaseLLM):
|
|
471
450
|
if self.check_done:
|
472
451
|
# self.logger.info(f"worker Response: {full_response}")
|
473
452
|
if not full_response.strip().endswith('[done]'):
|
474
|
-
raise Exception(
|
475
|
-
elif not full_response.strip():
|
476
|
-
raise Exception(f"Response is empty")
|
453
|
+
raise Exception(json.dumps({"type": "validation_error", "message": "Response is not ended with [done]", "response": full_response}, ensure_ascii=False))
|
477
454
|
else:
|
478
455
|
full_response = full_response.strip().rstrip('[done]')
|
479
456
|
full_response = full_response.replace("<tool_code>", "").replace("</tool_code>", "")
|
@@ -537,8 +514,10 @@ class chatgpt(BaseLLM):
|
|
537
514
|
# 处理函数调用
|
538
515
|
if need_function_call and self.use_plugins == True:
|
539
516
|
if self.print_log:
|
540
|
-
|
541
|
-
|
517
|
+
if function_parameter:
|
518
|
+
self.logger.info(f"function_parameter: {function_parameter}")
|
519
|
+
else:
|
520
|
+
self.logger.info(f"function_full_response: {function_full_response}")
|
542
521
|
|
543
522
|
function_response = ""
|
544
523
|
# 定义处理单个工具调用的辅助函数
|
@@ -553,17 +532,13 @@ class chatgpt(BaseLLM):
|
|
553
532
|
tool_response = ""
|
554
533
|
has_args = safe_get(self.function_call_list, tool_name, "parameters", "required", default=False)
|
555
534
|
if self.function_calls_counter[tool_name] <= self.function_call_max_loop and (tool_args != "{}" or not has_args):
|
556
|
-
function_call_max_tokens = self.truncate_limit - 1000
|
557
|
-
if function_call_max_tokens <= 0:
|
558
|
-
function_call_max_tokens = int(self.truncate_limit / 2)
|
559
535
|
if self.print_log:
|
560
|
-
self.logger.info(f"
|
536
|
+
self.logger.info(f"Tool use, calling: {tool_name}")
|
561
537
|
|
562
538
|
# 处理函数调用结果
|
563
539
|
if is_async:
|
564
540
|
async for chunk in get_tools_result_async(
|
565
|
-
tool_name, tool_args,
|
566
|
-
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
541
|
+
tool_name, tool_args, model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
567
542
|
kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
|
568
543
|
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
569
544
|
):
|
@@ -571,8 +546,7 @@ class chatgpt(BaseLLM):
|
|
571
546
|
else:
|
572
547
|
async def run_async():
|
573
548
|
async for chunk in get_tools_result_async(
|
574
|
-
tool_name, tool_args,
|
575
|
-
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
549
|
+
tool_name, tool_args, model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
576
550
|
kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
|
577
551
|
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
578
552
|
):
|
@@ -638,8 +612,8 @@ class chatgpt(BaseLLM):
|
|
638
612
|
else:
|
639
613
|
all_responses.append(f"[{tool_name}({tool_args}) Result]:\n\n{tool_response}")
|
640
614
|
|
641
|
-
|
642
|
-
|
615
|
+
if self.check_done:
|
616
|
+
all_responses.append("Your message **must** end with [done] to signify the end of your output.")
|
643
617
|
|
644
618
|
# 合并所有工具响应
|
645
619
|
function_response = "\n\n".join(all_responses).strip()
|
@@ -721,13 +695,17 @@ class chatgpt(BaseLLM):
|
|
721
695
|
|
722
696
|
# 打印日志
|
723
697
|
if self.print_log:
|
724
|
-
self.logger.
|
725
|
-
self.logger.
|
698
|
+
self.logger.debug(f"api_url: {kwargs.get('api_url', self.api_url.chat_url)}")
|
699
|
+
self.logger.debug(f"api_key: {kwargs.get('api_key', self.api_key)}")
|
700
|
+
need_done_prompt = False
|
726
701
|
|
727
702
|
# 发送请求并处理响应
|
728
703
|
for i in range(30):
|
704
|
+
tmp_post_json = copy.deepcopy(json_post)
|
705
|
+
if need_done_prompt:
|
706
|
+
tmp_post_json["messages"].extend(need_done_prompt)
|
729
707
|
if self.print_log:
|
730
|
-
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(
|
708
|
+
replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(tmp_post_json)))
|
731
709
|
replaced_text_str = json.dumps(replaced_text, indent=4, ensure_ascii=False)
|
732
710
|
self.logger.info(f"Request Body:\n{replaced_text_str}")
|
733
711
|
|
@@ -753,11 +731,11 @@ class chatgpt(BaseLLM):
|
|
753
731
|
else:
|
754
732
|
if stream:
|
755
733
|
generator = fetch_response_stream(
|
756
|
-
self.aclient, url, headers,
|
734
|
+
self.aclient, url, headers, tmp_post_json, engine_type, model or self.engine,
|
757
735
|
)
|
758
736
|
else:
|
759
737
|
generator = fetch_response(
|
760
|
-
self.aclient, url, headers,
|
738
|
+
self.aclient, url, headers, tmp_post_json, engine_type, model or self.engine,
|
761
739
|
)
|
762
740
|
|
763
741
|
# 处理正常响应
|
@@ -777,18 +755,24 @@ class chatgpt(BaseLLM):
|
|
777
755
|
except httpx.RemoteProtocolError:
|
778
756
|
continue
|
779
757
|
except Exception as e:
|
780
|
-
|
781
|
-
|
758
|
+
self.logger.error(f"{e}")
|
759
|
+
if "validation_error" in str(e):
|
760
|
+
bad_assistant_message = json.loads(str(e))["response"]
|
761
|
+
need_done_prompt = [
|
762
|
+
{"role": "assistant", "content": bad_assistant_message},
|
763
|
+
{"role": "user", "content": "你的消息没有以[done]结尾,请重新输出"}
|
764
|
+
]
|
765
|
+
continue
|
766
|
+
if "response_empty_error" in str(e):
|
782
767
|
continue
|
783
|
-
self.logger.error(f"发生了未预料的错误:{e}")
|
784
768
|
import traceback
|
785
769
|
self.logger.error(traceback.format_exc())
|
786
770
|
if "Invalid URL" in str(e):
|
787
|
-
|
788
|
-
raise Exception(
|
771
|
+
error_message = "您输入了无效的API URL,请使用正确的URL并使用`/start`命令重新设置API URL。具体错误如下:\n\n" + str(e)
|
772
|
+
raise Exception(json.dumps({"type": "configuration_error", "message": error_message}, ensure_ascii=False))
|
789
773
|
# 最后一次重试失败,向上抛出异常
|
790
774
|
if i == 11:
|
791
|
-
raise Exception(
|
775
|
+
raise Exception(json.dumps({"type": "retry_failed", "message": str(e)}, ensure_ascii=False))
|
792
776
|
|
793
777
|
def ask_stream(
|
794
778
|
self,
|
@@ -915,157 +899,4 @@ class chatgpt(BaseLLM):
|
|
915
899
|
{"role": "system", "content": self.system_prompt},
|
916
900
|
]
|
917
901
|
self.tokens_usage[convo_id] = 0
|
918
|
-
self.current_tokens[convo_id] = 0
|
919
|
-
|
920
|
-
def save(self, file: str, *keys: str) -> None:
|
921
|
-
"""
|
922
|
-
Save the Chatbot configuration to a JSON file
|
923
|
-
"""
|
924
|
-
with open(file, "w", encoding="utf-8") as f:
|
925
|
-
data = {
|
926
|
-
key: self.__dict__[key]
|
927
|
-
for key in get_filtered_keys_from_object(self, *keys)
|
928
|
-
}
|
929
|
-
# saves session.proxies dict as session
|
930
|
-
# leave this here for compatibility
|
931
|
-
data["session"] = data["proxy"]
|
932
|
-
del data["aclient"]
|
933
|
-
json.dump(
|
934
|
-
data,
|
935
|
-
f,
|
936
|
-
indent=2,
|
937
|
-
)
|
938
|
-
|
939
|
-
def load(self, file: Path, *keys_: str) -> None:
|
940
|
-
"""
|
941
|
-
Load the Chatbot configuration from a JSON file
|
942
|
-
"""
|
943
|
-
with open(file, encoding="utf-8") as f:
|
944
|
-
# load json, if session is in keys, load proxies
|
945
|
-
loaded_config = json.load(f)
|
946
|
-
keys = get_filtered_keys_from_object(self, *keys_)
|
947
|
-
|
948
|
-
if (
|
949
|
-
"session" in keys
|
950
|
-
and loaded_config["session"]
|
951
|
-
or "proxy" in keys
|
952
|
-
and loaded_config["proxy"]
|
953
|
-
):
|
954
|
-
self.proxy = loaded_config.get("session", loaded_config["proxy"])
|
955
|
-
self.session = httpx.Client(
|
956
|
-
follow_redirects=True,
|
957
|
-
proxies=self.proxy,
|
958
|
-
timeout=self.timeout,
|
959
|
-
cookies=self.session.cookies,
|
960
|
-
headers=self.session.headers,
|
961
|
-
)
|
962
|
-
self.aclient = httpx.AsyncClient(
|
963
|
-
follow_redirects=True,
|
964
|
-
proxies=self.proxy,
|
965
|
-
timeout=self.timeout,
|
966
|
-
cookies=self.session.cookies,
|
967
|
-
headers=self.session.headers,
|
968
|
-
)
|
969
|
-
if "session" in keys:
|
970
|
-
keys.remove("session")
|
971
|
-
if "aclient" in keys:
|
972
|
-
keys.remove("aclient")
|
973
|
-
self.__dict__.update({key: loaded_config[key] for key in keys})
|
974
|
-
|
975
|
-
def _handle_response_error_common(self, response_text, json_post):
|
976
|
-
"""通用的响应错误处理逻辑,适用于同步和异步场景"""
|
977
|
-
try:
|
978
|
-
# 检查内容审核失败
|
979
|
-
if "Content did not pass the moral check" in response_text:
|
980
|
-
return json_post, False, f"内容未通过道德检查:{response_text[:400]}"
|
981
|
-
|
982
|
-
# 处理函数调用相关错误
|
983
|
-
if "function calling" in response_text:
|
984
|
-
if "tools" in json_post:
|
985
|
-
del json_post["tools"]
|
986
|
-
if "tool_choice" in json_post:
|
987
|
-
del json_post["tool_choice"]
|
988
|
-
return json_post, True, None
|
989
|
-
|
990
|
-
# 处理请求格式错误
|
991
|
-
elif "invalid_request_error" in response_text:
|
992
|
-
for index, mess in enumerate(json_post["messages"]):
|
993
|
-
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
994
|
-
json_post["messages"][index] = {
|
995
|
-
"role": mess["role"],
|
996
|
-
"content": mess["content"][0]["text"]
|
997
|
-
}
|
998
|
-
return json_post, True, None
|
999
|
-
|
1000
|
-
# 处理角色不允许错误
|
1001
|
-
elif "'function' is not an allowed role" in response_text:
|
1002
|
-
if json_post["messages"][-1]["role"] == "tool":
|
1003
|
-
mess = json_post["messages"][-1]
|
1004
|
-
json_post["messages"][-1] = {
|
1005
|
-
"role": "assistant",
|
1006
|
-
"name": mess["name"],
|
1007
|
-
"content": mess["content"]
|
1008
|
-
}
|
1009
|
-
return json_post, True, None
|
1010
|
-
|
1011
|
-
# 处理服务器繁忙错误
|
1012
|
-
elif "Sorry, server is busy" in response_text:
|
1013
|
-
for index, mess in enumerate(json_post["messages"]):
|
1014
|
-
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
1015
|
-
json_post["messages"][index] = {
|
1016
|
-
"role": mess["role"],
|
1017
|
-
"content": mess["content"][0]["text"]
|
1018
|
-
}
|
1019
|
-
return json_post, True, None
|
1020
|
-
|
1021
|
-
# 处理token超限错误
|
1022
|
-
elif "is not possible because the prompts occupy" in response_text:
|
1023
|
-
max_tokens = re.findall(r"only\s(\d+)\stokens", response_text)
|
1024
|
-
if max_tokens:
|
1025
|
-
json_post["max_tokens"] = int(max_tokens[0])
|
1026
|
-
return json_post, True, None
|
1027
|
-
|
1028
|
-
# 默认移除工具相关设置
|
1029
|
-
else:
|
1030
|
-
if "tools" in json_post:
|
1031
|
-
del json_post["tools"]
|
1032
|
-
if "tool_choice" in json_post:
|
1033
|
-
del json_post["tool_choice"]
|
1034
|
-
return json_post, True, None
|
1035
|
-
|
1036
|
-
except Exception as e:
|
1037
|
-
self.logger.error(f"处理响应错误时出现异常: {e}")
|
1038
|
-
return json_post, False, str(e)
|
1039
|
-
|
1040
|
-
def _handle_response_error_sync(self, response, json_post):
|
1041
|
-
"""处理API响应错误并相应地修改请求体(同步版本)"""
|
1042
|
-
response_text = response.text
|
1043
|
-
|
1044
|
-
# 处理空响应
|
1045
|
-
if response.status_code == 200 and response_text == "":
|
1046
|
-
for index, mess in enumerate(json_post["messages"]):
|
1047
|
-
if type(mess["content"]) == list and "text" in mess["content"][0]:
|
1048
|
-
json_post["messages"][index] = {
|
1049
|
-
"role": mess["role"],
|
1050
|
-
"content": mess["content"][0]["text"]
|
1051
|
-
}
|
1052
|
-
return json_post, True
|
1053
|
-
|
1054
|
-
json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
|
1055
|
-
|
1056
|
-
if error_msg:
|
1057
|
-
raise Exception(f"{response.status_code} {response.reason} {error_msg}")
|
1058
|
-
|
1059
|
-
return json_post, should_retry
|
1060
|
-
|
1061
|
-
async def _handle_response_error(self, response, json_post):
|
1062
|
-
"""处理API响应错误并相应地修改请求体(异步版本)"""
|
1063
|
-
await response.aread()
|
1064
|
-
response_text = response.text
|
1065
|
-
|
1066
|
-
json_post, should_retry, error_msg = self._handle_response_error_common(response_text, json_post)
|
1067
|
-
|
1068
|
-
if error_msg:
|
1069
|
-
raise Exception(f"{response.status_code} {response.reason_phrase} {error_msg}")
|
1070
|
-
|
1071
|
-
return json_post, should_retry
|
902
|
+
self.current_tokens[convo_id] = 0
|
aient/plugins/config.py
CHANGED
@@ -5,7 +5,7 @@ import inspect
|
|
5
5
|
from .registry import registry
|
6
6
|
from ..utils.prompt import search_key_word_prompt
|
7
7
|
|
8
|
-
async def get_tools_result_async(function_call_name, function_full_response,
|
8
|
+
async def get_tools_result_async(function_call_name, function_full_response, engine, robot, api_key, api_url, use_plugins, model, add_message, convo_id, language):
|
9
9
|
function_response = ""
|
10
10
|
function_to_call = None
|
11
11
|
call_args = json.loads(function_full_response)
|
@@ -12,10 +12,10 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
|
|
12
12
|
aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
|
13
13
|
aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
14
14
|
aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
|
15
|
-
aient/models/chatgpt.py,sha256=
|
15
|
+
aient/models/chatgpt.py,sha256=S_6s_3epO9UPWea9zyf4-cy8ekMZXNct3AuvSJtO9Pg,43056
|
16
16
|
aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
17
17
|
aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
18
|
-
aient/plugins/config.py,sha256=
|
18
|
+
aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
|
19
19
|
aient/plugins/excute_command.py,sha256=urbOFUI-Wd-XaNyH3EfNBn7vnimzF84b2uq4jMXPxYM,10642
|
20
20
|
aient/plugins/get_time.py,sha256=Ih5XIW5SDAIhrZ9W4Qe5Hs1k4ieKPUc_LAd6ySNyqZk,654
|
21
21
|
aient/plugins/image.py,sha256=ZElCIaZznE06TN9xW3DrSukS7U3A5_cjk1Jge4NzPxw,2072
|
@@ -30,8 +30,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
|
|
30
30
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
32
32
|
aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
|
33
|
-
aient-1.1.
|
34
|
-
aient-1.1.
|
35
|
-
aient-1.1.
|
36
|
-
aient-1.1.
|
37
|
-
aient-1.1.
|
33
|
+
aient-1.1.73.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
34
|
+
aient-1.1.73.dist-info/METADATA,sha256=tBDP7wGDeeRRD-UEww0-UxF5IgbmHhQV4O7qXuDVXOA,4842
|
35
|
+
aient-1.1.73.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
36
|
+
aient-1.1.73.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
37
|
+
aient-1.1.73.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|