auto-coder 0.1.278__py3-none-any.whl → 0.1.280__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/METADATA +2 -2
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/RECORD +14 -12
- autocoder/rag/api_server.py +1 -3
- autocoder/rag/doc_filter.py +104 -29
- autocoder/rag/lang.py +50 -0
- autocoder/rag/llm_wrapper.py +70 -40
- autocoder/rag/long_context_rag.py +353 -197
- autocoder/rag/relevant_utils.py +10 -0
- autocoder/utils/stream_thinking.py +193 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.278.dist-info → auto_coder-0.1.280.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,8 @@ from autocoder.rag.relevant_utils import (
|
|
|
23
23
|
FilterDoc,
|
|
24
24
|
TaskTiming,
|
|
25
25
|
parse_relevance,
|
|
26
|
+
ProgressUpdate,
|
|
27
|
+
DocFilterResult
|
|
26
28
|
)
|
|
27
29
|
from autocoder.rag.token_checker import check_token_limit
|
|
28
30
|
from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
|
|
@@ -33,14 +35,18 @@ from importlib.metadata import version
|
|
|
33
35
|
from autocoder.rag.stream_event import event_writer
|
|
34
36
|
from autocoder.rag.relevant_utils import DocFilterResult
|
|
35
37
|
from pydantic import BaseModel
|
|
38
|
+
from byzerllm.utils.types import SingleOutputMeta
|
|
39
|
+
from autocoder.rag.lang import get_message_with_format_and_newline
|
|
36
40
|
|
|
37
|
-
try:
|
|
41
|
+
try:
|
|
38
42
|
from autocoder_pro.rag.llm_compute import LLMComputeEngine
|
|
39
43
|
pro_version = version("auto-coder-pro")
|
|
40
44
|
autocoder_version = version("auto-coder")
|
|
41
|
-
logger.warning(
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
|
|
42
47
|
except ImportError:
|
|
43
|
-
logger.warning(
|
|
48
|
+
logger.warning(
|
|
49
|
+
"Please install auto-coder-pro to enhance llm compute ability")
|
|
44
50
|
LLMComputeEngine = None
|
|
45
51
|
|
|
46
52
|
|
|
@@ -48,20 +54,26 @@ class RecallStat(BaseModel):
|
|
|
48
54
|
total_input_tokens: int
|
|
49
55
|
total_generated_tokens: int
|
|
50
56
|
model_name: str = "unknown"
|
|
57
|
+
|
|
58
|
+
|
|
51
59
|
class ChunkStat(BaseModel):
|
|
52
60
|
total_input_tokens: int
|
|
53
|
-
total_generated_tokens: int
|
|
61
|
+
total_generated_tokens: int
|
|
54
62
|
model_name: str = "unknown"
|
|
63
|
+
|
|
64
|
+
|
|
55
65
|
class AnswerStat(BaseModel):
|
|
56
66
|
total_input_tokens: int
|
|
57
67
|
total_generated_tokens: int
|
|
58
68
|
model_name: str = "unknown"
|
|
59
69
|
|
|
70
|
+
|
|
60
71
|
class RAGStat(BaseModel):
|
|
61
72
|
recall_stat: RecallStat
|
|
62
73
|
chunk_stat: ChunkStat
|
|
63
74
|
answer_stat: AnswerStat
|
|
64
75
|
|
|
76
|
+
|
|
65
77
|
class LongContextRAG:
|
|
66
78
|
def __init__(
|
|
67
79
|
self,
|
|
@@ -85,7 +97,7 @@ class LongContextRAG:
|
|
|
85
97
|
self.chunk_llm = self.llm.get_sub_client("chunk_model")
|
|
86
98
|
|
|
87
99
|
self.args = args
|
|
88
|
-
|
|
100
|
+
|
|
89
101
|
self.path = path
|
|
90
102
|
self.relevant_score = self.args.rag_doc_filter_relevance or 5
|
|
91
103
|
|
|
@@ -98,8 +110,10 @@ class LongContextRAG:
|
|
|
98
110
|
"The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
|
|
99
111
|
)
|
|
100
112
|
|
|
101
|
-
self.full_text_limit = int(
|
|
102
|
-
|
|
113
|
+
self.full_text_limit = int(
|
|
114
|
+
args.rag_context_window_limit * self.full_text_ratio)
|
|
115
|
+
self.segment_limit = int(
|
|
116
|
+
args.rag_context_window_limit * self.segment_ratio)
|
|
103
117
|
self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
|
|
104
118
|
|
|
105
119
|
self.tokenizer = None
|
|
@@ -108,7 +122,8 @@ class LongContextRAG:
|
|
|
108
122
|
|
|
109
123
|
if self.tokenizer_path:
|
|
110
124
|
VariableHolder.TOKENIZER_PATH = self.tokenizer_path
|
|
111
|
-
VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
|
|
125
|
+
VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
|
|
126
|
+
self.tokenizer_path)
|
|
112
127
|
self.tokenizer = TokenCounter(self.tokenizer_path)
|
|
113
128
|
else:
|
|
114
129
|
if llm.is_model_exist("deepseek_tokenizer"):
|
|
@@ -160,9 +175,9 @@ class LongContextRAG:
|
|
|
160
175
|
self.required_exts,
|
|
161
176
|
self.on_ray,
|
|
162
177
|
self.monitor_mode,
|
|
163
|
-
|
|
178
|
+
# 确保全文区至少能放下一个文件
|
|
164
179
|
single_file_token_limit=self.full_text_limit - 100,
|
|
165
|
-
disable_auto_window=self.args.disable_auto_window,
|
|
180
|
+
disable_auto_window=self.args.disable_auto_window,
|
|
166
181
|
enable_hybrid_index=self.args.enable_hybrid_index,
|
|
167
182
|
extra_params=self.args
|
|
168
183
|
)
|
|
@@ -223,14 +238,14 @@ class LongContextRAG:
|
|
|
223
238
|
{% for msg in conversations %}
|
|
224
239
|
[{{ msg.role }}]:
|
|
225
240
|
{{ msg.content }}
|
|
226
|
-
|
|
241
|
+
|
|
227
242
|
{% endfor %}
|
|
228
243
|
</conversations>
|
|
229
244
|
|
|
230
245
|
请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
|
|
231
246
|
如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
|
|
232
247
|
提取的信息尽量保持和原文中的一样,并且只输出这些信息。
|
|
233
|
-
"""
|
|
248
|
+
"""
|
|
234
249
|
|
|
235
250
|
@byzerllm.prompt()
|
|
236
251
|
def _answer_question(
|
|
@@ -265,20 +280,20 @@ class LongContextRAG:
|
|
|
265
280
|
"""Get the document retriever class based on configuration."""
|
|
266
281
|
# Default to LocalDocumentRetriever if not specified
|
|
267
282
|
return LocalDocumentRetriever
|
|
268
|
-
|
|
283
|
+
|
|
269
284
|
def _load_ignore_file(self):
|
|
270
285
|
serveignore_path = os.path.join(self.path, ".serveignore")
|
|
271
286
|
gitignore_path = os.path.join(self.path, ".gitignore")
|
|
272
287
|
|
|
273
288
|
if os.path.exists(serveignore_path):
|
|
274
|
-
with open(serveignore_path, "r",encoding="utf-8") as ignore_file:
|
|
289
|
+
with open(serveignore_path, "r", encoding="utf-8") as ignore_file:
|
|
275
290
|
return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
|
|
276
291
|
elif os.path.exists(gitignore_path):
|
|
277
|
-
with open(gitignore_path, "r",encoding="utf-8") as ignore_file:
|
|
292
|
+
with open(gitignore_path, "r", encoding="utf-8") as ignore_file:
|
|
278
293
|
return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
|
|
279
294
|
return None
|
|
280
295
|
|
|
281
|
-
def _retrieve_documents(self,options:Optional[Dict[str,Any]]=None) -> Generator[SourceCode, None, None]:
|
|
296
|
+
def _retrieve_documents(self, options: Optional[Dict[str, Any]] = None) -> Generator[SourceCode, None, None]:
|
|
282
297
|
return self.document_retriever.retrieve_documents(options=options)
|
|
283
298
|
|
|
284
299
|
def build(self):
|
|
@@ -298,7 +313,8 @@ class LongContextRAG:
|
|
|
298
313
|
only_contexts = True
|
|
299
314
|
|
|
300
315
|
logger.info("Search from RAG.....")
|
|
301
|
-
logger.info(
|
|
316
|
+
logger.info(
|
|
317
|
+
f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
|
|
302
318
|
|
|
303
319
|
if self.client:
|
|
304
320
|
new_query = json.dumps(
|
|
@@ -314,7 +330,8 @@ class LongContextRAG:
|
|
|
314
330
|
if not only_contexts:
|
|
315
331
|
return [SourceCode(module_name=f"RAG:{target_query}", source_code=v)]
|
|
316
332
|
|
|
317
|
-
json_lines = [json.loads(line)
|
|
333
|
+
json_lines = [json.loads(line)
|
|
334
|
+
for line in v.split("\n") if line.strip()]
|
|
318
335
|
return [SourceCode.model_validate(json_line) for json_line in json_lines]
|
|
319
336
|
else:
|
|
320
337
|
if only_contexts:
|
|
@@ -333,7 +350,7 @@ class LongContextRAG:
|
|
|
333
350
|
|
|
334
351
|
def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
|
|
335
352
|
query = conversations[-1]["content"]
|
|
336
|
-
documents = self._retrieve_documents(options={"query":query})
|
|
353
|
+
documents = self._retrieve_documents(options={"query": query})
|
|
337
354
|
return self.doc_filter.filter_docs(
|
|
338
355
|
conversations=conversations, documents=documents
|
|
339
356
|
)
|
|
@@ -344,6 +361,7 @@ class LongContextRAG:
|
|
|
344
361
|
model: Optional[str] = None,
|
|
345
362
|
role_mapping=None,
|
|
346
363
|
llm_config: Dict[str, Any] = {},
|
|
364
|
+
extra_request_params: Dict[str, Any] = {}
|
|
347
365
|
):
|
|
348
366
|
try:
|
|
349
367
|
return self._stream_chat_oai(
|
|
@@ -351,18 +369,49 @@ class LongContextRAG:
|
|
|
351
369
|
model=model,
|
|
352
370
|
role_mapping=role_mapping,
|
|
353
371
|
llm_config=llm_config,
|
|
372
|
+
extra_request_params=extra_request_params
|
|
354
373
|
)
|
|
355
374
|
except Exception as e:
|
|
356
375
|
logger.error(f"Error in stream_chat_oai: {str(e)}")
|
|
357
376
|
traceback.print_exc()
|
|
358
377
|
return ["出现错误,请稍后再试。"], []
|
|
359
378
|
|
|
379
|
+
def _stream_chatfrom_openai_sdk(self, response):
|
|
380
|
+
for chunk in response:
|
|
381
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
382
|
+
input_tokens_count = chunk.usage.prompt_tokens
|
|
383
|
+
generated_tokens_count = chunk.usage.completion_tokens
|
|
384
|
+
else:
|
|
385
|
+
input_tokens_count = 0
|
|
386
|
+
generated_tokens_count = 0
|
|
387
|
+
|
|
388
|
+
if not chunk.choices:
|
|
389
|
+
if last_meta:
|
|
390
|
+
yield ("", SingleOutputMeta(input_tokens_count=input_tokens_count,
|
|
391
|
+
generated_tokens_count=generated_tokens_count,
|
|
392
|
+
reasoning_content="",
|
|
393
|
+
finish_reason=last_meta.finish_reason))
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
content = chunk.choices[0].delta.content or ""
|
|
397
|
+
|
|
398
|
+
reasoning_text = ""
|
|
399
|
+
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
|
400
|
+
reasoning_text = chunk.choices[0].delta.reasoning_content or ""
|
|
401
|
+
|
|
402
|
+
last_meta = SingleOutputMeta(input_tokens_count=input_tokens_count,
|
|
403
|
+
generated_tokens_count=generated_tokens_count,
|
|
404
|
+
reasoning_content=reasoning_text,
|
|
405
|
+
finish_reason=chunk.choices[0].finish_reason)
|
|
406
|
+
yield (content, last_meta)
|
|
407
|
+
|
|
360
408
|
def _stream_chat_oai(
|
|
361
409
|
self,
|
|
362
410
|
conversations,
|
|
363
411
|
model: Optional[str] = None,
|
|
364
412
|
role_mapping=None,
|
|
365
413
|
llm_config: Dict[str, Any] = {},
|
|
414
|
+
extra_request_params: Dict[str, Any] = {}
|
|
366
415
|
):
|
|
367
416
|
if self.client:
|
|
368
417
|
model = model or self.args.model
|
|
@@ -370,130 +419,182 @@ class LongContextRAG:
|
|
|
370
419
|
model=model,
|
|
371
420
|
messages=conversations,
|
|
372
421
|
stream=True,
|
|
373
|
-
max_tokens=self.args.rag_params_max_tokens
|
|
422
|
+
max_tokens=self.args.rag_params_max_tokens,
|
|
423
|
+
extra_body=extra_request_params
|
|
374
424
|
)
|
|
425
|
+
return self._stream_chatfrom_openai_sdk(response), []
|
|
375
426
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
yield chunk.choices[0].delta.content
|
|
427
|
+
target_llm = self.llm
|
|
428
|
+
if self.llm.get_sub_client("qa_model"):
|
|
429
|
+
target_llm = self.llm.get_sub_client("qa_model")
|
|
380
430
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
target_llm = self.llm
|
|
385
|
-
if self.llm.get_sub_client("qa_model"):
|
|
386
|
-
target_llm = self.llm.get_sub_client("qa_model")
|
|
431
|
+
query = conversations[-1]["content"]
|
|
432
|
+
context = []
|
|
387
433
|
|
|
388
|
-
|
|
389
|
-
|
|
434
|
+
if (
|
|
435
|
+
"使用四到五个字直接返回这句话的简要主题,不要解释、不要标点、不要语气词、不要多余文本,不要加粗,如果没有主题"
|
|
436
|
+
in query
|
|
437
|
+
or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
|
|
438
|
+
in query
|
|
439
|
+
):
|
|
390
440
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
441
|
+
chunks = target_llm.stream_chat_oai(
|
|
442
|
+
conversations=conversations,
|
|
443
|
+
model=model,
|
|
444
|
+
role_mapping=role_mapping,
|
|
445
|
+
llm_config=llm_config,
|
|
446
|
+
delta_mode=True,
|
|
447
|
+
extra_request_params=extra_request_params
|
|
448
|
+
)
|
|
397
449
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
llm_config=llm_config,
|
|
403
|
-
delta_mode=True,
|
|
404
|
-
)
|
|
405
|
-
return (chunk[0] for chunk in chunks), context
|
|
406
|
-
|
|
407
|
-
try:
|
|
408
|
-
request_params = json.loads(query)
|
|
409
|
-
if "request_id" in request_params:
|
|
410
|
-
request_id = request_params["request_id"]
|
|
411
|
-
index = request_params["index"]
|
|
412
|
-
|
|
413
|
-
file_path = event_writer.get_event_file_path(request_id)
|
|
414
|
-
logger.info(f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
|
|
415
|
-
events = []
|
|
416
|
-
if not os.path.exists(file_path):
|
|
417
|
-
return [],context
|
|
418
|
-
|
|
419
|
-
with open(file_path, "r") as f:
|
|
420
|
-
for line in f:
|
|
421
|
-
event = json.loads(line)
|
|
422
|
-
if event["index"] >= index:
|
|
423
|
-
events.append(event)
|
|
424
|
-
return [json.dumps({
|
|
425
|
-
"events": [event for event in events],
|
|
426
|
-
},ensure_ascii=False)], context
|
|
427
|
-
except json.JSONDecodeError:
|
|
428
|
-
pass
|
|
429
|
-
|
|
430
|
-
if self.args.without_contexts and LLMComputeEngine is not None:
|
|
431
|
-
llm_compute_engine = LLMComputeEngine(
|
|
432
|
-
llm=target_llm,
|
|
433
|
-
inference_enhance=not self.args.disable_inference_enhance,
|
|
434
|
-
inference_deep_thought=self.args.inference_deep_thought,
|
|
435
|
-
inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
|
|
436
|
-
precision=self.args.inference_compute_precision,
|
|
437
|
-
data_cells_max_num=self.args.data_cells_max_num,
|
|
438
|
-
)
|
|
439
|
-
conversations = conversations[:-1]
|
|
440
|
-
new_conversations = llm_compute_engine.process_conversation(
|
|
441
|
-
conversations, query, []
|
|
442
|
-
)
|
|
450
|
+
def generate_chunks():
|
|
451
|
+
for chunk in chunks:
|
|
452
|
+
yield chunk
|
|
453
|
+
return generate_chunks(), context
|
|
443
454
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
llm_config=llm_config,
|
|
450
|
-
delta_mode=True,
|
|
451
|
-
),
|
|
452
|
-
context,
|
|
453
|
-
)
|
|
455
|
+
try:
|
|
456
|
+
request_params = json.loads(query)
|
|
457
|
+
if "request_id" in request_params:
|
|
458
|
+
request_id = request_params["request_id"]
|
|
459
|
+
index = request_params["index"]
|
|
454
460
|
|
|
461
|
+
file_path = event_writer.get_event_file_path(request_id)
|
|
462
|
+
logger.info(
|
|
463
|
+
f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
|
|
464
|
+
events = []
|
|
465
|
+
if not os.path.exists(file_path):
|
|
466
|
+
return [], context
|
|
467
|
+
|
|
468
|
+
with open(file_path, "r") as f:
|
|
469
|
+
for line in f:
|
|
470
|
+
event = json.loads(line)
|
|
471
|
+
if event["index"] >= index:
|
|
472
|
+
events.append(event)
|
|
473
|
+
return [json.dumps({
|
|
474
|
+
"events": [event for event in events],
|
|
475
|
+
}, ensure_ascii=False)], context
|
|
476
|
+
except json.JSONDecodeError:
|
|
477
|
+
pass
|
|
478
|
+
|
|
479
|
+
if self.args.without_contexts and LLMComputeEngine is not None:
|
|
480
|
+
llm_compute_engine = LLMComputeEngine(
|
|
481
|
+
llm=target_llm,
|
|
482
|
+
inference_enhance=not self.args.disable_inference_enhance,
|
|
483
|
+
inference_deep_thought=self.args.inference_deep_thought,
|
|
484
|
+
inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
|
|
485
|
+
precision=self.args.inference_compute_precision,
|
|
486
|
+
data_cells_max_num=self.args.data_cells_max_num,
|
|
487
|
+
)
|
|
488
|
+
conversations = conversations[:-1]
|
|
489
|
+
new_conversations = llm_compute_engine.process_conversation(
|
|
490
|
+
conversations, query, []
|
|
491
|
+
)
|
|
492
|
+
chunks = llm_compute_engine.stream_chat_oai(
|
|
493
|
+
conversations=new_conversations,
|
|
494
|
+
model=model,
|
|
495
|
+
role_mapping=role_mapping,
|
|
496
|
+
llm_config=llm_config,
|
|
497
|
+
delta_mode=True,
|
|
498
|
+
extra_request_params=extra_request_params
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
def generate_chunks():
|
|
502
|
+
for chunk in chunks:
|
|
503
|
+
yield chunk
|
|
455
504
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
if "only_contexts" in v:
|
|
460
|
-
query = v["query"]
|
|
461
|
-
only_contexts = v["only_contexts"]
|
|
462
|
-
conversations[-1]["content"] = query
|
|
463
|
-
except json.JSONDecodeError:
|
|
464
|
-
pass
|
|
465
|
-
|
|
466
|
-
logger.info(f"Query: {query} only_contexts: {only_contexts}")
|
|
467
|
-
start_time = time.time()
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
rag_stat = RAGStat(
|
|
471
|
-
recall_stat=RecallStat(
|
|
472
|
-
total_input_tokens=0,
|
|
473
|
-
total_generated_tokens=0,
|
|
474
|
-
model_name=self.recall_llm.default_model_name,
|
|
475
|
-
),
|
|
476
|
-
chunk_stat=ChunkStat(
|
|
477
|
-
total_input_tokens=0,
|
|
478
|
-
total_generated_tokens=0,
|
|
479
|
-
model_name=self.chunk_llm.default_model_name,
|
|
480
|
-
),
|
|
481
|
-
answer_stat=AnswerStat(
|
|
482
|
-
total_input_tokens=0,
|
|
483
|
-
total_generated_tokens=0,
|
|
484
|
-
model_name=self.qa_llm.default_model_name,
|
|
485
|
-
),
|
|
505
|
+
return (
|
|
506
|
+
generate_chunks(),
|
|
507
|
+
context,
|
|
486
508
|
)
|
|
487
509
|
|
|
488
|
-
|
|
510
|
+
only_contexts = False
|
|
511
|
+
try:
|
|
512
|
+
v = json.loads(query)
|
|
513
|
+
if "only_contexts" in v:
|
|
514
|
+
query = v["query"]
|
|
515
|
+
only_contexts = v["only_contexts"]
|
|
516
|
+
conversations[-1]["content"] = query
|
|
517
|
+
except json.JSONDecodeError:
|
|
518
|
+
pass
|
|
519
|
+
|
|
520
|
+
logger.info(f"Query: {query} only_contexts: {only_contexts}")
|
|
521
|
+
start_time = time.time()
|
|
522
|
+
|
|
523
|
+
rag_stat = RAGStat(
|
|
524
|
+
recall_stat=RecallStat(
|
|
525
|
+
total_input_tokens=0,
|
|
526
|
+
total_generated_tokens=0,
|
|
527
|
+
model_name=self.recall_llm.default_model_name,
|
|
528
|
+
),
|
|
529
|
+
chunk_stat=ChunkStat(
|
|
530
|
+
total_input_tokens=0,
|
|
531
|
+
total_generated_tokens=0,
|
|
532
|
+
model_name=self.chunk_llm.default_model_name,
|
|
533
|
+
),
|
|
534
|
+
answer_stat=AnswerStat(
|
|
535
|
+
total_input_tokens=0,
|
|
536
|
+
total_generated_tokens=0,
|
|
537
|
+
model_name=self.qa_llm.default_model_name,
|
|
538
|
+
),
|
|
539
|
+
)
|
|
489
540
|
|
|
490
|
-
|
|
491
|
-
|
|
541
|
+
context = []
|
|
542
|
+
|
|
543
|
+
def generate_sream():
|
|
544
|
+
nonlocal context
|
|
545
|
+
|
|
546
|
+
yield ("", SingleOutputMeta(input_tokens_count=0,
|
|
547
|
+
generated_tokens_count=0,
|
|
548
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
549
|
+
"rag_searching_docs",
|
|
550
|
+
model=rag_stat.recall_stat.model_name
|
|
551
|
+
)
|
|
552
|
+
))
|
|
553
|
+
|
|
554
|
+
doc_filter_result = DocFilterResult(
|
|
555
|
+
docs=[],
|
|
556
|
+
raw_docs=[],
|
|
557
|
+
input_tokens_counts=[],
|
|
558
|
+
generated_tokens_counts=[],
|
|
559
|
+
durations=[],
|
|
560
|
+
model_name=rag_stat.recall_stat.model_name
|
|
561
|
+
)
|
|
562
|
+
query = conversations[-1]["content"]
|
|
563
|
+
documents = self._retrieve_documents(options={"query": query})
|
|
564
|
+
|
|
565
|
+
# 使用带进度报告的过滤方法
|
|
566
|
+
for progress_update, result in self.doc_filter.filter_docs_with_progress(conversations, documents):
|
|
567
|
+
if result is not None:
|
|
568
|
+
doc_filter_result = result
|
|
569
|
+
else:
|
|
570
|
+
# 生成进度更新
|
|
571
|
+
yield ("", SingleOutputMeta(
|
|
572
|
+
input_tokens_count=rag_stat.recall_stat.total_input_tokens,
|
|
573
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
|
|
574
|
+
reasoning_content=f"{progress_update.message} ({progress_update.completed}/{progress_update.total})"
|
|
575
|
+
))
|
|
576
|
+
|
|
577
|
+
rag_stat.recall_stat.total_input_tokens += sum(
|
|
578
|
+
doc_filter_result.input_tokens_counts)
|
|
579
|
+
rag_stat.recall_stat.total_generated_tokens += sum(
|
|
580
|
+
doc_filter_result.generated_tokens_counts)
|
|
492
581
|
rag_stat.recall_stat.model_name = doc_filter_result.model_name
|
|
493
582
|
|
|
494
583
|
relevant_docs: List[FilterDoc] = doc_filter_result.docs
|
|
495
584
|
filter_time = time.time() - start_time
|
|
496
585
|
|
|
586
|
+
yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens,
|
|
587
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
|
|
588
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
589
|
+
"rag_docs_filter_result",
|
|
590
|
+
filter_time=filter_time,
|
|
591
|
+
docs_num=len(relevant_docs),
|
|
592
|
+
input_tokens=rag_stat.recall_stat.total_input_tokens,
|
|
593
|
+
output_tokens=rag_stat.recall_stat.total_generated_tokens,
|
|
594
|
+
model=rag_stat.recall_stat.model_name
|
|
595
|
+
)
|
|
596
|
+
))
|
|
597
|
+
|
|
497
598
|
# Filter relevant_docs to only include those with is_relevant=True
|
|
498
599
|
highly_relevant_docs = [
|
|
499
600
|
doc for doc in relevant_docs if doc.relevance.is_relevant
|
|
@@ -501,7 +602,8 @@ class LongContextRAG:
|
|
|
501
602
|
|
|
502
603
|
if highly_relevant_docs:
|
|
503
604
|
relevant_docs = highly_relevant_docs
|
|
504
|
-
logger.info(
|
|
605
|
+
logger.info(
|
|
606
|
+
f"Found {len(relevant_docs)} highly relevant documents")
|
|
505
607
|
|
|
506
608
|
logger.info(
|
|
507
609
|
f"Filter time: {filter_time:.2f} seconds with {len(relevant_docs)} docs"
|
|
@@ -511,7 +613,7 @@ class LongContextRAG:
|
|
|
511
613
|
final_docs = []
|
|
512
614
|
for doc in relevant_docs:
|
|
513
615
|
final_docs.append(doc.model_dump())
|
|
514
|
-
return [json.dumps(final_docs,ensure_ascii=False)], []
|
|
616
|
+
return [json.dumps(final_docs, ensure_ascii=False)], []
|
|
515
617
|
|
|
516
618
|
if not relevant_docs:
|
|
517
619
|
return ["没有找到相关的文档来回答这个问题。"], []
|
|
@@ -546,6 +648,12 @@ class LongContextRAG:
|
|
|
546
648
|
+ "".join([f"\n * {info}" for info in relevant_docs_info])
|
|
547
649
|
)
|
|
548
650
|
|
|
651
|
+
yield ("", SingleOutputMeta(generated_tokens_count=0,
|
|
652
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
653
|
+
"dynamic_chunking_start",
|
|
654
|
+
model=rag_stat.chunk_stat.model_name
|
|
655
|
+
)
|
|
656
|
+
))
|
|
549
657
|
first_round_full_docs = []
|
|
550
658
|
second_round_extracted_docs = []
|
|
551
659
|
sencond_round_time = 0
|
|
@@ -560,17 +668,19 @@ class LongContextRAG:
|
|
|
560
668
|
llm=self.llm,
|
|
561
669
|
disable_segment_reorder=self.args.disable_segment_reorder,
|
|
562
670
|
)
|
|
563
|
-
|
|
671
|
+
|
|
564
672
|
token_limiter_result = token_limiter.limit_tokens(
|
|
565
673
|
relevant_docs=relevant_docs,
|
|
566
674
|
conversations=conversations,
|
|
567
675
|
index_filter_workers=self.args.index_filter_workers or 5,
|
|
568
676
|
)
|
|
569
677
|
|
|
570
|
-
rag_stat.chunk_stat.total_input_tokens += sum(
|
|
571
|
-
|
|
678
|
+
rag_stat.chunk_stat.total_input_tokens += sum(
|
|
679
|
+
token_limiter_result.input_tokens_counts)
|
|
680
|
+
rag_stat.chunk_stat.total_generated_tokens += sum(
|
|
681
|
+
token_limiter_result.generated_tokens_counts)
|
|
572
682
|
rag_stat.chunk_stat.model_name = token_limiter_result.model_name
|
|
573
|
-
|
|
683
|
+
|
|
574
684
|
final_relevant_docs = token_limiter_result.docs
|
|
575
685
|
first_round_full_docs = token_limiter.first_round_full_docs
|
|
576
686
|
second_round_extracted_docs = token_limiter.second_round_extracted_docs
|
|
@@ -581,24 +691,41 @@ class LongContextRAG:
|
|
|
581
691
|
relevant_docs = relevant_docs[: self.args.index_filter_file_num]
|
|
582
692
|
|
|
583
693
|
logger.info(f"Finally send to model: {len(relevant_docs)}")
|
|
584
|
-
|
|
585
694
|
# 记录分段处理的统计信息
|
|
586
695
|
logger.info(
|
|
587
696
|
f"=== Token Management ===\n"
|
|
588
697
|
f" * Only contexts: {only_contexts}\n"
|
|
589
|
-
f" * Filter time: {filter_time:.2f} seconds\n"
|
|
698
|
+
f" * Filter time: {filter_time:.2f} seconds\n"
|
|
590
699
|
f" * Final relevant docs: {len(relevant_docs)}\n"
|
|
591
700
|
f" * First round full docs: {len(first_round_full_docs)}\n"
|
|
592
701
|
f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
|
|
593
702
|
f" * Second round time: {sencond_round_time:.2f} seconds"
|
|
594
703
|
)
|
|
595
704
|
|
|
705
|
+
yield ("", SingleOutputMeta(generated_tokens_count=rag_stat.chunk_stat.total_generated_tokens + rag_stat.recall_stat.total_generated_tokens,
|
|
706
|
+
input_tokens_count=rag_stat.chunk_stat.total_input_tokens +
|
|
707
|
+
rag_stat.recall_stat.total_input_tokens,
|
|
708
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
709
|
+
"dynamic_chunking_result",
|
|
710
|
+
model=rag_stat.chunk_stat.model_name,
|
|
711
|
+
docs_num=len(relevant_docs),
|
|
712
|
+
filter_time=filter_time,
|
|
713
|
+
sencond_round_time=sencond_round_time,
|
|
714
|
+
first_round_full_docs=len(
|
|
715
|
+
first_round_full_docs),
|
|
716
|
+
second_round_extracted_docs=len(
|
|
717
|
+
second_round_extracted_docs),
|
|
718
|
+
input_tokens=rag_stat.chunk_stat.total_input_tokens,
|
|
719
|
+
output_tokens=rag_stat.chunk_stat.total_generated_tokens
|
|
720
|
+
)
|
|
721
|
+
))
|
|
722
|
+
|
|
596
723
|
# 记录最终选择的文档详情
|
|
597
724
|
final_relevant_docs_info = []
|
|
598
725
|
for i, doc in enumerate(relevant_docs):
|
|
599
726
|
doc_path = doc.module_name.replace(self.path, '', 1)
|
|
600
727
|
info = f"{i+1}. {doc_path}"
|
|
601
|
-
|
|
728
|
+
|
|
602
729
|
metadata_info = []
|
|
603
730
|
if "original_docs" in doc.metadata:
|
|
604
731
|
original_docs = ", ".join(
|
|
@@ -608,26 +735,27 @@ class LongContextRAG:
|
|
|
608
735
|
]
|
|
609
736
|
)
|
|
610
737
|
metadata_info.append(f"Original docs: {original_docs}")
|
|
611
|
-
|
|
738
|
+
|
|
612
739
|
if "chunk_ranges" in doc.metadata:
|
|
613
740
|
chunk_ranges = json.dumps(
|
|
614
741
|
doc.metadata["chunk_ranges"], ensure_ascii=False
|
|
615
742
|
)
|
|
616
743
|
metadata_info.append(f"Chunk ranges: {chunk_ranges}")
|
|
617
|
-
|
|
744
|
+
|
|
618
745
|
if "processing_time" in doc.metadata:
|
|
619
|
-
metadata_info.append(
|
|
620
|
-
|
|
746
|
+
metadata_info.append(
|
|
747
|
+
f"Processing time: {doc.metadata['processing_time']:.2f}s")
|
|
748
|
+
|
|
621
749
|
if metadata_info:
|
|
622
750
|
info += f" ({'; '.join(metadata_info)})"
|
|
623
|
-
|
|
751
|
+
|
|
624
752
|
final_relevant_docs_info.append(info)
|
|
625
753
|
|
|
626
754
|
if final_relevant_docs_info:
|
|
627
755
|
logger.info(
|
|
628
756
|
f"Final documents to be sent to model:"
|
|
629
757
|
+ "".join([f"\n * {info}" for info in final_relevant_docs_info])
|
|
630
|
-
|
|
758
|
+
)
|
|
631
759
|
|
|
632
760
|
# 记录令牌统计
|
|
633
761
|
request_tokens = sum([doc.tokens for doc in relevant_docs])
|
|
@@ -638,7 +766,18 @@ class LongContextRAG:
|
|
|
638
766
|
f" * Total tokens: {request_tokens}"
|
|
639
767
|
)
|
|
640
768
|
|
|
641
|
-
logger.info(
|
|
769
|
+
logger.info(
|
|
770
|
+
f"Start to send to model {target_model} with {request_tokens} tokens")
|
|
771
|
+
|
|
772
|
+
yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
|
|
773
|
+
generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
|
|
774
|
+
rag_stat.chunk_stat.total_generated_tokens,
|
|
775
|
+
reasoning_content=get_message_with_format_and_newline(
|
|
776
|
+
"send_to_model",
|
|
777
|
+
model=target_model,
|
|
778
|
+
tokens=request_tokens
|
|
779
|
+
)
|
|
780
|
+
))
|
|
642
781
|
|
|
643
782
|
if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
|
|
644
783
|
llm_compute_engine = LLMComputeEngine(
|
|
@@ -650,53 +789,66 @@ class LongContextRAG:
|
|
|
650
789
|
debug=False,
|
|
651
790
|
)
|
|
652
791
|
new_conversations = llm_compute_engine.process_conversation(
|
|
653
|
-
conversations, query, [
|
|
792
|
+
conversations, query, [
|
|
793
|
+
doc.source_code for doc in relevant_docs]
|
|
654
794
|
)
|
|
655
795
|
chunks = llm_compute_engine.stream_chat_oai(
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
def generate_chunks():
|
|
664
|
-
for chunk in chunks:
|
|
665
|
-
yield chunk[0]
|
|
666
|
-
if chunk[1] is not None:
|
|
667
|
-
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
668
|
-
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
669
|
-
self._print_rag_stats(rag_stat)
|
|
670
|
-
return generate_chunks(), context
|
|
671
|
-
|
|
672
|
-
new_conversations = conversations[:-1] + [
|
|
673
|
-
{
|
|
674
|
-
"role": "user",
|
|
675
|
-
"content": self._answer_question.prompt(
|
|
676
|
-
query=query,
|
|
677
|
-
relevant_docs=[doc.source_code for doc in relevant_docs],
|
|
678
|
-
),
|
|
679
|
-
}
|
|
680
|
-
]
|
|
681
|
-
|
|
682
|
-
chunks = target_llm.stream_chat_oai(
|
|
683
|
-
conversations=new_conversations,
|
|
684
|
-
model=model,
|
|
685
|
-
role_mapping=role_mapping,
|
|
686
|
-
llm_config=llm_config,
|
|
687
|
-
delta_mode=True,
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
def generate_chunks():
|
|
796
|
+
conversations=new_conversations,
|
|
797
|
+
model=model,
|
|
798
|
+
role_mapping=role_mapping,
|
|
799
|
+
llm_config=llm_config,
|
|
800
|
+
delta_mode=True,
|
|
801
|
+
)
|
|
802
|
+
|
|
691
803
|
for chunk in chunks:
|
|
692
|
-
yield chunk[0]
|
|
693
804
|
if chunk[1] is not None:
|
|
694
805
|
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
695
|
-
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
806
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
807
|
+
chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
|
|
808
|
+
rag_stat.chunk_stat.total_input_tokens + \
|
|
809
|
+
rag_stat.answer_stat.total_input_tokens
|
|
810
|
+
chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
|
|
811
|
+
rag_stat.chunk_stat.total_generated_tokens + \
|
|
812
|
+
rag_stat.answer_stat.total_generated_tokens
|
|
813
|
+
yield chunk
|
|
814
|
+
|
|
815
|
+
self._print_rag_stats(rag_stat)
|
|
816
|
+
else:
|
|
817
|
+
new_conversations = conversations[:-1] + [
|
|
818
|
+
{
|
|
819
|
+
"role": "user",
|
|
820
|
+
"content": self._answer_question.prompt(
|
|
821
|
+
query=query,
|
|
822
|
+
relevant_docs=[
|
|
823
|
+
doc.source_code for doc in relevant_docs],
|
|
824
|
+
),
|
|
825
|
+
}
|
|
826
|
+
]
|
|
827
|
+
|
|
828
|
+
chunks = target_llm.stream_chat_oai(
|
|
829
|
+
conversations=new_conversations,
|
|
830
|
+
model=model,
|
|
831
|
+
role_mapping=role_mapping,
|
|
832
|
+
llm_config=llm_config,
|
|
833
|
+
delta_mode=True,
|
|
834
|
+
extra_request_params=extra_request_params
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
for chunk in chunks:
|
|
838
|
+
if chunk[1] is not None:
|
|
839
|
+
rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
|
|
840
|
+
rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
|
|
841
|
+
chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
|
|
842
|
+
rag_stat.chunk_stat.total_input_tokens + \
|
|
843
|
+
rag_stat.answer_stat.total_input_tokens
|
|
844
|
+
chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
|
|
845
|
+
rag_stat.chunk_stat.total_generated_tokens + \
|
|
846
|
+
rag_stat.answer_stat.total_generated_tokens
|
|
847
|
+
yield chunk
|
|
848
|
+
|
|
849
|
+
self._print_rag_stats(rag_stat)
|
|
850
|
+
|
|
851
|
+
return generate_sream(), context
|
|
700
852
|
|
|
701
853
|
def _print_rag_stats(self, rag_stat: RAGStat) -> None:
|
|
702
854
|
"""打印RAG执行的详细统计信息"""
|
|
@@ -707,19 +859,22 @@ class LongContextRAG:
|
|
|
707
859
|
)
|
|
708
860
|
total_generated_tokens = (
|
|
709
861
|
rag_stat.recall_stat.total_generated_tokens +
|
|
710
|
-
rag_stat.chunk_stat.total_generated_tokens +
|
|
862
|
+
rag_stat.chunk_stat.total_generated_tokens +
|
|
711
863
|
rag_stat.answer_stat.total_generated_tokens
|
|
712
864
|
)
|
|
713
865
|
total_tokens = total_input_tokens + total_generated_tokens
|
|
714
|
-
|
|
866
|
+
|
|
715
867
|
# 避免除以零错误
|
|
716
868
|
if total_tokens == 0:
|
|
717
869
|
recall_percent = chunk_percent = answer_percent = 0
|
|
718
870
|
else:
|
|
719
|
-
recall_percent = (rag_stat.recall_stat.total_input_tokens +
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
871
|
+
recall_percent = (rag_stat.recall_stat.total_input_tokens +
|
|
872
|
+
rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
|
|
873
|
+
chunk_percent = (rag_stat.chunk_stat.total_input_tokens +
|
|
874
|
+
rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
|
|
875
|
+
answer_percent = (rag_stat.answer_stat.total_input_tokens +
|
|
876
|
+
rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
|
|
877
|
+
|
|
723
878
|
logger.info(
|
|
724
879
|
f"=== RAG 执行统计信息 ===\n"
|
|
725
880
|
f"总令牌使用: {total_tokens} 令牌\n"
|
|
@@ -750,21 +905,22 @@ class LongContextRAG:
|
|
|
750
905
|
f" - 文档分块: {chunk_percent:.1f}%\n"
|
|
751
906
|
f" - 答案生成: {answer_percent:.1f}%\n"
|
|
752
907
|
)
|
|
753
|
-
|
|
908
|
+
|
|
754
909
|
# 记录原始统计数据,以便调试
|
|
755
910
|
logger.debug(f"RAG Stat 原始数据: {rag_stat}")
|
|
756
|
-
|
|
911
|
+
|
|
757
912
|
# 返回成本估算
|
|
758
|
-
estimated_cost = self._estimate_token_cost(
|
|
913
|
+
estimated_cost = self._estimate_token_cost(
|
|
914
|
+
total_input_tokens, total_generated_tokens)
|
|
759
915
|
if estimated_cost > 0:
|
|
760
916
|
logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
|
|
761
917
|
|
|
762
918
|
def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
|
|
763
|
-
"""估算当前请求的令牌成本(人民币)"""
|
|
919
|
+
"""估算当前请求的令牌成本(人民币)"""
|
|
764
920
|
# 实际应用中,可以根据不同模型设置不同价格
|
|
765
921
|
input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
|
|
766
922
|
output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
|
|
767
|
-
|
|
768
|
-
cost = (input_tokens * input_cost_per_1m / 1000000) +
|
|
923
|
+
|
|
924
|
+
cost = (input_tokens * input_cost_per_1m / 1000000) + \
|
|
925
|
+
(output_tokens * output_cost_per_1m/1000000)
|
|
769
926
|
return cost
|
|
770
|
-
|