auto-coder 0.1.279__py3-none-any.whl → 0.1.281__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/METADATA +1 -1
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/RECORD +15 -13
- autocoder/auto_coder.py +2 -1
- autocoder/common/context_pruner.py +168 -206
- autocoder/index/entry.py +1 -1
- autocoder/rag/doc_filter.py +104 -29
- autocoder/rag/lang.py +50 -0
- autocoder/rag/long_context_rag.py +218 -102
- autocoder/rag/relevant_utils.py +10 -0
- autocoder/utils/stream_thinking.py +193 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/top_level.txt +0 -0
autocoder/rag/doc_filter.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import List, Dict, Optional
|
|
2
|
+
from typing import List, Dict, Optional, Generator, Tuple
|
|
3
3
|
from loguru import logger
|
|
4
4
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5
|
+
from autocoder.rag.lang import get_message_with_format_and_newline
|
|
5
6
|
|
|
6
7
|
from autocoder.rag.relevant_utils import (
|
|
7
8
|
parse_relevance,
|
|
8
9
|
FilterDoc,
|
|
9
10
|
TaskTiming,
|
|
10
|
-
DocFilterResult
|
|
11
|
+
DocFilterResult,
|
|
12
|
+
ProgressUpdate
|
|
11
13
|
)
|
|
12
14
|
|
|
13
15
|
from autocoder.common import SourceCode, AutoCoderArgs
|
|
@@ -49,6 +51,7 @@ def _check_relevance_with_conversation(
|
|
|
49
51
|
其中, <relevant> 是你认为文档中和问题的相关度,0-10之间的数字,数字越大表示相关度越高。
|
|
50
52
|
"""
|
|
51
53
|
|
|
54
|
+
|
|
52
55
|
class DocFilter:
|
|
53
56
|
def __init__(
|
|
54
57
|
self,
|
|
@@ -73,10 +76,10 @@ class DocFilter:
|
|
|
73
76
|
) -> DocFilterResult:
|
|
74
77
|
return self.filter_docs_with_threads(conversations, documents)
|
|
75
78
|
|
|
76
|
-
def
|
|
79
|
+
def filter_docs_with_progress(
|
|
77
80
|
self, conversations: List[Dict[str, str]], documents: List[SourceCode]
|
|
78
|
-
) -> DocFilterResult:
|
|
79
|
-
|
|
81
|
+
) -> Generator[Tuple[ProgressUpdate, Optional[DocFilterResult]], None, DocFilterResult]:
|
|
82
|
+
"""使用线程过滤文档,同时产生进度更新"""
|
|
80
83
|
start_time = time.time()
|
|
81
84
|
logger.info(f"=== DocFilter Starting ===")
|
|
82
85
|
logger.info(
|
|
@@ -93,6 +96,16 @@ class DocFilter:
|
|
|
93
96
|
relevant_count = 0
|
|
94
97
|
model_name = self.recall_llm.default_model_name or "unknown"
|
|
95
98
|
|
|
99
|
+
doc_filter_result = DocFilterResult(
|
|
100
|
+
docs=[],
|
|
101
|
+
raw_docs=[],
|
|
102
|
+
input_tokens_counts=[],
|
|
103
|
+
generated_tokens_counts=[],
|
|
104
|
+
durations=[],
|
|
105
|
+
model_name=model_name
|
|
106
|
+
)
|
|
107
|
+
relevant_docs = doc_filter_result.docs
|
|
108
|
+
|
|
96
109
|
with ThreadPoolExecutor(
|
|
97
110
|
max_workers=self.args.index_filter_workers or 5
|
|
98
111
|
) as executor:
|
|
@@ -141,16 +154,19 @@ class DocFilter:
|
|
|
141
154
|
logger.info(
|
|
142
155
|
f"Submitted {submitted_tasks} document filtering tasks to thread pool")
|
|
143
156
|
|
|
157
|
+
# 发送初始进度更新
|
|
158
|
+
yield (ProgressUpdate(
|
|
159
|
+
phase="doc_filter",
|
|
160
|
+
completed=0,
|
|
161
|
+
total=len(documents),
|
|
162
|
+
relevant_count=0,
|
|
163
|
+
message=get_message_with_format_and_newline(
|
|
164
|
+
"doc_filter_start",
|
|
165
|
+
total=len(documents)
|
|
166
|
+
)
|
|
167
|
+
), None)
|
|
168
|
+
|
|
144
169
|
# 处理完成的任务
|
|
145
|
-
doc_filter_result = DocFilterResult(
|
|
146
|
-
docs=[],
|
|
147
|
-
raw_docs=[],
|
|
148
|
-
input_tokens_counts=[],
|
|
149
|
-
generated_tokens_counts=[],
|
|
150
|
-
durations=[],
|
|
151
|
-
model_name=model_name
|
|
152
|
-
)
|
|
153
|
-
relevant_docs = doc_filter_result.docs
|
|
154
170
|
for future in as_completed(list(future_to_doc.keys())):
|
|
155
171
|
try:
|
|
156
172
|
doc, submit_time = future_to_doc[future]
|
|
@@ -194,32 +210,50 @@ class DocFilter:
|
|
|
194
210
|
f"\n - Timing: Duration={task_timing.duration:.2f}s, Processing={task_timing.real_duration:.2f}s, Queue={queue_time:.2f}s"
|
|
195
211
|
f"\n - Response: {v}"
|
|
196
212
|
)
|
|
197
|
-
|
|
213
|
+
|
|
198
214
|
if "rag" not in doc.metadata:
|
|
199
215
|
doc.metadata["rag"] = {}
|
|
200
216
|
doc.metadata["rag"]["recall"] = {
|
|
201
217
|
"input_tokens_count": input_tokens_count,
|
|
202
218
|
"generated_tokens_count": generated_tokens_count,
|
|
203
219
|
"recall_model": model_name,
|
|
204
|
-
"duration": task_timing.real_duration
|
|
220
|
+
"duration": task_timing.real_duration
|
|
205
221
|
}
|
|
206
|
-
|
|
207
|
-
doc_filter_result.input_tokens_counts.append(
|
|
208
|
-
|
|
209
|
-
doc_filter_result.
|
|
210
|
-
|
|
222
|
+
|
|
223
|
+
doc_filter_result.input_tokens_counts.append(
|
|
224
|
+
input_tokens_count)
|
|
225
|
+
doc_filter_result.generated_tokens_counts.append(
|
|
226
|
+
generated_tokens_count)
|
|
227
|
+
doc_filter_result.durations.append(
|
|
228
|
+
task_timing.real_duration)
|
|
229
|
+
|
|
211
230
|
new_filter_doc = FilterDoc(
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
231
|
+
source_code=doc,
|
|
232
|
+
relevance=relevance,
|
|
233
|
+
task_timing=task_timing,
|
|
234
|
+
)
|
|
235
|
+
|
|
217
236
|
doc_filter_result.raw_docs.append(new_filter_doc)
|
|
218
237
|
|
|
219
238
|
if is_relevant:
|
|
220
239
|
relevant_docs.append(
|
|
221
240
|
new_filter_doc
|
|
222
241
|
)
|
|
242
|
+
|
|
243
|
+
# 产生进度更新
|
|
244
|
+
yield (ProgressUpdate(
|
|
245
|
+
phase="doc_filter",
|
|
246
|
+
completed=completed_tasks,
|
|
247
|
+
total=len(documents),
|
|
248
|
+
relevant_count=relevant_count,
|
|
249
|
+
message=get_message_with_format_and_newline(
|
|
250
|
+
"doc_filter_progress",
|
|
251
|
+
progress_percent=progress_percent,
|
|
252
|
+
relevant_count=relevant_count,
|
|
253
|
+
total=len(documents)
|
|
254
|
+
)
|
|
255
|
+
), None)
|
|
256
|
+
|
|
223
257
|
except Exception as exc:
|
|
224
258
|
try:
|
|
225
259
|
doc, submit_time = future_to_doc[future]
|
|
@@ -236,7 +270,7 @@ class DocFilter:
|
|
|
236
270
|
FilterDoc(
|
|
237
271
|
source_code=doc,
|
|
238
272
|
relevance=None,
|
|
239
|
-
task_timing=TaskTiming(),
|
|
273
|
+
task_timing=TaskTiming(),
|
|
240
274
|
)
|
|
241
275
|
)
|
|
242
276
|
except Exception as e:
|
|
@@ -244,6 +278,18 @@ class DocFilter:
|
|
|
244
278
|
f"Document filtering error in task tracking: {exc}"
|
|
245
279
|
)
|
|
246
280
|
|
|
281
|
+
# 报告错误进度
|
|
282
|
+
yield (ProgressUpdate(
|
|
283
|
+
phase="doc_filter",
|
|
284
|
+
completed=completed_tasks,
|
|
285
|
+
total=len(documents),
|
|
286
|
+
relevant_count=relevant_count,
|
|
287
|
+
message=get_message_with_format_and_newline(
|
|
288
|
+
"doc_filter_error",
|
|
289
|
+
error=str(exc)
|
|
290
|
+
)
|
|
291
|
+
), None)
|
|
292
|
+
|
|
247
293
|
# Sort relevant_docs by relevance score in descending order
|
|
248
294
|
relevant_docs.sort(
|
|
249
295
|
key=lambda x: x.relevance.relevant_score, reverse=True)
|
|
@@ -254,7 +300,7 @@ class DocFilter:
|
|
|
254
300
|
doc.task_timing.real_duration for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
|
|
255
301
|
avg_queue_time = sum(doc.task_timing.real_start_time -
|
|
256
302
|
doc.task_timing.submit_time for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
|
|
257
|
-
|
|
303
|
+
|
|
258
304
|
total_input_tokens = sum(doc_filter_result.input_tokens_counts)
|
|
259
305
|
total_generated_tokens = sum(doc_filter_result.generated_tokens_counts)
|
|
260
306
|
|
|
@@ -278,4 +324,33 @@ class DocFilter:
|
|
|
278
324
|
else:
|
|
279
325
|
logger.warning("No relevant documents found!")
|
|
280
326
|
|
|
281
|
-
|
|
327
|
+
# 返回最终结果
|
|
328
|
+
yield (ProgressUpdate(
|
|
329
|
+
phase="doc_filter",
|
|
330
|
+
completed=len(documents),
|
|
331
|
+
total=len(documents),
|
|
332
|
+
relevant_count=relevant_count,
|
|
333
|
+
message=get_message_with_format_and_newline(
|
|
334
|
+
"doc_filter_complete",
|
|
335
|
+
total_time=total_time,
|
|
336
|
+
relevant_count=relevant_count
|
|
337
|
+
)
|
|
338
|
+
), doc_filter_result)
|
|
339
|
+
|
|
340
|
+
def filter_docs_with_threads(
|
|
341
|
+
self, conversations: List[Dict[str, str]], documents: List[SourceCode]
|
|
342
|
+
) -> DocFilterResult:
|
|
343
|
+
# 保持兼容性的接口
|
|
344
|
+
for _, result in self.filter_docs_with_progress(conversations, documents):
|
|
345
|
+
if result is not None:
|
|
346
|
+
return result
|
|
347
|
+
|
|
348
|
+
# 这是一个应急情况,不应该到达这里
|
|
349
|
+
return DocFilterResult(
|
|
350
|
+
docs=[],
|
|
351
|
+
raw_docs=[],
|
|
352
|
+
input_tokens_counts=[],
|
|
353
|
+
generated_tokens_counts=[],
|
|
354
|
+
durations=[],
|
|
355
|
+
model_name=self.recall_llm.default_model_name or "unknown"
|
|
356
|
+
)
|
autocoder/rag/lang.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import locale
|
|
2
|
+
from byzerllm.utils import format_str_jinja2
|
|
3
|
+
|
|
4
|
+
MESSAGES = {
|
|
5
|
+
"en": {
|
|
6
|
+
"rag_error_title": "RAG Error",
|
|
7
|
+
"rag_error_message": "Failed to generate response: {{error}}",
|
|
8
|
+
"rag_searching_docs": "Searching documents with {{model}}...",
|
|
9
|
+
"rag_docs_filter_result": "{{model}} processed {{docs_num}} documents, cost {{filter_time}} seconds, input tokens: {{input_tokens}}, output tokens: {{output_tokens}}",
|
|
10
|
+
"dynamic_chunking_start": "Dynamic chunking start with {{model}}",
|
|
11
|
+
"dynamic_chunking_result": "Dynamic chunking result with {{model}}, first round cost {{first_round_time}} seconds, second round cost {{sencond_round_time}} seconds, input tokens: {{input_tokens}}, output tokens: {{output_tokens}}, first round full docs: {{first_round_full_docs}}, second round extracted docs: {{second_round_extracted_docs}}",
|
|
12
|
+
"send_to_model": "Send to model {{model}} with {{tokens}} tokens",
|
|
13
|
+
"doc_filter_start": "Document filtering start, total {{total}} documents",
|
|
14
|
+
"doc_filter_progress": "Document filtering progress: {{progress_percent}}% processed {{relevant_count}}/{{total}} documents",
|
|
15
|
+
"doc_filter_error": "Document filtering error: {{error}}",
|
|
16
|
+
"doc_filter_complete": "Document filtering complete, cost {{total_time}} seconds, found {{relevant_count}} relevant documents"
|
|
17
|
+
},
|
|
18
|
+
"zh": {
|
|
19
|
+
"rag_error_title": "RAG 错误",
|
|
20
|
+
"rag_error_message": "生成响应失败: {{error}}",
|
|
21
|
+
"rag_searching_docs": "正在使用 {{model}} 搜索文档...",
|
|
22
|
+
"rag_docs_filter_result": "{{model}} 处理了 {{docs_num}} 个文档, 耗时 {{filter_time}} 秒, 输入 tokens: {{input_tokens}}, 输出 tokens: {{output_tokens}}",
|
|
23
|
+
"dynamic_chunking_start": "使用 {{model}} 进行动态分块",
|
|
24
|
+
"dynamic_chunking_result": "使用 {{model}} 进行动态分块, 第一轮耗时 {{first_round_time}} 秒, 第二轮耗时 {{sencond_round_time}} 秒, 输入 tokens: {{input_tokens}}, 输出 tokens: {{output_tokens}}, 第一轮全量文档: {{first_round_full_docs}}, 第二轮提取文档: {{second_round_extracted_docs}}",
|
|
25
|
+
"send_to_model": "发送给模型 {{model}} 的 tokens 数量预估为 {{tokens}}",
|
|
26
|
+
"doc_filter_start": "开始过滤文档,共 {{total}} 个文档",
|
|
27
|
+
"doc_filter_progress": "文档过滤进度:{{progress_percent}}%,处理了 {{relevant_count}}/{{total}} 个文档",
|
|
28
|
+
"doc_filter_error": "文档过滤错误:{{error}}",
|
|
29
|
+
"doc_filter_complete": "文档过滤完成,耗时 {{total_time}} 秒,找到 {{relevant_count}} 个相关文档"
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_system_language():
|
|
35
|
+
try:
|
|
36
|
+
return locale.getdefaultlocale()[0][:2]
|
|
37
|
+
except:
|
|
38
|
+
return 'en'
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_message(key):
|
|
42
|
+
lang = get_system_language()
|
|
43
|
+
return MESSAGES.get(lang, MESSAGES['en']).get(key, MESSAGES['en'][key])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_message_with_format(msg_key: str, **kwargs):
|
|
47
|
+
return format_str_jinja2(get_message(msg_key), **kwargs)
|
|
48
|
+
|
|
49
|
+
def get_message_with_format_and_newline(msg_key: str, **kwargs):
|
|
50
|
+
return format_str_jinja2(get_message(msg_key), **kwargs) + "\n"
|