vision-agent 0.2.161__py3-none-any.whl → 0.2.162__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,32 +2,33 @@ import copy
2
2
  import logging
3
3
  import os
4
4
  import sys
5
- from json import JSONDecodeError
6
5
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
8
7
 
9
- from rich.console import Console
10
- from rich.style import Style
11
- from rich.syntax import Syntax
12
8
  from tabulate import tabulate
13
9
 
14
10
  import vision_agent.tools as T
15
- from vision_agent.agent import Agent
11
+ from vision_agent.agent.agent import Agent
16
12
  from vision_agent.agent.agent_utils import (
13
+ DefaultImports,
17
14
  extract_code,
18
15
  extract_json,
16
+ format_memory,
17
+ print_code,
19
18
  remove_installs_from_code,
20
19
  )
21
20
  from vision_agent.agent.vision_agent_coder_prompts import (
22
21
  CODE,
23
22
  FIX_BUG,
24
23
  FULL_TASK,
25
- PICK_PLAN,
26
- PLAN,
27
- PREVIOUS_FAILED,
28
24
  SIMPLE_TEST,
29
- TEST_PLANS,
30
- USER_REQ,
25
+ )
26
+ from vision_agent.agent.vision_agent_planner import (
27
+ AnthropicVisionAgentPlanner,
28
+ AzureVisionAgentPlanner,
29
+ OllamaVisionAgentPlanner,
30
+ OpenAIVisionAgentPlanner,
31
+ PlanContext,
31
32
  )
32
33
  from vision_agent.lmm import (
33
34
  LMM,
@@ -40,241 +41,11 @@ from vision_agent.lmm import (
40
41
  from vision_agent.tools.meta_tools import get_diff
41
42
  from vision_agent.utils import CodeInterpreterFactory, Execution
42
43
  from vision_agent.utils.execute import CodeInterpreter
43
- from vision_agent.utils.image_utils import b64_to_pil
44
- from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
45
- from vision_agent.utils.video import play_video
46
44
 
47
45
  logging.basicConfig(stream=sys.stdout)
48
46
  WORKSPACE = Path(os.getenv("WORKSPACE", ""))
49
47
  _LOGGER = logging.getLogger(__name__)
50
48
  _MAX_TABULATE_COL_WIDTH = 80
51
- _CONSOLE = Console()
52
-
53
-
54
- class DefaultImports:
55
- """Container for default imports used in the code execution."""
56
-
57
- common_imports = [
58
- "import os",
59
- "import numpy as np",
60
- "from vision_agent.tools import *",
61
- "from typing import *",
62
- "from pillow_heif import register_heif_opener",
63
- "register_heif_opener()",
64
- ]
65
-
66
- @staticmethod
67
- def to_code_string() -> str:
68
- return "\n".join(DefaultImports.common_imports + T.__new_tools__)
69
-
70
- @staticmethod
71
- def prepend_imports(code: str) -> str:
72
- """Run this method to prepend the default imports to the code.
73
- NOTE: be sure to run this method after the custom tools have been registered.
74
- """
75
- return DefaultImports.to_code_string() + "\n\n" + code
76
-
77
-
78
- def format_memory(memory: List[Dict[str, str]]) -> str:
79
- output_str = ""
80
- for i, m in enumerate(memory):
81
- output_str += f"### Feedback {i}:\n"
82
- output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
83
- output_str += f"Feedback {i}: {m['feedback']}\n\n"
84
- if "edits" in m:
85
- output_str += f"Edits {i}:\n{m['edits']}\n"
86
- output_str += "\n"
87
-
88
- return output_str
89
-
90
-
91
- def format_plans(plans: Dict[str, Any]) -> str:
92
- plan_str = ""
93
- for k, v in plans.items():
94
- plan_str += "\n" + f"{k}: {v['thoughts']}\n"
95
- plan_str += " -" + "\n -".join([e for e in v["instructions"]])
96
-
97
- return plan_str
98
-
99
-
100
- def write_plans(
101
- chat: List[Message],
102
- tool_desc: str,
103
- working_memory: str,
104
- model: LMM,
105
- ) -> Dict[str, Any]:
106
- chat = copy.deepcopy(chat)
107
- if chat[-1]["role"] != "user":
108
- raise ValueError("Last chat message must be from the user.")
109
-
110
- user_request = chat[-1]["content"]
111
- context = USER_REQ.format(user_request=user_request)
112
- prompt = PLAN.format(
113
- context=context,
114
- tool_desc=tool_desc,
115
- feedback=working_memory,
116
- )
117
- chat[-1]["content"] = prompt
118
- return extract_json(model(chat, stream=False)) # type: ignore
119
-
120
-
121
- def pick_plan(
122
- chat: List[Message],
123
- plans: Dict[str, Any],
124
- tool_info: str,
125
- model: LMM,
126
- code_interpreter: CodeInterpreter,
127
- media: List[str],
128
- log_progress: Callable[[Dict[str, Any]], None],
129
- verbosity: int = 0,
130
- max_retries: int = 3,
131
- ) -> Tuple[Dict[str, str], str]:
132
- log_progress(
133
- {
134
- "type": "log",
135
- "log_content": "Generating code to pick the best plan",
136
- "status": "started",
137
- }
138
- )
139
-
140
- chat = copy.deepcopy(chat)
141
- if chat[-1]["role"] != "user":
142
- raise ValueError("Last chat message must be from the user.")
143
-
144
- plan_str = format_plans(plans)
145
- prompt = TEST_PLANS.format(
146
- docstring=tool_info, plans=plan_str, previous_attempts="", media=media
147
- )
148
-
149
- code = extract_code(model(prompt, stream=False)) # type: ignore
150
- log_progress(
151
- {
152
- "type": "log",
153
- "log_content": "Executing code to test plans",
154
- "code": DefaultImports.prepend_imports(code),
155
- "status": "running",
156
- }
157
- )
158
- tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
159
- # Because of the way we trace function calls the trace information ends up in the
160
- # results. We don't want to show this info to the LLM so we don't include it in the
161
- # tool_output_str.
162
- tool_output_str = tool_output.text(include_results=False).strip()
163
-
164
- if verbosity == 2:
165
- _print_code("Initial code and tests:", code)
166
- _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
167
-
168
- log_progress(
169
- {
170
- "type": "log",
171
- "log_content": (
172
- "Code execution succeeded"
173
- if tool_output.success
174
- else "Code execution failed"
175
- ),
176
- "code": DefaultImports.prepend_imports(code),
177
- # "payload": tool_output.to_json(),
178
- "status": "completed" if tool_output.success else "failed",
179
- }
180
- )
181
-
182
- # retry if the tool output is empty or code fails
183
- count = 0
184
- while (
185
- not tool_output.success
186
- or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
187
- ) and count < max_retries:
188
- prompt = TEST_PLANS.format(
189
- docstring=tool_info,
190
- plans=plan_str,
191
- previous_attempts=PREVIOUS_FAILED.format(
192
- code=code, error="\n".join(tool_output_str.splitlines()[-50:])
193
- ),
194
- media=media,
195
- )
196
- log_progress(
197
- {
198
- "type": "log",
199
- "log_content": "Retrying code to test plans",
200
- "status": "running",
201
- "code": DefaultImports.prepend_imports(code),
202
- }
203
- )
204
- code = extract_code(model(prompt, stream=False)) # type: ignore
205
- tool_output = code_interpreter.exec_isolation(
206
- DefaultImports.prepend_imports(code)
207
- )
208
- log_progress(
209
- {
210
- "type": "log",
211
- "log_content": (
212
- "Code execution succeeded"
213
- if tool_output.success
214
- else "Code execution failed"
215
- ),
216
- "code": DefaultImports.prepend_imports(code),
217
- # "payload": tool_output.to_json(),
218
- "status": "completed" if tool_output.success else "failed",
219
- }
220
- )
221
- tool_output_str = tool_output.text(include_results=False).strip()
222
-
223
- if verbosity == 2:
224
- _print_code("Code and test after attempted fix:", code)
225
- _LOGGER.info(f"Code execution result after attempt {count + 1}")
226
- _LOGGER.info(f"{tool_output_str}")
227
-
228
- count += 1
229
-
230
- if verbosity >= 1:
231
- _print_code("Final code:", code)
232
-
233
- user_req = chat[-1]["content"]
234
- context = USER_REQ.format(user_request=user_req)
235
- # because the tool picker model gets the image as well, we have to be careful with
236
- # how much text we send it, so we truncate the tool output to 20,000 characters
237
- prompt = PICK_PLAN.format(
238
- context=context,
239
- plans=format_plans(plans),
240
- tool_output=tool_output_str[:20_000],
241
- )
242
- chat[-1]["content"] = prompt
243
-
244
- count = 0
245
- plan_thoughts = None
246
- while plan_thoughts is None and count < max_retries:
247
- try:
248
- plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
249
- except JSONDecodeError as e:
250
- _LOGGER.exception(
251
- f"Error while extracting JSON during picking best plan {str(e)}"
252
- )
253
- pass
254
- count += 1
255
-
256
- if (
257
- plan_thoughts is None
258
- or "best_plan" not in plan_thoughts
259
- or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
260
- ):
261
- _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
262
- plan_thoughts = {"best_plan": list(plans.keys())[0]}
263
-
264
- if "thoughts" not in plan_thoughts:
265
- plan_thoughts["thoughts"] = ""
266
-
267
- if verbosity >= 1:
268
- _LOGGER.info(f"Best plan:\n{plan_thoughts}")
269
- log_progress(
270
- {
271
- "type": "log",
272
- "log_content": "Picked best plan",
273
- "status": "completed",
274
- "payload": plans[plan_thoughts["best_plan"]],
275
- }
276
- )
277
- return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str
278
49
 
279
50
 
280
51
  def write_code(
@@ -393,7 +164,7 @@ def write_and_test_code(
393
164
  }
394
165
  )
395
166
  if verbosity == 2:
396
- _print_code("Initial code and tests:", code, test)
167
+ print_code("Initial code and tests:", code, test)
397
168
  _LOGGER.info(
398
169
  f"Initial code execution result:\n{result.text(include_logs=True)}"
399
170
  )
@@ -418,7 +189,7 @@ def write_and_test_code(
418
189
  count += 1
419
190
 
420
191
  if verbosity >= 1:
421
- _print_code("Final code and tests:", code, test)
192
+ print_code("Final code and tests:", code, test)
422
193
 
423
194
  return {
424
195
  "code": code,
@@ -537,7 +308,7 @@ def debug_code(
537
308
  }
538
309
  )
539
310
  if verbosity == 2:
540
- _print_code("Code and test after attempted fix:", code, test)
311
+ print_code("Code and test after attempted fix:", code, test)
541
312
  _LOGGER.info(
542
313
  f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
543
314
  )
@@ -545,62 +316,6 @@ def debug_code(
545
316
  return code, test, result
546
317
 
547
318
 
548
- def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
549
- _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
550
- _CONSOLE.print("=" * 30 + " Code " + "=" * 30)
551
- _CONSOLE.print(
552
- Syntax(
553
- DefaultImports.prepend_imports(code),
554
- "python",
555
- theme="gruvbox-dark",
556
- line_numbers=True,
557
- )
558
- )
559
- if test:
560
- _CONSOLE.print("=" * 30 + " Test " + "=" * 30)
561
- _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
562
-
563
-
564
- def retrieve_tools(
565
- plans: Dict[str, Dict[str, Any]],
566
- tool_recommender: Sim,
567
- log_progress: Callable[[Dict[str, Any]], None],
568
- verbosity: int = 0,
569
- ) -> Dict[str, str]:
570
- log_progress(
571
- {
572
- "type": "log",
573
- "log_content": ("Retrieving tools for each plan"),
574
- "status": "started",
575
- }
576
- )
577
- tool_info = []
578
- tool_desc = []
579
- tool_lists: Dict[str, List[Dict[str, str]]] = {}
580
- for k, plan in plans.items():
581
- tool_lists[k] = []
582
- for task in plan["instructions"]:
583
- tools = tool_recommender.top_k(task, k=2, thresh=0.3)
584
- tool_info.extend([e["doc"] for e in tools])
585
- tool_desc.extend([e["desc"] for e in tools])
586
- tool_lists[k].extend(
587
- {"description": e["desc"], "documentation": e["doc"]} for e in tools
588
- )
589
-
590
- if verbosity == 2:
591
- tool_desc_str = "\n".join(set(tool_desc))
592
- _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
593
-
594
- tool_lists_unique = {}
595
- for k in tool_lists:
596
- tool_lists_unique[k] = "\n\n".join(
597
- set(e["documentation"] for e in tool_lists[k])
598
- )
599
- all_tools = "\n\n".join(set(tool_info))
600
- tool_lists_unique["all"] = all_tools
601
- return tool_lists_unique
602
-
603
-
604
319
  class VisionAgentCoder(Agent):
605
320
  """Vision Agent Coder is an agentic framework that can output code based on a user
606
321
  request. It can plan tasks, retrieve relevant tools, write code, write tests and
@@ -616,23 +331,22 @@ class VisionAgentCoder(Agent):
616
331
 
617
332
  def __init__(
618
333
  self,
619
- planner: Optional[LMM] = None,
334
+ planner: Optional[Agent] = None,
620
335
  coder: Optional[LMM] = None,
621
336
  tester: Optional[LMM] = None,
622
337
  debugger: Optional[LMM] = None,
623
- tool_recommender: Optional[Sim] = None,
624
338
  verbosity: int = 0,
625
339
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
626
- code_sandbox_runtime: Optional[str] = None,
340
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
627
341
  ) -> None:
628
342
  """Initialize the Vision Agent Coder.
629
343
 
630
344
  Parameters:
631
- planner (Optional[LMM]): The planner model to use. Defaults to AnthropicLMM.
345
+ planner (Optional[Agent]): The planner model to use. Defaults to
346
+ AnthropicVisionAgentPlanner.
632
347
  coder (Optional[LMM]): The coder model to use. Defaults to AnthropicLMM.
633
348
  tester (Optional[LMM]): The tester model to use. Defaults to AnthropicLMM.
634
349
  debugger (Optional[LMM]): The debugger model to use. Defaults to AnthropicLMM.
635
- tool_recommender (Optional[Sim]): The tool recommender model to use.
636
350
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
637
351
  highest verbosity level which will output all intermediate debugging
638
352
  code.
@@ -641,14 +355,17 @@ class VisionAgentCoder(Agent):
641
355
  in a web application where multiple VisionAgentCoder instances are
642
356
  running in parallel. This callback ensures that the progress are not
643
357
  mixed up.
644
- code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
645
- code sandbox is used to run the generated code. It can be one of the
646
- following values: None, "local" or "e2b". If None, VisionAgentCoder
647
- will read the value from the environment variable CODE_SANDBOX_RUNTIME.
648
- If it's also None, the local python runtime environment will be used.
358
+ code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
359
+ it can be one of: None, "local" or "e2b". If None, it will read from
360
+ the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
361
+ object is provided it will use that.
649
362
  """
650
363
 
651
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
364
+ self.planner = (
365
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
366
+ if planner is None
367
+ else planner
368
+ )
652
369
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
653
370
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
654
371
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -656,21 +373,15 @@ class VisionAgentCoder(Agent):
656
373
  if self.verbosity > 0:
657
374
  _LOGGER.setLevel(logging.INFO)
658
375
 
659
- self.tool_recommender = (
660
- Sim(T.TOOLS_DF, sim_key="desc")
661
- if tool_recommender is None
662
- else tool_recommender
663
- )
664
376
  self.report_progress_callback = report_progress_callback
665
- self.code_sandbox_runtime = code_sandbox_runtime
377
+ self.code_interpreter = code_interpreter
666
378
 
667
379
  def __call__(
668
380
  self,
669
381
  input: Union[str, List[Message]],
670
382
  media: Optional[Union[str, Path]] = None,
671
383
  ) -> str:
672
- """Chat with VisionAgentCoder and return intermediate information regarding the
673
- task.
384
+ """Generate code based on a user request.
674
385
 
675
386
  Parameters:
676
387
  input (Union[str, List[Message]]): A conversation in the format of
@@ -686,46 +397,58 @@ class VisionAgentCoder(Agent):
686
397
  input = [{"role": "user", "content": input}]
687
398
  if media is not None:
688
399
  input[0]["media"] = [media]
689
- results = self.chat_with_workflow(input)
690
- results.pop("working_memory")
691
- return results["code"] # type: ignore
400
+ code_and_context = self.generate_code(input)
401
+ return code_and_context["code"] # type: ignore
692
402
 
693
- def chat_with_workflow(
403
+ def generate_code_from_plan(
694
404
  self,
695
405
  chat: List[Message],
696
- test_multi_plan: bool = True,
697
- display_visualization: bool = False,
698
- custom_tool_names: Optional[List[str]] = None,
406
+ plan_context: PlanContext,
407
+ code_interpreter: Optional[CodeInterpreter] = None,
699
408
  ) -> Dict[str, Any]:
700
- """Chat with VisionAgentCoder and return intermediate information regarding the
701
- task.
409
+ """Generates code and other intermediate outputs from a chat input and a plan.
410
+ The plan includes:
411
+ - plans: The plans generated by the planner.
412
+ - best_plan: The best plan selected by the planner.
413
+ - plan_thoughts: The thoughts of the planner, including any modifications
414
+ to the plan.
415
+ - tool_doc: The tool documentation for the best plan.
416
+ - tool_output: The tool output from the tools used by the best plan.
702
417
 
703
418
  Parameters:
704
- chat (List[Message]): A conversation
705
- in the format of:
706
- [{"role": "user", "content": "describe your task here..."}]
707
- or if it contains media files, it should be in the format of:
708
- [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
709
- test_multi_plan (bool): If True, it will test tools for multiple plans and
710
- pick the best one based off of the tool results. If False, it will go
711
- with the first plan.
712
- display_visualization (bool): If True, it opens a new window locally to
713
- show the image(s) created by visualization code (if there is any).
714
- custom_tool_names (List[str]): A list of custom tools for the agent to pick
715
- and use. If not provided, default to full tool set from vision_agent.tools.
419
+ chat (List[Message]): A conversation in the format of
420
+ [{"role": "user", "content": "describe your task here..."}].
421
+ plan_context (PlanContext): The context of the plan, including the plans,
422
+ best_plan, plan_thoughts, tool_doc, and tool_output.
423
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
424
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
425
+ for the planner.
716
426
 
717
427
  Returns:
718
- Dict[str, Any]: A dictionary containing the code, test, test result, plan,
719
- and working memory of the agent.
428
+ Dict[str, Any]: A dictionary containing the code output by the
429
+ VisionAgentCoder and other intermediate outputs. include:
430
+ - status (str): Whether or not the agent completed or failed generating
431
+ the code.
432
+ - code (str): The code output by the VisionAgentCoder.
433
+ - test (str): The test output by the VisionAgentCoder.
434
+ - test_result (Execution): The result of the test execution.
435
+ - plans (Dict[str, Any]): The plans generated by the planner.
436
+ - plan_thoughts (str): The thoughts of the planner.
437
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
720
438
  """
721
-
722
439
  if not chat:
723
440
  raise ValueError("Chat cannot be empty.")
724
441
 
725
442
  # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
726
- with CodeInterpreterFactory.new_instance(
727
- code_sandbox_runtime=self.code_sandbox_runtime
728
- ) as code_interpreter:
443
+ code_interpreter = (
444
+ self.code_interpreter
445
+ if self.code_interpreter is not None
446
+ and not isinstance(self.code_interpreter, str)
447
+ else CodeInterpreterFactory.new_instance(
448
+ code_sandbox_runtime=self.code_interpreter,
449
+ )
450
+ )
451
+ with code_interpreter:
729
452
  chat = copy.deepcopy(chat)
730
453
  media_list = []
731
454
  for chat_i in chat:
@@ -759,74 +482,22 @@ class VisionAgentCoder(Agent):
759
482
  code = ""
760
483
  test = ""
761
484
  working_memory: List[Dict[str, str]] = []
762
- results = {"code": "", "test": "", "plan": []}
763
- plan = []
764
- success = False
765
-
766
- plans = self._create_plans(
767
- int_chat, custom_tool_names, working_memory, self.planner
768
- )
769
-
770
- if test_multi_plan:
771
- self._log_plans(plans, self.verbosity)
772
-
773
- tool_infos = retrieve_tools(
774
- plans,
775
- self.tool_recommender,
776
- self.log_progress,
777
- self.verbosity,
778
- )
779
-
780
- if test_multi_plan:
781
- plan_thoughts, tool_output_str = pick_plan(
782
- int_chat,
783
- plans,
784
- tool_infos["all"],
785
- self.coder,
786
- code_interpreter,
787
- media_list,
788
- self.log_progress,
789
- verbosity=self.verbosity,
790
- )
791
- best_plan = plan_thoughts["best_plan"]
792
- plan_thoughts_str = plan_thoughts["thoughts"]
793
- else:
794
- best_plan = list(plans.keys())[0]
795
- tool_output_str = ""
796
- plan_thoughts_str = ""
797
-
798
- if best_plan in plans and best_plan in tool_infos:
799
- plan_i = plans[best_plan]
800
- tool_info = tool_infos[best_plan]
801
- else:
802
- if self.verbosity >= 1:
803
- _LOGGER.warning(
804
- f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
805
- )
806
- k = list(plans.keys())[0]
807
- plan_i = plans[k]
808
- tool_info = tool_infos[k]
809
-
810
- self.log_progress(
811
- {
812
- "type": "log",
813
- "log_content": "Creating plans",
814
- "status": "completed",
815
- "payload": tool_info,
816
- }
817
- )
485
+ plan = plan_context.plans[plan_context.best_plan]
486
+ tool_doc = plan_context.tool_doc
487
+ tool_output_str = plan_context.tool_output
488
+ plan_thoughts_str = str(plan_context.plan_thoughts)
818
489
 
819
490
  if self.verbosity >= 1:
820
- plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]]
491
+ plan_fixed = [{"instructions": e} for e in plan["instructions"]]
821
492
  _LOGGER.info(
822
- f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
493
+ f"Picked best plan:\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
823
494
  )
824
495
 
825
496
  results = write_and_test_code(
826
497
  chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
827
- plan=f"\n{plan_i['thoughts']}\n-"
828
- + "\n-".join([e for e in plan_i["instructions"]]),
829
- tool_info=tool_info,
498
+ plan=f"\n{plan['thoughts']}\n-"
499
+ + "\n-".join([e for e in plan["instructions"]]),
500
+ tool_info=tool_doc,
830
501
  tool_output=tool_output_str,
831
502
  plan_thoughts=plan_thoughts_str,
832
503
  tool_utils=T.UTILITIES_DOCSTRING,
@@ -842,64 +513,83 @@ class VisionAgentCoder(Agent):
842
513
  success = cast(bool, results["success"])
843
514
  code = remove_installs_from_code(cast(str, results["code"]))
844
515
  test = remove_installs_from_code(cast(str, results["test"]))
845
- working_memory.extend(results["working_memory"]) # type: ignore
846
- plan.append({"code": code, "test": test, "plan": plan_i})
516
+ working_memory.extend(results["working_memory"])
847
517
 
848
518
  execution_result = cast(Execution, results["test_result"])
849
519
 
850
- if display_visualization:
851
- for res in execution_result.results:
852
- if res.png:
853
- b64_to_pil(res.png).show()
854
- if res.mp4:
855
- play_video(res.mp4)
856
-
857
520
  return {
858
521
  "status": "completed" if success else "failed",
859
522
  "code": DefaultImports.prepend_imports(code),
860
523
  "test": test,
861
524
  "test_result": execution_result,
862
- "plans": plans,
525
+ "plans": plan_context.plans,
863
526
  "plan_thoughts": plan_thoughts_str,
864
527
  "working_memory": working_memory,
865
528
  }
866
529
 
867
- def log_progress(self, data: Dict[str, Any]) -> None:
868
- if self.report_progress_callback is not None:
869
- self.report_progress_callback(data)
870
-
871
- def _create_plans(
530
+ def generate_code(
872
531
  self,
873
- int_chat: List[Message],
874
- customized_tool_names: Optional[List[str]],
875
- working_memory: List[Dict[str, str]],
876
- planner: LMM,
532
+ chat: List[Message],
533
+ test_multi_plan: bool = True,
534
+ custom_tool_names: Optional[List[str]] = None,
877
535
  ) -> Dict[str, Any]:
878
- self.log_progress(
879
- {
880
- "type": "log",
881
- "log_content": "Creating plans",
882
- "status": "started",
883
- }
884
- )
885
- plans = write_plans(
886
- int_chat,
887
- T.get_tool_descriptions_by_names(
888
- customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
889
- ),
890
- format_memory(working_memory),
891
- planner,
536
+ """Generates code and other intermediate outputs from a chat input.
537
+
538
+ Parameters:
539
+ chat (List[Message]): A conversation in the format of
540
+ [{"role": "user", "content": "describe your task here..."}].
541
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
542
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
543
+ for the planner.
544
+
545
+ Returns:
546
+ Dict[str, Any]: A dictionary containing the code output by the
547
+ VisionAgentCoder and other intermediate outputs. include:
548
+ - status (str): Whether or not the agent completed or failed generating
549
+ the code.
550
+ - code (str): The code output by the VisionAgentCoder.
551
+ - test (str): The test output by the VisionAgentCoder.
552
+ - test_result (Execution): The result of the test execution.
553
+ - plans (Dict[str, Any]): The plans generated by the planner.
554
+ - plan_thoughts (str): The thoughts of the planner.
555
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
556
+ """
557
+ if not chat:
558
+ raise ValueError("Chat cannot be empty.")
559
+
560
+ # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
561
+ code_interpreter = (
562
+ self.code_interpreter
563
+ if self.code_interpreter is not None
564
+ and not isinstance(self.code_interpreter, str)
565
+ else CodeInterpreterFactory.new_instance(
566
+ code_sandbox_runtime=self.code_interpreter,
567
+ )
892
568
  )
893
- return plans
569
+ with code_interpreter:
570
+ plan_context = self.planner.generate_plan( # type: ignore
571
+ chat,
572
+ test_multi_plan=test_multi_plan,
573
+ custom_tool_names=custom_tool_names,
574
+ code_interpreter=code_interpreter,
575
+ )
894
576
 
895
- def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
896
- if verbosity >= 1:
897
- for p in plans:
898
- # tabulate will fail if the keys are not the same for all elements
899
- p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
900
- _LOGGER.info(
901
- f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
902
- )
577
+ code_and_context = self.generate_code_from_plan(
578
+ chat,
579
+ plan_context,
580
+ code_interpreter=code_interpreter,
581
+ )
582
+ return code_and_context
583
+
584
+ def chat(self, chat: List[Message]) -> List[Message]:
585
+ chat = copy.deepcopy(chat)
586
+ code = self.generate_code(chat)
587
+ chat.append({"role": "agent", "content": code["code"]})
588
+ return chat
589
+
590
+ def log_progress(self, data: Dict[str, Any]) -> None:
591
+ if self.report_progress_callback is not None:
592
+ self.report_progress_callback(data)
903
593
 
904
594
 
905
595
  class OpenAIVisionAgentCoder(VisionAgentCoder):
@@ -907,17 +597,18 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
907
597
 
908
598
  def __init__(
909
599
  self,
910
- planner: Optional[LMM] = None,
600
+ planner: Optional[Agent] = None,
911
601
  coder: Optional[LMM] = None,
912
602
  tester: Optional[LMM] = None,
913
603
  debugger: Optional[LMM] = None,
914
- tool_recommender: Optional[Sim] = None,
915
604
  verbosity: int = 0,
916
605
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
917
- code_sandbox_runtime: Optional[str] = None,
606
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
918
607
  ) -> None:
919
608
  self.planner = (
920
- OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
609
+ OpenAIVisionAgentPlanner(verbosity=verbosity)
610
+ if planner is None
611
+ else planner
921
612
  )
922
613
  self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
923
614
  self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
@@ -926,13 +617,8 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
926
617
  if self.verbosity > 0:
927
618
  _LOGGER.setLevel(logging.INFO)
928
619
 
929
- self.tool_recommender = (
930
- Sim(T.TOOLS_DF, sim_key="desc")
931
- if tool_recommender is None
932
- else tool_recommender
933
- )
934
620
  self.report_progress_callback = report_progress_callback
935
- self.code_sandbox_runtime = code_sandbox_runtime
621
+ self.code_interpreter = code_interpreter
936
622
 
937
623
 
938
624
  class AnthropicVisionAgentCoder(VisionAgentCoder):
@@ -940,17 +626,20 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
940
626
 
941
627
  def __init__(
942
628
  self,
943
- planner: Optional[LMM] = None,
629
+ planner: Optional[Agent] = None,
944
630
  coder: Optional[LMM] = None,
945
631
  tester: Optional[LMM] = None,
946
632
  debugger: Optional[LMM] = None,
947
- tool_recommender: Optional[Sim] = None,
948
633
  verbosity: int = 0,
949
634
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
950
- code_sandbox_runtime: Optional[str] = None,
635
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
951
636
  ) -> None:
952
637
  # NOTE: Claude doesn't have an official JSON mode
953
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
638
+ self.planner = (
639
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
640
+ if planner is None
641
+ else planner
642
+ )
954
643
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
955
644
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
956
645
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -958,15 +647,8 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
958
647
  if self.verbosity > 0:
959
648
  _LOGGER.setLevel(logging.INFO)
960
649
 
961
- # Anthropic does not offer any embedding models and instead recomends Voyage,
962
- # we're using OpenAI's embedder for now.
963
- self.tool_recommender = (
964
- Sim(T.TOOLS_DF, sim_key="desc")
965
- if tool_recommender is None
966
- else tool_recommender
967
- )
968
650
  self.report_progress_callback = report_progress_callback
969
- self.code_sandbox_runtime = code_sandbox_runtime
651
+ self.code_interpreter = code_interpreter
970
652
 
971
653
 
972
654
  class OllamaVisionAgentCoder(VisionAgentCoder):
@@ -988,17 +670,17 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
988
670
 
989
671
  def __init__(
990
672
  self,
991
- planner: Optional[LMM] = None,
673
+ planner: Optional[Agent] = None,
992
674
  coder: Optional[LMM] = None,
993
675
  tester: Optional[LMM] = None,
994
676
  debugger: Optional[LMM] = None,
995
- tool_recommender: Optional[Sim] = None,
996
677
  verbosity: int = 0,
997
678
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
679
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
998
680
  ) -> None:
999
681
  super().__init__(
1000
682
  planner=(
1001
- OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
683
+ OllamaVisionAgentPlanner(verbosity=verbosity)
1002
684
  if planner is None
1003
685
  else planner
1004
686
  ),
@@ -1017,13 +699,9 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
1017
699
  if debugger is None
1018
700
  else debugger
1019
701
  ),
1020
- tool_recommender=(
1021
- OllamaSim(T.TOOLS_DF, sim_key="desc")
1022
- if tool_recommender is None
1023
- else tool_recommender
1024
- ),
1025
702
  verbosity=verbosity,
1026
703
  report_progress_callback=report_progress_callback,
704
+ code_interpreter=code_interpreter,
1027
705
  )
1028
706
 
1029
707
 
@@ -1043,22 +721,22 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1043
721
 
1044
722
  def __init__(
1045
723
  self,
1046
- planner: Optional[LMM] = None,
724
+ planner: Optional[Agent] = None,
1047
725
  coder: Optional[LMM] = None,
1048
726
  tester: Optional[LMM] = None,
1049
727
  debugger: Optional[LMM] = None,
1050
- tool_recommender: Optional[Sim] = None,
1051
728
  verbosity: int = 0,
1052
729
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
730
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
1053
731
  ) -> None:
1054
732
  """Initialize the Vision Agent Coder.
1055
733
 
1056
734
  Parameters:
1057
- planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
735
+ planner (Optional[Agent]): The planner model to use. Defaults to
736
+ AzureVisionAgentPlanner.
1058
737
  coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
1059
738
  tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
1060
739
  debugger (Optional[LMM]): The debugger model to
1061
- tool_recommender (Optional[Sim]): The tool recommender model to use.
1062
740
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
1063
741
  highest verbosity level which will output all intermediate debugging
1064
742
  code.
@@ -1069,7 +747,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1069
747
  """
1070
748
  super().__init__(
1071
749
  planner=(
1072
- AzureOpenAILMM(temperature=0.0, json_mode=True)
750
+ AzureVisionAgentPlanner(verbosity=verbosity)
1073
751
  if planner is None
1074
752
  else planner
1075
753
  ),
@@ -1078,11 +756,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1078
756
  debugger=(
1079
757
  AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
1080
758
  ),
1081
- tool_recommender=(
1082
- AzureSim(T.TOOLS_DF, sim_key="desc")
1083
- if tool_recommender is None
1084
- else tool_recommender
1085
- ),
1086
759
  verbosity=verbosity,
1087
760
  report_progress_callback=report_progress_callback,
761
+ code_interpreter=code_interpreter,
1088
762
  )