auto-coder 0.1.195__py3-none-any.whl → 0.1.197__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

autocoder/index/index.py CHANGED
@@ -24,6 +24,11 @@ from rich.panel import Panel
24
24
  from rich.text import Text
25
25
 
26
26
  from loguru import logger
27
+ from autocoder.utils.queue_communicate import (
28
+ queue_communicate,
29
+ CommunicateEvent,
30
+ CommunicateEventType,
31
+ )
27
32
 
28
33
 
29
34
  class IndexItem(pydantic.BaseModel):
@@ -39,10 +44,12 @@ class TargetFile(pydantic.BaseModel):
39
44
  ..., description="The reason why the file is the target file"
40
45
  )
41
46
 
47
+
42
48
  class VerifyFileRelevance(pydantic.BaseModel):
43
49
  relevant_score: int
44
50
  reason: str
45
51
 
52
+
46
53
  class FileList(pydantic.BaseModel):
47
54
  file_list: List[TargetFile]
48
55
 
@@ -93,7 +100,7 @@ class IndexManager:
93
100
  "reason": "这是相关的原因..."
94
101
  }
95
102
  ```
96
- """
103
+ """
97
104
 
98
105
  @byzerllm.prompt()
99
106
  def _get_related_files(self, indices: str, file_paths: str) -> str:
@@ -127,7 +134,7 @@ class IndexManager:
127
134
  1. 找到的文件名必须出现在上面的文件列表中
128
135
  2. 原因控制在20字以内
129
136
  3. 如果没有相关的文件,输出如下 json 即可:
130
-
137
+
131
138
  ```json
132
139
  {"file_list": []}
133
140
  ```
@@ -231,12 +238,14 @@ class IndexManager:
231
238
  )
232
239
  symbols = []
233
240
  for chunk in chunks:
234
- chunk_symbols = self.get_all_file_symbols.with_llm(self.index_llm).run(source.module_name, chunk)
241
+ chunk_symbols = self.get_all_file_symbols.with_llm(
242
+ self.index_llm).run(source.module_name, chunk)
235
243
  time.sleep(self.anti_quota_limit)
236
244
  symbols.append(chunk_symbols)
237
245
  symbols = "\n".join(symbols)
238
246
  else:
239
- symbols = self.get_all_file_symbols.with_llm(self.index_llm).run(source.module_name, source_code)
247
+ symbols = self.get_all_file_symbols.with_llm(
248
+ self.index_llm).run(source.module_name, source_code)
240
249
  time.sleep(self.anti_quota_limit)
241
250
 
242
251
  logger.info(
@@ -286,7 +295,7 @@ class IndexManager:
286
295
  v = source.source_code.splitlines()
287
296
  new_v = []
288
297
  for line in v:
289
- new_v.append(line[line.find(":") :])
298
+ new_v.append(line[line.find(":"):])
290
299
  source_code = "\n".join(new_v)
291
300
 
292
301
  md5 = hashlib.md5(source_code.encode("utf-8")).hexdigest()
@@ -299,7 +308,8 @@ class IndexManager:
299
308
  counter = 0
300
309
  num_files = len(wait_to_build_files)
301
310
  total_files = len(self.sources)
302
- logger.info(f"Total Files: {total_files}, Need to Build Index: {num_files}")
311
+ logger.info(
312
+ f"Total Files: {total_files}, Need to Build Index: {num_files}")
303
313
 
304
314
  futures = [
305
315
  executor.submit(self.build_index_for_single_source, source)
@@ -366,13 +376,13 @@ class IndexManager:
366
376
  item_str = f"##{item.module_name}\n{symbols_str}\n\n"
367
377
 
368
378
  if skip_symbols:
369
- item_str = f"{item.module_name}\n"
379
+ item_str = f"{item.module_name}\n"
370
380
 
371
381
  if len(current_chunk) > self.args.filter_batch_size:
372
382
  yield "".join(current_chunk)
373
- current_chunk = [item_str]
383
+ current_chunk = [item_str]
374
384
  else:
375
- current_chunk.append(item_str)
385
+ current_chunk.append(item_str)
376
386
 
377
387
  if current_chunk:
378
388
  yield "".join(current_chunk)
@@ -405,7 +415,8 @@ class IndexManager:
405
415
  lock = threading.Lock()
406
416
 
407
417
  def process_chunk(chunk, chunk_count):
408
- result = self._get_related_files.with_llm(self.llm).with_return_type(FileList).run(chunk, "\n".join(file_paths))
418
+ result = self._get_related_files.with_llm(self.llm).with_return_type(
419
+ FileList).run(chunk, "\n".join(file_paths))
409
420
  if result is not None:
410
421
  with lock:
411
422
  all_results.extend(result.file_list)
@@ -419,7 +430,7 @@ class IndexManager:
419
430
  futures = []
420
431
  chunk_count = 0
421
432
  for chunk in self._get_meta_str(
422
- max_chunk_size= -1
433
+ max_chunk_size=-1
423
434
  ):
424
435
  future = executor.submit(process_chunk, chunk, chunk_count)
425
436
  futures.append(future)
@@ -428,7 +439,8 @@ class IndexManager:
428
439
  for future in as_completed(futures):
429
440
  future.result()
430
441
 
431
- all_results = list({file.file_path: file for file in all_results}.values())
442
+ all_results = list(
443
+ {file.file_path: file for file in all_results}.values())
432
444
  return FileList(file_list=all_results)
433
445
 
434
446
  def _query_index_with_thread(self, query, func):
@@ -439,7 +451,8 @@ class IndexManager:
439
451
 
440
452
  def process_chunk(chunk):
441
453
  nonlocal completed_threads
442
- result = self._get_target_files_by_query.with_llm(self.llm).with_return_type(FileList).run(chunk, query)
454
+ result = self._get_target_files_by_query.with_llm(
455
+ self.llm).with_return_type(FileList).run(chunk, query)
443
456
  if result is not None:
444
457
  with lock:
445
458
  all_results.extend(result.file_list)
@@ -469,24 +482,27 @@ class IndexManager:
469
482
  def w():
470
483
  return self._get_meta_str(
471
484
  skip_symbols=False,
472
- max_chunk_size= -1,
485
+ max_chunk_size=-1,
473
486
  includes=[SymbolType.USAGE],
474
487
  )
475
-
476
- temp_result, total_threads, completed_threads = self._query_index_with_thread(query, w)
488
+
489
+ temp_result, total_threads, completed_threads = self._query_index_with_thread(
490
+ query, w)
477
491
  all_results.extend(temp_result)
478
492
 
479
493
  if self.args.index_filter_level >= 1:
480
494
 
481
495
  def w():
482
496
  return self._get_meta_str(
483
- skip_symbols=False, max_chunk_size= -1
497
+ skip_symbols=False, max_chunk_size=-1
484
498
  )
485
499
 
486
- temp_result, total_threads, completed_threads = self._query_index_with_thread(query, w)
500
+ temp_result, total_threads, completed_threads = self._query_index_with_thread(
501
+ query, w)
487
502
  all_results.extend(temp_result)
488
503
 
489
- all_results = list({file.file_path: file for file in all_results}.values())
504
+ all_results = list(
505
+ {file.file_path: file for file in all_results}.values())
490
506
  # Limit the number of files based on index_filter_file_num
491
507
  limited_results = all_results[: self.args.index_filter_file_num]
492
508
  return FileList(file_list=limited_results)
@@ -518,7 +534,7 @@ class IndexManager:
518
534
  ]
519
535
  }
520
536
  ```
521
-
537
+
522
538
  如果没有找到,返回如下 json 即可:
523
539
 
524
540
  ```json
@@ -585,6 +601,15 @@ def build_index_and_filter_files(
585
601
 
586
602
  if not args.skip_build_index and llm:
587
603
  # Phase 2: Build index
604
+ if args.request_id and not args.skip_events:
605
+ queue_communicate.send_event(
606
+ request_id=args.request_id,
607
+ event=CommunicateEvent(
608
+ event_type=CommunicateEventType.CODE_INDEX_BUILD_START.value,
609
+ data=json.dumps({"total_files": len(sources)})
610
+ )
611
+ )
612
+
588
613
  logger.info("Phase 2: Building index for all files...")
589
614
  phase_start = time.monotonic()
590
615
  index_manager = IndexManager(llm=llm, sources=sources, args=args)
@@ -592,9 +617,31 @@ def build_index_and_filter_files(
592
617
  stats["indexed_files"] = len(index_data) if index_data else 0
593
618
  stats["timings"]["build_index"] = time.monotonic() - phase_start
594
619
 
620
+ if args.request_id and not args.skip_events:
621
+ queue_communicate.send_event(
622
+ request_id=args.request_id,
623
+ event=CommunicateEvent(
624
+ event_type=CommunicateEventType.CODE_INDEX_BUILD_END.value,
625
+ data=json.dumps({
626
+ "indexed_files": stats["indexed_files"],
627
+ "build_index_time": stats["timings"]["build_index"],
628
+ })
629
+ )
630
+ )
631
+
595
632
  if not args.skip_filter_index:
633
+ if args.request_id and not args.skip_events:
634
+ queue_communicate.send_event(
635
+ request_id=args.request_id,
636
+ event=CommunicateEvent(
637
+ event_type=CommunicateEventType.CODE_INDEX_FILTER_START.value,
638
+ data=json.dumps({})
639
+ )
640
+ )
596
641
  # Phase 3: Level 1 filtering - Query-based
597
- logger.info("Phase 3: Performing Level 1 filtering (query-based)...")
642
+ logger.info(
643
+ "Phase 3: Performing Level 1 filtering (query-based)...")
644
+
598
645
  phase_start = time.monotonic()
599
646
  target_files = index_manager.get_target_files_by_query(args.query)
600
647
 
@@ -607,7 +654,16 @@ def build_index_and_filter_files(
607
654
 
608
655
  # Phase 4: Level 2 filtering - Related files
609
656
  if target_files is not None and args.index_filter_level >= 2:
610
- logger.info("Phase 4: Performing Level 2 filtering (related files)...")
657
+ logger.info(
658
+ "Phase 4: Performing Level 2 filtering (related files)...")
659
+ if args.request_id and not args.skip_events:
660
+ queue_communicate.send_event(
661
+ request_id=args.request_id,
662
+ event=CommunicateEvent(
663
+ event_type=CommunicateEventType.CODE_INDEX_FILTER_START.value,
664
+ data=json.dumps({})
665
+ )
666
+ )
611
667
  phase_start = time.monotonic()
612
668
  related_files = index_manager.get_related_files(
613
669
  [file.file_path for file in target_files.file_list]
@@ -617,7 +673,8 @@ def build_index_and_filter_files(
617
673
  file_path = file.file_path.strip()
618
674
  final_files[get_file_path(file_path)] = file
619
675
  stats["level2_filtered"] = len(related_files.file_list)
620
- stats["timings"]["level2_filter"] = time.monotonic() - phase_start
676
+ stats["timings"]["level2_filter"] = time.monotonic() - \
677
+ phase_start
621
678
 
622
679
  if not final_files:
623
680
  logger.warning("No related files found, using all files")
@@ -632,37 +689,40 @@ def build_index_and_filter_files(
632
689
  phase_start = time.monotonic()
633
690
  verified_files = {}
634
691
  temp_files = list(final_files.values())
635
-
692
+
636
693
  def verify_single_file(file: TargetFile):
637
694
  for source in sources:
638
695
  if source.module_name == file.file_path:
639
696
  file_content = source.source_code
640
697
  try:
641
698
  result = index_manager.verify_file_relevance.with_llm(llm).with_return_type(VerifyFileRelevance).run(
642
- file_content=file_content,
699
+ file_content=file_content,
643
700
  query=args.query
644
- )
701
+ )
645
702
  if result.relevant_score >= args.verify_file_relevance_score:
646
703
  return file.file_path, TargetFile(
647
704
  file_path=file.file_path,
648
705
  reason=f"Score:{result.relevant_score}, {result.reason}"
649
706
  )
650
707
  except Exception as e:
651
- logger.warning(f"Failed to verify file {file.file_path}: {str(e)}")
708
+ logger.warning(
709
+ f"Failed to verify file {file.file_path}: {str(e)}")
652
710
  return None
653
711
 
654
712
  with ThreadPoolExecutor(max_workers=args.index_filter_workers) as executor:
655
- futures = [executor.submit(verify_single_file, file) for file in temp_files]
713
+ futures = [executor.submit(verify_single_file, file)
714
+ for file in temp_files]
656
715
  for future in as_completed(futures):
657
716
  result = future.result()
658
717
  if result:
659
718
  file_path, target_file = result
660
719
  verified_files[file_path] = target_file
661
720
  time.sleep(args.anti_quota_limit)
662
-
721
+
663
722
  stats["verified_files"] = len(verified_files)
664
- stats["timings"]["relevance_verification"] = time.monotonic() - phase_start
665
-
723
+ stats["timings"]["relevance_verification"] = time.monotonic() - \
724
+ phase_start
725
+
666
726
  final_files = verified_files if verified_files else final_files
667
727
 
668
728
  def display_table_and_get_selections(data):
@@ -714,12 +774,13 @@ def build_index_and_filter_files(
714
774
 
715
775
  console.print(panel)
716
776
 
717
- # Phase 6: File selection and limitation
777
+ # Phase 6: File selection and limitation
718
778
  logger.info("Phase 6: Processing file selection and limits...")
719
779
  phase_start = time.monotonic()
720
-
780
+
721
781
  if args.index_filter_file_num > 0:
722
- logger.info(f"Limiting files from {len(final_files)} to {args.index_filter_file_num}")
782
+ logger.info(
783
+ f"Limiting files from {len(final_files)} to {args.index_filter_file_num}")
723
784
 
724
785
  if args.skip_confirm:
725
786
  final_filenames = [file.file_path for file in final_files.values()]
@@ -735,17 +796,17 @@ def build_index_and_filter_files(
735
796
  )
736
797
  final_filenames = []
737
798
  else:
738
- final_filenames = display_table_and_get_selections(target_files_data)
739
-
799
+ final_filenames = display_table_and_get_selections(
800
+ target_files_data)
801
+
740
802
  if args.index_filter_file_num > 0:
741
803
  final_filenames = final_filenames[: args.index_filter_file_num]
742
-
804
+
743
805
  stats["timings"]["file_selection"] = time.monotonic() - phase_start
744
806
 
745
807
  # Phase 7: Display results and prepare output
746
808
  logger.info("Phase 7: Preparing final output...")
747
- phase_start = time.monotonic()
748
-
809
+ phase_start = time.monotonic()
749
810
  try:
750
811
  print_selected(
751
812
  [
@@ -772,26 +833,52 @@ def build_index_and_filter_files(
772
833
  depulicated_sources.add(file.module_name)
773
834
  source_code += f"##File: {file.module_name}\n"
774
835
  source_code += f"{file.source_code}\n\n"
775
-
836
+
837
+ if args.request_id and not args.skip_events:
838
+ queue_communicate.send_event(
839
+ request_id=args.request_id,
840
+ event=CommunicateEvent(
841
+ event_type=CommunicateEventType.CODE_INDEX_FILTER_FILE_SELECTED.value,
842
+ data=json.dumps([
843
+ (file.file_path, file.reason)
844
+ for file in final_files.values()
845
+ if file.file_path in depulicated_sources
846
+ ])
847
+ )
848
+ )
849
+
776
850
  stats["final_files"] = len(depulicated_sources)
777
851
  stats["timings"]["prepare_output"] = time.monotonic() - phase_start
778
-
852
+
779
853
  # Calculate total time and print summary
780
854
  total_time = time.monotonic() - total_start_time
781
855
  stats["timings"]["total"] = total_time
782
-
856
+
783
857
  # Print final statistics
784
858
  logger.info("\n=== Build Index and Filter Files Summary ===")
785
859
  logger.info(f"Total files in project: {stats['total_files']}")
786
860
  logger.info(f"Files indexed: {stats['indexed_files']}")
787
861
  logger.info(f"Files after Level 1 filter: {stats['level1_filtered']}")
788
862
  logger.info(f"Files after Level 2 filter: {stats['level2_filtered']}")
789
- logger.info(f"Files after relevance verification: {stats.get('verified_files', 0)}")
863
+ logger.info(
864
+ f"Files after relevance verification: {stats.get('verified_files', 0)}")
790
865
  logger.info(f"Final files selected: {stats['final_files']}")
791
866
  logger.info("\nTime breakdown:")
792
867
  for phase, duration in stats["timings"].items():
793
868
  logger.info(f" - {phase}: {duration:.2f}s")
794
869
  logger.info(f"Total execution time: {total_time:.2f}s")
795
870
  logger.info("==========================================\n")
796
-
871
+
872
+ if args.request_id and not args.skip_events:
873
+ queue_communicate.send_event(
874
+ request_id=args.request_id,
875
+ event=CommunicateEvent(
876
+ event_type=CommunicateEventType.CODE_INDEX_FILTER_END.value,
877
+ data=json.dumps({
878
+ "filtered_files": stats["final_files"],
879
+ "filter_time": stats['level1_filtered'] + stats['level2_filtered'] + stats.get('verified_files', 0)
880
+ })
881
+ )
882
+ )
883
+
797
884
  return source_code
autocoder/lang.py CHANGED
@@ -71,7 +71,12 @@ lang_desc = {
71
71
  "base_dir": "Alternative path for /~/.auto-coder to store or retrieve text embeddings. Using /~/.auto-coder/ if not specified",
72
72
  "editblock_similarity": "The similarity threshold of TextSimilarity when merging edit blocks. Default is 0.9",
73
73
  "include_project_structure": "Whether to include the project directory structure in the code generation prompt. Default is False",
74
- "filter_batch_size": "The batch size used when filtering files. Default is 5"
74
+ "filter_batch_size": "The batch size used when filtering files. Default is 5",
75
+ "skip_events": "Skip sending events during execution. Default is False",
76
+ "rag_url": "The URL of the RAG service. Default is empty",
77
+ "rag_token": "The token for the RAG service. Default is empty",
78
+ "rag_type": "RAG type (simple/storage), default is storage",
79
+ "rag_params_max_tokens": "The maximum number of tokens for RAG parameters. Default is 4096",
75
80
  },
76
81
  "zh": {
77
82
  "request_id": "Request ID",
@@ -145,6 +150,11 @@ lang_desc = {
145
150
  "base_dir": "用于替代byzerllm中/~/.auto-coder的路径存放或读取向量化后的文本。不指定则使用默认路径",
146
151
  "editblock_similarity": "合并编辑块时TextSimilarity的相似度阈值。默认为0.9",
147
152
  "include_project_structure": "在生成代码的提示中是否包含项目目录结构。默认为False",
148
- "filter_batch_size": "文件过滤时使用的批处理大小。默认为5"
153
+ "filter_batch_size": "文件过滤时使用的批处理大小。默认为5",
154
+ "skip_events": "在执行过程中跳过事件发送。默认为False",
155
+ "rag_url": "RAG服务的URL",
156
+ "rag_token": "RAG服务的令牌",
157
+ "rag_type": "RAG类型(simple/storage),默认是storage",
158
+ "rag_params_max_tokens": "RAG参数的最大token数。默认为4096",
149
159
  }
150
160
  }
@@ -14,6 +14,9 @@ from autocoder.common.search import Search, SearchEngine
14
14
  from loguru import logger
15
15
  import re
16
16
  from pydantic import BaseModel, Field
17
+ from rich.console import Console
18
+ import json
19
+ from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
17
20
 
18
21
 
19
22
  class RegPattern(BaseModel):
@@ -228,11 +231,37 @@ class PyProject:
228
231
  def get_rag_source_codes(self):
229
232
  if not self.args.enable_rag_search and not self.args.enable_rag_context:
230
233
  return []
234
+
235
+ if self.args.request_id and not self.args.skip_events:
236
+ _ = queue_communicate.send_event(
237
+ request_id=self.args.request_id,
238
+ event=CommunicateEvent(
239
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_START.value,
240
+ data=json.dumps({},ensure_ascii=False)
241
+ )
242
+ )
243
+ else:
244
+ console = Console()
245
+ console.print(f"\n[bold blue]Starting RAG search for:[/bold blue] {self.args.query}")
246
+
231
247
  from autocoder.rag.rag_entry import RAGFactory
232
248
  rag = RAGFactory.get_rag(self.llm, self.args, "")
233
249
  docs = rag.search(self.args.query)
234
250
  for doc in docs:
235
251
  doc.tag = "RAG"
252
+
253
+ if self.args.request_id and not self.args.skip_events:
254
+ _ = queue_communicate.send_event(
255
+ request_id=self.args.request_id,
256
+ event=CommunicateEvent(
257
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_END.value,
258
+ data=json.dumps({},ensure_ascii=False)
259
+ )
260
+ )
261
+ else:
262
+ console = Console()
263
+ console.print(f"[bold green]Found {len(docs)} relevant documents[/bold green]")
264
+
236
265
  return docs
237
266
 
238
267
  def get_search_source_codes(self):
@@ -261,7 +290,7 @@ class PyProject:
261
290
  return temp + []
262
291
 
263
292
  def get_source_codes(self) -> Generator[SourceCode, None, None]:
264
- for root, dirs, files in os.walk(self.directory):
293
+ for root, dirs, files in os.walk(self.directory,followlinks=True):
265
294
  dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
266
295
  for file in files:
267
296
  file_path = os.path.join(root, file)
@@ -364,7 +364,7 @@ class ByzerStorageCache(BaseCacheManager):
364
364
 
365
365
  def get_all_files(self) -> List[Tuple[str, str, float]]:
366
366
  all_files = []
367
- for root, dirs, files in os.walk(self.path):
367
+ for root, dirs, files in os.walk(self.path,followlinks=True):
368
368
  dirs[:] = [d for d in dirs if not d.startswith(".")]
369
369
 
370
370
  if self.ignore_spec:
@@ -117,7 +117,7 @@ class AutoCoderRAGDocListener(BaseCacheManager):
117
117
 
118
118
  def get_all_files(self) -> List[str]:
119
119
  all_files = []
120
- for root, dirs, files in os.walk(self.path):
120
+ for root, dirs, files in os.walk(self.path,followlinks=True):
121
121
  dirs[:] = [d for d in dirs if not d.startswith(".")]
122
122
 
123
123
  if self.ignore_spec:
@@ -174,7 +174,7 @@ class AutoCoderRAGAsyncUpdateQueue(BaseCacheManager):
174
174
 
175
175
  def get_all_files(self) -> List[Tuple[str, str, float]]:
176
176
  all_files = []
177
- for root, dirs, files in os.walk(self.path):
177
+ for root, dirs, files in os.walk(self.path,followlinks=True):
178
178
  dirs[:] = [d for d in dirs if not d.startswith(".")]
179
179
 
180
180
  if self.ignore_spec:
@@ -184,7 +184,7 @@ class AutoCoderRAGAsyncUpdateQueue(BaseCacheManager):
184
184
  for d in dirs
185
185
  if not self.ignore_spec.match_file(os.path.join(relative_root, d))
186
186
  ]
187
- files = [
187
+ files[:] = [
188
188
  f
189
189
  for f in files
190
190
  if not self.ignore_spec.match_file(os.path.join(relative_root, f))
@@ -260,6 +260,7 @@ class LongContextRAG:
260
260
  response = self.client.chat.completions.create(
261
261
  messages=[{"role": "user", "content": new_query}],
262
262
  model=self.args.model,
263
+ max_tokens=self.args.rag_params_max_tokens,
263
264
  )
264
265
  v = response.choices[0].message.content
265
266
  if not only_contexts:
@@ -9,6 +9,9 @@ import byzerllm
9
9
  from autocoder.common.search import Search, SearchEngine
10
10
  from loguru import logger
11
11
  from pydantic import BaseModel, Field
12
+ from rich.console import Console
13
+ import json
14
+ from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
12
15
 
13
16
 
14
17
  class RegPattern(BaseModel):
@@ -90,7 +93,7 @@ class RegexProject:
90
93
  return SourceCode(module_name=module_name, source_code=source_code)
91
94
 
92
95
  def get_source_codes(self) -> Generator[SourceCode, None, None]:
93
- for root, dirs, files in os.walk(self.directory):
96
+ for root, dirs, files in os.walk(self.directory,followlinks=True):
94
97
  for file in files:
95
98
  file_path = os.path.join(root, file)
96
99
  if self.is_regex_match(file_path):
@@ -117,11 +120,37 @@ class RegexProject:
117
120
  def get_rag_source_codes(self):
118
121
  if not self.args.enable_rag_search and not self.args.enable_rag_context:
119
122
  return []
123
+
124
+ if self.args.request_id and not self.args.skip_events:
125
+ _ = queue_communicate.send_event(
126
+ request_id=self.args.request_id,
127
+ event=CommunicateEvent(
128
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_START.value,
129
+ data=json.dumps({},ensure_ascii=False)
130
+ )
131
+ )
132
+ else:
133
+ console = Console()
134
+ console.print(f"\n[bold blue]Starting RAG search for:[/bold blue] {self.args.query}")
135
+
120
136
  from autocoder.rag.rag_entry import RAGFactory
121
137
  rag = RAGFactory.get_rag(self.llm, self.args, "")
122
138
  docs = rag.search(self.args.query)
123
139
  for doc in docs:
124
140
  doc.tag = "RAG"
141
+
142
+ if self.args.request_id and not self.args.skip_events:
143
+ _ = queue_communicate.send_event(
144
+ request_id=self.args.request_id,
145
+ event=CommunicateEvent(
146
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_END.value,
147
+ data=json.dumps({},ensure_ascii=False)
148
+ )
149
+ )
150
+ else:
151
+ console = Console()
152
+ console.print(f"[bold green]Found {len(docs)} relevant documents[/bold green]")
153
+
125
154
  return docs
126
155
 
127
156
  def get_search_source_codes(self):
@@ -9,6 +9,9 @@ from autocoder.common.search import Search, SearchEngine
9
9
  from loguru import logger
10
10
  import re
11
11
  from pydantic import BaseModel, Field
12
+ from rich.console import Console
13
+ import json
14
+ from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
12
15
 
13
16
 
14
17
  class RegPattern(BaseModel):
@@ -128,7 +131,7 @@ class SuffixProject:
128
131
  return SourceCode(module_name=module_name, source_code=source_code)
129
132
 
130
133
  def get_source_codes(self) -> Generator[SourceCode, None, None]:
131
- for root, dirs, files in os.walk(self.directory):
134
+ for root, dirs, files in os.walk(self.directory,followlinks=True):
132
135
  dirs[:] = [d for d in dirs if d not in self.default_exclude_dirs]
133
136
  for file in files:
134
137
  file_path = os.path.join(root, file)
@@ -157,11 +160,37 @@ class SuffixProject:
157
160
  def get_rag_source_codes(self):
158
161
  if not self.args.enable_rag_search and not self.args.enable_rag_context:
159
162
  return []
163
+
164
+ if self.args.request_id and not self.args.skip_events:
165
+ _ = queue_communicate.send_event(
166
+ request_id=self.args.request_id,
167
+ event=CommunicateEvent(
168
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_START.value,
169
+ data=json.dumps({},ensure_ascii=False)
170
+ )
171
+ )
172
+ else:
173
+ console = Console()
174
+ console.print(f"\n[bold blue]Starting RAG search for:[/bold blue] {self.args.query}")
175
+
160
176
  from autocoder.rag.rag_entry import RAGFactory
161
177
  rag = RAGFactory.get_rag(self.llm, self.args, "")
162
178
  docs = rag.search(self.args.query)
163
179
  for doc in docs:
164
180
  doc.tag = "RAG"
181
+
182
+ if self.args.request_id and not self.args.skip_events:
183
+ _ = queue_communicate.send_event(
184
+ request_id=self.args.request_id,
185
+ event=CommunicateEvent(
186
+ event_type=CommunicateEventType.CODE_RAG_SEARCH_END.value,
187
+ data=json.dumps({},ensure_ascii=False)
188
+ )
189
+ )
190
+ else:
191
+ console = Console()
192
+ console.print(f"[bold green]Found {len(docs)} relevant documents[/bold green]")
193
+
165
194
  return docs
166
195
 
167
196
  def get_search_source_codes(self):