auto-coder 0.1.205__py3-none-any.whl → 0.1.207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/METADATA +1 -1
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/RECORD +16 -10
- autocoder/agent/auto_filegroup.py +202 -0
- autocoder/auto_coder_rag.py +168 -33
- autocoder/benchmark.py +138 -0
- autocoder/chat_auto_coder.py +9 -3
- autocoder/common/chunk_validation.py +91 -0
- autocoder/common/recall_validation.py +58 -0
- autocoder/data/tokenizer.json +199865 -0
- autocoder/rag/token_counter.py +3 -3
- autocoder/utils/operate_config_api.py +148 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.205.dist-info → auto_coder-0.1.207.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
autocoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
autocoder/auto_coder.py,sha256=IP4aSmuZh3HaIzLIThTKHF0PB38eTPUU1GDdfuJBam4,38971
|
|
3
3
|
autocoder/auto_coder_lang.py,sha256=Rtupq6N3_HT7JRhDKdgCBcwRaiAnyCOR_Gsp4jUomrI,3229
|
|
4
|
-
autocoder/auto_coder_rag.py,sha256=
|
|
4
|
+
autocoder/auto_coder_rag.py,sha256=UNP69PTJA_Iz0MIF_RS26o9slSUvm3f3D7yQBxGYKAY,21578
|
|
5
5
|
autocoder/auto_coder_server.py,sha256=XU9b4SBH7zjPPXaTWWHV4_zJm-XYa6njuLQaplYJH_c,20290
|
|
6
|
-
autocoder/
|
|
6
|
+
autocoder/benchmark.py,sha256=Ypomkdzd1T3GE6dRICY3Hj547dZ6_inqJbBJIp5QMco,4423
|
|
7
|
+
autocoder/chat_auto_coder.py,sha256=PtES0XppQ4OTuMucQZFXN5j-6rXS4Rc7RL5gmBy3GZU,87182
|
|
7
8
|
autocoder/chat_auto_coder_lang.py,sha256=zU9VRY-l80fZnLJ0Op8A3wq27UhQHh9WcpSYU4SmnqU,8708
|
|
8
9
|
autocoder/command_args.py,sha256=nmVD3xgLKKNt3fLJ4wVNUdJBtBOkfAdZ2ZmaRMlS7qg,29891
|
|
9
10
|
autocoder/lang.py,sha256=Ajng6m7towmx-cvQfEHPFp43iEfddPvr8ju5GH4H8qA,13819
|
|
10
|
-
autocoder/version.py,sha256=
|
|
11
|
+
autocoder/version.py,sha256=Zou96E1H6FClEDuqHKcUc9kgUuTRvVu0lQVcFzPxZRY,24
|
|
11
12
|
autocoder/agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
autocoder/agent/auto_filegroup.py,sha256=CW7bqp0FW1GIEMnl-blyAc2UGT7O9Mom0q66ITz1ckM,6635
|
|
12
14
|
autocoder/agent/auto_tool.py,sha256=DBzip-P_T6ZtT2eHexPcusmKYD0h7ufzp7TLwXAY10E,11554
|
|
13
15
|
autocoder/agent/coder.py,sha256=x6bdJwDuETGg9ebQnYlUWCxCtQcDGg73LtI6McpWslQ,72034
|
|
14
16
|
autocoder/agent/designer.py,sha256=EpRbzO58Xym3GrnppIT1Z8ZFAlnNfgzHbIzZ3PX-Yv8,27037
|
|
@@ -20,6 +22,7 @@ autocoder/common/ShellClient.py,sha256=fM1q8t_XMSbLBl2zkCNC2J9xuyKN3eXzGm6hHhqL2
|
|
|
20
22
|
autocoder/common/__init__.py,sha256=szE6bhkWT_KDTxY73hDQDRYCaH7DDTrBuEQBxzqAEOM,11201
|
|
21
23
|
autocoder/common/anything2images.py,sha256=0ILBbWzY02M-CiWB-vzuomb_J1hVdxRcenAfIrAXq9M,25283
|
|
22
24
|
autocoder/common/audio.py,sha256=Kn9nWKQddWnUrAz0a_ZUgjcu4VUU_IcZBigT7n3N3qc,7439
|
|
25
|
+
autocoder/common/chunk_validation.py,sha256=BrR_ZWavW8IANuueEE7hS8NFAwEvm8TX34WnPx_1hs8,3030
|
|
23
26
|
autocoder/common/cleaner.py,sha256=NU72i8C6o9m0vXExab7nao5bstBUsfJFcj11cXa9l4U,1089
|
|
24
27
|
autocoder/common/code_auto_execute.py,sha256=4KXGmiGObr_B1d6tzV9dwS6MifCSc3Gm4j2d6ildBXQ,6867
|
|
25
28
|
autocoder/common/code_auto_generate.py,sha256=bREGPj2yNQ1oy35ivHRHQSXAdmIsPaW4HQ9BkO7Ytco,8099
|
|
@@ -38,12 +41,14 @@ autocoder/common/git_utils.py,sha256=btK45sxvfm4tX3fBRNUPRZoGQuZuOEQrWSAwLy1yoLw
|
|
|
38
41
|
autocoder/common/image_to_page.py,sha256=O0cNO_vHHUP-fP4GXiVojShmNqkPnZXeIyiY1MRLpKg,13936
|
|
39
42
|
autocoder/common/interpreter.py,sha256=62-dIakOunYB4yjmX8SHC0Gdy2h8NtxdgbpdqRZJ5vk,2833
|
|
40
43
|
autocoder/common/llm_rerank.py,sha256=FbvtCzaR661Mt2wn0qsuiEL1Y3puD6jeIJS4zg_e7Bs,3260
|
|
44
|
+
autocoder/common/recall_validation.py,sha256=Avt9Q9dX3kG6Pf2zsdlOHmsjd-OeSj7U1PFBDp_Cve0,1700
|
|
41
45
|
autocoder/common/screenshots.py,sha256=_gA-z1HxGjPShBrtgkdideq58MG6rqFB2qMUJKjrycs,3769
|
|
42
46
|
autocoder/common/search.py,sha256=_ZX03ph89rDPGMY1OrfqaDfxsDR-flh6YEHixherjwM,16616
|
|
43
47
|
autocoder/common/search_replace.py,sha256=GphFkc57Hb673CAwmbiocqTbw8vrV7TrZxtOhD0332g,22147
|
|
44
48
|
autocoder/common/sys_prompt.py,sha256=JlexfjZt554faqbgkCmzOJqYUzDHfbnxly5ugFfHfEE,26403
|
|
45
49
|
autocoder/common/text.py,sha256=KGRQq314GHBmY4MWG8ossRoQi1_DTotvhxchpn78c-k,1003
|
|
46
50
|
autocoder/common/types.py,sha256=_kl8q2oCAZe-HqU9LjQgSqVHAw4ojMojlF58-bp2xiE,354
|
|
51
|
+
autocoder/data/tokenizer.json,sha256=QfO_ZCE9qMAS2L0IcaWKH99wRj6PCPEQ3bsQgvUp9mk,4607451
|
|
47
52
|
autocoder/db/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
48
53
|
autocoder/db/store.py,sha256=tFT66bP2ZKIqZip-uhLkHRSLaaOAUUDZfozJwcqix3c,1908
|
|
49
54
|
autocoder/dispacher/__init__.py,sha256=YoA64dIxnx4jcE1pwSfg81sjkQtjDkhddkfac1-cMWo,1230
|
|
@@ -71,7 +76,7 @@ autocoder/rag/relevant_utils.py,sha256=OGfp98OXG4jr3jNmtHIeXGPF8mOlIbTnolPIVTZzY
|
|
|
71
76
|
autocoder/rag/simple_directory_reader.py,sha256=LkKreCkNdEOoL4fNhc3_hDoyyWTQUte4uqextISRz4U,24485
|
|
72
77
|
autocoder/rag/simple_rag.py,sha256=I902EUqOK1WM0Y2WFd7RzDJYofElvTZNLVCBtX5A9rc,14885
|
|
73
78
|
autocoder/rag/token_checker.py,sha256=jc76x6KWmvVxds6W8juZfQGaoErudc2HenG3sNQfSLs,2819
|
|
74
|
-
autocoder/rag/token_counter.py,sha256=
|
|
79
|
+
autocoder/rag/token_counter.py,sha256=C-Lwc4oIjJpZDEqp9WLHGOe6hb4yhrdJpMtkrtp_1qc,2125
|
|
75
80
|
autocoder/rag/token_limiter.py,sha256=dtIxCtHswZ2ut-XKbx8_SiWyv-xqnR1WAIcmh6f8Ktw,11137
|
|
76
81
|
autocoder/rag/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
82
|
autocoder/rag/utils.py,sha256=MOMy0l2_YQ8foWnQQSvYmYAwghAWJ8_wVrdZh5DSaTg,4822
|
|
@@ -100,15 +105,16 @@ autocoder/utils/conversation_store.py,sha256=sz-hhY7sttPAUOAQU6Pze-5zJc3j0_Emj22
|
|
|
100
105
|
autocoder/utils/llm_client_interceptors.py,sha256=FEHNXoFZlCjAHQcjPRyX8FOMjo6rPXpO2AJ2zn2KTTo,901
|
|
101
106
|
autocoder/utils/log_capture.py,sha256=I-bsJFLWoGUiX-GKoZsH9kWJCKSV7ZlUnRt7jh-fOL0,1548
|
|
102
107
|
autocoder/utils/multi_turn.py,sha256=unK9OpqVRbK6uIcTKXgggX2wNmyj7s5eyEAQ2xUwHoM,88
|
|
108
|
+
autocoder/utils/operate_config_api.py,sha256=99YAKsuUFLPwrRvj0CJal_bAPgyiXWMma6ZKMU56thw,5790
|
|
103
109
|
autocoder/utils/print_table.py,sha256=ZMRhCA9DD0FUfKyJBWd5bDdj1RrtPtgOMWSJwtvZcLs,403
|
|
104
110
|
autocoder/utils/queue_communicate.py,sha256=buyEzdvab1QA4i2QKbq35rG5v_9x9PWVLWWMTznWcYM,6832
|
|
105
111
|
autocoder/utils/request_event_queue.py,sha256=r3lo5qGsB1dIjzVQ05dnr0z_9Z3zOkBdP1vmRciKdi4,2095
|
|
106
112
|
autocoder/utils/request_queue.py,sha256=nwp6PMtgTCiuwJI24p8OLNZjUiprC-TsefQrhMI-yPE,3889
|
|
107
113
|
autocoder/utils/rest.py,sha256=HawagAap3wMIDROGhY1730zSZrJR_EycODAA5qOj83c,8807
|
|
108
114
|
autocoder/utils/tests.py,sha256=BqphrwyycGAvs-5mhH8pKtMZdObwhFtJ5MC_ZAOiLq8,1340
|
|
109
|
-
auto_coder-0.1.
|
|
110
|
-
auto_coder-0.1.
|
|
111
|
-
auto_coder-0.1.
|
|
112
|
-
auto_coder-0.1.
|
|
113
|
-
auto_coder-0.1.
|
|
114
|
-
auto_coder-0.1.
|
|
115
|
+
auto_coder-0.1.207.dist-info/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
116
|
+
auto_coder-0.1.207.dist-info/METADATA,sha256=CIJPaXnDWOPKCFvDnjZHWQEmftK9Jde99NTbFdxANng,2575
|
|
117
|
+
auto_coder-0.1.207.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
118
|
+
auto_coder-0.1.207.dist-info/entry_points.txt,sha256=0nzHtHH4pNcM7xq4EBA2toS28Qelrvcbrr59GqD_0Ak,350
|
|
119
|
+
auto_coder-0.1.207.dist-info/top_level.txt,sha256=Jqc0_uJSw2GwoFQAa9iJxYns-2mWla-9ok_Y3Gcznjk,10
|
|
120
|
+
auto_coder-0.1.207.dist-info/RECORD,,
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from typing import List, Dict, Optional, Any, Tuple
|
|
2
|
+
import os
|
|
3
|
+
import yaml
|
|
4
|
+
from loguru import logger
|
|
5
|
+
import byzerllm
|
|
6
|
+
import pydantic
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FileGroup(pydantic.BaseModel):
|
|
10
|
+
name: str
|
|
11
|
+
description: str
|
|
12
|
+
queries: List[str]
|
|
13
|
+
urls: List[str]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FileGroups(pydantic.BaseModel):
|
|
17
|
+
groups: List[FileGroup]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_yaml_config(yaml_file: str) -> Dict:
|
|
21
|
+
"""加载YAML配置文件"""
|
|
22
|
+
try:
|
|
23
|
+
with open(yaml_file, 'r', encoding='utf-8') as f:
|
|
24
|
+
return yaml.safe_load(f)
|
|
25
|
+
except Exception as e:
|
|
26
|
+
logger.error(f"Error loading yaml file {yaml_file}: {str(e)}")
|
|
27
|
+
return {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AutoFileGroup:
|
|
31
|
+
def __init__(self, llm: byzerllm.ByzerLLM,
|
|
32
|
+
project_dir: str,
|
|
33
|
+
skip_diff: bool = False,
|
|
34
|
+
group_num_limit: int = 10,
|
|
35
|
+
file_size_limit: int = 100):
|
|
36
|
+
"""
|
|
37
|
+
初始化AutoFileGroup
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
actions_dir: 包含YAML文件的目录
|
|
41
|
+
"""
|
|
42
|
+
self.project_dir = project_dir
|
|
43
|
+
self.actions_dir = os.path.join(project_dir, "actions")
|
|
44
|
+
self.llm = llm
|
|
45
|
+
self.file_size_limit = file_size_limit
|
|
46
|
+
self.skip_diff = skip_diff
|
|
47
|
+
self.group_num_limit = group_num_limit
|
|
48
|
+
|
|
49
|
+
@byzerllm.prompt()
|
|
50
|
+
def group_by_similarity(self, querie_with_urls: List[Tuple[str, List[str], str]]) -> str:
|
|
51
|
+
"""
|
|
52
|
+
分析多个开发任务的关联性,将相互关联的任务进行分组。
|
|
53
|
+
|
|
54
|
+
输入说明:
|
|
55
|
+
querie_with_urls 包含多个开发任务信息,每个任务由以下部分组成:
|
|
56
|
+
1. query: 任务需求描述
|
|
57
|
+
2. urls: 需要修改的文件路径列表
|
|
58
|
+
3. diff: Git diff信息,展示具体的代码修改
|
|
59
|
+
|
|
60
|
+
示例数据:
|
|
61
|
+
<queries>
|
|
62
|
+
{% for query,urls,diff in querie_with_urls %}
|
|
63
|
+
## {{ query }}
|
|
64
|
+
|
|
65
|
+
修改的文件:
|
|
66
|
+
{% for url in urls %}
|
|
67
|
+
- {{ url }}
|
|
68
|
+
{% endfor %}
|
|
69
|
+
{% if diff %}
|
|
70
|
+
|
|
71
|
+
代码变更:
|
|
72
|
+
```diff
|
|
73
|
+
{{ diff }}
|
|
74
|
+
```
|
|
75
|
+
{% endif %}
|
|
76
|
+
{% endfor %}
|
|
77
|
+
</queries>
|
|
78
|
+
|
|
79
|
+
分组规则:
|
|
80
|
+
1. 每个分组至少包含2个query
|
|
81
|
+
2. 根据以下维度判断任务的关联性:
|
|
82
|
+
- 功能相似性:任务是否属于同一个功能模块
|
|
83
|
+
- 文件关联:修改的文件是否有重叠或紧密关联
|
|
84
|
+
- 代码依赖:代码修改是否存在依赖关系
|
|
85
|
+
- 业务目的:任务的最终业务目标是否一致
|
|
86
|
+
3. 输出的分组数量最多不超过 {{ group_num_limit }}
|
|
87
|
+
|
|
88
|
+
期望输出:
|
|
89
|
+
返回符合以下格式的JSON:
|
|
90
|
+
{
|
|
91
|
+
"groups": [
|
|
92
|
+
{
|
|
93
|
+
"name": "分组名称",
|
|
94
|
+
"description": "分组的功能概述,描述该组任务的共同目标",
|
|
95
|
+
"queries": ["相关的query1", "相关的query2"],
|
|
96
|
+
"urls": ["相关的文件1", "相关的文件2"]
|
|
97
|
+
}
|
|
98
|
+
]
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
特别说明:
|
|
102
|
+
1. 分组名称应该简洁且具有描述性,能反映该组任务的主要特征
|
|
103
|
+
2. 分组描述应突出任务间的共同点和关联性
|
|
104
|
+
3. 返回的urls应该是该组任务涉及的所有相关文件的并集
|
|
105
|
+
"""
|
|
106
|
+
return {
|
|
107
|
+
"group_num_limit": self.group_num_limit
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def group_files(self) -> List[Dict]:
|
|
112
|
+
"""
|
|
113
|
+
根据YAML文件中的query和urls进行文件分组,并获取相关的git commit信息
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List[Dict]: 分组结果列表
|
|
117
|
+
"""
|
|
118
|
+
import git
|
|
119
|
+
import hashlib
|
|
120
|
+
|
|
121
|
+
# 获取所有YAML文件
|
|
122
|
+
action_files = [
|
|
123
|
+
f for f in os.listdir(self.actions_dir)
|
|
124
|
+
if f[:3].isdigit() and "_" in f and f.endswith('.yml')
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
# 按序号排序
|
|
128
|
+
def get_seq(name):
|
|
129
|
+
return int(name.split("_")[0])
|
|
130
|
+
|
|
131
|
+
# 获取最新的action文件列表
|
|
132
|
+
action_files = sorted(action_files, key=get_seq)
|
|
133
|
+
action_files.reverse()
|
|
134
|
+
|
|
135
|
+
action_files = action_files[:self.file_size_limit]
|
|
136
|
+
|
|
137
|
+
querie_with_urls_and_diffs = []
|
|
138
|
+
repo = git.Repo(self.project_dir)
|
|
139
|
+
|
|
140
|
+
# 收集所有query、urls和对应的commit diff
|
|
141
|
+
for yaml_file in action_files:
|
|
142
|
+
yaml_path = os.path.join(self.actions_dir, yaml_file)
|
|
143
|
+
config = load_yaml_config(yaml_path)
|
|
144
|
+
|
|
145
|
+
if not config:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
query = config.get('query', '')
|
|
149
|
+
urls = config.get('urls', [])
|
|
150
|
+
|
|
151
|
+
if query and urls:
|
|
152
|
+
commit_diff = ""
|
|
153
|
+
if not self.skip_diff:
|
|
154
|
+
# 计算文件的MD5用于匹配commit
|
|
155
|
+
file_md5 = hashlib.md5(open(yaml_path, 'rb').read()).hexdigest()
|
|
156
|
+
response_id = f"auto_coder_{yaml_file}_{file_md5}"
|
|
157
|
+
# 查找对应的commit
|
|
158
|
+
try:
|
|
159
|
+
for commit in repo.iter_commits():
|
|
160
|
+
if response_id in commit.message:
|
|
161
|
+
if commit.parents:
|
|
162
|
+
parent = commit.parents[0]
|
|
163
|
+
commit_diff = repo.git.diff(
|
|
164
|
+
parent.hexsha, commit.hexsha)
|
|
165
|
+
else:
|
|
166
|
+
commit_diff = repo.git.show(commit.hexsha)
|
|
167
|
+
break
|
|
168
|
+
except git.exc.GitCommandError as e:
|
|
169
|
+
logger.error(f"Git命令执行错误: {str(e)}")
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"获取commit diff时出错: {str(e)}")
|
|
172
|
+
|
|
173
|
+
querie_with_urls_and_diffs.append((query, urls, commit_diff))
|
|
174
|
+
|
|
175
|
+
if not querie_with_urls_and_diffs:
|
|
176
|
+
return []
|
|
177
|
+
|
|
178
|
+
# 使用LLM进行分组
|
|
179
|
+
try:
|
|
180
|
+
result = self.group_by_similarity.with_llm(self.llm).with_return_type(FileGroups).run(
|
|
181
|
+
querie_with_urls=querie_with_urls_and_diffs
|
|
182
|
+
)
|
|
183
|
+
return result.groups
|
|
184
|
+
except Exception as e:
|
|
185
|
+
import traceback
|
|
186
|
+
traceback.print_exc()
|
|
187
|
+
logger.error(f"Error during grouping: {str(e)}")
|
|
188
|
+
return []
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def create_file_groups(actions_dir: str) -> List[Dict]:
|
|
192
|
+
"""
|
|
193
|
+
创建文件分组的便捷函数
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
actions_dir: YAML文件所在目录
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
List[Dict]: 分组结果,每个字典包含name, queries和urls
|
|
200
|
+
"""
|
|
201
|
+
grouper = AutoFileGroup(actions_dir)
|
|
202
|
+
return grouper.group_files()
|
autocoder/auto_coder_rag.py
CHANGED
|
@@ -18,8 +18,10 @@ from rich.console import Console
|
|
|
18
18
|
from rich.table import Table
|
|
19
19
|
import os
|
|
20
20
|
from loguru import logger
|
|
21
|
+
import asyncio
|
|
21
22
|
|
|
22
23
|
from autocoder.rag.document_retriever import process_file_local
|
|
24
|
+
import pkg_resources
|
|
23
25
|
from autocoder.rag.token_counter import TokenCounter
|
|
24
26
|
|
|
25
27
|
if platform.system() == "Windows":
|
|
@@ -139,6 +141,13 @@ def initialize_system():
|
|
|
139
141
|
|
|
140
142
|
def main(input_args: Optional[List[str]] = None):
|
|
141
143
|
|
|
144
|
+
try:
|
|
145
|
+
tokenizer_path = pkg_resources.resource_filename(
|
|
146
|
+
"autocoder", "data/tokenizer.json"
|
|
147
|
+
)
|
|
148
|
+
except FileNotFoundError:
|
|
149
|
+
tokenizer_path = None
|
|
150
|
+
|
|
142
151
|
system_lang, _ = locale.getdefaultlocale()
|
|
143
152
|
lang = "zh" if system_lang and system_lang.startswith("zh") else "en"
|
|
144
153
|
desc = lang_desc[lang]
|
|
@@ -146,18 +155,38 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
146
155
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
147
156
|
|
|
148
157
|
# Build hybrid index command
|
|
149
|
-
build_index_parser = subparsers.add_parser(
|
|
150
|
-
|
|
158
|
+
build_index_parser = subparsers.add_parser(
|
|
159
|
+
"build_hybrid_index", help="Build hybrid index for RAG"
|
|
160
|
+
)
|
|
161
|
+
build_index_parser.add_argument(
|
|
162
|
+
"--quick", action="store_true", help="Skip system initialization"
|
|
163
|
+
)
|
|
151
164
|
build_index_parser.add_argument("--file", default="", help=desc["file"])
|
|
152
|
-
build_index_parser.add_argument(
|
|
153
|
-
|
|
165
|
+
build_index_parser.add_argument(
|
|
166
|
+
"--model", default="deepseek_chat", help=desc["model"]
|
|
167
|
+
)
|
|
168
|
+
build_index_parser.add_argument(
|
|
169
|
+
"--index_model", default="", help=desc["index_model"]
|
|
170
|
+
)
|
|
154
171
|
build_index_parser.add_argument("--emb_model", default="", help=desc["emb_model"])
|
|
155
|
-
build_index_parser.add_argument(
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
build_index_parser.add_argument(
|
|
159
|
-
|
|
160
|
-
|
|
172
|
+
build_index_parser.add_argument(
|
|
173
|
+
"--ray_address", default="auto", help=desc["ray_address"]
|
|
174
|
+
)
|
|
175
|
+
build_index_parser.add_argument(
|
|
176
|
+
"--required_exts", default="", help=desc["doc_build_parse_required_exts"]
|
|
177
|
+
)
|
|
178
|
+
build_index_parser.add_argument(
|
|
179
|
+
"--source_dir", default=".", help="Source directory path"
|
|
180
|
+
)
|
|
181
|
+
build_index_parser.add_argument(
|
|
182
|
+
"--tokenizer_path", default=tokenizer_path, help="Path to tokenizer file"
|
|
183
|
+
)
|
|
184
|
+
build_index_parser.add_argument(
|
|
185
|
+
"--doc_dir", default="", help="Document directory path"
|
|
186
|
+
)
|
|
187
|
+
build_index_parser.add_argument(
|
|
188
|
+
"--enable_hybrid_index", action="store_true", help="Enable hybrid index"
|
|
189
|
+
)
|
|
161
190
|
|
|
162
191
|
# Serve command
|
|
163
192
|
serve_parser = subparsers.add_parser("serve", help="Start the RAG server")
|
|
@@ -220,7 +249,7 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
220
249
|
serve_parser.add_argument("--ssl_certfile", default="", help="")
|
|
221
250
|
serve_parser.add_argument("--response_role", default="assistant", help="")
|
|
222
251
|
serve_parser.add_argument("--doc_dir", default="", help="")
|
|
223
|
-
serve_parser.add_argument("--tokenizer_path", default=
|
|
252
|
+
serve_parser.add_argument("--tokenizer_path", default=tokenizer_path, help="")
|
|
224
253
|
serve_parser.add_argument(
|
|
225
254
|
"--collections", default="", help="Collection name for indexing"
|
|
226
255
|
)
|
|
@@ -282,7 +311,7 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
282
311
|
|
|
283
312
|
serve_parser.add_argument(
|
|
284
313
|
"--without_contexts",
|
|
285
|
-
action="store_true",
|
|
314
|
+
action="store_true",
|
|
286
315
|
help="Whether to return responses without contexts. only works when pro plugin is installed",
|
|
287
316
|
)
|
|
288
317
|
|
|
@@ -304,14 +333,73 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
304
333
|
help="The model used for question answering",
|
|
305
334
|
)
|
|
306
335
|
|
|
336
|
+
# Benchmark command
|
|
337
|
+
benchmark_parser = subparsers.add_parser(
|
|
338
|
+
"benchmark", help="Benchmark LLM client performance"
|
|
339
|
+
)
|
|
340
|
+
benchmark_parser.add_argument(
|
|
341
|
+
"--model", default="deepseek_chat", help="Model to benchmark"
|
|
342
|
+
)
|
|
343
|
+
benchmark_parser.add_argument(
|
|
344
|
+
"--parallel", type=int, default=10, help="Number of parallel requests"
|
|
345
|
+
)
|
|
346
|
+
benchmark_parser.add_argument(
|
|
347
|
+
"--rounds", type=int, default=1, help="Number of rounds to run"
|
|
348
|
+
)
|
|
349
|
+
benchmark_parser.add_argument(
|
|
350
|
+
"--type",
|
|
351
|
+
choices=["openai", "byzerllm"],
|
|
352
|
+
default="byzerllm",
|
|
353
|
+
help="Client type to benchmark",
|
|
354
|
+
)
|
|
355
|
+
benchmark_parser.add_argument(
|
|
356
|
+
"--api_key", default="", help="OpenAI API key for OpenAI client"
|
|
357
|
+
)
|
|
358
|
+
benchmark_parser.add_argument(
|
|
359
|
+
"--base_url", default="", help="Base URL for OpenAI client"
|
|
360
|
+
)
|
|
361
|
+
benchmark_parser.add_argument(
|
|
362
|
+
"--query", default="Hello, how are you?", help="Query to use for benchmarking"
|
|
363
|
+
)
|
|
364
|
+
|
|
307
365
|
# Tools command
|
|
308
366
|
tools_parser = subparsers.add_parser("tools", help="Various tools")
|
|
309
367
|
tools_subparsers = tools_parser.add_subparsers(dest="tool", help="Available tools")
|
|
310
368
|
|
|
311
369
|
# Count tool
|
|
312
370
|
count_parser = tools_subparsers.add_parser("count", help="Count tokens in a file")
|
|
371
|
+
|
|
372
|
+
# Recall validation tool
|
|
373
|
+
recall_parser = tools_subparsers.add_parser(
|
|
374
|
+
"recall", help="Validate recall model performance"
|
|
375
|
+
)
|
|
376
|
+
recall_parser.add_argument(
|
|
377
|
+
"--model", required=True, help="Model to use for recall validation"
|
|
378
|
+
)
|
|
379
|
+
recall_parser.add_argument(
|
|
380
|
+
"--content", default=None, help="Content to validate against"
|
|
381
|
+
)
|
|
382
|
+
recall_parser.add_argument(
|
|
383
|
+
"--query", default=None, help="Query to use for validation"
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Add chunk model validation tool
|
|
387
|
+
chunk_parser = tools_subparsers.add_parser(
|
|
388
|
+
"chunk", help="Validate chunk model performance"
|
|
389
|
+
)
|
|
390
|
+
chunk_parser.add_argument(
|
|
391
|
+
"--model", required=True, help="Model to use for chunk validation"
|
|
392
|
+
)
|
|
393
|
+
chunk_parser.add_argument(
|
|
394
|
+
"--content", default=None, help="Content to validate against"
|
|
395
|
+
)
|
|
396
|
+
chunk_parser.add_argument(
|
|
397
|
+
"--query", default=None, help="Query to use for validation"
|
|
398
|
+
)
|
|
313
399
|
count_parser.add_argument(
|
|
314
|
-
"--tokenizer_path",
|
|
400
|
+
"--tokenizer_path",
|
|
401
|
+
default=tokenizer_path,
|
|
402
|
+
help="Path to the tokenizer",
|
|
315
403
|
)
|
|
316
404
|
count_parser.add_argument(
|
|
317
405
|
"--file", required=True, help="Path to the file to count tokens"
|
|
@@ -319,7 +407,22 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
319
407
|
|
|
320
408
|
args = parser.parse_args(input_args)
|
|
321
409
|
|
|
322
|
-
if args.command == "
|
|
410
|
+
if args.command == "benchmark":
|
|
411
|
+
from .benchmark import benchmark_openai, benchmark_byzerllm
|
|
412
|
+
|
|
413
|
+
if args.type == "openai":
|
|
414
|
+
if not args.api_key:
|
|
415
|
+
print("OpenAI API key is required for OpenAI client benchmark")
|
|
416
|
+
return
|
|
417
|
+
asyncio.run(
|
|
418
|
+
benchmark_openai(
|
|
419
|
+
args.model, args.parallel, args.api_key, args.base_url, args.rounds, args.query
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
else: # byzerllm
|
|
423
|
+
benchmark_byzerllm(args.model, args.parallel, args.rounds, args.query)
|
|
424
|
+
|
|
425
|
+
elif args.command == "serve":
|
|
323
426
|
if not args.quick:
|
|
324
427
|
initialize_system()
|
|
325
428
|
server_args = ServerArgs(
|
|
@@ -337,14 +440,17 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
337
440
|
}
|
|
338
441
|
)
|
|
339
442
|
|
|
340
|
-
if auto_coder_args.enable_hybrid_index:
|
|
341
|
-
# 尝试连接storage
|
|
443
|
+
if auto_coder_args.enable_hybrid_index:
|
|
444
|
+
# 尝试连接storage
|
|
342
445
|
try:
|
|
343
446
|
from byzerllm.apps.byzer_storage.simple_api import ByzerStorage
|
|
447
|
+
|
|
344
448
|
storage = ByzerStorage("byzerai_store", "rag", "files")
|
|
345
449
|
storage.retrieval.cluster_info("byzerai_store")
|
|
346
450
|
except Exception as e:
|
|
347
|
-
logger.error(
|
|
451
|
+
logger.error(
|
|
452
|
+
"When enable_hybrid_index is true, ByzerStorage must be started"
|
|
453
|
+
)
|
|
348
454
|
logger.error("Please run 'byzerllm storage start' first")
|
|
349
455
|
return
|
|
350
456
|
else:
|
|
@@ -369,12 +475,14 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
369
475
|
llm.setup_sub_client("qa_model", qa_model)
|
|
370
476
|
|
|
371
477
|
# 当启用hybrid_index时,检查必要的组件
|
|
372
|
-
if auto_coder_args.enable_hybrid_index:
|
|
478
|
+
if auto_coder_args.enable_hybrid_index:
|
|
373
479
|
if not llm.is_model_exist("emb"):
|
|
374
|
-
logger.error(
|
|
480
|
+
logger.error(
|
|
481
|
+
"When enable_hybrid_index is true, an 'emb' model must be deployed"
|
|
482
|
+
)
|
|
375
483
|
return
|
|
376
484
|
llm.setup_default_emb_model_name("emb")
|
|
377
|
-
|
|
485
|
+
|
|
378
486
|
if server_args.doc_dir:
|
|
379
487
|
auto_coder_args.rag_type = "simple"
|
|
380
488
|
rag = RAGFactory.get_rag(
|
|
@@ -391,7 +499,7 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
391
499
|
elif args.command == "build_hybrid_index":
|
|
392
500
|
if not args.quick:
|
|
393
501
|
initialize_system()
|
|
394
|
-
|
|
502
|
+
|
|
395
503
|
auto_coder_args = AutoCoderArgs(
|
|
396
504
|
**{
|
|
397
505
|
arg: getattr(args, arg)
|
|
@@ -402,25 +510,30 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
402
510
|
|
|
403
511
|
auto_coder_args.enable_hybrid_index = True
|
|
404
512
|
auto_coder_args.rag_type = "simple"
|
|
405
|
-
|
|
513
|
+
|
|
406
514
|
try:
|
|
407
515
|
from byzerllm.apps.byzer_storage.simple_api import ByzerStorage
|
|
516
|
+
|
|
408
517
|
storage = ByzerStorage("byzerai_store", "rag", "files")
|
|
409
518
|
storage.retrieval.cluster_info("byzerai_store")
|
|
410
519
|
except Exception as e:
|
|
411
|
-
logger.error(
|
|
520
|
+
logger.error(
|
|
521
|
+
"When enable_hybrid_index is true, ByzerStorage must be started"
|
|
522
|
+
)
|
|
412
523
|
logger.error("Please run 'byzerllm storage start' first")
|
|
413
524
|
return
|
|
414
|
-
|
|
525
|
+
|
|
415
526
|
llm = byzerllm.ByzerLLM()
|
|
416
527
|
llm.setup_default_model_name(args.model)
|
|
417
528
|
|
|
418
529
|
# 当启用hybrid_index时,检查必要的组件
|
|
419
|
-
if auto_coder_args.enable_hybrid_index:
|
|
530
|
+
if auto_coder_args.enable_hybrid_index:
|
|
420
531
|
if not llm.is_model_exist("emb"):
|
|
421
|
-
logger.error(
|
|
532
|
+
logger.error(
|
|
533
|
+
"When enable_hybrid_index is true, an 'emb' model must be deployed"
|
|
534
|
+
)
|
|
422
535
|
return
|
|
423
|
-
llm.setup_default_emb_model_name("emb")
|
|
536
|
+
llm.setup_default_emb_model_name("emb")
|
|
424
537
|
|
|
425
538
|
rag = RAGFactory.get_rag(
|
|
426
539
|
llm=llm,
|
|
@@ -428,19 +541,41 @@ def main(input_args: Optional[List[str]] = None):
|
|
|
428
541
|
path=args.doc_dir,
|
|
429
542
|
tokenizer_path=args.tokenizer_path,
|
|
430
543
|
)
|
|
431
|
-
|
|
544
|
+
|
|
432
545
|
if hasattr(rag.document_retriever, "cacher"):
|
|
433
546
|
rag.document_retriever.cacher.build_cache()
|
|
434
547
|
else:
|
|
435
|
-
logger.error(
|
|
548
|
+
logger.error(
|
|
549
|
+
"The document retriever does not support hybrid index building"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
elif args.command == "tools":
|
|
553
|
+
if args.tool == "count":
|
|
554
|
+
# auto-coder.rag tools count --tokenizer_path /Users/allwefantasy/Downloads/tokenizer.json --file /Users/allwefantasy/data/yum/schema/schema.xlsx
|
|
555
|
+
count_tokens(args.tokenizer_path, args.file)
|
|
556
|
+
elif args.tool == "recall":
|
|
557
|
+
from .common.recall_validation import validate_recall
|
|
558
|
+
|
|
559
|
+
llm = byzerllm.ByzerLLM.from_default_model(args.model)
|
|
560
|
+
|
|
561
|
+
content = None if not args.content else [args.content]
|
|
562
|
+
result = validate_recall(llm, content=content, query=args.query)
|
|
563
|
+
print(f"Recall Validation Result:\n{result}")
|
|
564
|
+
elif args.tool == "chunk":
|
|
565
|
+
from .common.chunk_validation import validate_chunk
|
|
436
566
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
567
|
+
llm = byzerllm.ByzerLLM.from_default_model(args.model)
|
|
568
|
+
content = None if not args.content else [args.content]
|
|
569
|
+
result = validate_chunk(llm, content=content, query=args.query)
|
|
570
|
+
print(f"Chunk Model Validation Result:\n{result}")
|
|
440
571
|
|
|
441
572
|
|
|
442
573
|
def count_tokens(tokenizer_path: str, file_path: str):
|
|
443
|
-
|
|
574
|
+
from autocoder.rag.variable_holder import VariableHolder
|
|
575
|
+
from tokenizers import Tokenizer
|
|
576
|
+
VariableHolder.TOKENIZER_PATH = tokenizer_path
|
|
577
|
+
VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(tokenizer_path)
|
|
578
|
+
token_counter = TokenCounter(tokenizer_path)
|
|
444
579
|
source_codes = process_file_local(file_path)
|
|
445
580
|
|
|
446
581
|
console = Console()
|