auto-coder 0.1.399__py3-none-any.whl → 0.1.400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/METADATA +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/RECORD +38 -19
- autocoder/auto_coder_runner.py +2 -1
- autocoder/common/ac_style_command_parser/parser.py +27 -12
- autocoder/common/auto_coder_lang.py +78 -0
- autocoder/common/command_completer_v2.py +1 -1
- autocoder/common/pull_requests/__init__.py +256 -0
- autocoder/common/pull_requests/base_provider.py +191 -0
- autocoder/common/pull_requests/config.py +66 -0
- autocoder/common/pull_requests/example.py +1 -0
- autocoder/common/pull_requests/exceptions.py +46 -0
- autocoder/common/pull_requests/manager.py +201 -0
- autocoder/common/pull_requests/models.py +164 -0
- autocoder/common/pull_requests/providers/__init__.py +23 -0
- autocoder/common/pull_requests/providers/gitcode_provider.py +19 -0
- autocoder/common/pull_requests/providers/gitee_provider.py +20 -0
- autocoder/common/pull_requests/providers/github_provider.py +214 -0
- autocoder/common/pull_requests/providers/gitlab_provider.py +29 -0
- autocoder/common/pull_requests/test_module.py +1 -0
- autocoder/common/pull_requests/utils.py +344 -0
- autocoder/common/tokens/__init__.py +62 -0
- autocoder/common/tokens/counter.py +211 -0
- autocoder/common/tokens/file_detector.py +105 -0
- autocoder/common/tokens/filters.py +111 -0
- autocoder/common/tokens/models.py +28 -0
- autocoder/common/v2/agent/agentic_edit.py +182 -68
- autocoder/common/v2/agent/agentic_edit_types.py +1 -0
- autocoder/sdk/cli/handlers.py +2 -1
- autocoder/sdk/cli/main.py +4 -2
- autocoder/sdk/cli/options.py +4 -3
- autocoder/sdk/core/auto_coder_core.py +14 -1
- autocoder/sdk/core/bridge.py +3 -0
- autocoder/sdk/models/options.py +8 -6
- autocoder/version.py +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/licenses/LICENSE +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GitCode API 提供者实现
|
|
3
|
+
"""
|
|
4
|
+
from typing import List
|
|
5
|
+
from .github_provider import GitHubProvider
|
|
6
|
+
from ..models import RepoInfo, PRData, PRResult, PRInfo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GitCodeProvider(GitHubProvider):
|
|
10
|
+
"""GitCode API 提供者(基于GitHub提供者)"""
|
|
11
|
+
|
|
12
|
+
def _get_auth_header(self) -> str:
|
|
13
|
+
"""获取认证头"""
|
|
14
|
+
return f"Bearer {self.config.token}"
|
|
15
|
+
|
|
16
|
+
def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
|
|
17
|
+
"""创建 Merge Request"""
|
|
18
|
+
# 简化实现,实际应该调用GitCode API
|
|
19
|
+
return super().create_pr(repo_info, pr_data)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gitee API 提供者实现
|
|
3
|
+
"""
|
|
4
|
+
from typing import List
|
|
5
|
+
from .github_provider import GitHubProvider
|
|
6
|
+
from ..models import RepoInfo, PRData, PRResult, PRInfo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GiteeProvider(GitHubProvider):
|
|
10
|
+
"""Gitee API 提供者(基于GitHub提供者)"""
|
|
11
|
+
|
|
12
|
+
def _get_auth_header(self) -> str:
|
|
13
|
+
"""获取认证头"""
|
|
14
|
+
# Gitee 使用 token 参数而不是 Authorization 头
|
|
15
|
+
return ""
|
|
16
|
+
|
|
17
|
+
def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
|
|
18
|
+
"""创建 Pull Request"""
|
|
19
|
+
# 简化实现,实际应该调用Gitee API
|
|
20
|
+
return super().create_pr(repo_info, pr_data)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GitHub API 提供者实现
|
|
3
|
+
"""
|
|
4
|
+
from typing import List, Dict, Any
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from ..base_provider import BasePlatformProvider
|
|
8
|
+
from ..models import RepoInfo, PRData, PRResult, PRInfo
|
|
9
|
+
from ..utils import build_pr_url
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GitHubProvider(BasePlatformProvider):
|
|
13
|
+
"""GitHub API 提供者"""
|
|
14
|
+
|
|
15
|
+
def _get_auth_header(self) -> str:
|
|
16
|
+
"""获取认证头"""
|
|
17
|
+
return f"token {self.config.token}"
|
|
18
|
+
|
|
19
|
+
def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
|
|
20
|
+
"""创建 Pull Request"""
|
|
21
|
+
self._validate_repo_info(repo_info)
|
|
22
|
+
self._validate_pr_data(pr_data)
|
|
23
|
+
|
|
24
|
+
url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls"
|
|
25
|
+
|
|
26
|
+
payload = {
|
|
27
|
+
"title": pr_data.title,
|
|
28
|
+
"body": pr_data.description,
|
|
29
|
+
"head": pr_data.source_branch,
|
|
30
|
+
"base": pr_data.target_branch,
|
|
31
|
+
"draft": pr_data.draft or self.config.draft,
|
|
32
|
+
"maintainer_can_modify": self.config.maintainer_can_modify
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
response = self._make_request('POST', url, data=payload)
|
|
37
|
+
data = response.json()
|
|
38
|
+
|
|
39
|
+
pr_url = build_pr_url(self.config.platform, repo_info, data['number'])
|
|
40
|
+
|
|
41
|
+
result = PRResult(
|
|
42
|
+
success=True,
|
|
43
|
+
pr_number=data['number'],
|
|
44
|
+
pr_url=pr_url,
|
|
45
|
+
pr_id=str(data['id']),
|
|
46
|
+
platform=self.config.platform,
|
|
47
|
+
raw_response=data
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
logger.info(f"GitHub PR 创建成功: {pr_url}")
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.error(f"创建 GitHub PR 失败: {e}")
|
|
55
|
+
return PRResult(
|
|
56
|
+
success=False,
|
|
57
|
+
error_message=str(e),
|
|
58
|
+
platform=self.config.platform
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def get_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
|
|
62
|
+
"""获取 PR 信息"""
|
|
63
|
+
url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}"
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
response = self._make_request('GET', url)
|
|
67
|
+
data = response.json()
|
|
68
|
+
|
|
69
|
+
pr_info = PRInfo(
|
|
70
|
+
number=data['number'],
|
|
71
|
+
title=data['title'],
|
|
72
|
+
description=data.get('body', ''),
|
|
73
|
+
state=data['state'],
|
|
74
|
+
source_branch=data['head']['ref'],
|
|
75
|
+
target_branch=data['base']['ref'],
|
|
76
|
+
author=data['user']['login'],
|
|
77
|
+
created_at=data['created_at'],
|
|
78
|
+
updated_at=data['updated_at'],
|
|
79
|
+
merged_at=data.get('merged_at'),
|
|
80
|
+
pr_url=data['html_url'],
|
|
81
|
+
labels=[label['name'] for label in data.get('labels', [])],
|
|
82
|
+
assignees=[assignee['login'] for assignee in data.get('assignees', [])],
|
|
83
|
+
mergeable=data.get('mergeable'),
|
|
84
|
+
draft=data.get('draft', False),
|
|
85
|
+
raw_data=data
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return PRResult(
|
|
89
|
+
success=True,
|
|
90
|
+
pr_number=pr_info.number,
|
|
91
|
+
pr_url=pr_info.pr_url,
|
|
92
|
+
platform=self.config.platform,
|
|
93
|
+
raw_response=data
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.error(f"获取 GitHub PR 失败: {e}")
|
|
98
|
+
return PRResult(
|
|
99
|
+
success=False,
|
|
100
|
+
error_message=str(e),
|
|
101
|
+
platform=self.config.platform
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def update_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
|
|
105
|
+
"""更新 PR"""
|
|
106
|
+
url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}"
|
|
107
|
+
|
|
108
|
+
payload = {}
|
|
109
|
+
if 'title' in kwargs:
|
|
110
|
+
payload['title'] = kwargs['title']
|
|
111
|
+
if 'description' in kwargs:
|
|
112
|
+
payload['body'] = kwargs['description']
|
|
113
|
+
if 'state' in kwargs:
|
|
114
|
+
payload['state'] = kwargs['state']
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
response = self._make_request('PATCH', url, data=payload)
|
|
118
|
+
data = response.json()
|
|
119
|
+
|
|
120
|
+
return PRResult(
|
|
121
|
+
success=True,
|
|
122
|
+
pr_number=data['number'],
|
|
123
|
+
pr_url=data['html_url'],
|
|
124
|
+
platform=self.config.platform,
|
|
125
|
+
raw_response=data
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"更新 GitHub PR 失败: {e}")
|
|
130
|
+
return PRResult(
|
|
131
|
+
success=False,
|
|
132
|
+
error_message=str(e),
|
|
133
|
+
platform=self.config.platform
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def close_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
|
|
137
|
+
"""关闭 PR"""
|
|
138
|
+
return self.update_pr(repo_info, pr_number, state='closed')
|
|
139
|
+
|
|
140
|
+
def merge_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
|
|
141
|
+
"""合并 PR"""
|
|
142
|
+
url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}/merge"
|
|
143
|
+
|
|
144
|
+
payload = {
|
|
145
|
+
'commit_title': kwargs.get('commit_title', ''),
|
|
146
|
+
'commit_message': kwargs.get('commit_message', ''),
|
|
147
|
+
'merge_method': kwargs.get('merge_method', 'merge')
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
response = self._make_request('PUT', url, data=payload)
|
|
152
|
+
data = response.json()
|
|
153
|
+
|
|
154
|
+
return PRResult(
|
|
155
|
+
success=True,
|
|
156
|
+
pr_number=pr_number,
|
|
157
|
+
platform=self.config.platform,
|
|
158
|
+
raw_response=data
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
logger.error(f"合并 GitHub PR 失败: {e}")
|
|
163
|
+
return PRResult(
|
|
164
|
+
success=False,
|
|
165
|
+
error_message=str(e),
|
|
166
|
+
platform=self.config.platform
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def list_prs(
|
|
170
|
+
self,
|
|
171
|
+
repo_info: RepoInfo,
|
|
172
|
+
state: str = "open",
|
|
173
|
+
per_page: int = 30,
|
|
174
|
+
page: int = 1
|
|
175
|
+
) -> List[PRInfo]:
|
|
176
|
+
"""列出仓库的PR"""
|
|
177
|
+
url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls"
|
|
178
|
+
params = {
|
|
179
|
+
'state': state,
|
|
180
|
+
'per_page': str(per_page),
|
|
181
|
+
'page': str(page)
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
response = self._make_request('GET', url, params=params)
|
|
186
|
+
data = response.json()
|
|
187
|
+
|
|
188
|
+
prs = []
|
|
189
|
+
for pr_data in data:
|
|
190
|
+
pr_info = PRInfo(
|
|
191
|
+
number=pr_data['number'],
|
|
192
|
+
title=pr_data['title'],
|
|
193
|
+
description=pr_data.get('body', ''),
|
|
194
|
+
state=pr_data['state'],
|
|
195
|
+
source_branch=pr_data['head']['ref'],
|
|
196
|
+
target_branch=pr_data['base']['ref'],
|
|
197
|
+
author=pr_data['user']['login'],
|
|
198
|
+
created_at=pr_data['created_at'],
|
|
199
|
+
updated_at=pr_data['updated_at'],
|
|
200
|
+
merged_at=pr_data.get('merged_at'),
|
|
201
|
+
pr_url=pr_data['html_url'],
|
|
202
|
+
labels=[label['name'] for label in pr_data.get('labels', [])],
|
|
203
|
+
assignees=[assignee['login'] for assignee in pr_data.get('assignees', [])],
|
|
204
|
+
mergeable=pr_data.get('mergeable'),
|
|
205
|
+
draft=pr_data.get('draft', False),
|
|
206
|
+
raw_data=pr_data
|
|
207
|
+
)
|
|
208
|
+
prs.append(pr_info)
|
|
209
|
+
|
|
210
|
+
return prs
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error(f"列出 GitHub PR 失败: {e}")
|
|
214
|
+
return []
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GitLab API 提供者实现
|
|
3
|
+
"""
|
|
4
|
+
from typing import List
|
|
5
|
+
from .github_provider import GitHubProvider
|
|
6
|
+
from ..models import RepoInfo, PRData, PRResult, PRInfo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GitLabProvider(GitHubProvider):
|
|
10
|
+
"""GitLab API 提供者(基于GitHub提供者)"""
|
|
11
|
+
|
|
12
|
+
def _get_auth_header(self) -> str:
|
|
13
|
+
"""获取认证头"""
|
|
14
|
+
return f"Bearer {self.config.token}"
|
|
15
|
+
|
|
16
|
+
def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
|
|
17
|
+
"""创建 Merge Request (GitLab的PR称为MR)"""
|
|
18
|
+
# 简化实现,实际应该调用GitLab API
|
|
19
|
+
return super().create_pr(repo_info, pr_data)
|
|
20
|
+
|
|
21
|
+
def list_prs(
|
|
22
|
+
self,
|
|
23
|
+
repo_info: RepoInfo,
|
|
24
|
+
state: str = "opened", # GitLab使用"opened"而不是"open"
|
|
25
|
+
per_page: int = 30,
|
|
26
|
+
page: int = 1
|
|
27
|
+
) -> List[PRInfo]:
|
|
28
|
+
"""列出仓库的MR"""
|
|
29
|
+
return super().list_prs(repo_info, state, per_page, page)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 模块工具函数
|
|
3
|
+
"""
|
|
4
|
+
import re
|
|
5
|
+
import os
|
|
6
|
+
import subprocess
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from .models import PlatformType, RepoInfo
|
|
13
|
+
from .exceptions import ValidationError
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def parse_git_url(url: str) -> Tuple[Optional[PlatformType], Optional[str], Optional[str]]:
|
|
17
|
+
"""
|
|
18
|
+
解析Git URL,提取平台类型、所有者和仓库名
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Tuple[平台类型, 所有者, 仓库名]
|
|
22
|
+
"""
|
|
23
|
+
if not url:
|
|
24
|
+
return None, None, None
|
|
25
|
+
|
|
26
|
+
platform_domains = {
|
|
27
|
+
'github.com': PlatformType.GITHUB,
|
|
28
|
+
'gitlab.com': PlatformType.GITLAB,
|
|
29
|
+
'gitee.com': PlatformType.GITEE,
|
|
30
|
+
'gitcode.net': PlatformType.GITCODE
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# SSH URL 格式: git@domain:owner/repo.git
|
|
34
|
+
ssh_pattern = r'^git@([^:]+):([^/]+)/([^/]+?)(?:\.git)?/?$'
|
|
35
|
+
ssh_match = re.match(ssh_pattern, url)
|
|
36
|
+
|
|
37
|
+
if ssh_match:
|
|
38
|
+
domain, owner, repo = ssh_match.groups()
|
|
39
|
+
platform = platform_domains.get(domain)
|
|
40
|
+
return platform, owner, repo
|
|
41
|
+
|
|
42
|
+
# HTTPS URL 格式: https://domain/owner/repo.git
|
|
43
|
+
try:
|
|
44
|
+
parsed = urlparse(url)
|
|
45
|
+
domain = parsed.netloc
|
|
46
|
+
platform = platform_domains.get(domain)
|
|
47
|
+
|
|
48
|
+
if not platform:
|
|
49
|
+
return None, None, None
|
|
50
|
+
|
|
51
|
+
path_parts = [p for p in parsed.path.split('/') if p]
|
|
52
|
+
if len(path_parts) >= 2:
|
|
53
|
+
owner = path_parts[0]
|
|
54
|
+
repo = path_parts[1]
|
|
55
|
+
if repo.endswith('.git'):
|
|
56
|
+
repo = repo[:-4]
|
|
57
|
+
return platform, owner, repo
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(f"解析Git URL失败: {e}")
|
|
61
|
+
|
|
62
|
+
return None, None, None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_repo_remote_url(repo_path: str, remote_name: str = 'origin') -> Optional[str]:
|
|
66
|
+
"""获取仓库的远程URL"""
|
|
67
|
+
try:
|
|
68
|
+
result = subprocess.run(
|
|
69
|
+
['git', 'remote', 'get-url', remote_name],
|
|
70
|
+
cwd=repo_path,
|
|
71
|
+
capture_output=True,
|
|
72
|
+
text=True,
|
|
73
|
+
check=True
|
|
74
|
+
)
|
|
75
|
+
return result.stdout.strip()
|
|
76
|
+
except subprocess.CalledProcessError:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_default_remote_branch(repo_path: str, remote_name: str = 'origin') -> Optional[str]:
|
|
81
|
+
"""
|
|
82
|
+
获取默认远程分支
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
repo_path: 仓库路径
|
|
86
|
+
remote_name: 远程名称,默认为 'origin'
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
默认远程分支名,如果获取失败则返回 None
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
# 首先尝试获取远程的 HEAD 指向的分支
|
|
93
|
+
result = subprocess.run(
|
|
94
|
+
['git', 'symbolic-ref', f'refs/remotes/{remote_name}/HEAD'],
|
|
95
|
+
cwd=repo_path,
|
|
96
|
+
capture_output=True,
|
|
97
|
+
text=True,
|
|
98
|
+
check=True
|
|
99
|
+
)
|
|
100
|
+
# 输出格式通常是 "refs/remotes/origin/main",我们需要提取分支名
|
|
101
|
+
head_ref = result.stdout.strip()
|
|
102
|
+
if head_ref.startswith(f'refs/remotes/{remote_name}/'):
|
|
103
|
+
return head_ref[len(f'refs/remotes/{remote_name}/'):]
|
|
104
|
+
except subprocess.CalledProcessError:
|
|
105
|
+
# 如果上面的方法失败,尝试从远程获取 HEAD 信息
|
|
106
|
+
try:
|
|
107
|
+
result = subprocess.run(
|
|
108
|
+
['git', 'ls-remote', '--symref', remote_name, 'HEAD'],
|
|
109
|
+
cwd=repo_path,
|
|
110
|
+
capture_output=True,
|
|
111
|
+
text=True,
|
|
112
|
+
check=True
|
|
113
|
+
)
|
|
114
|
+
# 解析输出,查找类似 "ref: refs/heads/main HEAD"
|
|
115
|
+
lines = result.stdout.strip().split('\n')
|
|
116
|
+
for line in lines:
|
|
117
|
+
if line.startswith('ref: refs/heads/'):
|
|
118
|
+
return line.split('refs/heads/')[-1].split('\t')[0]
|
|
119
|
+
except subprocess.CalledProcessError:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
# 如果都失败了,检查常见的默认分支是否存在
|
|
123
|
+
common_branches = ['main', 'master', 'develop']
|
|
124
|
+
for branch in common_branches:
|
|
125
|
+
if branch_exists(repo_path, branch, remote=True):
|
|
126
|
+
return branch
|
|
127
|
+
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def detect_platform_from_repo(repo_path: str) -> Optional[PlatformType]:
|
|
132
|
+
"""从仓库路径自动检测平台类型"""
|
|
133
|
+
remote_url = get_repo_remote_url(repo_path)
|
|
134
|
+
if not remote_url:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
platform, _, _ = parse_git_url(remote_url)
|
|
138
|
+
return platform
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_repo_info_from_path(repo_path: str) -> Optional[RepoInfo]:
|
|
142
|
+
"""从仓库路径获取仓库信息"""
|
|
143
|
+
remote_url = get_repo_remote_url(repo_path)
|
|
144
|
+
if not remote_url:
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
platform, owner, name = parse_git_url(remote_url)
|
|
148
|
+
if not all([platform, owner, name]):
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
return RepoInfo(
|
|
152
|
+
platform=platform, # type: ignore
|
|
153
|
+
owner=owner, # type: ignore
|
|
154
|
+
name=name # type: ignore
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_current_branch(repo_path: str) -> Optional[str]:
|
|
159
|
+
"""获取当前分支名"""
|
|
160
|
+
try:
|
|
161
|
+
result = subprocess.run(
|
|
162
|
+
['git', 'branch', '--show-current'],
|
|
163
|
+
cwd=repo_path,
|
|
164
|
+
capture_output=True,
|
|
165
|
+
text=True,
|
|
166
|
+
check=True
|
|
167
|
+
)
|
|
168
|
+
return result.stdout.strip()
|
|
169
|
+
except subprocess.CalledProcessError:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def branch_exists(repo_path: str, branch_name: str, remote: bool = False) -> bool:
|
|
174
|
+
"""检查分支是否存在"""
|
|
175
|
+
try:
|
|
176
|
+
if remote:
|
|
177
|
+
result = subprocess.run(
|
|
178
|
+
['git', 'ls-remote', '--heads', 'origin', branch_name],
|
|
179
|
+
cwd=repo_path,
|
|
180
|
+
capture_output=True,
|
|
181
|
+
text=True,
|
|
182
|
+
check=True
|
|
183
|
+
)
|
|
184
|
+
return bool(result.stdout.strip())
|
|
185
|
+
else:
|
|
186
|
+
result = subprocess.run(
|
|
187
|
+
['git', 'show-ref', '--verify', '--quiet', f'refs/heads/{branch_name}'],
|
|
188
|
+
cwd=repo_path,
|
|
189
|
+
capture_output=True
|
|
190
|
+
)
|
|
191
|
+
return result.returncode == 0
|
|
192
|
+
except subprocess.CalledProcessError:
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def is_git_repo(path: str) -> bool:
|
|
197
|
+
"""检查路径是否为Git仓库"""
|
|
198
|
+
git_dir = Path(path) / '.git'
|
|
199
|
+
return git_dir.exists() or git_dir.is_file()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def push_branch_to_remote(repo_path: str, branch_name: str, remote_name: str = 'origin') -> bool:
|
|
203
|
+
"""
|
|
204
|
+
推送分支到远程仓库
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
repo_path: 仓库路径
|
|
208
|
+
branch_name: 分支名称
|
|
209
|
+
remote_name: 远程名称,默认为 'origin'
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
推送是否成功
|
|
213
|
+
"""
|
|
214
|
+
try:
|
|
215
|
+
logger.info(f"正在推送分支 '{branch_name}' 到远程仓库...")
|
|
216
|
+
result = subprocess.run(
|
|
217
|
+
['git', 'push', remote_name, branch_name],
|
|
218
|
+
cwd=repo_path,
|
|
219
|
+
capture_output=True,
|
|
220
|
+
text=True,
|
|
221
|
+
check=True
|
|
222
|
+
)
|
|
223
|
+
logger.info(f"分支 '{branch_name}' 推送成功")
|
|
224
|
+
return True
|
|
225
|
+
except subprocess.CalledProcessError as e:
|
|
226
|
+
logger.error(f"推送分支 '{branch_name}' 失败: {e.stderr}")
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def ensure_branch_exists_remotely(repo_path: str, branch_name: str, remote_name: str = 'origin') -> bool:
|
|
231
|
+
"""
|
|
232
|
+
确保分支在远程仓库中存在,如果不存在则推送
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
repo_path: 仓库路径
|
|
236
|
+
branch_name: 分支名称
|
|
237
|
+
remote_name: 远程名称,默认为 'origin'
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
分支是否存在于远程仓库(推送后)
|
|
241
|
+
"""
|
|
242
|
+
# 首先检查分支是否已经存在于远程
|
|
243
|
+
if branch_exists(repo_path, branch_name, remote=True):
|
|
244
|
+
logger.debug(f"分支 '{branch_name}' 已存在于远程仓库")
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
# 检查分支是否存在于本地
|
|
248
|
+
if not branch_exists(repo_path, branch_name, remote=False):
|
|
249
|
+
logger.error(f"分支 '{branch_name}' 在本地也不存在")
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
# 推送分支到远程
|
|
253
|
+
return push_branch_to_remote(repo_path, branch_name, remote_name)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def is_main_branch(branch_name: str) -> bool:
|
|
257
|
+
"""
|
|
258
|
+
检查是否为主分支
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
branch_name: 分支名称
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
是否为主分支
|
|
265
|
+
"""
|
|
266
|
+
main_branches = ['main', 'master', 'develop', 'dev']
|
|
267
|
+
return branch_name.lower() in main_branches
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def create_and_checkout_branch(repo_path: str, branch_name: str) -> bool:
|
|
271
|
+
"""
|
|
272
|
+
创建并切换到新分支
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
repo_path: 仓库路径
|
|
276
|
+
branch_name: 新分支名称
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
操作是否成功
|
|
280
|
+
"""
|
|
281
|
+
try:
|
|
282
|
+
logger.info(f"正在创建并切换到新分支: {branch_name}")
|
|
283
|
+
result = subprocess.run(
|
|
284
|
+
['git', 'checkout', '-b', branch_name],
|
|
285
|
+
cwd=repo_path,
|
|
286
|
+
capture_output=True,
|
|
287
|
+
text=True,
|
|
288
|
+
check=True
|
|
289
|
+
)
|
|
290
|
+
logger.info(f"成功创建并切换到分支: {branch_name}")
|
|
291
|
+
return True
|
|
292
|
+
except subprocess.CalledProcessError as e:
|
|
293
|
+
logger.error(f"创建分支 '{branch_name}' 失败: {e.stderr}")
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def generate_auto_branch_name() -> str:
|
|
298
|
+
"""
|
|
299
|
+
生成自动分支名称,格式为 ac-<yyyyMMdd-HH-mm-ss>
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
分支名称
|
|
303
|
+
"""
|
|
304
|
+
import datetime
|
|
305
|
+
now = datetime.datetime.now()
|
|
306
|
+
return f"ac-{now.strftime('%Y%m%d-%H-%M-%S')}"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def validate_repo_path(repo_path: str) -> str:
|
|
310
|
+
"""验证并规范化仓库路径"""
|
|
311
|
+
if not repo_path:
|
|
312
|
+
raise ValidationError("仓库路径不能为空")
|
|
313
|
+
|
|
314
|
+
path = Path(repo_path).resolve()
|
|
315
|
+
|
|
316
|
+
if not path.exists():
|
|
317
|
+
raise ValidationError(f"仓库路径不存在: {path}")
|
|
318
|
+
|
|
319
|
+
if not path.is_dir():
|
|
320
|
+
raise ValidationError(f"仓库路径不是目录: {path}")
|
|
321
|
+
|
|
322
|
+
if not is_git_repo(str(path)):
|
|
323
|
+
raise ValidationError(f"路径不是Git仓库: {path}")
|
|
324
|
+
|
|
325
|
+
return str(path)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def build_pr_url(platform: PlatformType, repo_info: RepoInfo, pr_number: int) -> str:
|
|
329
|
+
"""构建PR的Web URL"""
|
|
330
|
+
base_urls = {
|
|
331
|
+
PlatformType.GITHUB: "https://github.com",
|
|
332
|
+
PlatformType.GITLAB: "https://gitlab.com",
|
|
333
|
+
PlatformType.GITEE: "https://gitee.com",
|
|
334
|
+
PlatformType.GITCODE: "https://gitcode.net"
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
base_url = base_urls.get(platform)
|
|
338
|
+
if not base_url:
|
|
339
|
+
return ""
|
|
340
|
+
|
|
341
|
+
if platform == PlatformType.GITLAB or platform == PlatformType.GITCODE:
|
|
342
|
+
return f"{base_url}/{repo_info.full_name}/-/merge_requests/{pr_number}"
|
|
343
|
+
else:
|
|
344
|
+
return f"{base_url}/{repo_info.full_name}/pull/{pr_number}"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tokens 模块 - 高效统计文件和目录中的 token 数量
|
|
3
|
+
|
|
4
|
+
提供了简单易用的接口,支持正则过滤和智能文件类型识别。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .models import TokenResult, DirectoryTokenResult
|
|
8
|
+
from .counter import TokenCounter
|
|
9
|
+
from .file_detector import FileTypeDetector
|
|
10
|
+
from .filters import FileFilter
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def count_file_tokens(file_path: str) -> TokenResult:
|
|
14
|
+
"""
|
|
15
|
+
统计单个文件的 token 数量
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
file_path: 文件路径
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
TokenResult: 统计结果
|
|
22
|
+
"""
|
|
23
|
+
counter = TokenCounter()
|
|
24
|
+
return counter.count_file(file_path)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def count_directory_tokens(
|
|
28
|
+
dir_path: str,
|
|
29
|
+
pattern: str = None,
|
|
30
|
+
exclude_pattern: str = None,
|
|
31
|
+
recursive: bool = True
|
|
32
|
+
) -> DirectoryTokenResult:
|
|
33
|
+
"""
|
|
34
|
+
统计目录中所有文件的 token 数量
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
dir_path: 目录路径
|
|
38
|
+
pattern: 文件名匹配模式(正则表达式)
|
|
39
|
+
exclude_pattern: 排除的文件名模式(正则表达式)
|
|
40
|
+
recursive: 是否递归处理子目录
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
DirectoryTokenResult: 目录统计结果
|
|
44
|
+
"""
|
|
45
|
+
counter = TokenCounter()
|
|
46
|
+
return counter.count_directory(
|
|
47
|
+
dir_path=dir_path,
|
|
48
|
+
pattern=pattern,
|
|
49
|
+
exclude_pattern=exclude_pattern,
|
|
50
|
+
recursive=recursive
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
'TokenResult',
|
|
56
|
+
'DirectoryTokenResult',
|
|
57
|
+
'TokenCounter',
|
|
58
|
+
'FileTypeDetector',
|
|
59
|
+
'FileFilter',
|
|
60
|
+
'count_file_tokens',
|
|
61
|
+
'count_directory_tokens',
|
|
62
|
+
]
|