auto-coder 0.1.270__py3-none-any.whl → 0.1.272__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -5,8 +5,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
5
5
 
6
6
  from autocoder.rag.relevant_utils import (
7
7
  parse_relevance,
8
- FilterDoc,
8
+ FilterDoc,
9
9
  TaskTiming,
10
+ DocFilterResult
10
11
  )
11
12
 
12
13
  from autocoder.common import SourceCode, AutoCoderArgs
@@ -48,7 +49,6 @@ def _check_relevance_with_conversation(
48
49
  其中, <relevant> 是你认为文档中和问题的相关度,0-10之间的数字,数字越大表示相关度越高。
49
50
  """
50
51
 
51
-
52
52
  class DocFilter:
53
53
  def __init__(
54
54
  self,
@@ -62,40 +62,57 @@ class DocFilter:
62
62
  self.recall_llm = self.llm.get_sub_client("recall_model")
63
63
  else:
64
64
  self.recall_llm = self.llm
65
-
65
+
66
66
  self.args = args
67
67
  self.relevant_score = self.args.rag_doc_filter_relevance
68
68
  self.on_ray = on_ray
69
- self.path = path
69
+ self.path = path
70
70
 
71
71
  def filter_docs(
72
72
  self, conversations: List[Dict[str, str]], documents: List[SourceCode]
73
- ) -> List[FilterDoc]:
74
- return self.filter_docs_with_threads(conversations, documents)
73
+ ) -> DocFilterResult:
74
+ return self.filter_docs_with_threads(conversations, documents)
75
75
 
76
76
  def filter_docs_with_threads(
77
77
  self, conversations: List[Dict[str, str]], documents: List[SourceCode]
78
- ) -> List[FilterDoc]:
79
-
78
+ ) -> DocFilterResult:
79
+
80
+ start_time = time.time()
81
+ logger.info(f"=== DocFilter Starting ===")
82
+ logger.info(
83
+ f"Configuration: relevance_threshold={self.relevant_score}, thread_workers={self.args.index_filter_workers or 5}")
84
+
80
85
  rag_manager = RagConfigManager(path=self.path)
81
86
  rag_config = rag_manager.load_config()
82
- documents = list(documents)
83
- logger.info(f"Filtering {len(documents)} documents....")
87
+
88
+ documents = list(documents)
89
+ logger.info(f"Filtering {len(documents)} documents...")
90
+
91
+ submitted_tasks = 0
92
+ completed_tasks = 0
93
+ relevant_count = 0
94
+ model_name = self.recall_llm.default_model_name or "unknown"
95
+
84
96
  with ThreadPoolExecutor(
85
97
  max_workers=self.args.index_filter_workers or 5
86
98
  ) as executor:
87
99
  future_to_doc = {}
100
+
101
+ # 提交所有任务
88
102
  for doc in documents:
89
103
  submit_time = time.time()
104
+ submitted_tasks += 1
90
105
 
91
106
  def _run(conversations, docs):
92
107
  submit_time_1 = time.time()
108
+ meta = None
93
109
  try:
94
110
  llm = self.recall_llm
111
+ meta_holder = byzerllm.MetaHolder()
95
112
 
96
113
  v = (
97
114
  _check_relevance_with_conversation.with_llm(
98
- llm)
115
+ llm).with_meta(meta_holder)
99
116
  .options({"llm_config": {"max_length": 10}})
100
117
  .run(
101
118
  conversations=conversations,
@@ -103,14 +120,16 @@ class DocFilter:
103
120
  filter_config=rag_config.filter_config,
104
121
  )
105
122
  )
123
+
124
+ meta = meta_holder.get_meta_model()
106
125
  except Exception as e:
107
126
  logger.error(
108
127
  f"Error in _check_relevance_with_conversation: {str(e)}"
109
128
  )
110
- return (None, submit_time_1, time.time())
129
+ return (None, submit_time_1, time.time(), meta)
111
130
 
112
131
  end_time_2 = time.time()
113
- return (v, submit_time_1, end_time_2)
132
+ return (v, submit_time_1, end_time_2, meta)
114
133
 
115
134
  m = executor.submit(
116
135
  _run,
@@ -119,57 +138,144 @@ class DocFilter:
119
138
  )
120
139
  future_to_doc[m] = (doc, submit_time)
121
140
 
122
- relevant_docs = []
123
- for future in as_completed(list(future_to_doc.keys())):
124
- try:
125
- doc, submit_time = future_to_doc[future]
126
- end_time = time.time()
127
- v, submit_time_1, end_time_2 = future.result()
128
- task_timing = TaskTiming(
129
- submit_time=submit_time,
130
- end_time=end_time,
131
- duration=end_time - submit_time,
132
- real_start_time=submit_time_1,
133
- real_end_time=end_time_2,
134
- real_duration=end_time_2 - submit_time_1,
135
- )
136
-
137
- relevance = parse_relevance(v)
138
- logger.info(
139
- f"Document filtering progress:\n"
140
- f" - File: {doc.module_name}\n"
141
- f" - Relevance: {'Relevant' if relevance and relevance.is_relevant else 'Not Relevant'}\n"
142
- f" - Score: {relevance.relevant_score if relevance else 'N/A'}\n"
143
- f" - Score Threshold: {self.relevant_score}\n"
144
- f" - Raw Response: {v}\n"
145
- f" - Timing:\n"
146
- f" * Total Duration: {task_timing.duration:.2f}s\n"
147
- f" * Real Duration: {task_timing.real_duration:.2f}s\n"
148
- f" * Queue Time: {(task_timing.real_start_time - task_timing.submit_time):.2f}s"
149
- )
150
- if (
151
- relevance
152
- # and relevance.is_relevant
153
- and relevance.relevant_score >= self.relevant_score
154
- ):
155
- relevant_docs.append(
156
- FilterDoc(
141
+ logger.info(
142
+ f"Submitted {submitted_tasks} document filtering tasks to thread pool")
143
+
144
+ # 处理完成的任务
145
+ doc_filter_result = DocFilterResult(
146
+ docs=[],
147
+ raw_docs=[],
148
+ input_tokens_counts=[],
149
+ generated_tokens_counts=[],
150
+ durations=[],
151
+ model_name=model_name
152
+ )
153
+ relevant_docs = doc_filter_result.docs
154
+ for future in as_completed(list(future_to_doc.keys())):
155
+ try:
156
+ doc, submit_time = future_to_doc[future]
157
+ end_time = time.time()
158
+ completed_tasks += 1
159
+ progress_percent = (completed_tasks / len(documents)) * 100
160
+
161
+ v, submit_time_1, end_time_2, meta = future.result()
162
+ task_timing = TaskTiming(
163
+ submit_time=submit_time,
164
+ end_time=end_time,
165
+ duration=end_time - submit_time,
166
+ real_start_time=submit_time_1,
167
+ real_end_time=end_time_2,
168
+ real_duration=end_time_2 - submit_time_1,
169
+ )
170
+
171
+ relevance = parse_relevance(v)
172
+ is_relevant = relevance and relevance.relevant_score >= self.relevant_score
173
+
174
+ if is_relevant:
175
+ relevant_count += 1
176
+ status_text = f"RELEVANT (Score: {relevance.relevant_score:.1f})"
177
+ else:
178
+ score_text = f"{relevance.relevant_score:.1f}" if relevance else "N/A"
179
+ status_text = f"NOT RELEVANT (Score: {score_text})"
180
+
181
+ queue_time = task_timing.real_start_time - task_timing.submit_time
182
+
183
+ input_tokens_count = meta.input_tokens_count if meta else 0
184
+ generated_tokens_count = meta.generated_tokens_count if meta else 0
185
+
186
+ logger.info(
187
+ f"Document filtering [{progress_percent:.1f}%] - {completed_tasks}/{len(documents)}:"
188
+ f"\n - File: {doc.module_name}"
189
+ f"\n - Status: {status_text}"
190
+ f"\n - Model: {model_name}"
191
+ f"\n - Threshold: {self.relevant_score}"
192
+ f"\n - Input tokens: {input_tokens_count}"
193
+ f"\n - Generated tokens: {generated_tokens_count}"
194
+ f"\n - Timing: Duration={task_timing.duration:.2f}s, Processing={task_timing.real_duration:.2f}s, Queue={queue_time:.2f}s"
195
+ f"\n - Response: {v}"
196
+ )
197
+
198
+ if "rag" not in doc.metadata:
199
+ doc.metadata["rag"] = {}
200
+ doc.metadata["rag"]["recall"] = {
201
+ "input_tokens_count": input_tokens_count,
202
+ "generated_tokens_count": generated_tokens_count,
203
+ "recall_model": model_name,
204
+ "duration": task_timing.real_duration
205
+ }
206
+
207
+ doc_filter_result.input_tokens_counts.append(input_tokens_count)
208
+ doc_filter_result.generated_tokens_counts.append(generated_tokens_count)
209
+ doc_filter_result.durations.append(task_timing.real_duration)
210
+
211
+ new_filter_doc = FilterDoc(
157
212
  source_code=doc,
158
213
  relevance=relevance,
159
214
  task_timing=task_timing,
160
215
  )
161
- )
162
- except Exception as exc:
163
- try:
164
- doc, submit_time = future_to_doc[future]
165
- logger.error(
166
- f"Filtering document generated an exception (doc: {doc.module_name}): {exc}")
167
- except Exception as e:
168
- logger.error(
169
- f"Filtering document generated an exception: {exc}")
216
+
217
+ doc_filter_result.raw_docs.append(new_filter_doc)
218
+
219
+ if is_relevant:
220
+ relevant_docs.append(
221
+ new_filter_doc
222
+ )
223
+ except Exception as exc:
224
+ try:
225
+ doc, submit_time = future_to_doc[future]
226
+ completed_tasks += 1
227
+ progress_percent = (
228
+ completed_tasks / len(documents)) * 100
229
+ logger.error(
230
+ f"Document filtering [{progress_percent:.1f}%] - {completed_tasks}/{len(documents)}:"
231
+ f"\n - File: {doc.module_name}"
232
+ f"\n - Error: {exc}"
233
+ f"\n - Duration: {time.time() - submit_time:.2f}s"
234
+ )
235
+ doc_filter_result.raw_docs.append(
236
+ FilterDoc(
237
+ source_code=doc,
238
+ relevance=None,
239
+ task_timing=TaskTiming(),
240
+ )
241
+ )
242
+ except Exception as e:
243
+ logger.error(
244
+ f"Document filtering error in task tracking: {exc}"
245
+ )
170
246
 
171
247
  # Sort relevant_docs by relevance score in descending order
172
248
  relevant_docs.sort(
173
249
  key=lambda x: x.relevance.relevant_score, reverse=True)
174
- return relevant_docs
175
-
250
+
251
+ total_time = time.time() - start_time
252
+
253
+ avg_processing_time = sum(
254
+ doc.task_timing.real_duration for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
255
+ avg_queue_time = sum(doc.task_timing.real_start_time -
256
+ doc.task_timing.submit_time for doc in relevant_docs) / len(relevant_docs) if relevant_docs else 0
257
+
258
+ total_input_tokens = sum(doc_filter_result.input_tokens_counts)
259
+ total_generated_tokens = sum(doc_filter_result.generated_tokens_counts)
260
+
261
+ logger.info(
262
+ f"=== DocFilter Complete ==="
263
+ f"\n * Total time: {total_time:.2f}s"
264
+ f"\n * Documents processed: {completed_tasks}/{len(documents)}"
265
+ f"\n * Relevant documents: {relevant_count} (threshold: {self.relevant_score})"
266
+ f"\n * Average processing time: {avg_processing_time:.2f}s"
267
+ f"\n * Average queue time: {avg_queue_time:.2f}s"
268
+ f"\n * Total input tokens: {total_input_tokens}"
269
+ f"\n * Total generated tokens: {total_generated_tokens}"
270
+ )
271
+
272
+ if relevant_docs:
273
+ logger.info(
274
+ f"Top 5 relevant documents:"
275
+ + "".join([f"\n * {doc.source_code.module_name} (Score: {doc.relevance.relevant_score:.1f})"
276
+ for doc in relevant_docs[:5]])
277
+ )
278
+ else:
279
+ logger.warning("No relevant documents found!")
280
+
281
+ return doc_filter_result
@@ -44,13 +44,15 @@ class LLWrapper:
44
44
  res,contexts = self.rag.stream_chat_oai(conversations,llm_config=llm_config)
45
45
  for t in res:
46
46
  yield (t,SingleOutputMeta(0,0))
47
+
47
48
 
48
49
  async def async_stream_chat_oai(self,conversations,
49
50
  model:Optional[str]=None,
50
51
  role_mapping=None,
51
52
  delta_mode=False,
52
53
  llm_config:Dict[str,Any]={}):
53
- res,contexts = await asyncfy_with_semaphore(lambda: self.rag.stream_chat_oai(conversations,llm_config=llm_config))()
54
+ res,contexts = await asyncfy_with_semaphore(lambda: self.rag.stream_chat_oai(conversations,llm_config=llm_config))()
55
+ # res,contexts = await self.llm.async_stream_chat_oai(conversations,llm_config=llm_config)
54
56
  for t in res:
55
57
  yield (t,SingleOutputMeta(0,0))
56
58
 
@@ -31,6 +31,8 @@ from tokenizers import Tokenizer
31
31
  from autocoder.rag.variable_holder import VariableHolder
32
32
  from importlib.metadata import version
33
33
  from autocoder.rag.stream_event import event_writer
34
+ from autocoder.rag.relevant_utils import DocFilterResult
35
+ from pydantic import BaseModel
34
36
 
35
37
  try:
36
38
  from autocoder_pro.rag.llm_compute import LLMComputeEngine
@@ -42,6 +44,24 @@ except ImportError:
42
44
  LLMComputeEngine = None
43
45
 
44
46
 
47
+ class RecallStat(BaseModel):
48
+ total_input_tokens: int
49
+ total_generated_tokens: int
50
+ model_name: str = "unknown"
51
+ class ChunkStat(BaseModel):
52
+ total_input_tokens: int
53
+ total_generated_tokens: int
54
+ model_name: str = "unknown"
55
+ class AnswerStat(BaseModel):
56
+ total_input_tokens: int
57
+ total_generated_tokens: int
58
+ model_name: str = "unknown"
59
+
60
+ class RAGStat(BaseModel):
61
+ recall_stat: RecallStat
62
+ chunk_stat: ChunkStat
63
+ answer_stat: AnswerStat
64
+
45
65
  class LongContextRAG:
46
66
  def __init__(
47
67
  self,
@@ -305,7 +325,7 @@ class LongContextRAG:
305
325
  url = ",".join(contexts)
306
326
  return [SourceCode(module_name=f"RAG:{url}", source_code="".join(v))]
307
327
 
308
- def _filter_docs(self, conversations: List[Dict[str, str]]) -> List[FilterDoc]:
328
+ def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
309
329
  query = conversations[-1]["content"]
310
330
  documents = self._retrieve_documents(options={"query":query})
311
331
  return self.doc_filter.filter_docs(
@@ -439,7 +459,32 @@ class LongContextRAG:
439
459
 
440
460
  logger.info(f"Query: {query} only_contexts: {only_contexts}")
441
461
  start_time = time.time()
442
- relevant_docs: List[FilterDoc] = self._filter_docs(conversations)
462
+
463
+ rag_stat = RAGStat(
464
+ recall_stat=RecallStat(
465
+ total_input_tokens=0,
466
+ total_generated_tokens=0,
467
+ model_name=self.llm.default_model_name,
468
+ ),
469
+ chunk_stat=ChunkStat(
470
+ total_input_tokens=0,
471
+ total_generated_tokens=0,
472
+ model_name=self.llm.default_model_name,
473
+ ),
474
+ answer_stat=AnswerStat(
475
+ total_input_tokens=0,
476
+ total_generated_tokens=0,
477
+ model_name=self.llm.default_model_name,
478
+ ),
479
+ )
480
+
481
+ doc_filter_result = self._filter_docs(conversations)
482
+
483
+ rag_stat.recall_stat.total_input_tokens += sum(doc_filter_result.input_tokens_counts)
484
+ rag_stat.recall_stat.total_generated_tokens += sum(doc_filter_result.generated_tokens_counts)
485
+ rag_stat.recall_stat.model_name = doc_filter_result.model_name
486
+
487
+ relevant_docs: List[FilterDoc] = doc_filter_result.docs
443
488
  filter_time = time.time() - start_time
444
489
 
445
490
  # Filter relevant_docs to only include those with is_relevant=True
@@ -469,17 +514,15 @@ class LongContextRAG:
469
514
  # 将 FilterDoc 转化为 SourceCode 方便后续的逻辑继续做处理
470
515
  relevant_docs = [doc.source_code for doc in relevant_docs]
471
516
 
472
- console = Console()
517
+ logger.info(f"=== RAG Search Results ===")
518
+ logger.info(f"Query: {query}")
519
+ logger.info(f"Found relevant docs: {len(relevant_docs)}")
473
520
 
474
- # Create a table for the query information
475
- query_table = Table(title="Query Information", show_header=False)
476
- query_table.add_row("Query", query)
477
- query_table.add_row("Relevant docs", str(len(relevant_docs)))
478
-
479
- # Add relevant docs information
521
+ # 记录相关文档信息
480
522
  relevant_docs_info = []
481
- for doc in relevant_docs:
482
- info = f"- {doc.module_name.replace(self.path,'',1)}"
523
+ for i, doc in enumerate(relevant_docs):
524
+ doc_path = doc.module_name.replace(self.path, '', 1)
525
+ info = f"{i+1}. {doc_path}"
483
526
  if "original_docs" in doc.metadata:
484
527
  original_docs = ", ".join(
485
528
  [
@@ -490,8 +533,11 @@ class LongContextRAG:
490
533
  info += f" (Original docs: {original_docs})"
491
534
  relevant_docs_info.append(info)
492
535
 
493
- relevant_docs_info = "\n".join(relevant_docs_info)
494
- query_table.add_row("Relevant docs list", relevant_docs_info)
536
+ if relevant_docs_info:
537
+ logger.info(
538
+ f"Relevant documents list:"
539
+ + "".join([f"\n * {info}" for info in relevant_docs_info])
540
+ )
495
541
 
496
542
  first_round_full_docs = []
497
543
  second_round_extracted_docs = []
@@ -507,11 +553,18 @@ class LongContextRAG:
507
553
  llm=self.llm,
508
554
  disable_segment_reorder=self.args.disable_segment_reorder,
509
555
  )
510
- final_relevant_docs = token_limiter.limit_tokens(
556
+
557
+ token_limiter_result = token_limiter.limit_tokens(
511
558
  relevant_docs=relevant_docs,
512
559
  conversations=conversations,
513
560
  index_filter_workers=self.args.index_filter_workers or 5,
514
561
  )
562
+
563
+ rag_stat.chunk_stat.total_input_tokens += sum(token_limiter_result.input_tokens_counts)
564
+ rag_stat.chunk_stat.total_generated_tokens += sum(token_limiter_result.generated_tokens_counts)
565
+ rag_stat.chunk_stat.model_name = token_limiter_result.model_name
566
+
567
+ final_relevant_docs = token_limiter_result.docs
515
568
  first_round_full_docs = token_limiter.first_round_full_docs
516
569
  second_round_extracted_docs = token_limiter.second_round_extracted_docs
517
570
  sencond_round_time = token_limiter.sencond_round_time
@@ -522,57 +575,64 @@ class LongContextRAG:
522
575
 
523
576
  logger.info(f"Finally send to model: {len(relevant_docs)}")
524
577
 
525
- query_table.add_row("Only contexts", str(only_contexts))
526
- query_table.add_row("Filter time", f"{filter_time:.2f} seconds")
527
- query_table.add_row("Final relevant docs", str(len(relevant_docs)))
528
- query_table.add_row(
529
- "first_round_full_docs", str(len(first_round_full_docs))
530
- )
531
- query_table.add_row(
532
- "second_round_extracted_docs", str(len(second_round_extracted_docs))
533
- )
534
- query_table.add_row(
535
- "Second round time", f"{sencond_round_time:.2f} seconds"
578
+ # 记录分段处理的统计信息
579
+ logger.info(
580
+ f"=== Token Management ===\n"
581
+ f" * Only contexts: {only_contexts}\n"
582
+ f" * Filter time: {filter_time:.2f} seconds\n"
583
+ f" * Final relevant docs: {len(relevant_docs)}\n"
584
+ f" * First round full docs: {len(first_round_full_docs)}\n"
585
+ f" * Second round extracted docs: {len(second_round_extracted_docs)}\n"
586
+ f" * Second round time: {sencond_round_time:.2f} seconds"
536
587
  )
537
588
 
538
- # Add relevant docs information
589
+ # 记录最终选择的文档详情
539
590
  final_relevant_docs_info = []
540
- for doc in relevant_docs:
541
- info = f"- {doc.module_name.replace(self.path,'',1)}"
591
+ for i, doc in enumerate(relevant_docs):
592
+ doc_path = doc.module_name.replace(self.path, '', 1)
593
+ info = f"{i+1}. {doc_path}"
594
+
595
+ metadata_info = []
542
596
  if "original_docs" in doc.metadata:
543
597
  original_docs = ", ".join(
544
598
  [
545
- doc.replace(self.path, "", 1)
546
- for doc in doc.metadata["original_docs"]
599
+ od.replace(self.path, "", 1)
600
+ for od in doc.metadata["original_docs"]
547
601
  ]
548
602
  )
549
- info += f" (Original docs: {original_docs})"
603
+ metadata_info.append(f"Original docs: {original_docs}")
604
+
550
605
  if "chunk_ranges" in doc.metadata:
551
606
  chunk_ranges = json.dumps(
552
607
  doc.metadata["chunk_ranges"], ensure_ascii=False
553
608
  )
554
- info += f" (Chunk ranges: {chunk_ranges})"
609
+ metadata_info.append(f"Chunk ranges: {chunk_ranges}")
610
+
611
+ if "processing_time" in doc.metadata:
612
+ metadata_info.append(f"Processing time: {doc.metadata['processing_time']:.2f}s")
613
+
614
+ if metadata_info:
615
+ info += f" ({'; '.join(metadata_info)})"
616
+
555
617
  final_relevant_docs_info.append(info)
556
618
 
557
- final_relevant_docs_info = "\n".join(final_relevant_docs_info)
558
- query_table.add_row("Final Relevant docs list", final_relevant_docs_info)
559
-
560
- # Create a panel to contain the table
561
- panel = Panel(
562
- query_table,
563
- title="RAG Search Results",
564
- expand=False,
619
+ if final_relevant_docs_info:
620
+ logger.info(
621
+ f"Final documents to be sent to model:"
622
+ + "".join([f"\n * {info}" for info in final_relevant_docs_info])
565
623
  )
566
624
 
567
- # Log the panel using rich
568
- console.print(panel)
569
-
625
+ # 记录令牌统计
570
626
  request_tokens = sum([doc.tokens for doc in relevant_docs])
571
627
  target_model = model or self.llm.default_model_name
572
628
  logger.info(
573
- f"Start to send to model {target_model} with {request_tokens} tokens"
629
+ f"=== LLM Request ===\n"
630
+ f" * Target model: {target_model}\n"
631
+ f" * Total tokens: {request_tokens}"
574
632
  )
575
633
 
634
+ logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
635
+
576
636
  if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
577
637
  llm_compute_engine = LLMComputeEngine(
578
638
  llm=target_llm,
@@ -585,17 +645,22 @@ class LongContextRAG:
585
645
  new_conversations = llm_compute_engine.process_conversation(
586
646
  conversations, query, [doc.source_code for doc in relevant_docs]
587
647
  )
588
-
589
- return (
590
- llm_compute_engine.stream_chat_oai(
648
+ chunks = llm_compute_engine.stream_chat_oai(
591
649
  conversations=new_conversations,
592
650
  model=model,
593
651
  role_mapping=role_mapping,
594
652
  llm_config=llm_config,
595
653
  delta_mode=True,
596
- ),
597
- context,
598
- )
654
+ )
655
+
656
+ def generate_chunks():
657
+ for chunk in chunks:
658
+ yield chunk[0]
659
+ if chunk[1] is not None:
660
+ rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
661
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
662
+ self._print_rag_stats(rag_stat)
663
+ return generate_chunks(), context
599
664
 
600
665
  new_conversations = conversations[:-1] + [
601
666
  {
@@ -614,5 +679,85 @@ class LongContextRAG:
614
679
  llm_config=llm_config,
615
680
  delta_mode=True,
616
681
  )
682
+
683
+ def generate_chunks():
684
+ for chunk in chunks:
685
+ yield chunk[0]
686
+ if chunk[1] is not None:
687
+ rag_stat.answer_stat.total_input_tokens += chunk[1].input_tokens_count
688
+ rag_stat.answer_stat.total_generated_tokens += chunk[1].generated_tokens_count
689
+ self._print_rag_stats(rag_stat)
690
+ return generate_chunks(), context
691
+
692
+
617
693
 
618
- return (chunk[0] for chunk in chunks), context
694
+ def _print_rag_stats(self, rag_stat: RAGStat) -> None:
695
+ """打印RAG执行的详细统计信息"""
696
+ total_input_tokens = (
697
+ rag_stat.recall_stat.total_input_tokens +
698
+ rag_stat.chunk_stat.total_input_tokens +
699
+ rag_stat.answer_stat.total_input_tokens
700
+ )
701
+ total_generated_tokens = (
702
+ rag_stat.recall_stat.total_generated_tokens +
703
+ rag_stat.chunk_stat.total_generated_tokens +
704
+ rag_stat.answer_stat.total_generated_tokens
705
+ )
706
+ total_tokens = total_input_tokens + total_generated_tokens
707
+
708
+ # 避免除以零错误
709
+ if total_tokens == 0:
710
+ recall_percent = chunk_percent = answer_percent = 0
711
+ else:
712
+ recall_percent = (rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens) / total_tokens * 100
713
+ chunk_percent = (rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens) / total_tokens * 100
714
+ answer_percent = (rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens) / total_tokens * 100
715
+
716
+ logger.info(
717
+ f"=== RAG 执行统计信息 ===\n"
718
+ f"总令牌使用: {total_tokens} 令牌\n"
719
+ f" * 输入令牌总数: {total_input_tokens}\n"
720
+ f" * 生成令牌总数: {total_generated_tokens}\n"
721
+ f"\n"
722
+ f"阶段统计:\n"
723
+ f" 1. 文档检索阶段:\n"
724
+ f" - 模型: {rag_stat.recall_stat.model_name}\n"
725
+ f" - 输入令牌: {rag_stat.recall_stat.total_input_tokens}\n"
726
+ f" - 生成令牌: {rag_stat.recall_stat.total_generated_tokens}\n"
727
+ f" - 阶段总计: {rag_stat.recall_stat.total_input_tokens + rag_stat.recall_stat.total_generated_tokens}\n"
728
+ f"\n"
729
+ f" 2. 文档分块阶段:\n"
730
+ f" - 模型: {rag_stat.chunk_stat.model_name}\n"
731
+ f" - 输入令牌: {rag_stat.chunk_stat.total_input_tokens}\n"
732
+ f" - 生成令牌: {rag_stat.chunk_stat.total_generated_tokens}\n"
733
+ f" - 阶段总计: {rag_stat.chunk_stat.total_input_tokens + rag_stat.chunk_stat.total_generated_tokens}\n"
734
+ f"\n"
735
+ f" 3. 答案生成阶段:\n"
736
+ f" - 模型: {rag_stat.answer_stat.model_name}\n"
737
+ f" - 输入令牌: {rag_stat.answer_stat.total_input_tokens}\n"
738
+ f" - 生成令牌: {rag_stat.answer_stat.total_generated_tokens}\n"
739
+ f" - 阶段总计: {rag_stat.answer_stat.total_input_tokens + rag_stat.answer_stat.total_generated_tokens}\n"
740
+ f"\n"
741
+ f"令牌分布百分比:\n"
742
+ f" - 文档检索: {recall_percent:.1f}%\n"
743
+ f" - 文档分块: {chunk_percent:.1f}%\n"
744
+ f" - 答案生成: {answer_percent:.1f}%\n"
745
+ )
746
+
747
+ # 记录原始统计数据,以便调试
748
+ logger.debug(f"RAG Stat 原始数据: {rag_stat}")
749
+
750
+ # 返回成本估算
751
+ estimated_cost = self._estimate_token_cost(total_input_tokens, total_generated_tokens)
752
+ if estimated_cost > 0:
753
+ logger.info(f"估计成本: 约 ${estimated_cost:.4f} 人民币")
754
+
755
+ def _estimate_token_cost(self, input_tokens: int, output_tokens: int) -> float:
756
+ """估算当前请求的令牌成本(人民币)"""
757
+ # 实际应用中,可以根据不同模型设置不同价格
758
+ input_cost_per_1m = 2.0/1000000 # 每百万输入令牌的成本
759
+ output_cost_per_1m = 8.0/100000 # 每百万输出令牌的成本
760
+
761
+ cost = (input_tokens * input_cost_per_1m / 1000000) + (output_tokens* output_cost_per_1m/1000000)
762
+ return cost
763
+