auto-coder 0.1.399__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (71) hide show
  1. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/METADATA +1 -1
  2. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/RECORD +71 -35
  3. autocoder/agent/agentic_filter.py +1 -1
  4. autocoder/agent/base_agentic/tools/read_file_tool_resolver.py +1 -1
  5. autocoder/auto_coder_runner.py +121 -26
  6. autocoder/chat_auto_coder.py +81 -22
  7. autocoder/commands/auto_command.py +1 -1
  8. autocoder/common/__init__.py +2 -2
  9. autocoder/common/ac_style_command_parser/parser.py +27 -12
  10. autocoder/common/auto_coder_lang.py +78 -0
  11. autocoder/common/command_completer_v2.py +1 -1
  12. autocoder/common/file_monitor/test_file_monitor.py +307 -0
  13. autocoder/common/git_utils.py +7 -2
  14. autocoder/common/pruner/__init__.py +0 -0
  15. autocoder/common/pruner/agentic_conversation_pruner.py +197 -0
  16. autocoder/common/pruner/context_pruner.py +574 -0
  17. autocoder/common/pruner/conversation_pruner.py +132 -0
  18. autocoder/common/pruner/test_agentic_conversation_pruner.py +342 -0
  19. autocoder/common/pruner/test_context_pruner.py +546 -0
  20. autocoder/common/pull_requests/__init__.py +256 -0
  21. autocoder/common/pull_requests/base_provider.py +191 -0
  22. autocoder/common/pull_requests/config.py +66 -0
  23. autocoder/common/pull_requests/example.py +1 -0
  24. autocoder/common/pull_requests/exceptions.py +46 -0
  25. autocoder/common/pull_requests/manager.py +201 -0
  26. autocoder/common/pull_requests/models.py +164 -0
  27. autocoder/common/pull_requests/providers/__init__.py +23 -0
  28. autocoder/common/pull_requests/providers/gitcode_provider.py +19 -0
  29. autocoder/common/pull_requests/providers/gitee_provider.py +20 -0
  30. autocoder/common/pull_requests/providers/github_provider.py +214 -0
  31. autocoder/common/pull_requests/providers/gitlab_provider.py +29 -0
  32. autocoder/common/pull_requests/test_module.py +1 -0
  33. autocoder/common/pull_requests/utils.py +344 -0
  34. autocoder/common/tokens/__init__.py +77 -0
  35. autocoder/common/tokens/counter.py +231 -0
  36. autocoder/common/tokens/file_detector.py +105 -0
  37. autocoder/common/tokens/filters.py +111 -0
  38. autocoder/common/tokens/models.py +28 -0
  39. autocoder/common/v2/agent/agentic_edit.py +538 -590
  40. autocoder/common/v2/agent/agentic_edit_tools/__init__.py +8 -1
  41. autocoder/common/v2/agent/agentic_edit_tools/ac_mod_read_tool_resolver.py +40 -0
  42. autocoder/common/v2/agent/agentic_edit_tools/ac_mod_write_tool_resolver.py +43 -0
  43. autocoder/common/v2/agent/agentic_edit_tools/ask_followup_question_tool_resolver.py +8 -0
  44. autocoder/common/v2/agent/agentic_edit_tools/execute_command_tool_resolver.py +1 -1
  45. autocoder/common/v2/agent/agentic_edit_tools/read_file_tool_resolver.py +1 -1
  46. autocoder/common/v2/agent/agentic_edit_tools/search_files_tool_resolver.py +33 -88
  47. autocoder/common/v2/agent/agentic_edit_tools/test_write_to_file_tool_resolver.py +8 -8
  48. autocoder/common/v2/agent/agentic_edit_tools/todo_read_tool_resolver.py +118 -0
  49. autocoder/common/v2/agent/agentic_edit_tools/todo_write_tool_resolver.py +324 -0
  50. autocoder/common/v2/agent/agentic_edit_types.py +47 -4
  51. autocoder/common/v2/agent/runner/__init__.py +31 -0
  52. autocoder/common/v2/agent/runner/base_runner.py +106 -0
  53. autocoder/common/v2/agent/runner/event_runner.py +216 -0
  54. autocoder/common/v2/agent/runner/sdk_runner.py +40 -0
  55. autocoder/common/v2/agent/runner/terminal_runner.py +283 -0
  56. autocoder/common/v2/agent/runner/tool_display.py +191 -0
  57. autocoder/index/entry.py +1 -1
  58. autocoder/plugins/token_helper_plugin.py +107 -7
  59. autocoder/run_context.py +9 -0
  60. autocoder/sdk/__init__.py +114 -81
  61. autocoder/sdk/cli/handlers.py +2 -1
  62. autocoder/sdk/cli/main.py +9 -2
  63. autocoder/sdk/cli/options.py +4 -3
  64. autocoder/sdk/core/auto_coder_core.py +7 -152
  65. autocoder/sdk/core/bridge.py +5 -4
  66. autocoder/sdk/models/options.py +8 -6
  67. autocoder/version.py +1 -1
  68. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/WHEEL +0 -0
  69. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/entry_points.txt +0 -0
  70. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/licenses/LICENSE +0 -0
  71. {auto_coder-0.1.399.dist-info → auto_coder-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
1
+ """
2
+ Pull Request 主管理器
3
+ """
4
+ from typing import Optional, List, Dict, Any
5
+ from loguru import logger
6
+
7
+ from .models import PRConfig, PRData, PRResult, PRInfo, RepoInfo, PlatformType, DEFAULT_TEMPLATES
8
+ from .providers import PROVIDERS
9
+ from .utils import detect_platform_from_repo, get_repo_info_from_path, validate_repo_path
10
+ from .exceptions import (
11
+ PRError, PlatformNotSupportedError, ConfigurationError, ValidationError
12
+ )
13
+ from .config import get_config
14
+
15
+
16
+ class PullRequestManager:
17
+ """Pull Request 主管理器"""
18
+
19
+ def __init__(self, config: Optional[PRConfig] = None):
20
+ self.config = config
21
+ self._provider = None
22
+ self._templates = DEFAULT_TEMPLATES.copy()
23
+
24
+ def _get_provider(self, platform: Optional[str] = None, repo_path: Optional[str] = None):
25
+ """获取平台提供者"""
26
+ if self.config and not platform:
27
+ platform = self.config.platform.value
28
+ elif not platform and repo_path:
29
+ # 自动检测平台
30
+ detected_platform = detect_platform_from_repo(repo_path)
31
+ if not detected_platform:
32
+ raise PlatformNotSupportedError("无法从仓库路径检测平台类型")
33
+ platform = detected_platform.value
34
+ elif not platform:
35
+ raise ValidationError("必须指定平台类型或提供仓库路径")
36
+
37
+ if platform not in PROVIDERS:
38
+ raise PlatformNotSupportedError(f"不支持的平台: {platform}")
39
+
40
+ # 使用现有配置或创建新配置
41
+ if self.config and self.config.platform.value == platform:
42
+ config = self.config
43
+ else:
44
+ try:
45
+ config = get_config(platform)
46
+ except ConfigurationError:
47
+ raise ConfigurationError(f"平台 {platform} 的配置未找到")
48
+
49
+ provider_class = PROVIDERS[platform]
50
+ return provider_class(config)
51
+
52
+ def create_pull_request(
53
+ self,
54
+ repo_path: str,
55
+ source_branch: str,
56
+ target_branch: str,
57
+ title: str,
58
+ description: str = "",
59
+ labels: Optional[List[str]] = None,
60
+ assignees: Optional[List[str]] = None,
61
+ reviewers: Optional[List[str]] = None,
62
+ draft: bool = False,
63
+ template_type: Optional[str] = None,
64
+ template_vars: Optional[Dict[str, str]] = None,
65
+ platform: Optional[str] = None,
66
+ **kwargs
67
+ ) -> PRResult:
68
+ """创建 Pull Request"""
69
+ try:
70
+ # 验证仓库路径
71
+ repo_path = validate_repo_path(repo_path)
72
+
73
+ # 获取仓库信息
74
+ repo_info = get_repo_info_from_path(repo_path)
75
+ if not repo_info:
76
+ raise ValidationError("无法获取仓库信息")
77
+
78
+ # 获取提供者
79
+ provider = self._get_provider(platform, repo_path)
80
+
81
+ # 创建PR数据
82
+ pr_data = PRData(
83
+ title=title,
84
+ description=description,
85
+ source_branch=source_branch,
86
+ target_branch=target_branch,
87
+ labels=labels or [],
88
+ assignees=assignees or [],
89
+ reviewers=reviewers or [],
90
+ draft=draft,
91
+ template_type=template_type,
92
+ template_vars=template_vars or {}
93
+ )
94
+
95
+ # 创建PR
96
+ result = provider.create_pr(repo_info, pr_data)
97
+
98
+ return result
99
+
100
+ except Exception as e:
101
+ logger.error(f"创建PR失败: {e}")
102
+ return PRResult(
103
+ success=False,
104
+ error_message=str(e),
105
+ platform=PlatformType(platform) if platform else None
106
+ )
107
+
108
+ def get_pull_request(
109
+ self,
110
+ repo_path: str,
111
+ pr_number: int,
112
+ platform: Optional[str] = None
113
+ ) -> PRResult:
114
+ """获取 Pull Request 信息"""
115
+ try:
116
+ repo_path = validate_repo_path(repo_path)
117
+ repo_info = get_repo_info_from_path(repo_path)
118
+ if not repo_info:
119
+ raise ValidationError("无法获取仓库信息")
120
+
121
+ provider = self._get_provider(platform, repo_path)
122
+ return provider.get_pr(repo_info, pr_number)
123
+
124
+ except Exception as e:
125
+ logger.error(f"获取PR失败: {e}")
126
+ return PRResult(
127
+ success=False,
128
+ error_message=str(e),
129
+ platform=PlatformType(platform) if platform else None
130
+ )
131
+
132
+ def update_pull_request(
133
+ self,
134
+ repo_path: str,
135
+ pr_number: int,
136
+ platform: Optional[str] = None,
137
+ **kwargs
138
+ ) -> PRResult:
139
+ """更新 Pull Request"""
140
+ try:
141
+ repo_path = validate_repo_path(repo_path)
142
+ repo_info = get_repo_info_from_path(repo_path)
143
+ if not repo_info:
144
+ raise ValidationError("无法获取仓库信息")
145
+
146
+ provider = self._get_provider(platform, repo_path)
147
+ return provider.update_pr(repo_info, pr_number, **kwargs)
148
+
149
+ except Exception as e:
150
+ logger.error(f"更新PR失败: {e}")
151
+ return PRResult(
152
+ success=False,
153
+ error_message=str(e),
154
+ platform=PlatformType(platform) if platform else None
155
+ )
156
+
157
+ def list_pull_requests(
158
+ self,
159
+ repo_path: str,
160
+ state: str = "open",
161
+ per_page: int = 30,
162
+ page: int = 1,
163
+ platform: Optional[str] = None
164
+ ) -> List[PRInfo]:
165
+ """列出 Pull Requests"""
166
+ try:
167
+ repo_path = validate_repo_path(repo_path)
168
+ repo_info = get_repo_info_from_path(repo_path)
169
+ if not repo_info:
170
+ raise ValidationError("无法获取仓库信息")
171
+
172
+ provider = self._get_provider(platform, repo_path)
173
+ return provider.list_prs(repo_info, state, per_page, page)
174
+
175
+ except Exception as e:
176
+ logger.error(f"列出PR失败: {e}")
177
+ return []
178
+
179
+ def health_check(self, platform: Optional[str] = None, repo_path: Optional[str] = None) -> bool:
180
+ """健康检查"""
181
+ try:
182
+ provider = self._get_provider(platform, repo_path)
183
+ return provider.health_check()
184
+ except Exception as e:
185
+ logger.error(f"健康检查失败: {e}")
186
+ return False
187
+
188
+
189
+ # 全局管理器实例
190
+ _global_manager = PullRequestManager()
191
+
192
+
193
+ def set_global_config(config: PRConfig) -> None:
194
+ """设置全局配置"""
195
+ global _global_manager
196
+ _global_manager = PullRequestManager(config)
197
+
198
+
199
+ def get_global_manager() -> PullRequestManager:
200
+ """获取全局管理器"""
201
+ return _global_manager
@@ -0,0 +1,164 @@
1
+ """
2
+ Pull Request 数据模型定义
3
+ """
4
+ from typing import Optional, List, Dict, Any
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+
11
+
12
+ class PlatformType(str, Enum):
13
+ """支持的代码托管平台类型"""
14
+ GITHUB = "github"
15
+ GITLAB = "gitlab"
16
+ GITEE = "gitee"
17
+ GITCODE = "gitcode"
18
+
19
+
20
+ @dataclass
21
+ class PRConfig:
22
+ """Pull Request 配置类"""
23
+ platform: PlatformType
24
+ token: str
25
+ base_url: Optional[str] = None
26
+ timeout: int = 30
27
+ verify_ssl: bool = True
28
+ retry_count: int = 3
29
+ default_labels: List[str] = field(default_factory=list)
30
+ default_assignees: List[str] = field(default_factory=list)
31
+
32
+ # 平台特定配置
33
+ draft: bool = False
34
+ maintainer_can_modify: bool = True
35
+ remove_source_branch: bool = False
36
+ squash: bool = False
37
+
38
+ def __post_init__(self):
39
+ if isinstance(self.platform, str):
40
+ self.platform = PlatformType(self.platform)
41
+ if self.base_url is None:
42
+ self.base_url = self._get_default_base_url()
43
+
44
+ def _get_default_base_url(self) -> str:
45
+ default_urls = {
46
+ PlatformType.GITHUB: "https://api.github.com",
47
+ PlatformType.GITLAB: "https://gitlab.com/api/v4",
48
+ PlatformType.GITEE: "https://gitee.com/api/v5",
49
+ PlatformType.GITCODE: "https://gitcode.net/api/v4"
50
+ }
51
+ return default_urls.get(self.platform, "")
52
+
53
+ @classmethod
54
+ def from_env(cls, platform: str) -> 'PRConfig':
55
+ env_mappings = {
56
+ PlatformType.GITHUB: "GITHUB_TOKEN",
57
+ PlatformType.GITLAB: "GITLAB_TOKEN",
58
+ PlatformType.GITEE: "GITEE_TOKEN",
59
+ PlatformType.GITCODE: "GITCODE_TOKEN"
60
+ }
61
+ platform_type = PlatformType(platform)
62
+ token_env = env_mappings.get(platform_type)
63
+ if not token_env:
64
+ raise ValueError(f"不支持的平台类型: {platform}")
65
+ token = os.getenv(token_env)
66
+ if not token:
67
+ raise ValueError(f"环境变量 {token_env} 未设置")
68
+ return cls(platform=platform_type, token=token)
69
+
70
+
71
+ @dataclass
72
+ class RepoInfo:
73
+ """仓库信息"""
74
+ owner: str
75
+ name: str
76
+ platform: PlatformType
77
+ full_name: str = field(init=False)
78
+
79
+ def __post_init__(self):
80
+ self.full_name = f"{self.owner}/{self.name}"
81
+
82
+
83
+ @dataclass
84
+ class PRData:
85
+ """Pull Request 数据"""
86
+ title: str
87
+ description: str
88
+ source_branch: str
89
+ target_branch: str
90
+ labels: List[str] = field(default_factory=list)
91
+ assignees: List[str] = field(default_factory=list)
92
+ reviewers: List[str] = field(default_factory=list)
93
+ draft: bool = False
94
+ template_type: Optional[str] = None
95
+ template_vars: Dict[str, str] = field(default_factory=dict)
96
+
97
+
98
+ @dataclass
99
+ class PRResult:
100
+ """Pull Request 操作结果"""
101
+ success: bool
102
+ pr_number: Optional[int] = None
103
+ pr_url: Optional[str] = None
104
+ pr_id: Optional[str] = None
105
+ error_message: Optional[str] = None
106
+ error_code: Optional[str] = None
107
+ platform: Optional[PlatformType] = None
108
+ raw_response: Optional[Dict[str, Any]] = None
109
+ retry_after: Optional[int] = None
110
+
111
+
112
+ @dataclass
113
+ class PRInfo:
114
+ """Pull Request 详细信息"""
115
+ number: int
116
+ title: str
117
+ description: str
118
+ state: str
119
+ source_branch: str
120
+ target_branch: str
121
+ author: str
122
+ created_at: str
123
+ updated_at: str
124
+ merged_at: Optional[str] = None
125
+ pr_url: str = ""
126
+ labels: List[str] = field(default_factory=list)
127
+ assignees: List[str] = field(default_factory=list)
128
+ reviewers: List[str] = field(default_factory=list)
129
+ mergeable: Optional[bool] = None
130
+ draft: bool = False
131
+ raw_data: Optional[Dict[str, Any]] = None
132
+
133
+
134
+ # 默认模板配置
135
+ DEFAULT_TEMPLATES = {
136
+ "bug_fix": {
137
+ "title_prefix": "🐛 Bug Fix:",
138
+ "description_template": """
139
+ ## 问题描述
140
+ {problem_description}
141
+
142
+ ## 解决方案
143
+ {solution_description}
144
+
145
+ ## 测试
146
+ - [ ] 单元测试通过
147
+ - [ ] 集成测试通过
148
+ - [ ] 手动测试验证
149
+ """
150
+ },
151
+ "feature": {
152
+ "title_prefix": "✨ Feature:",
153
+ "description_template": """
154
+ ## 新功能说明
155
+ {feature_description}
156
+
157
+ ## 实现细节
158
+ {implementation_details}
159
+
160
+ ## 使用示例
161
+ {usage_examples}
162
+ """
163
+ }
164
+ }
@@ -0,0 +1,23 @@
1
+ """
2
+ Pull Request 平台提供者模块
3
+ """
4
+ from .github_provider import GitHubProvider
5
+ from .gitlab_provider import GitLabProvider
6
+ from .gitee_provider import GiteeProvider
7
+ from .gitcode_provider import GitCodeProvider
8
+
9
+ # 提供者映射
10
+ PROVIDERS = {
11
+ 'github': GitHubProvider,
12
+ 'gitlab': GitLabProvider,
13
+ 'gitee': GiteeProvider,
14
+ 'gitcode': GitCodeProvider
15
+ }
16
+
17
+ __all__ = [
18
+ 'GitHubProvider',
19
+ 'GitLabProvider',
20
+ 'GiteeProvider',
21
+ 'GitCodeProvider',
22
+ 'PROVIDERS'
23
+ ]
@@ -0,0 +1,19 @@
1
+ """
2
+ GitCode API 提供者实现
3
+ """
4
+ from typing import List
5
+ from .github_provider import GitHubProvider
6
+ from ..models import RepoInfo, PRData, PRResult, PRInfo
7
+
8
+
9
+ class GitCodeProvider(GitHubProvider):
10
+ """GitCode API 提供者(基于GitHub提供者)"""
11
+
12
+ def _get_auth_header(self) -> str:
13
+ """获取认证头"""
14
+ return f"Bearer {self.config.token}"
15
+
16
+ def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
17
+ """创建 Merge Request"""
18
+ # 简化实现,实际应该调用GitCode API
19
+ return super().create_pr(repo_info, pr_data)
@@ -0,0 +1,20 @@
1
+ """
2
+ Gitee API 提供者实现
3
+ """
4
+ from typing import List
5
+ from .github_provider import GitHubProvider
6
+ from ..models import RepoInfo, PRData, PRResult, PRInfo
7
+
8
+
9
+ class GiteeProvider(GitHubProvider):
10
+ """Gitee API 提供者(基于GitHub提供者)"""
11
+
12
+ def _get_auth_header(self) -> str:
13
+ """获取认证头"""
14
+ # Gitee 使用 token 参数而不是 Authorization 头
15
+ return ""
16
+
17
+ def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
18
+ """创建 Pull Request"""
19
+ # 简化实现,实际应该调用Gitee API
20
+ return super().create_pr(repo_info, pr_data)
@@ -0,0 +1,214 @@
1
+ """
2
+ GitHub API 提供者实现
3
+ """
4
+ from typing import List, Dict, Any
5
+ from loguru import logger
6
+
7
+ from ..base_provider import BasePlatformProvider
8
+ from ..models import RepoInfo, PRData, PRResult, PRInfo
9
+ from ..utils import build_pr_url
10
+
11
+
12
+ class GitHubProvider(BasePlatformProvider):
13
+ """GitHub API 提供者"""
14
+
15
+ def _get_auth_header(self) -> str:
16
+ """获取认证头"""
17
+ return f"token {self.config.token}"
18
+
19
+ def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
20
+ """创建 Pull Request"""
21
+ self._validate_repo_info(repo_info)
22
+ self._validate_pr_data(pr_data)
23
+
24
+ url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls"
25
+
26
+ payload = {
27
+ "title": pr_data.title,
28
+ "body": pr_data.description,
29
+ "head": pr_data.source_branch,
30
+ "base": pr_data.target_branch,
31
+ "draft": pr_data.draft or self.config.draft,
32
+ "maintainer_can_modify": self.config.maintainer_can_modify
33
+ }
34
+
35
+ try:
36
+ response = self._make_request('POST', url, data=payload)
37
+ data = response.json()
38
+
39
+ pr_url = build_pr_url(self.config.platform, repo_info, data['number'])
40
+
41
+ result = PRResult(
42
+ success=True,
43
+ pr_number=data['number'],
44
+ pr_url=pr_url,
45
+ pr_id=str(data['id']),
46
+ platform=self.config.platform,
47
+ raw_response=data
48
+ )
49
+
50
+ logger.info(f"GitHub PR 创建成功: {pr_url}")
51
+ return result
52
+
53
+ except Exception as e:
54
+ logger.error(f"创建 GitHub PR 失败: {e}")
55
+ return PRResult(
56
+ success=False,
57
+ error_message=str(e),
58
+ platform=self.config.platform
59
+ )
60
+
61
+ def get_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
62
+ """获取 PR 信息"""
63
+ url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}"
64
+
65
+ try:
66
+ response = self._make_request('GET', url)
67
+ data = response.json()
68
+
69
+ pr_info = PRInfo(
70
+ number=data['number'],
71
+ title=data['title'],
72
+ description=data.get('body', ''),
73
+ state=data['state'],
74
+ source_branch=data['head']['ref'],
75
+ target_branch=data['base']['ref'],
76
+ author=data['user']['login'],
77
+ created_at=data['created_at'],
78
+ updated_at=data['updated_at'],
79
+ merged_at=data.get('merged_at'),
80
+ pr_url=data['html_url'],
81
+ labels=[label['name'] for label in data.get('labels', [])],
82
+ assignees=[assignee['login'] for assignee in data.get('assignees', [])],
83
+ mergeable=data.get('mergeable'),
84
+ draft=data.get('draft', False),
85
+ raw_data=data
86
+ )
87
+
88
+ return PRResult(
89
+ success=True,
90
+ pr_number=pr_info.number,
91
+ pr_url=pr_info.pr_url,
92
+ platform=self.config.platform,
93
+ raw_response=data
94
+ )
95
+
96
+ except Exception as e:
97
+ logger.error(f"获取 GitHub PR 失败: {e}")
98
+ return PRResult(
99
+ success=False,
100
+ error_message=str(e),
101
+ platform=self.config.platform
102
+ )
103
+
104
+ def update_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
105
+ """更新 PR"""
106
+ url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}"
107
+
108
+ payload = {}
109
+ if 'title' in kwargs:
110
+ payload['title'] = kwargs['title']
111
+ if 'description' in kwargs:
112
+ payload['body'] = kwargs['description']
113
+ if 'state' in kwargs:
114
+ payload['state'] = kwargs['state']
115
+
116
+ try:
117
+ response = self._make_request('PATCH', url, data=payload)
118
+ data = response.json()
119
+
120
+ return PRResult(
121
+ success=True,
122
+ pr_number=data['number'],
123
+ pr_url=data['html_url'],
124
+ platform=self.config.platform,
125
+ raw_response=data
126
+ )
127
+
128
+ except Exception as e:
129
+ logger.error(f"更新 GitHub PR 失败: {e}")
130
+ return PRResult(
131
+ success=False,
132
+ error_message=str(e),
133
+ platform=self.config.platform
134
+ )
135
+
136
+ def close_pr(self, repo_info: RepoInfo, pr_number: int) -> PRResult:
137
+ """关闭 PR"""
138
+ return self.update_pr(repo_info, pr_number, state='closed')
139
+
140
+ def merge_pr(self, repo_info: RepoInfo, pr_number: int, **kwargs) -> PRResult:
141
+ """合并 PR"""
142
+ url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls/{pr_number}/merge"
143
+
144
+ payload = {
145
+ 'commit_title': kwargs.get('commit_title', ''),
146
+ 'commit_message': kwargs.get('commit_message', ''),
147
+ 'merge_method': kwargs.get('merge_method', 'merge')
148
+ }
149
+
150
+ try:
151
+ response = self._make_request('PUT', url, data=payload)
152
+ data = response.json()
153
+
154
+ return PRResult(
155
+ success=True,
156
+ pr_number=pr_number,
157
+ platform=self.config.platform,
158
+ raw_response=data
159
+ )
160
+
161
+ except Exception as e:
162
+ logger.error(f"合并 GitHub PR 失败: {e}")
163
+ return PRResult(
164
+ success=False,
165
+ error_message=str(e),
166
+ platform=self.config.platform
167
+ )
168
+
169
+ def list_prs(
170
+ self,
171
+ repo_info: RepoInfo,
172
+ state: str = "open",
173
+ per_page: int = 30,
174
+ page: int = 1
175
+ ) -> List[PRInfo]:
176
+ """列出仓库的PR"""
177
+ url = f"{self.config.base_url}/repos/{repo_info.full_name}/pulls"
178
+ params = {
179
+ 'state': state,
180
+ 'per_page': str(per_page),
181
+ 'page': str(page)
182
+ }
183
+
184
+ try:
185
+ response = self._make_request('GET', url, params=params)
186
+ data = response.json()
187
+
188
+ prs = []
189
+ for pr_data in data:
190
+ pr_info = PRInfo(
191
+ number=pr_data['number'],
192
+ title=pr_data['title'],
193
+ description=pr_data.get('body', ''),
194
+ state=pr_data['state'],
195
+ source_branch=pr_data['head']['ref'],
196
+ target_branch=pr_data['base']['ref'],
197
+ author=pr_data['user']['login'],
198
+ created_at=pr_data['created_at'],
199
+ updated_at=pr_data['updated_at'],
200
+ merged_at=pr_data.get('merged_at'),
201
+ pr_url=pr_data['html_url'],
202
+ labels=[label['name'] for label in pr_data.get('labels', [])],
203
+ assignees=[assignee['login'] for assignee in pr_data.get('assignees', [])],
204
+ mergeable=pr_data.get('mergeable'),
205
+ draft=pr_data.get('draft', False),
206
+ raw_data=pr_data
207
+ )
208
+ prs.append(pr_info)
209
+
210
+ return prs
211
+
212
+ except Exception as e:
213
+ logger.error(f"列出 GitHub PR 失败: {e}")
214
+ return []
@@ -0,0 +1,29 @@
1
+ """
2
+ GitLab API 提供者实现
3
+ """
4
+ from typing import List
5
+ from .github_provider import GitHubProvider
6
+ from ..models import RepoInfo, PRData, PRResult, PRInfo
7
+
8
+
9
+ class GitLabProvider(GitHubProvider):
10
+ """GitLab API 提供者(基于GitHub提供者)"""
11
+
12
+ def _get_auth_header(self) -> str:
13
+ """获取认证头"""
14
+ return f"Bearer {self.config.token}"
15
+
16
+ def create_pr(self, repo_info: RepoInfo, pr_data: PRData) -> PRResult:
17
+ """创建 Merge Request (GitLab的PR称为MR)"""
18
+ # 简化实现,实际应该调用GitLab API
19
+ return super().create_pr(repo_info, pr_data)
20
+
21
+ def list_prs(
22
+ self,
23
+ repo_info: RepoInfo,
24
+ state: str = "opened", # GitLab使用"opened"而不是"open"
25
+ per_page: int = 30,
26
+ page: int = 1
27
+ ) -> List[PRInfo]:
28
+ """列出仓库的MR"""
29
+ return super().list_prs(repo_info, state, per_page, page)