auto-coder 0.1.316__py3-none-any.whl → 0.1.318__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/METADATA +2 -2
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/RECORD +41 -20
- autocoder/auto_coder_runner.py +1 -2
- autocoder/common/__init__.py +3 -0
- autocoder/common/auto_coder_lang.py +24 -0
- autocoder/common/code_auto_merge_editblock.py +2 -42
- autocoder/common/git_utils.py +2 -2
- autocoder/common/token_cost_caculate.py +103 -42
- autocoder/common/v2/__init__.py +0 -0
- autocoder/common/v2/code_auto_generate.py +199 -0
- autocoder/common/v2/code_auto_generate_diff.py +361 -0
- autocoder/common/v2/code_auto_generate_editblock.py +380 -0
- autocoder/common/v2/code_auto_generate_strict_diff.py +269 -0
- autocoder/common/v2/code_auto_merge.py +211 -0
- autocoder/common/v2/code_auto_merge_diff.py +354 -0
- autocoder/common/v2/code_auto_merge_editblock.py +523 -0
- autocoder/common/v2/code_auto_merge_strict_diff.py +259 -0
- autocoder/common/v2/code_diff_manager.py +266 -0
- autocoder/common/v2/code_editblock_manager.py +282 -0
- autocoder/common/v2/code_manager.py +238 -0
- autocoder/common/v2/code_strict_diff_manager.py +241 -0
- autocoder/dispacher/actions/action.py +16 -0
- autocoder/dispacher/actions/plugins/action_regex_project.py +6 -0
- autocoder/events/event_manager_singleton.py +2 -2
- autocoder/helper/__init__.py +0 -0
- autocoder/helper/project_creator.py +570 -0
- autocoder/linters/linter_factory.py +44 -25
- autocoder/linters/models.py +220 -0
- autocoder/linters/python_linter.py +1 -7
- autocoder/linters/reactjs_linter.py +580 -0
- autocoder/linters/shadow_linter.py +390 -0
- autocoder/linters/vue_linter.py +576 -0
- autocoder/memory/active_context_manager.py +0 -4
- autocoder/memory/active_package.py +12 -12
- autocoder/shadows/__init__.py +0 -0
- autocoder/shadows/shadow_manager.py +235 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.316.dist-info → auto_coder-0.1.318.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import difflib
|
|
3
|
+
import diff_match_patch as dmp_module
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Dict, Tuple
|
|
6
|
+
import pydantic
|
|
7
|
+
import byzerllm
|
|
8
|
+
from autocoder.common import AutoCoderArgs, git_utils
|
|
9
|
+
from autocoder.common.types import CodeGenerateResult, MergeCodeWithoutEffect
|
|
10
|
+
from autocoder.common.v2.code_auto_merge import CodeAutoMerge
|
|
11
|
+
from autocoder.common import files as FileUtils
|
|
12
|
+
|
|
13
|
+
class PathAndCode(pydantic.BaseModel):
|
|
14
|
+
path: str
|
|
15
|
+
content: str
|
|
16
|
+
|
|
17
|
+
def safe_abs_path(res):
|
|
18
|
+
"Gives an abs path, which safely returns a full (not 8.3) windows path"
|
|
19
|
+
res = Path(res).resolve()
|
|
20
|
+
return str(res)
|
|
21
|
+
|
|
22
|
+
def apply_hunk(content, hunk):
|
|
23
|
+
before, after = hunk_to_before_after(hunk)
|
|
24
|
+
|
|
25
|
+
# Get line numbers from @@ ... @@ markers
|
|
26
|
+
line_info = hunk[0].split("@@")[1].strip()
|
|
27
|
+
s_line_num = int(line_info.split(" ")[1].lstrip("+"))
|
|
28
|
+
|
|
29
|
+
# Split content into lines
|
|
30
|
+
content_lines = content.splitlines()
|
|
31
|
+
|
|
32
|
+
# Merge changes using difflib
|
|
33
|
+
merged_lines = list(difflib.ndiff(before.splitlines(), after.splitlines()))
|
|
34
|
+
|
|
35
|
+
# Apply changes to original content
|
|
36
|
+
j = 0
|
|
37
|
+
content_out = content_lines[:s_line_num]
|
|
38
|
+
for line in merged_lines:
|
|
39
|
+
if line.startswith("- "):
|
|
40
|
+
continue
|
|
41
|
+
elif line.startswith("+ "):
|
|
42
|
+
content_out.append(line[2:])
|
|
43
|
+
elif line.startswith(" "):
|
|
44
|
+
if j < len(content_lines):
|
|
45
|
+
content_out.append(content_lines[s_line_num+j])
|
|
46
|
+
j += 1
|
|
47
|
+
|
|
48
|
+
content_out.extend(content_lines[s_line_num+j:])
|
|
49
|
+
|
|
50
|
+
return "\n".join(content_out)
|
|
51
|
+
|
|
52
|
+
def hunk_to_before_after(hunk, lines=False):
|
|
53
|
+
before = []
|
|
54
|
+
after = []
|
|
55
|
+
op = " "
|
|
56
|
+
for line in hunk:
|
|
57
|
+
if len(line) < 2:
|
|
58
|
+
op = " "
|
|
59
|
+
line = line
|
|
60
|
+
else:
|
|
61
|
+
op = line[0]
|
|
62
|
+
line = line[1:]
|
|
63
|
+
|
|
64
|
+
if op == " ":
|
|
65
|
+
before.append(line)
|
|
66
|
+
after.append(line)
|
|
67
|
+
elif op == "-":
|
|
68
|
+
before.append(line)
|
|
69
|
+
elif op == "+":
|
|
70
|
+
after.append(line)
|
|
71
|
+
|
|
72
|
+
if lines:
|
|
73
|
+
return before, after
|
|
74
|
+
|
|
75
|
+
before = "".join(before)
|
|
76
|
+
after = "".join(after)
|
|
77
|
+
|
|
78
|
+
return before, after
|
|
79
|
+
|
|
80
|
+
class CodeAutoMergeStrictDiff(CodeAutoMerge):
|
|
81
|
+
def parse_diff_block(self,text: str) -> List[PathAndCode]:
|
|
82
|
+
lines = text.split('\n')
|
|
83
|
+
lines_len = len(lines)
|
|
84
|
+
start_marker_count = 0
|
|
85
|
+
inline_start_marker_count = 0
|
|
86
|
+
block = []
|
|
87
|
+
path_and_code_list = []
|
|
88
|
+
|
|
89
|
+
def guard(index):
|
|
90
|
+
return index+1 < lines_len
|
|
91
|
+
|
|
92
|
+
def start_marker(line,index):
|
|
93
|
+
return line.startswith('```diff') and guard(index)
|
|
94
|
+
|
|
95
|
+
def inline_start_marker(line,index):
|
|
96
|
+
return line.startswith('```') and not line.startswith('```diff') and line.strip() != '```'
|
|
97
|
+
|
|
98
|
+
def end_marker(line,index):
|
|
99
|
+
return line.startswith('```') and line.strip() == '```'
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
for (index,line) in enumerate(lines):
|
|
103
|
+
if start_marker(line,index) and start_marker_count == 0:
|
|
104
|
+
start_marker_count += 1
|
|
105
|
+
elif (start_marker(line,index) or inline_start_marker(line,index)) and start_marker_count > 0:
|
|
106
|
+
inline_start_marker_count += 1
|
|
107
|
+
block.append(line)
|
|
108
|
+
elif end_marker(line,index) and start_marker_count == 1 and inline_start_marker_count == 0:
|
|
109
|
+
start_marker_count -= 1
|
|
110
|
+
if block:
|
|
111
|
+
# ori_path = block[0][4:0].strip()
|
|
112
|
+
new_path = block[1][4:].strip()
|
|
113
|
+
content = '\n'.join(block)
|
|
114
|
+
block = []
|
|
115
|
+
path_and_code_list.append(PathAndCode(path=new_path,content=content))
|
|
116
|
+
elif end_marker(line,index) and inline_start_marker_count > 0:
|
|
117
|
+
inline_start_marker_count -= 1
|
|
118
|
+
block.append(line)
|
|
119
|
+
elif start_marker_count > 0:
|
|
120
|
+
block.append(line)
|
|
121
|
+
|
|
122
|
+
return path_and_code_list
|
|
123
|
+
|
|
124
|
+
def abs_root_path(self, path):
|
|
125
|
+
if path.startswith(self.args.source_dir):
|
|
126
|
+
return safe_abs_path(Path(path))
|
|
127
|
+
res = Path(self.args.source_dir) / path
|
|
128
|
+
return safe_abs_path(res)
|
|
129
|
+
|
|
130
|
+
def _merge_code_without_effect(self, content: str) -> MergeCodeWithoutEffect:
|
|
131
|
+
"""Merge code without any side effects like git operations or file writing.
|
|
132
|
+
Returns a tuple of:
|
|
133
|
+
- list of (file_path, new_content) tuples for successfully merged blocks
|
|
134
|
+
- list of (file_path, content) tuples for failed to merge blocks"""
|
|
135
|
+
diff_blocks = self.parse_diff_block(content)
|
|
136
|
+
file_content_mapping = {}
|
|
137
|
+
failed_blocks = []
|
|
138
|
+
|
|
139
|
+
for block in diff_blocks:
|
|
140
|
+
path = block.path
|
|
141
|
+
content = block.content
|
|
142
|
+
full_path = self.abs_root_path(path)
|
|
143
|
+
|
|
144
|
+
if not os.path.exists(full_path):
|
|
145
|
+
file_content_mapping[full_path] = content
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
if full_path not in file_content_mapping:
|
|
149
|
+
file_content_mapping[full_path] = FileUtils.read_file(full_path)
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
import patch
|
|
153
|
+
patch_obj = patch.fromstring(content.encode('utf-8'))
|
|
154
|
+
root_path = None
|
|
155
|
+
if not path.startswith(self.args.source_dir):
|
|
156
|
+
root_path = self.args.source_dir
|
|
157
|
+
|
|
158
|
+
# Create a copy of the content to apply patch
|
|
159
|
+
temp_content = file_content_mapping[full_path]
|
|
160
|
+
success = patch_obj.apply(root=root_path, content=temp_content)
|
|
161
|
+
if success:
|
|
162
|
+
file_content_mapping[full_path] = temp_content
|
|
163
|
+
else:
|
|
164
|
+
failed_blocks.append((full_path, content))
|
|
165
|
+
except Exception as e:
|
|
166
|
+
self.printer.print_in_terminal("merge_failed", style="yellow", path=full_path, error=str(e))
|
|
167
|
+
failed_blocks.append((full_path, content))
|
|
168
|
+
|
|
169
|
+
return MergeCodeWithoutEffect(
|
|
170
|
+
success_blocks=[(path, content) for path, content in file_content_mapping.items()],
|
|
171
|
+
failed_blocks=failed_blocks
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def print_diff_blocks(self, diff_blocks: List[PathAndCode]):
|
|
175
|
+
"""Print diff blocks for user review using rich library"""
|
|
176
|
+
from rich.syntax import Syntax
|
|
177
|
+
from rich.panel import Panel
|
|
178
|
+
|
|
179
|
+
# Group blocks by file path
|
|
180
|
+
file_blocks = {}
|
|
181
|
+
for block in diff_blocks:
|
|
182
|
+
if block.path not in file_blocks:
|
|
183
|
+
file_blocks[block.path] = []
|
|
184
|
+
file_blocks[block.path].append(block.content)
|
|
185
|
+
|
|
186
|
+
# Generate formatted text for each file
|
|
187
|
+
formatted_text = ""
|
|
188
|
+
for path, contents in file_blocks.items():
|
|
189
|
+
formatted_text += f"##File: {path}\n"
|
|
190
|
+
for content in contents:
|
|
191
|
+
formatted_text += content + "\n"
|
|
192
|
+
formatted_text += "\n"
|
|
193
|
+
|
|
194
|
+
# Print with rich panel
|
|
195
|
+
self.printer.print_in_terminal("diff_blocks_title", style="bold green")
|
|
196
|
+
self.printer.console.print(
|
|
197
|
+
Panel(
|
|
198
|
+
Syntax(formatted_text, "diff", theme="monokai"),
|
|
199
|
+
title="Diff Blocks",
|
|
200
|
+
border_style="green",
|
|
201
|
+
expand=False
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _merge_code(self, content: str, force_skip_git: bool = False):
|
|
206
|
+
total = 0
|
|
207
|
+
|
|
208
|
+
file_content = FileUtils.read_file(self.args.file)
|
|
209
|
+
md5 = hashlib.md5(file_content.encode('utf-8')).hexdigest()
|
|
210
|
+
file_name = os.path.basename(self.args.file)
|
|
211
|
+
|
|
212
|
+
if not force_skip_git and not self.args.skip_commit:
|
|
213
|
+
try:
|
|
214
|
+
git_utils.commit_changes(self.args.source_dir, f"auto_coder_pre_{file_name}_{md5}")
|
|
215
|
+
except Exception as e:
|
|
216
|
+
self.printer.print_in_terminal("git_init_required", style="red", source_dir=self.args.source_dir, error=str(e))
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
diff_blocks = self.parse_diff_block(content)
|
|
220
|
+
|
|
221
|
+
for diff_blocks in diff_blocks:
|
|
222
|
+
path = diff_blocks.path
|
|
223
|
+
content = diff_blocks.content
|
|
224
|
+
|
|
225
|
+
import patch
|
|
226
|
+
patch_obj = patch.fromstring(content.encode('utf-8'))
|
|
227
|
+
root_path = None
|
|
228
|
+
if not path.startswith(self.args.source_dir):
|
|
229
|
+
root_path = self.args.source_dir
|
|
230
|
+
|
|
231
|
+
success = patch_obj.apply(root=root_path)
|
|
232
|
+
if not success:
|
|
233
|
+
raise Exception("Error applying diff to file: " + path)
|
|
234
|
+
|
|
235
|
+
self.printer.print_in_terminal("files_merged_total", total=total)
|
|
236
|
+
if not force_skip_git and not self.args.skip_commit:
|
|
237
|
+
commit_result = git_utils.commit_changes(
|
|
238
|
+
self.args.source_dir, f"{self.args.query}\nauto_coder_{file_name}"
|
|
239
|
+
)
|
|
240
|
+
action_yml_file_manager = ActionYmlFileManager(self.args.source_dir)
|
|
241
|
+
action_file_name = os.path.basename(self.args.file)
|
|
242
|
+
add_updated_urls = []
|
|
243
|
+
commit_result.changed_files
|
|
244
|
+
for file in commit_result.changed_files:
|
|
245
|
+
add_updated_urls.append(os.path.join(self.args.source_dir, file))
|
|
246
|
+
|
|
247
|
+
self.args.add_updated_urls = add_updated_urls
|
|
248
|
+
update_yaml_success = action_yml_file_manager.update_yaml_field(action_file_name, "add_updated_urls", add_updated_urls)
|
|
249
|
+
if not update_yaml_success:
|
|
250
|
+
self.printer.print_in_terminal("yaml_save_error", style="red", yaml_file=action_file_name)
|
|
251
|
+
|
|
252
|
+
if self.args.enable_active_context:
|
|
253
|
+
active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
|
|
254
|
+
active_context_manager.process_changes(self.args)
|
|
255
|
+
|
|
256
|
+
git_utils.print_commit_info(commit_result=commit_result)
|
|
257
|
+
else:
|
|
258
|
+
# Print diff blocks for review
|
|
259
|
+
self.print_diff_blocks(diff_blocks)
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from typing import List, Dict, Tuple, Optional, Any
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
+
|
|
7
|
+
import byzerllm
|
|
8
|
+
from byzerllm.utils.client import code_utils
|
|
9
|
+
|
|
10
|
+
from autocoder.common.types import Mode, CodeGenerateResult, MergeCodeWithoutEffect
|
|
11
|
+
from autocoder.common import AutoCoderArgs, git_utils, SourceCodeList
|
|
12
|
+
from autocoder.common import sys_prompt
|
|
13
|
+
from autocoder.privacy.model_filter import ModelPathFilter
|
|
14
|
+
from autocoder.common.utils_code_auto_generate import chat_with_continue, stream_chat_with_continue, ChatWithContinueResult
|
|
15
|
+
from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
|
|
16
|
+
from autocoder.common.stream_out_type import CodeGenerateStreamOutType
|
|
17
|
+
from autocoder.common.auto_coder_lang import get_message_with_format
|
|
18
|
+
from autocoder.common.printer import Printer
|
|
19
|
+
from autocoder.rag.token_counter import count_tokens
|
|
20
|
+
from autocoder.utils import llms as llm_utils
|
|
21
|
+
from autocoder.memory.active_context_manager import ActiveContextManager
|
|
22
|
+
from autocoder.common.v2.code_auto_generate_diff import CodeAutoGenerateDiff
|
|
23
|
+
from autocoder.common.v2.code_auto_merge_diff import CodeAutoMergeDiff
|
|
24
|
+
from autocoder.shadows.shadow_manager import ShadowManager
|
|
25
|
+
from autocoder.linters.shadow_linter import ShadowLinter
|
|
26
|
+
from autocoder.linters.models import IssueSeverity
|
|
27
|
+
from loguru import logger
|
|
28
|
+
from autocoder.common.global_cancel import global_cancel
|
|
29
|
+
from autocoder.linters.models import ProjectLintResult
|
|
30
|
+
from autocoder.common.token_cost_caculate import TokenCostCalculator
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CodeDiffManager:
|
|
34
|
+
"""
|
|
35
|
+
A class that combines code generation, linting, and merging with automatic error correction.
|
|
36
|
+
It generates code using diff format, lints it, and if there are errors, regenerates the code up to 5 times
|
|
37
|
+
before merging the final result.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
llm: byzerllm.ByzerLLM,
|
|
43
|
+
args: AutoCoderArgs,
|
|
44
|
+
action=None,
|
|
45
|
+
) -> None:
|
|
46
|
+
self.llm = llm
|
|
47
|
+
self.args = args
|
|
48
|
+
self.action = action
|
|
49
|
+
self.generate_times_same_model = args.generate_times_same_model
|
|
50
|
+
self.max_correction_attempts = args.auto_fix_lint_max_attempts
|
|
51
|
+
self.printer = Printer()
|
|
52
|
+
|
|
53
|
+
# Initialize sub-components
|
|
54
|
+
self.code_generator = CodeAutoGenerateDiff(llm, args, action)
|
|
55
|
+
self.code_merger = CodeAutoMergeDiff(llm, args)
|
|
56
|
+
|
|
57
|
+
# Create shadow manager for linting
|
|
58
|
+
self.shadow_manager = ShadowManager(args.source_dir)
|
|
59
|
+
self.shadow_linter = ShadowLinter(self.shadow_manager, verbose=False)
|
|
60
|
+
|
|
61
|
+
@byzerllm.prompt()
|
|
62
|
+
def fix_linter_errors(self, query: str, lint_issues: str) -> str:
|
|
63
|
+
"""
|
|
64
|
+
Linter 检测到的问题:
|
|
65
|
+
<lint_issues>
|
|
66
|
+
{{ lint_issues }}
|
|
67
|
+
</lint_issues>
|
|
68
|
+
|
|
69
|
+
用户原始需求:
|
|
70
|
+
<user_query_wrapper>
|
|
71
|
+
{{ query }}
|
|
72
|
+
</user_query_wrapper>
|
|
73
|
+
|
|
74
|
+
修复上述问题,请确保代码质量问题被解决,同时保持代码的原有功能。
|
|
75
|
+
请使用 unified diff 格式输出修改。
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def _create_shadow_files_from_edits(self, generation_result: CodeGenerateResult) -> Dict[str, str]:
|
|
79
|
+
"""
|
|
80
|
+
从编辑块内容中提取代码并创建临时影子文件用于检查。
|
|
81
|
+
|
|
82
|
+
参数:
|
|
83
|
+
generation_result (CodeGenerateResult): 包含diff格式的内容
|
|
84
|
+
|
|
85
|
+
返回:
|
|
86
|
+
Dict[str, str]: 映射 {影子文件路径: 内容}
|
|
87
|
+
"""
|
|
88
|
+
result = self.code_merger.choose_best_choice(generation_result)
|
|
89
|
+
merge = self.code_merger._merge_code_without_effect(result.contents[0])
|
|
90
|
+
shadow_files = {}
|
|
91
|
+
for file_path, new_content in merge.success_blocks:
|
|
92
|
+
self.shadow_manager.update_file(file_path, new_content)
|
|
93
|
+
shadow_files[self.shadow_manager.to_shadow_path(file_path)] = new_content
|
|
94
|
+
|
|
95
|
+
return shadow_files
|
|
96
|
+
|
|
97
|
+
def _format_lint_issues(self, lint_results:ProjectLintResult,level:IssueSeverity) -> str:
|
|
98
|
+
"""
|
|
99
|
+
将linter结果格式化为字符串供模型使用
|
|
100
|
+
|
|
101
|
+
参数:
|
|
102
|
+
lint_results: Linter结果对象
|
|
103
|
+
level: 过滤问题的级别
|
|
104
|
+
|
|
105
|
+
返回:
|
|
106
|
+
str: 格式化的问题描述
|
|
107
|
+
"""
|
|
108
|
+
formatted_issues = []
|
|
109
|
+
|
|
110
|
+
for file_path, result in lint_results.file_results.items():
|
|
111
|
+
file_has_issues = False
|
|
112
|
+
file_issues = []
|
|
113
|
+
|
|
114
|
+
for issue in result.issues:
|
|
115
|
+
if issue.severity.value != level.value:
|
|
116
|
+
continue
|
|
117
|
+
|
|
118
|
+
if not file_has_issues:
|
|
119
|
+
file_has_issues = True
|
|
120
|
+
file_issues.append(f"文件: {file_path}")
|
|
121
|
+
|
|
122
|
+
severity = "错误" if issue.severity == IssueSeverity.ERROR else "警告" if issue.severity == IssueSeverity.WARNING else "信息"
|
|
123
|
+
line_info = f"第{issue.position.line}行"
|
|
124
|
+
if issue.position.column:
|
|
125
|
+
line_info += f", 第{issue.position.column}列"
|
|
126
|
+
|
|
127
|
+
file_issues.append(
|
|
128
|
+
f" - [{severity}] {line_info}: {issue.message} (规则: {issue.code})"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if file_has_issues:
|
|
132
|
+
formatted_issues.extend(file_issues)
|
|
133
|
+
formatted_issues.append("") # 空行分隔不同文件
|
|
134
|
+
|
|
135
|
+
return "\n".join(formatted_issues)
|
|
136
|
+
|
|
137
|
+
def _count_errors(self, lint_results:ProjectLintResult) -> int:
|
|
138
|
+
"""
|
|
139
|
+
计算lint结果中的错误数量
|
|
140
|
+
|
|
141
|
+
参数:
|
|
142
|
+
lint_results: Linter结果对象
|
|
143
|
+
|
|
144
|
+
返回:
|
|
145
|
+
int: 错误数量
|
|
146
|
+
"""
|
|
147
|
+
error_count = 0
|
|
148
|
+
|
|
149
|
+
for _, result in lint_results.file_results.items():
|
|
150
|
+
error_count += result.error_count
|
|
151
|
+
|
|
152
|
+
return error_count
|
|
153
|
+
|
|
154
|
+
def generate_and_fix(self, query: str, source_code_list: SourceCodeList) -> CodeGenerateResult:
|
|
155
|
+
"""
|
|
156
|
+
生成代码,运行linter,修复错误,最多尝试指定次数
|
|
157
|
+
|
|
158
|
+
参数:
|
|
159
|
+
query (str): 用户查询
|
|
160
|
+
source_code_list (SourceCodeList): 源代码列表
|
|
161
|
+
|
|
162
|
+
返回:
|
|
163
|
+
CodeGenerateResult: 生成的代码结果
|
|
164
|
+
"""
|
|
165
|
+
# 初始代码生成
|
|
166
|
+
self.printer.print_in_terminal("generating_initial_code")
|
|
167
|
+
start_time = time.time()
|
|
168
|
+
generation_result = self.code_generator.single_round_run(query, source_code_list)
|
|
169
|
+
|
|
170
|
+
token_cost_calculator = TokenCostCalculator(args=self.args)
|
|
171
|
+
token_cost_calculator.track_token_usage_by_generate(
|
|
172
|
+
llm=self.llm,
|
|
173
|
+
generate=generation_result,
|
|
174
|
+
operation_name="code_generation_complete",
|
|
175
|
+
start_time=start_time,
|
|
176
|
+
end_time=time.time()
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# 确保结果非空
|
|
180
|
+
if not generation_result.contents:
|
|
181
|
+
self.printer.print_in_terminal("generation_failed", style="red")
|
|
182
|
+
return generation_result
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# 最多尝试修复5次
|
|
186
|
+
for attempt in range(self.max_correction_attempts):
|
|
187
|
+
global_cancel.check_and_raise()
|
|
188
|
+
# 代码生成结果更新到影子文件里去
|
|
189
|
+
shadow_files = self._create_shadow_files_from_edits(generation_result)
|
|
190
|
+
|
|
191
|
+
if not shadow_files:
|
|
192
|
+
self.printer.print_in_terminal("no_files_to_lint", style="yellow")
|
|
193
|
+
break
|
|
194
|
+
|
|
195
|
+
# 运行linter
|
|
196
|
+
lint_results = self.shadow_linter.lint_all_shadow_files()
|
|
197
|
+
error_count = self._count_errors(lint_results)
|
|
198
|
+
|
|
199
|
+
# 如果没有错误则完成
|
|
200
|
+
if error_count == 0:
|
|
201
|
+
self.printer.print_in_terminal("no_lint_errors_found", style="green")
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
# 格式化lint问题
|
|
205
|
+
formatted_issues = self._format_lint_issues(lint_results, IssueSeverity.ERROR)
|
|
206
|
+
|
|
207
|
+
# 打印当前错误
|
|
208
|
+
self.printer.print_in_terminal(
|
|
209
|
+
"lint_attempt_status",
|
|
210
|
+
style="yellow",
|
|
211
|
+
attempt=(attempt + 1),
|
|
212
|
+
max_correction_attempts=self.max_correction_attempts,
|
|
213
|
+
error_count=error_count,
|
|
214
|
+
formatted_issues=formatted_issues
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if attempt == self.max_correction_attempts - 1:
|
|
218
|
+
self.printer.print_in_terminal("max_attempts_reached", style="yellow")
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
# 准备修复提示
|
|
222
|
+
fix_prompt = self.fix_linter_errors.prompt(
|
|
223
|
+
query=query,
|
|
224
|
+
lint_issues=formatted_issues
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for source in source_code_list.sources:
|
|
228
|
+
print(f"file_path: {source.module_name}")
|
|
229
|
+
print(f"fix_prompt: {fix_prompt}")
|
|
230
|
+
|
|
231
|
+
# 将 shadow_files 转化为 source_code_list
|
|
232
|
+
source_code_list = self.code_merger.get_source_code_list_from_shadow_files(shadow_files)
|
|
233
|
+
start_time = time.time()
|
|
234
|
+
generation_result = self.code_generator.single_round_run(fix_prompt, source_code_list)
|
|
235
|
+
token_cost_calculator.track_token_usage_by_generate(
|
|
236
|
+
llm=self.llm,
|
|
237
|
+
generate=generation_result,
|
|
238
|
+
operation_name="code_generation_complete",
|
|
239
|
+
start_time=start_time,
|
|
240
|
+
end_time=time.time()
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# 清理临时影子文件
|
|
244
|
+
self.shadow_manager.clean_shadows()
|
|
245
|
+
|
|
246
|
+
# 返回最终结果
|
|
247
|
+
return generation_result
|
|
248
|
+
|
|
249
|
+
def run(self, query: str, source_code_list: SourceCodeList) -> CodeGenerateResult:
|
|
250
|
+
"""
|
|
251
|
+
执行完整的代码生成、修复、合并流程
|
|
252
|
+
|
|
253
|
+
参数:
|
|
254
|
+
query (str): 用户查询
|
|
255
|
+
source_code_list (SourceCodeList): 源代码列表
|
|
256
|
+
|
|
257
|
+
返回:
|
|
258
|
+
CodeGenerateResult: 生成和修复的代码结果
|
|
259
|
+
"""
|
|
260
|
+
# 生成代码并自动修复lint错误
|
|
261
|
+
generation_result = self.generate_and_fix(query, source_code_list)
|
|
262
|
+
global_cancel.check_and_raise()
|
|
263
|
+
# 合并代码
|
|
264
|
+
self.code_merger.merge_code(generation_result)
|
|
265
|
+
|
|
266
|
+
return generation_result
|