auto-coder 0.1.259__py3-none-any.whl → 0.1.261__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (36) hide show
  1. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/METADATA +1 -1
  2. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/RECORD +36 -27
  3. autocoder/agent/auto_review_commit.py +51 -24
  4. autocoder/auto_coder.py +24 -1
  5. autocoder/chat_auto_coder.py +377 -399
  6. autocoder/chat_auto_coder_lang.py +20 -0
  7. autocoder/commands/__init__.py +0 -0
  8. autocoder/commands/auto_command.py +1174 -0
  9. autocoder/commands/tools.py +533 -0
  10. autocoder/common/__init__.py +8 -0
  11. autocoder/common/auto_coder_lang.py +61 -8
  12. autocoder/common/auto_configure.py +304 -0
  13. autocoder/common/code_auto_merge.py +2 -2
  14. autocoder/common/code_auto_merge_diff.py +2 -2
  15. autocoder/common/code_auto_merge_editblock.py +2 -2
  16. autocoder/common/code_auto_merge_strict_diff.py +2 -2
  17. autocoder/common/code_modification_ranker.py +8 -7
  18. autocoder/common/command_completer.py +557 -0
  19. autocoder/common/conf_validator.py +245 -0
  20. autocoder/common/conversation_pruner.py +131 -0
  21. autocoder/common/git_utils.py +82 -1
  22. autocoder/common/index_import_export.py +101 -0
  23. autocoder/common/result_manager.py +115 -0
  24. autocoder/common/shells.py +22 -6
  25. autocoder/common/utils_code_auto_generate.py +2 -2
  26. autocoder/dispacher/actions/action.py +45 -4
  27. autocoder/dispacher/actions/plugins/action_regex_project.py +13 -1
  28. autocoder/index/filter/quick_filter.py +22 -7
  29. autocoder/utils/auto_coder_utils/chat_stream_out.py +13 -6
  30. autocoder/utils/project_structure.py +15 -0
  31. autocoder/utils/thread_utils.py +4 -0
  32. autocoder/version.py +1 -1
  33. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/LICENSE +0 -0
  34. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/WHEEL +0 -0
  35. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/entry_points.txt +0 -0
  36. {auto_coder-0.1.259.dist-info → auto_coder-0.1.261.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Any
2
+ from autocoder.utils import llms as llms_utils
3
+ from autocoder.common.auto_coder_lang import get_message_with_format
4
+
5
+ class ConfigValidationError(Exception):
6
+ def __init__(self, message: str):
7
+ self.message = message
8
+ super().__init__(message)
9
+
10
+ class ConfigValidator:
11
+ CONFIG_SPEC = {
12
+ # 核心配置项
13
+ "auto_merge": {
14
+ "type": str,
15
+ "allowed": ["editblock", "diff", "wholefile"],
16
+ "default": "editblock",
17
+ "description": "代码合并方式(editblock/diff/wholefile)"
18
+ },
19
+ "editblock_similarity": {
20
+ "type": float,
21
+ "min": 0.0,
22
+ "max": 1.0,
23
+ "default": 0.9,
24
+ "description": "代码块相似度阈值(0-1)"
25
+ },
26
+ "generate_times_same_model": {
27
+ "type": int,
28
+ "min": 1,
29
+ "max": 5,
30
+ "default": 1,
31
+ "description": "同模型生成次数(1-5)"
32
+ },
33
+ "skip_filter_index": {
34
+ "type": bool,
35
+ "default": False,
36
+ "description": "是否跳过根据用户的query自动查找上下文"
37
+ },
38
+ "skip_build_index": {
39
+ "type": bool,
40
+ "default": True,
41
+ "description": "是否自动构建索引"
42
+ },
43
+ "enable_global_memory": {
44
+ "type": bool,
45
+ "default": True,
46
+ "description": "是否开启全局记忆"
47
+ },
48
+ "rank_times_same_model": {
49
+ "type": int,
50
+ "min": 1,
51
+ "max": 3,
52
+ "default": 1,
53
+ "description": "相同模型重排序次数"
54
+ },
55
+ "human_as_model": {
56
+ "type": bool,
57
+ "default": False,
58
+ "description": "是否以人类作为模型"
59
+ },
60
+ "skip_confirm": {
61
+ "type": bool,
62
+ "default": True,
63
+ "description": "是否跳过确认步骤"
64
+ },
65
+ "silence": {
66
+ "type": bool,
67
+ "default": True,
68
+ "description": "是否静默模式"
69
+ },
70
+ "include_project_structure": {
71
+ "type": bool,
72
+ "default": True,
73
+ "description": "是否包含项目结构"
74
+ },
75
+ "product_mode": {
76
+ "type": str,
77
+ "allowed": ["lite", "pro"],
78
+ "default": "lite",
79
+ "description": "产品模式(lite/pro)"
80
+ },
81
+ "model": {
82
+ "type": str,
83
+ "default": "v3_chat",
84
+ "description": "默认模型名称"
85
+ },
86
+ "chat_model": {
87
+ "type": str,
88
+ "default": "r1_chat",
89
+ "description": "聊天模型名称"
90
+ },
91
+ "code_model": {
92
+ "type": str,
93
+ "default": "v3_chat",
94
+ "description": "代码生成模型名称"
95
+ },
96
+ "index_filter_model": {
97
+ "type": str,
98
+ "default": "r1_chat",
99
+ "description": "索引过滤模型名称"
100
+ },
101
+ "generate_rerank_model": {
102
+ "type": str,
103
+ "default": "r1_chat",
104
+ "description": "生成重排序模型名称"
105
+ },
106
+ "emb_model": {
107
+ "type": str,
108
+ "default": "v3_chat",
109
+ "description": "嵌入模型名称"
110
+ },
111
+ "vl_model": {
112
+ "type": str,
113
+ "default": "v3_chat",
114
+ "description": "视觉语言模型名称"
115
+ },
116
+ "designer_model": {
117
+ "type": str,
118
+ "default": "v3_chat",
119
+ "description": "设计模型名称"
120
+ },
121
+ "sd_model": {
122
+ "type": str,
123
+ "default": "v3_chat",
124
+ "description": "稳定扩散模型名称"
125
+ },
126
+ "voice2text_model": {
127
+ "type": str,
128
+ "default": "v3_chat",
129
+ "description": "语音转文本模型名称"
130
+ },
131
+ "commit_model": {
132
+ "type": str,
133
+ "default": "v3_chat",
134
+ "description": "提交信息生成模型名称"
135
+ }
136
+ }
137
+
138
+ @classmethod
139
+ def validate(cls, key: str, value: Any, product_mode: str) -> Any:
140
+ # 获取配置规范
141
+ spec = cls.CONFIG_SPEC.get(key)
142
+ if not spec:
143
+ # raise ConfigValidationError(
144
+ # get_message_with_format("unknown_config_key", key=key)
145
+ # )
146
+ return
147
+
148
+ # 类型转换和验证
149
+ try:
150
+ # 布尔类型特殊处理
151
+ if isinstance(spec['type'], (list, tuple)):
152
+ # 多个类型支持
153
+ for type_ in spec['type']:
154
+ try:
155
+ if type_ == bool:
156
+ return cls.validate_boolean(value)
157
+ converted_value = type_(value)
158
+ break
159
+ except ValueError:
160
+ continue
161
+ else:
162
+ types_str = ', '.join([t.__name__ for t in spec['type']])
163
+ raise ConfigValidationError(
164
+ get_message_with_format(f"invalid_type_value",
165
+ value=value,
166
+ types=types_str)
167
+ )
168
+ else:
169
+ # 单个类型处理
170
+ if spec['type'] == bool:
171
+ return cls.validate_boolean(value)
172
+ converted_value = spec['type'](value)
173
+ except ValueError:
174
+ type_name = spec['type'].__name__ if not isinstance(spec['type'], (list, tuple)) else ', '.join([t.__name__ for t in spec['type']])
175
+ raise ConfigValidationError(
176
+ get_message_with_format(f"invalid_type_value",
177
+ value=value,
178
+ types=type_name)
179
+ )
180
+
181
+ # 范围检查
182
+ if 'min' in spec and converted_value < spec['min']:
183
+ raise ConfigValidationError(
184
+ get_message_with_format("value_out_of_range",
185
+ value=converted_value,
186
+ min=spec['min'],
187
+ max=spec['max'])
188
+ )
189
+
190
+ if 'max' in spec and converted_value > spec['max']:
191
+ raise ConfigValidationError(
192
+ get_message_with_format("value_out_of_range",
193
+ value=converted_value,
194
+ min=spec['min'],
195
+ max=spec['max'])
196
+ )
197
+
198
+ # 枚举值检查
199
+ if 'allowed' in spec and converted_value not in spec['allowed']:
200
+ raise ConfigValidationError(
201
+ get_message_with_format("invalid_enum_value",
202
+ value=converted_value,
203
+ allowed=', '.join(map(str, spec['allowed'])))
204
+ )
205
+
206
+ # 模型存在性检查
207
+ if product_mode == "lite" and key in ["chat_model","code_model",
208
+ "index_filter_model", "generate_rerank_model",
209
+ "rank_times_same_model",
210
+ "emb_model", "vl_model", "designer_model", "sd_model",
211
+ "voice2text_model",
212
+ "commit_model","model"]:
213
+ if not llms_utils.get_model_info(converted_value,product_mode):
214
+ raise ConfigValidationError(
215
+ get_message_with_format("model_not_found", model=converted_value)
216
+ )
217
+
218
+ return converted_value
219
+
220
+ @staticmethod
221
+ def validate_boolean(value: str) -> bool:
222
+ if value.lower() in ("true", "1", "yes"):
223
+ return True
224
+ if value.lower() in ("false", "0", "no"):
225
+ return False
226
+ raise ConfigValidationError(
227
+ get_message_with_format("invalid_boolean_value", value=value)
228
+ )
229
+
230
+ @classmethod
231
+ def get_config_docs(cls) -> str:
232
+ """生成配置项文档"""
233
+ docs = ["可用配置项:"]
234
+ for key, spec in cls.CONFIG_SPEC.items():
235
+ desc = [
236
+ f"- {key}: {spec['description']}",
237
+ f" 类型: {spec['type'].__name__}",
238
+ f" 默认值: {spec['default']}"
239
+ ]
240
+ if "allowed" in spec:
241
+ desc.append(f" 允许值: {', '.join(spec['allowed'])}")
242
+ if "min" in spec and "max" in spec:
243
+ desc.append(f" 取值范围: {spec['min']}~{spec['max']}")
244
+ docs.append("\n".join(desc))
245
+ return "\n\n".join(docs)
@@ -0,0 +1,131 @@
1
+ from typing import List, Dict, Any, Union
2
+ import json
3
+ from pydantic import BaseModel
4
+ import byzerllm
5
+ from autocoder.common.printer import Printer
6
+ from autocoder.utils.llms import count_tokens
7
+ from loguru import logger
8
+
9
+ class PruneStrategy(BaseModel):
10
+ name: str
11
+ description: str
12
+ config: Dict[str, Any] = {"safe_zone_tokens": 0, "group_size": 4}
13
+
14
+ class ConversationPruner:
15
+ def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM],
16
+ safe_zone_tokens: int = 500, group_size: int = 4):
17
+ self.llm = llm
18
+ self.printer = Printer()
19
+ self.strategies = {
20
+ "summarize": PruneStrategy(
21
+ name="summarize",
22
+ description="对早期对话进行分组摘要,保留关键信息",
23
+ config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
24
+ ),
25
+ "truncate": PruneStrategy(
26
+ name="truncate",
27
+ description="分组截断最早的部分对话",
28
+ config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
29
+ ),
30
+ "hybrid": PruneStrategy(
31
+ name="hybrid",
32
+ description="先尝试分组摘要,如果仍超限则分组截断",
33
+ config={"safe_zone_tokens": safe_zone_tokens, "group_size": group_size}
34
+ )
35
+ }
36
+
37
+ def get_available_strategies(self) -> List[Dict[str, Any]]:
38
+ """获取所有可用策略"""
39
+ return [strategy.dict() for strategy in self.strategies.values()]
40
+
41
+ def prune_conversations(self, conversations: List[Dict[str, Any]],
42
+ strategy_name: str = "summarize") -> List[Dict[str, Any]]:
43
+ """
44
+ 根据策略修剪对话
45
+ Args:
46
+ conversations: 原始对话列表
47
+ strategy_name: 策略名称
48
+ Returns:
49
+ 修剪后的对话列表
50
+ """
51
+ current_tokens = count_tokens(json.dumps(conversations, ensure_ascii=False))
52
+ if current_tokens <= self.args.conversation_prune_safe_zone_tokens:
53
+ return conversations
54
+
55
+ strategy = self.strategies.get(strategy_name, self.strategies["summarize"])
56
+
57
+ if strategy.name == "summarize":
58
+ return self._summarize_prune(conversations, strategy.config)
59
+ elif strategy.name == "truncate":
60
+ return self._truncate_prune.with_llm(self.llm).run(conversations)
61
+ elif strategy.name == "hybrid":
62
+ pruned = self._summarize_prune(conversations, strategy.config)
63
+ if count_tokens(json.dumps(pruned, ensure_ascii=False)) > self.args.conversation_prune_safe_zone_tokens:
64
+ return self._truncate_prune(pruned)
65
+ return pruned
66
+ else:
67
+ logger.warning(f"Unknown strategy: {strategy_name}, using summarize instead")
68
+ return self._summarize_prune(conversations, strategy.config)
69
+
70
+ def _summarize_prune(self, conversations: List[Dict[str, Any]],
71
+ config: Dict[str, Any]) -> List[Dict[str, Any]]:
72
+ """摘要式剪枝"""
73
+ safe_zone_tokens = config.get("safe_zone_tokens", 50*1024)
74
+ group_size = config.get("group_size", 4)
75
+ processed_conversations = conversations.copy()
76
+
77
+ while True:
78
+ current_tokens = count_tokens(json.dumps(processed_conversations, ensure_ascii=False))
79
+ if current_tokens <= safe_zone_tokens:
80
+ break
81
+
82
+ # 找到要处理的对话组
83
+ early_conversations = processed_conversations[:-group_size]
84
+ recent_conversations = processed_conversations[-group_size:]
85
+
86
+ if not early_conversations:
87
+ break
88
+
89
+ # 生成当前组的摘要
90
+ group_summary = self._generate_summary.with_llm(self.llm).run(early_conversations[-group_size:])
91
+
92
+ # 更新对话历史
93
+ processed_conversations = early_conversations[:-group_size] + [
94
+ {"role": "user", "content": f"历史对话摘要:\n{group_summary}"},
95
+ {"role": "assistant", "content": f"收到"}
96
+ ] + recent_conversations
97
+
98
+ return processed_conversations
99
+
100
+ @byzerllm.prompt()
101
+ def _generate_summary(self, conversations: List[Dict[str, Any]]) -> str:
102
+ '''
103
+ 请用中文将以下对话浓缩为要点,保留关键决策和技术细节:
104
+
105
+ <history_conversations>
106
+ {{conversations}}
107
+ </history_conversations>
108
+ '''
109
+ return {
110
+ "conversations": json.dumps(conversations, ensure_ascii=False)
111
+ }
112
+
113
+ def _truncate_prune(self, conversations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
114
+ """截断式剪枝"""
115
+ safe_zone_tokens = self.strategies["truncate"].config.get("safe_zone_tokens", 0)
116
+ group_size = self.strategies["truncate"].config.get("group_size", 4)
117
+ processed_conversations = conversations.copy()
118
+
119
+ while True:
120
+ current_tokens = count_tokens(json.dumps(processed_conversations, ensure_ascii=False))
121
+ if current_tokens <= safe_zone_tokens:
122
+ break
123
+
124
+ # 如果剩余对话不足一组,直接返回
125
+ if len(processed_conversations) <= group_size:
126
+ return []
127
+
128
+ # 移除最早的一组对话
129
+ processed_conversations = processed_conversations[group_size:]
130
+
131
+ return processed_conversations
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  from git import Repo, GitCommandError
3
3
  from loguru import logger
4
- from typing import List, Optional
4
+ from typing import List, Optional, Dict
5
5
  from pydantic import BaseModel
6
6
  import byzerllm
7
7
  from rich.console import Console
@@ -11,6 +11,16 @@ from rich.table import Table
11
11
  from rich.text import Text
12
12
 
13
13
 
14
+ class FileChange(BaseModel):
15
+ file_path: str
16
+ before: Optional[str] = None
17
+ after: Optional[str] = None
18
+
19
+ class CommitChangesResult(BaseModel):
20
+ success: bool
21
+ changes: Dict[str, FileChange] = {}
22
+ error_message: Optional[str] = None
23
+
14
24
  class CommitResult(BaseModel):
15
25
  success: bool
16
26
  commit_message: Optional[str] = None
@@ -605,6 +615,77 @@ def generate_commit_message(changes_report: str) -> str:
605
615
  请输出commit message, 不要输出任何其他内容.
606
616
  '''
607
617
 
618
+ def get_commit_by_message(repo_path: str, message: str):
619
+ repo = get_repo(repo_path)
620
+ try:
621
+ commit_hash = repo.git.log("--all", f"--grep={message}", "--format=%H", "-n", "1")
622
+ if not commit_hash:
623
+ return None
624
+ return repo.commit(commit_hash.strip())
625
+ except GitCommandError as e:
626
+ logger.error(f"Error finding commit: {e}")
627
+ return None
628
+
629
+ def get_changes_by_commit_message(repo_path: str, message: str) -> CommitChangesResult:
630
+ """
631
+ 根据提交信息查找对应的变更内容
632
+
633
+ Args:
634
+ repo_path: Git仓库路径
635
+ message: 提交信息
636
+
637
+ Returns:
638
+ CommitChangesResult: 包含变更前后内容的字典,键为文件路径
639
+ """
640
+ try:
641
+ if repo_path:
642
+ repo = get_repo(repo_path)
643
+ else:
644
+ repo = get_repo(os.getcwd())
645
+ commit = get_commit_by_message(repo_path, message)
646
+
647
+ if not commit:
648
+ return CommitChangesResult(success=False, error_message="Commit not found")
649
+
650
+ changes = {}
651
+
652
+ # 比较当前commit与其父commit的差异
653
+ for diff_item in commit.parents[0].diff(commit):
654
+ file_path = diff_item.a_path if diff_item.a_path else diff_item.b_path
655
+
656
+ # 获取变更前内容
657
+ before_content = None
658
+ try:
659
+ if diff_item.a_blob:
660
+ before_content = repo.git.show(f"{commit.parents[0].hexsha}:{file_path}")
661
+ except GitCommandError:
662
+ pass # 文件可能是新增的
663
+
664
+ # 获取变更后内容
665
+ after_content = None
666
+ try:
667
+ if diff_item.b_blob:
668
+ after_content = repo.git.show(f"{commit.hexsha}:{file_path}")
669
+ except GitCommandError:
670
+ pass # 文件可能被删除
671
+
672
+ changes[file_path] = FileChange(
673
+ file_path=file_path,
674
+ before=before_content,
675
+ after=after_content
676
+ )
677
+
678
+ return CommitChangesResult(success=True, changes=changes)
679
+
680
+ except GitCommandError as e:
681
+ logger.error(f"Error retrieving changes: {e}")
682
+ return CommitChangesResult(success=False, error_message=str(e))
683
+ except IndexError:
684
+ return CommitChangesResult(success=False, error_message="Initial commit has no parent")
685
+ except Exception as e:
686
+ logger.error(f"Unexpected error: {e}")
687
+ return CommitChangesResult(success=False, error_message=str(e))
688
+
608
689
  def print_commit_info(commit_result: CommitResult):
609
690
  console = Console()
610
691
  table = Table(
@@ -0,0 +1,101 @@
1
+ import os
2
+ import json
3
+ import shutil
4
+ from loguru import logger
5
+ from autocoder.common.printer import Printer
6
+
7
+
8
+ def export_index(project_root: str, export_path: str) -> bool:
9
+ printer = Printer()
10
+ """
11
+ Export index.json with absolute paths converted to relative paths
12
+
13
+ Args:
14
+ project_root: Project root directory
15
+ export_path: Path to export the index file
16
+
17
+ Returns:
18
+ bool: True if successful, False otherwise
19
+ """
20
+ try:
21
+ index_path = os.path.join(project_root, ".auto-coder", "index.json")
22
+ if not os.path.exists(index_path):
23
+ printer.print_in_terminal("index_not_found", path=index_path)
24
+ return False
25
+
26
+ # Read and convert paths
27
+ with open(index_path, "r") as f:
28
+ index_data = json.load(f)
29
+
30
+ # Convert absolute paths to relative
31
+ converted_data = {}
32
+ for abs_path, data in index_data.items():
33
+ try:
34
+ rel_path = os.path.relpath(abs_path, project_root)
35
+ data["module_name"] = rel_path
36
+ converted_data[rel_path] = data
37
+ except ValueError:
38
+ printer.print_in_terminal("index_convert_path_fail", path=abs_path)
39
+ converted_data[abs_path] = data
40
+
41
+ # Write to export location
42
+ export_file = os.path.join(export_path, "index.json")
43
+ os.makedirs(export_path, exist_ok=True)
44
+ with open(export_file, "w") as f:
45
+ json.dump(converted_data, f, indent=2)
46
+
47
+ return True
48
+
49
+ except Exception as e:
50
+ printer.print_in_terminal("index_error", error=str(e))
51
+ return False
52
+
53
+ def import_index(project_root: str, import_path: str) -> bool:
54
+ printer = Printer()
55
+ """
56
+ Import index.json with relative paths converted to absolute paths
57
+
58
+ Args:
59
+ project_root: Project root directory
60
+ import_path: Path containing the index file to import
61
+
62
+ Returns:
63
+ bool: True if successful, False otherwise
64
+ """
65
+ try:
66
+ import_file = os.path.join(import_path, "index.json")
67
+ if not os.path.exists(import_file):
68
+ printer.print_in_terminal("index_not_found", path=import_file)
69
+ return False
70
+
71
+ # Read and convert paths
72
+ with open(import_file, "r") as f:
73
+ index_data = json.load(f)
74
+
75
+ # Convert relative paths to absolute
76
+ converted_data = {}
77
+ for rel_path, data in index_data.items():
78
+ try:
79
+ abs_path = os.path.join(project_root, rel_path)
80
+ data["module_name"] = abs_path
81
+ converted_data[abs_path] = data
82
+ except Exception:
83
+ printer.print_in_terminal("index_convert_path_fail", path=rel_path)
84
+ converted_data[rel_path] = data
85
+
86
+ # Backup existing index
87
+ index_path = os.path.join(project_root, ".auto-coder", "index.json")
88
+ if os.path.exists(index_path):
89
+ backup_path = index_path + ".bak"
90
+ shutil.copy2(index_path, backup_path)
91
+ printer.print_in_terminal("index_backup_success", path=backup_path)
92
+
93
+ # Write new index
94
+ with open(index_path, "w") as f:
95
+ json.dump(converted_data, f, indent=2)
96
+
97
+ return True
98
+
99
+ except Exception as e:
100
+ printer.print_in_terminal("index_error", error=str(e))
101
+ return False
@@ -0,0 +1,115 @@
1
+ import os
2
+ import json
3
+ import time
4
+ from typing import List, Dict, Any, Optional
5
+ from pydantic import BaseModel, Field
6
+
7
+ class ResultItem(BaseModel):
8
+ """单条结果记录的数据模型"""
9
+ content: str = Field(..., description="结果内容")
10
+ meta: Dict[str, Any] = Field(default_factory=dict, description="元数据信息")
11
+ time: int = Field(default_factory=lambda: int(time.time()), description="记录时间戳")
12
+
13
+ class Config:
14
+ arbitrary_types_allowed = True
15
+
16
+ class ResultManager:
17
+ """结果管理器,用于维护一个追加写入的jsonl文件"""
18
+
19
+ def __init__(self, source_dir: Optional[str] = None):
20
+ """
21
+ 初始化结果管理器
22
+
23
+ Args:
24
+ source_dir: 可选的源目录,如果不提供则使用当前目录
25
+ """
26
+ self.source_dir = source_dir or os.getcwd()
27
+ self.result_dir = os.path.join(self.source_dir, ".auto-coder", "results")
28
+ self.result_file = os.path.join(self.result_dir, "results.jsonl")
29
+ os.makedirs(self.result_dir, exist_ok=True)
30
+
31
+ def append(self, content: str, meta: Optional[Dict[str, Any]] = None) -> ResultItem:
32
+ """
33
+ 追加一条新的结果记录
34
+
35
+ Args:
36
+ content: 结果内容
37
+ meta: 可选的元数据信息
38
+
39
+ Returns:
40
+ ResultItem: 新创建的结果记录
41
+ """
42
+ result_item = ResultItem(
43
+ content=content,
44
+ meta=meta or {},
45
+ )
46
+
47
+ with open(self.result_file, "a", encoding="utf-8") as f:
48
+ f.write(result_item.model_dump_json() + "\n")
49
+
50
+ return result_item
51
+
52
+ def add_result(self, content: str, meta: Optional[Dict[str, Any]] = None) -> ResultItem:
53
+ return self.append(content, meta)
54
+
55
+ def get_last(self) -> Optional[ResultItem]:
56
+ """
57
+ 获取最后一条记录
58
+
59
+ Returns:
60
+ Optional[ResultItem]: 最后一条记录,如果文件为空则返回None
61
+ """
62
+ if not os.path.exists(self.result_file):
63
+ return None
64
+
65
+ with open(self.result_file, "r", encoding="utf-8") as f:
66
+ lines = f.readlines()
67
+ if not lines:
68
+ return None
69
+ last_line = lines[-1].strip()
70
+ return ResultItem.model_validate_json(last_line)
71
+
72
+ def get_all(self) -> List[ResultItem]:
73
+ """
74
+ 获取所有记录
75
+
76
+ Returns:
77
+ List[ResultItem]: 所有记录的列表
78
+ """
79
+ if not os.path.exists(self.result_file):
80
+ return []
81
+
82
+ results = []
83
+ with open(self.result_file, "r", encoding="utf-8") as f:
84
+ for line in f:
85
+ line = line.strip()
86
+ if line: # 跳过空行
87
+ results.append(ResultItem.model_validate_json(line))
88
+ return results
89
+
90
+ def get_by_time_range(self,
91
+ start_time: Optional[int] = None,
92
+ end_time: Optional[int] = None) -> List[ResultItem]:
93
+ """
94
+ 获取指定时间范围内的记录
95
+
96
+ Args:
97
+ start_time: 开始时间戳
98
+ end_time: 结束时间戳
99
+
100
+ Returns:
101
+ List[ResultItem]: 符合时间范围的记录列表
102
+ """
103
+ results = []
104
+ for item in self.get_all():
105
+ if start_time and item.time < start_time:
106
+ continue
107
+ if end_time and item.time > end_time:
108
+ continue
109
+ results.append(item)
110
+ return results
111
+
112
+ def clear(self) -> None:
113
+ """清空所有记录"""
114
+ if os.path.exists(self.result_file):
115
+ os.remove(self.result_file)