auto-coder 0.1.305__py3-none-any.whl → 0.1.307__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (43) hide show
  1. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/METADATA +1 -1
  2. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/RECORD +43 -38
  3. autocoder/agent/auto_demand_organizer.py +13 -20
  4. autocoder/agent/auto_filegroup.py +10 -16
  5. autocoder/agent/auto_learn_from_commit.py +25 -33
  6. autocoder/agent/auto_review_commit.py +15 -64
  7. autocoder/auto_coder.py +6 -8
  8. autocoder/auto_coder_runner.py +153 -8
  9. autocoder/chat_auto_coder.py +9 -1
  10. autocoder/chat_auto_coder_lang.py +552 -278
  11. autocoder/commands/auto_command.py +31 -7
  12. autocoder/common/__init__.py +6 -0
  13. autocoder/common/action_yml_file_manager.py +75 -37
  14. autocoder/common/auto_coder_lang.py +737 -401
  15. autocoder/common/code_auto_generate.py +104 -16
  16. autocoder/common/code_auto_generate_diff.py +101 -10
  17. autocoder/common/code_auto_generate_editblock.py +103 -9
  18. autocoder/common/code_auto_generate_strict_diff.py +99 -9
  19. autocoder/common/code_auto_merge.py +8 -0
  20. autocoder/common/code_auto_merge_diff.py +8 -0
  21. autocoder/common/code_auto_merge_editblock.py +7 -0
  22. autocoder/common/code_auto_merge_strict_diff.py +5 -0
  23. autocoder/common/code_modification_ranker.py +9 -3
  24. autocoder/common/command_completer.py +12 -0
  25. autocoder/common/command_generator.py +5 -4
  26. autocoder/common/git_utils.py +86 -63
  27. autocoder/common/stream_out_type.py +8 -1
  28. autocoder/common/utils_code_auto_generate.py +29 -3
  29. autocoder/dispacher/__init__.py +18 -19
  30. autocoder/dispacher/actions/action.py +0 -132
  31. autocoder/index/filter/quick_filter.py +6 -3
  32. autocoder/memory/__init__.py +7 -0
  33. autocoder/memory/active_context_manager.py +649 -0
  34. autocoder/memory/active_package.py +469 -0
  35. autocoder/memory/async_processor.py +161 -0
  36. autocoder/memory/directory_mapper.py +67 -0
  37. autocoder/utils/auto_coder_utils/chat_stream_out.py +5 -0
  38. autocoder/utils/project_structure.py +35 -1
  39. autocoder/version.py +1 -1
  40. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/LICENSE +0 -0
  41. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/WHEEL +0 -0
  42. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/entry_points.txt +0 -0
  43. {auto_coder-0.1.305.dist-info → auto_coder-0.1.307.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,16 @@ from autocoder.utils.queue_communicate import queue_communicate, CommunicateEven
6
6
  from autocoder.common import sys_prompt
7
7
  from concurrent.futures import ThreadPoolExecutor
8
8
  import json
9
- from autocoder.common.utils_code_auto_generate import chat_with_continue
9
+ from autocoder.common.utils_code_auto_generate import chat_with_continue,stream_chat_with_continue,ChatWithContinueResult
10
+ from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
11
+ from autocoder.common.stream_out_type import CodeGenerateStreamOutType
12
+ from autocoder.common.auto_coder_lang import get_message_with_format
10
13
  from autocoder.common.printer import Printer
11
14
  from autocoder.rag.token_counter import count_tokens
12
15
  from autocoder.utils import llms as llm_utils
13
16
  from autocoder.common import SourceCodeList
14
17
  from autocoder.privacy.model_filter import ModelPathFilter
18
+ from autocoder.memory.active_context_manager import ActiveContextManager
15
19
  class CodeAutoGenerateStrictDiff:
16
20
  def __init__(
17
21
  self, llm: byzerllm.ByzerLLM, args: AutoCoderArgs, action=None
@@ -31,7 +35,7 @@ class CodeAutoGenerateStrictDiff:
31
35
 
32
36
  @byzerllm.prompt(llm=lambda self: self.llm)
33
37
  def multi_round_instruction(
34
- self, instruction: str, content: str, context: str = ""
38
+ self, instruction: str, content: str, context: str = "", package_context: str = ""
35
39
  ) -> str:
36
40
  """
37
41
  如果你需要生成代码,对于每个需要更改的文件,写出类似于 unified diff 的更改,就像`diff -U0`会产生的那样。
@@ -124,6 +128,13 @@ class CodeAutoGenerateStrictDiff:
124
128
  </files>
125
129
  {%- endif %}
126
130
 
131
+ {%- if package_context %}
132
+ 下面是上面文件的一些信息(包括最近的变更情况):
133
+ <package_context>
134
+ {{ package_context }}
135
+ </package_context>
136
+ {%- endif %}
137
+
127
138
  {%- if context %}
128
139
  <extra_context>
129
140
  {{ context }}
@@ -152,7 +163,7 @@ class CodeAutoGenerateStrictDiff:
152
163
 
153
164
  @byzerllm.prompt(llm=lambda self: self.llm)
154
165
  def single_round_instruction(
155
- self, instruction: str, content: str, context: str = ""
166
+ self, instruction: str, content: str, context: str = "", package_context: str = ""
156
167
  ) -> str:
157
168
  """
158
169
  如果你需要生成代码,对于每个需要更改的文件,写出类似于 unified diff 的更改,就像`diff -U0`会产生的那样。
@@ -248,6 +259,13 @@ class CodeAutoGenerateStrictDiff:
248
259
  </files>
249
260
  {%- endif %}
250
261
 
262
+ {%- if package_context %}
263
+ 下面是上面文件的一些信息(包括最近的变更情况):
264
+ <package_context>
265
+ {{ package_context }}
266
+ </package_context>
267
+ {%- endif %}
268
+
251
269
  {%- if context %}
252
270
  <extra_context>
253
271
  {{ context }}
@@ -278,9 +296,28 @@ class CodeAutoGenerateStrictDiff:
278
296
  llm_config = {"human_as_model": self.args.human_as_model}
279
297
  source_content = source_code_list.to_str()
280
298
 
299
+ # 获取包上下文信息
300
+ package_context = ""
301
+
302
+ if self.args.enable_active_context:
303
+ # 初始化活动上下文管理器
304
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
305
+ # 获取活动上下文信息
306
+ result = active_context_manager.load_active_contexts_for_files(
307
+ [source.module_name for source in source_code_list.sources]
308
+ )
309
+ # 将活动上下文信息格式化为文本
310
+ if result.contexts:
311
+ package_context_parts = []
312
+ for dir_path, context in result.contexts.items():
313
+ package_context_parts.append(f"<package_info>{context.content}</package_info>")
314
+
315
+ package_context = "\n".join(package_context_parts)
316
+
281
317
  if self.args.template == "common":
282
318
  init_prompt = self.single_round_instruction.prompt(
283
- instruction=query, content=source_content, context=self.args.context
319
+ instruction=query, content=source_content, context=self.args.context,
320
+ package_context=package_context
284
321
  )
285
322
  elif self.args.template == "auto_implement":
286
323
  init_prompt = self.auto_implement_function.prompt(
@@ -336,10 +373,39 @@ class CodeAutoGenerateStrictDiff:
336
373
  if model_names_list:
337
374
  model_name = model_names_list[0]
338
375
 
339
- for _ in range(self.generate_times_same_model):
376
+ for i in range(self.generate_times_same_model):
340
377
  model_names.append(model_name)
341
- futures.append(executor.submit(
342
- chat_with_continue, llm=llm, conversations=conversations, llm_config=llm_config))
378
+ if i==0:
379
+ def job():
380
+ stream_generator = stream_chat_with_continue(
381
+ llm=llm,
382
+ conversations=conversations,
383
+ llm_config=llm_config,
384
+ args=self.args
385
+ )
386
+ full_response, last_meta = stream_out(
387
+ stream_generator,
388
+ model_name=model_name,
389
+ title=get_message_with_format(
390
+ "code_generate_title", model_name=model_name),
391
+ args=self.args,
392
+ extra_meta={
393
+ "stream_out_type": CodeGenerateStreamOutType.CODE_GENERATE.value
394
+ })
395
+ return ChatWithContinueResult(
396
+ content=full_response,
397
+ input_tokens_count=last_meta.input_tokens_count,
398
+ generated_tokens_count=last_meta.generated_tokens_count
399
+ )
400
+ futures.append(executor.submit(job))
401
+ else:
402
+ futures.append(executor.submit(
403
+ chat_with_continue,
404
+ llm=llm,
405
+ conversations=conversations,
406
+ llm_config=llm_config,
407
+ args=self.args
408
+ ))
343
409
 
344
410
  temp_results = [future.result() for future in futures]
345
411
  for result in temp_results:
@@ -356,7 +422,12 @@ class CodeAutoGenerateStrictDiff:
356
422
  conversations + [{"role": "assistant", "content": result}])
357
423
  else:
358
424
  for _ in range(self.args.human_model_num):
359
- single_result = chat_with_continue(llm=self.llms[0], conversations=conversations, llm_config=llm_config)
425
+ single_result = chat_with_continue(
426
+ llm=self.llms[0],
427
+ conversations=conversations,
428
+ llm_config=llm_config,
429
+ args=self.args
430
+ )
360
431
  results.append(single_result.content)
361
432
  input_tokens_count += single_result.input_tokens_count
362
433
  generated_tokens_count += single_result.generated_tokens_count
@@ -404,9 +475,28 @@ class CodeAutoGenerateStrictDiff:
404
475
  result = []
405
476
  source_content = source_code_list.to_str()
406
477
 
478
+ # 获取包上下文信息
479
+ package_context = ""
480
+
481
+ if self.args.enable_active_context:
482
+ # 初始化活动上下文管理器
483
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
484
+ # 获取活动上下文信息
485
+ result = active_context_manager.load_active_contexts_for_files(
486
+ [source.module_name for source in source_code_list.sources]
487
+ )
488
+ # 将活动上下文信息格式化为文本
489
+ if result.contexts:
490
+ package_context_parts = []
491
+ for dir_path, context in result.contexts.items():
492
+ package_context_parts.append(f"<package_info>{context.content}</package_info>")
493
+
494
+ package_context = "\n".join(package_context_parts)
495
+
407
496
  if self.args.template == "common":
408
497
  init_prompt = self.multi_round_instruction.prompt(
409
- instruction=query, content=source_content, context=self.args.context
498
+ instruction=query, content=source_content, context=self.args.context,
499
+ package_context=package_context
410
500
  )
411
501
  elif self.args.template == "auto_implement":
412
502
  init_prompt = self.auto_implement_function.prompt(
@@ -11,6 +11,7 @@ from autocoder.common import files as FileUtils
11
11
  from autocoder.common.printer import Printer
12
12
  from autocoder.common.auto_coder_lang import get_message
13
13
  from autocoder.common.action_yml_file_manager import ActionYmlFileManager
14
+ from autocoder.memory.active_context_manager import ActiveContextManager
14
15
 
15
16
  class PathAndCode(pydantic.BaseModel):
16
17
  path: str
@@ -211,4 +212,11 @@ class CodeAutoMerge:
211
212
  if not update_yaml_success:
212
213
  self.printer.print_in_terminal("yaml_save_error", style="red", yaml_file=action_file_name)
213
214
 
215
+ if self.args.enable_active_context:
216
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
217
+ task_id = active_context_manager.process_changes(self.args)
218
+ self.printer.print_in_terminal("active_context_background_task",
219
+ style="blue",
220
+ task_id=task_id)
221
+
214
222
  git_utils.print_commit_info(commit_result=commit_result)
@@ -9,6 +9,7 @@ from autocoder.common.printer import Printer
9
9
  import hashlib
10
10
  from pathlib import Path
11
11
  from itertools import groupby
12
+ from autocoder.memory.active_context_manager import ActiveContextManager
12
13
  from autocoder.common.search_replace import (
13
14
  SearchTextNotUnique,
14
15
  all_preprocs,
@@ -593,6 +594,13 @@ class CodeAutoMergeDiff:
593
594
  update_yaml_success = action_yml_file_manager.update_yaml_field(action_file_name, "add_updated_urls", add_updated_urls)
594
595
  if not update_yaml_success:
595
596
  self.printer.print_in_terminal("yaml_save_error", style="red", yaml_file=action_file_name)
597
+
598
+ if self.args.enable_active_context:
599
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
600
+ task_id = active_context_manager.process_changes(self.args)
601
+ self.printer.print_in_terminal("active_context_background_task",
602
+ style="blue",
603
+ task_id=task_id)
596
604
 
597
605
  git_utils.print_commit_info(commit_result=commit_result)
598
606
  else:
@@ -3,6 +3,7 @@ from byzerllm.utils.client import code_utils
3
3
  from autocoder.common import AutoCoderArgs, git_utils
4
4
  from autocoder.common.action_yml_file_manager import ActionYmlFileManager
5
5
  from autocoder.common.text import TextSimilarity
6
+ from autocoder.memory.active_context_manager import ActiveContextManager
6
7
  from autocoder.utils.queue_communicate import (
7
8
  queue_communicate,
8
9
  CommunicateEvent,
@@ -442,6 +443,12 @@ class CodeAutoMergeEditBlock:
442
443
  if not update_yaml_success:
443
444
  self.printer.print_in_terminal("yaml_save_error", style="red", yaml_file=action_file_name)
444
445
 
446
+ if self.args.enable_active_context:
447
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
448
+ task_id = active_context_manager.process_changes(self.args)
449
+ self.printer.print_in_terminal("active_context_background_task",
450
+ style="blue",
451
+ task_id=task_id)
445
452
  git_utils.print_commit_info(commit_result=commit_result)
446
453
  except Exception as e:
447
454
  self.printer.print_str_in_terminal(
@@ -12,6 +12,7 @@ from pathlib import Path
12
12
  from autocoder.common.types import CodeGenerateResult, MergeCodeWithoutEffect
13
13
  from autocoder.common.code_modification_ranker import CodeModificationRanker
14
14
  from autocoder.common import files as FileUtils
15
+ from autocoder.memory.active_context_manager import ActiveContextManager
15
16
 
16
17
  class PathAndCode(pydantic.BaseModel):
17
18
  path: str
@@ -301,6 +302,10 @@ class CodeAutoMergeStrictDiff:
301
302
  if not update_yaml_success:
302
303
  self.printer.print_in_terminal("yaml_save_error", style="red", yaml_file=action_file_name)
303
304
 
305
+ if self.args.enable_active_context:
306
+ active_context_manager = ActiveContextManager(self.llm, self.args.source_dir)
307
+ active_context_manager.process_changes(self.args)
308
+
304
309
  git_utils.print_commit_info(commit_result=commit_result)
305
310
  else:
306
311
  # Print diff blocks for review
@@ -13,6 +13,7 @@ from autocoder.utils.llms import get_llm_names, get_model_info
13
13
  from autocoder.common.types import CodeGenerateResult, MergeCodeWithoutEffect
14
14
  import os
15
15
  from autocoder.rag.token_counter import count_tokens
16
+ from autocoder.common.stream_out_type import CodeRankStreamOutType
16
17
 
17
18
  class RankResult(BaseModel):
18
19
  rank_result: List[int]
@@ -163,7 +164,8 @@ class CodeModificationRanker:
163
164
  stream_chat_with_continue,
164
165
  llm,
165
166
  [{"role": "user", "content": query}],
166
- {}
167
+ {},
168
+ self.args
167
169
  )
168
170
  )
169
171
  else:
@@ -172,7 +174,8 @@ class CodeModificationRanker:
172
174
  chat_with_continue,
173
175
  llm,
174
176
  [{"role": "user", "content": query}],
175
- {}
177
+ {},
178
+ self.args
176
179
  )
177
180
  )
178
181
 
@@ -209,7 +212,10 @@ class CodeModificationRanker:
209
212
  model_name=model_name,
210
213
  title=self.printer.get_message_from_key_with_format(
211
214
  "rank_code_modification_title", model_name=model_name),
212
- args=self.args
215
+ args=self.args,
216
+ extra_meta={
217
+ "stream_out_type": CodeRankStreamOutType.CODE_RANK.value
218
+ }
213
219
  )
214
220
 
215
221
  if last_meta:
@@ -53,6 +53,9 @@ COMMANDS = {
53
53
  },
54
54
  "/shell": {
55
55
  "/chat": "",
56
+ },
57
+ "/active_context": {
58
+ "/list": ""
56
59
  }
57
60
  }
58
61
 
@@ -496,6 +499,15 @@ class CommandCompleter(Completer):
496
499
  if command.startswith(current_word):
497
500
  yield Completion(command, start_position=-len(current_word))
498
501
 
502
+ elif words[0] == "/active_context":
503
+ new_text = text[len("/active_context"):]
504
+ parser = CommandTextParser(new_text, words[0])
505
+ parser.lib()
506
+ current_word = parser.current_word()
507
+ for command in parser.get_sub_commands():
508
+ if command.startswith(current_word):
509
+ yield Completion(command, start_position=-len(current_word))
510
+
499
511
  elif words[0] == "/conf":
500
512
  new_words = text[len("/conf"):].strip().split()
501
513
  is_at_space = text[-1] == " "
@@ -1,7 +1,7 @@
1
1
  import byzerllm
2
2
  from byzerllm.utils.client import code_utils
3
3
  from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
4
- from autocoder.common import detect_env
4
+ from autocoder.common import detect_env,AutoCoderArgs
5
5
  from autocoder.common import shells
6
6
  from autocoder.common.printer import Printer
7
7
  from typing import Dict,Union
@@ -57,9 +57,9 @@ def _generate_shell_script(user_input: str) -> str:
57
57
  }
58
58
 
59
59
 
60
- def generate_shell_script(user_input: str, llm: Union[byzerllm.ByzerLLM,byzerllm.SimpleByzerLLM]) -> str:
60
+ def generate_shell_script(args: AutoCoderArgs, llm: Union[byzerllm.ByzerLLM,byzerllm.SimpleByzerLLM]) -> str:
61
61
  # 获取 prompt 内容
62
- prompt = _generate_shell_script.prompt(user_input=user_input)
62
+ prompt = _generate_shell_script.prompt(user_input=args.query)
63
63
  if llm.get_sub_client("chat_model"):
64
64
  shell_llm = llm.get_sub_client("chat_model")
65
65
  else:
@@ -74,7 +74,8 @@ def generate_shell_script(user_input: str, llm: Union[byzerllm.ByzerLLM,byzerllm
74
74
  result, _ = stream_out(
75
75
  shell_llm.stream_chat_oai(conversations=conversations, delta_mode=True),
76
76
  model_name=llm.default_model_name,
77
- title=title
77
+ title=title,
78
+ args=args
78
79
  )
79
80
 
80
81
  # 提取代码块
@@ -1,7 +1,8 @@
1
1
  import os
2
2
  from git import Repo, GitCommandError
3
+ import git
3
4
  from loguru import logger
4
- from typing import List, Optional, Dict
5
+ from typing import List, Optional, Dict, Any
5
6
  from pydantic import BaseModel
6
7
  import byzerllm
7
8
  from rich.console import Console
@@ -103,68 +104,84 @@ def get_current_branch(repo_path: str) -> str:
103
104
  return branch
104
105
 
105
106
 
106
- def revert_changes(repo_path: str, message: str) -> bool:
107
+ def revert_changes(repo_path: str, action_file_path: str) -> Optional[Any]:
108
+ '''
109
+ file_path 类似: auto_coder_000000002009_chat_action.yml 或者 000000002009_chat_action.yml
110
+ '''
107
111
  repo = get_repo(repo_path)
108
112
  if repo is None:
109
113
  logger.error("Repository is not initialized.")
110
114
  return False
111
-
115
+
116
+ commit_hash = None
117
+ # 这里遍历从最新的commit 开始遍历
118
+ for commit in repo.iter_commits():
119
+ if action_file_path in commit.message and not commit.message.startswith("<revert>"):
120
+ commit_hash = commit.hexsha
121
+ break
122
+
123
+ if commit_hash is None:
124
+ raise ValueError(f"File {action_file_path} not found in any commit")
125
+
126
+ # 尝试获取指定的提交
112
127
  try:
113
- # 检查当前工作目录是否有未提交的更改
114
- if repo.is_dirty():
115
- logger.warning(
116
- "Working directory is dirty. please commit or stash your changes before reverting."
117
- )
118
- return False
119
-
120
- # 通过message定位到commit_hash
121
- # --grep 默认只搜索第一行 -F 参数将搜索模式视为固定字符串而非正则表达式
122
- commit = repo.git.log("--all", f"--grep={message}", "-F", "--format=%H", "-n", "1")
123
- if not commit:
124
- logger.warning(f"No commit found with message: {message}")
125
- return False
126
-
127
- commit_hash = commit
128
-
129
- # 获取从指定commit到HEAD的所有提交
130
- commits = list(repo.iter_commits(f"{commit_hash}..HEAD"))
131
-
132
- if not commits:
133
- repo.git.revert(commit, no_edit=True)
134
- logger.info(f"Reverted single commit: {commit}")
128
+ commit = repo.commit(commit_hash)
129
+ except ValueError:
130
+ # 如果是短哈希,尝试匹配
131
+ matching_commits = [c for c in repo.iter_commits() if c.hexsha.startswith(commit_hash)]
132
+ if not matching_commits:
133
+ raise ValueError(f"Commit {commit_hash} not found")
134
+ commit = matching_commits[0]
135
+
136
+ # 检查工作目录是否干净
137
+ if repo.is_dirty():
138
+ raise ValueError("Working directory is dirty. please commit or stash your changes before reverting.")
139
+
140
+ try:
141
+ # 执行 git revert
142
+ # 使用 -n 选项不自动创建提交,而是让我们手动提交
143
+ repo.git.revert(commit.hexsha, no_commit=True)
144
+
145
+ # 创建带有信息的 revert 提交
146
+ revert_message = f"<revert>{commit.message.strip()}\n{commit.hexsha}"
147
+ new_commit = repo.index.commit(
148
+ revert_message,
149
+ author=repo.active_branch.commit.author,
150
+ committer=repo.active_branch.commit.committer
151
+ )
152
+
153
+ # 构建新提交的信息
154
+ stats = new_commit.stats.total
155
+ new_commit_info = {
156
+ "new_commit_hash": new_commit.hexsha,
157
+ "new_commit_short_hash": new_commit.hexsha[:7],
158
+ "reverted_commit": {
159
+ "hash": commit.hexsha,
160
+ "short_hash": commit.hexsha[:7],
161
+ "message": commit.message.strip()
162
+ },
163
+ "stats": {
164
+ "insertions": stats["insertions"],
165
+ "deletions": stats["deletions"],
166
+ "files_changed": stats["files"]
167
+ }
168
+ }
169
+
170
+ return new_commit_info
171
+
172
+ except git.GitCommandError as e:
173
+ # 如果发生 Git 命令错误,尝试恢复工作目录
174
+ try:
175
+ repo.git.reset("--hard", "HEAD")
176
+ except:
177
+ pass # 如果恢复失败,继续抛出原始错误
178
+
179
+ if "patch does not apply" in str(e):
180
+ raise Exception("Cannot revert: patch does not apply (likely due to conflicts)")
135
181
  else:
136
- # 从最新的提交开始,逐个回滚
137
- for commit in reversed(commits):
138
- try:
139
- repo.git.revert(commit.hexsha, no_commit=True)
140
- logger.info(f"Reverted changes from commit: {commit.hexsha}")
141
- except GitCommandError as e:
142
- logger.error(f"Error reverting commit {commit.hexsha}: {e}")
143
- repo.git.revert("--abort")
144
- return False
145
-
146
- # 提交所有的回滚更改
147
- repo.git.commit(message=f"Reverted all changes up to {commit_hash}")
148
-
149
- logger.info(f"Successfully reverted changes up to {commit_hash}")
150
-
151
- ## this is a mark, chat_auto_coder.py need this
152
- print(f"Successfully reverted changes", flush=True)
153
-
154
- # # 如果之前有stash,现在应用它
155
- # if stashed:
156
- # try:
157
- # repo.git.stash('pop')
158
- # logger.info("Applied stashed changes.")
159
- # except GitCommandError as e:
160
- # logger.error(f"Error applying stashed changes: {e}")
161
- # logger.info("Please manually apply the stashed changes.")
182
+ raise Exception(f"Git error during revert: {str(e)}")
162
183
 
163
- return True
164
-
165
- except GitCommandError as e:
166
- logger.error(f"Error during revert operation: {e}")
167
- return False
184
+ return None
168
185
 
169
186
 
170
187
  def revert_change(repo_path: str, message: str) -> bool:
@@ -213,13 +230,18 @@ def get_uncommitted_changes(repo_path: str) -> str:
213
230
  # 处理未暂存的变更
214
231
  for diff_item in diff_index:
215
232
  file_path = diff_item.a_path
216
- diff_content = repo.git.diff(None, file_path)
217
- if diff_item.new_file:
218
- changes['new'].append((file_path, diff_content))
219
- elif diff_item.deleted_file:
220
- changes['deleted'].append((file_path, diff_content))
221
- else:
222
- changes['modified'].append((file_path, diff_content))
233
+ try:
234
+ diff_content = repo.git.diff(None, '--', file_path)
235
+ if diff_item.new_file:
236
+ changes['new'].append((file_path, diff_content))
237
+ elif diff_item.deleted_file:
238
+ changes['deleted'].append((file_path, diff_content))
239
+ else:
240
+ changes['modified'].append((file_path, diff_content))
241
+ except GitCommandError as e:
242
+ logger.error(f"Error getting diff for file {file_path}: {e}")
243
+ # 继续处理下一个文件,不中断整个流程
244
+ continue
223
245
 
224
246
  # 处理未追踪的文件
225
247
  for file_path in untracked:
@@ -229,6 +251,7 @@ def get_uncommitted_changes(repo_path: str) -> str:
229
251
  changes['new'].append((file_path, f'+++ {file_path}\n{content}'))
230
252
  except Exception as e:
231
253
  logger.error(f"Error reading file {file_path}: {e}")
254
+ # 继续处理下一个文件
232
255
 
233
256
  # 生成markdown报告
234
257
  report = ["# Git Changes Report\n"]
@@ -4,4 +4,11 @@ class AutoCommandStreamOutType(Enum):
4
4
  COMMAND_SUGGESTION = "command_suggestion"
5
5
 
6
6
  class IndexFilterStreamOutType(Enum):
7
- FILE_NUMBER_LIST = "file_number_list"
7
+ FILE_NUMBER_LIST = "file_number_list"
8
+
9
+
10
+ class CodeGenerateStreamOutType(Enum):
11
+ CODE_GENERATE = "code_generate"
12
+
13
+ class CodeRankStreamOutType(Enum):
14
+ CODE_RANK = "code_rank"
@@ -2,6 +2,8 @@ from byzerllm import ByzerLLM,SimpleByzerLLM
2
2
  from typing import Generator, List, Any, Union, Optional, Callable
3
3
  from pydantic import BaseModel
4
4
  from loguru import logger
5
+ from autocoder.common import AutoCoderArgs
6
+ from autocoder.common.auto_coder_lang import get_message_with_format
5
7
 
6
8
  class ChatWithContinueResult(BaseModel):
7
9
  content: str
@@ -9,7 +11,12 @@ class ChatWithContinueResult(BaseModel):
9
11
  generated_tokens_count: int
10
12
 
11
13
 
12
- def chat_with_continue(llm: Union[ByzerLLM,SimpleByzerLLM], conversations: List[dict], llm_config: dict) -> ChatWithContinueResult:
14
+ def chat_with_continue(
15
+ llm: Union[ByzerLLM,SimpleByzerLLM],
16
+ conversations: List[dict],
17
+ llm_config: dict,
18
+ args: AutoCoderArgs
19
+ ) -> ChatWithContinueResult:
13
20
  final_result = ChatWithContinueResult(content="", input_tokens_count=0, generated_tokens_count=0)
14
21
  v = llm.chat_oai(
15
22
  conversations=conversations, llm_config=llm_config)
@@ -32,6 +39,15 @@ def chat_with_continue(llm: Union[ByzerLLM,SimpleByzerLLM], conversations: List[
32
39
  final_result.input_tokens_count += metadata.get("input_tokens_count", 0)
33
40
  final_result.generated_tokens_count += metadata.get("generated_tokens_count", 0)
34
41
  count += 1
42
+
43
+ if count >= args.generate_max_rounds:
44
+ warning_message = get_message_with_format(
45
+ "generate_max_rounds_reached",
46
+ count=count,
47
+ max_rounds=args.generate_max_rounds,
48
+ generated_tokens=final_result.generated_tokens_count
49
+ )
50
+ logger.warning(warning_message)
35
51
 
36
52
  # if count >= 2:
37
53
  # logger.info(f"The code generation is exceed the max length, continue to generate the code {count -1 } times")
@@ -41,7 +57,8 @@ def chat_with_continue(llm: Union[ByzerLLM,SimpleByzerLLM], conversations: List[
41
57
  def stream_chat_with_continue(
42
58
  llm: Union[ByzerLLM, SimpleByzerLLM],
43
59
  conversations: List[dict],
44
- llm_config: dict
60
+ llm_config: dict,
61
+ args: AutoCoderArgs
45
62
  ) -> Generator[Any, None, None]:
46
63
  """
47
64
  流式处理并继续生成内容,直到完成。
@@ -87,7 +104,16 @@ def stream_chat_with_continue(
87
104
  temp_conversations.append({"role": "assistant", "content": current_content})
88
105
 
89
106
  # 检查是否需要继续生成
90
- if current_metadata.finish_reason != "length" or count >= 5:
107
+ if current_metadata.finish_reason != "length" or count >= args.generate_max_rounds:
108
+ if count >= args.generate_max_rounds:
109
+ warning_message = get_message_with_format(
110
+ "generate_max_rounds_reached",
111
+ count=count,
112
+ max_rounds=args.generate_max_rounds,
113
+ generated_tokens=current_metadata.generated_tokens_count
114
+ )
115
+ logger.warning(warning_message)
91
116
  break
92
117
 
118
+
93
119
  count += 1