auto-coder 0.1.374__py3-none-any.whl → 0.1.375__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import re
3
3
  import glob
4
- from typing import Dict, Any, Optional, List
4
+ from typing import Dict, Any, Optional, List, Union
5
5
  from autocoder.agent.base_agentic.tools.base_tool_resolver import BaseToolResolver
6
6
  from autocoder.agent.base_agentic.types import SearchFilesTool, ToolResult # Import ToolResult from types
7
7
  from loguru import logger
@@ -20,14 +20,54 @@ class SearchFilesToolResolver(BaseToolResolver):
20
20
  self.tool: SearchFilesTool = tool
21
21
  self.shadow_manager = self.agent.shadow_manager if self.agent else None
22
22
 
23
- def resolve(self) -> ToolResult:
24
- search_path_str = self.tool.path
25
- regex_pattern = self.tool.regex
26
- file_pattern = self.tool.file_pattern or "*"
27
- source_dir = self.args.source_dir or "."
28
- absolute_source_dir = os.path.abspath(source_dir)
29
- absolute_search_path = os.path.abspath(os.path.join(source_dir, search_path_str))
23
+ def search_in_dir(self, base_dir: str, regex_pattern: str, file_pattern: str, source_dir: str, is_shadow: bool = False, compiled_regex: Optional[re.Pattern] = None) -> List[Dict[str, Any]]:
24
+ """Helper function to search in a directory"""
25
+ search_results = []
26
+ search_glob_pattern = os.path.join(base_dir, "**", file_pattern)
27
+
28
+ logger.info(f"Searching for regex '{regex_pattern}' in files matching '{file_pattern}' under '{base_dir}' (shadow: {is_shadow}) with ignore rules applied.")
29
+
30
+ if compiled_regex is None:
31
+ compiled_regex = re.compile(regex_pattern)
32
+
33
+ for filepath in glob.glob(search_glob_pattern, recursive=True):
34
+ abs_path = os.path.abspath(filepath)
35
+ if should_ignore(abs_path):
36
+ continue
30
37
 
38
+ if os.path.isfile(filepath):
39
+ try:
40
+ with open(filepath, 'r', encoding='utf-8', errors='replace') as f:
41
+ lines = f.readlines()
42
+ for i, line in enumerate(lines):
43
+ if compiled_regex.search(line):
44
+ context_start = max(0, i - 2)
45
+ context_end = min(len(lines), i + 3)
46
+ context = "".join([f"{j+1}: {lines[j]}" for j in range(context_start, context_end)])
47
+
48
+ if is_shadow and self.shadow_manager:
49
+ try:
50
+ abs_project_path = self.shadow_manager.from_shadow_path(filepath)
51
+ relative_path = os.path.relpath(abs_project_path, source_dir)
52
+ except Exception:
53
+ relative_path = os.path.relpath(filepath, source_dir)
54
+ else:
55
+ relative_path = os.path.relpath(filepath, source_dir)
56
+
57
+ search_results.append({
58
+ "path": relative_path,
59
+ "line_number": i + 1,
60
+ "match_line": line.strip(),
61
+ "context": context.strip()
62
+ })
63
+ except Exception as e:
64
+ logger.warning(f"Could not read or process file {filepath}: {e}")
65
+ continue
66
+
67
+ return search_results
68
+
69
+ def search_files_with_shadow(self, search_path_str: str, regex_pattern: str, file_pattern: str, source_dir: str, absolute_source_dir: str, absolute_search_path: str) -> Union[ToolResult, List[Dict[str, Any]]]:
70
+ """Search files using shadow manager for path translation"""
31
71
  # Security check
32
72
  if not absolute_search_path.startswith(absolute_source_dir):
33
73
  return ToolResult(success=False, message=f"Error: Access denied. Attempted to search outside the project directory: {search_path_str}")
@@ -54,58 +94,15 @@ class SearchFilesToolResolver(BaseToolResolver):
54
94
  try:
55
95
  compiled_regex = re.compile(regex_pattern)
56
96
 
57
- # Helper function to search in a directory
58
- def search_in_dir(base_dir, is_shadow=False):
59
- search_results = []
60
- search_glob_pattern = os.path.join(base_dir, "**", file_pattern)
61
-
62
- logger.info(f"Searching for regex '{regex_pattern}' in files matching '{file_pattern}' under '{base_dir}' (shadow: {is_shadow}) with ignore rules applied.")
63
-
64
- for filepath in glob.glob(search_glob_pattern, recursive=True):
65
- abs_path = os.path.abspath(filepath)
66
- if should_ignore(abs_path):
67
- continue
68
-
69
- if os.path.isfile(filepath):
70
- try:
71
- with open(filepath, 'r', encoding='utf-8', errors='replace') as f:
72
- lines = f.readlines()
73
- for i, line in enumerate(lines):
74
- if compiled_regex.search(line):
75
- context_start = max(0, i - 2)
76
- context_end = min(len(lines), i + 3)
77
- context = "".join([f"{j+1}: {lines[j]}" for j in range(context_start, context_end)])
78
-
79
- if is_shadow and self.shadow_manager:
80
- try:
81
- abs_project_path = self.shadow_manager.from_shadow_path(filepath)
82
- relative_path = os.path.relpath(abs_project_path, source_dir)
83
- except Exception:
84
- relative_path = os.path.relpath(filepath, source_dir)
85
- else:
86
- relative_path = os.path.relpath(filepath, source_dir)
87
-
88
- search_results.append({
89
- "path": relative_path,
90
- "line_number": i + 1,
91
- "match_line": line.strip(),
92
- "context": context.strip()
93
- })
94
- except Exception as e:
95
- logger.warning(f"Could not read or process file {filepath}: {e}")
96
- continue
97
-
98
- return search_results
99
-
100
97
  # Search in both directories and merge results
101
98
  shadow_results = []
102
99
  source_results = []
103
100
 
104
101
  if shadow_exists:
105
- shadow_results = search_in_dir(shadow_dir_path, is_shadow=True)
102
+ shadow_results = self.search_in_dir(shadow_dir_path, regex_pattern, file_pattern, source_dir, is_shadow=True, compiled_regex=compiled_regex)
106
103
 
107
104
  if os.path.exists(absolute_search_path) and os.path.isdir(absolute_search_path):
108
- source_results = search_in_dir(absolute_search_path, is_shadow=False)
105
+ source_results = self.search_in_dir(absolute_search_path, regex_pattern, file_pattern, source_dir, is_shadow=False, compiled_regex=compiled_regex)
109
106
 
110
107
  # Merge results, prioritizing shadow results
111
108
  # Create a dictionary for quick lookup
@@ -122,9 +119,34 @@ class SearchFilesToolResolver(BaseToolResolver):
122
119
  # Convert back to list
123
120
  merged_results = list(results_dict.values())
124
121
 
125
- message = f"Search completed. Found {len(merged_results)} matches."
126
- logger.info(message)
127
- return ToolResult(success=True, message=message, content=merged_results)
122
+ return merged_results
123
+
124
+ except re.error as e:
125
+ logger.error(f"Invalid regex pattern '{regex_pattern}': {e}")
126
+ return ToolResult(success=False, message=f"Invalid regex pattern: {e}")
127
+ except Exception as e:
128
+ logger.error(f"Error during file search: {str(e)}")
129
+ return ToolResult(success=False, message=f"An unexpected error occurred during search: {str(e)}")
130
+
131
+ def search_files_normal(self, search_path_str: str, regex_pattern: str, file_pattern: str, source_dir: str, absolute_source_dir: str, absolute_search_path: str) -> Union[ToolResult, List[Dict[str, Any]]]:
132
+ """Search files directly without using shadow manager"""
133
+ # Security check
134
+ if not absolute_search_path.startswith(absolute_source_dir):
135
+ return ToolResult(success=False, message=f"Error: Access denied. Attempted to search outside the project directory: {search_path_str}")
136
+
137
+ # Validate that the directory exists
138
+ if not os.path.exists(absolute_search_path):
139
+ return ToolResult(success=False, message=f"Error: Search path not found: {search_path_str}")
140
+ if not os.path.isdir(absolute_search_path):
141
+ return ToolResult(success=False, message=f"Error: Search path is not a directory: {search_path_str}")
142
+
143
+ try:
144
+ compiled_regex = re.compile(regex_pattern)
145
+
146
+ # Search in the directory
147
+ search_results = self.search_in_dir(absolute_search_path, regex_pattern, file_pattern, source_dir, is_shadow=False, compiled_regex=compiled_regex)
148
+
149
+ return search_results
128
150
 
129
151
  except re.error as e:
130
152
  logger.error(f"Invalid regex pattern '{regex_pattern}': {e}")
@@ -132,3 +154,26 @@ class SearchFilesToolResolver(BaseToolResolver):
132
154
  except Exception as e:
133
155
  logger.error(f"Error during file search: {str(e)}")
134
156
  return ToolResult(success=False, message=f"An unexpected error occurred during search: {str(e)}")
157
+
158
+ def resolve(self) -> ToolResult:
159
+ """Resolve the search files tool by calling the appropriate implementation"""
160
+ search_path_str = self.tool.path
161
+ regex_pattern = self.tool.regex
162
+ file_pattern = self.tool.file_pattern or "*"
163
+ source_dir = self.args.source_dir or "."
164
+ absolute_source_dir = os.path.abspath(source_dir)
165
+ absolute_search_path = os.path.abspath(os.path.join(source_dir, search_path_str))
166
+
167
+ # Choose the appropriate implementation based on whether shadow_manager is available
168
+ if self.shadow_manager:
169
+ result = self.search_files_with_shadow(search_path_str, regex_pattern, file_pattern, source_dir, absolute_source_dir, absolute_search_path)
170
+ else:
171
+ result = self.search_files_normal(search_path_str, regex_pattern, file_pattern, source_dir, absolute_source_dir, absolute_search_path)
172
+
173
+ # Handle the case where the implementation returns a list instead of a ToolResult
174
+ if isinstance(result, list):
175
+ message = f"Search completed. Found {len(result)} matches."
176
+ logger.info(message)
177
+ return ToolResult(success=True, message=message, content=result)
178
+ else:
179
+ return result
@@ -2,10 +2,15 @@ from typing import Optional, Dict, Any
2
2
  import os
3
3
  from loguru import logger
4
4
  from datetime import datetime
5
+ import typing
5
6
 
6
7
  from ..types import TalkToGroupTool, ToolResult
7
8
  from ..tools.base_tool_resolver import BaseToolResolver
8
9
  from ..agent_hub import AgentHub, Group
10
+ from autocoder.common import AutoCoderArgs
11
+
12
+ if typing.TYPE_CHECKING:
13
+ from ..base_agent import BaseAgent
9
14
 
10
15
 
11
16
  class TalkToGroupToolResolver(BaseToolResolver):
@@ -2,10 +2,15 @@ from typing import Optional, Dict, Any
2
2
  import os
3
3
  from loguru import logger
4
4
  from datetime import datetime
5
+ import typing
5
6
 
6
7
  from ..types import TalkToTool, ToolResult
7
8
  from ..tools.base_tool_resolver import BaseToolResolver
8
9
  from ..agent_hub import AgentHub
10
+ from autocoder.common import AutoCoderArgs
11
+
12
+ if typing.TYPE_CHECKING:
13
+ from ..base_agent import BaseAgent
9
14
 
10
15
 
11
16
  class TalkToToolResolver(BaseToolResolver):
@@ -1,9 +1,12 @@
1
1
  import os
2
- from typing import Dict, Any, Optional
2
+ from typing import Dict, Any, Optional,List
3
3
  from autocoder.agent.base_agentic.types import WriteToFileTool, ToolResult # Import ToolResult from types
4
4
  from autocoder.agent.base_agentic.tools.base_tool_resolver import BaseToolResolver
5
5
  from loguru import logger
6
6
  from autocoder.common import AutoCoderArgs
7
+ from autocoder.common.file_checkpoint.models import FileChange as CheckpointFileChange
8
+ from autocoder.common.file_checkpoint.manager import FileChangeManager as CheckpointFileChangeManager
9
+ from autocoder.linters.models import IssueSeverity, FileLintResult
7
10
  import typing
8
11
 
9
12
  if typing.TYPE_CHECKING:
@@ -13,9 +16,148 @@ class WriteToFileToolResolver(BaseToolResolver):
13
16
  def __init__(self, agent: Optional['BaseAgent'], tool: WriteToFileTool, args: AutoCoderArgs):
14
17
  super().__init__(agent, tool, args)
15
18
  self.tool: WriteToFileTool = tool # For type hinting
19
+ self.args = args
16
20
  self.shadow_manager = self.agent.shadow_manager if self.agent else None
21
+ self.shadow_linter = self.agent.shadow_linter if self.agent else None
22
+
23
+ def _filter_lint_issues(self, lint_result:FileLintResult, levels: List[IssueSeverity] = [IssueSeverity.ERROR, IssueSeverity.WARNING]):
24
+ """
25
+ 过滤 lint 结果,只保留指定级别的问题
26
+
27
+ 参数:
28
+ lint_result: 单个文件的 lint 结果对象
29
+ levels: 要保留的问题级别列表,默认保留 ERROR 和 WARNING 级别
30
+
31
+ 返回:
32
+ 过滤后的 lint 结果对象(原对象的副本)
33
+ """
34
+ if not lint_result or not lint_result.issues:
35
+ return lint_result
36
+
37
+ # 创建一个新的 issues 列表,只包含指定级别的问题
38
+ filtered_issues = []
39
+ for issue in lint_result.issues:
40
+ if issue.severity in levels:
41
+ filtered_issues.append(issue)
42
+
43
+ # 更新 lint_result 的副本
44
+ filtered_result = lint_result
45
+ filtered_result.issues = filtered_issues
46
+
47
+ # 更新计数
48
+ filtered_result.error_count = sum(1 for issue in filtered_issues if issue.severity == IssueSeverity.ERROR)
49
+ filtered_result.warning_count = sum(1 for issue in filtered_issues if issue.severity == IssueSeverity.WARNING)
50
+ filtered_result.info_count = sum(1 for issue in filtered_issues if issue.severity == IssueSeverity.INFO)
51
+
52
+ return filtered_result
53
+
54
+ def _format_lint_issues(self, lint_result:FileLintResult):
55
+ """
56
+ 将 lint 结果格式化为可读的文本格式
57
+
58
+ 参数:
59
+ lint_result: 单个文件的 lint 结果对象
60
+
61
+ 返回:
62
+ str: 格式化的问题描述
63
+ """
64
+ formatted_issues = []
65
+
66
+ for issue in lint_result.issues:
67
+ severity = "错误" if issue.severity.value == 3 else "警告" if issue.severity.value == 2 else "信息"
68
+ line_info = f"第{issue.position.line}行"
69
+ if issue.position.column:
70
+ line_info += f", 第{issue.position.column}列"
71
+
72
+ formatted_issues.append(
73
+ f" - [{severity}] {line_info}: {issue.message} (规则: {issue.code})"
74
+ )
75
+
76
+ return "\n".join(formatted_issues)
77
+
78
+
79
+ def write_file_normal(self, file_path: str, content: str, source_dir: str, abs_project_dir: str, abs_file_path: str) -> ToolResult:
80
+ """Write file directly without using shadow manager"""
81
+ try:
82
+ os.makedirs(os.path.dirname(abs_file_path), exist_ok=True)
83
+
84
+ if self.agent:
85
+ rel_path = os.path.relpath(abs_file_path, abs_project_dir)
86
+ self.agent.record_file_change(rel_path, "added", diff=None, content=content)
87
+
88
+ if self.agent and self.agent.checkpoint_manager:
89
+ changes = {
90
+ file_path: CheckpointFileChange(
91
+ file_path=file_path,
92
+ content=content,
93
+ is_deletion=False,
94
+ is_new=True
95
+ )
96
+ }
97
+ change_group_id = self.args.event_file
98
+
99
+ self.agent.checkpoint_manager.apply_changes_with_conversation(
100
+ changes=changes,
101
+ conversations=self.agent.current_conversations,
102
+ change_group_id=change_group_id,
103
+ metadata={"event_file": self.args.event_file}
104
+ )
105
+ else:
106
+ with open(abs_file_path, 'w', encoding='utf-8') as f:
107
+ f.write(content)
108
+ logger.info(f"Successfully wrote to file: {file_path}")
109
+
110
+ # 新增:执行代码质量检查
111
+ lint_results = None
112
+ lint_message = ""
113
+ formatted_issues = ""
114
+ has_lint_issues = False
115
+
116
+ # 检查是否启用了Lint功能
117
+ enable_lint = self.args.enable_auto_fix_lint
118
+
119
+ if enable_lint:
120
+ try:
121
+ if self.agent.linter:
122
+ lint_results = self.agent.linter.lint_file(file_path)
123
+ if lint_results and lint_results.issues:
124
+ # 过滤 lint 结果,只保留 ERROR 和 WARNING 级别的问题
125
+ filtered_results = self._filter_lint_issues(lint_results)
126
+ if filtered_results.issues:
127
+ has_lint_issues = True
128
+ # 格式化 lint 问题
129
+ formatted_issues = self._format_lint_issues(filtered_results)
130
+ lint_message = f"\n\n代码质量检查发现 {len(filtered_results.issues)} 个问题"
131
+ except Exception as e:
132
+ logger.error(f"Lint 检查失败: {str(e)}")
133
+ lint_message = "\n\n尝试进行代码质量检查时出错。"
134
+ else:
135
+ logger.info("代码质量检查已禁用")
136
+
137
+ # 构建包含 lint 结果的返回消息
138
+ message = f"{file_path}"
139
+
140
+
141
+ # 附加 lint 结果到返回内容
142
+ result_content = {
143
+ "content": content,
144
+ }
145
+
146
+ # 只有在启用Lint时才添加Lint结果
147
+ if enable_lint:
148
+ message = message + "\n" + lint_message
149
+ result_content["lint_results"] = {
150
+ "has_issues": has_lint_issues,
151
+ "issues": formatted_issues if has_lint_issues else None
152
+ }
153
+
154
+ return ToolResult(success=True, message=message, content=result_content)
155
+ except Exception as e:
156
+ logger.error(f"Error writing to file '{file_path}': {str(e)}")
157
+ return ToolResult(success=False, message=f"An error occurred while writing to the file: {str(e)}")
17
158
 
18
159
  def resolve(self) -> ToolResult:
160
+ """Resolve the write file tool by calling the appropriate implementation"""
19
161
  file_path = self.tool.path
20
162
  content = self.tool.content
21
163
  source_dir = self.args.source_dir or "."
@@ -25,34 +167,5 @@ class WriteToFileToolResolver(BaseToolResolver):
25
167
  # Security check: ensure the path is within the source directory
26
168
  if not abs_file_path.startswith(abs_project_dir):
27
169
  return ToolResult(success=False, message=f"Error: Access denied. Attempted to write file outside the project directory: {file_path}")
28
-
29
- try:
30
- if self.shadow_manager:
31
- shadow_path = self.shadow_manager.to_shadow_path(abs_file_path)
32
- # Ensure shadow directory exists
33
- os.makedirs(os.path.dirname(shadow_path), exist_ok=True)
34
- with open(shadow_path, 'w', encoding='utf-8') as f:
35
- f.write(content)
36
- logger.info(f"[Shadow] Successfully wrote shadow file: {shadow_path}")
37
-
38
- # 回调AgenticEdit,记录变更
39
- if self.agent:
40
- rel_path = os.path.relpath(abs_file_path, abs_project_dir)
41
- self.agent.record_file_change(rel_path, "added", diff=None, content=content)
42
-
43
- return ToolResult(success=True, message=f"Successfully wrote to file (shadow): {file_path}", content=content)
44
- else:
45
- # No shadow manager fallback to original file
46
- os.makedirs(os.path.dirname(abs_file_path), exist_ok=True)
47
- with open(abs_file_path, 'w', encoding='utf-8') as f:
48
- f.write(content)
49
- logger.info(f"Successfully wrote to file: {file_path}")
50
-
51
- if self.agent:
52
- rel_path = os.path.relpath(abs_file_path, abs_project_dir)
53
- self.agent.record_file_change(rel_path, "added", diff=None, content=content)
54
-
55
- return ToolResult(success=True, message=f"Successfully wrote to file: {file_path}", content=content)
56
- except Exception as e:
57
- logger.error(f"Error writing to file '{file_path}': {str(e)}")
58
- return ToolResult(success=False, message=f"An error occurred while writing to the file: {str(e)}")
170
+
171
+ return self.write_file_normal(file_path, content, source_dir, abs_project_dir, abs_file_path)
@@ -6,6 +6,8 @@ from typing import Optional, List
6
6
  import byzerllm
7
7
  from autocoder.rag.api_server import serve, ServerArgs
8
8
  from autocoder.rag.rag_entry import RAGFactory
9
+ from autocoder.rag.agentic_rag import AgenticRAG
10
+ from autocoder.rag.long_context_rag import LongContextRAG
9
11
  from autocoder.rag.llm_wrapper import LLWrapper
10
12
  from autocoder.common import AutoCoderArgs
11
13
  from autocoder.lang import lang_desc
@@ -301,6 +303,7 @@ def main(input_args: Optional[List[str]] = None):
301
303
  help="Document directory path, also used as the root directory for serving static files"
302
304
  )
303
305
  serve_parser.add_argument("--enable_local_image_host", action="store_true", help=" enable local image host for local Chat app")
306
+ serve_parser.add_argument("--agentic", action="store_true", help="使用 AgenticRAG 而不是 LongContextRAG")
304
307
  serve_parser.add_argument("--tokenizer_path", default=tokenizer_path, help="")
305
308
  serve_parser.add_argument(
306
309
  "--collections", default="", help="Collection name for indexing"
@@ -432,6 +435,18 @@ def main(input_args: Optional[List[str]] = None):
432
435
  help="The model used for embedding documents",
433
436
  )
434
437
 
438
+ serve_parser.add_argument(
439
+ "--agentic_model",
440
+ default="",
441
+ help="The model used for agentic operations",
442
+ )
443
+
444
+ serve_parser.add_argument(
445
+ "--context_prune_model",
446
+ default="",
447
+ help="The model used for context pruning",
448
+ )
449
+
435
450
  # Benchmark command
436
451
  benchmark_parser = subparsers.add_parser(
437
452
  "benchmark", help="Benchmark LLM client performance"
@@ -622,6 +637,18 @@ def main(input_args: Optional[List[str]] = None):
622
637
  emb_model.skip_nontext_check = True
623
638
  llm.setup_sub_client("emb_model", emb_model)
624
639
 
640
+ if args.agentic_model:
641
+ agentic_model = byzerllm.ByzerLLM()
642
+ agentic_model.setup_default_model_name(args.agentic_model)
643
+ agentic_model.skip_nontext_check = True
644
+ llm.setup_sub_client("agentic_model", agentic_model)
645
+
646
+ if args.context_prune_model:
647
+ context_prune_model = byzerllm.ByzerLLM()
648
+ context_prune_model.setup_default_model_name(args.context_prune_model)
649
+ context_prune_model.skip_nontext_check = True
650
+ llm.setup_sub_client("context_prune_model", context_prune_model)
651
+
625
652
  # 当启用hybrid_index时,检查必要的组件
626
653
  if auto_coder_args.enable_hybrid_index:
627
654
  if not args.emb_model and not llm.is_model_exist("emb"):
@@ -698,7 +725,7 @@ def main(input_args: Optional[List[str]] = None):
698
725
  "saas.max_output_tokens": model_info.get("max_output_tokens", 8096)
699
726
  }
700
727
  )
701
- llm.setup_sub_client("qa_model", qa_model)
728
+ llm.setup_sub_client("qa_model", qa_model)
702
729
 
703
730
  if args.emb_model:
704
731
  model_info = models_module.get_model_by_name(args.emb_model)
@@ -717,22 +744,52 @@ def main(input_args: Optional[List[str]] = None):
717
744
  )
718
745
  llm.setup_sub_client("emb_model", emb_model)
719
746
 
747
+ if args.agentic_model:
748
+ model_info = models_module.get_model_by_name(args.agentic_model)
749
+ agentic_model = byzerllm.SimpleByzerLLM(default_model_name=args.agentic_model)
750
+ agentic_model.deploy(
751
+ model_path="",
752
+ pretrained_model_type=model_info["model_type"],
753
+ udf_name=args.agentic_model,
754
+ infer_params={
755
+ "saas.base_url": model_info["base_url"],
756
+ "saas.api_key": model_info["api_key"],
757
+ "saas.model": model_info["model_name"],
758
+ "saas.is_reasoning": model_info["is_reasoning"],
759
+ "saas.max_output_tokens": model_info.get("max_output_tokens", 8096)
760
+ }
761
+ )
762
+ llm.setup_sub_client("agentic_model", agentic_model)
763
+
764
+ if args.context_prune_model:
765
+ model_info = models_module.get_model_by_name(args.context_prune_model)
766
+ context_prune_model = byzerllm.SimpleByzerLLM(default_model_name=args.context_prune_model)
767
+ context_prune_model.deploy(
768
+ model_path="",
769
+ pretrained_model_type=model_info["model_type"],
770
+ udf_name=args.context_prune_model,
771
+ infer_params={
772
+ "saas.base_url": model_info["base_url"],
773
+ "saas.api_key": model_info["api_key"],
774
+ "saas.model": model_info["model_name"],
775
+ "saas.is_reasoning": model_info["is_reasoning"],
776
+ "saas.max_output_tokens": model_info.get("max_output_tokens", 8096)
777
+ }
778
+ )
779
+ llm.setup_sub_client("context_prune_model", context_prune_model)
780
+
720
781
  if args.enable_hybrid_index:
721
782
  if not args.emb_model:
722
783
  raise Exception("When enable_hybrid_index is true, an 'emb' model must be specified")
723
784
 
724
- if server_args.doc_dir:
725
- auto_coder_args.rag_type = "simple"
785
+ if server_args.doc_dir:
726
786
  auto_coder_args.rag_build_name = generate_unique_name_from_path(server_args.doc_dir)
727
- rag = RAGFactory.get_rag(
728
- llm=llm,
729
- args=auto_coder_args,
730
- path=server_args.doc_dir,
731
- tokenizer_path=server_args.tokenizer_path,
732
- )
787
+ if args.agentic:
788
+ rag = AgenticRAG(llm=llm, args=auto_coder_args, path=server_args.doc_dir, tokenizer_path=server_args.tokenizer_path)
789
+ else:
790
+ rag = LongContextRAG(llm=llm, args=auto_coder_args, path=server_args.doc_dir, tokenizer_path=server_args.tokenizer_path)
733
791
  else:
734
- auto_coder_args.rag_build_name = generate_unique_name_from_path("")
735
- rag = RAGFactory.get_rag(llm=llm, args=auto_coder_args, path="")
792
+ raise Exception("doc_dir is required")
736
793
 
737
794
  llm_wrapper = LLWrapper(llm=llm, rag=rag)
738
795
  # Save service info