auto-coder 0.1.279__py3-none-any.whl → 0.1.281__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -23,6 +23,8 @@ from autocoder.rag.relevant_utils import (
23
23
  FilterDoc,
24
24
  TaskTiming,
25
25
  parse_relevance,
26
+ ProgressUpdate,
27
+ DocFilterResult
26
28
  )
27
29
  from autocoder.rag.token_checker import check_token_limit
28
30
  from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
@@ -34,14 +36,17 @@ from autocoder.rag.stream_event import event_writer
34
36
  from autocoder.rag.relevant_utils import DocFilterResult
35
37
  from pydantic import BaseModel
36
38
  from byzerllm.utils.types import SingleOutputMeta
39
+ from autocoder.rag.lang import get_message_with_format_and_newline
37
40
 
38
- try:
41
+ try:
39
42
  from autocoder_pro.rag.llm_compute import LLMComputeEngine
40
43
  pro_version = version("auto-coder-pro")
41
44
  autocoder_version = version("auto-coder")
42
- logger.warning(f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
45
+ logger.warning(
46
+ f"auto-coder-pro({pro_version}) plugin is enabled in auto-coder.rag({autocoder_version})")
43
47
  except ImportError:
44
- logger.warning("Please install auto-coder-pro to enhance llm compute ability")
48
+ logger.warning(
49
+ "Please install auto-coder-pro to enhance llm compute ability")
45
50
  LLMComputeEngine = None
46
51
 
47
52
 
@@ -49,20 +54,26 @@ class RecallStat(BaseModel):
49
54
  total_input_tokens: int
50
55
  total_generated_tokens: int
51
56
  model_name: str = "unknown"
57
+
58
+
52
59
  class ChunkStat(BaseModel):
53
60
  total_input_tokens: int
54
- total_generated_tokens: int
61
+ total_generated_tokens: int
55
62
  model_name: str = "unknown"
63
+
64
+
56
65
  class AnswerStat(BaseModel):
57
66
  total_input_tokens: int
58
67
  total_generated_tokens: int
59
68
  model_name: str = "unknown"
60
69
 
70
+
61
71
  class RAGStat(BaseModel):
62
72
  recall_stat: RecallStat
63
73
  chunk_stat: ChunkStat
64
74
  answer_stat: AnswerStat
65
75
 
76
+
66
77
  class LongContextRAG:
67
78
  def __init__(
68
79
  self,
@@ -86,7 +97,7 @@ class LongContextRAG:
86
97
  self.chunk_llm = self.llm.get_sub_client("chunk_model")
87
98
 
88
99
  self.args = args
89
-
100
+
90
101
  self.path = path
91
102
  self.relevant_score = self.args.rag_doc_filter_relevance or 5
92
103
 
@@ -99,8 +110,10 @@ class LongContextRAG:
99
110
  "The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
100
111
  )
101
112
 
102
- self.full_text_limit = int(args.rag_context_window_limit * self.full_text_ratio)
103
- self.segment_limit = int(args.rag_context_window_limit * self.segment_ratio)
113
+ self.full_text_limit = int(
114
+ args.rag_context_window_limit * self.full_text_ratio)
115
+ self.segment_limit = int(
116
+ args.rag_context_window_limit * self.segment_ratio)
104
117
  self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
105
118
 
106
119
  self.tokenizer = None
@@ -109,7 +122,8 @@ class LongContextRAG:
109
122
 
110
123
  if self.tokenizer_path:
111
124
  VariableHolder.TOKENIZER_PATH = self.tokenizer_path
112
- VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(self.tokenizer_path)
125
+ VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(
126
+ self.tokenizer_path)
113
127
  self.tokenizer = TokenCounter(self.tokenizer_path)
114
128
  else:
115
129
  if llm.is_model_exist("deepseek_tokenizer"):
@@ -161,9 +175,9 @@ class LongContextRAG:
161
175
  self.required_exts,
162
176
  self.on_ray,
163
177
  self.monitor_mode,
164
- ## 确保全文区至少能放下一个文件
178
+ # 确保全文区至少能放下一个文件
165
179
  single_file_token_limit=self.full_text_limit - 100,
166
- disable_auto_window=self.args.disable_auto_window,
180
+ disable_auto_window=self.args.disable_auto_window,
167
181
  enable_hybrid_index=self.args.enable_hybrid_index,
168
182
  extra_params=self.args
169
183
  )
@@ -224,14 +238,14 @@ class LongContextRAG:
224
238
  {% for msg in conversations %}
225
239
  [{{ msg.role }}]:
226
240
  {{ msg.content }}
227
-
241
+
228
242
  {% endfor %}
229
243
  </conversations>
230
244
 
231
245
  请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
232
246
  如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
233
247
  提取的信息尽量保持和原文中的一样,并且只输出这些信息。
234
- """
248
+ """
235
249
 
236
250
  @byzerllm.prompt()
237
251
  def _answer_question(
@@ -266,26 +280,25 @@ class LongContextRAG:
266
280
  """Get the document retriever class based on configuration."""
267
281
  # Default to LocalDocumentRetriever if not specified
268
282
  return LocalDocumentRetriever
269
-
283
+
270
284
  def _load_ignore_file(self):
271
285
  serveignore_path = os.path.join(self.path, ".serveignore")
272
286
  gitignore_path = os.path.join(self.path, ".gitignore")
273
287
 
274
288
  if os.path.exists(serveignore_path):
275
- with open(serveignore_path, "r",encoding="utf-8") as ignore_file:
289
+ with open(serveignore_path, "r", encoding="utf-8") as ignore_file:
276
290
  return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
277
291
  elif os.path.exists(gitignore_path):
278
- with open(gitignore_path, "r",encoding="utf-8") as ignore_file:
292
+ with open(gitignore_path, "r", encoding="utf-8") as ignore_file:
279
293
  return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
280
294
  return None
281
295
 
282
- def _retrieve_documents(self,options:Optional[Dict[str,Any]]=None) -> Generator[SourceCode, None, None]:
296
+ def _retrieve_documents(self, options: Optional[Dict[str, Any]] = None) -> Generator[SourceCode, None, None]:
283
297
  return self.document_retriever.retrieve_documents(options=options)
284
298
 
285
299
  def build(self):
286
300
  pass
287
301
 
288
-
289
302
  def search(self, query: str) -> List[SourceCode]:
290
303
  target_query = query
291
304
  only_contexts = False
@@ -300,7 +313,8 @@ class LongContextRAG:
300
313
  only_contexts = True
301
314
 
302
315
  logger.info("Search from RAG.....")
303
- logger.info(f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
316
+ logger.info(
317
+ f"Query: {target_query[0:100]}... only_contexts: {only_contexts}")
304
318
 
305
319
  if self.client:
306
320
  new_query = json.dumps(
@@ -316,7 +330,8 @@ class LongContextRAG:
316
330
  if not only_contexts:
317
331
  return [SourceCode(module_name=f"RAG:{target_query}", source_code=v)]
318
332
 
319
- json_lines = [json.loads(line) for line in v.split("\n") if line.strip()]
333
+ json_lines = [json.loads(line)
334
+ for line in v.split("\n") if line.strip()]
320
335
  return [SourceCode.model_validate(json_line) for json_line in json_lines]
321
336
  else:
322
337
  if only_contexts:
@@ -335,7 +350,7 @@ class LongContextRAG:
335
350
 
336
351
  def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
337
352
  query = conversations[-1]["content"]
338
- documents = self._retrieve_documents(options={"query":query})
353
+ documents = self._retrieve_documents(options={"query": query})
339
354
  return self.doc_filter.filter_docs(
340
355
  conversations=conversations, documents=documents
341
356
  )
@@ -360,9 +375,8 @@ class LongContextRAG:
360
375
  logger.error(f"Error in stream_chat_oai: {str(e)}")
361
376
  traceback.print_exc()
362
377
  return ["出现错误,请稍后再试。"], []
363
-
364
378
 
365
- def _stream_chatfrom_openai_sdk(self,response):
379
+ def _stream_chatfrom_openai_sdk(self, response):
366
380
  for chunk in response:
367
381
  if hasattr(chunk, "usage") and chunk.usage:
368
382
  input_tokens_count = chunk.usage.prompt_tokens
@@ -386,9 +400,9 @@ class LongContextRAG:
386
400
  reasoning_text = chunk.choices[0].delta.reasoning_content or ""
387
401
 
388
402
  last_meta = SingleOutputMeta(input_tokens_count=input_tokens_count,
389
- generated_tokens_count=generated_tokens_count,
390
- reasoning_content=reasoning_text,
391
- finish_reason=chunk.choices[0].finish_reason)
403
+ generated_tokens_count=generated_tokens_count,
404
+ reasoning_content=reasoning_text,
405
+ finish_reason=chunk.choices[0].finish_reason)
392
406
  yield (content, last_meta)
393
407
 
394
408
  def _stream_chat_oai(
@@ -398,7 +412,7 @@ class LongContextRAG:
398
412
  role_mapping=None,
399
413
  llm_config: Dict[str, Any] = {},
400
414
  extra_request_params: Dict[str, Any] = {}
401
- ):
415
+ ):
402
416
  if self.client:
403
417
  model = model or self.args.model
404
418
  response = self.client.chat.completions.create(
@@ -407,8 +421,8 @@ class LongContextRAG:
407
421
  stream=True,
408
422
  max_tokens=self.args.rag_params_max_tokens,
409
423
  extra_body=extra_request_params
410
- )
411
- return self._stream_chatfrom_openai_sdk(response), []
424
+ )
425
+ return self._stream_chatfrom_openai_sdk(response), []
412
426
 
413
427
  target_llm = self.llm
414
428
  if self.llm.get_sub_client("qa_model"):
@@ -422,7 +436,7 @@ class LongContextRAG:
422
436
  in query
423
437
  or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
424
438
  in query
425
- ):
439
+ ):
426
440
 
427
441
  chunks = target_llm.stream_chat_oai(
428
442
  conversations=conversations,
@@ -432,22 +446,24 @@ class LongContextRAG:
432
446
  delta_mode=True,
433
447
  extra_request_params=extra_request_params
434
448
  )
449
+
435
450
  def generate_chunks():
436
451
  for chunk in chunks:
437
452
  yield chunk
438
453
  return generate_chunks(), context
439
-
440
- try:
454
+
455
+ try:
441
456
  request_params = json.loads(query)
442
- if "request_id" in request_params:
457
+ if "request_id" in request_params:
443
458
  request_id = request_params["request_id"]
444
459
  index = request_params["index"]
445
-
446
- file_path = event_writer.get_event_file_path(request_id)
447
- logger.info(f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
460
+
461
+ file_path = event_writer.get_event_file_path(request_id)
462
+ logger.info(
463
+ f"Get events for request_id: {request_id} index: {index} file_path: {file_path}")
448
464
  events = []
449
465
  if not os.path.exists(file_path):
450
- return [],context
466
+ return [], context
451
467
 
452
468
  with open(file_path, "r") as f:
453
469
  for line in f:
@@ -455,8 +471,8 @@ class LongContextRAG:
455
471
  if event["index"] >= index:
456
472
  events.append(event)
457
473
  return [json.dumps({
458
- "events": [event for event in events],
459
- },ensure_ascii=False)], context
474
+ "events": [event for event in events],
475
+ }, ensure_ascii=False)], context
460
476
  except json.JSONDecodeError:
461
477
  pass
462
478
 
@@ -465,7 +481,7 @@ class LongContextRAG:
465
481
  llm=target_llm,
466
482
  inference_enhance=not self.args.disable_inference_enhance,
467
483
  inference_deep_thought=self.args.inference_deep_thought,
468
- inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
484
+ inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
469
485
  precision=self.args.inference_compute_precision,
470
486
  data_cells_max_num=self.args.data_cells_max_num,
471
487
  )
@@ -474,14 +490,14 @@ class LongContextRAG:
474
490
  conversations, query, []
475
491
  )
476
492
  chunks = llm_compute_engine.stream_chat_oai(
477
- conversations=new_conversations,
478
- model=model,
479
- role_mapping=role_mapping,
480
- llm_config=llm_config,
481
- delta_mode=True,
482
- extra_request_params=extra_request_params
483
- )
484
-
493
+ conversations=new_conversations,
494
+ model=model,
495
+ role_mapping=role_mapping,
496
+ llm_config=llm_config,
497
+ delta_mode=True,
498
+ extra_request_params=extra_request_params
499
+ )
500
+
485
501
  def generate_chunks():
486
502
  for chunk in chunks:
487
503
  yield chunk
@@ -491,7 +507,6 @@ class LongContextRAG:
491
507
  context,
492
508
  )
493
509
 
494
-
495
510
  only_contexts = False
496
511
  try:
497
512
  v = json.loads(query)
@@ -504,7 +519,6 @@ class LongContextRAG:
504
519
 
505
520
  logger.info(f"Query: {query} only_contexts: {only_contexts}")
506
521
  start_time = time.time()
507
-
508
522
 
509
523
  rag_stat = RAGStat(
510
524
  recall_stat=RecallStat(
@@ -525,17 +539,62 @@ class LongContextRAG:
525
539
  )
526
540
 
527
541
  context = []
542
+
528
543
  def generate_sream():
529
544
  nonlocal context
530
- doc_filter_result = self._filter_docs(conversations)
531
545
 
532
- rag_stat.recall_stat.total_input_tokens += sum(doc_filter_result.input_tokens_counts)
533
- rag_stat.recall_stat.total_generated_tokens += sum(doc_filter_result.generated_tokens_counts)
546
+ yield ("", SingleOutputMeta(input_tokens_count=0,
547
+ generated_tokens_count=0,
548
+ reasoning_content=get_message_with_format_and_newline(
549
+ "rag_searching_docs",
550
+ model=rag_stat.recall_stat.model_name
551
+ )
552
+ ))
553
+
554
+ doc_filter_result = DocFilterResult(
555
+ docs=[],
556
+ raw_docs=[],
557
+ input_tokens_counts=[],
558
+ generated_tokens_counts=[],
559
+ durations=[],
560
+ model_name=rag_stat.recall_stat.model_name
561
+ )
562
+ query = conversations[-1]["content"]
563
+ documents = self._retrieve_documents(options={"query": query})
564
+
565
+ # 使用带进度报告的过滤方法
566
+ for progress_update, result in self.doc_filter.filter_docs_with_progress(conversations, documents):
567
+ if result is not None:
568
+ doc_filter_result = result
569
+ else:
570
+ # 生成进度更新
571
+ yield ("", SingleOutputMeta(
572
+ input_tokens_count=rag_stat.recall_stat.total_input_tokens,
573
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
574
+ reasoning_content=f"{progress_update.message} ({progress_update.completed}/{progress_update.total})"
575
+ ))
576
+
577
+ rag_stat.recall_stat.total_input_tokens += sum(
578
+ doc_filter_result.input_tokens_counts)
579
+ rag_stat.recall_stat.total_generated_tokens += sum(
580
+ doc_filter_result.generated_tokens_counts)
534
581
  rag_stat.recall_stat.model_name = doc_filter_result.model_name
535
582
 
536
583
  relevant_docs: List[FilterDoc] = doc_filter_result.docs
537
584
  filter_time = time.time() - start_time
538
585
 
586
+ yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens,
587
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens,
588
+ reasoning_content=get_message_with_format_and_newline(
589
+ "rag_docs_filter_result",
590
+ filter_time=filter_time,
591
+ docs_num=len(relevant_docs),
592
+ input_tokens=rag_stat.recall_stat.total_input_tokens,
593
+ output_tokens=rag_stat.recall_stat.total_generated_tokens,
594
+ model=rag_stat.recall_stat.model_name
595
+ )
596
+ ))
597
+
539
598
  # Filter relevant_docs to only include those with is_relevant=True
540
599
  highly_relevant_docs = [
541
600
  doc for doc in relevant_docs if doc.relevance.is_relevant
@@ -543,7 +602,8 @@ class LongContextRAG:
543
602
 
544
603
  if highly_relevant_docs:
545
604
  relevant_docs = highly_relevant_docs
546
- logger.info(f"Found {len(relevant_docs)} highly relevant documents")
605
+ logger.info(
606
+ f"Found {len(relevant_docs)} highly relevant documents")
547
607
 
548
608
  logger.info(
549
609
  f"Filter time: {filter_time:.2f} seconds with {len(relevant_docs)} docs"
@@ -553,7 +613,7 @@ class LongContextRAG:
553
613
  final_docs = []
554
614
  for doc in relevant_docs:
555
615
  final_docs.append(doc.model_dump())
556
- return [json.dumps(final_docs,ensure_ascii=False)], []
616
+ return [json.dumps(final_docs, ensure_ascii=False)], []
557
617
 
558
618
  if not relevant_docs:
559
619
  return ["没有找到相关的文档来回答这个问题。"], []
@@ -588,6 +648,12 @@ class LongContextRAG:
588
648
  + "".join([f"\n * {info}" for info in relevant_docs_info])
589
649
  )
590
650
 
651
+ yield ("", SingleOutputMeta(generated_tokens_count=0,
652
+ reasoning_content=get_message_with_format_and_newline(
653
+ "dynamic_chunking_start",
654
+ model=rag_stat.chunk_stat.model_name
655
+ )
656
+ ))
591
657
  first_round_full_docs = []
592
658
  second_round_extracted_docs = []
593
659
  sencond_round_time = 0
@@ -602,17 +668,19 @@ class LongContextRAG:
602
668
  llm=self.llm,
603
669
  disable_segment_reorder=self.args.disable_segment_reorder,
604
670
  )
605
-
671
+
606
672
  token_limiter_result = token_limiter.limit_tokens(
607
673
  relevant_docs=relevant_docs,
608
674
  conversations=conversations,
609
675
  index_filter_workers=self.args.index_filter_workers or 5,
610
676
  )
611
677
 
612
- rag_stat.chunk_stat.total_input_tokens += sum(token_limiter_result.input_tokens_counts)
613
- rag_stat.chunk_stat.total_generated_tokens += sum(token_limiter_result.generated_tokens_counts)
678
+ rag_stat.chunk_stat.total_input_tokens += sum(
679
+ token_limiter_result.input_tokens_counts)
680
+ rag_stat.chunk_stat.total_generated_tokens += sum(
681
+ token_limiter_result.generated_tokens_counts)
614
682
  rag_stat.chunk_stat.model_name = token_limiter_result.model_name
615
-
683
+
616
684
  final_relevant_docs = token_limiter_result.docs
617
685
  first_round_full_docs = token_limiter.first_round_full_docs
618
686
  second_round_extracted_docs = token_limiter.second_round_extracted_docs
@@ -623,24 +691,41 @@ class LongContextRAG:
623
691
  relevant_docs = relevant_docs[: self.args.index_filter_file_num]
624
692
 
625
693
  logger.info(f"Finally send to model: {len(relevant_docs)}")
626
-
627
694
  # 记录分段处理的统计信息
628
695
  logger.info(
629
696
  f"=== Token Management ===\n"
630
697
  f" * Only contexts: {only_contexts}\n"
631
- f" * Filter time: {filter_time:.2f} seconds\n"
698
+ f" * Filter time: {filter_time:.2f} seconds\n"
632
699
  f" * Final relevant docs: {len(relevant_docs)}\n"
633
700
  f" * First round full docs: {len(first_round_full_docs)}\n"
634
701
  f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
635
702
  f" * Second round time: {sencond_round_time:.2f} seconds"
636
703
  )
637
704
 
705
+ yield ("", SingleOutputMeta(generated_tokens_count=rag_stat.chunk_stat.total_generated_tokens + rag_stat.recall_stat.total_generated_tokens,
706
+ input_tokens_count=rag_stat.chunk_stat.total_input_tokens +
707
+ rag_stat.recall_stat.total_input_tokens,
708
+ reasoning_content=get_message_with_format_and_newline(
709
+ "dynamic_chunking_result",
710
+ model=rag_stat.chunk_stat.model_name,
711
+ docs_num=len(relevant_docs),
712
+ filter_time=filter_time,
713
+ sencond_round_time=sencond_round_time,
714
+ first_round_full_docs=len(
715
+ first_round_full_docs),
716
+ second_round_extracted_docs=len(
717
+ second_round_extracted_docs),
718
+ input_tokens=rag_stat.chunk_stat.total_input_tokens,
719
+ output_tokens=rag_stat.chunk_stat.total_generated_tokens
720
+ )
721
+ ))
722
+
638
723
  # 记录最终选择的文档详情
639
724
  final_relevant_docs_info = []
640
725
  for i, doc in enumerate(relevant_docs):
641
726
  doc_path = doc.module_name.replace(self.path, '', 1)
642
727
  info = f"{i+1}. {doc_path}"
643
-
728
+
644
729
  metadata_info = []
645
730
  if "original_docs" in doc.metadata:
646
731
  original_docs = ", ".join(
@@ -650,26 +735,27 @@ class LongContextRAG:
650
735
  ]
651
736
  )
652
737
  metadata_info.append(f"Original docs: {original_docs}")
653
-
738
+
654
739
  if "chunk_ranges" in doc.metadata:
655
740
  chunk_ranges = json.dumps(
656
741
  doc.metadata["chunk_ranges"], ensure_ascii=False
657
742
  )
658
743
  metadata_info.append(f"Chunk ranges: {chunk_ranges}")
659
-
744
+
660
745
  if "processing_time" in doc.metadata:
661
- metadata_info.append(f"Processing time: {doc.metadata['processing_time']:.2f}s")
662
-
746
+ metadata_info.append(
747
+ f"Processing time: {doc.metadata['processing_time']:.2f}s")
748
+
663
749
  if metadata_info:
664
750
  info += f" ({'; '.join(metadata_info)})"
665
-
751
+
666
752
  final_relevant_docs_info.append(info)
667
753
 
668
754
  if final_relevant_docs_info:
669
755
  logger.info(
670
756
  f"Final documents to be sent to model:"
671
757
  + "".join([f"\n * {info}" for info in final_relevant_docs_info])
672
- )
758
+ )
673
759
 
674
760
  # 记录令牌统计
675
761
  request_tokens = sum([doc.tokens for doc in relevant_docs])
@@ -680,7 +766,18 @@ class LongContextRAG:
680
766
  f" * Total tokens: {request_tokens}"
681
767
  )
682
768
 
683
- logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
769
+ logger.info(
770
+ f"Start to send to model {target_model} with {request_tokens} tokens")
771
+
772
+ yield ("", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
773
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
774
+ rag_stat.chunk_stat.total_generated_tokens,
775
+ reasoning_content=get_message_with_format_and_newline(
776
+ "send_to_model",
777
+ model=target_model,
778
+ tokens=request_tokens
779
+ )
780
+ ))
684
781
 
685
782
  if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
686
783
  llm_compute_engine = LLMComputeEngine(
@@ -692,33 +789,42 @@ class LongContextRAG:
692
789
  debug=False,
693
790
  )
694
791
  new_conversations = llm_compute_engine.process_conversation(
695
- conversations, query, [doc.source_code for doc in relevant_docs]
792
+ conversations, query, [
793
+ doc.source_code for doc in relevant_docs]
696
794
  )
697
795
  chunks = llm_compute_engine.stream_chat_oai(
698
- conversations=new_conversations,
699
- model=model,
700
- role_mapping=role_mapping,
701
- llm_config=llm_config,
702
- delta_mode=True,
703
- )
704
-
796
+ conversations=new_conversations,
797
+ model=model,
798
+ role_mapping=role_mapping,
799
+ llm_config=llm_config,
800
+ delta_mode=True,
801
+ )
802
+
705
803
  for chunk in chunks:
706
- yield chunk
707
804
  if chunk[1] is not None:
708
805
  rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
709
- rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
710
- self._print_rag_stats(rag_stat)
711
- else:
806
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
807
+ chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
808
+ rag_stat.chunk_stat.total_input_tokens + \
809
+ rag_stat.answer_stat.total_input_tokens
810
+ chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
811
+ rag_stat.chunk_stat.total_generated_tokens + \
812
+ rag_stat.answer_stat.total_generated_tokens
813
+ yield chunk
814
+
815
+ self._print_rag_stats(rag_stat)
816
+ else:
712
817
  new_conversations = conversations[:-1] + [
713
818
  {
714
819
  "role": "user",
715
820
  "content": self._answer_question.prompt(
716
821
  query=query,
717
- relevant_docs=[doc.source_code for doc in relevant_docs],
822
+ relevant_docs=[
823
+ doc.source_code for doc in relevant_docs],
718
824
  ),
719
825
  }
720
826
  ]
721
-
827
+
722
828
  chunks = target_llm.stream_chat_oai(
723
829
  conversations=new_conversations,
724
830
  model=model,
@@ -727,17 +833,23 @@ class LongContextRAG:
727
833
  delta_mode=True,
728
834
  extra_request_params=extra_request_params
729
835
  )
730
-
836
+
731
837
  for chunk in chunks:
732
- yield chunk
733
838
  if chunk[1] is not None:
734
839
  rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
735
- rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
736
- self._print_rag_stats(rag_stat)
840
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
841
+ chunk[1].input_tokens_count = rag_stat.recall_stat.total_input_tokens + \
842
+ rag_stat.chunk_stat.total_input_tokens + \
843
+ rag_stat.answer_stat.total_input_tokens
844
+ chunk[1].generated_tokens_count = rag_stat.recall_stat.total_generated_tokens + \
845
+ rag_stat.chunk_stat.total_generated_tokens + \
846
+ rag_stat.answer_stat.total_generated_tokens
847
+
848
+ yield chunk
737
849
 
738
- return generate_sream(),context
739
-
740
-
850
+ self._print_rag_stats(rag_stat)
851
+
852
+ return generate_sream(), context
741
853
 
742
854
  def _print_rag_stats(self, rag_stat: RAGStat) -> None:
743
855
  """打印RAG执行的详细统计信息"""
@@ -748,19 +860,22 @@ class LongContextRAG:
748
860
  )
749
861
  total_generated_tokens = (
750
862
  rag_stat.recall_stat.total_generated_tokens +
751
- rag_stat.chunk_stat.total_generated_tokens +
863
+ rag_stat.chunk_stat.total_generated_tokens +
752
864
  rag_stat.answer_stat.total_generated_tokens
753
865
  )
754
866
  total_tokens = total_input_tokens + total_generated_tokens
755
-
867
+
756
868
  # 避免除以零错误
757
869
  if total_tokens == 0:
758
870
  recall_percent = chunk_percent = answer_percent = 0
759
871
  else:
760
- recall_percent = (rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
761
- chunk_percent = (rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
762
- answer_percent = (rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
763
-
872
+ recall_percent = (rag_stat.recall_stat.total_input_tokens +
873
+ rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
874
+ chunk_percent = (rag_stat.chunk_stat.total_input_tokens +
875
+ rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
876
+ answer_percent = (rag_stat.answer_stat.total_input_tokens +
877
+ rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
878
+
764
879
  logger.info(
765
880
  f"=== RAG 执行统计信息 ===\n"
766
881
  f"总令牌使用: {total_tokens} 令牌\n"
@@ -791,21 +906,22 @@ class LongContextRAG:
791
906
  f" - 文档分块: {chunk_percent:.1f}%\n"
792
907
  f" - 答案生成: {answer_percent:.1f}%\n"
793
908
  )
794
-
909
+
795
910
  # 记录原始统计数据,以便调试
796
911
  logger.debug(f"RAG Stat 原始数据: {rag_stat}")
797
-
912
+
798
913
  # 返回成本估算
799
- estimated_cost = self._estimate_token_cost(total_input_tokens, total_generated_tokens)
914
+ estimated_cost = self._estimate_token_cost(
915
+ total_input_tokens, total_generated_tokens)
800
916
  if estimated_cost > 0:
801
917
  logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
802
918
 
803
919
  def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
804
- """估算当前请求的令牌成本(人民币)"""
920
+ """估算当前请求的令牌成本(人民币)"""
805
921
  # 实际应用中,可以根据不同模型设置不同价格
806
922
  input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
807
923
  output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
808
-
809
- cost = (input_tokens * input_cost_per_1m / 1000000) + (output_tokens* output_cost_per_1m/1000000)
924
+
925
+ cost = (input_tokens * input_cost_per_1m / 1000000) + \
926
+ (output_tokens * output_cost_per_1m/1000000)
810
927
  return cost
811
-