auto-coder 0.1.399__py3-none-any.whl → 0.1.400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/METADATA +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/RECORD +38 -19
- autocoder/auto_coder_runner.py +2 -1
- autocoder/common/ac_style_command_parser/parser.py +27 -12
- autocoder/common/auto_coder_lang.py +78 -0
- autocoder/common/command_completer_v2.py +1 -1
- autocoder/common/pull_requests/__init__.py +256 -0
- autocoder/common/pull_requests/base_provider.py +191 -0
- autocoder/common/pull_requests/config.py +66 -0
- autocoder/common/pull_requests/example.py +1 -0
- autocoder/common/pull_requests/exceptions.py +46 -0
- autocoder/common/pull_requests/manager.py +201 -0
- autocoder/common/pull_requests/models.py +164 -0
- autocoder/common/pull_requests/providers/__init__.py +23 -0
- autocoder/common/pull_requests/providers/gitcode_provider.py +19 -0
- autocoder/common/pull_requests/providers/gitee_provider.py +20 -0
- autocoder/common/pull_requests/providers/github_provider.py +214 -0
- autocoder/common/pull_requests/providers/gitlab_provider.py +29 -0
- autocoder/common/pull_requests/test_module.py +1 -0
- autocoder/common/pull_requests/utils.py +344 -0
- autocoder/common/tokens/__init__.py +62 -0
- autocoder/common/tokens/counter.py +211 -0
- autocoder/common/tokens/file_detector.py +105 -0
- autocoder/common/tokens/filters.py +111 -0
- autocoder/common/tokens/models.py +28 -0
- autocoder/common/v2/agent/agentic_edit.py +182 -68
- autocoder/common/v2/agent/agentic_edit_types.py +1 -0
- autocoder/sdk/cli/handlers.py +2 -1
- autocoder/sdk/cli/main.py +4 -2
- autocoder/sdk/cli/options.py +4 -3
- autocoder/sdk/core/auto_coder_core.py +14 -1
- autocoder/sdk/core/bridge.py +3 -0
- autocoder/sdk/models/options.py +8 -6
- autocoder/version.py +1 -1
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/licenses/LICENSE +0 -0
- {auto_coder-0.1.399.dist-info → auto_coder-0.1.400.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 平台提供者基类
|
|
3
|
+
"""
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Optional, Dict, Any, List
|
|
6
|
+
import requests
|
|
7
|
+
import time
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
from .models import RepoInfo, PRData, PRResult, PRInfo, PRConfig
|
|
11
|
+
from .exceptions import (
|
|
12
|
+
PRError, NetworkError, AuthenticationError, RateLimitError,
|
|
13
|
+
RepositoryNotFoundError, ValidationError
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BasePlatformProvider(ABC):
|
|
18
|
+
"""平台提供者基类,定义统一接口"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: PRConfig):
|
|
21
|
+
self.config = config
|
|
22
|
+
self.session = self._create_session()
|
|
23
|
+
|
|
24
|
+
def _create_session(self) -> requests.Session:
|
|
25
|
+
"""创建HTTP会话"""
|
|
26
|
+
session = requests.Session()
|
|
27
|
+
session.headers.update({
|
|
28
|
+
'User-Agent': 'AutoCoder-PR/1.0',
|
|
29
|
+
'Authorization': self._get_auth_header(),
|
|
30
|
+
'Accept': 'application/json',
|
|
31
|
+
'Content-Type': 'application/json'
|
|
32
|
+
})
|
|
33
|
+
session.verify = self.config.verify_ssl
|
|
34
|
+
return session
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def _get_auth_header(self) -> str:
|
|
38
|
+
"""获取认证头"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
def _make_request(
|
|
42
|
+
self,
|
|
43
|
+
method: str,
|
|
44
|
+
url: str,
|
|
45
|
+
data: Optional[Dict[str, Any]] = None,
|
|
46
|
+
params: Optional[Dict[str, str]] = None,
|
|
47
|
+
retry_count: Optional[int] = None
|
|
48
|
+
) -> requests.Response:
|
|
49
|
+
"""发送HTTP请求,包含重试逻辑"""
|
|
50
|
+
|
|
51
|
+
if retry_count is None:
|
|
52
|
+
retry_count = self.config.retry_count
|
|
53
|
+
|
|
54
|
+
last_exception = None
|
|
55
|
+
|
|
56
|
+
for attempt in range(retry_count + 1):
|
|
57
|
+
try:
|
|
58
|
+
response = self.session.request(
|
|
59
|
+
method=method,
|
|
60
|
+
url=url,
|
|
61
|
+
json=data,
|
|
62
|
+
params=params,
|
|
63
|
+
timeout=self.config.timeout
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if response.status_code == 401:
|
|
67
|
+
raise AuthenticationError("认证失败,请检查token是否正确")
|
|
68
|
+
elif response.status_code == 404:
|
|
69
|
+
raise RepositoryNotFoundError("仓库或资源不存在")
|
|
70
|
+
elif response.status_code == 429:
|
|
71
|
+
retry_after = int(response.headers.get('Retry-After', 60))
|
|
72
|
+
raise RateLimitError(f"API限流,请等待{retry_after}秒后重试", retry_after=retry_after)
|
|
73
|
+
elif response.status_code == 422:
|
|
74
|
+
# 422 通常是验证错误,尝试解析响应内容获取详细错误信息
|
|
75
|
+
try:
|
|
76
|
+
error_data = response.json()
|
|
77
|
+
if 'errors' in error_data:
|
|
78
|
+
errors = error_data['errors']
|
|
79
|
+
error_msgs = []
|
|
80
|
+
for error in errors:
|
|
81
|
+
if isinstance(error, dict):
|
|
82
|
+
field = error.get('field', '')
|
|
83
|
+
code = error.get('code', '')
|
|
84
|
+
message = error.get('message', str(error))
|
|
85
|
+
if field and code:
|
|
86
|
+
error_msgs.append(f"{field}: {message} (code: {code})")
|
|
87
|
+
else:
|
|
88
|
+
error_msgs.append(message)
|
|
89
|
+
else:
|
|
90
|
+
error_msgs.append(str(error))
|
|
91
|
+
raise ValidationError(f"请求验证失败: {'; '.join(error_msgs)}")
|
|
92
|
+
elif 'message' in error_data:
|
|
93
|
+
raise ValidationError(f"请求验证失败: {error_data['message']}")
|
|
94
|
+
else:
|
|
95
|
+
raise ValidationError(f"请求验证失败: {response.text}")
|
|
96
|
+
except (ValueError, KeyError):
|
|
97
|
+
raise ValidationError(f"请求验证失败: HTTP 422 - {response.text}")
|
|
98
|
+
elif response.status_code >= 400:
|
|
99
|
+
try:
|
|
100
|
+
error_data = response.json()
|
|
101
|
+
if 'message' in error_data:
|
|
102
|
+
raise PRError(f"API请求失败: HTTP {response.status_code} - {error_data['message']}")
|
|
103
|
+
else:
|
|
104
|
+
raise PRError(f"API请求失败: HTTP {response.status_code} - {response.text}")
|
|
105
|
+
except (ValueError, KeyError):
|
|
106
|
+
raise PRError(f"API请求失败: HTTP {response.status_code} - {response.text}")
|
|
107
|
+
|
|
108
|
+
return response
|
|
109
|
+
|
|
110
|
+
except (requests.exceptions.RequestException, ConnectionError) as e:
|
|
111
|
+
last_exception = NetworkError(f"网络请求失败: {str(e)}")
|
|
112
|
+
|
|
113
|
+
if attempt < retry_count:
|
|
114
|
+
delay = 2 ** attempt
|
|
115
|
+
logger.warning(f"请求失败,{delay}秒后重试")
|
|
116
|
+
time.sleep(delay)
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
except RateLimitError as e:
|
|
120
|
+
if attempt < retry_count:
|
|
121
|
+
logger.warning(f"遇到限流,等待{e.retry_after}秒后重试")
|
|
122
|
+
time.sleep(e.retry_after)
|
|
123
|
+
continue
|
|
124
|
+
raise
|
|
125
|
+
|
|
126
|
+
if last_exception:
|
|
127
|
+
raise last_exception
|
|
128
|
+
|
|
129
|
+
raise PRError("请求失败,已达到最大重试次数")
|
|
130
|
+
|
|
131
|
+
def _validate_pr_data(self, pr_data: PRData) -> None:
|
|
132
|
+
"""验证PR数据"""
|
|
133
|
+
if not pr_data.title.strip():
|
|
134
|
+
raise ValidationError("PR标题不能为空")
|
|
135
|
+
if not pr_data.source_branch.strip():
|
|
136
|
+
raise ValidationError("源分支不能为空")
|
|
137
|
+
if not pr_data.target_branch.strip():
|
|
138
|
+
raise ValidationError("目标分支不能为空")
|
|
139
|
+
if pr_data.source_branch == pr_data.target_branch:
|
|
140
|
+
raise ValidationError("源分支和目标分支不能相同")
|
|
141
|
+
|
|
142
|
+
def _validate_repo_info(self, repo_info: RepoInfo) -> None:
|
|
143
|
+
"""验证仓库信息"""
|
|
144
|
+
if not repo_info.owner.strip():
|
|
145
|
+
raise ValidationError("仓库所有者不能为空")
|
|
146
|
+
if not repo_info.name.strip():
|
|
147
|
+
raise ValidationError("仓库名称不能为空")
|
|
148
|
+
|
|
149
|
+
@abstractmethod
|
|
150
|
+
def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
|
|
151
|
+
"""创建 Pull Request"""
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
@abstractmethod
|
|
155
|
+
def get_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
|
|
156
|
+
"""获取 PR 信息"""
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
@abstractmethod
|
|
160
|
+
def update_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
|
|
161
|
+
"""更新 PR"""
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
@abstractmethod
|
|
165
|
+
def close_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
|
|
166
|
+
"""关闭 PR"""
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def merge_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
|
|
171
|
+
"""合并 PR"""
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
def list_prs(
|
|
176
|
+
self,
|
|
177
|
+
repo_info: RepoInfo,
|
|
178
|
+
state: str = "open",
|
|
179
|
+
per_page: int = 30,
|
|
180
|
+
page: int = 1
|
|
181
|
+
) -> List[PRInfo]:
|
|
182
|
+
"""列出仓库的PR"""
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
def health_check(self) -> bool:
|
|
186
|
+
"""检查连接和认证状态"""
|
|
187
|
+
try:
|
|
188
|
+
response = self._make_request('GET', f"{self.config.base_url}/user")
|
|
189
|
+
return response.status_code == 200
|
|
190
|
+
except Exception:
|
|
191
|
+
return False
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 配置管理
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict, Any, Optional
|
|
6
|
+
from .models import PRConfig, PlatformType
|
|
7
|
+
from .exceptions import ConfigurationError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_config(platform: str, **overrides) -> PRConfig:
|
|
11
|
+
"""
|
|
12
|
+
获取平台配置
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
platform: 平台名称
|
|
16
|
+
**overrides: 配置覆盖参数
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
配置对象
|
|
20
|
+
"""
|
|
21
|
+
# 从环境变量加载配置
|
|
22
|
+
env_config = _load_from_env(platform)
|
|
23
|
+
|
|
24
|
+
# 合并配置
|
|
25
|
+
merged_config = {}
|
|
26
|
+
if env_config:
|
|
27
|
+
merged_config.update(env_config)
|
|
28
|
+
merged_config.update(overrides)
|
|
29
|
+
|
|
30
|
+
# 验证必需的配置
|
|
31
|
+
if 'token' not in merged_config:
|
|
32
|
+
raise ConfigurationError(f"平台 {platform} 缺少必需的 token 配置")
|
|
33
|
+
|
|
34
|
+
return PRConfig(platform=PlatformType(platform), **merged_config)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _load_from_env(platform: str) -> Dict[str, Any]:
|
|
38
|
+
"""从环境变量加载配置"""
|
|
39
|
+
env_mappings = {
|
|
40
|
+
'github': {
|
|
41
|
+
'token': 'GITHUB_TOKEN',
|
|
42
|
+
'base_url': 'GITHUB_BASE_URL'
|
|
43
|
+
},
|
|
44
|
+
'gitlab': {
|
|
45
|
+
'token': 'GITLAB_TOKEN',
|
|
46
|
+
'base_url': 'GITLAB_BASE_URL'
|
|
47
|
+
},
|
|
48
|
+
'gitee': {
|
|
49
|
+
'token': 'GITEE_TOKEN',
|
|
50
|
+
'base_url': 'GITEE_BASE_URL'
|
|
51
|
+
},
|
|
52
|
+
'gitcode': {
|
|
53
|
+
'token': 'GITCODE_TOKEN',
|
|
54
|
+
'base_url': 'GITCODE_BASE_URL'
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
mapping = env_mappings.get(platform, {})
|
|
59
|
+
config = {}
|
|
60
|
+
|
|
61
|
+
for key, env_var in mapping.items():
|
|
62
|
+
value = os.getenv(env_var)
|
|
63
|
+
if value:
|
|
64
|
+
config[key] = value
|
|
65
|
+
|
|
66
|
+
return config
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 模块自定义异常类
|
|
3
|
+
"""
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
class PRError(Exception):
|
|
7
|
+
"""Pull Request 操作基础异常"""
|
|
8
|
+
def __init__(self, message: str, error_code: Optional[str] = None, platform: Optional[str] = None):
|
|
9
|
+
self.message = message
|
|
10
|
+
self.error_code = error_code
|
|
11
|
+
self.platform = platform
|
|
12
|
+
super().__init__(message)
|
|
13
|
+
|
|
14
|
+
class AuthenticationError(PRError):
|
|
15
|
+
"""认证失败异常"""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
class RepositoryNotFoundError(PRError):
|
|
19
|
+
"""仓库不存在异常"""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
class BranchNotFoundError(PRError):
|
|
23
|
+
"""分支不存在异常"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
class NetworkError(PRError):
|
|
27
|
+
"""网络错误异常"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
class RateLimitError(PRError):
|
|
31
|
+
"""API 限流异常"""
|
|
32
|
+
def __init__(self, message: str, retry_after: int = 60, **kwargs):
|
|
33
|
+
self.retry_after = retry_after
|
|
34
|
+
super().__init__(message, **kwargs)
|
|
35
|
+
|
|
36
|
+
class ValidationError(PRError):
|
|
37
|
+
"""参数验证错误异常"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
class PlatformNotSupportedError(PRError):
|
|
41
|
+
"""平台不支持异常"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
class ConfigurationError(PRError):
|
|
45
|
+
"""配置错误异常"""
|
|
46
|
+
pass
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 主管理器
|
|
3
|
+
"""
|
|
4
|
+
from typing import Optional, List, Dict, Any
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from .models import PRConfig, PRData, PRResult, PRInfo, RepoInfo, PlatformType, DEFAULT_TEMPLATES
|
|
8
|
+
from .providers import PROVIDERS
|
|
9
|
+
from .utils import detect_platform_from_repo, get_repo_info_from_path, validate_repo_path
|
|
10
|
+
from .exceptions import (
|
|
11
|
+
PRError, PlatformNotSupportedError, ConfigurationError, ValidationError
|
|
12
|
+
)
|
|
13
|
+
from .config import get_config
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PullRequestManager:
|
|
17
|
+
"""Pull Request 主管理器"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: Optional[PRConfig] = None):
|
|
20
|
+
self.config = config
|
|
21
|
+
self._provider = None
|
|
22
|
+
self._templates = DEFAULT_TEMPLATES.copy()
|
|
23
|
+
|
|
24
|
+
def _get_provider(self, platform: Optional[str] = None, repo_path: Optional[str] = None):
|
|
25
|
+
"""获取平台提供者"""
|
|
26
|
+
if self.config and not platform:
|
|
27
|
+
platform = self.config.platform.value
|
|
28
|
+
elif not platform and repo_path:
|
|
29
|
+
# 自动检测平台
|
|
30
|
+
detected_platform = detect_platform_from_repo(repo_path)
|
|
31
|
+
if not detected_platform:
|
|
32
|
+
raise PlatformNotSupportedError("无法从仓库路径检测平台类型")
|
|
33
|
+
platform = detected_platform.value
|
|
34
|
+
elif not platform:
|
|
35
|
+
raise ValidationError("必须指定平台类型或提供仓库路径")
|
|
36
|
+
|
|
37
|
+
if platform not in PROVIDERS:
|
|
38
|
+
raise PlatformNotSupportedError(f"不支持的平台: {platform}")
|
|
39
|
+
|
|
40
|
+
# 使用现有配置或创建新配置
|
|
41
|
+
if self.config and self.config.platform.value == platform:
|
|
42
|
+
config = self.config
|
|
43
|
+
else:
|
|
44
|
+
try:
|
|
45
|
+
config = get_config(platform)
|
|
46
|
+
except ConfigurationError:
|
|
47
|
+
raise ConfigurationError(f"平台 {platform} 的配置未找到")
|
|
48
|
+
|
|
49
|
+
provider_class = PROVIDERS[platform]
|
|
50
|
+
return provider_class(config)
|
|
51
|
+
|
|
52
|
+
def create_pull_request(
|
|
53
|
+
self,
|
|
54
|
+
repo_path: str,
|
|
55
|
+
source_branch: str,
|
|
56
|
+
target_branch: str,
|
|
57
|
+
title: str,
|
|
58
|
+
description: str = "",
|
|
59
|
+
labels: Optional[List[str]] = None,
|
|
60
|
+
assignees: Optional[List[str]] = None,
|
|
61
|
+
reviewers: Optional[List[str]] = None,
|
|
62
|
+
draft: bool = False,
|
|
63
|
+
template_type: Optional[str] = None,
|
|
64
|
+
template_vars: Optional[Dict[str, str]] = None,
|
|
65
|
+
platform: Optional[str] = None,
|
|
66
|
+
**kwargs
|
|
67
|
+
) -> PRResult:
|
|
68
|
+
"""创建 Pull Request"""
|
|
69
|
+
try:
|
|
70
|
+
# 验证仓库路径
|
|
71
|
+
repo_path = validate_repo_path(repo_path)
|
|
72
|
+
|
|
73
|
+
# 获取仓库信息
|
|
74
|
+
repo_info = get_repo_info_from_path(repo_path)
|
|
75
|
+
if not repo_info:
|
|
76
|
+
raise ValidationError("无法获取仓库信息")
|
|
77
|
+
|
|
78
|
+
# 获取提供者
|
|
79
|
+
provider = self._get_provider(platform, repo_path)
|
|
80
|
+
|
|
81
|
+
# 创建PR数据
|
|
82
|
+
pr_data = PRData(
|
|
83
|
+
title=title,
|
|
84
|
+
description=description,
|
|
85
|
+
source_branch=source_branch,
|
|
86
|
+
target_branch=target_branch,
|
|
87
|
+
labels=labels or [],
|
|
88
|
+
assignees=assignees or [],
|
|
89
|
+
reviewers=reviewers or [],
|
|
90
|
+
draft=draft,
|
|
91
|
+
template_type=template_type,
|
|
92
|
+
template_vars=template_vars or {}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# 创建PR
|
|
96
|
+
result = provider.create_pr(repo_info, pr_data)
|
|
97
|
+
|
|
98
|
+
return result
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.error(f"创建PR失败: {e}")
|
|
102
|
+
return PRResult(
|
|
103
|
+
success=False,
|
|
104
|
+
error_message=str(e),
|
|
105
|
+
platform=PlatformType(platform) if platform else None
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def get_pull_request(
|
|
109
|
+
self,
|
|
110
|
+
repo_path: str,
|
|
111
|
+
pr_number: int,
|
|
112
|
+
platform: Optional[str] = None
|
|
113
|
+
) -> PRResult:
|
|
114
|
+
"""获取 Pull Request 信息"""
|
|
115
|
+
try:
|
|
116
|
+
repo_path = validate_repo_path(repo_path)
|
|
117
|
+
repo_info = get_repo_info_from_path(repo_path)
|
|
118
|
+
if not repo_info:
|
|
119
|
+
raise ValidationError("无法获取仓库信息")
|
|
120
|
+
|
|
121
|
+
provider = self._get_provider(platform, repo_path)
|
|
122
|
+
return provider.get_pr(repo_info, pr_number)
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"获取PR失败: {e}")
|
|
126
|
+
return PRResult(
|
|
127
|
+
success=False,
|
|
128
|
+
error_message=str(e),
|
|
129
|
+
platform=PlatformType(platform) if platform else None
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def update_pull_request(
|
|
133
|
+
self,
|
|
134
|
+
repo_path: str,
|
|
135
|
+
pr_number: int,
|
|
136
|
+
platform: Optional[str] = None,
|
|
137
|
+
**kwargs
|
|
138
|
+
) -> PRResult:
|
|
139
|
+
"""更新 Pull Request"""
|
|
140
|
+
try:
|
|
141
|
+
repo_path = validate_repo_path(repo_path)
|
|
142
|
+
repo_info = get_repo_info_from_path(repo_path)
|
|
143
|
+
if not repo_info:
|
|
144
|
+
raise ValidationError("无法获取仓库信息")
|
|
145
|
+
|
|
146
|
+
provider = self._get_provider(platform, repo_path)
|
|
147
|
+
return provider.update_pr(repo_info, pr_number, **kwargs)
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.error(f"更新PR失败: {e}")
|
|
151
|
+
return PRResult(
|
|
152
|
+
success=False,
|
|
153
|
+
error_message=str(e),
|
|
154
|
+
platform=PlatformType(platform) if platform else None
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def list_pull_requests(
|
|
158
|
+
self,
|
|
159
|
+
repo_path: str,
|
|
160
|
+
state: str = "open",
|
|
161
|
+
per_page: int = 30,
|
|
162
|
+
page: int = 1,
|
|
163
|
+
platform: Optional[str] = None
|
|
164
|
+
) -> List[PRInfo]:
|
|
165
|
+
"""列出 Pull Requests"""
|
|
166
|
+
try:
|
|
167
|
+
repo_path = validate_repo_path(repo_path)
|
|
168
|
+
repo_info = get_repo_info_from_path(repo_path)
|
|
169
|
+
if not repo_info:
|
|
170
|
+
raise ValidationError("无法获取仓库信息")
|
|
171
|
+
|
|
172
|
+
provider = self._get_provider(platform, repo_path)
|
|
173
|
+
return provider.list_prs(repo_info, state, per_page, page)
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"列出PR失败: {e}")
|
|
177
|
+
return []
|
|
178
|
+
|
|
179
|
+
def health_check(self, platform: Optional[str] = None, repo_path: Optional[str] = None) -> bool:
|
|
180
|
+
"""健康检查"""
|
|
181
|
+
try:
|
|
182
|
+
provider = self._get_provider(platform, repo_path)
|
|
183
|
+
return provider.health_check()
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.error(f"健康检查失败: {e}")
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# 全局管理器实例
|
|
190
|
+
_global_manager = PullRequestManager()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def set_global_config(config: PRConfig) -> None:
|
|
194
|
+
"""设置全局配置"""
|
|
195
|
+
global _global_manager
|
|
196
|
+
_global_manager = PullRequestManager(config)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_global_manager() -> PullRequestManager:
|
|
200
|
+
"""获取全局管理器"""
|
|
201
|
+
return _global_manager
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 数据模型定义
|
|
3
|
+
"""
|
|
4
|
+
from typing import Optional, List, Dict, Any
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
import os
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PlatformType(str, Enum):
|
|
13
|
+
"""支持的代码托管平台类型"""
|
|
14
|
+
GITHUB = "github"
|
|
15
|
+
GITLAB = "gitlab"
|
|
16
|
+
GITEE = "gitee"
|
|
17
|
+
GITCODE = "gitcode"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class PRConfig:
|
|
22
|
+
"""Pull Request 配置类"""
|
|
23
|
+
platform: PlatformType
|
|
24
|
+
token: str
|
|
25
|
+
base_url: Optional[str] = None
|
|
26
|
+
timeout: int = 30
|
|
27
|
+
verify_ssl: bool = True
|
|
28
|
+
retry_count: int = 3
|
|
29
|
+
default_labels: List[str] = field(default_factory=list)
|
|
30
|
+
default_assignees: List[str] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
# 平台特定配置
|
|
33
|
+
draft: bool = False
|
|
34
|
+
maintainer_can_modify: bool = True
|
|
35
|
+
remove_source_branch: bool = False
|
|
36
|
+
squash: bool = False
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
if isinstance(self.platform, str):
|
|
40
|
+
self.platform = PlatformType(self.platform)
|
|
41
|
+
if self.base_url is None:
|
|
42
|
+
self.base_url = self._get_default_base_url()
|
|
43
|
+
|
|
44
|
+
def _get_default_base_url(self) -> str:
|
|
45
|
+
default_urls = {
|
|
46
|
+
PlatformType.GITHUB: "https://api.github.com",
|
|
47
|
+
PlatformType.GITLAB: "https://gitlab.com/api/v4",
|
|
48
|
+
PlatformType.GITEE: "https://gitee.com/api/v5",
|
|
49
|
+
PlatformType.GITCODE: "https://gitcode.net/api/v4"
|
|
50
|
+
}
|
|
51
|
+
return default_urls.get(self.platform, "")
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_env(cls, platform: str) -> 'PRConfig':
|
|
55
|
+
env_mappings = {
|
|
56
|
+
PlatformType.GITHUB: "GITHUB_TOKEN",
|
|
57
|
+
PlatformType.GITLAB: "GITLAB_TOKEN",
|
|
58
|
+
PlatformType.GITEE: "GITEE_TOKEN",
|
|
59
|
+
PlatformType.GITCODE: "GITCODE_TOKEN"
|
|
60
|
+
}
|
|
61
|
+
platform_type = PlatformType(platform)
|
|
62
|
+
token_env = env_mappings.get(platform_type)
|
|
63
|
+
if not token_env:
|
|
64
|
+
raise ValueError(f"不支持的平台类型: {platform}")
|
|
65
|
+
token = os.getenv(token_env)
|
|
66
|
+
if not token:
|
|
67
|
+
raise ValueError(f"环境变量 {token_env} 未设置")
|
|
68
|
+
return cls(platform=platform_type, token=token)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class RepoInfo:
|
|
73
|
+
"""仓库信息"""
|
|
74
|
+
owner: str
|
|
75
|
+
name: str
|
|
76
|
+
platform: PlatformType
|
|
77
|
+
full_name: str = field(init=False)
|
|
78
|
+
|
|
79
|
+
def __post_init__(self):
|
|
80
|
+
self.full_name = f"{self.owner}/{self.name}"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class PRData:
|
|
85
|
+
"""Pull Request 数据"""
|
|
86
|
+
title: str
|
|
87
|
+
description: str
|
|
88
|
+
source_branch: str
|
|
89
|
+
target_branch: str
|
|
90
|
+
labels: List[str] = field(default_factory=list)
|
|
91
|
+
assignees: List[str] = field(default_factory=list)
|
|
92
|
+
reviewers: List[str] = field(default_factory=list)
|
|
93
|
+
draft: bool = False
|
|
94
|
+
template_type: Optional[str] = None
|
|
95
|
+
template_vars: Dict[str, str] = field(default_factory=dict)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class PRResult:
|
|
100
|
+
"""Pull Request 操作结果"""
|
|
101
|
+
success: bool
|
|
102
|
+
pr_number: Optional[int] = None
|
|
103
|
+
pr_url: Optional[str] = None
|
|
104
|
+
pr_id: Optional[str] = None
|
|
105
|
+
error_message: Optional[str] = None
|
|
106
|
+
error_code: Optional[str] = None
|
|
107
|
+
platform: Optional[PlatformType] = None
|
|
108
|
+
raw_response: Optional[Dict[str, Any]] = None
|
|
109
|
+
retry_after: Optional[int] = None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class PRInfo:
|
|
114
|
+
"""Pull Request 详细信息"""
|
|
115
|
+
number: int
|
|
116
|
+
title: str
|
|
117
|
+
description: str
|
|
118
|
+
state: str
|
|
119
|
+
source_branch: str
|
|
120
|
+
target_branch: str
|
|
121
|
+
author: str
|
|
122
|
+
created_at: str
|
|
123
|
+
updated_at: str
|
|
124
|
+
merged_at: Optional[str] = None
|
|
125
|
+
pr_url: str = ""
|
|
126
|
+
labels: List[str] = field(default_factory=list)
|
|
127
|
+
assignees: List[str] = field(default_factory=list)
|
|
128
|
+
reviewers: List[str] = field(default_factory=list)
|
|
129
|
+
mergeable: Optional[bool] = None
|
|
130
|
+
draft: bool = False
|
|
131
|
+
raw_data: Optional[Dict[str, Any]] = None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# 默认模板配置
|
|
135
|
+
DEFAULT_TEMPLATES = {
|
|
136
|
+
"bug_fix": {
|
|
137
|
+
"title_prefix": "🐛 Bug Fix:",
|
|
138
|
+
"description_template": """
|
|
139
|
+
## 问题描述
|
|
140
|
+
{problem_description}
|
|
141
|
+
|
|
142
|
+
## 解决方案
|
|
143
|
+
{solution_description}
|
|
144
|
+
|
|
145
|
+
## 测试
|
|
146
|
+
- [ ] 单元测试通过
|
|
147
|
+
- [ ] 集成测试通过
|
|
148
|
+
- [ ] 手动测试验证
|
|
149
|
+
"""
|
|
150
|
+
},
|
|
151
|
+
"feature": {
|
|
152
|
+
"title_prefix": "✨ Feature:",
|
|
153
|
+
"description_template": """
|
|
154
|
+
## 新功能说明
|
|
155
|
+
{feature_description}
|
|
156
|
+
|
|
157
|
+
## 实现细节
|
|
158
|
+
{implementation_details}
|
|
159
|
+
|
|
160
|
+
## 使用示例
|
|
161
|
+
{usage_examples}
|
|
162
|
+
"""
|
|
163
|
+
}
|
|
164
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pull Request 平台提供者模块
|
|
3
|
+
"""
|
|
4
|
+
from .github_provider import GitHubProvider
|
|
5
|
+
from .gitlab_provider import GitLabProvider
|
|
6
|
+
from .gitee_provider import GiteeProvider
|
|
7
|
+
from .gitcode_provider import GitCodeProvider
|
|
8
|
+
|
|
9
|
+
# 提供者映射
|
|
10
|
+
PROVIDERS = {
|
|
11
|
+
'github': GitHubProvider,
|
|
12
|
+
'gitlab': GitLabProvider,
|
|
13
|
+
'gitee': GiteeProvider,
|
|
14
|
+
'gitcode': GitCodeProvider
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
'GitHubProvider',
|
|
19
|
+
'GitLabProvider',
|
|
20
|
+
'GiteeProvider',
|
|
21
|
+
'GitCodeProvider',
|
|
22
|
+
'PROVIDERS'
|
|
23
|
+
]
|