auto-coder 0.1.279__py3-none-any.whl → 0.1.281__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/METADATA +1 -1
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/RECORD +15 -13
- autocoder/auto_coder.py +2 -1
- autocoder/common/context_pruner.py +168 -206
- autocoder/index/entry.py +1 -1
- autocoder/rag/doc_filter.py +104 -29
- autocoder/rag/lang.py +50 -0
- autocoder/rag/long_context_rag.py +218 -102
- autocoder/rag/relevant_utils.py +10 -0
- autocoder/utils/stream_thinking.py +193 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.279.dist-info → auto_coder-0.1.281.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,8 @@ from autocoder.rag.relevant_utils import (
|
|
|
23
23
|
FilterDoc,
|
|
24
24
|
TaskTiming,
|
|
25
25
|
parse_relevance,
|
|
26
|
+
ProgressUpdate,
|
|
27
|
+
DocFilterResult
|
|
26
28
|
)
|
|
27
29
|
from autocoder.rag.token_checker import check_token_limit
|
|
28
30
|
from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
|
|
@@ -34,14 +36,17 @@ from autocoder.rag.stream_event import event_writer
|
|
|
34
36
|
from autocoder.rag.relevant_utils import DocFilterResult
|
|
35
37
|
from pydantic import BaseModel
|
|
36
38
|
from byzerllm.utils.types import SingleOutputMeta
|
|
39
|
+
from autocoder.rag.lang import get_message_with_format_and_newline
|
|
37
40
|
|
|
38
|
-
try:
|
|
41
|
+
try:
|
|
39
42
|
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
40
43
|
pro_version = version("auto-coder-pro")
|
|
41
44
|
autocoder_version = version("auto-coder")
|
|
42
|
-
logger.warning(
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
|
|
43
47
|
except ImportError:
|
|
44
|
-
logger.warning(
|
|
48
|
+
logger.warning(
|
|
49
|
+
"Please install auto-coder-pro to enhance llm compute ability")
|
|
45
50
|
LLMComputeEngine = None
|
|
46
51
|
|
|
47
52
|
|
|
@@ -49,20 +54,26 @@ class RecallStat(BaseModel):
|
|
|
49
54
|
total_input_tokens: int
|
|
50
55
|
total_generated_tokens: int
|
|
51
56
|
model_name: str = "unknown"
|
|
57
|
+
|
|
58
|
+
|
|
52
59
|
class ChunkStat(BaseModel):
|
|
53
60
|
total_input_tokens: int
|
|
54
|
-
total_generated_tokens: int
|
|
61
|
+
total_generated_tokens: int
|
|
55
62
|
model_name: str = "unknown"
|
|
63
|
+
|
|
64
|
+
|
|
56
65
|
class AnswerStat(BaseModel):
|
|
57
66
|
total_input_tokens: int
|
|
58
67
|
total_generated_tokens: int
|
|
59
68
|
model_name: str = "unknown"
|
|
60
69
|
|
|
70
|
+
|
|
61
71
|
class RAGStat(BaseModel):
|
|
62
72
|
recall_stat: RecallStat
|
|
63
73
|
chunk_stat: ChunkStat
|
|
64
74
|
answer_stat: AnswerStat
|
|
65
75
|
|
|
76
|
+
|
|
66
77
|
class LongContextRAG:
|
|
67
78
|
def __init__(
|
|
68
79
|
self,
|
|
@@ -86,7 +97,7 @@ class LongContextRAG:
|
|
|
86
97
|
self.chunk_llm = self.llm.get_sub_client("chunk_model")
|
|
87
98
|
|
|
88
99
|
self.args = args
|
|
89
|
-
|
|
100
|
+
|
|
90
101
|
self.path = path
|
|
91
102
|
self.relevant_score = self.args.rag_doc_filter_relevance or 5
|
|
92
103
|
|
|
@@ -99,8 +110,10 @@ class LongContextRAG:
|
|
|
99
110
|
"The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
|
|
100
111
|
)
|
|
101
112
|
|
|
102
|
-
self.full_text_limit = int(
|
|
103
|
-
|
|
113
|
+
self.full_text_limit = int(
|
|
114
|
+
args.rag_context_window_limit * self.full_text_ratio)
|
|
115
|
+
self.segment_limit = int(
|
|
116
|
+
args.rag_context_window_limit * self.segment_ratio)
|
|
104
117
|
self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
|
|
105
118
|
|
|
106
119
|
self.tokenizer = None
|
|
@@ -109,7 +122,8 @@ class LongContextRAG:
|
|
|
109
122
|
|
|
110
123
|
if self.tokenizer_path:
|
|
111
124
|
VariableHolder.TOKENIZER_PATH = self.tokenizer_path
|
|
112
|
-
VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
|
|
125
|
+
VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
|
|
126
|
+
self.tokenizer_path)
|
|
113
127
|
self.tokenizer = TokenCounter(self.tokenizer_path)
|
|
114
128
|
else:
|
|
115
129
|
if llm.is_model_exist("deepseek_tokenizer"):
|
|
@@ -161,9 +175,9 @@ class LongContextRAG:
|
|
|
161
175
|
self.required_exts,
|
|
162
176
|
self.on_ray,
|
|
163
177
|
self.monitor_mode,
|
|
164
|
-
|
|
178
|
+
# 确保全文区至少能放下一个文件
|
|
165
179
|
single_file_token_limit=self.full_text_limit - 100,
|
|
166
|
-
disable_auto_window=self.args.disable_auto_window,
|
|
180
|
+
disable_auto_window=self.args.disable_auto_window,
|
|
167
181
|
enable_hybrid_index=self.args.enable_hybrid_index,
|
|
168
182
|
extra_params=self.args
|
|
169
183
|
)
|
|
@@ -224,14 +238,14 @@ class LongContextRAG:
|
|
|
224
238
|
{% for msg in conversations %}
|
|
225
239
|
[{{ msg.role }}]:
|
|
226
240
|
{{ msg.content }}
|
|
227
|
-
|
|
241
|
+
|
|
228
242
|
{% endfor %}
|
|
229
243
|
</conversations>
|
|
230
244
|
|
|
231
245
|
请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
|
|
232
246
|
如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
|
|
233
247
|
提取的信息尽量保持和原文中的一样,并且只输出这些信息。
|
|
234
|
-
"""
|
|
248
|
+
"""
|
|
235
249
|
|
|
236
250
|
@byzerllm.prompt()
|
|
237
251
|
def _answer_question(
|
|
@@ -266,26 +280,25 @@ class LongContextRAG:
|
|
|
266
280
|
"""Get the document retriever class based on configuration."""
|
|
267
281
|
# Default to LocalDocumentRetriever if not specified
|
|
268
282
|
return LocalDocumentRetriever
|
|
269
|
-
|
|
283
|
+
|
|
270
284
|
def _load_ignore_file(self):
|
|
271
285
|
serveignore_path = os.path.join(self.path, ".serveignore")
|
|
272
286
|
gitignore_path = os.path.join(self.path, ".gitignore")
|
|
273
287
|
|
|
274
288
|
if os.path.exists(serveignore_path):
|
|
275
|
-
with open(serveignore_path, "r",encoding="utf-8") as ignore_file:
|
|
289
|
+
with open(serveignore_path, "r", encoding="utf-8") as ignore_file:
|
|
276
290
|
return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
|
|
277
291
|
elif os.path.exists(gitignore_path):
|
|
278
|
-
with open(gitignore_path, "r",encoding="utf-8") as ignore_file:
|
|
292
|
+
with open(gitignore_path, "r", encoding="utf-8") as ignore_file:
|
|
279
293
|
return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
|
|
280
294
|
return None
|
|
281
295
|
|
|
282
|
-
def _retrieve_documents(self,options:Optional[Dict[str,Any]]=None) -> Generator[SourceCode, None, None]:
|
|
296
|
+
def _retrieve_documents(self, options: Optional[Dict[str, Any]] = None) -> Generator[SourceCode, None, None]:
|
|
283
297
|
return self.document_retriever.retrieve_documents(options=options)
|
|
284
298
|
|
|
285
299
|
def build(self):
|
|
286
300
|
pass
|
|
287
301
|
|
|
288
|
-
|
|
289
302
|
def search(self, query: str) -> List[SourceCode]:
|
|
290
303
|
target_query = query
|
|
291
304
|
only_contexts = False
|
|
@@ -300,7 +313,8 @@ class LongContextRAG:
|
|
|
300
313
|
only_contexts = True
|
|
301
314
|
|
|
302
315
|
logger.info("Search from RAG.....")
|
|
303
|
-
logger.info(
|
|
316
|
+
logger.info(
|
|
317
|
+
f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
|
|
304
318
|
|
|
305
319
|
if self.client:
|
|
306
320
|
new_query = json.dumps(
|
|
@@ -316,7 +330,8 @@ class LongContextRAG:
|
|
|
316
330
|
if not only_contexts:
|
|
317
331
|
return [SourceCode(module_name=f"RAG:{target_query}", source_code=v)]
|
|
318
332
|
|
|
319
|
-
json_lines = [json.loads(line)
|
|
333
|
+
json_lines = [json.loads(line)
|
|
334
|
+
for line in v.split("\n") if line.strip()]
|
|
320
335
|
return [SourceCode.model_validate(json_line) for json_line in json_lines]
|
|
321
336
|
else:
|
|
322
337
|
if only_contexts:
|
|
@@ -335,7 +350,7 @@ class LongContextRAG:
|
|
|
335
350
|
|
|
336
351
|
def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
|
|
337
352
|
query = conversations[-1]["content"]
|
|
338
|
-
documents = self._retrieve_documents(options={"query":query})
|
|
353
|
+
documents = self._retrieve_documents(options={"query": query})
|
|
339
354
|
return self.doc_filter.filter_docs(
|
|
340
355
|
conversations=conversations, documents=documents
|
|
341
356
|
)
|
|
@@ -360,9 +375,8 @@ class LongContextRAG:
|
|
|
360
375
|
logger.error(f"Error in stream_chat_oai: {str(e)}")
|
|
361
376
|
traceback.print_exc()
|
|
362
377
|
return ["出现错误,请稍后再试。"], []
|
|
363
|
-
|
|
364
378
|
|
|
365
|
-
def _stream_chatfrom_openai_sdk(self,response):
|
|
379
|
+
def _stream_chatfrom_openai_sdk(self, response):
|
|
366
380
|
for chunk in response:
|
|
367
381
|
if hasattr(chunk, "usage") and chunk.usage:
|
|
368
382
|
input_tokens_count = chunk.usage.prompt_tokens
|
|
@@ -386,9 +400,9 @@ class LongContextRAG:
|
|
|
386
400
|
reasoning_text = chunk.choices[0].delta.reasoning_content or ""
|
|
387
401
|
|
|
388
402
|
last_meta = SingleOutputMeta(input_tokens_count=input_tokens_count,
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
403
|
+
generated_tokens_count=generated_tokens_count,
|
|
404
|
+
reasoning_content=reasoning_text,
|
|
405
|
+
finish_reason=chunk.choices[0].finish_reason)
|
|
392
406
|
yield (content, last_meta)
|
|
393
407
|
|
|
394
408
|
def _stream_chat_oai(
|
|
@@ -398,7 +412,7 @@ class LongContextRAG:
|
|
|
398
412
|
role_mapping=None,
|
|
399
413
|
llm_config: Dict[str, Any] = {},
|
|
400
414
|
extra_request_params: Dict[str, Any] = {}
|
|
401
|
-
):
|
|
415
|
+
):
|
|
402
416
|
if self.client:
|
|
403
417
|
model = model or self.args.model
|
|
404
418
|
response = self.client.chat.completions.create(
|
|
@@ -407,8 +421,8 @@ class LongContextRAG:
|
|
|
407
421
|
stream=True,
|
|
408
422
|
max_tokens=self.args.rag_params_max_tokens,
|
|
409
423
|
extra_body=extra_request_params
|
|
410
|
-
)
|
|
411
|
-
return self._stream_chatfrom_openai_sdk(response), []
|
|
424
|
+
)
|
|
425
|
+
return self._stream_chatfrom_openai_sdk(response), []
|
|
412
426
|
|
|
413
427
|
target_llm = self.llm
|
|
414
428
|
if self.llm.get_sub_client("qa_model"):
|
|
@@ -422,7 +436,7 @@ class LongContextRAG:
|
|
|
422
436
|
in query
|
|
423
437
|
or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
|
|
424
438
|
in query
|
|
425
|
-
):
|
|
439
|
+
):
|
|
426
440
|
|
|
427
441
|
chunks = target_llm.stream_chat_oai(
|
|
428
442
|
conversations=conversations,
|
|
@@ -432,22 +446,24 @@ class LongContextRAG:
|
|
|
432
446
|
delta_mode=True,
|
|
433
447
|
extra_request_params=extra_request_params
|
|
434
448
|
)
|
|
449
|
+
|
|
435
450
|
def generate_chunks():
|
|
436
451
|
for chunk in chunks:
|
|
437
452
|
yield chunk
|
|
438
453
|
return generate_chunks(), context
|
|
439
|
-
|
|
440
|
-
try:
|
|
454
|
+
|
|
455
|
+
try:
|
|
441
456
|
request_params = json.loads(query)
|
|
442
|
-
if "request_id" in request_params:
|
|
457
|
+
if "request_id" in request_params:
|
|
443
458
|
request_id = request_params["request_id"]
|
|
444
459
|
index = request_params["index"]
|
|
445
|
-
|
|
446
|
-
file_path = event_writer.get_event_file_path(request_id)
|
|
447
|
-
logger.info(
|
|
460
|
+
|
|
461
|
+
file_path = event_writer.get_event_file_path(request_id)
|
|
462
|
+
logger.info(
|
|
463
|
+
f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
|
|
448
464
|
events = []
|
|
449
465
|
if not os.path.exists(file_path):
|
|
450
|
-
return [],context
|
|
466
|
+
return [], context
|
|
451
467
|
|
|
452
468
|
with open(file_path, "r") as f:
|
|
453
469
|
for line in f:
|
|
@@ -455,8 +471,8 @@ class LongContextRAG:
|
|
|
455
471
|
if event["index"] >= index:
|
|
456
472
|
events.append(event)
|
|
457
473
|
return [json.dumps({
|
|
458
|
-
"events": [event for event in events],
|
|
459
|
-
},ensure_ascii=False)], context
|
|
474
|
+
"events": [event for event in events],
|
|
475
|
+
}, ensure_ascii=False)], context
|
|
460
476
|
except json.JSONDecodeError:
|
|
461
477
|
pass
|
|
462
478
|
|
|
@@ -465,7 +481,7 @@ class LongContextRAG:
|
|
|
465
481
|
llm=target_llm,
|
|
466
482
|
inference_enhance=not self.args.disable_inference_enhance,
|
|
467
483
|
inference_deep_thought=self.args.inference_deep_thought,
|
|
468
|
-
inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
|
|
484
|
+
inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
|
|
469
485
|
precision=self.args.inference_compute_precision,
|
|
470
486
|
data_cells_max_num=self.args.data_cells_max_num,
|
|
471
487
|
)
|
|
@@ -474,14 +490,14 @@ class LongContextRAG:
|
|
|
474
490
|
conversations, query, []
|
|
475
491
|
)
|
|
476
492
|
chunks = llm_compute_engine.stream_chat_oai(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
493
|
+
conversations=new_conversations,
|
|
494
|
+
model=model,
|
|
495
|
+
role_mapping=role_mapping,
|
|
496
|
+
llm_config=llm_config,
|
|
497
|
+
delta_mode=True,
|
|
498
|
+
extra_request_params=extra_request_params
|
|
499
|
+
)
|
|
500
|
+
|
|
485
501
|
def generate_chunks():
|
|
486
502
|
for chunk in chunks:
|
|
487
503
|
yield chunk
|
|
@@ -491,7 +507,6 @@ class LongContextRAG:
|
|
|
491
507
|
context,
|
|
492
508
|
)
|
|
493
509
|
|
|
494
|
-
|
|
495
510
|
only_contexts = False
|
|
496
511
|
try:
|
|
497
512
|
v = json.loads(query)
|
|
@@ -504,7 +519,6 @@ class LongContextRAG:
|
|
|
504
519
|
|
|
505
520
|
logger.info(f"Query: {query} only_contexts: {only_contexts}")
|
|
506
521
|
start_time = time.time()
|
|
507
|
-
|
|
508
522
|
|
|
509
523
|
rag_stat = RAGStat(
|
|
510
524
|
recall_stat=RecallStat(
|
|
@@ -525,17 +539,62 @@ class LongContextRAG:
|
|
|
525
539
|
)
|
|
526
540
|
|
|
527
541
|
context = []
|
|
542
|
+
|
|
528
543
|
def generate_sream():
|
|
529
544
|
nonlocal context
|
|
530
|
-
doc_filter_result = self._filter_docs(conversations)
|
|
531
545
|
|
|
532
|
-
|
|
533
|
-
|
|
546
|
+
yield ("", SingleOutputMeta(input_tokens_count=0,
|
|
547
|
+
generated_tokens_count=0,
|
|
548
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
549
|
+
"rag_searching_docs",
|
|
550
|
+
model=rag_stat.recall_stat.model_name
|
|
551
|
+
)
|
|
552
|
+
))
|
|
553
|
+
|
|
554
|
+
doc_filter_result = DocFilterResult(
|
|
555
|
+
docs=[],
|
|
556
|
+
raw_docs=[],
|
|
557
|
+
input_tokens_counts=[],
|
|
558
|
+
generated_tokens_counts=[],
|
|
559
|
+
durations=[],
|
|
560
|
+
model_name=rag_stat.recall_stat.model_name
|
|
561
|
+
)
|
|
562
|
+
query = conversations[-1]["content"]
|
|
563
|
+
documents = self._retrieve_documents(options={"query": query})
|
|
564
|
+
|
|
565
|
+
# 使用带进度报告的过滤方法
|
|
566
|
+
for progress_update, result in self.doc_filter.filter_docs_with_progress(conversations, documents):
|
|
567
|
+
if result is not None:
|
|
568
|
+
doc_filter_result = result
|
|
569
|
+
else:
|
|
570
|
+
# 生成进度更新
|
|
571
|
+
yield ("", SingleOutputMeta(
|
|
572
|
+
input_tokens_count=rag_stat.recall_stat.total_input_tokens,
|
|
573
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
|
|
574
|
+
reasoning_content=f"{progress_update.message} ({progress_update.completed}/{progress_update.total})"
|
|
575
|
+
))
|
|
576
|
+
|
|
577
|
+
rag_stat.recall_stat.total_input_tokens += sum(
|
|
578
|
+
doc_filter_result.input_tokens_counts)
|
|
579
|
+
rag_stat.recall_stat.total_generated_tokens += sum(
|
|
580
|
+
doc_filter_result.generated_tokens_counts)
|
|
534
581
|
rag_stat.recall_stat.model_name = doc_filter_result.model_name
|
|
535
582
|
|
|
536
583
|
relevant_docs: List[FilterDoc] = doc_filter_result.docs
|
|
537
584
|
filter_time = time.time() - start_time
|
|
538
585
|
|
|
586
|
+
yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens,
|
|
587
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
|
|
588
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
589
|
+
"rag_docs_filter_result",
|
|
590
|
+
filter_time=filter_time,
|
|
591
|
+
docs_num=len(relevant_docs),
|
|
592
|
+
input_tokens=rag_stat.recall_stat.total_input_tokens,
|
|
593
|
+
output_tokens=rag_stat.recall_stat.total_generated_tokens,
|
|
594
|
+
model=rag_stat.recall_stat.model_name
|
|
595
|
+
)
|
|
596
|
+
))
|
|
597
|
+
|
|
539
598
|
# Filter relevant_docs to only include those with is_relevant=True
|
|
540
599
|
highly_relevant_docs = [
|
|
541
600
|
doc for doc in relevant_docs if doc.relevance.is_relevant
|
|
@@ -543,7 +602,8 @@ class LongContextRAG:
|
|
|
543
602
|
|
|
544
603
|
if highly_relevant_docs:
|
|
545
604
|
relevant_docs = highly_relevant_docs
|
|
546
|
-
logger.info(
|
|
605
|
+
logger.info(
|
|
606
|
+
f"Found {len(relevant_docs)} highly relevant documents")
|
|
547
607
|
|
|
548
608
|
logger.info(
|
|
549
609
|
f"Filter time: {filter_time:.2f} seconds with {len(relevant_docs)} docs"
|
|
@@ -553,7 +613,7 @@ class LongContextRAG:
|
|
|
553
613
|
final_docs = []
|
|
554
614
|
for doc in relevant_docs:
|
|
555
615
|
final_docs.append(doc.model_dump())
|
|
556
|
-
return [json.dumps(final_docs,ensure_ascii=False)], []
|
|
616
|
+
return [json.dumps(final_docs, ensure_ascii=False)], []
|
|
557
617
|
|
|
558
618
|
if not relevant_docs:
|
|
559
619
|
return ["没有找到相关的文档来回答这个问题。"], []
|
|
@@ -588,6 +648,12 @@ class LongContextRAG:
|
|
|
588
648
|
+ "".join([f"\n * {info}" for info in relevant_docs_info])
|
|
589
649
|
)
|
|
590
650
|
|
|
651
|
+
yield ("", SingleOutputMeta(generated_tokens_count=0,
|
|
652
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
653
|
+
"dynamic_chunking_start",
|
|
654
|
+
model=rag_stat.chunk_stat.model_name
|
|
655
|
+
)
|
|
656
|
+
))
|
|
591
657
|
first_round_full_docs = []
|
|
592
658
|
second_round_extracted_docs = []
|
|
593
659
|
sencond_round_time = 0
|
|
@@ -602,17 +668,19 @@ class LongContextRAG:
|
|
|
602
668
|
llm=self.llm,
|
|
603
669
|
disable_segment_reorder=self.args.disable_segment_reorder,
|
|
604
670
|
)
|
|
605
|
-
|
|
671
|
+
|
|
606
672
|
token_limiter_result = token_limiter.limit_tokens(
|
|
607
673
|
relevant_docs=relevant_docs,
|
|
608
674
|
conversations=conversations,
|
|
609
675
|
index_filter_workers=self.args.index_filter_workers or 5,
|
|
610
676
|
)
|
|
611
677
|
|
|
612
|
-
rag_stat.chunk_stat.total_input_tokens += sum(
|
|
613
|
-
|
|
678
|
+
rag_stat.chunk_stat.total_input_tokens += sum(
|
|
679
|
+
token_limiter_result.input_tokens_counts)
|
|
680
|
+
rag_stat.chunk_stat.total_generated_tokens += sum(
|
|
681
|
+
token_limiter_result.generated_tokens_counts)
|
|
614
682
|
rag_stat.chunk_stat.model_name = token_limiter_result.model_name
|
|
615
|
-
|
|
683
|
+
|
|
616
684
|
final_relevant_docs = token_limiter_result.docs
|
|
617
685
|
first_round_full_docs = token_limiter.first_round_full_docs
|
|
618
686
|
second_round_extracted_docs = token_limiter.second_round_extracted_docs
|
|
@@ -623,24 +691,41 @@ class LongContextRAG:
|
|
|
623
691
|
relevant_docs = relevant_docs[: self.args.index_filter_file_num]
|
|
624
692
|
|
|
625
693
|
logger.info(f"Finally send to model: {len(relevant_docs)}")
|
|
626
|
-
|
|
627
694
|
# 记录分段处理的统计信息
|
|
628
695
|
logger.info(
|
|
629
696
|
f"=== Token Management ===\n"
|
|
630
697
|
f" * Only contexts: {only_contexts}\n"
|
|
631
|
-
f" * Filter time: {filter_time:.2f} seconds\n"
|
|
698
|
+
f" * Filter time: {filter_time:.2f} seconds\n"
|
|
632
699
|
f" * Final relevant docs: {len(relevant_docs)}\n"
|
|
633
700
|
f" * First round full docs: {len(first_round_full_docs)}\n"
|
|
634
701
|
f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
|
|
635
702
|
f" * Second round time: {sencond_round_time:.2f} seconds"
|
|
636
703
|
)
|
|
637
704
|
|
|
705
|
+
yield ("", SingleOutputMeta(generated_tokens_count=rag_stat.chunk_stat.total_generated_tokens + rag_stat.recall_stat.total_generated_tokens,
|
|
706
|
+
input_tokens_count=rag_stat.chunk_stat.total_input_tokens +
|
|
707
|
+
rag_stat.recall_stat.total_input_tokens,
|
|
708
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
709
|
+
"dynamic_chunking_result",
|
|
710
|
+
model=rag_stat.chunk_stat.model_name,
|
|
711
|
+
docs_num=len(relevant_docs),
|
|
712
|
+
filter_time=filter_time,
|
|
713
|
+
sencond_round_time=sencond_round_time,
|
|
714
|
+
first_round_full_docs=len(
|
|
715
|
+
first_round_full_docs),
|
|
716
|
+
second_round_extracted_docs=len(
|
|
717
|
+
second_round_extracted_docs),
|
|
718
|
+
input_tokens=rag_stat.chunk_stat.total_input_tokens,
|
|
719
|
+
output_tokens=rag_stat.chunk_stat.total_generated_tokens
|
|
720
|
+
)
|
|
721
|
+
))
|
|
722
|
+
|
|
638
723
|
# 记录最终选择的文档详情
|
|
639
724
|
final_relevant_docs_info = []
|
|
640
725
|
for i, doc in enumerate(relevant_docs):
|
|
641
726
|
doc_path = doc.module_name.replace(self.path, '', 1)
|
|
642
727
|
info = f"{i+1}. {doc_path}"
|
|
643
|
-
|
|
728
|
+
|
|
644
729
|
metadata_info = []
|
|
645
730
|
if "original_docs" in doc.metadata:
|
|
646
731
|
original_docs = ", ".join(
|
|
@@ -650,26 +735,27 @@ class LongContextRAG:
|
|
|
650
735
|
]
|
|
651
736
|
)
|
|
652
737
|
metadata_info.append(f"Original docs: {original_docs}")
|
|
653
|
-
|
|
738
|
+
|
|
654
739
|
if "chunk_ranges" in doc.metadata:
|
|
655
740
|
chunk_ranges = json.dumps(
|
|
656
741
|
doc.metadata["chunk_ranges"], ensure_ascii=False
|
|
657
742
|
)
|
|
658
743
|
metadata_info.append(f"Chunk ranges: {chunk_ranges}")
|
|
659
|
-
|
|
744
|
+
|
|
660
745
|
if "processing_time" in doc.metadata:
|
|
661
|
-
metadata_info.append(
|
|
662
|
-
|
|
746
|
+
metadata_info.append(
|
|
747
|
+
f"Processing time: {doc.metadata['processing_time']:.2f}s")
|
|
748
|
+
|
|
663
749
|
if metadata_info:
|
|
664
750
|
info += f" ({'; '.join(metadata_info)})"
|
|
665
|
-
|
|
751
|
+
|
|
666
752
|
final_relevant_docs_info.append(info)
|
|
667
753
|
|
|
668
754
|
if final_relevant_docs_info:
|
|
669
755
|
logger.info(
|
|
670
756
|
f"Final documents to be sent to model:"
|
|
671
757
|
+ "".join([f"\n * {info}" for info in final_relevant_docs_info])
|
|
672
|
-
|
|
758
|
+
)
|
|
673
759
|
|
|
674
760
|
# 记录令牌统计
|
|
675
761
|
request_tokens = sum([doc.tokens for doc in relevant_docs])
|
|
@@ -680,7 +766,18 @@ class LongContextRAG:
|
|
|
680
766
|
f" * Total tokens: {request_tokens}"
|
|
681
767
|
)
|
|
682
768
|
|
|
683
|
-
logger.info(
|
|
769
|
+
logger.info(
|
|
770
|
+
f"Start to send to model {target_model} with {request_tokens} tokens")
|
|
771
|
+
|
|
772
|
+
yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
|
|
773
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
|
|
774
|
+
rag_stat.chunk_stat.total_generated_tokens,
|
|
775
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
776
|
+
"send_to_model",
|
|
777
|
+
model=target_model,
|
|
778
|
+
tokens=request_tokens
|
|
779
|
+
)
|
|
780
|
+
))
|
|
684
781
|
|
|
685
782
|
if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
|
|
686
783
|
llm_compute_engine = LLMComputeEngine(
|
|
@@ -692,33 +789,42 @@ class LongContextRAG:
|
|
|
692
789
|
debug=False,
|
|
693
790
|
)
|
|
694
791
|
new_conversations = llm_compute_engine.process_conversation(
|
|
695
|
-
conversations, query, [
|
|
792
|
+
conversations, query, [
|
|
793
|
+
doc.source_code for doc in relevant_docs]
|
|
696
794
|
)
|
|
697
795
|
chunks = llm_compute_engine.stream_chat_oai(
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
796
|
+
conversations=new_conversations,
|
|
797
|
+
model=model,
|
|
798
|
+
role_mapping=role_mapping,
|
|
799
|
+
llm_config=llm_config,
|
|
800
|
+
delta_mode=True,
|
|
801
|
+
)
|
|
802
|
+
|
|
705
803
|
for chunk in chunks:
|
|
706
|
-
yield chunk
|
|
707
804
|
if chunk[1] is not None:
|
|
708
805
|
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
709
|
-
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
710
|
-
|
|
711
|
-
|
|
806
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
807
|
+
chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
|
|
808
|
+
rag_stat.chunk_stat.total_input_tokens + \
|
|
809
|
+
rag_stat.answer_stat.total_input_tokens
|
|
810
|
+
chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
|
|
811
|
+
rag_stat.chunk_stat.total_generated_tokens + \
|
|
812
|
+
rag_stat.answer_stat.total_generated_tokens
|
|
813
|
+
yield chunk
|
|
814
|
+
|
|
815
|
+
self._print_rag_stats(rag_stat)
|
|
816
|
+
else:
|
|
712
817
|
new_conversations = conversations[:-1] + [
|
|
713
818
|
{
|
|
714
819
|
"role": "user",
|
|
715
820
|
"content": self._answer_question.prompt(
|
|
716
821
|
query=query,
|
|
717
|
-
relevant_docs=[
|
|
822
|
+
relevant_docs=[
|
|
823
|
+
doc.source_code for doc in relevant_docs],
|
|
718
824
|
),
|
|
719
825
|
}
|
|
720
826
|
]
|
|
721
|
-
|
|
827
|
+
|
|
722
828
|
chunks = target_llm.stream_chat_oai(
|
|
723
829
|
conversations=new_conversations,
|
|
724
830
|
model=model,
|
|
@@ -727,17 +833,23 @@ class LongContextRAG:
|
|
|
727
833
|
delta_mode=True,
|
|
728
834
|
extra_request_params=extra_request_params
|
|
729
835
|
)
|
|
730
|
-
|
|
836
|
+
|
|
731
837
|
for chunk in chunks:
|
|
732
|
-
yield chunk
|
|
733
838
|
if chunk[1] is not None:
|
|
734
839
|
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
735
|
-
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
736
|
-
|
|
840
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
841
|
+
chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
|
|
842
|
+
rag_stat.chunk_stat.total_input_tokens + \
|
|
843
|
+
rag_stat.answer_stat.total_input_tokens
|
|
844
|
+
chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
|
|
845
|
+
rag_stat.chunk_stat.total_generated_tokens + \
|
|
846
|
+
rag_stat.answer_stat.total_generated_tokens
|
|
847
|
+
|
|
848
|
+
yield chunk
|
|
737
849
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
850
|
+
self._print_rag_stats(rag_stat)
|
|
851
|
+
|
|
852
|
+
return generate_sream(), context
|
|
741
853
|
|
|
742
854
|
def _print_rag_stats(self, rag_stat: RAGStat) -> None:
|
|
743
855
|
"""打印RAG执行的详细统计信息"""
|
|
@@ -748,19 +860,22 @@ class LongContextRAG:
|
|
|
748
860
|
)
|
|
749
861
|
total_generated_tokens = (
|
|
750
862
|
rag_stat.recall_stat.total_generated_tokens +
|
|
751
|
-
rag_stat.chunk_stat.total_generated_tokens +
|
|
863
|
+
rag_stat.chunk_stat.total_generated_tokens +
|
|
752
864
|
rag_stat.answer_stat.total_generated_tokens
|
|
753
865
|
)
|
|
754
866
|
total_tokens = total_input_tokens + total_generated_tokens
|
|
755
|
-
|
|
867
|
+
|
|
756
868
|
# 避免除以零错误
|
|
757
869
|
if total_tokens == 0:
|
|
758
870
|
recall_percent = chunk_percent = answer_percent = 0
|
|
759
871
|
else:
|
|
760
|
-
recall_percent = (rag_stat.recall_stat.total_input_tokens +
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
872
|
+
recall_percent = (rag_stat.recall_stat.total_input_tokens +
|
|
873
|
+
rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
|
|
874
|
+
chunk_percent = (rag_stat.chunk_stat.total_input_tokens +
|
|
875
|
+
rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
|
|
876
|
+
answer_percent = (rag_stat.answer_stat.total_input_tokens +
|
|
877
|
+
rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
|
|
878
|
+
|
|
764
879
|
logger.info(
|
|
765
880
|
f"=== RAG 执行统计信息 ===\n"
|
|
766
881
|
f"总令牌使用: {total_tokens} 令牌\n"
|
|
@@ -791,21 +906,22 @@ class LongContextRAG:
|
|
|
791
906
|
f" - 文档分块: {chunk_percent:.1f}%\n"
|
|
792
907
|
f" - 答案生成: {answer_percent:.1f}%\n"
|
|
793
908
|
)
|
|
794
|
-
|
|
909
|
+
|
|
795
910
|
# 记录原始统计数据,以便调试
|
|
796
911
|
logger.debug(f"RAG Stat 原始数据: {rag_stat}")
|
|
797
|
-
|
|
912
|
+
|
|
798
913
|
# 返回成本估算
|
|
799
|
-
estimated_cost = self._estimate_token_cost(
|
|
914
|
+
estimated_cost = self._estimate_token_cost(
|
|
915
|
+
total_input_tokens, total_generated_tokens)
|
|
800
916
|
if estimated_cost > 0:
|
|
801
917
|
logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
|
|
802
918
|
|
|
803
919
|
def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
|
|
804
|
-
"""估算当前请求的令牌成本(人民币)"""
|
|
920
|
+
"""估算当前请求的令牌成本(人民币)"""
|
|
805
921
|
# 实际应用中,可以根据不同模型设置不同价格
|
|
806
922
|
input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
|
|
807
923
|
output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
|
|
808
|
-
|
|
809
|
-
cost = (input_tokens * input_cost_per_1m / 1000000) +
|
|
924
|
+
|
|
925
|
+
cost = (input_tokens * input_cost_per_1m / 1000000) + \
|
|
926
|
+
(output_tokens * output_cost_per_1m/1000000)
|
|
810
927
|
return cost
|
|
811
|
-
|