auto-coder 0.1.399__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/METADATA +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/RECORD +71 -35
- autocoder/agent/agentic_filter.py +1 -1
- autocoder/agent/base_agentic/tools/read_file_tool_resolver.py +1 -1
- autocoder/auto_coder_runner.py +121 -26
- autocoder/chat_auto_coder.py +81 -22
- autocoder/commands/auto_command.py +1 -1
- autocoder/common/__init__.py +2 -2
- autocoder/common/ac_style_command_parser/parser.py +27 -12
- autocoder/common/auto_coder_lang.py +78 -0
- autocoder/common/command_completer_v2.py +1 -1
- autocoder/common/file_monitor/test_file_monitor.py +307 -0
- autocoder/common/git_utils.py +7 -2
- autocoder/common/pruner/__init__.py +0 -0
- autocoder/common/pruner/agentic_conversation_pruner.py +197 -0
- autocoder/common/pruner/context_pruner.py +574 -0
- autocoder/common/pruner/conversation_pruner.py +132 -0
- autocoder/common/pruner/test_agentic_conversation_pruner.py +342 -0
- autocoder/common/pruner/test_context_pruner.py +546 -0
- autocoder/common/pull_requests/__init__.py +256 -0
- autocoder/common/pull_requests/base_provider.py +191 -0
- autocoder/common/pull_requests/config.py +66 -0
- autocoder/common/pull_requests/example.py +1 -0
- autocoder/common/pull_requests/exceptions.py +46 -0
- autocoder/common/pull_requests/manager.py +201 -0
- autocoder/common/pull_requests/models.py +164 -0
- autocoder/common/pull_requests/providers/__init__.py +23 -0
- autocoder/common/pull_requests/providers/gitcode_provider.py +19 -0
- autocoder/common/pull_requests/providers/gitee_provider.py +20 -0
- autocoder/common/pull_requests/providers/github_provider.py +214 -0
- autocoder/common/pull_requests/providers/gitlab_provider.py +29 -0
- autocoder/common/pull_requests/test_module.py +1 -0
- autocoder/common/pull_requests/utils.py +344 -0
- autocoder/common/tokens/__init__.py +77 -0
- autocoder/common/tokens/counter.py +231 -0
- autocoder/common/tokens/file_detector.py +105 -0
- autocoder/common/tokens/filters.py +111 -0
- autocoder/common/tokens/models.py +28 -0
- autocoder/common/v2/agent/agentic_edit.py +538 -590
- autocoder/common/v2/agent/agentic_edit_tools/__init__.py +8 -1
- autocoder/common/v2/agent/agentic_edit_tools/ac_mod_read_tool_resolver.py +40 -0
- autocoder/common/v2/agent/agentic_edit_tools/ac_mod_write_tool_resolver.py +43 -0
- autocoder/common/v2/agent/agentic_edit_tools/ask_followup_question_tool_resolver.py +8 -0
- autocoder/common/v2/agent/agentic_edit_tools/execute_command_tool_resolver.py +1 -1
- autocoder/common/v2/agent/agentic_edit_tools/read_file_tool_resolver.py +1 -1
- autocoder/common/v2/agent/agentic_edit_tools/search_files_tool_resolver.py +33 -88
- autocoder/common/v2/agent/agentic_edit_tools/test_write_to_file_tool_resolver.py +8 -8
- autocoder/common/v2/agent/agentic_edit_tools/todo_read_tool_resolver.py +118 -0
- autocoder/common/v2/agent/agentic_edit_tools/todo_write_tool_resolver.py +324 -0
- autocoder/common/v2/agent/agentic_edit_types.py +47 -4
- autocoder/common/v2/agent/runner/__init__.py +31 -0
- autocoder/common/v2/agent/runner/base_runner.py +106 -0
- autocoder/common/v2/agent/runner/event_runner.py +216 -0
- autocoder/common/v2/agent/runner/sdk_runner.py +40 -0
- autocoder/common/v2/agent/runner/terminal_runner.py +283 -0
- autocoder/common/v2/agent/runner/tool_display.py +191 -0
- autocoder/index/entry.py +1 -1
- autocoder/plugins/token_helper_plugin.py +107 -7
- autocoder/run_context.py +9 -0
- autocoder/sdk/__init__.py +114 -81
- autocoder/sdk/cli/handlers.py +2 -1
- autocoder/sdk/cli/main.py +9 -2
- autocoder/sdk/cli/options.py +4 -3
- autocoder/sdk/core/auto_coder_core.py +7 -152
- autocoder/sdk/core/bridge.py +5 -4
- autocoder/sdk/models/options.py +8 -6
- autocoder/version.py +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 模块工具函数
|
|
3
|
+
"""
|
|
4
|
+
import re
|
|
5
|
+
import os
|
|
6
|
+
import subprocess
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from .models import PlatformType, RepoInfo
|
|
13
|
+
from .exceptions import ValidationError
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def parse_git_url(url: str) -> Tuple[Optional[PlatformType], Optional[str], Optional[str]]:
|
|
17
|
+
"""
|
|
18
|
+
解析Git URL,提取平台类型、所有者和仓库名
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tuple[平台类型, 所有者, 仓库名]
|
|
22
|
+
"""
|
|
23
|
+
if not url:
|
|
24
|
+
return None, None, None
|
|
25
|
+
|
|
26
|
+
platform_domains = {
|
|
27
|
+
'github.com': PlatformType.GITHUB,
|
|
28
|
+
'gitlab.com': PlatformType.GITLAB,
|
|
29
|
+
'gitee.com': PlatformType.GITEE,
|
|
30
|
+
'gitcode.net': PlatformType.GITCODE
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# SSH URL 格式: git@domain:owner/repo.git
|
|
34
|
+
ssh_pattern = r'^git@([^:]+):([^/]+)/([^/]+?)(?:\.git)?/?$'
|
|
35
|
+
ssh_match = re.match(ssh_pattern, url)
|
|
36
|
+
|
|
37
|
+
if ssh_match:
|
|
38
|
+
domain, owner, repo = ssh_match.groups()
|
|
39
|
+
platform = platform_domains.get(domain)
|
|
40
|
+
return platform, owner, repo
|
|
41
|
+
|
|
42
|
+
# HTTPS URL 格式: https://domain/owner/repo.git
|
|
43
|
+
try:
|
|
44
|
+
parsed = urlparse(url)
|
|
45
|
+
domain = parsed.netloc
|
|
46
|
+
platform = platform_domains.get(domain)
|
|
47
|
+
|
|
48
|
+
if not platform:
|
|
49
|
+
return None, None, None
|
|
50
|
+
|
|
51
|
+
path_parts = [p for p in parsed.path.split('/') if p]
|
|
52
|
+
if len(path_parts) >= 2:
|
|
53
|
+
owner = path_parts[0]
|
|
54
|
+
repo = path_parts[1]
|
|
55
|
+
if repo.endswith('.git'):
|
|
56
|
+
repo = repo[:-4]
|
|
57
|
+
return platform, owner, repo
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(f"解析Git URL失败: {e}")
|
|
61
|
+
|
|
62
|
+
return None, None, None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_repo_remote_url(repo_path: str, remote_name: str = 'origin') -> Optional[str]:
|
|
66
|
+
"""获取仓库的远程URL"""
|
|
67
|
+
try:
|
|
68
|
+
result = subprocess.run(
|
|
69
|
+
['git', 'remote', 'get-url', remote_name],
|
|
70
|
+
cwd=repo_path,
|
|
71
|
+
capture_output=True,
|
|
72
|
+
text=True,
|
|
73
|
+
check=True
|
|
74
|
+
)
|
|
75
|
+
return result.stdout.strip()
|
|
76
|
+
except subprocess.CalledProcessError:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_default_remote_branch(repo_path: str, remote_name: str = 'origin') -> Optional[str]:
|
|
81
|
+
"""
|
|
82
|
+
获取默认远程分支
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
repo_path: 仓库路径
|
|
86
|
+
remote_name: 远程名称,默认为 'origin'
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
默认远程分支名,如果获取失败则返回 None
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
# 首先尝试获取远程的 HEAD 指向的分支
|
|
93
|
+
result = subprocess.run(
|
|
94
|
+
['git', 'symbolic-ref', f'refs/remotes/{remote_name}/HEAD'],
|
|
95
|
+
cwd=repo_path,
|
|
96
|
+
capture_output=True,
|
|
97
|
+
text=True,
|
|
98
|
+
check=True
|
|
99
|
+
)
|
|
100
|
+
# 输出格式通常是 "refs/remotes/origin/main",我们需要提取分支名
|
|
101
|
+
head_ref = result.stdout.strip()
|
|
102
|
+
if head_ref.startswith(f'refs/remotes/{remote_name}/'):
|
|
103
|
+
return head_ref[len(f'refs/remotes/{remote_name}/'):]
|
|
104
|
+
except subprocess.CalledProcessError:
|
|
105
|
+
# 如果上面的方法失败,尝试从远程获取 HEAD 信息
|
|
106
|
+
try:
|
|
107
|
+
result = subprocess.run(
|
|
108
|
+
['git', 'ls-remote', '--symref', remote_name, 'HEAD'],
|
|
109
|
+
cwd=repo_path,
|
|
110
|
+
capture_output=True,
|
|
111
|
+
text=True,
|
|
112
|
+
check=True
|
|
113
|
+
)
|
|
114
|
+
# 解析输出,查找类似 "ref: refs/heads/main HEAD"
|
|
115
|
+
lines = result.stdout.strip().split('\n')
|
|
116
|
+
for line in lines:
|
|
117
|
+
if line.startswith('ref: refs/heads/'):
|
|
118
|
+
return line.split('refs/heads/')[-1].split('\t')[0]
|
|
119
|
+
except subprocess.CalledProcessError:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
# 如果都失败了,检查常见的默认分支是否存在
|
|
123
|
+
common_branches = ['main', 'master', 'develop']
|
|
124
|
+
for branch in common_branches:
|
|
125
|
+
if branch_exists(repo_path, branch, remote=True):
|
|
126
|
+
return branch
|
|
127
|
+
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def detect_platform_from_repo(repo_path: str) -> Optional[PlatformType]:
|
|
132
|
+
"""从仓库路径自动检测平台类型"""
|
|
133
|
+
remote_url = get_repo_remote_url(repo_path)
|
|
134
|
+
if not remote_url:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
platform, _, _ = parse_git_url(remote_url)
|
|
138
|
+
return platform
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_repo_info_from_path(repo_path: str) -> Optional[RepoInfo]:
|
|
142
|
+
"""从仓库路径获取仓库信息"""
|
|
143
|
+
remote_url = get_repo_remote_url(repo_path)
|
|
144
|
+
if not remote_url:
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
platform, owner, name = parse_git_url(remote_url)
|
|
148
|
+
if not all([platform, owner, name]):
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
return RepoInfo(
|
|
152
|
+
platform=platform, # type: ignore
|
|
153
|
+
owner=owner, # type: ignore
|
|
154
|
+
name=name # type: ignore
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_current_branch(repo_path: str) -> Optional[str]:
|
|
159
|
+
"""获取当前分支名"""
|
|
160
|
+
try:
|
|
161
|
+
result = subprocess.run(
|
|
162
|
+
['git', 'branch', '--show-current'],
|
|
163
|
+
cwd=repo_path,
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
check=True
|
|
167
|
+
)
|
|
168
|
+
return result.stdout.strip()
|
|
169
|
+
except subprocess.CalledProcessError:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def branch_exists(repo_path: str, branch_name: str, remote: bool = False) -> bool:
|
|
174
|
+
"""检查分支是否存在"""
|
|
175
|
+
try:
|
|
176
|
+
if remote:
|
|
177
|
+
result = subprocess.run(
|
|
178
|
+
['git', 'ls-remote', '--heads', 'origin', branch_name],
|
|
179
|
+
cwd=repo_path,
|
|
180
|
+
capture_output=True,
|
|
181
|
+
text=True,
|
|
182
|
+
check=True
|
|
183
|
+
)
|
|
184
|
+
return bool(result.stdout.strip())
|
|
185
|
+
else:
|
|
186
|
+
result = subprocess.run(
|
|
187
|
+
['git', 'show-ref', '--verify', '--quiet', f'refs/heads/{branch_name}'],
|
|
188
|
+
cwd=repo_path,
|
|
189
|
+
capture_output=True
|
|
190
|
+
)
|
|
191
|
+
return result.returncode == 0
|
|
192
|
+
except subprocess.CalledProcessError:
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def is_git_repo(path: str) -> bool:
|
|
197
|
+
"""检查路径是否为Git仓库"""
|
|
198
|
+
git_dir = Path(path) / '.git'
|
|
199
|
+
return git_dir.exists() or git_dir.is_file()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def push_branch_to_remote(repo_path: str, branch_name: str, remote_name: str = 'origin') -> bool:
|
|
203
|
+
"""
|
|
204
|
+
推送分支到远程仓库
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
repo_path: 仓库路径
|
|
208
|
+
branch_name: 分支名称
|
|
209
|
+
remote_name: 远程名称,默认为 'origin'
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
推送是否成功
|
|
213
|
+
"""
|
|
214
|
+
try:
|
|
215
|
+
logger.info(f"正在推送分支 '{branch_name}' 到远程仓库...")
|
|
216
|
+
result = subprocess.run(
|
|
217
|
+
['git', 'push', remote_name, branch_name],
|
|
218
|
+
cwd=repo_path,
|
|
219
|
+
capture_output=True,
|
|
220
|
+
text=True,
|
|
221
|
+
check=True
|
|
222
|
+
)
|
|
223
|
+
logger.info(f"分支 '{branch_name}' 推送成功")
|
|
224
|
+
return True
|
|
225
|
+
except subprocess.CalledProcessError as e:
|
|
226
|
+
logger.error(f"推送分支 '{branch_name}' 失败: {e.stderr}")
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def ensure_branch_exists_remotely(repo_path: str, branch_name: str, remote_name: str = 'origin') -> bool:
|
|
231
|
+
"""
|
|
232
|
+
确保分支在远程仓库中存在,如果不存在则推送
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
repo_path: 仓库路径
|
|
236
|
+
branch_name: 分支名称
|
|
237
|
+
remote_name: 远程名称,默认为 'origin'
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
分支是否存在于远程仓库(推送后)
|
|
241
|
+
"""
|
|
242
|
+
# 首先检查分支是否已经存在于远程
|
|
243
|
+
if branch_exists(repo_path, branch_name, remote=True):
|
|
244
|
+
logger.debug(f"分支 '{branch_name}' 已存在于远程仓库")
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
# 检查分支是否存在于本地
|
|
248
|
+
if not branch_exists(repo_path, branch_name, remote=False):
|
|
249
|
+
logger.error(f"分支 '{branch_name}' 在本地也不存在")
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
# 推送分支到远程
|
|
253
|
+
return push_branch_to_remote(repo_path, branch_name, remote_name)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def is_main_branch(branch_name: str) -> bool:
|
|
257
|
+
"""
|
|
258
|
+
检查是否为主分支
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
branch_name: 分支名称
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
是否为主分支
|
|
265
|
+
"""
|
|
266
|
+
main_branches = ['main', 'master', 'develop', 'dev']
|
|
267
|
+
return branch_name.lower() in main_branches
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def create_and_checkout_branch(repo_path: str, branch_name: str) -> bool:
|
|
271
|
+
"""
|
|
272
|
+
创建并切换到新分支
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
repo_path: 仓库路径
|
|
276
|
+
branch_name: 新分支名称
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
操作是否成功
|
|
280
|
+
"""
|
|
281
|
+
try:
|
|
282
|
+
logger.info(f"正在创建并切换到新分支: {branch_name}")
|
|
283
|
+
result = subprocess.run(
|
|
284
|
+
['git', 'checkout', '-b', branch_name],
|
|
285
|
+
cwd=repo_path,
|
|
286
|
+
capture_output=True,
|
|
287
|
+
text=True,
|
|
288
|
+
check=True
|
|
289
|
+
)
|
|
290
|
+
logger.info(f"成功创建并切换到分支: {branch_name}")
|
|
291
|
+
return True
|
|
292
|
+
except subprocess.CalledProcessError as e:
|
|
293
|
+
logger.error(f"创建分支 '{branch_name}' 失败: {e.stderr}")
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def generate_auto_branch_name() -> str:
|
|
298
|
+
"""
|
|
299
|
+
生成自动分支名称,格式为 ac-<yyyyMMdd-HH-mm-ss>
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
分支名称
|
|
303
|
+
"""
|
|
304
|
+
import datetime
|
|
305
|
+
now = datetime.datetime.now()
|
|
306
|
+
return f"ac-{now.strftime('%Y%m%d-%H-%M-%S')}"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def validate_repo_path(repo_path: str) -> str:
|
|
310
|
+
"""验证并规范化仓库路径"""
|
|
311
|
+
if not repo_path:
|
|
312
|
+
raise ValidationError("仓库路径不能为空")
|
|
313
|
+
|
|
314
|
+
path = Path(repo_path).resolve()
|
|
315
|
+
|
|
316
|
+
if not path.exists():
|
|
317
|
+
raise ValidationError(f"仓库路径不存在: {path}")
|
|
318
|
+
|
|
319
|
+
if not path.is_dir():
|
|
320
|
+
raise ValidationError(f"仓库路径不是目录: {path}")
|
|
321
|
+
|
|
322
|
+
if not is_git_repo(str(path)):
|
|
323
|
+
raise ValidationError(f"路径不是Git仓库: {path}")
|
|
324
|
+
|
|
325
|
+
return str(path)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def build_pr_url(platform: PlatformType, repo_info: RepoInfo, pr_number: int) -> str:
|
|
329
|
+
"""构建PR的Web URL"""
|
|
330
|
+
base_urls = {
|
|
331
|
+
PlatformType.GITHUB: "https://github.com",
|
|
332
|
+
PlatformType.GITLAB: "https://gitlab.com",
|
|
333
|
+
PlatformType.GITEE: "https://gitee.com",
|
|
334
|
+
PlatformType.GITCODE: "https://gitcode.net"
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
base_url = base_urls.get(platform)
|
|
338
|
+
if not base_url:
|
|
339
|
+
return ""
|
|
340
|
+
|
|
341
|
+
if platform == PlatformType.GITLAB or platform == PlatformType.GITCODE:
|
|
342
|
+
return f"{base_url}/{repo_info.full_name}/-/merge_requests/{pr_number}"
|
|
343
|
+
else:
|
|
344
|
+
return f"{base_url}/{repo_info.full_name}/pull/{pr_number}"
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tokens 模块 - 高效统计文件和目录中的 token 数量
|
|
3
|
+
|
|
4
|
+
提供了简单易用的接口,支持正则过滤和智能文件类型识别。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .models import TokenResult, DirectoryTokenResult
|
|
8
|
+
from .counter import TokenCounter
|
|
9
|
+
from .file_detector import FileTypeDetector
|
|
10
|
+
from .filters import FileFilter
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def count_file_tokens(file_path: str) -> TokenResult:
|
|
14
|
+
"""
|
|
15
|
+
统计单个文件的 token 数量
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
file_path: 文件路径
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
TokenResult: 统计结果
|
|
22
|
+
"""
|
|
23
|
+
counter = TokenCounter()
|
|
24
|
+
return counter.count_file(file_path)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def count_directory_tokens(
|
|
28
|
+
dir_path: str,
|
|
29
|
+
pattern: str = None,
|
|
30
|
+
exclude_pattern: str = None,
|
|
31
|
+
recursive: bool = True
|
|
32
|
+
) -> DirectoryTokenResult:
|
|
33
|
+
"""
|
|
34
|
+
统计目录中所有文件的 token 数量
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
dir_path: 目录路径
|
|
38
|
+
pattern: 文件名匹配模式(正则表达式)
|
|
39
|
+
exclude_pattern: 排除的文件名模式(正则表达式)
|
|
40
|
+
recursive: 是否递归处理子目录
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
DirectoryTokenResult: 目录统计结果
|
|
44
|
+
"""
|
|
45
|
+
counter = TokenCounter()
|
|
46
|
+
return counter.count_directory(
|
|
47
|
+
dir_path=dir_path,
|
|
48
|
+
pattern=pattern,
|
|
49
|
+
exclude_pattern=exclude_pattern,
|
|
50
|
+
recursive=recursive
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def count_string_tokens(text: str) -> int:
|
|
55
|
+
"""
|
|
56
|
+
统计字符串的 token 数量
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
text: 要统计的字符串内容
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
int: token 数量
|
|
63
|
+
"""
|
|
64
|
+
counter = TokenCounter()
|
|
65
|
+
return counter.count_string_tokens(text)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__all__ = [
|
|
69
|
+
'TokenResult',
|
|
70
|
+
'DirectoryTokenResult',
|
|
71
|
+
'TokenCounter',
|
|
72
|
+
'FileTypeDetector',
|
|
73
|
+
'FileFilter',
|
|
74
|
+
'count_file_tokens',
|
|
75
|
+
'count_directory_tokens',
|
|
76
|
+
'count_string_tokens',
|
|
77
|
+
]
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import concurrent.futures
|
|
3
|
+
from typing import List, Dict, Optional, Union, Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import time
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
from autocoder.rag.variable_holder import VariableHolder
|
|
9
|
+
from .models import TokenResult, DirectoryTokenResult
|
|
10
|
+
from .file_detector import FileTypeDetector
|
|
11
|
+
from .filters import FileFilter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TokenCounter:
|
|
15
|
+
"""Token 计数器,用于统计文件和目录的 token 数量"""
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
timeout: int = 30,
|
|
19
|
+
parallel: bool = True,
|
|
20
|
+
max_workers: int = 4):
|
|
21
|
+
"""
|
|
22
|
+
初始化 Token 计数器
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
timeout: 单文件处理超时时间(秒)
|
|
26
|
+
parallel: 是否并行处理
|
|
27
|
+
max_workers: 最大工作线程数
|
|
28
|
+
"""
|
|
29
|
+
self.timeout = timeout
|
|
30
|
+
self.parallel = parallel
|
|
31
|
+
self.max_workers = max_workers
|
|
32
|
+
|
|
33
|
+
# 确保 tokenizer 已经加载
|
|
34
|
+
if VariableHolder.TOKENIZER_MODEL is None:
|
|
35
|
+
raise RuntimeError("Tokenizer model not initialized. Please call load_tokenizer() first.")
|
|
36
|
+
|
|
37
|
+
def count_file(self, file_path: str) -> TokenResult:
|
|
38
|
+
"""
|
|
39
|
+
统计单个文件的 token 数量
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
file_path: 文件路径
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
TokenResult: 统计结果
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
if not os.path.isfile(file_path):
|
|
49
|
+
return TokenResult(
|
|
50
|
+
file_path=file_path,
|
|
51
|
+
token_count=0,
|
|
52
|
+
char_count=0,
|
|
53
|
+
line_count=0,
|
|
54
|
+
success=False,
|
|
55
|
+
error="File does not exist"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# 检查是否为文本文件
|
|
59
|
+
if not FileTypeDetector.is_text_file(file_path):
|
|
60
|
+
return TokenResult(
|
|
61
|
+
file_path=file_path,
|
|
62
|
+
token_count=0,
|
|
63
|
+
char_count=0,
|
|
64
|
+
line_count=0,
|
|
65
|
+
success=False,
|
|
66
|
+
error="Not a text file"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# 检测文件编码
|
|
70
|
+
encoding = FileTypeDetector.detect_encoding(file_path)
|
|
71
|
+
|
|
72
|
+
# 读取文件内容
|
|
73
|
+
with open(file_path, 'r', encoding=encoding, errors='replace') as f:
|
|
74
|
+
content = f.read()
|
|
75
|
+
|
|
76
|
+
# 统计行数
|
|
77
|
+
line_count = content.count('\n') + (0 if content == "" or content.endswith('\n') else 1)
|
|
78
|
+
|
|
79
|
+
# 统计字符数
|
|
80
|
+
char_count = len(content)
|
|
81
|
+
|
|
82
|
+
# 统计 token 数量
|
|
83
|
+
tokens = VariableHolder.TOKENIZER_MODEL.encode(content)
|
|
84
|
+
token_count = len(tokens)
|
|
85
|
+
|
|
86
|
+
return TokenResult(
|
|
87
|
+
file_path=file_path,
|
|
88
|
+
token_count=token_count,
|
|
89
|
+
char_count=char_count,
|
|
90
|
+
line_count=line_count
|
|
91
|
+
)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
return TokenResult(
|
|
94
|
+
file_path=file_path,
|
|
95
|
+
token_count=0,
|
|
96
|
+
char_count=0,
|
|
97
|
+
line_count=0,
|
|
98
|
+
success=False,
|
|
99
|
+
error=str(e)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def count_files(self, file_paths: List[str]) -> List[TokenResult]:
|
|
103
|
+
"""
|
|
104
|
+
批量统计多个文件的 token 数量
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
file_paths: 文件路径列表
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List[TokenResult]: 统计结果列表
|
|
111
|
+
"""
|
|
112
|
+
if not self.parallel or len(file_paths) <= 1:
|
|
113
|
+
return [self.count_file(file_path) for file_path in file_paths]
|
|
114
|
+
|
|
115
|
+
results = []
|
|
116
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
117
|
+
future_to_file = {
|
|
118
|
+
executor.submit(self.count_file, file_path): file_path
|
|
119
|
+
for file_path in file_paths
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
for future in concurrent.futures.as_completed(future_to_file):
|
|
123
|
+
results.append(future.result())
|
|
124
|
+
|
|
125
|
+
return results
|
|
126
|
+
|
|
127
|
+
def count_directory(self,
|
|
128
|
+
dir_path: str,
|
|
129
|
+
pattern: str = None,
|
|
130
|
+
exclude_pattern: str = None,
|
|
131
|
+
recursive: bool = True,
|
|
132
|
+
max_depth: int = None) -> DirectoryTokenResult:
|
|
133
|
+
"""
|
|
134
|
+
统计目录中所有文件的 token 数量
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
dir_path: 目录路径
|
|
138
|
+
pattern: 文件名匹配模式(正则表达式)
|
|
139
|
+
exclude_pattern: 排除的文件名模式(正则表达式)
|
|
140
|
+
recursive: 是否递归处理子目录
|
|
141
|
+
max_depth: 最大递归深度
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
DirectoryTokenResult: 目录统计结果
|
|
145
|
+
"""
|
|
146
|
+
if not os.path.isdir(dir_path):
|
|
147
|
+
return DirectoryTokenResult(
|
|
148
|
+
directory_path=dir_path,
|
|
149
|
+
total_tokens=0,
|
|
150
|
+
file_count=0,
|
|
151
|
+
skipped_count=0,
|
|
152
|
+
files=[],
|
|
153
|
+
errors=["Directory does not exist"]
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# 创建文件过滤器
|
|
157
|
+
patterns = [pattern] if pattern else []
|
|
158
|
+
exclude_patterns = [exclude_pattern] if exclude_pattern else []
|
|
159
|
+
file_filter = FileFilter(patterns=patterns, exclude_patterns=exclude_patterns)
|
|
160
|
+
|
|
161
|
+
# 收集所有匹配的文件
|
|
162
|
+
all_files = []
|
|
163
|
+
skipped_count = 0
|
|
164
|
+
|
|
165
|
+
for root, dirs, files in os.walk(dir_path):
|
|
166
|
+
# 检查递归深度
|
|
167
|
+
if max_depth is not None:
|
|
168
|
+
current_depth = root[len(dir_path):].count(os.sep)
|
|
169
|
+
if current_depth >= max_depth:
|
|
170
|
+
dirs.clear() # 不再递归子目录
|
|
171
|
+
|
|
172
|
+
for file in files:
|
|
173
|
+
file_path = os.path.join(root, file)
|
|
174
|
+
if file_filter.matches(file_path):
|
|
175
|
+
all_files.append(file_path)
|
|
176
|
+
else:
|
|
177
|
+
skipped_count += 1
|
|
178
|
+
|
|
179
|
+
if not recursive:
|
|
180
|
+
break # 不递归处理子目录
|
|
181
|
+
|
|
182
|
+
# 统计所有文件
|
|
183
|
+
file_results = self.count_files(all_files)
|
|
184
|
+
|
|
185
|
+
# 计算总 token 数
|
|
186
|
+
total_tokens = sum(result.token_count for result in file_results if result.success)
|
|
187
|
+
|
|
188
|
+
# 收集错误
|
|
189
|
+
errors = [
|
|
190
|
+
f"{result.file_path}: {result.error}"
|
|
191
|
+
for result in file_results if not result.success
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
return DirectoryTokenResult(
|
|
195
|
+
directory_path=dir_path,
|
|
196
|
+
total_tokens=total_tokens,
|
|
197
|
+
file_count=len(file_results),
|
|
198
|
+
skipped_count=skipped_count,
|
|
199
|
+
files=file_results,
|
|
200
|
+
errors=errors
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def count_string_tokens(self, text: str) -> int:
|
|
204
|
+
"""
|
|
205
|
+
统计字符串的 token 数量
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
text: 要统计的字符串内容
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
int: token 数量
|
|
212
|
+
"""
|
|
213
|
+
try:
|
|
214
|
+
if not isinstance(text, str):
|
|
215
|
+
raise ValueError("Input must be a string")
|
|
216
|
+
|
|
217
|
+
# 使用 tokenizer 编码字符串并统计 token 数量
|
|
218
|
+
tokens = VariableHolder.TOKENIZER_MODEL.encode(text)
|
|
219
|
+
return len(tokens)
|
|
220
|
+
except Exception as e:
|
|
221
|
+
raise RuntimeError(f"Failed to count tokens: {str(e)}")
|
|
222
|
+
|
|
223
|
+
def set_tokenizer(self, tokenizer_name: str) -> None:
|
|
224
|
+
"""
|
|
225
|
+
更改 tokenizer(目前不支持,仅为接口预留)
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
tokenizer_name: tokenizer 名称
|
|
229
|
+
"""
|
|
230
|
+
# 目前仅支持默认的 tokenizer
|
|
231
|
+
pass
|