auto-coder 0.1.183__py3-none-any.whl → 0.1.184__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -4,15 +4,9 @@ import platform
4
4
  import time
5
5
  import traceback
6
6
 
7
- from watchfiles import Change, DefaultFilter, awatch, watch
8
-
9
- if platform.system() != "Windows":
10
- import fcntl
11
- else:
12
- fcntl = None
13
7
  import threading
14
8
  from multiprocessing import Pool
15
- from typing import Dict, Generator, List, Tuple
9
+ from typing import Dict, Generator, List, Tuple, Any, Optional
16
10
 
17
11
  import ray
18
12
  from loguru import logger
@@ -28,465 +22,17 @@ from autocoder.rag.loaders import (
28
22
  from autocoder.rag.token_counter import count_tokens_worker, count_tokens
29
23
  from uuid import uuid4
30
24
  from autocoder.rag.variable_holder import VariableHolder
25
+ from abc import ABC, abstractmethod
26
+ from autocoder.rag.cache.base_cache import BaseCacheManager
27
+ from autocoder.rag.cache.simple_cache import AutoCoderRAGAsyncUpdateQueue
28
+ from autocoder.rag.cache.file_monitor_cache import AutoCoderRAGDocListener
29
+ from autocoder.rag.cache.byzer_storage_cache import ByzerStorageCache
30
+ from autocoder.rag.utils import process_file_in_multi_process, process_file_local
31
+ from autocoder.common import AutoCoderArgs
31
32
 
32
33
  cache_lock = threading.Lock()
33
34
 
34
35
 
35
- class DeleteEvent(BaseModel):
36
- file_paths: List[str]
37
-
38
-
39
- class AddOrUpdateEvent(BaseModel):
40
- file_infos: List[Tuple[str, str, float]]
41
-
42
-
43
- def process_file_in_multi_process(
44
- file_info: Tuple[str, str, float]
45
- ) -> List[SourceCode]:
46
- start_time = time.time()
47
- file_path, relative_path, _ = file_info
48
- try:
49
- if file_path.endswith(".pdf"):
50
- with open(file_path, "rb") as f:
51
- content = extract_text_from_pdf(f.read())
52
- v = [
53
- SourceCode(
54
- module_name=file_path,
55
- source_code=content,
56
- tokens=count_tokens_worker(content),
57
- )
58
- ]
59
- elif file_path.endswith(".docx"):
60
- with open(file_path, "rb") as f:
61
- content = extract_text_from_docx(f.read())
62
- v = [
63
- SourceCode(
64
- module_name=f"##File: {file_path}",
65
- source_code=content,
66
- tokens=count_tokens_worker(content),
67
- )
68
- ]
69
- elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
70
- sheets = extract_text_from_excel(file_path)
71
- v = [
72
- SourceCode(
73
- module_name=f"##File: {file_path}#{sheet[0]}",
74
- source_code=sheet[1],
75
- tokens=count_tokens_worker(sheet[1]),
76
- )
77
- for sheet in sheets
78
- ]
79
- elif file_path.endswith(".pptx"):
80
- slides = extract_text_from_ppt(file_path)
81
- content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
82
- v = [
83
- SourceCode(
84
- module_name=f"##File: {file_path}",
85
- source_code=content,
86
- tokens=count_tokens_worker(content),
87
- )
88
- ]
89
- else:
90
- with open(file_path, "r", encoding="utf-8") as f:
91
- content = f.read()
92
- v = [
93
- SourceCode(
94
- module_name=f"##File: {file_path}",
95
- source_code=content,
96
- tokens=count_tokens_worker(content),
97
- )
98
- ]
99
- logger.info(f"Load file {file_path} in {time.time() - start_time}")
100
- return v
101
- except Exception as e:
102
- logger.error(f"Error processing file {file_path}: {str(e)}")
103
- return []
104
-
105
-
106
- def process_file_local(file_path: str) -> List[SourceCode]:
107
- start_time = time.time()
108
- try:
109
- if file_path.endswith(".pdf"):
110
- with open(file_path, "rb") as f:
111
- content = extract_text_from_pdf(f.read())
112
- v = [
113
- SourceCode(
114
- module_name=file_path,
115
- source_code=content,
116
- tokens=count_tokens(content),
117
- )
118
- ]
119
- elif file_path.endswith(".docx"):
120
- with open(file_path, "rb") as f:
121
- content = extract_text_from_docx(f.read())
122
- v = [
123
- SourceCode(
124
- module_name=f"##File: {file_path}",
125
- source_code=content,
126
- tokens=count_tokens(content),
127
- )
128
- ]
129
- elif file_path.endswith(".xlsx") or file_path.endswith(".xls"):
130
- sheets = extract_text_from_excel(file_path)
131
- v = [
132
- SourceCode(
133
- module_name=f"##File: {file_path}#{sheet[0]}",
134
- source_code=sheet[1],
135
- tokens=count_tokens(sheet[1]),
136
- )
137
- for sheet in sheets
138
- ]
139
- elif file_path.endswith(".pptx"):
140
- slides = extract_text_from_ppt(file_path)
141
- content = "".join(f"#{slide[0]}\n{slide[1]}\n\n" for slide in slides)
142
- v = [
143
- SourceCode(
144
- module_name=f"##File: {file_path}",
145
- source_code=content,
146
- tokens=count_tokens(content),
147
- )
148
- ]
149
- else:
150
- with open(file_path, "r", encoding="utf-8") as f:
151
- content = f.read()
152
- v = [
153
- SourceCode(
154
- module_name=f"##File: {file_path}",
155
- source_code=content,
156
- tokens=count_tokens(content),
157
- )
158
- ]
159
- logger.info(f"Load file {file_path} in {time.time() - start_time}")
160
- return v
161
- except Exception as e:
162
- logger.error(f"Error processing file {file_path}: {str(e)}")
163
- traceback.print_exc()
164
- return []
165
-
166
-
167
- class AutoCoderRAGDocListener:
168
- cache: Dict[str, Dict] = {}
169
- ignore_dirs = [
170
- "__pycache__",
171
- ".git",
172
- ".hg",
173
- ".svn",
174
- ".tox",
175
- ".venv",
176
- ".cache",
177
- ".idea",
178
- "node_modules",
179
- ".mypy_cache",
180
- ".pytest_cache",
181
- ".hypothesis",
182
- ]
183
- ignore_entity_patterns = [
184
- r"\.py[cod]$",
185
- r"\.___jb_...___$",
186
- r"\.sw.$",
187
- "~$",
188
- r"^\.\#",
189
- r"^\.DS_Store$",
190
- r"^flycheck_",
191
- r"^test.*$",
192
- ]
193
-
194
- def __init__(self, path: str, ignore_spec, required_exts: List) -> None:
195
- self.path = path
196
- self.ignore_spec = ignore_spec
197
- self.required_exts = required_exts
198
- self.stop_event = threading.Event()
199
-
200
- # connect list
201
- self.ignore_entity_patterns.extend(self._load_ignore_file())
202
- self.file_filter = DefaultFilter(
203
- ignore_dirs=self.ignore_dirs,
204
- ignore_paths=[],
205
- ignore_entity_patterns=self.ignore_entity_patterns,
206
- )
207
- self.load_first()
208
- # 创建一个新线程来执行open_watch
209
- self.watch_thread = threading.Thread(target=self.open_watch)
210
- # 将线程设置为守护线程,这样主程序退出时,这个线程也会自动退出
211
- self.watch_thread.daemon = True
212
- # 启动线程
213
- self.watch_thread.start()
214
-
215
- def stop(self):
216
- self.stop_event.set()
217
- self.watch_thread.join()
218
-
219
- def __del__(self):
220
- self.stop()
221
-
222
- def load_first(self):
223
- files_to_process = self.get_all_files()
224
- if not files_to_process:
225
- return
226
- for item in files_to_process:
227
- self.update_cache(item)
228
-
229
- def update_cache(self, file_path):
230
- source_code = process_file_local(file_path)
231
- self.cache[file_path] = {
232
- "file_path": file_path,
233
- "content": [c.model_dump() for c in source_code],
234
- }
235
- logger.info(f"update cache: {file_path}")
236
- logger.info(f"current cache: {self.cache.keys()}")
237
-
238
- def remove_cache(self, file_path):
239
- del self.cache[file_path]
240
- logger.info(f"remove cache: {file_path}")
241
- logger.info(f"current cache: {self.cache.keys()}")
242
-
243
- def open_watch(self):
244
- logger.info(f"start monitor: {self.path}...")
245
- for changes in watch(
246
- self.path, watch_filter=self.file_filter, stop_event=self.stop_event
247
- ):
248
- for change in changes:
249
- (action, path) = change
250
- if action == Change.added or action == Change.modified:
251
- self.update_cache(path)
252
- elif action == Change.deleted:
253
- self.remove_cache(path)
254
-
255
- def get_cache(self):
256
- return self.cache
257
-
258
- def _load_ignore_file(self):
259
- serveignore_path = os.path.join(self.path, ".serveignore")
260
- gitignore_path = os.path.join(self.path, ".gitignore")
261
-
262
- if os.path.exists(serveignore_path):
263
- with open(serveignore_path, "r") as ignore_file:
264
- patterns = ignore_file.readlines()
265
- return [pattern.strip() for pattern in patterns]
266
- elif os.path.exists(gitignore_path):
267
- with open(gitignore_path, "r") as ignore_file:
268
- patterns = ignore_file.readlines()
269
- return [pattern.strip() for pattern in patterns]
270
- return []
271
-
272
- def get_all_files(self) -> List[str]:
273
- all_files = []
274
- for root, dirs, files in os.walk(self.path):
275
- dirs[:] = [d for d in dirs if not d.startswith(".")]
276
-
277
- if self.ignore_spec:
278
- relative_root = os.path.relpath(root, self.path)
279
- dirs[:] = [
280
- d
281
- for d in dirs
282
- if not self.ignore_spec.match_file(os.path.join(relative_root, d))
283
- ]
284
- files = [
285
- f
286
- for f in files
287
- if not self.ignore_spec.match_file(os.path.join(relative_root, f))
288
- ]
289
-
290
- for file in files:
291
- if self.required_exts and not any(
292
- file.endswith(ext) for ext in self.required_exts
293
- ):
294
- continue
295
-
296
- file_path = os.path.join(root, file)
297
- absolute_path = os.path.abspath(file_path)
298
- all_files.append(absolute_path)
299
-
300
- return all_files
301
-
302
-
303
- class AutoCoderRAGAsyncUpdateQueue:
304
- def __init__(self, path: str, ignore_spec, required_exts: list):
305
- self.path = path
306
- self.ignore_spec = ignore_spec
307
- self.required_exts = required_exts
308
- self.queue = []
309
- self.cache = {}
310
- self.lock = threading.Lock()
311
- self.stop_event = threading.Event()
312
- self.thread = threading.Thread(target=self._process_queue)
313
- self.thread.daemon = True
314
- self.thread.start()
315
- self.cache = self.read_cache()
316
-
317
- def _process_queue(self):
318
- while not self.stop_event.is_set():
319
- try:
320
- self.process_queue()
321
- except Exception as e:
322
- logger.error(f"Error in process_queue: {e}")
323
- time.sleep(1) # 避免过于频繁的检查
324
-
325
- def stop(self):
326
- self.stop_event.set()
327
- self.thread.join()
328
-
329
- def __del__(self):
330
- self.stop()
331
-
332
- def load_first(self):
333
- with self.lock:
334
- if self.cache:
335
- return
336
- files_to_process = []
337
- for file_info in self.get_all_files():
338
- file_path, _, modify_time = file_info
339
- if (
340
- file_path not in self.cache
341
- or self.cache[file_path]["modify_time"] < modify_time
342
- ):
343
- files_to_process.append(file_info)
344
- if not files_to_process:
345
- return
346
- # remote_process_file = ray.remote(process_file)
347
- # results = ray.get(
348
- # [process_file.remote(file_info) for file_info in files_to_process]
349
- # )
350
- from autocoder.rag.token_counter import initialize_tokenizer
351
-
352
- with Pool(
353
- processes=os.cpu_count(),
354
- initializer=initialize_tokenizer,
355
- initargs=(VariableHolder.TOKENIZER_PATH,),
356
- ) as pool:
357
- results = pool.map(process_file_in_multi_process, files_to_process)
358
-
359
- for file_info, result in zip(files_to_process, results):
360
- self.update_cache(file_info, result)
361
-
362
- self.write_cache()
363
-
364
- def trigger_update(self):
365
- logger.info("检查文件是否有更新.....")
366
- files_to_process = []
367
- current_files = set()
368
- for file_info in self.get_all_files():
369
- file_path, _, modify_time = file_info
370
- current_files.add(file_path)
371
- if (
372
- file_path not in self.cache
373
- or self.cache[file_path]["modify_time"] < modify_time
374
- ):
375
- files_to_process.append(file_info)
376
-
377
- deleted_files = set(self.cache.keys()) - current_files
378
- logger.info(f"files_to_process: {files_to_process}")
379
- logger.info(f"deleted_files: {deleted_files}")
380
- if deleted_files:
381
- with self.lock:
382
- self.queue.append(DeleteEvent(file_paths=deleted_files))
383
- if files_to_process:
384
- with self.lock:
385
- self.queue.append(AddOrUpdateEvent(file_infos=files_to_process))
386
-
387
- def process_queue(self):
388
- while self.queue:
389
- file_list = self.queue.pop(0)
390
- if isinstance(file_list, DeleteEvent):
391
- for item in file_list.file_paths:
392
- logger.info(f"{item} is detected to be removed")
393
- del self.cache[item]
394
- elif isinstance(file_list, AddOrUpdateEvent):
395
- for file_info in file_list.file_infos:
396
- logger.info(f"{file_info[0]} is detected to be updated")
397
- result = process_file_local(file_info[0])
398
- self.update_cache(file_info, result)
399
-
400
- self.write_cache()
401
-
402
- def read_cache(self) -> Dict[str, Dict]:
403
- cache_dir = os.path.join(self.path, ".cache")
404
- cache_file = os.path.join(cache_dir, "cache.jsonl")
405
-
406
- if not os.path.exists(cache_dir):
407
- os.makedirs(cache_dir)
408
-
409
- cache = {}
410
- if os.path.exists(cache_file):
411
- with open(cache_file, "r") as f:
412
- for line in f:
413
- data = json.loads(line)
414
- cache[data["file_path"]] = data
415
- return cache
416
-
417
- def write_cache(self):
418
- cache_dir = os.path.join(self.path, ".cache")
419
- cache_file = os.path.join(cache_dir, "cache.jsonl")
420
-
421
- if not fcntl:
422
- with open(cache_file, "w") as f:
423
- for data in self.cache.values():
424
- json.dump(data, f, ensure_ascii=False)
425
- f.write("\n")
426
- else:
427
- lock_file = cache_file + ".lock"
428
- with open(lock_file, "w") as lockf:
429
- try:
430
- # 获取文件锁
431
- fcntl.flock(lockf, fcntl.LOCK_EX | fcntl.LOCK_NB)
432
- # 写入缓存文件
433
- with open(cache_file, "w") as f:
434
- for data in self.cache.values():
435
- json.dump(data, f, ensure_ascii=False)
436
- f.write("\n")
437
-
438
- finally:
439
- # 释放文件锁
440
- fcntl.flock(lockf, fcntl.LOCK_UN)
441
-
442
- def update_cache(
443
- self, file_info: Tuple[str, str, float], content: List[SourceCode]
444
- ):
445
- file_path, relative_path, modify_time = file_info
446
- self.cache[file_path] = {
447
- "file_path": file_path,
448
- "relative_path": relative_path,
449
- "content": [c.model_dump() for c in content],
450
- "modify_time": modify_time,
451
- }
452
-
453
- def get_cache(self):
454
- self.load_first()
455
- self.trigger_update()
456
- return self.cache
457
-
458
- def get_all_files(self) -> List[Tuple[str, str, float]]:
459
- all_files = []
460
- for root, dirs, files in os.walk(self.path):
461
- dirs[:] = [d for d in dirs if not d.startswith(".")]
462
-
463
- if self.ignore_spec:
464
- relative_root = os.path.relpath(root, self.path)
465
- dirs[:] = [
466
- d
467
- for d in dirs
468
- if not self.ignore_spec.match_file(os.path.join(relative_root, d))
469
- ]
470
- files = [
471
- f
472
- for f in files
473
- if not self.ignore_spec.match_file(os.path.join(relative_root, f))
474
- ]
475
-
476
- for file in files:
477
- if self.required_exts and not any(
478
- file.endswith(ext) for ext in self.required_exts
479
- ):
480
- continue
481
-
482
- file_path = os.path.join(root, file)
483
- relative_path = os.path.relpath(file_path, self.path)
484
- modify_time = os.path.getmtime(file_path)
485
- all_files.append((file_path, relative_path, modify_time))
486
-
487
- return all_files
488
-
489
-
490
36
  def get_or_create_actor(path: str, ignore_spec, required_exts: list, cacher={}):
491
37
  with cache_lock:
492
38
  # 处理路径名
@@ -508,7 +54,25 @@ def get_or_create_actor(path: str, ignore_spec, required_exts: list, cacher={}):
508
54
  return actor
509
55
 
510
56
 
511
- class DocumentRetriever:
57
+ class BaseDocumentRetriever(ABC):
58
+ """Abstract base class for document retrieval."""
59
+
60
+ @abstractmethod
61
+ def get_cache(self, options: Optional[Dict[str, Any]] = None):
62
+ """Get cached documents."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ def retrieve_documents(
67
+ self, options: Optional[Dict[str, Any]] = None
68
+ ) -> Generator[SourceCode, None, None]:
69
+ """Retrieve documents."""
70
+ pass
71
+
72
+
73
+ class LocalDocumentRetriever(BaseDocumentRetriever):
74
+ """Local filesystem document retriever implementation."""
75
+
512
76
  def __init__(
513
77
  self,
514
78
  path: str,
@@ -518,11 +82,14 @@ class DocumentRetriever:
518
82
  monitor_mode: bool = False,
519
83
  single_file_token_limit: int = 60000,
520
84
  disable_auto_window: bool = False,
85
+ enable_hybrid_index: bool = False,
86
+ extra_params: Optional[AutoCoderArgs] = None,
521
87
  ) -> None:
522
88
  self.path = path
523
89
  self.ignore_spec = ignore_spec
524
90
  self.required_exts = required_exts
525
91
  self.monitor_mode = monitor_mode
92
+ self.enable_hybrid_index = enable_hybrid_index
526
93
  self.single_file_token_limit = single_file_token_limit
527
94
  self.disable_auto_window = disable_auto_window
528
95
 
@@ -535,12 +102,19 @@ class DocumentRetriever:
535
102
  if self.on_ray:
536
103
  self.cacher = get_or_create_actor(path, ignore_spec, required_exts)
537
104
  else:
538
- if self.monitor_mode:
539
- self.cacher = AutoCoderRAGDocListener(path, ignore_spec, required_exts)
540
- else:
541
- self.cacher = AutoCoderRAGAsyncUpdateQueue(
542
- path, ignore_spec, required_exts
105
+ if self.enable_hybrid_index:
106
+ self.cacher = ByzerStorageCache(
107
+ path, ignore_spec, required_exts, extra_params
543
108
  )
109
+ else:
110
+ if self.monitor_mode:
111
+ self.cacher = AutoCoderRAGDocListener(
112
+ path, ignore_spec, required_exts
113
+ )
114
+ else:
115
+ self.cacher = AutoCoderRAGAsyncUpdateQueue(
116
+ path, ignore_spec, required_exts
117
+ )
544
118
 
545
119
  logger.info(f"DocumentRetriever initialized with:")
546
120
  logger.info(f" Path: {self.path}")
@@ -548,18 +122,25 @@ class DocumentRetriever:
548
122
  logger.info(f" Single file token limit: {self.single_file_token_limit}")
549
123
  logger.info(f" Small file token limit: {self.small_file_token_limit}")
550
124
  logger.info(f" Small file merge limit: {self.small_file_merge_limit}")
125
+ logger.info(f" Enable hybrid index: {self.enable_hybrid_index}")
126
+ if extra_params:
127
+ logger.info(
128
+ f" Hybrid index max output tokens: {extra_params.hybrid_index_max_output_tokens}"
129
+ )
551
130
 
552
- def get_cache(self):
131
+ def get_cache(self, options: Optional[Dict[str, Any]] = None):
553
132
  if self.on_ray:
554
133
  return ray.get(self.cacher.get_cache.remote())
555
134
  else:
556
- return self.cacher.get_cache()
135
+ return self.cacher.get_cache(options=options)
557
136
 
558
- def retrieve_documents(self) -> Generator[SourceCode, None, None]:
137
+ def retrieve_documents(
138
+ self, options: Optional[Dict[str, Any]] = None
139
+ ) -> Generator[SourceCode, None, None]:
559
140
  logger.info("Starting document retrieval process")
560
141
  waiting_list = []
561
142
  waiting_tokens = 0
562
- for _, data in self.get_cache().items():
143
+ for _, data in self.get_cache(options=options).items():
563
144
  for source_code in data["content"]:
564
145
  doc = SourceCode.model_validate(source_code)
565
146
  if self.disable_auto_window:
@@ -579,7 +160,7 @@ class DocumentRetriever:
579
160
  yield from self._split_large_document(doc)
580
161
  else:
581
162
  yield doc
582
- if waiting_list and not self.disable_auto_window:
163
+ if waiting_list and not self.disable_auto_window:
583
164
  yield from self._process_waiting_list(waiting_list)
584
165
 
585
166
  logger.info("Document retrieval process completed")
@@ -594,7 +175,7 @@ class DocumentRetriever:
594
175
  self, waiting_list: List[SourceCode]
595
176
  ) -> Generator[SourceCode, None, None]:
596
177
  if len(waiting_list) == 1:
597
- yield waiting_list[0]
178
+ yield waiting_list[0]
598
179
  elif len(waiting_list) > 1:
599
180
  yield self._merge_documents(waiting_list)
600
181
 
@@ -603,7 +184,7 @@ class DocumentRetriever:
603
184
  [f"#File: {doc.module_name}\n{doc.source_code}" for doc in docs]
604
185
  )
605
186
  merged_tokens = sum([doc.tokens for doc in docs])
606
- merged_name = f"Merged_{len(docs)}_docs_{str(uuid4())}"
187
+ merged_name = f"Merged_{len(docs)}_docs_{str(uuid4())}"
607
188
  logger.info(
608
189
  f"Merged {len(docs)} documents into {merged_name} (tokens: {merged_tokens})."
609
190
  )
@@ -17,7 +17,7 @@ import traceback
17
17
 
18
18
  from autocoder.common import AutoCoderArgs, SourceCode
19
19
  from autocoder.rag.doc_filter import DocFilter
20
- from autocoder.rag.document_retriever import DocumentRetriever
20
+ from autocoder.rag.document_retriever import LocalDocumentRetriever
21
21
  from autocoder.rag.relevant_utils import (
22
22
  DocRelevance,
23
23
  FilterDoc,
@@ -91,6 +91,7 @@ class LongContextRAG:
91
91
 
92
92
  # if open monitor mode
93
93
  self.monitor_mode = self.args.monitor_mode or False
94
+ self.enable_hybrid_index = self.args.enable_hybrid_index
94
95
  logger.info(f"Monitor mode: {self.monitor_mode}")
95
96
 
96
97
  if args.rag_url and args.rag_url.startswith("http://"):
@@ -117,7 +118,8 @@ class LongContextRAG:
117
118
  self.ignore_spec = self._load_ignore_file()
118
119
 
119
120
  self.token_limit = self.args.rag_context_window_limit or 120000
120
- self.document_retriever = DocumentRetriever(
121
+ retriever_class = self._get_document_retriever_class()
122
+ self.document_retriever = retriever_class(
121
123
  self.path,
122
124
  self.ignore_spec,
123
125
  self.required_exts,
@@ -125,7 +127,9 @@ class LongContextRAG:
125
127
  self.monitor_mode,
126
128
  ## 确保全文区至少能放下一个文件
127
129
  single_file_token_limit=self.full_text_limit - 100,
128
- disable_auto_window=self.args.disable_auto_window,
130
+ disable_auto_window=self.args.disable_auto_window,
131
+ enable_hybrid_index=self.args.enable_hybrid_index,
132
+ extra_params=self.args
129
133
  )
130
134
 
131
135
  self.doc_filter = DocFilter(
@@ -204,6 +208,11 @@ class LongContextRAG:
204
208
  回答:
205
209
  """
206
210
 
211
+ def _get_document_retriever_class(self):
212
+ """Get the document retriever class based on configuration."""
213
+ # Default to LocalDocumentRetriever if not specified
214
+ return LocalDocumentRetriever
215
+
207
216
  def _load_ignore_file(self):
208
217
  serveignore_path = os.path.join(self.path, ".serveignore")
209
218
  gitignore_path = os.path.join(self.path, ".gitignore")
@@ -216,8 +225,8 @@ class LongContextRAG:
216
225
  return pathspec.PathSpec.from_lines("gitwildmatch", ignore_file)
217
226
  return None
218
227
 
219
- def _retrieve_documents(self) -> Generator[SourceCode, None, None]:
220
- return self.document_retriever.retrieve_documents()
228
+ def _retrieve_documents(self,options:Optional[Dict[str,Any]]=None) -> Generator[SourceCode, None, None]:
229
+ return self.document_retriever.retrieve_documents(options=options)
221
230
 
222
231
  def build(self):
223
232
  pass
@@ -269,7 +278,8 @@ class LongContextRAG:
269
278
  return [SourceCode(module_name=f"RAG:{url}", source_code="".join(v))]
270
279
 
271
280
  def _filter_docs(self, conversations: List[Dict[str, str]]) -> List[FilterDoc]:
272
- documents = self._retrieve_documents()
281
+ query = conversations[-1]["content"]
282
+ documents = self._retrieve_documents(options={"query":query})
273
283
  return self.doc_filter.filter_docs(
274
284
  conversations=conversations, documents=documents
275
285
  )