autocoder-nano 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autocoder_nano/agent/agent_base.py +376 -63
- autocoder_nano/auto_coder_nano.py +147 -1842
- autocoder_nano/edit/__init__.py +20 -0
- autocoder_nano/edit/actions.py +136 -0
- autocoder_nano/edit/code/__init__.py +0 -0
- autocoder_nano/edit/code/generate_editblock.py +403 -0
- autocoder_nano/edit/code/merge_editblock.py +418 -0
- autocoder_nano/edit/code/modification_ranker.py +90 -0
- autocoder_nano/edit/text.py +38 -0
- autocoder_nano/index/__init__.py +0 -0
- autocoder_nano/index/entry.py +166 -0
- autocoder_nano/index/index_manager.py +410 -0
- autocoder_nano/index/symbols_utils.py +43 -0
- autocoder_nano/llm_types.py +12 -8
- autocoder_nano/version.py +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/METADATA +1 -1
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/RECORD +21 -10
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/LICENSE +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/WHEEL +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/entry_points.txt +0 -0
- {autocoder_nano-0.1.25.dist-info → autocoder_nano-0.1.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,418 @@
|
|
1
|
+
import hashlib
|
2
|
+
import os
|
3
|
+
import subprocess
|
4
|
+
import tempfile
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
from loguru import logger
|
8
|
+
from rich.console import Console
|
9
|
+
from rich.panel import Panel
|
10
|
+
from rich.syntax import Syntax
|
11
|
+
from rich.table import Table
|
12
|
+
|
13
|
+
from autocoder_nano.edit.code.modification_ranker import CodeModificationRanker
|
14
|
+
from autocoder_nano.edit.text import TextSimilarity
|
15
|
+
from autocoder_nano.git_utils import commit_changes
|
16
|
+
from autocoder_nano.llm_client import AutoLLM
|
17
|
+
from autocoder_nano.llm_prompt import prompt
|
18
|
+
from autocoder_nano.llm_types import AutoCoderArgs, PathAndCode, MergeCodeWithoutEffect, CodeGenerateResult, \
|
19
|
+
CommitResult
|
20
|
+
|
21
|
+
|
22
|
+
console = Console()
|
23
|
+
|
24
|
+
|
25
|
+
def git_print_commit_info(commit_result: CommitResult):
|
26
|
+
table = Table(
|
27
|
+
title="Commit Information (Use /revert to revert this commit)", show_header=True, header_style="bold magenta"
|
28
|
+
)
|
29
|
+
table.add_column("Attribute", style="cyan", no_wrap=True)
|
30
|
+
table.add_column("Value", style="green")
|
31
|
+
|
32
|
+
table.add_row("Commit Hash", commit_result.commit_hash)
|
33
|
+
table.add_row("Commit Message", commit_result.commit_message)
|
34
|
+
table.add_row("Changed Files", "\n".join(commit_result.changed_files))
|
35
|
+
|
36
|
+
console.print(
|
37
|
+
Panel(table, expand=False, border_style="green", title="Git Commit Summary")
|
38
|
+
)
|
39
|
+
|
40
|
+
if commit_result.diffs:
|
41
|
+
for file, diff in commit_result.diffs.items():
|
42
|
+
console.print(f"\n[bold blue]File: {file}[/bold blue]")
|
43
|
+
syntax = Syntax(diff, "diff", theme="monokai", line_numbers=True)
|
44
|
+
console.print(
|
45
|
+
Panel(syntax, expand=False, border_style="yellow", title="File Diff")
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
class CodeAutoMergeEditBlock:
|
50
|
+
def __init__(self, args: AutoCoderArgs, llm: AutoLLM, fence_0: str = "```", fence_1: str = "```"):
|
51
|
+
self.llm = llm
|
52
|
+
self.llm.setup_default_model_name(args.code_model)
|
53
|
+
self.args = args
|
54
|
+
self.fence_0 = fence_0
|
55
|
+
self.fence_1 = fence_1
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def run_pylint(code: str) -> tuple[bool, str]:
|
59
|
+
"""
|
60
|
+
--disable=all 禁用所有 Pylint 的检查规则
|
61
|
+
--enable=E0001,W0311,W0312 启用指定的 Pylint 检查规则,
|
62
|
+
E0001:语法错误(Syntax Error),
|
63
|
+
W0311:代码缩进使用了 Tab 而不是空格(Bad indentation)
|
64
|
+
W0312:代码缩进不一致(Mixed indentation)
|
65
|
+
:param code:
|
66
|
+
:return:
|
67
|
+
"""
|
68
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file:
|
69
|
+
temp_file.write(code)
|
70
|
+
temp_file_path = temp_file.name
|
71
|
+
|
72
|
+
try:
|
73
|
+
result = subprocess.run(
|
74
|
+
["pylint", "--disable=all", "--enable=E0001,W0311,W0312", temp_file_path,],
|
75
|
+
capture_output=True,
|
76
|
+
text=True,
|
77
|
+
check=False,
|
78
|
+
)
|
79
|
+
os.unlink(temp_file_path)
|
80
|
+
if result.returncode != 0:
|
81
|
+
error_message = result.stdout.strip() or result.stderr.strip()
|
82
|
+
logger.warning(f"Pylint 检查代码失败: {error_message}")
|
83
|
+
return False, error_message
|
84
|
+
return True, ""
|
85
|
+
except subprocess.CalledProcessError as e:
|
86
|
+
error_message = f"运行 Pylint 时发生错误: {str(e)}"
|
87
|
+
logger.error(error_message)
|
88
|
+
os.unlink(temp_file_path)
|
89
|
+
return False, error_message
|
90
|
+
|
91
|
+
def parse_whole_text(self, text: str) -> List[PathAndCode]:
|
92
|
+
"""
|
93
|
+
从文本中抽取如下格式代码(two_line_mode):
|
94
|
+
|
95
|
+
```python
|
96
|
+
##File: /project/path/src/autocoder/index/index.py
|
97
|
+
<<<<<<< SEARCH
|
98
|
+
=======
|
99
|
+
>>>>>>> REPLACE
|
100
|
+
```
|
101
|
+
|
102
|
+
或者 (one_line_mode)
|
103
|
+
|
104
|
+
```python:/project/path/src/autocoder/index/index.py
|
105
|
+
<<<<<<< SEARCH
|
106
|
+
=======
|
107
|
+
>>>>>>> REPLACE
|
108
|
+
```
|
109
|
+
"""
|
110
|
+
HEAD = "<<<<<<< SEARCH"
|
111
|
+
DIVIDER = "======="
|
112
|
+
UPDATED = ">>>>>>> REPLACE"
|
113
|
+
lines = text.split("\n")
|
114
|
+
lines_len = len(lines)
|
115
|
+
start_marker_count = 0
|
116
|
+
block = []
|
117
|
+
path_and_code_list = []
|
118
|
+
# two_line_mode or one_line_mode
|
119
|
+
current_editblock_mode = "two_line_mode"
|
120
|
+
current_editblock_path = None
|
121
|
+
|
122
|
+
def guard(_index):
|
123
|
+
return _index + 1 < lines_len
|
124
|
+
|
125
|
+
def start_marker(_line, _index):
|
126
|
+
nonlocal current_editblock_mode
|
127
|
+
nonlocal current_editblock_path
|
128
|
+
if _line.startswith(self.fence_0) and guard(_index) and ":" in _line and lines[_index + 1].startswith(HEAD):
|
129
|
+
current_editblock_mode = "one_line_mode"
|
130
|
+
current_editblock_path = _line.split(":", 1)[1].strip()
|
131
|
+
return True
|
132
|
+
if _line.startswith(self.fence_0) and guard(_index) and lines[_index + 1].startswith("##File:"):
|
133
|
+
current_editblock_mode = "two_line_mode"
|
134
|
+
current_editblock_path = None
|
135
|
+
return True
|
136
|
+
return False
|
137
|
+
|
138
|
+
def end_marker(_line, _index):
|
139
|
+
return _line.startswith(self.fence_1) and UPDATED in lines[_index - 1]
|
140
|
+
|
141
|
+
for index, line in enumerate(lines):
|
142
|
+
if start_marker(line, index) and start_marker_count == 0:
|
143
|
+
start_marker_count += 1
|
144
|
+
elif end_marker(line, index) and start_marker_count == 1:
|
145
|
+
start_marker_count -= 1
|
146
|
+
if block:
|
147
|
+
if current_editblock_mode == "two_line_mode":
|
148
|
+
path = block[0].split(":", 1)[1].strip()
|
149
|
+
content = "\n".join(block[1:])
|
150
|
+
else:
|
151
|
+
path = current_editblock_path
|
152
|
+
content = "\n".join(block)
|
153
|
+
block = []
|
154
|
+
path_and_code_list.append(PathAndCode(path=path, content=content))
|
155
|
+
elif start_marker_count > 0:
|
156
|
+
block.append(line)
|
157
|
+
|
158
|
+
return path_and_code_list
|
159
|
+
|
160
|
+
def get_edits(self, content: str):
|
161
|
+
edits = self.parse_whole_text(content)
|
162
|
+
HEAD = "<<<<<<< SEARCH"
|
163
|
+
DIVIDER = "======="
|
164
|
+
UPDATED = ">>>>>>> REPLACE"
|
165
|
+
result = []
|
166
|
+
for edit in edits:
|
167
|
+
heads = []
|
168
|
+
updates = []
|
169
|
+
c = edit.content
|
170
|
+
in_head = False
|
171
|
+
in_updated = False
|
172
|
+
for line in c.splitlines():
|
173
|
+
if line.strip() == HEAD:
|
174
|
+
in_head = True
|
175
|
+
continue
|
176
|
+
if line.strip() == DIVIDER:
|
177
|
+
in_head = False
|
178
|
+
in_updated = True
|
179
|
+
continue
|
180
|
+
if line.strip() == UPDATED:
|
181
|
+
in_head = False
|
182
|
+
in_updated = False
|
183
|
+
continue
|
184
|
+
if in_head:
|
185
|
+
heads.append(line)
|
186
|
+
if in_updated:
|
187
|
+
updates.append(line)
|
188
|
+
result.append((edit.path, "\n".join(heads), "\n".join(updates)))
|
189
|
+
return result
|
190
|
+
|
191
|
+
@prompt()
|
192
|
+
def git_require_msg(self, source_dir: str, error: str) -> str:
|
193
|
+
"""
|
194
|
+
auto_merge only works for git repositories.
|
195
|
+
|
196
|
+
Try to use git init in the source directory.
|
197
|
+
|
198
|
+
```shell
|
199
|
+
cd {{ source_dir }}
|
200
|
+
git init .
|
201
|
+
```
|
202
|
+
|
203
|
+
Then try to run auto-coder again.
|
204
|
+
Error: {{ error }}
|
205
|
+
"""
|
206
|
+
|
207
|
+
def _merge_code_without_effect(self, content: str) -> MergeCodeWithoutEffect:
|
208
|
+
"""
|
209
|
+
合并代码时不会产生任何副作用,例如 Git 操作、代码检查或文件写入。
|
210
|
+
返回一个元组,包含:
|
211
|
+
- 成功合并的代码块的列表,每个元素是一个 (file_path, new_content) 元组,
|
212
|
+
其中 file_path 是文件路径,new_content 是合并后的新内容。
|
213
|
+
- 合并失败的代码块的列表,每个元素是一个 (file_path, head, update) 元组,
|
214
|
+
其中:file_path 是文件路径,head 是原始内容,update 是尝试合并的内容。
|
215
|
+
"""
|
216
|
+
codes = self.get_edits(content)
|
217
|
+
file_content_mapping = {}
|
218
|
+
failed_blocks = []
|
219
|
+
|
220
|
+
for block in codes:
|
221
|
+
file_path, head, update = block
|
222
|
+
if not os.path.exists(file_path):
|
223
|
+
file_content_mapping[file_path] = update
|
224
|
+
else:
|
225
|
+
if file_path not in file_content_mapping:
|
226
|
+
with open(file_path, "r") as f:
|
227
|
+
temp = f.read()
|
228
|
+
file_content_mapping[file_path] = temp
|
229
|
+
existing_content = file_content_mapping[file_path]
|
230
|
+
|
231
|
+
# First try exact match
|
232
|
+
new_content = (
|
233
|
+
existing_content.replace(head, update, 1)
|
234
|
+
if head
|
235
|
+
else existing_content + "\n" + update
|
236
|
+
)
|
237
|
+
|
238
|
+
# If exact match fails, try similarity match
|
239
|
+
if new_content == existing_content and head:
|
240
|
+
similarity, best_window = TextSimilarity(
|
241
|
+
head, existing_content
|
242
|
+
).get_best_matching_window()
|
243
|
+
if similarity > self.args.editblock_similarity:
|
244
|
+
new_content = existing_content.replace(
|
245
|
+
best_window, update, 1
|
246
|
+
)
|
247
|
+
|
248
|
+
if new_content != existing_content:
|
249
|
+
file_content_mapping[file_path] = new_content
|
250
|
+
else:
|
251
|
+
failed_blocks.append((file_path, head, update))
|
252
|
+
return MergeCodeWithoutEffect(
|
253
|
+
success_blocks=[(path, content) for path, content in file_content_mapping.items()],
|
254
|
+
failed_blocks=failed_blocks
|
255
|
+
)
|
256
|
+
|
257
|
+
def choose_best_choice(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
258
|
+
""" 选择最佳代码 """
|
259
|
+
if len(generate_result.contents) == 1: # 仅一份代码立即返回
|
260
|
+
logger.info("仅有一个候选结果,跳过排序")
|
261
|
+
return generate_result
|
262
|
+
|
263
|
+
ranker = CodeModificationRanker(args=self.args, llm=self.llm)
|
264
|
+
ranked_result = ranker.rank_modifications(generate_result)
|
265
|
+
# 过滤掉包含失败块的内容
|
266
|
+
for content, conversations in zip(ranked_result.contents, ranked_result.conversations):
|
267
|
+
merge_result = self._merge_code_without_effect(content)
|
268
|
+
if not merge_result.failed_blocks:
|
269
|
+
return CodeGenerateResult(contents=[content], conversations=[conversations])
|
270
|
+
# 如果所有内容都包含失败块,则返回第一个
|
271
|
+
return CodeGenerateResult(contents=[ranked_result.contents[0]], conversations=[ranked_result.conversations[0]])
|
272
|
+
|
273
|
+
def _merge_code(self, content: str, force_skip_git: bool = False):
|
274
|
+
file_content = open(self.args.file).read()
|
275
|
+
md5 = hashlib.md5(file_content.encode("utf-8")).hexdigest()
|
276
|
+
file_name = os.path.basename(self.args.file)
|
277
|
+
|
278
|
+
codes = self.get_edits(content)
|
279
|
+
changes_to_make = []
|
280
|
+
changes_made = False
|
281
|
+
unmerged_blocks = []
|
282
|
+
merged_blocks = []
|
283
|
+
|
284
|
+
# First, check if there are any changes to be made
|
285
|
+
file_content_mapping = {}
|
286
|
+
for block in codes:
|
287
|
+
file_path, head, update = block
|
288
|
+
if not os.path.exists(file_path):
|
289
|
+
changes_to_make.append((file_path, None, update))
|
290
|
+
file_content_mapping[file_path] = update
|
291
|
+
merged_blocks.append((file_path, "", update, 1))
|
292
|
+
changes_made = True
|
293
|
+
else:
|
294
|
+
if file_path not in file_content_mapping:
|
295
|
+
with open(file_path, "r") as f:
|
296
|
+
temp = f.read()
|
297
|
+
file_content_mapping[file_path] = temp
|
298
|
+
existing_content = file_content_mapping[file_path]
|
299
|
+
new_content = (
|
300
|
+
existing_content.replace(head, update, 1)
|
301
|
+
if head
|
302
|
+
else existing_content + "\n" + update
|
303
|
+
)
|
304
|
+
if new_content != existing_content:
|
305
|
+
changes_to_make.append(
|
306
|
+
(file_path, existing_content, new_content))
|
307
|
+
file_content_mapping[file_path] = new_content
|
308
|
+
merged_blocks.append((file_path, head, update, 1))
|
309
|
+
changes_made = True
|
310
|
+
else:
|
311
|
+
# If the SEARCH BLOCK is not found exactly, then try to use
|
312
|
+
# the similarity ratio to find the best matching block
|
313
|
+
similarity, best_window = TextSimilarity(head, existing_content).get_best_matching_window()
|
314
|
+
if similarity > self.args.editblock_similarity: # 相似性比较
|
315
|
+
new_content = existing_content.replace(
|
316
|
+
best_window, update, 1)
|
317
|
+
if new_content != existing_content:
|
318
|
+
changes_to_make.append(
|
319
|
+
(file_path, existing_content, new_content)
|
320
|
+
)
|
321
|
+
file_content_mapping[file_path] = new_content
|
322
|
+
merged_blocks.append(
|
323
|
+
(file_path, head, update, similarity))
|
324
|
+
changes_made = True
|
325
|
+
else:
|
326
|
+
unmerged_blocks.append((file_path, head, update, similarity))
|
327
|
+
|
328
|
+
if unmerged_blocks:
|
329
|
+
if self.args.request_id and not self.args.skip_events:
|
330
|
+
# collect unmerged blocks
|
331
|
+
event_data = []
|
332
|
+
for file_path, head, update, similarity in unmerged_blocks:
|
333
|
+
event_data.append(
|
334
|
+
{
|
335
|
+
"file_path": file_path,
|
336
|
+
"head": head,
|
337
|
+
"update": update,
|
338
|
+
"similarity": similarity,
|
339
|
+
}
|
340
|
+
)
|
341
|
+
return
|
342
|
+
logger.warning(f"发现 {len(unmerged_blocks)} 个未合并的代码块,更改将不会应用,请手动检查这些代码块后重试。")
|
343
|
+
self._print_unmerged_blocks(unmerged_blocks)
|
344
|
+
return
|
345
|
+
|
346
|
+
# lint check
|
347
|
+
for file_path, new_content in file_content_mapping.items():
|
348
|
+
if file_path.endswith(".py"):
|
349
|
+
pylint_passed, error_message = self.run_pylint(new_content)
|
350
|
+
if not pylint_passed:
|
351
|
+
logger.warning(f"代码文件 {file_path} 的 Pylint 检查未通过,本次更改未应用。错误信息: {error_message}")
|
352
|
+
|
353
|
+
if changes_made and not force_skip_git and not self.args.skip_commit:
|
354
|
+
try:
|
355
|
+
commit_changes(self.args.source_dir, f"auto_coder_pre_{file_name}_{md5}")
|
356
|
+
except Exception as e:
|
357
|
+
logger.error(
|
358
|
+
self.git_require_msg(
|
359
|
+
source_dir=self.args.source_dir, error=str(e))
|
360
|
+
)
|
361
|
+
return
|
362
|
+
# Now, apply the changes
|
363
|
+
for file_path, new_content in file_content_mapping.items():
|
364
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
365
|
+
with open(file_path, "w") as f:
|
366
|
+
f.write(new_content)
|
367
|
+
|
368
|
+
if self.args.request_id and not self.args.skip_events:
|
369
|
+
# collect modified files
|
370
|
+
event_data = []
|
371
|
+
for code in merged_blocks:
|
372
|
+
file_path, head, update, similarity = code
|
373
|
+
event_data.append(
|
374
|
+
{
|
375
|
+
"file_path": file_path,
|
376
|
+
"head": head,
|
377
|
+
"update": update,
|
378
|
+
"similarity": similarity,
|
379
|
+
}
|
380
|
+
)
|
381
|
+
|
382
|
+
if changes_made:
|
383
|
+
if not force_skip_git and not self.args.skip_commit:
|
384
|
+
try:
|
385
|
+
commit_result = commit_changes(self.args.source_dir, f"auto_coder_{file_name}_{md5}")
|
386
|
+
git_print_commit_info(commit_result=commit_result)
|
387
|
+
except Exception as e:
|
388
|
+
logger.error(
|
389
|
+
self.git_require_msg(
|
390
|
+
source_dir=self.args.source_dir, error=str(e)
|
391
|
+
)
|
392
|
+
)
|
393
|
+
logger.info(
|
394
|
+
f"已在 {len(file_content_mapping.keys())} 个文件中合并更改,"
|
395
|
+
f"完成 {len(changes_to_make)}/{len(codes)} 个代码块。"
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
logger.warning("未对任何文件进行更改。")
|
399
|
+
|
400
|
+
def merge_code(self, generate_result: CodeGenerateResult, force_skip_git: bool = False):
|
401
|
+
result = self.choose_best_choice(generate_result)
|
402
|
+
self._merge_code(result.contents[0], force_skip_git)
|
403
|
+
return result
|
404
|
+
|
405
|
+
@staticmethod
|
406
|
+
def _print_unmerged_blocks(unmerged_blocks: List[tuple]):
|
407
|
+
console.print(f"\n[bold red]未合并的代码块:[/bold red]")
|
408
|
+
for file_path, head, update, similarity in unmerged_blocks:
|
409
|
+
console.print(f"\n[bold blue]文件:[/bold blue] {file_path}")
|
410
|
+
console.print(
|
411
|
+
f"\n[bold green]搜索代码块(相似度:{similarity}):[/bold green]")
|
412
|
+
syntax = Syntax(head, "python", theme="monokai", line_numbers=True)
|
413
|
+
console.print(Panel(syntax, expand=False))
|
414
|
+
console.print("\n[bold yellow]替换代码块:[/bold yellow]")
|
415
|
+
syntax = Syntax(update, "python", theme="monokai",
|
416
|
+
line_numbers=True)
|
417
|
+
console.print(Panel(syntax, expand=False))
|
418
|
+
console.print(f"\n[bold red]未合并的代码块总数: {len(unmerged_blocks)}[/bold red]")
|
@@ -0,0 +1,90 @@
|
|
1
|
+
import traceback
|
2
|
+
|
3
|
+
from loguru import logger
|
4
|
+
|
5
|
+
from autocoder_nano.llm_client import AutoLLM
|
6
|
+
from autocoder_nano.llm_prompt import prompt
|
7
|
+
from autocoder_nano.llm_types import AutoCoderArgs, CodeGenerateResult, RankResult
|
8
|
+
|
9
|
+
|
10
|
+
class CodeModificationRanker:
|
11
|
+
def __init__(self, args: AutoCoderArgs, llm: AutoLLM):
|
12
|
+
self.args = args
|
13
|
+
self.llm = llm
|
14
|
+
self.llm.setup_default_model_name(args.code_model)
|
15
|
+
self.llms = [self.llm]
|
16
|
+
|
17
|
+
@prompt()
|
18
|
+
def _rank_modifications(self, s: CodeGenerateResult) -> str:
|
19
|
+
"""
|
20
|
+
对一组代码修改进行质量评估并排序。
|
21
|
+
|
22
|
+
下面是修改需求:
|
23
|
+
|
24
|
+
<edit_requirement>
|
25
|
+
{{ s.conversations[0][-2]["content"] }}
|
26
|
+
</edit_requirement>
|
27
|
+
|
28
|
+
下面是相应的代码修改:
|
29
|
+
{% for content in s.contents %}
|
30
|
+
<edit_block id="{{ loop.index0 }}">
|
31
|
+
{{content}}
|
32
|
+
</edit_block>
|
33
|
+
{% endfor %}
|
34
|
+
|
35
|
+
请输出如下格式的评估结果,只包含 JSON 数据:
|
36
|
+
|
37
|
+
```json
|
38
|
+
{
|
39
|
+
"rank_result": [id1, id2, id3] // id 为 edit_block 的 id,按质量从高到低排序
|
40
|
+
}
|
41
|
+
```
|
42
|
+
|
43
|
+
注意:
|
44
|
+
1. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
|
45
|
+
"""
|
46
|
+
|
47
|
+
def rank_modifications(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
|
48
|
+
import time
|
49
|
+
from collections import defaultdict
|
50
|
+
|
51
|
+
start_time = time.time()
|
52
|
+
logger.info(f"开始对 {len(generate_result.contents)} 个候选结果进行排序")
|
53
|
+
|
54
|
+
try:
|
55
|
+
results = []
|
56
|
+
for llm in self.llms:
|
57
|
+
v = self._rank_modifications.with_llm(llm).with_return_type(RankResult).run(generate_result)
|
58
|
+
results.append(v.rank_result)
|
59
|
+
|
60
|
+
if not results:
|
61
|
+
raise Exception("All ranking requests failed")
|
62
|
+
|
63
|
+
# 计算每个候选人的分数
|
64
|
+
candidate_scores = defaultdict(float)
|
65
|
+
for rank_result in results:
|
66
|
+
for idx, candidate_id in enumerate(rank_result):
|
67
|
+
# Score is 1/(position + 1) since position starts from 0
|
68
|
+
candidate_scores[candidate_id] += 1.0 / (idx + 1)
|
69
|
+
# 按分数降序对候选人进行排序
|
70
|
+
sorted_candidates = sorted(candidate_scores.keys(),
|
71
|
+
key=lambda x: candidate_scores[x],
|
72
|
+
reverse=True)
|
73
|
+
|
74
|
+
elapsed = time.time() - start_time
|
75
|
+
score_details = ", ".join([f"candidate {i}: {candidate_scores[i]:.2f}" for i in sorted_candidates])
|
76
|
+
logger.info(
|
77
|
+
f"排序完成,耗时 {elapsed:.2f} 秒,最佳候选索引: {sorted_candidates[0]},评分详情: {score_details}"
|
78
|
+
)
|
79
|
+
|
80
|
+
rerank_contents = [generate_result.contents[i] for i in sorted_candidates]
|
81
|
+
rerank_conversations = [generate_result.conversations[i] for i in sorted_candidates]
|
82
|
+
|
83
|
+
return CodeGenerateResult(contents=rerank_contents, conversations=rerank_conversations)
|
84
|
+
|
85
|
+
except Exception as e:
|
86
|
+
logger.error(f"排序过程失败: {str(e)}")
|
87
|
+
logger.debug(traceback.format_exc())
|
88
|
+
elapsed = time.time() - start_time
|
89
|
+
logger.warning(f"排序失败,耗时 {elapsed:.2f} 秒,将使用原始顺序")
|
90
|
+
return generate_result
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from difflib import SequenceMatcher
|
2
|
+
|
3
|
+
|
4
|
+
class TextSimilarity:
|
5
|
+
"""
|
6
|
+
找到 text_b 中与 text_a 最相似的部分(滑动窗口)
|
7
|
+
返回相似度分数和最相似的文本片段
|
8
|
+
"""
|
9
|
+
|
10
|
+
def __init__(self, text_a, text_b):
|
11
|
+
self.text_a = text_a
|
12
|
+
self.text_b = text_b
|
13
|
+
self.lines_a = self._split_into_lines(text_a)
|
14
|
+
self.lines_b = self._split_into_lines(text_b)
|
15
|
+
self.m = len(self.lines_a)
|
16
|
+
self.n = len(self.lines_b)
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def _split_into_lines(text):
|
20
|
+
return text.splitlines()
|
21
|
+
|
22
|
+
@staticmethod
|
23
|
+
def _levenshtein_ratio(s1, s2):
|
24
|
+
return SequenceMatcher(None, s1, s2).ratio()
|
25
|
+
|
26
|
+
def get_best_matching_window(self):
|
27
|
+
best_similarity = 0
|
28
|
+
best_window = []
|
29
|
+
|
30
|
+
for i in range(self.n - self.m + 1): # 滑动窗口
|
31
|
+
window_b = self.lines_b[i:i + self.m]
|
32
|
+
similarity = self._levenshtein_ratio("\n".join(self.lines_a), "\n".join(window_b))
|
33
|
+
|
34
|
+
if similarity > best_similarity:
|
35
|
+
best_similarity = similarity
|
36
|
+
best_window = window_b
|
37
|
+
|
38
|
+
return best_similarity, "\n".join(best_window)
|
File without changes
|