auto-coder 0.1.172__py3-none-any.whl → 0.1.175__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -18,10 +18,15 @@ from loguru import logger
18
18
  from pydantic import BaseModel
19
19
 
20
20
  from autocoder.common import SourceCode
21
- from autocoder.rag.loaders import (extract_text_from_docx,
22
- extract_text_from_excel,
23
- extract_text_from_pdf,
24
- extract_text_from_ppt)
21
+ from autocoder.rag.loaders import (
22
+ extract_text_from_docx,
23
+ extract_text_from_excel,
24
+ extract_text_from_pdf,
25
+ extract_text_from_ppt,
26
+ )
27
+ from autocoder.rag import variable_holder
28
+ from autocoder.rag.token_counter import count_tokens_worker, count_tokens
29
+ from uuid import uuid4
25
30
 
26
31
  cache_lock = threading.Lock()
27
32
 
@@ -34,72 +39,62 @@ class AddOrUpdateEvent(BaseModel):
34
39
  file_infos: List[Tuple[str, str, float]]
35
40
 
36
41
 
37
- @ray.remote
38
- def process_file(file_info: Tuple[str, str, float]) -> List[SourceCode]:
42
+ def process_file_in_multi_process(
43
+ file_info: Tuple[str, str, float]
44
+ ) -> List[SourceCode]:
39
45
  start_time = time.time()
40
46
  file_path, relative_path, _ = file_info
41
47
  try:
42
48
  if file_path.endswith(".pdf"):
43
49
  with open(file_path, "rb") as f:
44
50
  content = extract_text_from_pdf(f.read())
45
- v = [SourceCode(module_name=file_path, source_code=content)]
46
- elif file_path.endswith(".docx"):
47
- with open(file_path, "rb") as f:
48
- content = extract_text_from_docx(f.read())
49
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
50
- elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
51
- sheets = extract_text_from_excel(file_path)
52
51
  v = [
53
52
  SourceCode(
54
- module_name=f"##File: {file_path}#{sheet[0]}",
55
- source_code=sheet[1],
53
+ module_name=file_path,
54
+ source_code=content,
55
+ tokens=count_tokens_worker(content),
56
56
  )
57
- for sheet in sheets
58
57
  ]
59
- elif file_path.endswith(".pptx"):
60
- slides = extract_text_from_ppt(file_path)
61
- content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
62
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
63
- else:
64
- with open(file_path, "r", encoding="utf-8") as f:
65
- content = f.read()
66
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
67
- logger.info(f"Load file {file_path} in {time.time() - start_time}")
68
- return v
69
- except Exception as e:
70
- logger.error(f"Error processing file {file_path}: {str(e)}")
71
- return []
72
-
73
-
74
- def process_file2(file_info: Tuple[str, str, float]) -> List[SourceCode]:
75
- start_time = time.time()
76
- file_path, relative_path, _ = file_info
77
- try:
78
- if file_path.endswith(".pdf"):
79
- with open(file_path, "rb") as f:
80
- content = extract_text_from_pdf(f.read())
81
- v = [SourceCode(module_name=file_path, source_code=content)]
82
58
  elif file_path.endswith(".docx"):
83
59
  with open(file_path, "rb") as f:
84
60
  content = extract_text_from_docx(f.read())
85
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
61
+ v = [
62
+ SourceCode(
63
+ module_name=f"##File: {file_path}",
64
+ source_code=content,
65
+ tokens=count_tokens_worker(content),
66
+ )
67
+ ]
86
68
  elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
87
69
  sheets = extract_text_from_excel(file_path)
88
70
  v = [
89
71
  SourceCode(
90
72
  module_name=f"##File: {file_path}#{sheet[0]}",
91
73
  source_code=sheet[1],
74
+ tokens=count_tokens_worker(sheet[1]),
92
75
  )
93
76
  for sheet in sheets
94
77
  ]
95
78
  elif file_path.endswith(".pptx"):
96
79
  slides = extract_text_from_ppt(file_path)
97
80
  content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
98
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
81
+ v = [
82
+ SourceCode(
83
+ module_name=f"##File: {file_path}",
84
+ source_code=content,
85
+ tokens=count_tokens_worker(content),
86
+ )
87
+ ]
99
88
  else:
100
89
  with open(file_path, "r", encoding="utf-8") as f:
101
90
  content = f.read()
102
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
91
+ v = [
92
+ SourceCode(
93
+ module_name=f"##File: {file_path}",
94
+ source_code=content,
95
+ tokens=count_tokens_worker(content),
96
+ )
97
+ ]
103
98
  logger.info(f"Load file {file_path} in {time.time() - start_time}")
104
99
  return v
105
100
  except Exception as e:
@@ -107,34 +102,59 @@ def process_file2(file_info: Tuple[str, str, float]) -> List[SourceCode]:
107
102
  return []
108
103
 
109
104
 
110
- def process_file3(file_path: str) -> List[SourceCode]:
105
+ def process_file_local(file_path: str) -> List[SourceCode]:
111
106
  start_time = time.time()
112
107
  try:
113
108
  if file_path.endswith(".pdf"):
114
109
  with open(file_path, "rb") as f:
115
110
  content = extract_text_from_pdf(f.read())
116
- v = [SourceCode(module_name=file_path, source_code=content)]
111
+ v = [
112
+ SourceCode(
113
+ module_name=file_path,
114
+ source_code=content,
115
+ tokens=count_tokens(content),
116
+ )
117
+ ]
117
118
  elif file_path.endswith(".docx"):
118
119
  with open(file_path, "rb") as f:
119
120
  content = extract_text_from_docx(f.read())
120
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
121
+ v = [
122
+ SourceCode(
123
+ module_name=f"##File: {file_path}",
124
+ source_code=content,
125
+ tokens=count_tokens(content),
126
+ )
127
+ ]
121
128
  elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
122
129
  sheets = extract_text_from_excel(file_path)
123
130
  v = [
124
131
  SourceCode(
125
132
  module_name=f"##File: {file_path}#{sheet[0]}",
126
133
  source_code=sheet[1],
134
+ tokens=count_tokens(sheet[1]),
127
135
  )
128
136
  for sheet in sheets
129
137
  ]
130
138
  elif file_path.endswith(".pptx"):
131
139
  slides = extract_text_from_ppt(file_path)
132
140
  content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
133
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
141
+ v = [
142
+ SourceCode(
143
+ module_name=f"##File: {file_path}",
144
+ source_code=content,
145
+ tokens=count_tokens(content),
146
+ )
147
+ ]
134
148
  else:
135
149
  with open(file_path, "r", encoding="utf-8") as f:
136
150
  content = f.read()
137
- v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
151
+ v = [
152
+ SourceCode(
153
+ module_name=f"##File: {file_path}",
154
+ source_code=content,
155
+ tokens=count_tokens(content),
156
+ )
157
+ ]
138
158
  logger.info(f"Load file {file_path} in {time.time() - start_time}")
139
159
  return v
140
160
  except Exception as e:
@@ -205,7 +225,7 @@ class AutoCoderRAGDocListener:
205
225
  self.update_cache(item)
206
226
 
207
227
  def update_cache(self, file_path):
208
- source_code = process_file3(file_path)
228
+ source_code = process_file_local(file_path)
209
229
  self.cache[file_path] = {
210
230
  "file_path": file_path,
211
231
  "content": [c.model_dump() for c in source_code],
@@ -220,7 +240,9 @@ class AutoCoderRAGDocListener:
220
240
 
221
241
  def open_watch(self):
222
242
  logger.info(f"start monitor: {self.path}...")
223
- for changes in watch(self.path, watch_filter=self.file_filter, stop_event=self.stop_event):
243
+ for changes in watch(
244
+ self.path, watch_filter=self.file_filter, stop_event=self.stop_event
245
+ ):
224
246
  for change in changes:
225
247
  (action, path) = change
226
248
  if action == Change.added or action == Change.modified:
@@ -290,7 +312,6 @@ class AutoCoderRAGAsyncUpdateQueue:
290
312
  self.thread.start()
291
313
  self.cache = self.read_cache()
292
314
 
293
-
294
315
  def _process_queue(self):
295
316
  while not self.stop_event.is_set():
296
317
  try:
@@ -324,8 +345,14 @@ class AutoCoderRAGAsyncUpdateQueue:
324
345
  # results = ray.get(
325
346
  # [process_file.remote(file_info) for file_info in files_to_process]
326
347
  # )
327
- with Pool(processes=os.cpu_count()) as pool:
328
- results = pool.map(process_file2, files_to_process)
348
+ from autocoder.rag.token_counter import initialize_tokenizer
349
+
350
+ with Pool(
351
+ processes=os.cpu_count(),
352
+ initializer=initialize_tokenizer,
353
+ initargs=(variable_holder.TOKENIZER_PATH,),
354
+ ) as pool:
355
+ results = pool.map(process_file_in_multi_process, files_to_process)
329
356
 
330
357
  for file_info, result in zip(files_to_process, results):
331
358
  self.update_cache(file_info, result)
@@ -365,7 +392,7 @@ class AutoCoderRAGAsyncUpdateQueue:
365
392
  elif isinstance(file_list, AddOrUpdateEvent):
366
393
  for file_info in file_list.file_infos:
367
394
  logger.info(f"{file_info[0]} is detected to be updated")
368
- result = process_file2(file_info)
395
+ result = process_file_local(file_info)
369
396
  self.update_cache(file_info, result)
370
397
 
371
398
  self.write_cache()
@@ -410,7 +437,9 @@ class AutoCoderRAGAsyncUpdateQueue:
410
437
  # 释放文件锁
411
438
  fcntl.flock(lockf, fcntl.LOCK_UN)
412
439
 
413
- def update_cache(self, file_info: Tuple[str, str, float], content: List[SourceCode]):
440
+ def update_cache(
441
+ self, file_info: Tuple[str, str, float], content: List[SourceCode]
442
+ ):
414
443
  file_path, relative_path, modify_time = file_info
415
444
  self.cache[file_path] = {
416
445
  "file_path": file_path,
@@ -485,11 +514,20 @@ class DocumentRetriever:
485
514
  required_exts: list,
486
515
  on_ray: bool = False,
487
516
  monitor_mode: bool = False,
517
+ single_file_token_limit: int = 60000,
518
+ disable_auto_window: bool = False,
488
519
  ) -> None:
489
520
  self.path = path
490
521
  self.ignore_spec = ignore_spec
491
522
  self.required_exts = required_exts
492
523
  self.monitor_mode = monitor_mode
524
+ self.single_file_token_limit = single_file_token_limit
525
+ self.disable_auto_window = disable_auto_window
526
+
527
+ # 多小的文件会被合并
528
+ self.small_file_token_limit = self.single_file_token_limit / 4
529
+ # 合并后的最大文件大小
530
+ self.small_file_merge_limit = self.single_file_token_limit / 2
493
531
 
494
532
  self.on_ray = on_ray
495
533
  if self.on_ray:
@@ -502,6 +540,13 @@ class DocumentRetriever:
502
540
  path, ignore_spec, required_exts
503
541
  )
504
542
 
543
+ logger.info(f"DocumentRetriever initialized with:")
544
+ logger.info(f" Path: {self.path}")
545
+ logger.info(f" Diable auto window: {self.disable_auto_window} ")
546
+ logger.info(f" Single file token limit: {self.single_file_token_limit}")
547
+ logger.info(f" Small file token limit: {self.small_file_token_limit}")
548
+ logger.info(f" Small file merge limit: {self.small_file_merge_limit}")
549
+
505
550
  def get_cache(self):
506
551
  if self.on_ray:
507
552
  return ray.get(self.cacher.get_cache.remote())
@@ -509,6 +554,102 @@ class DocumentRetriever:
509
554
  return self.cacher.get_cache()
510
555
 
511
556
  def retrieve_documents(self) -> Generator[SourceCode, None, None]:
557
+ logger.info("Starting document retrieval process")
558
+ waiting_list = []
559
+ waiting_tokens = 0
512
560
  for _, data in self.get_cache().items():
513
561
  for source_code in data["content"]:
514
- yield SourceCode.model_validate(source_code)
562
+ doc = SourceCode.model_validate(source_code)
563
+ if self.disable_auto_window:
564
+ yield doc
565
+ else:
566
+ if doc.tokens <= 0:
567
+ yield doc
568
+ elif doc.tokens < self.small_file_token_limit:
569
+ waiting_list, waiting_tokens = self._add_to_waiting_list(
570
+ doc, waiting_list, waiting_tokens
571
+ )
572
+ if waiting_tokens >= self.small_file_merge_limit:
573
+ yield from self._process_waiting_list(waiting_list)
574
+ waiting_list = []
575
+ waiting_tokens = 0
576
+ elif doc.tokens > self.single_file_token_limit:
577
+ yield from self._split_large_document(doc)
578
+ else:
579
+ yield doc
580
+ if waiting_list and not self.disable_auto_window:
581
+ yield from self._process_waiting_list(waiting_list)
582
+
583
+ logger.info("Document retrieval process completed")
584
+
585
+ def _add_to_waiting_list(
586
+ self, doc: SourceCode, waiting_list: List[SourceCode], waiting_tokens: int
587
+ ) -> Tuple[List[SourceCode], int]:
588
+ waiting_list.append(doc)
589
+ return waiting_list, waiting_tokens + doc.tokens
590
+
591
+ def _process_waiting_list(
592
+ self, waiting_list: List[SourceCode]
593
+ ) -> Generator[SourceCode, None, None]:
594
+ if len(waiting_list) == 1:
595
+ yield waiting_list[0]
596
+ elif len(waiting_list) > 1:
597
+ yield self._merge_documents(waiting_list)
598
+
599
+ def _merge_documents(self, docs: List[SourceCode]) -> SourceCode:
600
+ merged_content = "\n".join(
601
+ [f"#File: {doc.module_name}\n{doc.source_code}" for doc in docs]
602
+ )
603
+ merged_tokens = sum([doc.tokens for doc in docs])
604
+ merged_name = f"Merged_{len(docs)}_docs_{str(uuid4())}"
605
+ logger.info(
606
+ f"Merged {len(docs)} documents into {merged_name} (tokens: {merged_tokens})."
607
+ )
608
+ return SourceCode(
609
+ module_name=merged_name,
610
+ source_code=merged_content,
611
+ tokens=merged_tokens,
612
+ metadata={"original_docs": [doc.module_name for doc in docs]},
613
+ )
614
+
615
+ def _split_large_document(
616
+ self, doc: SourceCode
617
+ ) -> Generator[SourceCode, None, None]:
618
+ chunk_size = self.single_file_token_limit
619
+ total_chunks = (doc.tokens + chunk_size - 1) // chunk_size
620
+ logger.info(f"Splitting document {doc.module_name} into {total_chunks} chunks")
621
+ for i in range(0, doc.tokens, chunk_size):
622
+ chunk_content = doc.source_code[i : i + chunk_size]
623
+ chunk_tokens = min(chunk_size, doc.tokens - i)
624
+ chunk_name = f"{doc.module_name}#chunk{i//chunk_size+1}"
625
+ # logger.debug(f" Created chunk: {chunk_name} (tokens: {chunk_tokens})")
626
+ yield SourceCode(
627
+ module_name=chunk_name,
628
+ source_code=chunk_content,
629
+ tokens=chunk_tokens,
630
+ metadata={
631
+ "original_doc": doc.module_name,
632
+ "chunk_index": i // chunk_size + 1,
633
+ },
634
+ )
635
+
636
+ def _split_document(
637
+ self, doc: SourceCode, token_limit: int
638
+ ) -> Generator[SourceCode, None, None]:
639
+ remaining_tokens = doc.tokens
640
+ chunk_number = 1
641
+ start_index = 0
642
+
643
+ while remaining_tokens > 0:
644
+ end_index = start_index + token_limit
645
+ chunk_content = doc.source_code[start_index:end_index]
646
+ chunk_tokens = min(token_limit, remaining_tokens)
647
+
648
+ chunk_name = f"{doc.module_name}#{chunk_number:06d}"
649
+ yield SourceCode(
650
+ module_name=chunk_name, source_code=chunk_content, tokens=chunk_tokens
651
+ )
652
+
653
+ start_index = end_index
654
+ remaining_tokens -= chunk_tokens
655
+ chunk_number += 1
@@ -13,16 +13,22 @@ from openai import OpenAI
13
13
  from rich.console import Console
14
14
  from rich.panel import Panel
15
15
  from rich.table import Table
16
- from rich.text import Text
16
+ import statistics
17
17
 
18
18
  from autocoder.common import AutoCoderArgs, SourceCode
19
19
  from autocoder.rag.doc_filter import DocFilter
20
20
  from autocoder.rag.document_retriever import DocumentRetriever
21
- from autocoder.rag.relevant_utils import (DocRelevance, FilterDoc, TaskTiming,
22
- parse_relevance)
21
+ from autocoder.rag.relevant_utils import (
22
+ DocRelevance,
23
+ FilterDoc,
24
+ TaskTiming,
25
+ parse_relevance,
26
+ )
23
27
  from autocoder.rag.token_checker import check_token_limit
24
28
  from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
25
29
  from autocoder.rag.token_limiter import TokenLimiter
30
+ from tokenizers import Tokenizer
31
+ from autocoder.rag import variable_holder
26
32
 
27
33
 
28
34
  class LongContextRAG:
@@ -44,11 +50,26 @@ class LongContextRAG:
44
50
  self.path = path
45
51
  self.relevant_score = self.args.rag_doc_filter_relevance or 5
46
52
 
53
+ self.full_text_ratio = args.full_text_ratio
54
+ self.segment_ratio = args.segment_ratio
55
+ self.buff_ratio = 1 - self.full_text_ratio - self.segment_ratio
56
+
57
+ if self.buff_ratio < 0:
58
+ raise ValueError(
59
+ "The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
60
+ )
61
+
62
+ self.full_text_limit = int(args.rag_context_window_limit * self.full_text_ratio)
63
+ self.segment_limit = int(args.rag_context_window_limit * self.segment_ratio)
64
+ self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
65
+
47
66
  self.tokenizer = None
48
67
  self.tokenizer_path = tokenizer_path
49
68
  self.on_ray = False
50
69
 
51
70
  if self.tokenizer_path:
71
+ variable_holder.TOKENIZER_PATH = self.tokenizer_path
72
+ variable_holder.TOKENIZER_MODEL = Tokenizer.from_file(self.tokenizer_path)
52
73
  self.tokenizer = TokenCounter(self.tokenizer_path)
53
74
  else:
54
75
  if llm.is_model_exist("deepseek_tokenizer"):
@@ -96,24 +117,41 @@ class LongContextRAG:
96
117
  self.required_exts,
97
118
  self.on_ray,
98
119
  self.monitor_mode,
120
+ ## 确保全文区至少能放下一个文件
121
+ single_file_token_limit=self.full_text_limit - 100,
122
+ disable_auto_window=self.args.disable_auto_window
99
123
  )
100
124
 
101
125
  self.doc_filter = DocFilter(
102
126
  self.index_model, self.args, on_ray=self.on_ray, path=self.path
103
127
  )
104
-
105
- # 检查当前目录下所有文件是否超过 120k tokens ,并且打印出来
106
- self.token_exceed_files = []
107
- if self.tokenizer is not None:
108
- self.token_exceed_files = check_token_limit(
109
- count_tokens=self.count_tokens,
110
- token_limit=self.token_limit,
111
- retrieve_documents=self._retrieve_documents,
112
- max_workers=self.args.index_filter_workers or 5,
113
- )
128
+
129
+ doc_num = 0
130
+ token_num = 0
131
+ token_counts = []
132
+ for doc in self._retrieve_documents():
133
+ doc_num += 1
134
+ doc_tokens = doc.tokens
135
+ token_num += doc_tokens
136
+ token_counts.append(doc_tokens)
137
+
138
+ avg_tokens = statistics.mean(token_counts) if token_counts else 0
139
+ median_tokens = statistics.median(token_counts) if token_counts else 0
114
140
 
115
141
  logger.info(
116
- f"Tokenizer path: {self.tokenizer_path} relevant_score: {self.relevant_score} token_limit: {self.token_limit}"
142
+ "RAG Configuration:\n"
143
+ f" Total docs: {doc_num}\n"
144
+ f" Total tokens: {token_num}\n"
145
+ f" Tokenizer path: {self.tokenizer_path}\n"
146
+ f" Relevant score: {self.relevant_score}\n"
147
+ f" Token limit: {self.token_limit}\n"
148
+ f" Full text limit: {self.full_text_limit}\n"
149
+ f" Segment limit: {self.segment_limit}\n"
150
+ f" Buff limit: {self.buff_limit}\n"
151
+ f" Max doc tokens: {max(token_counts) if token_counts else 0}\n"
152
+ f" Min doc tokens: {min(token_counts) if token_counts else 0}\n"
153
+ f" Avg doc tokens: {avg_tokens:.2f}\n"
154
+ f" Median doc tokens: {median_tokens:.2f}\n"
117
155
  )
118
156
 
119
157
  def count_tokens(self, text: str) -> int:
@@ -350,9 +388,15 @@ class LongContextRAG:
350
388
  query_table.add_row("Relevant docs", str(len(relevant_docs)))
351
389
 
352
390
  # Add relevant docs information
353
- relevant_docs_info = "\n".join(
354
- [f"- {doc.module_name}" for doc in relevant_docs]
355
- )
391
+ relevant_docs_info = []
392
+ for doc in relevant_docs:
393
+ info = f"- {doc.module_name.replace(self.path,'',1)}"
394
+ if 'original_docs' in doc.metadata:
395
+ original_docs = ", ".join([doc.replace(self.path,"",1) for doc in doc.metadata['original_docs']])
396
+ info += f" (Original docs: {original_docs})"
397
+ relevant_docs_info.append(info)
398
+
399
+ relevant_docs_info = "\n".join(relevant_docs_info)
356
400
  query_table.add_row("Relevant docs list", relevant_docs_info)
357
401
 
358
402
  first_round_full_docs = []
@@ -363,7 +407,9 @@ class LongContextRAG:
363
407
 
364
408
  token_limiter = TokenLimiter(
365
409
  count_tokens=self.count_tokens,
366
- token_limit=self.token_limit,
410
+ full_text_limit=self.full_text_limit,
411
+ segment_limit=self.segment_limit,
412
+ buff_limit=self.buff_limit,
367
413
  llm=self.llm,
368
414
  )
369
415
  final_relevant_docs = token_limiter.limit_tokens(
@@ -395,9 +441,18 @@ class LongContextRAG:
395
441
  )
396
442
 
397
443
  # Add relevant docs information
398
- final_relevant_docs_info = "\n".join(
399
- [f"- {doc.module_name}" for doc in relevant_docs]
400
- )
444
+ final_relevant_docs_info = []
445
+ for doc in relevant_docs:
446
+ info = f"- {doc.module_name.replace(self.path,'',1)}"
447
+ if 'original_docs' in doc.metadata:
448
+ original_docs = ", ".join([doc.replace(self.path,"",1) for doc in doc.metadata['original_docs']])
449
+ info += f" (Original docs: {original_docs})"
450
+ if "chunk_ranges" in doc.metadata:
451
+ chunk_ranges = json.dumps(doc.metadata['chunk_ranges'],ensure_ascii=False)
452
+ info += f" (Chunk ranges: {chunk_ranges})"
453
+ final_relevant_docs_info.append(info)
454
+
455
+ final_relevant_docs_info = "\n".join(final_relevant_docs_info)
401
456
  query_table.add_row("Final Relevant docs list", final_relevant_docs_info)
402
457
 
403
458
  # Create a panel to contain the table
@@ -409,8 +464,10 @@ class LongContextRAG:
409
464
 
410
465
  # Log the panel using rich
411
466
  console.print(panel)
412
-
413
- logger.info(f"Start to send to model {model}")
467
+
468
+ request_tokens = sum([doc.tokens for doc in relevant_docs])
469
+ target_model = model or self.llm.default_model_name
470
+ logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
414
471
 
415
472
  new_conversations = conversations[:-1] + [
416
473
  {
@@ -2,29 +2,46 @@ import time
2
2
  from loguru import logger
3
3
  from tokenizers import Tokenizer
4
4
  from multiprocessing import Pool, cpu_count
5
+ from autocoder.rag.variable_holder import TOKENIZER_MODEL
6
+
5
7
 
6
8
  class RemoteTokenCounter:
7
- def __init__(self,tokenizer) -> None:
9
+ def __init__(self, tokenizer) -> None:
8
10
  self.tokenizer = tokenizer
9
11
 
10
- def count_tokens(self, text: str) -> int:
11
- try:
12
+ def count_tokens(self, text: str) -> int:
13
+ try:
12
14
  v = self.tokenizer.chat_oai(
13
15
  conversations=[{"role": "user", "content": text}]
14
- )
16
+ )
15
17
  return int(v[0].output)
16
18
  except Exception as e:
17
19
  logger.error(f"Error counting tokens: {str(e)}")
18
20
  return -1
19
-
21
+
22
+
20
23
  def initialize_tokenizer(tokenizer_path):
21
- global tokenizer_model
24
+ global tokenizer_model
22
25
  tokenizer_model = Tokenizer.from_file(tokenizer_path)
23
26
 
27
+
28
+ def count_tokens(text: str) -> int:
29
+ try:
30
+ # start_time = time.time_ns()
31
+ encoded = TOKENIZER_MODEL.encode('{"role":"user","content":"' + text + '"}')
32
+ v = len(encoded.ids)
33
+ # elapsed_time = time.time_ns() - start_time
34
+ # logger.info(f"Token counting took {elapsed_time/1000000} ms")
35
+ return v
36
+ except Exception as e:
37
+ logger.error(f"Error counting tokens: {str(e)}")
38
+ return -1
39
+
40
+
24
41
  def count_tokens_worker(text: str) -> int:
25
42
  try:
26
43
  # start_time = time.time_ns()
27
- encoded = tokenizer_model.encode('{"role":"user","content":"'+text+'"}')
44
+ encoded = tokenizer_model.encode('{"role":"user","content":"' + text + '"}')
28
45
  v = len(encoded.ids)
29
46
  # elapsed_time = time.time_ns() - start_time
30
47
  # logger.info(f"Token counting took {elapsed_time/1000000} ms")
@@ -33,11 +50,16 @@ def count_tokens_worker(text: str) -> int:
33
50
  logger.error(f"Error counting tokens: {str(e)}")
34
51
  return -1
35
52
 
53
+
36
54
  class TokenCounter:
37
55
  def __init__(self, tokenizer_path: str):
38
56
  self.tokenizer_path = tokenizer_path
39
57
  self.num_processes = cpu_count() - 1 if cpu_count() > 1 else 1
40
- self.pool = Pool(processes=self.num_processes, initializer=initialize_tokenizer, initargs=(self.tokenizer_path,))
58
+ self.pool = Pool(
59
+ processes=self.num_processes,
60
+ initializer=initialize_tokenizer,
61
+ initargs=(self.tokenizer_path,),
62
+ )
41
63
 
42
64
  def count_tokens(self, text: str) -> int:
43
- return self.pool.apply(count_tokens_worker, (text,))
65
+ return self.pool.apply(count_tokens_worker, (text,))