auto-coder 0.1.172__py3-none-any.whl → 0.1.175__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/METADATA +3 -1
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/RECORD +26 -24
- autocoder/agent/designer.py +385 -0
- autocoder/auto_coder.py +32 -8
- autocoder/auto_coder_lang.py +2 -0
- autocoder/auto_coder_rag.py +41 -13
- autocoder/chat_auto_coder.py +144 -21
- autocoder/chat_auto_coder_lang.py +3 -0
- autocoder/command_args.py +12 -2
- autocoder/common/__init__.py +11 -1
- autocoder/common/command_completer.py +4 -0
- autocoder/common/command_generator.py +4 -5
- autocoder/lang.py +2 -0
- autocoder/pyproject/__init__.py +5 -1
- autocoder/rag/document_retriever.py +196 -55
- autocoder/rag/long_context_rag.py +80 -23
- autocoder/rag/token_counter.py +31 -9
- autocoder/rag/token_limiter.py +34 -9
- autocoder/rag/variable_holder.py +2 -0
- autocoder/suffixproject/__init__.py +5 -1
- autocoder/tsproject/__init__.py +5 -1
- autocoder/version.py +1 -1
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.172.dist-info → auto_coder-0.1.175.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,15 @@ from loguru import logger
|
|
|
18
18
|
from pydantic import BaseModel
|
|
19
19
|
|
|
20
20
|
from autocoder.common import SourceCode
|
|
21
|
-
from autocoder.rag.loaders import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
21
|
+
from autocoder.rag.loaders import (
|
|
22
|
+
extract_text_from_docx,
|
|
23
|
+
extract_text_from_excel,
|
|
24
|
+
extract_text_from_pdf,
|
|
25
|
+
extract_text_from_ppt,
|
|
26
|
+
)
|
|
27
|
+
from autocoder.rag import variable_holder
|
|
28
|
+
from autocoder.rag.token_counter import count_tokens_worker, count_tokens
|
|
29
|
+
from uuid import uuid4
|
|
25
30
|
|
|
26
31
|
cache_lock = threading.Lock()
|
|
27
32
|
|
|
@@ -34,72 +39,62 @@ class AddOrUpdateEvent(BaseModel):
|
|
|
34
39
|
file_infos: List[Tuple[str, str, float]]
|
|
35
40
|
|
|
36
41
|
|
|
37
|
-
|
|
38
|
-
|
|
42
|
+
def process_file_in_multi_process(
|
|
43
|
+
file_info: Tuple[str, str, float]
|
|
44
|
+
) -> List[SourceCode]:
|
|
39
45
|
start_time = time.time()
|
|
40
46
|
file_path, relative_path, _ = file_info
|
|
41
47
|
try:
|
|
42
48
|
if file_path.endswith(".pdf"):
|
|
43
49
|
with open(file_path, "rb") as f:
|
|
44
50
|
content = extract_text_from_pdf(f.read())
|
|
45
|
-
v = [SourceCode(module_name=file_path, source_code=content)]
|
|
46
|
-
elif file_path.endswith(".docx"):
|
|
47
|
-
with open(file_path, "rb") as f:
|
|
48
|
-
content = extract_text_from_docx(f.read())
|
|
49
|
-
v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
|
|
50
|
-
elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
|
|
51
|
-
sheets = extract_text_from_excel(file_path)
|
|
52
51
|
v = [
|
|
53
52
|
SourceCode(
|
|
54
|
-
module_name=
|
|
55
|
-
source_code=
|
|
53
|
+
module_name=file_path,
|
|
54
|
+
source_code=content,
|
|
55
|
+
tokens=count_tokens_worker(content),
|
|
56
56
|
)
|
|
57
|
-
for sheet in sheets
|
|
58
57
|
]
|
|
59
|
-
elif file_path.endswith(".pptx"):
|
|
60
|
-
slides = extract_text_from_ppt(file_path)
|
|
61
|
-
content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
|
|
62
|
-
v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
|
|
63
|
-
else:
|
|
64
|
-
with open(file_path, "r", encoding="utf-8") as f:
|
|
65
|
-
content = f.read()
|
|
66
|
-
v = [SourceCode(module_name=f"##File: {file_path}", source_code=content)]
|
|
67
|
-
logger.info(f"Load file {file_path} in {time.time() - start_time}")
|
|
68
|
-
return v
|
|
69
|
-
except Exception as e:
|
|
70
|
-
logger.error(f"Error processing file {file_path}: {str(e)}")
|
|
71
|
-
return []
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def process_file2(file_info: Tuple[str, str, float]) -> List[SourceCode]:
|
|
75
|
-
start_time = time.time()
|
|
76
|
-
file_path, relative_path, _ = file_info
|
|
77
|
-
try:
|
|
78
|
-
if file_path.endswith(".pdf"):
|
|
79
|
-
with open(file_path, "rb") as f:
|
|
80
|
-
content = extract_text_from_pdf(f.read())
|
|
81
|
-
v = [SourceCode(module_name=file_path, source_code=content)]
|
|
82
58
|
elif file_path.endswith(".docx"):
|
|
83
59
|
with open(file_path, "rb") as f:
|
|
84
60
|
content = extract_text_from_docx(f.read())
|
|
85
|
-
v = [
|
|
61
|
+
v = [
|
|
62
|
+
SourceCode(
|
|
63
|
+
module_name=f"##File: {file_path}",
|
|
64
|
+
source_code=content,
|
|
65
|
+
tokens=count_tokens_worker(content),
|
|
66
|
+
)
|
|
67
|
+
]
|
|
86
68
|
elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
|
|
87
69
|
sheets = extract_text_from_excel(file_path)
|
|
88
70
|
v = [
|
|
89
71
|
SourceCode(
|
|
90
72
|
module_name=f"##File: {file_path}#{sheet[0]}",
|
|
91
73
|
source_code=sheet[1],
|
|
74
|
+
tokens=count_tokens_worker(sheet[1]),
|
|
92
75
|
)
|
|
93
76
|
for sheet in sheets
|
|
94
77
|
]
|
|
95
78
|
elif file_path.endswith(".pptx"):
|
|
96
79
|
slides = extract_text_from_ppt(file_path)
|
|
97
80
|
content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
|
|
98
|
-
v = [
|
|
81
|
+
v = [
|
|
82
|
+
SourceCode(
|
|
83
|
+
module_name=f"##File: {file_path}",
|
|
84
|
+
source_code=content,
|
|
85
|
+
tokens=count_tokens_worker(content),
|
|
86
|
+
)
|
|
87
|
+
]
|
|
99
88
|
else:
|
|
100
89
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
101
90
|
content = f.read()
|
|
102
|
-
v = [
|
|
91
|
+
v = [
|
|
92
|
+
SourceCode(
|
|
93
|
+
module_name=f"##File: {file_path}",
|
|
94
|
+
source_code=content,
|
|
95
|
+
tokens=count_tokens_worker(content),
|
|
96
|
+
)
|
|
97
|
+
]
|
|
103
98
|
logger.info(f"Load file {file_path} in {time.time() - start_time}")
|
|
104
99
|
return v
|
|
105
100
|
except Exception as e:
|
|
@@ -107,34 +102,59 @@ def process_file2(file_info: Tuple[str, str, float]) -> List[SourceCode]:
|
|
|
107
102
|
return []
|
|
108
103
|
|
|
109
104
|
|
|
110
|
-
def
|
|
105
|
+
def process_file_local(file_path: str) -> List[SourceCode]:
|
|
111
106
|
start_time = time.time()
|
|
112
107
|
try:
|
|
113
108
|
if file_path.endswith(".pdf"):
|
|
114
109
|
with open(file_path, "rb") as f:
|
|
115
110
|
content = extract_text_from_pdf(f.read())
|
|
116
|
-
v = [
|
|
111
|
+
v = [
|
|
112
|
+
SourceCode(
|
|
113
|
+
module_name=file_path,
|
|
114
|
+
source_code=content,
|
|
115
|
+
tokens=count_tokens(content),
|
|
116
|
+
)
|
|
117
|
+
]
|
|
117
118
|
elif file_path.endswith(".docx"):
|
|
118
119
|
with open(file_path, "rb") as f:
|
|
119
120
|
content = extract_text_from_docx(f.read())
|
|
120
|
-
v = [
|
|
121
|
+
v = [
|
|
122
|
+
SourceCode(
|
|
123
|
+
module_name=f"##File: {file_path}",
|
|
124
|
+
source_code=content,
|
|
125
|
+
tokens=count_tokens(content),
|
|
126
|
+
)
|
|
127
|
+
]
|
|
121
128
|
elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
|
|
122
129
|
sheets = extract_text_from_excel(file_path)
|
|
123
130
|
v = [
|
|
124
131
|
SourceCode(
|
|
125
132
|
module_name=f"##File: {file_path}#{sheet[0]}",
|
|
126
133
|
source_code=sheet[1],
|
|
134
|
+
tokens=count_tokens(sheet[1]),
|
|
127
135
|
)
|
|
128
136
|
for sheet in sheets
|
|
129
137
|
]
|
|
130
138
|
elif file_path.endswith(".pptx"):
|
|
131
139
|
slides = extract_text_from_ppt(file_path)
|
|
132
140
|
content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
|
|
133
|
-
v = [
|
|
141
|
+
v = [
|
|
142
|
+
SourceCode(
|
|
143
|
+
module_name=f"##File: {file_path}",
|
|
144
|
+
source_code=content,
|
|
145
|
+
tokens=count_tokens(content),
|
|
146
|
+
)
|
|
147
|
+
]
|
|
134
148
|
else:
|
|
135
149
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
136
150
|
content = f.read()
|
|
137
|
-
v = [
|
|
151
|
+
v = [
|
|
152
|
+
SourceCode(
|
|
153
|
+
module_name=f"##File: {file_path}",
|
|
154
|
+
source_code=content,
|
|
155
|
+
tokens=count_tokens(content),
|
|
156
|
+
)
|
|
157
|
+
]
|
|
138
158
|
logger.info(f"Load file {file_path} in {time.time() - start_time}")
|
|
139
159
|
return v
|
|
140
160
|
except Exception as e:
|
|
@@ -205,7 +225,7 @@ class AutoCoderRAGDocListener:
|
|
|
205
225
|
self.update_cache(item)
|
|
206
226
|
|
|
207
227
|
def update_cache(self, file_path):
|
|
208
|
-
source_code =
|
|
228
|
+
source_code = process_file_local(file_path)
|
|
209
229
|
self.cache[file_path] = {
|
|
210
230
|
"file_path": file_path,
|
|
211
231
|
"content": [c.model_dump() for c in source_code],
|
|
@@ -220,7 +240,9 @@ class AutoCoderRAGDocListener:
|
|
|
220
240
|
|
|
221
241
|
def open_watch(self):
|
|
222
242
|
logger.info(f"start monitor: {self.path}...")
|
|
223
|
-
for changes in watch(
|
|
243
|
+
for changes in watch(
|
|
244
|
+
self.path, watch_filter=self.file_filter, stop_event=self.stop_event
|
|
245
|
+
):
|
|
224
246
|
for change in changes:
|
|
225
247
|
(action, path) = change
|
|
226
248
|
if action == Change.added or action == Change.modified:
|
|
@@ -290,7 +312,6 @@ class AutoCoderRAGAsyncUpdateQueue:
|
|
|
290
312
|
self.thread.start()
|
|
291
313
|
self.cache = self.read_cache()
|
|
292
314
|
|
|
293
|
-
|
|
294
315
|
def _process_queue(self):
|
|
295
316
|
while not self.stop_event.is_set():
|
|
296
317
|
try:
|
|
@@ -324,8 +345,14 @@ class AutoCoderRAGAsyncUpdateQueue:
|
|
|
324
345
|
# results = ray.get(
|
|
325
346
|
# [process_file.remote(file_info) for file_info in files_to_process]
|
|
326
347
|
# )
|
|
327
|
-
|
|
328
|
-
|
|
348
|
+
from autocoder.rag.token_counter import initialize_tokenizer
|
|
349
|
+
|
|
350
|
+
with Pool(
|
|
351
|
+
processes=os.cpu_count(),
|
|
352
|
+
initializer=initialize_tokenizer,
|
|
353
|
+
initargs=(variable_holder.TOKENIZER_PATH,),
|
|
354
|
+
) as pool:
|
|
355
|
+
results = pool.map(process_file_in_multi_process, files_to_process)
|
|
329
356
|
|
|
330
357
|
for file_info, result in zip(files_to_process, results):
|
|
331
358
|
self.update_cache(file_info, result)
|
|
@@ -365,7 +392,7 @@ class AutoCoderRAGAsyncUpdateQueue:
|
|
|
365
392
|
elif isinstance(file_list, AddOrUpdateEvent):
|
|
366
393
|
for file_info in file_list.file_infos:
|
|
367
394
|
logger.info(f"{file_info[0]} is detected to be updated")
|
|
368
|
-
result =
|
|
395
|
+
result = process_file_local(file_info)
|
|
369
396
|
self.update_cache(file_info, result)
|
|
370
397
|
|
|
371
398
|
self.write_cache()
|
|
@@ -410,7 +437,9 @@ class AutoCoderRAGAsyncUpdateQueue:
|
|
|
410
437
|
# 释放文件锁
|
|
411
438
|
fcntl.flock(lockf, fcntl.LOCK_UN)
|
|
412
439
|
|
|
413
|
-
def update_cache(
|
|
440
|
+
def update_cache(
|
|
441
|
+
self, file_info: Tuple[str, str, float], content: List[SourceCode]
|
|
442
|
+
):
|
|
414
443
|
file_path, relative_path, modify_time = file_info
|
|
415
444
|
self.cache[file_path] = {
|
|
416
445
|
"file_path": file_path,
|
|
@@ -485,11 +514,20 @@ class DocumentRetriever:
|
|
|
485
514
|
required_exts: list,
|
|
486
515
|
on_ray: bool = False,
|
|
487
516
|
monitor_mode: bool = False,
|
|
517
|
+
single_file_token_limit: int = 60000,
|
|
518
|
+
disable_auto_window: bool = False,
|
|
488
519
|
) -> None:
|
|
489
520
|
self.path = path
|
|
490
521
|
self.ignore_spec = ignore_spec
|
|
491
522
|
self.required_exts = required_exts
|
|
492
523
|
self.monitor_mode = monitor_mode
|
|
524
|
+
self.single_file_token_limit = single_file_token_limit
|
|
525
|
+
self.disable_auto_window = disable_auto_window
|
|
526
|
+
|
|
527
|
+
# 多小的文件会被合并
|
|
528
|
+
self.small_file_token_limit = self.single_file_token_limit / 4
|
|
529
|
+
# 合并后的最大文件大小
|
|
530
|
+
self.small_file_merge_limit = self.single_file_token_limit / 2
|
|
493
531
|
|
|
494
532
|
self.on_ray = on_ray
|
|
495
533
|
if self.on_ray:
|
|
@@ -502,6 +540,13 @@ class DocumentRetriever:
|
|
|
502
540
|
path, ignore_spec, required_exts
|
|
503
541
|
)
|
|
504
542
|
|
|
543
|
+
logger.info(f"DocumentRetriever initialized with:")
|
|
544
|
+
logger.info(f" Path: {self.path}")
|
|
545
|
+
logger.info(f" Diable auto window: {self.disable_auto_window} ")
|
|
546
|
+
logger.info(f" Single file token limit: {self.single_file_token_limit}")
|
|
547
|
+
logger.info(f" Small file token limit: {self.small_file_token_limit}")
|
|
548
|
+
logger.info(f" Small file merge limit: {self.small_file_merge_limit}")
|
|
549
|
+
|
|
505
550
|
def get_cache(self):
|
|
506
551
|
if self.on_ray:
|
|
507
552
|
return ray.get(self.cacher.get_cache.remote())
|
|
@@ -509,6 +554,102 @@ class DocumentRetriever:
|
|
|
509
554
|
return self.cacher.get_cache()
|
|
510
555
|
|
|
511
556
|
def retrieve_documents(self) -> Generator[SourceCode, None, None]:
|
|
557
|
+
logger.info("Starting document retrieval process")
|
|
558
|
+
waiting_list = []
|
|
559
|
+
waiting_tokens = 0
|
|
512
560
|
for _, data in self.get_cache().items():
|
|
513
561
|
for source_code in data["content"]:
|
|
514
|
-
|
|
562
|
+
doc = SourceCode.model_validate(source_code)
|
|
563
|
+
if self.disable_auto_window:
|
|
564
|
+
yield doc
|
|
565
|
+
else:
|
|
566
|
+
if doc.tokens <= 0:
|
|
567
|
+
yield doc
|
|
568
|
+
elif doc.tokens < self.small_file_token_limit:
|
|
569
|
+
waiting_list, waiting_tokens = self._add_to_waiting_list(
|
|
570
|
+
doc, waiting_list, waiting_tokens
|
|
571
|
+
)
|
|
572
|
+
if waiting_tokens >= self.small_file_merge_limit:
|
|
573
|
+
yield from self._process_waiting_list(waiting_list)
|
|
574
|
+
waiting_list = []
|
|
575
|
+
waiting_tokens = 0
|
|
576
|
+
elif doc.tokens > self.single_file_token_limit:
|
|
577
|
+
yield from self._split_large_document(doc)
|
|
578
|
+
else:
|
|
579
|
+
yield doc
|
|
580
|
+
if waiting_list and not self.disable_auto_window:
|
|
581
|
+
yield from self._process_waiting_list(waiting_list)
|
|
582
|
+
|
|
583
|
+
logger.info("Document retrieval process completed")
|
|
584
|
+
|
|
585
|
+
def _add_to_waiting_list(
|
|
586
|
+
self, doc: SourceCode, waiting_list: List[SourceCode], waiting_tokens: int
|
|
587
|
+
) -> Tuple[List[SourceCode], int]:
|
|
588
|
+
waiting_list.append(doc)
|
|
589
|
+
return waiting_list, waiting_tokens + doc.tokens
|
|
590
|
+
|
|
591
|
+
def _process_waiting_list(
|
|
592
|
+
self, waiting_list: List[SourceCode]
|
|
593
|
+
) -> Generator[SourceCode, None, None]:
|
|
594
|
+
if len(waiting_list) == 1:
|
|
595
|
+
yield waiting_list[0]
|
|
596
|
+
elif len(waiting_list) > 1:
|
|
597
|
+
yield self._merge_documents(waiting_list)
|
|
598
|
+
|
|
599
|
+
def _merge_documents(self, docs: List[SourceCode]) -> SourceCode:
|
|
600
|
+
merged_content = "\n".join(
|
|
601
|
+
[f"#File: {doc.module_name}\n{doc.source_code}" for doc in docs]
|
|
602
|
+
)
|
|
603
|
+
merged_tokens = sum([doc.tokens for doc in docs])
|
|
604
|
+
merged_name = f"Merged_{len(docs)}_docs_{str(uuid4())}"
|
|
605
|
+
logger.info(
|
|
606
|
+
f"Merged {len(docs)} documents into {merged_name} (tokens: {merged_tokens})."
|
|
607
|
+
)
|
|
608
|
+
return SourceCode(
|
|
609
|
+
module_name=merged_name,
|
|
610
|
+
source_code=merged_content,
|
|
611
|
+
tokens=merged_tokens,
|
|
612
|
+
metadata={"original_docs": [doc.module_name for doc in docs]},
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
def _split_large_document(
|
|
616
|
+
self, doc: SourceCode
|
|
617
|
+
) -> Generator[SourceCode, None, None]:
|
|
618
|
+
chunk_size = self.single_file_token_limit
|
|
619
|
+
total_chunks = (doc.tokens + chunk_size - 1) // chunk_size
|
|
620
|
+
logger.info(f"Splitting document {doc.module_name} into {total_chunks} chunks")
|
|
621
|
+
for i in range(0, doc.tokens, chunk_size):
|
|
622
|
+
chunk_content = doc.source_code[i : i + chunk_size]
|
|
623
|
+
chunk_tokens = min(chunk_size, doc.tokens - i)
|
|
624
|
+
chunk_name = f"{doc.module_name}#chunk{i//chunk_size+1}"
|
|
625
|
+
# logger.debug(f" Created chunk: {chunk_name} (tokens: {chunk_tokens})")
|
|
626
|
+
yield SourceCode(
|
|
627
|
+
module_name=chunk_name,
|
|
628
|
+
source_code=chunk_content,
|
|
629
|
+
tokens=chunk_tokens,
|
|
630
|
+
metadata={
|
|
631
|
+
"original_doc": doc.module_name,
|
|
632
|
+
"chunk_index": i // chunk_size + 1,
|
|
633
|
+
},
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
def _split_document(
|
|
637
|
+
self, doc: SourceCode, token_limit: int
|
|
638
|
+
) -> Generator[SourceCode, None, None]:
|
|
639
|
+
remaining_tokens = doc.tokens
|
|
640
|
+
chunk_number = 1
|
|
641
|
+
start_index = 0
|
|
642
|
+
|
|
643
|
+
while remaining_tokens > 0:
|
|
644
|
+
end_index = start_index + token_limit
|
|
645
|
+
chunk_content = doc.source_code[start_index:end_index]
|
|
646
|
+
chunk_tokens = min(token_limit, remaining_tokens)
|
|
647
|
+
|
|
648
|
+
chunk_name = f"{doc.module_name}#{chunk_number:06d}"
|
|
649
|
+
yield SourceCode(
|
|
650
|
+
module_name=chunk_name, source_code=chunk_content, tokens=chunk_tokens
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
start_index = end_index
|
|
654
|
+
remaining_tokens -= chunk_tokens
|
|
655
|
+
chunk_number += 1
|
|
@@ -13,16 +13,22 @@ from openai import OpenAI
|
|
|
13
13
|
from rich.console import Console
|
|
14
14
|
from rich.panel import Panel
|
|
15
15
|
from rich.table import Table
|
|
16
|
-
|
|
16
|
+
import statistics
|
|
17
17
|
|
|
18
18
|
from autocoder.common import AutoCoderArgs, SourceCode
|
|
19
19
|
from autocoder.rag.doc_filter import DocFilter
|
|
20
20
|
from autocoder.rag.document_retriever import DocumentRetriever
|
|
21
|
-
from autocoder.rag.relevant_utils import (
|
|
22
|
-
|
|
21
|
+
from autocoder.rag.relevant_utils import (
|
|
22
|
+
DocRelevance,
|
|
23
|
+
FilterDoc,
|
|
24
|
+
TaskTiming,
|
|
25
|
+
parse_relevance,
|
|
26
|
+
)
|
|
23
27
|
from autocoder.rag.token_checker import check_token_limit
|
|
24
28
|
from autocoder.rag.token_counter import RemoteTokenCounter, TokenCounter
|
|
25
29
|
from autocoder.rag.token_limiter import TokenLimiter
|
|
30
|
+
from tokenizers import Tokenizer
|
|
31
|
+
from autocoder.rag import variable_holder
|
|
26
32
|
|
|
27
33
|
|
|
28
34
|
class LongContextRAG:
|
|
@@ -44,11 +50,26 @@ class LongContextRAG:
|
|
|
44
50
|
self.path = path
|
|
45
51
|
self.relevant_score = self.args.rag_doc_filter_relevance or 5
|
|
46
52
|
|
|
53
|
+
self.full_text_ratio = args.full_text_ratio
|
|
54
|
+
self.segment_ratio = args.segment_ratio
|
|
55
|
+
self.buff_ratio = 1 - self.full_text_ratio - self.segment_ratio
|
|
56
|
+
|
|
57
|
+
if self.buff_ratio < 0:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"The sum of full_text_ratio and segment_ratio must be less than or equal to 1.0"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.full_text_limit = int(args.rag_context_window_limit * self.full_text_ratio)
|
|
63
|
+
self.segment_limit = int(args.rag_context_window_limit * self.segment_ratio)
|
|
64
|
+
self.buff_limit = int(args.rag_context_window_limit * self.buff_ratio)
|
|
65
|
+
|
|
47
66
|
self.tokenizer = None
|
|
48
67
|
self.tokenizer_path = tokenizer_path
|
|
49
68
|
self.on_ray = False
|
|
50
69
|
|
|
51
70
|
if self.tokenizer_path:
|
|
71
|
+
variable_holder.TOKENIZER_PATH = self.tokenizer_path
|
|
72
|
+
variable_holder.TOKENIZER_MODEL = Tokenizer.from_file(self.tokenizer_path)
|
|
52
73
|
self.tokenizer = TokenCounter(self.tokenizer_path)
|
|
53
74
|
else:
|
|
54
75
|
if llm.is_model_exist("deepseek_tokenizer"):
|
|
@@ -96,24 +117,41 @@ class LongContextRAG:
|
|
|
96
117
|
self.required_exts,
|
|
97
118
|
self.on_ray,
|
|
98
119
|
self.monitor_mode,
|
|
120
|
+
## 确保全文区至少能放下一个文件
|
|
121
|
+
single_file_token_limit=self.full_text_limit - 100,
|
|
122
|
+
disable_auto_window=self.args.disable_auto_window
|
|
99
123
|
)
|
|
100
124
|
|
|
101
125
|
self.doc_filter = DocFilter(
|
|
102
126
|
self.index_model, self.args, on_ray=self.on_ray, path=self.path
|
|
103
127
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
128
|
+
|
|
129
|
+
doc_num = 0
|
|
130
|
+
token_num = 0
|
|
131
|
+
token_counts = []
|
|
132
|
+
for doc in self._retrieve_documents():
|
|
133
|
+
doc_num += 1
|
|
134
|
+
doc_tokens = doc.tokens
|
|
135
|
+
token_num += doc_tokens
|
|
136
|
+
token_counts.append(doc_tokens)
|
|
137
|
+
|
|
138
|
+
avg_tokens = statistics.mean(token_counts) if token_counts else 0
|
|
139
|
+
median_tokens = statistics.median(token_counts) if token_counts else 0
|
|
114
140
|
|
|
115
141
|
logger.info(
|
|
116
|
-
|
|
142
|
+
"RAG Configuration:\n"
|
|
143
|
+
f" Total docs: {doc_num}\n"
|
|
144
|
+
f" Total tokens: {token_num}\n"
|
|
145
|
+
f" Tokenizer path: {self.tokenizer_path}\n"
|
|
146
|
+
f" Relevant score: {self.relevant_score}\n"
|
|
147
|
+
f" Token limit: {self.token_limit}\n"
|
|
148
|
+
f" Full text limit: {self.full_text_limit}\n"
|
|
149
|
+
f" Segment limit: {self.segment_limit}\n"
|
|
150
|
+
f" Buff limit: {self.buff_limit}\n"
|
|
151
|
+
f" Max doc tokens: {max(token_counts) if token_counts else 0}\n"
|
|
152
|
+
f" Min doc tokens: {min(token_counts) if token_counts else 0}\n"
|
|
153
|
+
f" Avg doc tokens: {avg_tokens:.2f}\n"
|
|
154
|
+
f" Median doc tokens: {median_tokens:.2f}\n"
|
|
117
155
|
)
|
|
118
156
|
|
|
119
157
|
def count_tokens(self, text: str) -> int:
|
|
@@ -350,9 +388,15 @@ class LongContextRAG:
|
|
|
350
388
|
query_table.add_row("Relevant docs", str(len(relevant_docs)))
|
|
351
389
|
|
|
352
390
|
# Add relevant docs information
|
|
353
|
-
relevant_docs_info =
|
|
354
|
-
|
|
355
|
-
|
|
391
|
+
relevant_docs_info = []
|
|
392
|
+
for doc in relevant_docs:
|
|
393
|
+
info = f"- {doc.module_name.replace(self.path,'',1)}"
|
|
394
|
+
if 'original_docs' in doc.metadata:
|
|
395
|
+
original_docs = ", ".join([doc.replace(self.path,"",1) for doc in doc.metadata['original_docs']])
|
|
396
|
+
info += f" (Original docs: {original_docs})"
|
|
397
|
+
relevant_docs_info.append(info)
|
|
398
|
+
|
|
399
|
+
relevant_docs_info = "\n".join(relevant_docs_info)
|
|
356
400
|
query_table.add_row("Relevant docs list", relevant_docs_info)
|
|
357
401
|
|
|
358
402
|
first_round_full_docs = []
|
|
@@ -363,7 +407,9 @@ class LongContextRAG:
|
|
|
363
407
|
|
|
364
408
|
token_limiter = TokenLimiter(
|
|
365
409
|
count_tokens=self.count_tokens,
|
|
366
|
-
|
|
410
|
+
full_text_limit=self.full_text_limit,
|
|
411
|
+
segment_limit=self.segment_limit,
|
|
412
|
+
buff_limit=self.buff_limit,
|
|
367
413
|
llm=self.llm,
|
|
368
414
|
)
|
|
369
415
|
final_relevant_docs = token_limiter.limit_tokens(
|
|
@@ -395,9 +441,18 @@ class LongContextRAG:
|
|
|
395
441
|
)
|
|
396
442
|
|
|
397
443
|
# Add relevant docs information
|
|
398
|
-
final_relevant_docs_info =
|
|
399
|
-
|
|
400
|
-
|
|
444
|
+
final_relevant_docs_info = []
|
|
445
|
+
for doc in relevant_docs:
|
|
446
|
+
info = f"- {doc.module_name.replace(self.path,'',1)}"
|
|
447
|
+
if 'original_docs' in doc.metadata:
|
|
448
|
+
original_docs = ", ".join([doc.replace(self.path,"",1) for doc in doc.metadata['original_docs']])
|
|
449
|
+
info += f" (Original docs: {original_docs})"
|
|
450
|
+
if "chunk_ranges" in doc.metadata:
|
|
451
|
+
chunk_ranges = json.dumps(doc.metadata['chunk_ranges'],ensure_ascii=False)
|
|
452
|
+
info += f" (Chunk ranges: {chunk_ranges})"
|
|
453
|
+
final_relevant_docs_info.append(info)
|
|
454
|
+
|
|
455
|
+
final_relevant_docs_info = "\n".join(final_relevant_docs_info)
|
|
401
456
|
query_table.add_row("Final Relevant docs list", final_relevant_docs_info)
|
|
402
457
|
|
|
403
458
|
# Create a panel to contain the table
|
|
@@ -409,8 +464,10 @@ class LongContextRAG:
|
|
|
409
464
|
|
|
410
465
|
# Log the panel using rich
|
|
411
466
|
console.print(panel)
|
|
412
|
-
|
|
413
|
-
|
|
467
|
+
|
|
468
|
+
request_tokens = sum([doc.tokens for doc in relevant_docs])
|
|
469
|
+
target_model = model or self.llm.default_model_name
|
|
470
|
+
logger.info(f"Start to send to model {target_model} with {request_tokens} tokens")
|
|
414
471
|
|
|
415
472
|
new_conversations = conversations[:-1] + [
|
|
416
473
|
{
|
autocoder/rag/token_counter.py
CHANGED
|
@@ -2,29 +2,46 @@ import time
|
|
|
2
2
|
from loguru import logger
|
|
3
3
|
from tokenizers import Tokenizer
|
|
4
4
|
from multiprocessing import Pool, cpu_count
|
|
5
|
+
from autocoder.rag.variable_holder import TOKENIZER_MODEL
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class RemoteTokenCounter:
|
|
7
|
-
def __init__(self,tokenizer) -> None:
|
|
9
|
+
def __init__(self, tokenizer) -> None:
|
|
8
10
|
self.tokenizer = tokenizer
|
|
9
11
|
|
|
10
|
-
def count_tokens(self, text: str) -> int:
|
|
11
|
-
try:
|
|
12
|
+
def count_tokens(self, text: str) -> int:
|
|
13
|
+
try:
|
|
12
14
|
v = self.tokenizer.chat_oai(
|
|
13
15
|
conversations=[{"role": "user", "content": text}]
|
|
14
|
-
)
|
|
16
|
+
)
|
|
15
17
|
return int(v[0].output)
|
|
16
18
|
except Exception as e:
|
|
17
19
|
logger.error(f"Error counting tokens: {str(e)}")
|
|
18
20
|
return -1
|
|
19
|
-
|
|
21
|
+
|
|
22
|
+
|
|
20
23
|
def initialize_tokenizer(tokenizer_path):
|
|
21
|
-
global tokenizer_model
|
|
24
|
+
global tokenizer_model
|
|
22
25
|
tokenizer_model = Tokenizer.from_file(tokenizer_path)
|
|
23
26
|
|
|
27
|
+
|
|
28
|
+
def count_tokens(text: str) -> int:
|
|
29
|
+
try:
|
|
30
|
+
# start_time = time.time_ns()
|
|
31
|
+
encoded = TOKENIZER_MODEL.encode('{"role":"user","content":"' + text + '"}')
|
|
32
|
+
v = len(encoded.ids)
|
|
33
|
+
# elapsed_time = time.time_ns() - start_time
|
|
34
|
+
# logger.info(f"Token counting took {elapsed_time/1000000} ms")
|
|
35
|
+
return v
|
|
36
|
+
except Exception as e:
|
|
37
|
+
logger.error(f"Error counting tokens: {str(e)}")
|
|
38
|
+
return -1
|
|
39
|
+
|
|
40
|
+
|
|
24
41
|
def count_tokens_worker(text: str) -> int:
|
|
25
42
|
try:
|
|
26
43
|
# start_time = time.time_ns()
|
|
27
|
-
encoded = tokenizer_model.encode('{"role":"user","content":"'+text+'"}')
|
|
44
|
+
encoded = tokenizer_model.encode('{"role":"user","content":"' + text + '"}')
|
|
28
45
|
v = len(encoded.ids)
|
|
29
46
|
# elapsed_time = time.time_ns() - start_time
|
|
30
47
|
# logger.info(f"Token counting took {elapsed_time/1000000} ms")
|
|
@@ -33,11 +50,16 @@ def count_tokens_worker(text: str) -> int:
|
|
|
33
50
|
logger.error(f"Error counting tokens: {str(e)}")
|
|
34
51
|
return -1
|
|
35
52
|
|
|
53
|
+
|
|
36
54
|
class TokenCounter:
|
|
37
55
|
def __init__(self, tokenizer_path: str):
|
|
38
56
|
self.tokenizer_path = tokenizer_path
|
|
39
57
|
self.num_processes = cpu_count() - 1 if cpu_count() > 1 else 1
|
|
40
|
-
self.pool = Pool(
|
|
58
|
+
self.pool = Pool(
|
|
59
|
+
processes=self.num_processes,
|
|
60
|
+
initializer=initialize_tokenizer,
|
|
61
|
+
initargs=(self.tokenizer_path,),
|
|
62
|
+
)
|
|
41
63
|
|
|
42
64
|
def count_tokens(self, text: str) -> int:
|
|
43
|
-
return self.pool.apply(count_tokens_worker, (text,))
|
|
65
|
+
return self.pool.apply(count_tokens_worker, (text,))
|