auto-coder 0.1.278__py3-none-any.whl → 0.1.280__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -23,6 +23,8 @@ from autocoder.rag.relevant_utils import (
23
23
  FilterDoc,
24
24
  TaskTiming,
25
25
  parse_relevance,
26
+ ProgressUpdate,
27
+ DocFilterResult
26
28
  )
27
29
  from autocoder.rag.token_checker import check_token_limit
28
30
  from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
@@ -33,14 +35,18 @@ from importlib.metadata import version
33
35
  from autocoder.rag.stream_event import event_writer
34
36
  from autocoder.rag.relevant_utils import DocFilterResult
35
37
  from pydantic import BaseModel
38
+ from byzerllm.utils.types import SingleOutputMeta
39
+ from autocoder.rag.lang import get_message_with_format_and_newline
36
40
 
37
- try:
41
+ try:
38
42
  from autocoder_pro.rag.llm_compute import LLMComputeEngine
39
43
  pro_version = version("auto-coder-pro")
40
44
  autocoder_version = version("auto-coder")
41
- logger.warning(f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
45
+ logger.warning(
46
+ f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
42
47
  except ImportError:
43
- logger.warning("Please install auto-coder-pro to enhance llm compute ability")
48
+ logger.warning(
49
+ "Please install auto-coder-pro to enhance llm compute ability")
44
50
  LLMComputeEngine = None
45
51
 
46
52
 
@@ -48,20 +54,26 @@ class RecallStat(BaseModel):
48
54
  total_input_tokens: int
49
55
  total_generated_tokens: int
50
56
  model_name: str = "unknown"
57
+
58
+
51
59
  class ChunkStat(BaseModel):
52
60
  total_input_tokens: int
53
- total_generated_tokens: int
61
+ total_generated_tokens: int
54
62
  model_name: str = "unknown"
63
+
64
+
55
65
  class AnswerStat(BaseModel):
56
66
  total_input_tokens: int
57
67
  total_generated_tokens: int
58
68
  model_name: str = "unknown"
59
69
 
70
+
60
71
  class RAGStat(BaseModel):
61
72
  recall_stat: RecallStat
62
73
  chunk_stat: ChunkStat
63
74
  answer_stat: AnswerStat
64
75
 
76
+
65
77
  class LongContextRAG:
66
78
  def __init__(
67
79
  self,
@@ -85,7 +97,7 @@ class LongContextRAG:
85
97
  self.chunk_llm = self.llm.get_sub_client("chunk_model")
86
98
 
87
99
  self.args = args
88
-
100
+
89
101
  self.path = path
90
102
  self.relevant_score = self.args.rag_doc_filter_relevance or 5
91
103
 
@@ -98,8 +110,10 @@ class LongContextRAG:
98
110
  "The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
99
111
  )
100
112
 
101
- self.full_text_limit = int(args.rag_context_window_limit * self.full_text_ratio)
102
- self.segment_limit = int(args.rag_context_window_limit * self.segment_ratio)
113
+ self.full_text_limit = int(
114
+ args.rag_context_window_limit * self.full_text_ratio)
115
+ self.segment_limit = int(
116
+ args.rag_context_window_limit * self.segment_ratio)
103
117
  self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
104
118
 
105
119
  self.tokenizer = None
@@ -108,7 +122,8 @@ class LongContextRAG:
108
122
 
109
123
  if self.tokenizer_path:
110
124
  VariableHolder.TOKENIZER_PATH = self.tokenizer_path
111
- VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(self.tokenizer_path)
125
+ VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
126
+ self.tokenizer_path)
112
127
  self.tokenizer = TokenCounter(self.tokenizer_path)
113
128
  else:
114
129
  if llm.is_model_exist("deepseek_tokenizer"):
@@ -160,9 +175,9 @@ class LongContextRAG:
160
175
  self.required_exts,
161
176
  self.on_ray,
162
177
  self.monitor_mode,
163
- ## 确保全文区至少能放下一个文件
178
+ # 确保全文区至少能放下一个文件
164
179
  single_file_token_limit=self.full_text_limit - 100,
165
- disable_auto_window=self.args.disable_auto_window,
180
+ disable_auto_window=self.args.disable_auto_window,
166
181
  enable_hybrid_index=self.args.enable_hybrid_index,
167
182
  extra_params=self.args
168
183
  )
@@ -223,14 +238,14 @@ class LongContextRAG:
223
238
  {% for msg in conversations %}
224
239
  [{{ msg.role }}]:
225
240
  {{ msg.content }}
226
-
241
+
227
242
  {% endfor %}
228
243
  </conversations>
229
244
 
230
245
  请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
231
246
  如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
232
247
  提取的信息尽量保持和原文中的一样,并且只输出这些信息。
233
- """
248
+ """
234
249
 
235
250
  @byzerllm.prompt()
236
251
  def _answer_question(
@@ -265,20 +280,20 @@ class LongContextRAG:
265
280
  """Get the document retriever class based on configuration."""
266
281
  # Default to LocalDocumentRetriever if not specified
267
282
  return LocalDocumentRetriever
268
-
283
+
269
284
  def _load_ignore_file(self):
270
285
  serveignore_path = os.path.join(self.path, ".serveignore")
271
286
  gitignore_path = os.path.join(self.path, ".gitignore")
272
287
 
273
288
  if os.path.exists(serveignore_path):
274
- with open(serveignore_path, "r",encoding="utf-8") as ignore_file:
289
+ with open(serveignore_path, "r", encoding="utf-8") as ignore_file:
275
290
  return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
276
291
  elif os.path.exists(gitignore_path):
277
- with open(gitignore_path, "r",encoding="utf-8") as ignore_file:
292
+ with open(gitignore_path, "r", encoding="utf-8") as ignore_file:
278
293
  return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
279
294
  return None
280
295
 
281
- def _retrieve_documents(self,options:Optional[Dict[str,Any]]=None) -> Generator[SourceCode, None, None]:
296
+ def _retrieve_documents(self, options: Optional[Dict[str, Any]] = None) -> Generator[SourceCode, None, None]:
282
297
  return self.document_retriever.retrieve_documents(options=options)
283
298
 
284
299
  def build(self):
@@ -298,7 +313,8 @@ class LongContextRAG:
298
313
  only_contexts = True
299
314
 
300
315
  logger.info("Search from RAG.....")
301
- logger.info(f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
316
+ logger.info(
317
+ f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
302
318
 
303
319
  if self.client:
304
320
  new_query = json.dumps(
@@ -314,7 +330,8 @@ class LongContextRAG:
314
330
  if not only_contexts:
315
331
  return [SourceCode(module_name=f"RAG:{target_query}", source_code=v)]
316
332
 
317
- json_lines = [json.loads(line) for line in v.split("\n") if line.strip()]
333
+ json_lines = [json.loads(line)
334
+ for line in v.split("\n") if line.strip()]
318
335
  return [SourceCode.model_validate(json_line) for json_line in json_lines]
319
336
  else:
320
337
  if only_contexts:
@@ -333,7 +350,7 @@ class LongContextRAG:
333
350
 
334
351
  def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
335
352
  query = conversations[-1]["content"]
336
- documents = self._retrieve_documents(options={"query":query})
353
+ documents = self._retrieve_documents(options={"query": query})
337
354
  return self.doc_filter.filter_docs(
338
355
  conversations=conversations, documents=documents
339
356
  )
@@ -344,6 +361,7 @@ class LongContextRAG:
344
361
  model: Optional[str] = None,
345
362
  role_mapping=None,
346
363
  llm_config: Dict[str, Any] = {},
364
+ extra_request_params: Dict[str, Any] = {}
347
365
  ):
348
366
  try:
349
367
  return self._stream_chat_oai(
@@ -351,18 +369,49 @@ class LongContextRAG:
351
369
  model=model,
352
370
  role_mapping=role_mapping,
353
371
  llm_config=llm_config,
372
+ extra_request_params=extra_request_params
354
373
  )
355
374
  except Exception as e:
356
375
  logger.error(f"Error in stream_chat_oai: {str(e)}")
357
376
  traceback.print_exc()
358
377
  return ["出现错误,请稍后再试。"], []
359
378
 
379
+ def _stream_chatfrom_openai_sdk(self, response):
380
+ for chunk in response:
381
+ if hasattr(chunk, "usage") and chunk.usage:
382
+ input_tokens_count = chunk.usage.prompt_tokens
383
+ generated_tokens_count = chunk.usage.completion_tokens
384
+ else:
385
+ input_tokens_count = 0
386
+ generated_tokens_count = 0
387
+
388
+ if not chunk.choices:
389
+ if last_meta:
390
+ yield ("", SingleOutputMeta(input_tokens_count=input_tokens_count,
391
+ generated_tokens_count=generated_tokens_count,
392
+ reasoning_content="",
393
+ finish_reason=last_meta.finish_reason))
394
+ continue
395
+
396
+ content = chunk.choices[0].delta.content or ""
397
+
398
+ reasoning_text = ""
399
+ if hasattr(chunk.choices[0].delta, "reasoning_content"):
400
+ reasoning_text = chunk.choices[0].delta.reasoning_content or ""
401
+
402
+ last_meta = SingleOutputMeta(input_tokens_count=input_tokens_count,
403
+ generated_tokens_count=generated_tokens_count,
404
+ reasoning_content=reasoning_text,
405
+ finish_reason=chunk.choices[0].finish_reason)
406
+ yield (content, last_meta)
407
+
360
408
  def _stream_chat_oai(
361
409
  self,
362
410
  conversations,
363
411
  model: Optional[str] = None,
364
412
  role_mapping=None,
365
413
  llm_config: Dict[str, Any] = {},
414
+ extra_request_params: Dict[str, Any] = {}
366
415
  ):
367
416
  if self.client:
368
417
  model = model or self.args.model
@@ -370,130 +419,182 @@ class LongContextRAG:
370
419
  model=model,
371
420
  messages=conversations,
372
421
  stream=True,
373
- max_tokens=self.args.rag_params_max_tokens
422
+ max_tokens=self.args.rag_params_max_tokens,
423
+ extra_body=extra_request_params
374
424
  )
425
+ return self._stream_chatfrom_openai_sdk(response), []
375
426
 
376
- def response_generator():
377
- for chunk in response:
378
- if chunk.choices[0].delta.content is not None:
379
- yield chunk.choices[0].delta.content
427
+ target_llm = self.llm
428
+ if self.llm.get_sub_client("qa_model"):
429
+ target_llm = self.llm.get_sub_client("qa_model")
380
430
 
381
- return response_generator(), []
382
- else:
383
-
384
- target_llm = self.llm
385
- if self.llm.get_sub_client("qa_model"):
386
- target_llm = self.llm.get_sub_client("qa_model")
431
+ query = conversations[-1]["content"]
432
+ context = []
387
433
 
388
- query = conversations[-1]["content"]
389
- context = []
434
+ if (
435
+ "使用四到五个字直接返回这句话的简要主题,不要解释、不要标点、不要语气词、不要多余文本,不要加粗,如果没有主题"
436
+ in query
437
+ or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
438
+ in query
439
+ ):
390
440
 
391
- if (
392
- "使用四到五个字直接返回这句话的简要主题,不要解释、不要标点、不要语气词、不要多余文本,不要加粗,如果没有主题"
393
- in query
394
- or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
395
- in query
396
- ):
441
+ chunks = target_llm.stream_chat_oai(
442
+ conversations=conversations,
443
+ model=model,
444
+ role_mapping=role_mapping,
445
+ llm_config=llm_config,
446
+ delta_mode=True,
447
+ extra_request_params=extra_request_params
448
+ )
397
449
 
398
- chunks = target_llm.stream_chat_oai(
399
- conversations=conversations,
400
- model=model,
401
- role_mapping=role_mapping,
402
- llm_config=llm_config,
403
- delta_mode=True,
404
- )
405
- return (chunk[0] for chunk in chunks), context
406
-
407
- try:
408
- request_params = json.loads(query)
409
- if "request_id" in request_params:
410
- request_id = request_params["request_id"]
411
- index = request_params["index"]
412
-
413
- file_path = event_writer.get_event_file_path(request_id)
414
- logger.info(f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
415
- events = []
416
- if not os.path.exists(file_path):
417
- return [],context
418
-
419
- with open(file_path, "r") as f:
420
- for line in f:
421
- event = json.loads(line)
422
- if event["index"] >= index:
423
- events.append(event)
424
- return [json.dumps({
425
- "events": [event for event in events],
426
- },ensure_ascii=False)], context
427
- except json.JSONDecodeError:
428
- pass
429
-
430
- if self.args.without_contexts and LLMComputeEngine is not None:
431
- llm_compute_engine = LLMComputeEngine(
432
- llm=target_llm,
433
- inference_enhance=not self.args.disable_inference_enhance,
434
- inference_deep_thought=self.args.inference_deep_thought,
435
- inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
436
- precision=self.args.inference_compute_precision,
437
- data_cells_max_num=self.args.data_cells_max_num,
438
- )
439
- conversations = conversations[:-1]
440
- new_conversations = llm_compute_engine.process_conversation(
441
- conversations, query, []
442
- )
450
+ def generate_chunks():
451
+ for chunk in chunks:
452
+ yield chunk
453
+ return generate_chunks(), context
443
454
 
444
- return (
445
- llm_compute_engine.stream_chat_oai(
446
- conversations=new_conversations,
447
- model=model,
448
- role_mapping=role_mapping,
449
- llm_config=llm_config,
450
- delta_mode=True,
451
- ),
452
- context,
453
- )
455
+ try:
456
+ request_params = json.loads(query)
457
+ if "request_id" in request_params:
458
+ request_id = request_params["request_id"]
459
+ index = request_params["index"]
454
460
 
461
+ file_path = event_writer.get_event_file_path(request_id)
462
+ logger.info(
463
+ f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
464
+ events = []
465
+ if not os.path.exists(file_path):
466
+ return [], context
467
+
468
+ with open(file_path, "r") as f:
469
+ for line in f:
470
+ event = json.loads(line)
471
+ if event["index"] >= index:
472
+ events.append(event)
473
+ return [json.dumps({
474
+ "events": [event for event in events],
475
+ }, ensure_ascii=False)], context
476
+ except json.JSONDecodeError:
477
+ pass
478
+
479
+ if self.args.without_contexts and LLMComputeEngine is not None:
480
+ llm_compute_engine = LLMComputeEngine(
481
+ llm=target_llm,
482
+ inference_enhance=not self.args.disable_inference_enhance,
483
+ inference_deep_thought=self.args.inference_deep_thought,
484
+ inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
485
+ precision=self.args.inference_compute_precision,
486
+ data_cells_max_num=self.args.data_cells_max_num,
487
+ )
488
+ conversations = conversations[:-1]
489
+ new_conversations = llm_compute_engine.process_conversation(
490
+ conversations, query, []
491
+ )
492
+ chunks = llm_compute_engine.stream_chat_oai(
493
+ conversations=new_conversations,
494
+ model=model,
495
+ role_mapping=role_mapping,
496
+ llm_config=llm_config,
497
+ delta_mode=True,
498
+ extra_request_params=extra_request_params
499
+ )
500
+
501
+ def generate_chunks():
502
+ for chunk in chunks:
503
+ yield chunk
455
504
 
456
- only_contexts = False
457
- try:
458
- v = json.loads(query)
459
- if "only_contexts" in v:
460
- query = v["query"]
461
- only_contexts = v["only_contexts"]
462
- conversations[-1]["content"] = query
463
- except json.JSONDecodeError:
464
- pass
465
-
466
- logger.info(f"Query: {query} only_contexts: {only_contexts}")
467
- start_time = time.time()
468
-
469
-
470
- rag_stat = RAGStat(
471
- recall_stat=RecallStat(
472
- total_input_tokens=0,
473
- total_generated_tokens=0,
474
- model_name=self.recall_llm.default_model_name,
475
- ),
476
- chunk_stat=ChunkStat(
477
- total_input_tokens=0,
478
- total_generated_tokens=0,
479
- model_name=self.chunk_llm.default_model_name,
480
- ),
481
- answer_stat=AnswerStat(
482
- total_input_tokens=0,
483
- total_generated_tokens=0,
484
- model_name=self.qa_llm.default_model_name,
485
- ),
505
+ return (
506
+ generate_chunks(),
507
+ context,
486
508
  )
487
509
 
488
- doc_filter_result = self._filter_docs(conversations)
510
+ only_contexts = False
511
+ try:
512
+ v = json.loads(query)
513
+ if "only_contexts" in v:
514
+ query = v["query"]
515
+ only_contexts = v["only_contexts"]
516
+ conversations[-1]["content"] = query
517
+ except json.JSONDecodeError:
518
+ pass
519
+
520
+ logger.info(f"Query: {query} only_contexts: {only_contexts}")
521
+ start_time = time.time()
522
+
523
+ rag_stat = RAGStat(
524
+ recall_stat=RecallStat(
525
+ total_input_tokens=0,
526
+ total_generated_tokens=0,
527
+ model_name=self.recall_llm.default_model_name,
528
+ ),
529
+ chunk_stat=ChunkStat(
530
+ total_input_tokens=0,
531
+ total_generated_tokens=0,
532
+ model_name=self.chunk_llm.default_model_name,
533
+ ),
534
+ answer_stat=AnswerStat(
535
+ total_input_tokens=0,
536
+ total_generated_tokens=0,
537
+ model_name=self.qa_llm.default_model_name,
538
+ ),
539
+ )
489
540
 
490
- rag_stat.recall_stat.total_input_tokens += sum(doc_filter_result.input_tokens_counts)
491
- rag_stat.recall_stat.total_generated_tokens += sum(doc_filter_result.generated_tokens_counts)
541
+ context = []
542
+
543
+ def generate_sream():
544
+ nonlocal context
545
+
546
+ yield ("", SingleOutputMeta(input_tokens_count=0,
547
+ generated_tokens_count=0,
548
+ reasoning_content=get_message_with_format_and_newline(
549
+ "rag_searching_docs",
550
+ model=rag_stat.recall_stat.model_name
551
+ )
552
+ ))
553
+
554
+ doc_filter_result = DocFilterResult(
555
+ docs=[],
556
+ raw_docs=[],
557
+ input_tokens_counts=[],
558
+ generated_tokens_counts=[],
559
+ durations=[],
560
+ model_name=rag_stat.recall_stat.model_name
561
+ )
562
+ query = conversations[-1]["content"]
563
+ documents = self._retrieve_documents(options={"query": query})
564
+
565
+ # 使用带进度报告的过滤方法
566
+ for progress_update, result in self.doc_filter.filter_docs_with_progress(conversations, documents):
567
+ if result is not None:
568
+ doc_filter_result = result
569
+ else:
570
+ # 生成进度更新
571
+ yield ("", SingleOutputMeta(
572
+ input_tokens_count=rag_stat.recall_stat.total_input_tokens,
573
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
574
+ reasoning_content=f"{progress_update.message} ({progress_update.completed}/{progress_update.total})"
575
+ ))
576
+
577
+ rag_stat.recall_stat.total_input_tokens += sum(
578
+ doc_filter_result.input_tokens_counts)
579
+ rag_stat.recall_stat.total_generated_tokens += sum(
580
+ doc_filter_result.generated_tokens_counts)
492
581
  rag_stat.recall_stat.model_name = doc_filter_result.model_name
493
582
 
494
583
  relevant_docs: List[FilterDoc] = doc_filter_result.docs
495
584
  filter_time = time.time() - start_time
496
585
 
586
+ yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens,
587
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
588
+ reasoning_content=get_message_with_format_and_newline(
589
+ "rag_docs_filter_result",
590
+ filter_time=filter_time,
591
+ docs_num=len(relevant_docs),
592
+ input_tokens=rag_stat.recall_stat.total_input_tokens,
593
+ output_tokens=rag_stat.recall_stat.total_generated_tokens,
594
+ model=rag_stat.recall_stat.model_name
595
+ )
596
+ ))
597
+
497
598
  # Filter relevant_docs to only include those with is_relevant=True
498
599
  highly_relevant_docs = [
499
600
  doc for doc in relevant_docs if doc.relevance.is_relevant
@@ -501,7 +602,8 @@ class LongContextRAG:
501
602
 
502
603
  if highly_relevant_docs:
503
604
  relevant_docs = highly_relevant_docs
504
- logger.info(f"Found {len(relevant_docs)} highly relevant documents")
605
+ logger.info(
606
+ f"Found {len(relevant_docs)} highly relevant documents")
505
607
 
506
608
  logger.info(
507
609
  f"Filter time: {filter_time:.2f} seconds with {len(relevant_docs)} docs"
@@ -511,7 +613,7 @@ class LongContextRAG:
511
613
  final_docs = []
512
614
  for doc in relevant_docs:
513
615
  final_docs.append(doc.model_dump())
514
- return [json.dumps(final_docs,ensure_ascii=False)], []
616
+ return [json.dumps(final_docs, ensure_ascii=False)], []
515
617
 
516
618
  if not relevant_docs:
517
619
  return ["没有找到相关的文档来回答这个问题。"], []
@@ -546,6 +648,12 @@ class LongContextRAG:
546
648
  + "".join([f"\n * {info}" for info in relevant_docs_info])
547
649
  )
548
650
 
651
+ yield ("", SingleOutputMeta(generated_tokens_count=0,
652
+ reasoning_content=get_message_with_format_and_newline(
653
+ "dynamic_chunking_start",
654
+ model=rag_stat.chunk_stat.model_name
655
+ )
656
+ ))
549
657
  first_round_full_docs = []
550
658
  second_round_extracted_docs = []
551
659
  sencond_round_time = 0
@@ -560,17 +668,19 @@ class LongContextRAG:
560
668
  llm=self.llm,
561
669
  disable_segment_reorder=self.args.disable_segment_reorder,
562
670
  )
563
-
671
+
564
672
  token_limiter_result = token_limiter.limit_tokens(
565
673
  relevant_docs=relevant_docs,
566
674
  conversations=conversations,
567
675
  index_filter_workers=self.args.index_filter_workers or 5,
568
676
  )
569
677
 
570
- rag_stat.chunk_stat.total_input_tokens += sum(token_limiter_result.input_tokens_counts)
571
- rag_stat.chunk_stat.total_generated_tokens += sum(token_limiter_result.generated_tokens_counts)
678
+ rag_stat.chunk_stat.total_input_tokens += sum(
679
+ token_limiter_result.input_tokens_counts)
680
+ rag_stat.chunk_stat.total_generated_tokens += sum(
681
+ token_limiter_result.generated_tokens_counts)
572
682
  rag_stat.chunk_stat.model_name = token_limiter_result.model_name
573
-
683
+
574
684
  final_relevant_docs = token_limiter_result.docs
575
685
  first_round_full_docs = token_limiter.first_round_full_docs
576
686
  second_round_extracted_docs = token_limiter.second_round_extracted_docs
@@ -581,24 +691,41 @@ class LongContextRAG:
581
691
  relevant_docs = relevant_docs[: self.args.index_filter_file_num]
582
692
 
583
693
  logger.info(f"Finally send to model: {len(relevant_docs)}")
584
-
585
694
  # 记录分段处理的统计信息
586
695
  logger.info(
587
696
  f"=== Token Management ===\n"
588
697
  f" * Only contexts: {only_contexts}\n"
589
- f" * Filter time: {filter_time:.2f} seconds\n"
698
+ f" * Filter time: {filter_time:.2f} seconds\n"
590
699
  f" * Final relevant docs: {len(relevant_docs)}\n"
591
700
  f" * First round full docs: {len(first_round_full_docs)}\n"
592
701
  f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
593
702
  f" * Second round time: {sencond_round_time:.2f} seconds"
594
703
  )
595
704
 
705
+ yield ("", SingleOutputMeta(generated_tokens_count=rag_stat.chunk_stat.total_generated_tokens + rag_stat.recall_stat.total_generated_tokens,
706
+ input_tokens_count=rag_stat.chunk_stat.total_input_tokens +
707
+ rag_stat.recall_stat.total_input_tokens,
708
+ reasoning_content=get_message_with_format_and_newline(
709
+ "dynamic_chunking_result",
710
+ model=rag_stat.chunk_stat.model_name,
711
+ docs_num=len(relevant_docs),
712
+ filter_time=filter_time,
713
+ sencond_round_time=sencond_round_time,
714
+ first_round_full_docs=len(
715
+ first_round_full_docs),
716
+ second_round_extracted_docs=len(
717
+ second_round_extracted_docs),
718
+ input_tokens=rag_stat.chunk_stat.total_input_tokens,
719
+ output_tokens=rag_stat.chunk_stat.total_generated_tokens
720
+ )
721
+ ))
722
+
596
723
  # 记录最终选择的文档详情
597
724
  final_relevant_docs_info = []
598
725
  for i, doc in enumerate(relevant_docs):
599
726
  doc_path = doc.module_name.replace(self.path, '', 1)
600
727
  info = f"{i+1}. {doc_path}"
601
-
728
+
602
729
  metadata_info = []
603
730
  if "original_docs" in doc.metadata:
604
731
  original_docs = ", ".join(
@@ -608,26 +735,27 @@ class LongContextRAG:
608
735
  ]
609
736
  )
610
737
  metadata_info.append(f"Original docs: {original_docs}")
611
-
738
+
612
739
  if "chunk_ranges" in doc.metadata:
613
740
  chunk_ranges = json.dumps(
614
741
  doc.metadata["chunk_ranges"], ensure_ascii=False
615
742
  )
616
743
  metadata_info.append(f"Chunk ranges: {chunk_ranges}")
617
-
744
+
618
745
  if "processing_time" in doc.metadata:
619
- metadata_info.append(f"Processing time: {doc.metadata['processing_time']:.2f}s")
620
-
746
+ metadata_info.append(
747
+ f"Processing time: {doc.metadata['processing_time']:.2f}s")
748
+
621
749
  if metadata_info:
622
750
  info += f" ({'; '.join(metadata_info)})"
623
-
751
+
624
752
  final_relevant_docs_info.append(info)
625
753
 
626
754
  if final_relevant_docs_info:
627
755
  logger.info(
628
756
  f"Final documents to be sent to model:"
629
757
  + "".join([f"\n * {info}" for info in final_relevant_docs_info])
630
- )
758
+ )
631
759
 
632
760
  # 记录令牌统计
633
761
  request_tokens = sum([doc.tokens for doc in relevant_docs])
@@ -638,7 +766,18 @@ class LongContextRAG:
638
766
  f" * Total tokens: {request_tokens}"
639
767
  )
640
768
 
641
- logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
769
+ logger.info(
770
+ f"Start to send to model {target_model} with {request_tokens} tokens")
771
+
772
+ yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
773
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
774
+ rag_stat.chunk_stat.total_generated_tokens,
775
+ reasoning_content=get_message_with_format_and_newline(
776
+ "send_to_model",
777
+ model=target_model,
778
+ tokens=request_tokens
779
+ )
780
+ ))
642
781
 
643
782
  if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
644
783
  llm_compute_engine = LLMComputeEngine(
@@ -650,53 +789,66 @@ class LongContextRAG:
650
789
  debug=False,
651
790
  )
652
791
  new_conversations = llm_compute_engine.process_conversation(
653
- conversations, query, [doc.source_code for doc in relevant_docs]
792
+ conversations, query, [
793
+ doc.source_code for doc in relevant_docs]
654
794
  )
655
795
  chunks = llm_compute_engine.stream_chat_oai(
656
- conversations=new_conversations,
657
- model=model,
658
- role_mapping=role_mapping,
659
- llm_config=llm_config,
660
- delta_mode=True,
661
- )
662
-
663
- def generate_chunks():
664
- for chunk in chunks:
665
- yield chunk[0]
666
- if chunk[1] is not None:
667
- rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
668
- rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
669
- self._print_rag_stats(rag_stat)
670
- return generate_chunks(), context
671
-
672
- new_conversations = conversations[:-1] + [
673
- {
674
- "role": "user",
675
- "content": self._answer_question.prompt(
676
- query=query,
677
- relevant_docs=[doc.source_code for doc in relevant_docs],
678
- ),
679
- }
680
- ]
681
-
682
- chunks = target_llm.stream_chat_oai(
683
- conversations=new_conversations,
684
- model=model,
685
- role_mapping=role_mapping,
686
- llm_config=llm_config,
687
- delta_mode=True,
688
- )
689
-
690
- def generate_chunks():
796
+ conversations=new_conversations,
797
+ model=model,
798
+ role_mapping=role_mapping,
799
+ llm_config=llm_config,
800
+ delta_mode=True,
801
+ )
802
+
691
803
  for chunk in chunks:
692
- yield chunk[0]
693
804
  if chunk[1] is not None:
694
805
  rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
695
- rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
696
- self._print_rag_stats(rag_stat)
697
- return generate_chunks(), context
698
-
699
-
806
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
807
+ chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
808
+ rag_stat.chunk_stat.total_input_tokens + \
809
+ rag_stat.answer_stat.total_input_tokens
810
+ chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
811
+ rag_stat.chunk_stat.total_generated_tokens + \
812
+ rag_stat.answer_stat.total_generated_tokens
813
+ yield chunk
814
+
815
+ self._print_rag_stats(rag_stat)
816
+ else:
817
+ new_conversations = conversations[:-1] + [
818
+ {
819
+ "role": "user",
820
+ "content": self._answer_question.prompt(
821
+ query=query,
822
+ relevant_docs=[
823
+ doc.source_code for doc in relevant_docs],
824
+ ),
825
+ }
826
+ ]
827
+
828
+ chunks = target_llm.stream_chat_oai(
829
+ conversations=new_conversations,
830
+ model=model,
831
+ role_mapping=role_mapping,
832
+ llm_config=llm_config,
833
+ delta_mode=True,
834
+ extra_request_params=extra_request_params
835
+ )
836
+
837
+ for chunk in chunks:
838
+ if chunk[1] is not None:
839
+ rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
840
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
841
+ chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
842
+ rag_stat.chunk_stat.total_input_tokens + \
843
+ rag_stat.answer_stat.total_input_tokens
844
+ chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
845
+ rag_stat.chunk_stat.total_generated_tokens + \
846
+ rag_stat.answer_stat.total_generated_tokens
847
+ yield chunk
848
+
849
+ self._print_rag_stats(rag_stat)
850
+
851
+ return generate_sream(), context
700
852
 
701
853
  def _print_rag_stats(self, rag_stat: RAGStat) -> None:
702
854
  """打印RAG执行的详细统计信息"""
@@ -707,19 +859,22 @@ class LongContextRAG:
707
859
  )
708
860
  total_generated_tokens = (
709
861
  rag_stat.recall_stat.total_generated_tokens +
710
- rag_stat.chunk_stat.total_generated_tokens +
862
+ rag_stat.chunk_stat.total_generated_tokens +
711
863
  rag_stat.answer_stat.total_generated_tokens
712
864
  )
713
865
  total_tokens = total_input_tokens + total_generated_tokens
714
-
866
+
715
867
  # 避免除以零错误
716
868
  if total_tokens == 0:
717
869
  recall_percent = chunk_percent = answer_percent = 0
718
870
  else:
719
- recall_percent = (rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
720
- chunk_percent = (rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
721
- answer_percent = (rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
722
-
871
+ recall_percent = (rag_stat.recall_stat.total_input_tokens +
872
+ rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
873
+ chunk_percent = (rag_stat.chunk_stat.total_input_tokens +
874
+ rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
875
+ answer_percent = (rag_stat.answer_stat.total_input_tokens +
876
+ rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
877
+
723
878
  logger.info(
724
879
  f"=== RAG 执行统计信息 ===\n"
725
880
  f"总令牌使用: {total_tokens} 令牌\n"
@@ -750,21 +905,22 @@ class LongContextRAG:
750
905
  f" - 文档分块: {chunk_percent:.1f}%\n"
751
906
  f" - 答案生成: {answer_percent:.1f}%\n"
752
907
  )
753
-
908
+
754
909
  # 记录原始统计数据,以便调试
755
910
  logger.debug(f"RAG Stat 原始数据: {rag_stat}")
756
-
911
+
757
912
  # 返回成本估算
758
- estimated_cost = self._estimate_token_cost(total_input_tokens, total_generated_tokens)
913
+ estimated_cost = self._estimate_token_cost(
914
+ total_input_tokens, total_generated_tokens)
759
915
  if estimated_cost > 0:
760
916
  logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
761
917
 
762
918
  def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
763
- """估算当前请求的令牌成本(人民币)"""
919
+ """估算当前请求的令牌成本(人民币)"""
764
920
  # 实际应用中,可以根据不同模型设置不同价格
765
921
  input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
766
922
  output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
767
-
768
- cost = (input_tokens * input_cost_per_1m / 1000000) + (output_tokens* output_cost_per_1m/1000000)
923
+
924
+ cost = (input_tokens * input_cost_per_1m / 1000000) + \
925
+ (output_tokens * output_cost_per_1m/1000000)
769
926
  return cost
770
-