auto-coder 0.1.270__py3-none-any.whl → 0.1.272__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/METADATA +2 -2
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/RECORD +22 -21
- autocoder/auto_coder_runner.py +4 -4
- autocoder/commands/auto_command.py +33 -5
- autocoder/commands/tools.py +28 -15
- autocoder/common/auto_coder_lang.py +7 -3
- autocoder/common/auto_configure.py +1 -1
- autocoder/common/command_generator.py +3 -1
- autocoder/common/files.py +44 -10
- autocoder/common/shells.py +68 -0
- autocoder/index/filter/quick_filter.py +4 -3
- autocoder/rag/doc_filter.py +165 -59
- autocoder/rag/llm_wrapper.py +3 -1
- autocoder/rag/long_context_rag.py +196 -51
- autocoder/rag/relevant_utils.py +12 -1
- autocoder/rag/token_limiter.py +159 -18
- autocoder/rag/token_limiter_utils.py +13 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.270.dist-info → auto_coder-0.1.272.dist-info}/top_level.txt +0 -0
autocoder/rag/doc_filter.py
CHANGED
|
@@ -5,8 +5,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
5
5
|
|
|
6
6
|
from autocoder.rag.relevant_utils import (
|
|
7
7
|
parse_relevance,
|
|
8
|
-
FilterDoc,
|
|
8
|
+
FilterDoc,
|
|
9
9
|
TaskTiming,
|
|
10
|
+
DocFilterResult
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
from autocoder.common import SourceCode, AutoCoderArgs
|
|
@@ -48,7 +49,6 @@ def _check_relevance_with_conversation(
|
|
|
48
49
|
其中, <relevant> 是你认为文档中和问题的相关度,0-10之间的数字,数字越大表示相关度越高。
|
|
49
50
|
"""
|
|
50
51
|
|
|
51
|
-
|
|
52
52
|
class DocFilter:
|
|
53
53
|
def __init__(
|
|
54
54
|
self,
|
|
@@ -62,40 +62,57 @@ class DocFilter:
|
|
|
62
62
|
self.recall_llm = self.llm.get_sub_client("recall_model")
|
|
63
63
|
else:
|
|
64
64
|
self.recall_llm = self.llm
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
self.args = args
|
|
67
67
|
self.relevant_score = self.args.rag_doc_filter_relevance
|
|
68
68
|
self.on_ray = on_ray
|
|
69
|
-
self.path = path
|
|
69
|
+
self.path = path
|
|
70
70
|
|
|
71
71
|
def filter_docs(
|
|
72
72
|
self, conversations: List[Dict[str, str]], documents: List[SourceCode]
|
|
73
|
-
) ->
|
|
74
|
-
return self.filter_docs_with_threads(conversations, documents)
|
|
73
|
+
) -> DocFilterResult:
|
|
74
|
+
return self.filter_docs_with_threads(conversations, documents)
|
|
75
75
|
|
|
76
76
|
def filter_docs_with_threads(
|
|
77
77
|
self, conversations: List[Dict[str, str]], documents: List[SourceCode]
|
|
78
|
-
) ->
|
|
79
|
-
|
|
78
|
+
) -> DocFilterResult:
|
|
79
|
+
|
|
80
|
+
start_time = time.time()
|
|
81
|
+
logger.info(f"=== DocFilter Starting ===")
|
|
82
|
+
logger.info(
|
|
83
|
+
f"Configuration: relevance_threshold={self.relevant_score}, thread_workers={self.args.index_filter_workers or 5}")
|
|
84
|
+
|
|
80
85
|
rag_manager = RagConfigManager(path=self.path)
|
|
81
86
|
rag_config = rag_manager.load_config()
|
|
82
|
-
|
|
83
|
-
|
|
87
|
+
|
|
88
|
+
documents = list(documents)
|
|
89
|
+
logger.info(f"Filtering {len(documents)} documents...")
|
|
90
|
+
|
|
91
|
+
submitted_tasks = 0
|
|
92
|
+
completed_tasks = 0
|
|
93
|
+
relevant_count = 0
|
|
94
|
+
model_name = self.recall_llm.default_model_name or "unknown"
|
|
95
|
+
|
|
84
96
|
with ThreadPoolExecutor(
|
|
85
97
|
max_workers=self.args.index_filter_workers or 5
|
|
86
98
|
) as executor:
|
|
87
99
|
future_to_doc = {}
|
|
100
|
+
|
|
101
|
+
# 提交所有任务
|
|
88
102
|
for doc in documents:
|
|
89
103
|
submit_time = time.time()
|
|
104
|
+
submitted_tasks += 1
|
|
90
105
|
|
|
91
106
|
def _run(conversations, docs):
|
|
92
107
|
submit_time_1 = time.time()
|
|
108
|
+
meta = None
|
|
93
109
|
try:
|
|
94
110
|
llm = self.recall_llm
|
|
111
|
+
meta_holder = byzerllm.MetaHolder()
|
|
95
112
|
|
|
96
113
|
v = (
|
|
97
114
|
_check_relevance_with_conversation.with_llm(
|
|
98
|
-
llm)
|
|
115
|
+
llm).with_meta(meta_holder)
|
|
99
116
|
.options({"llm_config": {"max_length": 10}})
|
|
100
117
|
.run(
|
|
101
118
|
conversations=conversations,
|
|
@@ -103,14 +120,16 @@ class DocFilter:
|
|
|
103
120
|
filter_config=rag_config.filter_config,
|
|
104
121
|
)
|
|
105
122
|
)
|
|
123
|
+
|
|
124
|
+
meta = meta_holder.get_meta_model()
|
|
106
125
|
except Exception as e:
|
|
107
126
|
logger.error(
|
|
108
127
|
f"Error in _check_relevance_with_conversation: {str(e)}"
|
|
109
128
|
)
|
|
110
|
-
return (None, submit_time_1, time.time())
|
|
129
|
+
return (None, submit_time_1, time.time(), meta)
|
|
111
130
|
|
|
112
131
|
end_time_2 = time.time()
|
|
113
|
-
return (v, submit_time_1, end_time_2)
|
|
132
|
+
return (v, submit_time_1, end_time_2, meta)
|
|
114
133
|
|
|
115
134
|
m = executor.submit(
|
|
116
135
|
_run,
|
|
@@ -119,57 +138,144 @@ class DocFilter:
|
|
|
119
138
|
)
|
|
120
139
|
future_to_doc[m] = (doc, submit_time)
|
|
121
140
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
and relevance.relevant_score >= self.relevant_score
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
141
|
+
logger.info(
|
|
142
|
+
f"Submitted {submitted_tasks} document filtering tasks to thread pool")
|
|
143
|
+
|
|
144
|
+
# 处理完成的任务
|
|
145
|
+
doc_filter_result = DocFilterResult(
|
|
146
|
+
docs=[],
|
|
147
|
+
raw_docs=[],
|
|
148
|
+
input_tokens_counts=[],
|
|
149
|
+
generated_tokens_counts=[],
|
|
150
|
+
durations=[],
|
|
151
|
+
model_name=model_name
|
|
152
|
+
)
|
|
153
|
+
relevant_docs = doc_filter_result.docs
|
|
154
|
+
for future in as_completed(list(future_to_doc.keys())):
|
|
155
|
+
try:
|
|
156
|
+
doc, submit_time = future_to_doc[future]
|
|
157
|
+
end_time = time.time()
|
|
158
|
+
completed_tasks += 1
|
|
159
|
+
progress_percent = (completed_tasks / len(documents)) * 100
|
|
160
|
+
|
|
161
|
+
v, submit_time_1, end_time_2, meta = future.result()
|
|
162
|
+
task_timing = TaskTiming(
|
|
163
|
+
submit_time=submit_time,
|
|
164
|
+
end_time=end_time,
|
|
165
|
+
duration=end_time - submit_time,
|
|
166
|
+
real_start_time=submit_time_1,
|
|
167
|
+
real_end_time=end_time_2,
|
|
168
|
+
real_duration=end_time_2 - submit_time_1,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
relevance = parse_relevance(v)
|
|
172
|
+
is_relevant = relevance and relevance.relevant_score >= self.relevant_score
|
|
173
|
+
|
|
174
|
+
if is_relevant:
|
|
175
|
+
relevant_count += 1
|
|
176
|
+
status_text = f"RELEVANT (Score: {relevance.relevant_score:.1f})"
|
|
177
|
+
else:
|
|
178
|
+
score_text = f"{relevance.relevant_score:.1f}" if relevance else "N/A"
|
|
179
|
+
status_text = f"NOT RELEVANT (Score: {score_text})"
|
|
180
|
+
|
|
181
|
+
queue_time = task_timing.real_start_time - task_timing.submit_time
|
|
182
|
+
|
|
183
|
+
input_tokens_count = meta.input_tokens_count if meta else 0
|
|
184
|
+
generated_tokens_count = meta.generated_tokens_count if meta else 0
|
|
185
|
+
|
|
186
|
+
logger.info(
|
|
187
|
+
f"Document filtering [{progress_percent:.1f}%] - {completed_tasks}/{len(documents)}:"
|
|
188
|
+
f"\n - File: {doc.module_name}"
|
|
189
|
+
f"\n - Status: {status_text}"
|
|
190
|
+
f"\n - Model: {model_name}"
|
|
191
|
+
f"\n - Threshold: {self.relevant_score}"
|
|
192
|
+
f"\n - Input tokens: {input_tokens_count}"
|
|
193
|
+
f"\n - Generated tokens: {generated_tokens_count}"
|
|
194
|
+
f"\n - Timing: Duration={task_timing.duration:.2f}s, Processing={task_timing.real_duration:.2f}s, Queue={queue_time:.2f}s"
|
|
195
|
+
f"\n - Response: {v}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if "rag" not in doc.metadata:
|
|
199
|
+
doc.metadata["rag"] = {}
|
|
200
|
+
doc.metadata["rag"]["recall"] = {
|
|
201
|
+
"input_tokens_count": input_tokens_count,
|
|
202
|
+
"generated_tokens_count": generated_tokens_count,
|
|
203
|
+
"recall_model": model_name,
|
|
204
|
+
"duration": task_timing.real_duration
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
doc_filter_result.input_tokens_counts.append(input_tokens_count)
|
|
208
|
+
doc_filter_result.generated_tokens_counts.append(generated_tokens_count)
|
|
209
|
+
doc_filter_result.durations.append(task_timing.real_duration)
|
|
210
|
+
|
|
211
|
+
new_filter_doc = FilterDoc(
|
|
157
212
|
source_code=doc,
|
|
158
213
|
relevance=relevance,
|
|
159
214
|
task_timing=task_timing,
|
|
160
215
|
)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
216
|
+
|
|
217
|
+
doc_filter_result.raw_docs.append(new_filter_doc)
|
|
218
|
+
|
|
219
|
+
if is_relevant:
|
|
220
|
+
relevant_docs.append(
|
|
221
|
+
new_filter_doc
|
|
222
|
+
)
|
|
223
|
+
except Exception as exc:
|
|
224
|
+
try:
|
|
225
|
+
doc, submit_time = future_to_doc[future]
|
|
226
|
+
completed_tasks += 1
|
|
227
|
+
progress_percent = (
|
|
228
|
+
completed_tasks / len(documents)) * 100
|
|
229
|
+
logger.error(
|
|
230
|
+
f"Document filtering [{progress_percent:.1f}%] - {completed_tasks}/{len(documents)}:"
|
|
231
|
+
f"\n - File: {doc.module_name}"
|
|
232
|
+
f"\n - Error: {exc}"
|
|
233
|
+
f"\n - Duration: {time.time() - submit_time:.2f}s"
|
|
234
|
+
)
|
|
235
|
+
doc_filter_result.raw_docs.append(
|
|
236
|
+
FilterDoc(
|
|
237
|
+
source_code=doc,
|
|
238
|
+
relevance=None,
|
|
239
|
+
task_timing=TaskTiming(),
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.error(
|
|
244
|
+
f"Document filtering error in task tracking: {exc}"
|
|
245
|
+
)
|
|
170
246
|
|
|
171
247
|
# Sort relevant_docs by relevance score in descending order
|
|
172
248
|
relevant_docs.sort(
|
|
173
249
|
key=lambda x: x.relevance.relevant_score, reverse=True)
|
|
174
|
-
|
|
175
|
-
|
|
250
|
+
|
|
251
|
+
total_time = time.time() - start_time
|
|
252
|
+
|
|
253
|
+
avg_processing_time = sum(
|
|
254
|
+
doc.task_timing.real_duration for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
|
|
255
|
+
avg_queue_time = sum(doc.task_timing.real_start_time -
|
|
256
|
+
doc.task_timing.submit_time for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
|
|
257
|
+
|
|
258
|
+
total_input_tokens = sum(doc_filter_result.input_tokens_counts)
|
|
259
|
+
total_generated_tokens = sum(doc_filter_result.generated_tokens_counts)
|
|
260
|
+
|
|
261
|
+
logger.info(
|
|
262
|
+
f"=== DocFilter Complete ==="
|
|
263
|
+
f"\n * Total time: {total_time:.2f}s"
|
|
264
|
+
f"\n * Documents processed: {completed_tasks}/{len(documents)}"
|
|
265
|
+
f"\n * Relevant documents: {relevant_count} (threshold: {self.relevant_score})"
|
|
266
|
+
f"\n * Average processing time: {avg_processing_time:.2f}s"
|
|
267
|
+
f"\n * Average queue time: {avg_queue_time:.2f}s"
|
|
268
|
+
f"\n * Total input tokens: {total_input_tokens}"
|
|
269
|
+
f"\n * Total generated tokens: {total_generated_tokens}"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if relevant_docs:
|
|
273
|
+
logger.info(
|
|
274
|
+
f"Top 5 relevant documents:"
|
|
275
|
+
+ "".join([f"\n * {doc.source_code.module_name} (Score: {doc.relevance.relevant_score:.1f})"
|
|
276
|
+
for doc in relevant_docs[:5]])
|
|
277
|
+
)
|
|
278
|
+
else:
|
|
279
|
+
logger.warning("No relevant documents found!")
|
|
280
|
+
|
|
281
|
+
return doc_filter_result
|
autocoder/rag/llm_wrapper.py
CHANGED
|
@@ -44,13 +44,15 @@ class LLWrapper:
|
|
|
44
44
|
res,contexts = self.rag.stream_chat_oai(conversations,llm_config=llm_config)
|
|
45
45
|
for t in res:
|
|
46
46
|
yield (t,SingleOutputMeta(0,0))
|
|
47
|
+
|
|
47
48
|
|
|
48
49
|
async def async_stream_chat_oai(self,conversations,
|
|
49
50
|
model:Optional[str]=None,
|
|
50
51
|
role_mapping=None,
|
|
51
52
|
delta_mode=False,
|
|
52
53
|
llm_config:Dict[str,Any]={}):
|
|
53
|
-
res,contexts = await asyncfy_with_semaphore(lambda: self.rag.stream_chat_oai(conversations,llm_config=llm_config))()
|
|
54
|
+
res,contexts = await asyncfy_with_semaphore(lambda: self.rag.stream_chat_oai(conversations,llm_config=llm_config))()
|
|
55
|
+
# res,contexts = await self.llm.async_stream_chat_oai(conversations,llm_config=llm_config)
|
|
54
56
|
for t in res:
|
|
55
57
|
yield (t,SingleOutputMeta(0,0))
|
|
56
58
|
|
|
@@ -31,6 +31,8 @@ from tokenizers import Tokenizer
|
|
|
31
31
|
from autocoder.rag.variable_holder import VariableHolder
|
|
32
32
|
from importlib.metadata import version
|
|
33
33
|
from autocoder.rag.stream_event import event_writer
|
|
34
|
+
from autocoder.rag.relevant_utils import DocFilterResult
|
|
35
|
+
from pydantic import BaseModel
|
|
34
36
|
|
|
35
37
|
try:
|
|
36
38
|
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
@@ -42,6 +44,24 @@ except ImportError:
|
|
|
42
44
|
LLMComputeEngine = None
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
class RecallStat(BaseModel):
|
|
48
|
+
total_input_tokens: int
|
|
49
|
+
total_generated_tokens: int
|
|
50
|
+
model_name: str = "unknown"
|
|
51
|
+
class ChunkStat(BaseModel):
|
|
52
|
+
total_input_tokens: int
|
|
53
|
+
total_generated_tokens: int
|
|
54
|
+
model_name: str = "unknown"
|
|
55
|
+
class AnswerStat(BaseModel):
|
|
56
|
+
total_input_tokens: int
|
|
57
|
+
total_generated_tokens: int
|
|
58
|
+
model_name: str = "unknown"
|
|
59
|
+
|
|
60
|
+
class RAGStat(BaseModel):
|
|
61
|
+
recall_stat: RecallStat
|
|
62
|
+
chunk_stat: ChunkStat
|
|
63
|
+
answer_stat: AnswerStat
|
|
64
|
+
|
|
45
65
|
class LongContextRAG:
|
|
46
66
|
def __init__(
|
|
47
67
|
self,
|
|
@@ -305,7 +325,7 @@ class LongContextRAG:
|
|
|
305
325
|
url = ",".join(contexts)
|
|
306
326
|
return [SourceCode(module_name=f"RAG:{url}", source_code="".join(v))]
|
|
307
327
|
|
|
308
|
-
def _filter_docs(self, conversations: List[Dict[str, str]]) ->
|
|
328
|
+
def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
|
|
309
329
|
query = conversations[-1]["content"]
|
|
310
330
|
documents = self._retrieve_documents(options={"query":query})
|
|
311
331
|
return self.doc_filter.filter_docs(
|
|
@@ -439,7 +459,32 @@ class LongContextRAG:
|
|
|
439
459
|
|
|
440
460
|
logger.info(f"Query: {query} only_contexts: {only_contexts}")
|
|
441
461
|
start_time = time.time()
|
|
442
|
-
|
|
462
|
+
|
|
463
|
+
rag_stat = RAGStat(
|
|
464
|
+
recall_stat=RecallStat(
|
|
465
|
+
total_input_tokens=0,
|
|
466
|
+
total_generated_tokens=0,
|
|
467
|
+
model_name=self.llm.default_model_name,
|
|
468
|
+
),
|
|
469
|
+
chunk_stat=ChunkStat(
|
|
470
|
+
total_input_tokens=0,
|
|
471
|
+
total_generated_tokens=0,
|
|
472
|
+
model_name=self.llm.default_model_name,
|
|
473
|
+
),
|
|
474
|
+
answer_stat=AnswerStat(
|
|
475
|
+
total_input_tokens=0,
|
|
476
|
+
total_generated_tokens=0,
|
|
477
|
+
model_name=self.llm.default_model_name,
|
|
478
|
+
),
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
doc_filter_result = self._filter_docs(conversations)
|
|
482
|
+
|
|
483
|
+
rag_stat.recall_stat.total_input_tokens += sum(doc_filter_result.input_tokens_counts)
|
|
484
|
+
rag_stat.recall_stat.total_generated_tokens += sum(doc_filter_result.generated_tokens_counts)
|
|
485
|
+
rag_stat.recall_stat.model_name = doc_filter_result.model_name
|
|
486
|
+
|
|
487
|
+
relevant_docs: List[FilterDoc] = doc_filter_result.docs
|
|
443
488
|
filter_time = time.time() - start_time
|
|
444
489
|
|
|
445
490
|
# Filter relevant_docs to only include those with is_relevant=True
|
|
@@ -469,17 +514,15 @@ class LongContextRAG:
|
|
|
469
514
|
# 将 FilterDoc 转化为 SourceCode 方便后续的逻辑继续做处理
|
|
470
515
|
relevant_docs = [doc.source_code for doc in relevant_docs]
|
|
471
516
|
|
|
472
|
-
|
|
517
|
+
logger.info(f"=== RAG Search Results ===")
|
|
518
|
+
logger.info(f"Query: {query}")
|
|
519
|
+
logger.info(f"Found relevant docs: {len(relevant_docs)}")
|
|
473
520
|
|
|
474
|
-
#
|
|
475
|
-
query_table = Table(title="Query Information", show_header=False)
|
|
476
|
-
query_table.add_row("Query", query)
|
|
477
|
-
query_table.add_row("Relevant docs", str(len(relevant_docs)))
|
|
478
|
-
|
|
479
|
-
# Add relevant docs information
|
|
521
|
+
# 记录相关文档信息
|
|
480
522
|
relevant_docs_info = []
|
|
481
|
-
for doc in relevant_docs:
|
|
482
|
-
|
|
523
|
+
for i, doc in enumerate(relevant_docs):
|
|
524
|
+
doc_path = doc.module_name.replace(self.path, '', 1)
|
|
525
|
+
info = f"{i+1}. {doc_path}"
|
|
483
526
|
if "original_docs" in doc.metadata:
|
|
484
527
|
original_docs = ", ".join(
|
|
485
528
|
[
|
|
@@ -490,8 +533,11 @@ class LongContextRAG:
|
|
|
490
533
|
info += f" (Original docs: {original_docs})"
|
|
491
534
|
relevant_docs_info.append(info)
|
|
492
535
|
|
|
493
|
-
|
|
494
|
-
|
|
536
|
+
if relevant_docs_info:
|
|
537
|
+
logger.info(
|
|
538
|
+
f"Relevant documents list:"
|
|
539
|
+
+ "".join([f"\n * {info}" for info in relevant_docs_info])
|
|
540
|
+
)
|
|
495
541
|
|
|
496
542
|
first_round_full_docs = []
|
|
497
543
|
second_round_extracted_docs = []
|
|
@@ -507,11 +553,18 @@ class LongContextRAG:
|
|
|
507
553
|
llm=self.llm,
|
|
508
554
|
disable_segment_reorder=self.args.disable_segment_reorder,
|
|
509
555
|
)
|
|
510
|
-
|
|
556
|
+
|
|
557
|
+
token_limiter_result = token_limiter.limit_tokens(
|
|
511
558
|
relevant_docs=relevant_docs,
|
|
512
559
|
conversations=conversations,
|
|
513
560
|
index_filter_workers=self.args.index_filter_workers or 5,
|
|
514
561
|
)
|
|
562
|
+
|
|
563
|
+
rag_stat.chunk_stat.total_input_tokens += sum(token_limiter_result.input_tokens_counts)
|
|
564
|
+
rag_stat.chunk_stat.total_generated_tokens += sum(token_limiter_result.generated_tokens_counts)
|
|
565
|
+
rag_stat.chunk_stat.model_name = token_limiter_result.model_name
|
|
566
|
+
|
|
567
|
+
final_relevant_docs = token_limiter_result.docs
|
|
515
568
|
first_round_full_docs = token_limiter.first_round_full_docs
|
|
516
569
|
second_round_extracted_docs = token_limiter.second_round_extracted_docs
|
|
517
570
|
sencond_round_time = token_limiter.sencond_round_time
|
|
@@ -522,57 +575,64 @@ class LongContextRAG:
|
|
|
522
575
|
|
|
523
576
|
logger.info(f"Finally send to model: {len(relevant_docs)}")
|
|
524
577
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
"
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
"
|
|
533
|
-
|
|
534
|
-
query_table.add_row(
|
|
535
|
-
"Second round time", f"{sencond_round_time:.2f} seconds"
|
|
578
|
+
# 记录分段处理的统计信息
|
|
579
|
+
logger.info(
|
|
580
|
+
f"=== Token Management ===\n"
|
|
581
|
+
f" * Only contexts: {only_contexts}\n"
|
|
582
|
+
f" * Filter time: {filter_time:.2f} seconds\n"
|
|
583
|
+
f" * Final relevant docs: {len(relevant_docs)}\n"
|
|
584
|
+
f" * First round full docs: {len(first_round_full_docs)}\n"
|
|
585
|
+
f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
|
|
586
|
+
f" * Second round time: {sencond_round_time:.2f} seconds"
|
|
536
587
|
)
|
|
537
588
|
|
|
538
|
-
#
|
|
589
|
+
# 记录最终选择的文档详情
|
|
539
590
|
final_relevant_docs_info = []
|
|
540
|
-
for doc in relevant_docs:
|
|
541
|
-
|
|
591
|
+
for i, doc in enumerate(relevant_docs):
|
|
592
|
+
doc_path = doc.module_name.replace(self.path, '', 1)
|
|
593
|
+
info = f"{i+1}. {doc_path}"
|
|
594
|
+
|
|
595
|
+
metadata_info = []
|
|
542
596
|
if "original_docs" in doc.metadata:
|
|
543
597
|
original_docs = ", ".join(
|
|
544
598
|
[
|
|
545
|
-
|
|
546
|
-
for
|
|
599
|
+
od.replace(self.path, "", 1)
|
|
600
|
+
for od in doc.metadata["original_docs"]
|
|
547
601
|
]
|
|
548
602
|
)
|
|
549
|
-
|
|
603
|
+
metadata_info.append(f"Original docs: {original_docs}")
|
|
604
|
+
|
|
550
605
|
if "chunk_ranges" in doc.metadata:
|
|
551
606
|
chunk_ranges = json.dumps(
|
|
552
607
|
doc.metadata["chunk_ranges"], ensure_ascii=False
|
|
553
608
|
)
|
|
554
|
-
|
|
609
|
+
metadata_info.append(f"Chunk ranges: {chunk_ranges}")
|
|
610
|
+
|
|
611
|
+
if "processing_time" in doc.metadata:
|
|
612
|
+
metadata_info.append(f"Processing time: {doc.metadata['processing_time']:.2f}s")
|
|
613
|
+
|
|
614
|
+
if metadata_info:
|
|
615
|
+
info += f" ({'; '.join(metadata_info)})"
|
|
616
|
+
|
|
555
617
|
final_relevant_docs_info.append(info)
|
|
556
618
|
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
panel = Panel(
|
|
562
|
-
query_table,
|
|
563
|
-
title="RAG Search Results",
|
|
564
|
-
expand=False,
|
|
619
|
+
if final_relevant_docs_info:
|
|
620
|
+
logger.info(
|
|
621
|
+
f"Final documents to be sent to model:"
|
|
622
|
+
+ "".join([f"\n * {info}" for info in final_relevant_docs_info])
|
|
565
623
|
)
|
|
566
624
|
|
|
567
|
-
#
|
|
568
|
-
console.print(panel)
|
|
569
|
-
|
|
625
|
+
# 记录令牌统计
|
|
570
626
|
request_tokens = sum([doc.tokens for doc in relevant_docs])
|
|
571
627
|
target_model = model or self.llm.default_model_name
|
|
572
628
|
logger.info(
|
|
573
|
-
f"
|
|
629
|
+
f"=== LLM Request ===\n"
|
|
630
|
+
f" * Target model: {target_model}\n"
|
|
631
|
+
f" * Total tokens: {request_tokens}"
|
|
574
632
|
)
|
|
575
633
|
|
|
634
|
+
logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
|
|
635
|
+
|
|
576
636
|
if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
|
|
577
637
|
llm_compute_engine = LLMComputeEngine(
|
|
578
638
|
llm=target_llm,
|
|
@@ -585,17 +645,22 @@ class LongContextRAG:
|
|
|
585
645
|
new_conversations = llm_compute_engine.process_conversation(
|
|
586
646
|
conversations, query, [doc.source_code for doc in relevant_docs]
|
|
587
647
|
)
|
|
588
|
-
|
|
589
|
-
return (
|
|
590
|
-
llm_compute_engine.stream_chat_oai(
|
|
648
|
+
chunks = llm_compute_engine.stream_chat_oai(
|
|
591
649
|
conversations=new_conversations,
|
|
592
650
|
model=model,
|
|
593
651
|
role_mapping=role_mapping,
|
|
594
652
|
llm_config=llm_config,
|
|
595
653
|
delta_mode=True,
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
def generate_chunks():
|
|
657
|
+
for chunk in chunks:
|
|
658
|
+
yield chunk[0]
|
|
659
|
+
if chunk[1] is not None:
|
|
660
|
+
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
661
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
662
|
+
self._print_rag_stats(rag_stat)
|
|
663
|
+
return generate_chunks(), context
|
|
599
664
|
|
|
600
665
|
new_conversations = conversations[:-1] + [
|
|
601
666
|
{
|
|
@@ -614,5 +679,85 @@ class LongContextRAG:
|
|
|
614
679
|
llm_config=llm_config,
|
|
615
680
|
delta_mode=True,
|
|
616
681
|
)
|
|
682
|
+
|
|
683
|
+
def generate_chunks():
|
|
684
|
+
for chunk in chunks:
|
|
685
|
+
yield chunk[0]
|
|
686
|
+
if chunk[1] is not None:
|
|
687
|
+
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
688
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
689
|
+
self._print_rag_stats(rag_stat)
|
|
690
|
+
return generate_chunks(), context
|
|
691
|
+
|
|
692
|
+
|
|
617
693
|
|
|
618
|
-
|
|
694
|
+
def _print_rag_stats(self, rag_stat: RAGStat) -> None:
|
|
695
|
+
"""打印RAG执行的详细统计信息"""
|
|
696
|
+
total_input_tokens = (
|
|
697
|
+
rag_stat.recall_stat.total_input_tokens +
|
|
698
|
+
rag_stat.chunk_stat.total_input_tokens +
|
|
699
|
+
rag_stat.answer_stat.total_input_tokens
|
|
700
|
+
)
|
|
701
|
+
total_generated_tokens = (
|
|
702
|
+
rag_stat.recall_stat.total_generated_tokens +
|
|
703
|
+
rag_stat.chunk_stat.total_generated_tokens +
|
|
704
|
+
rag_stat.answer_stat.total_generated_tokens
|
|
705
|
+
)
|
|
706
|
+
total_tokens = total_input_tokens + total_generated_tokens
|
|
707
|
+
|
|
708
|
+
# 避免除以零错误
|
|
709
|
+
if total_tokens == 0:
|
|
710
|
+
recall_percent = chunk_percent = answer_percent = 0
|
|
711
|
+
else:
|
|
712
|
+
recall_percent = (rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
|
|
713
|
+
chunk_percent = (rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
|
|
714
|
+
answer_percent = (rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
|
|
715
|
+
|
|
716
|
+
logger.info(
|
|
717
|
+
f"=== RAG 执行统计信息 ===\n"
|
|
718
|
+
f"总令牌使用: {total_tokens} 令牌\n"
|
|
719
|
+
f" * 输入令牌总数: {total_input_tokens}\n"
|
|
720
|
+
f" * 生成令牌总数: {total_generated_tokens}\n"
|
|
721
|
+
f"\n"
|
|
722
|
+
f"阶段统计:\n"
|
|
723
|
+
f" 1. 文档检索阶段:\n"
|
|
724
|
+
f" - 模型: {rag_stat.recall_stat.model_name}\n"
|
|
725
|
+
f" - 输入令牌: {rag_stat.recall_stat.total_input_tokens}\n"
|
|
726
|
+
f" - 生成令牌: {rag_stat.recall_stat.total_generated_tokens}\n"
|
|
727
|
+
f" - 阶段总计: {rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens}\n"
|
|
728
|
+
f"\n"
|
|
729
|
+
f" 2. 文档分块阶段:\n"
|
|
730
|
+
f" - 模型: {rag_stat.chunk_stat.model_name}\n"
|
|
731
|
+
f" - 输入令牌: {rag_stat.chunk_stat.total_input_tokens}\n"
|
|
732
|
+
f" - 生成令牌: {rag_stat.chunk_stat.total_generated_tokens}\n"
|
|
733
|
+
f" - 阶段总计: {rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens}\n"
|
|
734
|
+
f"\n"
|
|
735
|
+
f" 3. 答案生成阶段:\n"
|
|
736
|
+
f" - 模型: {rag_stat.answer_stat.model_name}\n"
|
|
737
|
+
f" - 输入令牌: {rag_stat.answer_stat.total_input_tokens}\n"
|
|
738
|
+
f" - 生成令牌: {rag_stat.answer_stat.total_generated_tokens}\n"
|
|
739
|
+
f" - 阶段总计: {rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens}\n"
|
|
740
|
+
f"\n"
|
|
741
|
+
f"令牌分布百分比:\n"
|
|
742
|
+
f" - 文档检索: {recall_percent:.1f}%\n"
|
|
743
|
+
f" - 文档分块: {chunk_percent:.1f}%\n"
|
|
744
|
+
f" - 答案生成: {answer_percent:.1f}%\n"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# 记录原始统计数据,以便调试
|
|
748
|
+
logger.debug(f"RAG Stat 原始数据: {rag_stat}")
|
|
749
|
+
|
|
750
|
+
# 返回成本估算
|
|
751
|
+
estimated_cost = self._estimate_token_cost(total_input_tokens, total_generated_tokens)
|
|
752
|
+
if estimated_cost > 0:
|
|
753
|
+
logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
|
|
754
|
+
|
|
755
|
+
def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
|
|
756
|
+
"""估算当前请求的令牌成本(人民币)"""
|
|
757
|
+
# 实际应用中,可以根据不同模型设置不同价格
|
|
758
|
+
input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
|
|
759
|
+
output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
|
|
760
|
+
|
|
761
|
+
cost = (input_tokens * input_cost_per_1m / 1000000) + (output_tokens* output_cost_per_1m/1000000)
|
|
762
|
+
return cost
|
|
763
|
+
|