auto-coder 0.1.207__py3-none-any.whl → 0.1.209__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (37) hide show
  1. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/METADATA +4 -3
  2. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/RECORD +37 -34
  3. autocoder/agent/auto_demand_organizer.py +212 -0
  4. autocoder/agent/auto_guess_query.py +284 -0
  5. autocoder/auto_coder.py +64 -19
  6. autocoder/auto_coder_rag.py +6 -0
  7. autocoder/chat_auto_coder.py +119 -16
  8. autocoder/command_args.py +21 -5
  9. autocoder/common/__init__.py +7 -1
  10. autocoder/common/code_auto_generate.py +32 -10
  11. autocoder/common/code_auto_generate_diff.py +85 -47
  12. autocoder/common/code_auto_generate_editblock.py +50 -28
  13. autocoder/common/code_auto_generate_strict_diff.py +79 -45
  14. autocoder/common/code_auto_merge.py +51 -15
  15. autocoder/common/code_auto_merge_diff.py +55 -2
  16. autocoder/common/code_auto_merge_editblock.py +84 -14
  17. autocoder/common/code_auto_merge_strict_diff.py +69 -32
  18. autocoder/common/code_modification_ranker.py +100 -0
  19. autocoder/common/command_completer.py +6 -4
  20. autocoder/common/types.py +10 -2
  21. autocoder/dispacher/actions/action.py +141 -94
  22. autocoder/dispacher/actions/plugins/action_regex_project.py +35 -25
  23. autocoder/lang.py +9 -1
  24. autocoder/pyproject/__init__.py +4 -0
  25. autocoder/rag/cache/simple_cache.py +8 -2
  26. autocoder/rag/loaders/docx_loader.py +3 -2
  27. autocoder/rag/loaders/pdf_loader.py +3 -1
  28. autocoder/rag/long_context_rag.py +12 -2
  29. autocoder/rag/rag_entry.py +2 -2
  30. autocoder/rag/utils.py +14 -9
  31. autocoder/suffixproject/__init__.py +2 -0
  32. autocoder/tsproject/__init__.py +4 -0
  33. autocoder/version.py +1 -1
  34. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/LICENSE +0 -0
  35. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/WHEEL +0 -0
  36. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/entry_points.txt +0 -0
  37. {auto_coder-0.1.207.dist-info → auto_coder-0.1.209.dist-info}/top_level.txt +0 -0
autocoder/auto_coder.py CHANGED
@@ -207,7 +207,8 @@ def main(input_args: Optional[List[str]] = None):
207
207
  max_seq = max(seqs)
208
208
 
209
209
  new_seq = str(max_seq + 1).zfill(12)
210
- prev_files = [f for f in action_files if int(get_old_seq(f)) < int(new_seq)]
210
+ prev_files = [f for f in action_files if int(
211
+ get_old_seq(f)) < int(new_seq)]
211
212
 
212
213
  if raw_args.from_yaml:
213
214
  # If --from_yaml is specified, copy content from the matching YAML file
@@ -278,9 +279,30 @@ def main(input_args: Optional[List[str]] = None):
278
279
  llm = byzerllm.ByzerLLM(verbose=args.print_request)
279
280
 
280
281
  if args.code_model:
281
- code_model = byzerllm.ByzerLLM()
282
- code_model.setup_default_model_name(args.code_model)
283
- llm.setup_sub_client("code_model", code_model)
282
+ if "," in args.code_model:
283
+ # Multiple code models specified
284
+ model_names = args.code_model.split(",")
285
+ models = []
286
+ for _, model_name in enumerate(model_names):
287
+ code_model = byzerllm.ByzerLLM()
288
+ code_model.setup_default_model_name(model_name.strip())
289
+ models.append(code_model)
290
+ llm.setup_sub_client("code_model", models)
291
+ else:
292
+ # Single code model
293
+ code_model = byzerllm.ByzerLLM()
294
+ code_model.setup_default_model_name(args.code_model)
295
+ llm.setup_sub_client("code_model", code_model)
296
+
297
+ if args.generate_rerank_model:
298
+ generate_rerank_model = byzerllm.ByzerLLM()
299
+ generate_rerank_model.setup_default_model_name(args.generate_rerank_model)
300
+ llm.setup_sub_client("generate_rerank_model", generate_rerank_model)
301
+
302
+ if args.inference_model:
303
+ inference_model = byzerllm.ByzerLLM()
304
+ inference_model.setup_default_model_name(args.inference_model)
305
+ llm.setup_sub_client("inference_model", inference_model)
284
306
 
285
307
  if args.human_as_model:
286
308
 
@@ -386,18 +408,25 @@ def main(input_args: Optional[List[str]] = None):
386
408
 
387
409
  llm.add_event_callback(
388
410
  EventName.BEFORE_CALL_MODEL, intercept_callback)
389
- code_model = llm.get_sub_client("code_model")
390
- if code_model:
391
- code_model.add_event_callback(
392
- EventName.BEFORE_CALL_MODEL, intercept_callback
393
- )
411
+
412
+ code_models = llm.get_sub_client("code_model")
413
+ if code_models:
414
+ if not isinstance(code_models, list):
415
+ code_models = [code_models]
416
+ for model in code_models:
417
+ model.add_event_callback(
418
+ EventName.BEFORE_CALL_MODEL, intercept_callback
419
+ )
394
420
  # llm.add_event_callback(EventName.AFTER_CALL_MODEL, token_counter_interceptor)
395
421
 
396
- code_model = llm.get_sub_client("code_model")
397
- if code_model:
398
- code_model.add_event_callback(
399
- EventName.AFTER_CALL_MODEL, token_counter_interceptor
400
- )
422
+ code_models = llm.get_sub_client("code_model")
423
+ if code_models:
424
+ if not isinstance(code_models, list):
425
+ code_models = [code_models]
426
+ for model in code_models:
427
+ model.add_event_callback(
428
+ EventName.AFTER_CALL_MODEL, token_counter_interceptor
429
+ )
401
430
 
402
431
  llm.setup_template(model=args.model, template="auto")
403
432
  llm.setup_default_model_name(args.model)
@@ -493,7 +522,7 @@ def main(input_args: Optional[List[str]] = None):
493
522
  from autocoder.index.for_command import index_query_command
494
523
 
495
524
  index_query_command(args, llm)
496
- return
525
+ return
497
526
 
498
527
  if raw_args.command == "agent":
499
528
  if raw_args.agent_command == "planner":
@@ -694,24 +723,39 @@ def main(input_args: Optional[List[str]] = None):
694
723
  memory_file = os.path.join(memory_dir, "chat_history.json")
695
724
  console = Console()
696
725
  if args.new_session:
697
- chat_history = {"ask_conversation": []}
726
+ if os.path.exists(memory_file):
727
+ with open(memory_file, "r") as f:
728
+ old_chat_history = json.load(f)
729
+ if "conversation_history" not in old_chat_history:
730
+ old_chat_history["conversation_history"] = []
731
+ old_chat_history["conversation_history"].append(
732
+ old_chat_history.get("ask_conversation", []))
733
+ chat_history = {"ask_conversation": [
734
+ ], "conversation_history": old_chat_history["conversation_history"]}
735
+ else:
736
+ chat_history = {"ask_conversation": [],
737
+ "conversation_history": []}
698
738
  with open(memory_file, "w") as f:
699
739
  json.dump(chat_history, f, ensure_ascii=False)
700
740
  console.print(
701
741
  Panel(
702
- "New session started. Previous chat history has been cleared.",
742
+ "New session started. Previous chat history has been archived.",
703
743
  title="Session Status",
704
744
  expand=False,
705
745
  border_style="green",
706
746
  )
707
747
  )
708
- return
748
+ if not args.query:
749
+ return
709
750
 
710
751
  if os.path.exists(memory_file):
711
752
  with open(memory_file, "r") as f:
712
753
  chat_history = json.load(f)
754
+ if "conversation_history" not in chat_history:
755
+ chat_history["conversation_history"] = []
713
756
  else:
714
- chat_history = {"ask_conversation": []}
757
+ chat_history = {"ask_conversation": [],
758
+ "conversation_history": []}
715
759
 
716
760
  chat_history["ask_conversation"].append(
717
761
  {"role": "user", "content": args.query}
@@ -937,6 +981,7 @@ def main(input_args: Optional[List[str]] = None):
937
981
 
938
982
  with open(memory_file, "w") as f:
939
983
  json.dump(chat_history, f, ensure_ascii=False)
984
+
940
985
  return
941
986
 
942
987
  else:
@@ -314,6 +314,12 @@ def main(input_args: Optional[List[str]] = None):
314
314
  action="store_true",
315
315
  help="Whether to return responses without contexts. only works when pro plugin is installed",
316
316
  )
317
+ serve_parser.add_argument(
318
+ "--data_cells_max_num",
319
+ type=int,
320
+ default=2000,
321
+ help="Maximum number of data cells to process",
322
+ )
317
323
 
318
324
  serve_parser.add_argument(
319
325
  "--recall_model",
@@ -59,6 +59,7 @@ import byzerllm
59
59
  from byzerllm.utils import format_str_jinja2
60
60
  from autocoder.chat_auto_coder_lang import get_message
61
61
  from autocoder.utils import operate_config_api
62
+ from autocoder.agent.auto_guess_query import AutoGuessQuery
62
63
 
63
64
 
64
65
  class SymbolItem(BaseModel):
@@ -930,6 +931,14 @@ class CommandCompleter(Completer):
930
931
  yield Completion(
931
932
  lib_name, start_position=-len(current_word)
932
933
  )
934
+ elif words[0] == "/coding":
935
+ new_text = text[len("/coding"):]
936
+ parser = CommandTextParser(new_text, words[0])
937
+ parser.lib()
938
+ current_word = parser.current_word()
939
+ for command in parser.get_sub_commands():
940
+ if command.startswith(current_word):
941
+ yield Completion(command, start_position=-len(current_word))
933
942
 
934
943
  elif words[0] == "/conf":
935
944
  new_words = text[len("/conf"):].strip().split()
@@ -1427,6 +1436,88 @@ def convert_yaml_to_config(yaml_file: str):
1427
1436
  setattr(args, key, value)
1428
1437
  return args
1429
1438
 
1439
+
1440
+ def code_next(query: str):
1441
+ conf = memory.get("conf", {})
1442
+ yaml_config = {
1443
+ "include_file": ["./base/base.yml"],
1444
+ "auto_merge": conf.get("auto_merge", "editblock"),
1445
+ "human_as_model": conf.get("human_as_model", "false") == "true",
1446
+ "skip_build_index": conf.get("skip_build_index", "true") == "true",
1447
+ "skip_confirm": conf.get("skip_confirm", "true") == "true",
1448
+ "silence": conf.get("silence", "true") == "true",
1449
+ "include_project_structure": conf.get("include_project_structure", "true")
1450
+ == "true",
1451
+ }
1452
+ for key, value in conf.items():
1453
+ converted_value = convert_config_value(key, value)
1454
+ if converted_value is not None:
1455
+ yaml_config[key] = converted_value
1456
+
1457
+ temp_yaml = os.path.join("actions", f"{uuid.uuid4()}.yml")
1458
+ try:
1459
+ with open(temp_yaml, "w") as f:
1460
+ f.write(convert_yaml_config_to_str(yaml_config=yaml_config))
1461
+ args = convert_yaml_to_config(temp_yaml)
1462
+ finally:
1463
+ if os.path.exists(temp_yaml):
1464
+ os.remove(temp_yaml)
1465
+
1466
+ llm = byzerllm.ByzerLLM.from_default_model(
1467
+ args.inference_model or args.model)
1468
+
1469
+ auto_guesser = AutoGuessQuery(
1470
+ llm=llm,
1471
+ project_dir=os.getcwd(),
1472
+ skip_diff=True
1473
+ )
1474
+
1475
+ predicted_tasks = auto_guesser.predict_next_tasks(
1476
+ 5, is_human_as_model=args.human_as_model)
1477
+
1478
+ if not predicted_tasks:
1479
+ console = Console()
1480
+ console.print(Panel("No task predictions available", style="yellow"))
1481
+ return
1482
+
1483
+ console = Console()
1484
+
1485
+ # Create main panel for all predicted tasks
1486
+ table = Table(show_header=True,
1487
+ header_style="bold magenta", show_lines=True)
1488
+ table.add_column("Priority", style="cyan", width=8)
1489
+ table.add_column("Task Description", style="green",
1490
+ width=40, overflow="fold")
1491
+ table.add_column("Files", style="yellow", width=30, overflow="fold")
1492
+ table.add_column("Reason", style="blue", width=30, overflow="fold")
1493
+ table.add_column("Dependencies", style="magenta",
1494
+ width=30, overflow="fold")
1495
+
1496
+ for task in predicted_tasks:
1497
+ # Format file paths to be more readable
1498
+ file_list = "\n".join([os.path.relpath(f, os.getcwd())
1499
+ for f in task.urls])
1500
+
1501
+ # Format dependencies to be more readable
1502
+ dependencies = "\n".join(
1503
+ task.dependency_queries) if task.dependency_queries else "None"
1504
+
1505
+ table.add_row(
1506
+ str(task.priority),
1507
+ task.query,
1508
+ file_list,
1509
+ task.reason,
1510
+ dependencies
1511
+ )
1512
+
1513
+ console.print(Panel(
1514
+ table,
1515
+ title="[bold]Predicted Next Tasks[/bold]",
1516
+ border_style="blue",
1517
+ padding=(1, 2) # Add more horizontal padding
1518
+ ))
1519
+
1520
+
1430
1521
  def commit(query: str):
1431
1522
  def prepare_commit_yaml():
1432
1523
  auto_coder_main(["next", "chat_action"])
@@ -1435,15 +1526,15 @@ def commit(query: str):
1435
1526
 
1436
1527
  # no_diff = query.strip().startswith("/no_diff")
1437
1528
  # if no_diff:
1438
- # query = query.replace("/no_diff", "", 1).strip()
1529
+ # query = query.replace("/no_diff", "", 1).strip()
1439
1530
 
1440
1531
  latest_yaml_file = get_last_yaml_file("actions")
1441
-
1532
+
1442
1533
  conf = memory.get("conf", {})
1443
1534
  current_files = memory["current_files"]["files"]
1444
1535
  execute_file = None
1445
-
1446
- if latest_yaml_file:
1536
+
1537
+ if latest_yaml_file:
1447
1538
  try:
1448
1539
  execute_file = os.path.join("actions", latest_yaml_file)
1449
1540
  yaml_config = {
@@ -1469,26 +1560,30 @@ def commit(query: str):
1469
1560
  temp_yaml = os.path.join("actions", f"{uuid.uuid4()}.yml")
1470
1561
  try:
1471
1562
  with open(temp_yaml, "w") as f:
1472
- f.write(convert_yaml_config_to_str(yaml_config=yaml_config))
1563
+ f.write(convert_yaml_config_to_str(
1564
+ yaml_config=yaml_config))
1473
1565
  args = convert_yaml_to_config(temp_yaml)
1474
1566
  finally:
1475
1567
  if os.path.exists(temp_yaml):
1476
1568
  os.remove(temp_yaml)
1477
-
1478
- llm = byzerllm.ByzerLLM.from_default_model(args.code_model or args.model)
1479
- uncommitted_changes = git_utils.get_uncommitted_changes(".")
1569
+
1570
+ llm = byzerllm.ByzerLLM.from_default_model(
1571
+ args.code_model or args.model)
1572
+ uncommitted_changes = git_utils.get_uncommitted_changes(".")
1480
1573
  commit_message = git_utils.generate_commit_message.with_llm(
1481
1574
  llm).run(uncommitted_changes)
1482
- memory["conversation"].append({"role": "user", "content": commit_message})
1575
+ memory["conversation"].append(
1576
+ {"role": "user", "content": commit_message})
1483
1577
  yaml_config["query"] = commit_message
1484
- yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1578
+ yaml_content = convert_yaml_config_to_str(yaml_config=yaml_config)
1485
1579
  with open(os.path.join(execute_file), "w") as f:
1486
- f.write(yaml_content)
1487
-
1580
+ f.write(yaml_content)
1581
+
1488
1582
  file_content = open(execute_file).read()
1489
- md5 = hashlib.md5(file_content.encode('utf-8')).hexdigest()
1490
- file_name = os.path.basename(execute_file)
1491
- commit_result = git_utils.commit_changes(".", f"auto_coder_{file_name}_{md5}")
1583
+ md5 = hashlib.md5(file_content.encode('utf-8')).hexdigest()
1584
+ file_name = os.path.basename(execute_file)
1585
+ commit_result = git_utils.commit_changes(
1586
+ ".", f"auto_coder_{file_name}_{md5}")
1492
1587
  git_utils.print_commit_info(commit_result=commit_result)
1493
1588
  except Exception as e:
1494
1589
  print(f"Failed to commit: {e}")
@@ -1502,6 +1597,14 @@ def coding(query: str):
1502
1597
  if is_apply:
1503
1598
  query = query.replace("/apply", "", 1).strip()
1504
1599
 
1600
+ is_next = query.strip().startswith("/next")
1601
+ if is_next:
1602
+ query = query.replace("/next", "", 1).strip()
1603
+
1604
+ if is_next:
1605
+ code_next(query)
1606
+ return
1607
+
1505
1608
  memory["conversation"].append({"role": "user", "content": query})
1506
1609
  conf = memory.get("conf", {})
1507
1610
 
@@ -2198,7 +2301,7 @@ def main():
2198
2301
  mode = memory["mode"]
2199
2302
  human_as_model = memory["conf"].get("human_as_model", "false")
2200
2303
  return (
2201
- f" Mode: {MODES[mode]} (ctl+k) | Human as Model: {human_as_model} (ctl+n)"
2304
+ f" Mode: {MODES[mode]} (ctl+k) | Human as Model: {human_as_model} (ctl+n or /conf human_as_model:true/false)"
2202
2305
  )
2203
2306
 
2204
2307
  session = PromptSession(
autocoder/command_args.py CHANGED
@@ -102,7 +102,11 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
102
102
  "--print_request", action="store_true", help=desc["print_request"]
103
103
  )
104
104
  parser.add_argument("--code_model", default="", help=desc["code_model"])
105
- parser.add_argument("--system_prompt", default="", help=desc["system_prompt"])
105
+ parser.add_argument("--generate_rerank_model", default="", help=desc["generate_rerank_model"])
106
+ parser.add_argument("--inference_model", default="",
107
+ help="The name of the inference model to use. Default is empty")
108
+ parser.add_argument("--system_prompt", default="",
109
+ help=desc["system_prompt"])
106
110
  parser.add_argument("--planner_model", default="",
107
111
  help=desc["planner_model"])
108
112
  parser.add_argument(
@@ -111,6 +115,9 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
111
115
  parser.add_argument(
112
116
  "--human_as_model", action="store_true", help=desc["human_as_model"]
113
117
  )
118
+ parser.add_argument(
119
+ "--human_model_num", type=int, default=1, help=desc["human_model_num"]
120
+ )
114
121
  parser.add_argument("--urls", default="", help=desc["urls"])
115
122
  parser.add_argument(
116
123
  "--urls_use_model", action="store_true", help=desc["urls_use_model"]
@@ -129,6 +136,13 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
129
136
  "--search_engine_token", default="", help=desc["search_engine_token"]
130
137
  )
131
138
 
139
+ parser.add_argument(
140
+ "--generate_times_same_model",
141
+ type=int,
142
+ default=1,
143
+ help=desc["generate_times_same_model"],
144
+ )
145
+
132
146
  parser.add_argument(
133
147
  "--enable_rag_search",
134
148
  nargs="?",
@@ -196,7 +210,6 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
196
210
  help="是否静默执行,不打印任何信息。默认为False",
197
211
  )
198
212
 
199
-
200
213
  revert_parser = subparsers.add_parser("revert", help=desc["revert_desc"])
201
214
  revert_parser.add_argument("--file", help=desc["revert_desc"])
202
215
  revert_parser.add_argument(
@@ -554,7 +567,8 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
554
567
 
555
568
  read_project_parser.add_argument("--rag_token", default="", help="")
556
569
  read_project_parser.add_argument("--rag_url", default="", help="")
557
- read_project_parser.add_argument("--rag_params_max_tokens", default=4096, help="")
570
+ read_project_parser.add_argument(
571
+ "--rag_params_max_tokens", default=4096, help="")
558
572
  read_project_parser.add_argument(
559
573
  "--rag_type", default="storage", help="RAG type, default is storage"
560
574
  )
@@ -644,7 +658,8 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
644
658
 
645
659
  auto_tool_parser.add_argument("--rag_token", default="", help="")
646
660
  auto_tool_parser.add_argument("--rag_url", default="", help="")
647
- auto_tool_parser.add_argument("--rag_params_max_tokens", default=4096, help="")
661
+ auto_tool_parser.add_argument(
662
+ "--rag_params_max_tokens", default=4096, help="")
648
663
  auto_tool_parser.add_argument(
649
664
  "--rag_type", default="storage", help="RAG type, default is storage"
650
665
  )
@@ -710,7 +725,8 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
710
725
 
711
726
  planner_parser.add_argument("--rag_token", default="", help="")
712
727
  planner_parser.add_argument("--rag_url", default="", help="")
713
- planner_parser.add_argument("--rag_params_max_tokens", default=4096, help="")
728
+ planner_parser.add_argument(
729
+ "--rag_params_max_tokens", default=4096, help="")
714
730
  planner_parser.add_argument(
715
731
  "--rag_type", default="storage", help="RAG type, default is storage"
716
732
  )
@@ -248,9 +248,12 @@ class AutoCoderArgs(pydantic.BaseModel):
248
248
  sd_model: Optional[str] = ""
249
249
  emb_model: Optional[str] = ""
250
250
  code_model: Optional[str] = ""
251
+ generate_rerank_model: Optional[str] = ""
252
+ inference_model: Optional[str] = ""
251
253
  system_prompt: Optional[str] = ""
252
- text2voice_model: Optional[str] = ""
254
+ planner_model: Optional[str] = ""
253
255
  voice2text_model: Optional[str] = ""
256
+ text2voice_model: Optional[str] = ""
254
257
 
255
258
  skip_build_index: Optional[bool] = False
256
259
  skip_filter_index: Optional[bool] = False
@@ -291,6 +294,7 @@ class AutoCoderArgs(pydantic.BaseModel):
291
294
 
292
295
  auto_merge: Optional[Union[bool, str]] = False
293
296
  human_as_model: Optional[bool] = False
297
+ human_model_num: Optional[int] = 1
294
298
 
295
299
  image_file: Optional[str] = ""
296
300
  image_mode: Optional[str] = "direct"
@@ -341,6 +345,8 @@ class AutoCoderArgs(pydantic.BaseModel):
341
345
  inference_compute_precision: int = 64
342
346
  without_contexts: Optional[bool] = False
343
347
  skip_events: Optional[bool] = False
348
+ data_cells_max_num: Optional[int] = 2000
349
+ generate_times_same_model: Optional[int] = 1
344
350
 
345
351
  class Config:
346
352
  protected_namespaces = ()
@@ -4,6 +4,8 @@ from autocoder.common import AutoCoderArgs
4
4
  import byzerllm
5
5
  from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
6
6
  from autocoder.common import sys_prompt
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from autocoder.common.types import CodeGenerateResult
7
9
 
8
10
 
9
11
  class CodeAutoGenerate:
@@ -12,13 +14,15 @@ class CodeAutoGenerate:
12
14
  ) -> None:
13
15
  self.llm = llm
14
16
  self.args = args
15
- self.action = action
17
+ self.action = action
18
+ self.generate_times_same_model = args.generate_times_same_model
16
19
  if not self.llm:
17
20
  raise ValueError(
18
21
  "Please provide a valid model instance to use for code generation."
19
22
  )
20
- if self.llm.get_sub_client("code_model"):
21
- self.llm = self.llm.get_sub_client("code_model")
23
+ self.llms = self.llm.get_sub_client("code_model") or [self.llm]
24
+ if not isinstance(self.llms, list):
25
+ self.llms = [self.llms]
22
26
 
23
27
  @byzerllm.prompt(llm=lambda self: self.llm)
24
28
  def auto_implement_function(self, instruction: str, content: str) -> str:
@@ -145,7 +149,7 @@ class CodeAutoGenerate:
145
149
 
146
150
  def single_round_run(
147
151
  self, query: str, source_content: str
148
- ) -> Tuple[str, Dict[str, str]]:
152
+ ) -> Tuple[List[str], Dict[str, str]]:
149
153
  llm_config = {"human_as_model": self.args.human_as_model}
150
154
 
151
155
  if self.args.request_id and not self.args.skip_events:
@@ -178,9 +182,27 @@ class CodeAutoGenerate:
178
182
 
179
183
  conversations.append({"role": "user", "content": init_prompt})
180
184
 
181
-
182
- t = self.llm.chat_oai(conversations=conversations, llm_config=llm_config)
183
- conversations.append({"role": "assistant", "content": t[0].output})
185
+ conversations_list = []
186
+ results = []
187
+ if not self.args.human_as_model:
188
+ with ThreadPoolExecutor(max_workers=len(self.llms) * self.generate_times_same_model) as executor:
189
+ futures = []
190
+ for llm in self.llms:
191
+ for _ in range(self.generate_times_same_model):
192
+ futures.append(executor.submit(
193
+ llm.chat_oai, conversations=conversations, llm_config=llm_config))
194
+ results = [future.result()[0].output for future in futures]
195
+ for result in results:
196
+ conversations_list.append(
197
+ conversations + [{"role": "assistant", "content": result}])
198
+ else:
199
+ results = []
200
+ conversations_list = []
201
+ for _ in range(self.args.human_model_num):
202
+ v = self.llms[0].chat_oai(
203
+ conversations=conversations, llm_config=llm_config)
204
+ results.append(v[0].output)
205
+ conversations_list.append(conversations + [{"role": "assistant", "content": v[0].output}])
184
206
 
185
207
  if self.args.request_id and not self.args.skip_events:
186
208
  queue_communicate.send_event_no_wait(
@@ -191,7 +213,7 @@ class CodeAutoGenerate:
191
213
  ),
192
214
  )
193
215
 
194
- return [t[0].output], conversations
216
+ return CodeGenerateResult(contents=results, conversations=conversations_list)
195
217
 
196
218
  def multi_round_run(
197
219
  self, query: str, source_content: str, max_steps: int = 10
@@ -246,6 +268,6 @@ class CodeAutoGenerate:
246
268
  or "/done" in t[0].output
247
269
  or "__EOF__" in t[0].output
248
270
  ):
249
- return result, conversations
271
+ return CodeGenerateResult(contents=["\n\n".join(result)], conversations=[conversations])
250
272
 
251
- return result, conversations
273
+ return CodeGenerateResult(contents=["\n\n".join(result)], conversations=[conversations])