autocoder-nano 0.1.30__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. autocoder_nano/agent/agent_base.py +4 -4
  2. autocoder_nano/agent/agentic_edit.py +1584 -0
  3. autocoder_nano/agent/agentic_edit_tools/__init__.py +28 -0
  4. autocoder_nano/agent/agentic_edit_tools/ask_followup_question_tool.py +51 -0
  5. autocoder_nano/agent/agentic_edit_tools/attempt_completion_tool.py +36 -0
  6. autocoder_nano/agent/agentic_edit_tools/base_tool_resolver.py +31 -0
  7. autocoder_nano/agent/agentic_edit_tools/execute_command_tool.py +65 -0
  8. autocoder_nano/agent/agentic_edit_tools/list_code_definition_names_tool.py +78 -0
  9. autocoder_nano/agent/agentic_edit_tools/list_files_tool.py +123 -0
  10. autocoder_nano/agent/agentic_edit_tools/list_package_info_tool.py +42 -0
  11. autocoder_nano/agent/agentic_edit_tools/plan_mode_respond_tool.py +35 -0
  12. autocoder_nano/agent/agentic_edit_tools/read_file_tool.py +73 -0
  13. autocoder_nano/agent/agentic_edit_tools/replace_in_file_tool.py +148 -0
  14. autocoder_nano/agent/agentic_edit_tools/search_files_tool.py +135 -0
  15. autocoder_nano/agent/agentic_edit_tools/write_to_file_tool.py +57 -0
  16. autocoder_nano/agent/agentic_edit_types.py +151 -0
  17. autocoder_nano/auto_coder_nano.py +145 -91
  18. autocoder_nano/git_utils.py +63 -1
  19. autocoder_nano/llm_client.py +170 -3
  20. autocoder_nano/llm_types.py +53 -14
  21. autocoder_nano/rules/rules_learn.py +221 -0
  22. autocoder_nano/templates.py +1 -1
  23. autocoder_nano/utils/formatted_log_utils.py +128 -0
  24. autocoder_nano/utils/printer_utils.py +5 -4
  25. autocoder_nano/utils/shell_utils.py +85 -0
  26. autocoder_nano/version.py +1 -1
  27. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/METADATA +3 -2
  28. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/RECORD +33 -16
  29. autocoder_nano/agent/new/auto_new_project.py +0 -278
  30. /autocoder_nano/{agent/new → rules}/__init__.py +0 -0
  31. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/LICENSE +0 -0
  32. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/WHEEL +0 -0
  33. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/entry_points.txt +0 -0
  34. {autocoder_nano-0.1.30.dist-info → autocoder_nano-0.1.33.dist-info}/top_level.txt +0 -0
@@ -9,12 +9,15 @@ import textwrap
9
9
  import time
10
10
  import uuid
11
11
 
12
+ from autocoder_nano.agent.agentic_edit import AgenticEdit
13
+ from autocoder_nano.agent.agentic_edit_types import AgenticEditRequest
12
14
  from autocoder_nano.edit import Dispacher
13
15
  from autocoder_nano.helper import show_help
14
16
  from autocoder_nano.index.entry import build_index_and_filter_files
15
17
  from autocoder_nano.index.index_manager import IndexManager
16
18
  from autocoder_nano.index.symbols_utils import extract_symbols
17
19
  from autocoder_nano.llm_client import AutoLLM
20
+ from autocoder_nano.rules.rules_learn import AutoRulesLearn
18
21
  from autocoder_nano.version import __version__
19
22
  from autocoder_nano.llm_types import *
20
23
  from autocoder_nano.llm_prompt import prompt, extract_code
@@ -55,7 +58,8 @@ base_persist_dir = os.path.join(project_root, ".auto-coder", "plugins", "chat-au
55
58
  # ".vscode", ".idea", ".hg"]
56
59
  commands = [
57
60
  "/add_files", "/remove_files", "/list_files", "/conf", "/coding", "/chat", "/revert", "/index/query",
58
- "/index/build", "/exclude_dirs", "/exclude_files", "/help", "/shell", "/exit", "/mode", "/models", "/commit", "/new"
61
+ "/index/build", "/exclude_dirs", "/exclude_files", "/help", "/shell", "/exit", "/mode", "/models", "/commit",
62
+ "/rules", "/auto"
59
63
  ]
60
64
 
61
65
  memory = {
@@ -185,12 +189,7 @@ COMMANDS = {
185
189
  "/remove_files": {"/all": ""},
186
190
  "/coding": {"/apply": ""},
187
191
  "/chat": {"/history": "", "/new": "", "/review": ""},
188
- "/models": {
189
- "/add_model": "",
190
- "/remove": "",
191
- "/list": "",
192
- "/check": ""
193
- },
192
+ "/models": {"/add_model": "", "/remove": "", "/list": "", "/check": ""},
194
193
  "/help": {
195
194
  "/add_files": "",
196
195
  "/remove_files": "",
@@ -202,7 +201,8 @@ COMMANDS = {
202
201
  "/models": ""
203
202
  },
204
203
  "/exclude_files": {"/list": "", "/drop": ""},
205
- "/exclude_dirs": {}
204
+ "/exclude_dirs": {},
205
+ "/rules": {"/list": "", "/show": "", "/remove": "", "/analyze": "", "/commit": ""}
206
206
  }
207
207
 
208
208
 
@@ -707,6 +707,15 @@ class CommandCompleter(Completer):
707
707
  if command.startswith(current_word):
708
708
  yield Completion(command, start_position=-len(current_word))
709
709
 
710
+ elif words[0] == "/rules":
711
+ new_text = text[len("/rules"):]
712
+ parser = CommandTextParser(new_text, words[0])
713
+ parser.add_files()
714
+ current_word = parser.current_word()
715
+ for command in parser.get_sub_commands():
716
+ if command.startswith(current_word):
717
+ yield Completion(command, start_position=-len(current_word))
718
+
710
719
  elif words[0] == "/conf":
711
720
  new_words = text[len("/conf"):].strip().split()
712
721
  is_at_space = text[-1] == " "
@@ -1259,6 +1268,7 @@ def init_project():
1259
1268
  return
1260
1269
  os.makedirs(os.path.join(args.source_dir, "actions"), exist_ok=True)
1261
1270
  os.makedirs(os.path.join(args.source_dir, ".auto-coder"), exist_ok=True)
1271
+ os.makedirs(os.path.join(args.source_dir, ".auto-coder", "autocoderrules"), exist_ok=True)
1262
1272
  source_dir = os.path.abspath(args.source_dir)
1263
1273
  create_actions(
1264
1274
  source_dir=source_dir,
@@ -1307,7 +1317,7 @@ def load_include_files(config, base_path, max_depth=10, current_depth=0):
1307
1317
 
1308
1318
  for include_file in include_files:
1309
1319
  abs_include_path = resolve_include_path(base_path, include_file)
1310
- printer.print_text(f"正在加载 Include file: {abs_include_path}", style="green")
1320
+ # printer.print_text(f"正在加载 Include file: {abs_include_path}", style="green")
1311
1321
  with open(abs_include_path, "r") as f:
1312
1322
  include_config = yaml.safe_load(f)
1313
1323
  if not include_config:
@@ -1369,14 +1379,9 @@ def coding(query: str, llm: AutoLLM):
1369
1379
 
1370
1380
  memory["conversation"].append({"role": "user", "content": query})
1371
1381
  conf = memory.get("conf", {})
1372
-
1373
1382
  current_files = memory["current_files"]["files"]
1374
- current_groups = memory["current_files"].get("current_groups", [])
1375
- groups = memory["current_files"].get("groups", {})
1376
- groups_info = memory["current_files"].get("groups_info", {})
1377
1383
 
1378
1384
  prepare_chat_yaml() # 复制上一个序号的 yaml 文件, 生成一个新的聊天 yaml 文件
1379
-
1380
1385
  latest_yaml_file = get_last_yaml_file(os.path.join(args.source_dir, "actions"))
1381
1386
 
1382
1387
  if latest_yaml_file:
@@ -1398,19 +1403,6 @@ def coding(query: str, llm: AutoLLM):
1398
1403
  yaml_config["urls"] = current_files
1399
1404
  yaml_config["query"] = query
1400
1405
 
1401
- if current_groups:
1402
- active_groups_context = "下面是对上面文件按分组给到的一些描述,当用户的需求正好匹配描述的时候,参考描述来做修改:\n"
1403
- for group in current_groups:
1404
- group_files = groups.get(group, [])
1405
- query_prefix = groups_info.get(group, {}).get("query_prefix", "")
1406
- active_groups_context += f"组名: {group}\n"
1407
- active_groups_context += f"文件列表:\n"
1408
- for file in group_files:
1409
- active_groups_context += f"- {file}\n"
1410
- active_groups_context += f"组描述: {query_prefix}\n\n"
1411
-
1412
- yaml_config["context"] = active_groups_context + "\n"
1413
-
1414
1406
  if is_apply:
1415
1407
  memory_dir = os.path.join(args.source_dir, ".auto-coder", "memory")
1416
1408
  os.makedirs(memory_dir, exist_ok=True)
@@ -1441,6 +1433,19 @@ def coding(query: str, llm: AutoLLM):
1441
1433
  yaml_config["context"] += f"你: {conv['content']}\n"
1442
1434
  yaml_config["context"] += "</history>\n"
1443
1435
 
1436
+ if args.enable_rules:
1437
+ rules_dir_path = os.path.join(project_root, ".auto-coder", "autocoderrules")
1438
+ printer.print_text("已开启 Rules 模式", style="green")
1439
+ yaml_config["context"] += f"下面是我们对代码进行深入分析,提取具有通用价值的功能模式和设计模式,可在其他需求中复用的Rules\n"
1440
+ yaml_config["context"] += "你在编写代码时可以参考以下Rules\n"
1441
+ yaml_config["context"] += "<rules>\n"
1442
+ for rules_name in os.listdir(rules_dir_path):
1443
+ printer.print_text(f"正在加载 Rules:{rules_name}", style="green")
1444
+ rules_file_path = os.path.join(rules_dir_path, rules_name)
1445
+ with open(rules_file_path, "r") as fp:
1446
+ yaml_config["context"] += f"{fp.read()}\n"
1447
+ yaml_config["context"] += "</rules>\n"
1448
+
1444
1449
  yaml_config["file"] = latest_yaml_file
1445
1450
  yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1446
1451
  execute_file = os.path.join(args.source_dir, "actions", latest_yaml_file)
@@ -1572,6 +1577,20 @@ def commit_info(query: str, llm: AutoLLM):
1572
1577
  os.remove(execute_file)
1573
1578
 
1574
1579
 
1580
+ def agentic_edit(query: str, llm: AutoLLM):
1581
+ update_config_to_args(query=query, delete_execute_file=True)
1582
+
1583
+ sources = SourceCodeList([])
1584
+ agentic_editor = AgenticEdit(
1585
+ args=args, llm=llm, files=sources, history_conversation=[]
1586
+ )
1587
+
1588
+ query = query.strip()
1589
+ request = AgenticEditRequest(user_input=query)
1590
+
1591
+ agentic_editor.run_in_terminal(request)
1592
+
1593
+
1575
1594
  @prompt()
1576
1595
  def _generate_shell_script(user_input: str) -> str:
1577
1596
  """
@@ -2224,71 +2243,100 @@ def configure_project_model():
2224
2243
  )
2225
2244
 
2226
2245
 
2227
- # def new_project(query, llm):
2228
- # console.print(f"正在基于你的需求 {query} 构建项目 ...", style="bold green")
2229
- # env_info = detect_env()
2230
- # project = BuildNewProject(args=args, llm=llm,
2231
- # chat_model=memory["conf"]["chat_model"],
2232
- # code_model=memory["conf"]["code_model"])
2233
- #
2234
- # console.print(f"正在完善项目需求 ...", style="bold green")
2235
- #
2236
- # information = project.build_project_information(query, env_info, args.project_type)
2237
- # if not information:
2238
- # raise Exception(f"项目需求未正常生成 .")
2239
- #
2240
- # table = Table(title=f"{query}")
2241
- # table.add_column("需求说明", style="cyan")
2242
- # table.add_row(f"{information[:50]}...")
2243
- # console.print(table)
2244
- #
2245
- # console.print(f"正在完善项目架构 ...", style="bold green")
2246
- # architecture = project.build_project_architecture(query, env_info, args.project_type, information)
2247
- #
2248
- # console.print(f"正在构建项目索引 ...", style="bold green")
2249
- # index_file_list = project.build_project_index(query, env_info, args.project_type, information, architecture)
2250
- #
2251
- # table = Table(title=f"索引列表")
2252
- # table.add_column("路径", style="cyan")
2253
- # table.add_column("用途", style="cyan")
2254
- # for index_file in index_file_list.file_list:
2255
- # table.add_row(index_file.file_path, index_file.purpose)
2256
- # console.print(table)
2257
- #
2258
- # for index_file in index_file_list.file_list:
2259
- # full_path = os.path.join(args.source_dir, index_file.file_path)
2260
- #
2261
- # # 获取目录路径
2262
- # full_dir_path = os.path.dirname(full_path)
2263
- # if not os.path.exists(full_dir_path):
2264
- # os.makedirs(full_dir_path)
2265
- #
2266
- # console.print(f"正在编码: {full_path} ...", style="bold green")
2267
- # code = project.build_single_code(query, env_info, args.project_type, information, architecture, index_file)
2268
- #
2269
- # with open(full_path, "w") as fp:
2270
- # fp.write(code)
2271
- #
2272
- # # 生成 readme
2273
- # readme_context = information + architecture
2274
- # readme_path = os.path.join(args.source_dir, "README.md")
2275
- # with open(readme_path, "w") as fp:
2276
- # fp.write(readme_context)
2277
- #
2278
- # console.print(f"项目构建完成", style="bold green")
2246
+ def rules(query_args: List[str], llm: AutoLLM):
2247
+ """
2248
+ /rules 命令帮助:
2249
+ /rules /list - 列出规则文件
2250
+ /rules /show - 查看规则文件内容
2251
+ /rules /remove - 删除规则文件
2252
+ /rules /analyze - 分析当前文件,可选提供查询内容
2253
+ /rules /commit <提交ID> - 分析特定提交,必须提供提交ID和查询内容
2254
+ """
2255
+ update_config_to_args(query="", delete_execute_file=True)
2256
+ rules_dir_path = os.path.join(project_root, ".auto-coder", "autocoderrules")
2257
+ if query_args[0] == "/list":
2258
+ printer.print_table_compact(
2259
+ data=[[rules_name] for rules_name in os.listdir(rules_dir_path)],
2260
+ title="Rules 列表",
2261
+ headers=["Rules 文件"],
2262
+ center=True
2263
+ )
2264
+
2265
+ if query_args[0] == "/remove":
2266
+ remove_rules_name = query_args[1].strip()
2267
+ remove_rules_path = os.path.join(rules_dir_path, remove_rules_name)
2268
+ if os.path.exists(remove_rules_path):
2269
+ os.remove(remove_rules_path)
2270
+ printer.print_text(f"Rules 文件[{remove_rules_name}]移除成功", style="green")
2271
+ else:
2272
+ printer.print_text(f"Rules 文件[{remove_rules_name}]不存在", style="yellow")
2273
+
2274
+ if query_args[0] == "/show": # /rules /show 参数检查
2275
+ show_rules_name = query_args[1].strip()
2276
+ show_rules_path = os.path.join(rules_dir_path, show_rules_name)
2277
+ if os.path.exists(show_rules_path):
2278
+ with open(show_rules_path, "r") as fp:
2279
+ printer.print_markdown(text=fp.read(), panel=True)
2280
+ else:
2281
+ printer.print_text(f"Rules 文件[{show_rules_name}]不存在", style="yellow")
2282
+
2283
+ if query_args[0] == "/commit":
2284
+ commit_id = query_args[1].strip()
2285
+ auto_learn = AutoRulesLearn(llm=llm, args=args)
2286
+
2287
+ try:
2288
+ result = auto_learn.analyze_commit_changes(commit_id=commit_id, conversations=[])
2289
+ rules_file = os.path.join(rules_dir_path, f"rules-commit-{uuid.uuid4()}.md")
2290
+ with open(rules_file, "w", encoding="utf-8") as f:
2291
+ f.write(result)
2292
+ printer.print_text(f"代码变更[{commit_id}]生成 Rules 成功", style="green")
2293
+ except Exception as e:
2294
+ printer.print_text(f"代码变更[{commit_id}]生成 Rules 失败: {e}", style="red")
2295
+
2296
+ if query_args[0] == "/analyze":
2297
+ auto_learn = AutoRulesLearn(llm=llm, args=args)
2298
+
2299
+ files = memory.get("current_files", {}).get("files", [])
2300
+ if not files:
2301
+ printer.print_text("当前无活跃文件用于生成 Rules", style="yellow")
2302
+ return
2303
+
2304
+ sources = SourceCodeList([])
2305
+ for file in files:
2306
+ try:
2307
+ with open(file, "r", encoding="utf-8") as f:
2308
+ source_code = f.read()
2309
+ sources.sources.append(SourceCode(module_name=file, source_code=source_code))
2310
+ except Exception as e:
2311
+ printer.print_text(f"读取文件生成 Rules 失败: {e}", style="yellow")
2312
+ continue
2313
+
2314
+ try:
2315
+ result = auto_learn.analyze_modules(sources=sources, conversations=[])
2316
+ rules_file = os.path.join(rules_dir_path, f"rules-modules-{uuid.uuid4()}.md")
2317
+ with open(rules_file, "w", encoding="utf-8") as f:
2318
+ f.write(result)
2319
+ printer.print_text(f"活跃文件[Files:{len(files)}]生成 Rules 成功", style="green")
2320
+ except Exception as e:
2321
+ printer.print_text(f"活跃文件生成 Rules 失败: {e}", style="red")
2322
+
2323
+ completer.refresh_files()
2279
2324
 
2280
2325
 
2281
2326
  def is_old_version():
2282
- """
2283
- __version__ = "0.1.26" 开始使用兼容 AutoCoder 的 chat_model, code_model 参数
2284
- 不再使用 current_chat_model 和 current_chat_model
2285
- """
2327
+ # "0.1.26" 开始使用兼容 AutoCoder 的 chat_model, code_model 参数
2328
+ # 不再使用 current_chat_model current_chat_model
2286
2329
  if 'current_chat_model' in memory['conf'] and 'current_code_model' in memory['conf']:
2287
- printer.print_text(f"您当前使用的版本偏低 {__version__}, 正在进行配置兼容性处理", style="yellow")
2330
+ printer.print_text(f"0.1.26 新增 chat_model, code_model 参数, 正在进行配置兼容性处理", style="yellow")
2288
2331
  memory['conf']['chat_model'] = memory['conf']['current_chat_model']
2289
2332
  memory['conf']['code_model'] = memory['conf']['current_code_model']
2290
2333
  del memory['conf']['current_chat_model']
2291
2334
  del memory['conf']['current_code_model']
2335
+ # "0.1.31" 在 .auto-coder 目录中新增 autocoderrules 目录
2336
+ rules_dir_path = os.path.join(project_root, ".auto-coder", "autocoderrules")
2337
+ if not os.path.exists(rules_dir_path):
2338
+ printer.print_text(f"0.1.31 .auto-coder 目录中新增 autocoderrules 目录, 正在进行配置兼容性处理", style="yellow")
2339
+ os.makedirs(rules_dir_path, exist_ok=True)
2292
2340
 
2293
2341
 
2294
2342
  def main():
@@ -2453,6 +2501,12 @@ def main():
2453
2501
  elif user_input.startswith("/commit"):
2454
2502
  query = user_input[len("/commit"):].strip()
2455
2503
  commit_info(query, auto_llm)
2504
+ elif user_input.startswith("/rules"):
2505
+ query_args = user_input[len("/rules"):].strip().split()
2506
+ if not query_args:
2507
+ printer.print_text("Please enter your request.", style="yellow")
2508
+ continue
2509
+ rules(query_args=query_args, llm=auto_llm)
2456
2510
  elif user_input.startswith("/help"):
2457
2511
  query = user_input[len("/help"):].strip()
2458
2512
  show_help(query)
@@ -2461,15 +2515,15 @@ def main():
2461
2515
  elif user_input.startswith("/coding"):
2462
2516
  query = user_input[len("/coding"):].strip()
2463
2517
  if not query:
2464
- print("\033[91mPlease enter your request.\033[0m")
2518
+ printer.print_text("Please enter your request.", style="yellow")
2465
2519
  continue
2466
2520
  coding(query=query, llm=auto_llm)
2467
- # elif user_input.startswith("/new"):
2468
- # query = user_input[len("/new"):].strip()
2469
- # if not query:
2470
- # print("\033[91mPlease enter your request.\033[0m")
2471
- # continue
2472
- # new_project(query=query, llm=auto_llm)
2521
+ elif user_input.startswith("/auto"):
2522
+ query = user_input[len("/auto"):].strip()
2523
+ if not query:
2524
+ print("\033[91mPlease enter your request.\033[0m")
2525
+ continue
2526
+ agentic_edit(query=query, llm=auto_llm)
2473
2527
  elif user_input.startswith("/chat"):
2474
2528
  query = user_input[len("/chat"):].strip()
2475
2529
  if not query:
@@ -1,4 +1,5 @@
1
1
  import os
2
+ from typing import Tuple, List, Dict, Optional
2
3
 
3
4
  from autocoder_nano.llm_prompt import prompt
4
5
  from git import Repo, GitCommandError
@@ -540,4 +541,65 @@ def generate_commit_message(changes_report: str) -> str:
540
541
  {{ changes_report }}
541
542
 
542
543
  请输出commit message, 不要输出任何其他内容.
543
- '''
544
+ '''
545
+
546
+
547
+ def get_commit_changes(
548
+ repo_path: str, commit_id: str
549
+ ) -> Tuple[List[Tuple[str, List[str], Dict[str, Tuple[str, str]]]], Optional[str]]:
550
+ """ 直接从Git仓库获取指定commit的变更 """
551
+ querie_with_urls_and_changes = []
552
+ try:
553
+ repo = get_repo(repo_path)
554
+ commit = repo.commit(commit_id)
555
+ modified_files = []
556
+ changes = {}
557
+
558
+ # 检查是否是首次提交(没有父提交)
559
+ if not commit.parents:
560
+ # 首次提交,获取所有文件
561
+ for item in commit.tree.traverse():
562
+ if item.type == 'blob': # 只处理文件,不处理目录
563
+ file_path = item.path
564
+ modified_files.append(file_path)
565
+ # 首次提交前没有内容
566
+ before_content = None
567
+ # 获取提交后的内容
568
+ after_content = repo.git.show(f"{commit.hexsha}:{file_path}")
569
+ changes[file_path] = (before_content, after_content)
570
+ else:
571
+ # 获取parent commit
572
+ parent = commit.parents[0]
573
+ # 获取变更的文件列表
574
+ for diff_item in parent.diff(commit):
575
+ file_path = diff_item.a_path if diff_item.a_path else diff_item.b_path
576
+ modified_files.append(file_path)
577
+
578
+ # 获取变更前内容
579
+ before_content = None
580
+ try:
581
+ if diff_item.a_blob:
582
+ before_content = repo.git.show(f"{parent.hexsha}:{file_path}")
583
+ except GitCommandError:
584
+ pass # 文件可能是新增的
585
+
586
+ # 获取变更后内容
587
+ after_content = None
588
+ try:
589
+ if diff_item.b_blob:
590
+ after_content = repo.git.show(f"{commit.hexsha}:{file_path}")
591
+ except GitCommandError:
592
+ pass # 文件可能被删除
593
+
594
+ changes[file_path] = (before_content, after_content)
595
+
596
+ # 使用commit消息作为查询内容
597
+ query = commit.message
598
+ querie_with_urls_and_changes.append((query, modified_files, changes))
599
+
600
+ except GitCommandError as e:
601
+ printer.print_text(f"git_command_error: {e}.", style="red")
602
+ except Exception as e:
603
+ printer.print_text(f"get_commit_changes_error: {e}.", style="red")
604
+
605
+ return querie_with_urls_and_changes, None
@@ -1,10 +1,10 @@
1
- from typing import List
1
+ from typing import List, Generator, Any, Optional, Dict, Union
2
2
 
3
3
  # from loguru import logger
4
4
  from openai import OpenAI, Stream
5
5
  from openai.types.chat import ChatCompletionChunk, ChatCompletion
6
6
 
7
- from autocoder_nano.llm_types import LLMRequest, LLMResponse
7
+ from autocoder_nano.llm_types import LLMRequest, LLMResponse, AutoCoderArgs, SingleOutputMeta
8
8
  from autocoder_nano.utils.printer_utils import Printer
9
9
 
10
10
 
@@ -53,6 +53,126 @@ class AutoLLM:
53
53
  res = self._query(model, request, stream=True)
54
54
  return res
55
55
 
56
+ def stream_chat_ai_ex(
57
+ self, conversations, model: Optional[str] = None, role_mapping=None, delta_mode: bool = False,
58
+ is_reasoning: bool = False, llm_config: dict | None = None
59
+ ):
60
+ if llm_config is None:
61
+ llm_config = {}
62
+ if not model:
63
+ model = self.default_model_name
64
+
65
+ client: OpenAI = self.sub_clients[model]["client"]
66
+ model_name = self.sub_clients[model]["model_name"]
67
+
68
+ request = LLMRequest(
69
+ model=model_name,
70
+ messages=conversations,
71
+ stream=True
72
+ )
73
+
74
+ if is_reasoning:
75
+ response = client.chat.completions.create(
76
+ messages=request.messages,
77
+ model=request.model,
78
+ stream=request.stream,
79
+ stream_options={"include_usage": True},
80
+ extra_headers={
81
+ "HTTP-Referer": "https://auto-coder.chat",
82
+ "X-Title": "auto-coder-nano"
83
+ },
84
+ **llm_config
85
+ )
86
+ else:
87
+ response = client.chat.completions.create(
88
+ messages=conversations,
89
+ model=model_name,
90
+ temperature=llm_config.get("temperature", request.temperature),
91
+ max_tokens=llm_config.get("max_tokens", request.max_tokens),
92
+ top_p=llm_config.get("top_p", request.top_p),
93
+ stream=request.stream,
94
+ stream_options={"include_usage": True},
95
+ **llm_config
96
+ )
97
+
98
+ last_meta = None
99
+
100
+ if delta_mode:
101
+ for chunk in response:
102
+ if hasattr(chunk, "usage") and chunk.usage:
103
+ input_tokens_count = chunk.usage.prompt_tokens
104
+ generated_tokens_count = chunk.usage.completion_tokens
105
+ else:
106
+ input_tokens_count = 0
107
+ generated_tokens_count = 0
108
+
109
+ if not chunk.choices:
110
+ if last_meta:
111
+ yield (
112
+ "",
113
+ SingleOutputMeta(
114
+ input_tokens_count=input_tokens_count,
115
+ generated_tokens_count=generated_tokens_count,
116
+ reasoning_content="",
117
+ finish_reason=last_meta.finish_reason,
118
+ ),
119
+ )
120
+ continue
121
+
122
+ content = chunk.choices[0].delta.content or ""
123
+
124
+ reasoning_text = ""
125
+ if hasattr(chunk.choices[0].delta, "reasoning_content"):
126
+ reasoning_text = chunk.choices[0].delta.reasoning_content or ""
127
+
128
+ last_meta = SingleOutputMeta(
129
+ input_tokens_count=input_tokens_count,
130
+ generated_tokens_count=generated_tokens_count,
131
+ reasoning_content=reasoning_text,
132
+ finish_reason=chunk.choices[0].finish_reason,
133
+ )
134
+ yield content, last_meta
135
+ else:
136
+ s = ""
137
+ all_reasoning_text = ""
138
+ for chunk in response:
139
+ if hasattr(chunk, "usage") and chunk.usage:
140
+ input_tokens_count = chunk.usage.prompt_tokens
141
+ generated_tokens_count = chunk.usage.completion_tokens
142
+ else:
143
+ input_tokens_count = 0
144
+ generated_tokens_count = 0
145
+
146
+ if not chunk.choices:
147
+ if last_meta:
148
+ yield (
149
+ s,
150
+ SingleOutputMeta(
151
+ input_tokens_count=input_tokens_count,
152
+ generated_tokens_count=generated_tokens_count,
153
+ reasoning_content=all_reasoning_text,
154
+ finish_reason=last_meta.finish_reason,
155
+ ),
156
+ )
157
+ continue
158
+
159
+ content = chunk.choices[0].delta.content or ""
160
+ reasoning_text = ""
161
+ if hasattr(chunk.choices[0].delta, "reasoning_content"):
162
+ reasoning_text = chunk.choices[0].delta.reasoning_content or ""
163
+
164
+ s += content
165
+ all_reasoning_text += reasoning_text
166
+ yield (
167
+ s,
168
+ SingleOutputMeta(
169
+ input_tokens_count=input_tokens_count,
170
+ generated_tokens_count=generated_tokens_count,
171
+ reasoning_content=all_reasoning_text,
172
+ finish_reason=chunk.choices[0].finish_reason,
173
+ ),
174
+ )
175
+
56
176
  def chat_ai(self, conversations, model=None) -> LLMResponse:
57
177
  # conversations = [{"role": "user", "content": prompt_str}] deepseek-chat
58
178
  if not model and not self.default_model_name:
@@ -129,4 +249,51 @@ class AutoLLM:
129
249
  "model": res.model,
130
250
  "created": res.created
131
251
  }
132
- )
252
+ )
253
+
254
+
255
+ def stream_chat_with_continue(
256
+ llm: AutoLLM, conversations: List[dict], llm_config: dict, args: AutoCoderArgs
257
+ ) -> Generator[Any, None, None]:
258
+ """ 流式处理并继续生成内容,直到完成 """
259
+ count = 0
260
+ temp_conversations = [] + conversations
261
+ current_metadata = None
262
+ metadatas = {}
263
+ while True:
264
+ # 使用流式接口获取生成内容
265
+ stream_generator = llm.stream_chat_ai_ex(
266
+ conversations=temp_conversations,
267
+ model=args.chat_model,
268
+ delta_mode=True,
269
+ llm_config={**llm_config}
270
+ )
271
+
272
+ current_content = ""
273
+
274
+ for res in stream_generator:
275
+ content = res[0]
276
+ current_content += content
277
+ if current_metadata is None:
278
+ current_metadata = res[1]
279
+ metadatas[count] = res[1]
280
+ else:
281
+ metadatas[count] = res[1]
282
+ current_metadata.finish_reason = res[1].finish_reason
283
+ current_metadata.reasoning_content = res[1].reasoning_content
284
+
285
+ # Yield 当前的 StreamChatWithContinueResult
286
+ current_metadata.generated_tokens_count = sum([v.generated_tokens_count for _, v in metadatas.items()])
287
+ current_metadata.input_tokens_count = sum([v.input_tokens_count for _, v in metadatas.items()])
288
+ yield content, current_metadata
289
+
290
+ # 更新对话历史
291
+ temp_conversations.append({"role": "assistant", "content": current_content})
292
+
293
+ # 检查是否需要继续生成
294
+ if current_metadata.finish_reason != "length" or count >= args.generate_max_rounds:
295
+ if count >= args.generate_max_rounds:
296
+ printer.print_text(f"LLM生成达到的最大次数, 当前次数:{count}, 最大次数: {args.generate_max_rounds}, "
297
+ f"Tokens: {current_metadata.generated_tokens_count}", style="yellow")
298
+ break
299
+ count += 1