vision-agent 0.2.161__py3-none-any.whl → 0.2.163__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,32 +2,35 @@ import copy
2
2
  import logging
3
3
  import os
4
4
  import sys
5
- from json import JSONDecodeError
6
5
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
8
7
 
9
- from rich.console import Console
10
- from rich.style import Style
11
- from rich.syntax import Syntax
8
+ from redbaron import RedBaron # type: ignore
12
9
  from tabulate import tabulate
13
10
 
14
11
  import vision_agent.tools as T
15
- from vision_agent.agent import Agent
12
+ from vision_agent.agent.agent import Agent
16
13
  from vision_agent.agent.agent_utils import (
14
+ _MAX_TABULATE_COL_WIDTH,
15
+ DefaultImports,
17
16
  extract_code,
18
- extract_json,
17
+ extract_tag,
18
+ format_memory,
19
+ print_code,
19
20
  remove_installs_from_code,
20
21
  )
21
22
  from vision_agent.agent.vision_agent_coder_prompts import (
22
23
  CODE,
23
24
  FIX_BUG,
24
25
  FULL_TASK,
25
- PICK_PLAN,
26
- PLAN,
27
- PREVIOUS_FAILED,
28
26
  SIMPLE_TEST,
29
- TEST_PLANS,
30
- USER_REQ,
27
+ )
28
+ from vision_agent.agent.vision_agent_planner import (
29
+ AnthropicVisionAgentPlanner,
30
+ AzureVisionAgentPlanner,
31
+ OllamaVisionAgentPlanner,
32
+ OpenAIVisionAgentPlanner,
33
+ PlanContext,
31
34
  )
32
35
  from vision_agent.lmm import (
33
36
  LMM,
@@ -40,241 +43,48 @@ from vision_agent.lmm import (
40
43
  from vision_agent.tools.meta_tools import get_diff
41
44
  from vision_agent.utils import CodeInterpreterFactory, Execution
42
45
  from vision_agent.utils.execute import CodeInterpreter
43
- from vision_agent.utils.image_utils import b64_to_pil
44
- from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
45
- from vision_agent.utils.video import play_video
46
46
 
47
47
  logging.basicConfig(stream=sys.stdout)
48
48
  WORKSPACE = Path(os.getenv("WORKSPACE", ""))
49
49
  _LOGGER = logging.getLogger(__name__)
50
- _MAX_TABULATE_COL_WIDTH = 80
51
- _CONSOLE = Console()
52
-
53
-
54
- class DefaultImports:
55
- """Container for default imports used in the code execution."""
56
-
57
- common_imports = [
58
- "import os",
59
- "import numpy as np",
60
- "from vision_agent.tools import *",
61
- "from typing import *",
62
- "from pillow_heif import register_heif_opener",
63
- "register_heif_opener()",
64
- ]
65
-
66
- @staticmethod
67
- def to_code_string() -> str:
68
- return "\n".join(DefaultImports.common_imports + T.__new_tools__)
69
-
70
- @staticmethod
71
- def prepend_imports(code: str) -> str:
72
- """Run this method to prepend the default imports to the code.
73
- NOTE: be sure to run this method after the custom tools have been registered.
74
- """
75
- return DefaultImports.to_code_string() + "\n\n" + code
76
-
77
-
78
- def format_memory(memory: List[Dict[str, str]]) -> str:
79
- output_str = ""
80
- for i, m in enumerate(memory):
81
- output_str += f"### Feedback {i}:\n"
82
- output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
83
- output_str += f"Feedback {i}: {m['feedback']}\n\n"
84
- if "edits" in m:
85
- output_str += f"Edits {i}:\n{m['edits']}\n"
86
- output_str += "\n"
87
-
88
- return output_str
89
-
90
-
91
- def format_plans(plans: Dict[str, Any]) -> str:
92
- plan_str = ""
93
- for k, v in plans.items():
94
- plan_str += "\n" + f"{k}: {v['thoughts']}\n"
95
- plan_str += " -" + "\n -".join([e for e in v["instructions"]])
96
-
97
- return plan_str
98
-
99
-
100
- def write_plans(
101
- chat: List[Message],
102
- tool_desc: str,
103
- working_memory: str,
104
- model: LMM,
105
- ) -> Dict[str, Any]:
106
- chat = copy.deepcopy(chat)
107
- if chat[-1]["role"] != "user":
108
- raise ValueError("Last chat message must be from the user.")
109
-
110
- user_request = chat[-1]["content"]
111
- context = USER_REQ.format(user_request=user_request)
112
- prompt = PLAN.format(
113
- context=context,
114
- tool_desc=tool_desc,
115
- feedback=working_memory,
116
- )
117
- chat[-1]["content"] = prompt
118
- return extract_json(model(chat, stream=False)) # type: ignore
119
-
120
-
121
- def pick_plan(
122
- chat: List[Message],
123
- plans: Dict[str, Any],
124
- tool_info: str,
125
- model: LMM,
126
- code_interpreter: CodeInterpreter,
127
- media: List[str],
128
- log_progress: Callable[[Dict[str, Any]], None],
129
- verbosity: int = 0,
130
- max_retries: int = 3,
131
- ) -> Tuple[Dict[str, str], str]:
132
- log_progress(
133
- {
134
- "type": "log",
135
- "log_content": "Generating code to pick the best plan",
136
- "status": "started",
137
- }
138
- )
139
50
 
140
- chat = copy.deepcopy(chat)
141
- if chat[-1]["role"] != "user":
142
- raise ValueError("Last chat message must be from the user.")
143
51
 
144
- plan_str = format_plans(plans)
145
- prompt = TEST_PLANS.format(
146
- docstring=tool_info, plans=plan_str, previous_attempts="", media=media
147
- )
148
-
149
- code = extract_code(model(prompt, stream=False)) # type: ignore
150
- log_progress(
151
- {
152
- "type": "log",
153
- "log_content": "Executing code to test plans",
154
- "code": DefaultImports.prepend_imports(code),
155
- "status": "running",
156
- }
157
- )
158
- tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
159
- # Because of the way we trace function calls the trace information ends up in the
160
- # results. We don't want to show this info to the LLM so we don't include it in the
161
- # tool_output_str.
162
- tool_output_str = tool_output.text(include_results=False).strip()
163
-
164
- if verbosity == 2:
165
- _print_code("Initial code and tests:", code)
166
- _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
167
-
168
- log_progress(
169
- {
170
- "type": "log",
171
- "log_content": (
172
- "Code execution succeeded"
173
- if tool_output.success
174
- else "Code execution failed"
175
- ),
176
- "code": DefaultImports.prepend_imports(code),
177
- # "payload": tool_output.to_json(),
178
- "status": "completed" if tool_output.success else "failed",
179
- }
180
- )
181
-
182
- # retry if the tool output is empty or code fails
183
- count = 0
184
- while (
185
- not tool_output.success
186
- or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
187
- ) and count < max_retries:
188
- prompt = TEST_PLANS.format(
189
- docstring=tool_info,
190
- plans=plan_str,
191
- previous_attempts=PREVIOUS_FAILED.format(
192
- code=code, error="\n".join(tool_output_str.splitlines()[-50:])
193
- ),
194
- media=media,
195
- )
196
- log_progress(
197
- {
198
- "type": "log",
199
- "log_content": "Retrying code to test plans",
200
- "status": "running",
201
- "code": DefaultImports.prepend_imports(code),
202
- }
203
- )
204
- code = extract_code(model(prompt, stream=False)) # type: ignore
205
- tool_output = code_interpreter.exec_isolation(
206
- DefaultImports.prepend_imports(code)
207
- )
208
- log_progress(
209
- {
210
- "type": "log",
211
- "log_content": (
212
- "Code execution succeeded"
213
- if tool_output.success
214
- else "Code execution failed"
215
- ),
216
- "code": DefaultImports.prepend_imports(code),
217
- # "payload": tool_output.to_json(),
218
- "status": "completed" if tool_output.success else "failed",
219
- }
220
- )
221
- tool_output_str = tool_output.text(include_results=False).strip()
222
-
223
- if verbosity == 2:
224
- _print_code("Code and test after attempted fix:", code)
225
- _LOGGER.info(f"Code execution result after attempt {count + 1}")
226
- _LOGGER.info(f"{tool_output_str}")
227
-
228
- count += 1
229
-
230
- if verbosity >= 1:
231
- _print_code("Final code:", code)
232
-
233
- user_req = chat[-1]["content"]
234
- context = USER_REQ.format(user_request=user_req)
235
- # because the tool picker model gets the image as well, we have to be careful with
236
- # how much text we send it, so we truncate the tool output to 20,000 characters
237
- prompt = PICK_PLAN.format(
238
- context=context,
239
- plans=format_plans(plans),
240
- tool_output=tool_output_str[:20_000],
241
- )
242
- chat[-1]["content"] = prompt
243
-
244
- count = 0
245
- plan_thoughts = None
246
- while plan_thoughts is None and count < max_retries:
247
- try:
248
- plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
249
- except JSONDecodeError as e:
250
- _LOGGER.exception(
251
- f"Error while extracting JSON during picking best plan {str(e)}"
252
- )
253
- pass
254
- count += 1
255
-
256
- if (
257
- plan_thoughts is None
258
- or "best_plan" not in plan_thoughts
259
- or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
260
- ):
261
- _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
262
- plan_thoughts = {"best_plan": list(plans.keys())[0]}
263
-
264
- if "thoughts" not in plan_thoughts:
265
- plan_thoughts["thoughts"] = ""
266
-
267
- if verbosity >= 1:
268
- _LOGGER.info(f"Best plan:\n{plan_thoughts}")
269
- log_progress(
270
- {
271
- "type": "log",
272
- "log_content": "Picked best plan",
273
- "status": "completed",
274
- "payload": plans[plan_thoughts["best_plan"]],
275
- }
276
- )
277
- return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str
52
+ def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str:
53
+ """This will strip out all code that calls functions except for functions included
54
+ in exclusions.
55
+ """
56
+ if exclusions is None:
57
+ exclusions = []
58
+
59
+ red = RedBaron(code)
60
+ nodes_to_remove = []
61
+ for node in red:
62
+ if node.type == "def":
63
+ continue
64
+ elif node.type == "import" or node.type == "from_import":
65
+ continue
66
+ elif node.type == "call":
67
+ if node.value and node.value[0].value in exclusions:
68
+ continue
69
+ nodes_to_remove.append(node)
70
+ elif node.type == "atomtrailers":
71
+ if node[0].value in exclusions:
72
+ continue
73
+ nodes_to_remove.append(node)
74
+ elif node.type == "assignment":
75
+ if node.value.type == "call" or node.value.type == "atomtrailers":
76
+ func_name = node.value[0].value
77
+ if func_name in exclusions:
78
+ continue
79
+ nodes_to_remove.append(node)
80
+ elif node.type == "endl":
81
+ continue
82
+ else:
83
+ nodes_to_remove.append(node)
84
+ for node in nodes_to_remove:
85
+ node.parent.remove(node)
86
+ cleaned_code = red.dumps().strip()
87
+ return cleaned_code if isinstance(cleaned_code, str) else code
278
88
 
279
89
 
280
90
  def write_code(
@@ -359,6 +169,7 @@ def write_and_test_code(
359
169
  plan_thoughts,
360
170
  format_memory(working_memory),
361
171
  )
172
+ code = strip_function_calls(code)
362
173
  test = write_test(
363
174
  tester, chat, tool_utils, code, format_memory(working_memory), media
364
175
  )
@@ -393,7 +204,7 @@ def write_and_test_code(
393
204
  }
394
205
  )
395
206
  if verbosity == 2:
396
- _print_code("Initial code and tests:", code, test)
207
+ print_code("Initial code and tests:", code, test)
397
208
  _LOGGER.info(
398
209
  f"Initial code execution result:\n{result.text(include_logs=True)}"
399
210
  )
@@ -418,7 +229,7 @@ def write_and_test_code(
418
229
  count += 1
419
230
 
420
231
  if verbosity >= 1:
421
- _print_code("Final code and tests:", code, test)
232
+ print_code("Final code and tests:", code, test)
422
233
 
423
234
  return {
424
235
  "code": code,
@@ -449,7 +260,9 @@ def debug_code(
449
260
  }
450
261
  )
451
262
 
452
- fixed_code_and_test = {"code": "", "test": "", "reflections": ""}
263
+ fixed_code = None
264
+ fixed_test = None
265
+ thoughts = ""
453
266
  success = False
454
267
  count = 0
455
268
  while not success and count < 3:
@@ -472,21 +285,16 @@ def debug_code(
472
285
  stream=False,
473
286
  )
474
287
  fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
475
- fixed_code_and_test = extract_json(fixed_code_and_test_str)
476
- code = extract_code(fixed_code_and_test_str)
477
- if (
478
- "which_code" in fixed_code_and_test
479
- and fixed_code_and_test["which_code"] == "test"
480
- ):
481
- fixed_code_and_test["code"] = ""
482
- fixed_code_and_test["test"] = code
483
- else: # for everything else always assume it's updating code
484
- fixed_code_and_test["code"] = code
485
- fixed_code_and_test["test"] = ""
486
- if "which_code" in fixed_code_and_test:
487
- del fixed_code_and_test["which_code"]
488
-
489
- success = True
288
+ thoughts_tag = extract_tag(fixed_code_and_test_str, "thoughts")
289
+ thoughts = thoughts_tag if thoughts_tag is not None else ""
290
+ fixed_code = extract_tag(fixed_code_and_test_str, "code")
291
+ fixed_test = extract_tag(fixed_code_and_test_str, "test")
292
+
293
+ if fixed_code is None and fixed_test is None:
294
+ success = False
295
+ else:
296
+ success = True
297
+
490
298
  except Exception as e:
491
299
  _LOGGER.exception(f"Error while extracting JSON: {e}")
492
300
 
@@ -495,15 +303,15 @@ def debug_code(
495
303
  old_code = code
496
304
  old_test = test
497
305
 
498
- if fixed_code_and_test["code"].strip() != "":
499
- code = fixed_code_and_test["code"]
500
- if fixed_code_and_test["test"].strip() != "":
501
- test = fixed_code_and_test["test"]
306
+ if fixed_code is not None and fixed_code.strip() != "":
307
+ code = fixed_code
308
+ if fixed_test is not None and fixed_test.strip() != "":
309
+ test = fixed_test
502
310
 
503
311
  new_working_memory.append(
504
312
  {
505
313
  "code": f"{code}\n{test}",
506
- "feedback": fixed_code_and_test["reflections"],
314
+ "feedback": thoughts,
507
315
  "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
508
316
  }
509
317
  )
@@ -537,70 +345,14 @@ def debug_code(
537
345
  }
538
346
  )
539
347
  if verbosity == 2:
540
- _print_code("Code and test after attempted fix:", code, test)
348
+ print_code("Code and test after attempted fix:", code, test)
541
349
  _LOGGER.info(
542
- f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
350
+ f"Reflection: {thoughts}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
543
351
  )
544
352
 
545
353
  return code, test, result
546
354
 
547
355
 
548
- def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
549
- _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
550
- _CONSOLE.print("=" * 30 + " Code " + "=" * 30)
551
- _CONSOLE.print(
552
- Syntax(
553
- DefaultImports.prepend_imports(code),
554
- "python",
555
- theme="gruvbox-dark",
556
- line_numbers=True,
557
- )
558
- )
559
- if test:
560
- _CONSOLE.print("=" * 30 + " Test " + "=" * 30)
561
- _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
562
-
563
-
564
- def retrieve_tools(
565
- plans: Dict[str, Dict[str, Any]],
566
- tool_recommender: Sim,
567
- log_progress: Callable[[Dict[str, Any]], None],
568
- verbosity: int = 0,
569
- ) -> Dict[str, str]:
570
- log_progress(
571
- {
572
- "type": "log",
573
- "log_content": ("Retrieving tools for each plan"),
574
- "status": "started",
575
- }
576
- )
577
- tool_info = []
578
- tool_desc = []
579
- tool_lists: Dict[str, List[Dict[str, str]]] = {}
580
- for k, plan in plans.items():
581
- tool_lists[k] = []
582
- for task in plan["instructions"]:
583
- tools = tool_recommender.top_k(task, k=2, thresh=0.3)
584
- tool_info.extend([e["doc"] for e in tools])
585
- tool_desc.extend([e["desc"] for e in tools])
586
- tool_lists[k].extend(
587
- {"description": e["desc"], "documentation": e["doc"]} for e in tools
588
- )
589
-
590
- if verbosity == 2:
591
- tool_desc_str = "\n".join(set(tool_desc))
592
- _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
593
-
594
- tool_lists_unique = {}
595
- for k in tool_lists:
596
- tool_lists_unique[k] = "\n\n".join(
597
- set(e["documentation"] for e in tool_lists[k])
598
- )
599
- all_tools = "\n\n".join(set(tool_info))
600
- tool_lists_unique["all"] = all_tools
601
- return tool_lists_unique
602
-
603
-
604
356
  class VisionAgentCoder(Agent):
605
357
  """Vision Agent Coder is an agentic framework that can output code based on a user
606
358
  request. It can plan tasks, retrieve relevant tools, write code, write tests and
@@ -616,23 +368,22 @@ class VisionAgentCoder(Agent):
616
368
 
617
369
  def __init__(
618
370
  self,
619
- planner: Optional[LMM] = None,
371
+ planner: Optional[Agent] = None,
620
372
  coder: Optional[LMM] = None,
621
373
  tester: Optional[LMM] = None,
622
374
  debugger: Optional[LMM] = None,
623
- tool_recommender: Optional[Sim] = None,
624
375
  verbosity: int = 0,
625
376
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
626
- code_sandbox_runtime: Optional[str] = None,
377
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
627
378
  ) -> None:
628
379
  """Initialize the Vision Agent Coder.
629
380
 
630
381
  Parameters:
631
- planner (Optional[LMM]): The planner model to use. Defaults to AnthropicLMM.
382
+ planner (Optional[Agent]): The planner model to use. Defaults to
383
+ AnthropicVisionAgentPlanner.
632
384
  coder (Optional[LMM]): The coder model to use. Defaults to AnthropicLMM.
633
385
  tester (Optional[LMM]): The tester model to use. Defaults to AnthropicLMM.
634
386
  debugger (Optional[LMM]): The debugger model to use. Defaults to AnthropicLMM.
635
- tool_recommender (Optional[Sim]): The tool recommender model to use.
636
387
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
637
388
  highest verbosity level which will output all intermediate debugging
638
389
  code.
@@ -641,14 +392,17 @@ class VisionAgentCoder(Agent):
641
392
  in a web application where multiple VisionAgentCoder instances are
642
393
  running in parallel. This callback ensures that the progress are not
643
394
  mixed up.
644
- code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
645
- code sandbox is used to run the generated code. It can be one of the
646
- following values: None, "local" or "e2b". If None, VisionAgentCoder
647
- will read the value from the environment variable CODE_SANDBOX_RUNTIME.
648
- If it's also None, the local python runtime environment will be used.
395
+ code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
396
+ it can be one of: None, "local" or "e2b". If None, it will read from
397
+ the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
398
+ object is provided it will use that.
649
399
  """
650
400
 
651
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
401
+ self.planner = (
402
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
403
+ if planner is None
404
+ else planner
405
+ )
652
406
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
653
407
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
654
408
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -656,21 +410,15 @@ class VisionAgentCoder(Agent):
656
410
  if self.verbosity > 0:
657
411
  _LOGGER.setLevel(logging.INFO)
658
412
 
659
- self.tool_recommender = (
660
- Sim(T.TOOLS_DF, sim_key="desc")
661
- if tool_recommender is None
662
- else tool_recommender
663
- )
664
413
  self.report_progress_callback = report_progress_callback
665
- self.code_sandbox_runtime = code_sandbox_runtime
414
+ self.code_interpreter = code_interpreter
666
415
 
667
416
  def __call__(
668
417
  self,
669
418
  input: Union[str, List[Message]],
670
419
  media: Optional[Union[str, Path]] = None,
671
420
  ) -> str:
672
- """Chat with VisionAgentCoder and return intermediate information regarding the
673
- task.
421
+ """Generate code based on a user request.
674
422
 
675
423
  Parameters:
676
424
  input (Union[str, List[Message]]): A conversation in the format of
@@ -686,46 +434,58 @@ class VisionAgentCoder(Agent):
686
434
  input = [{"role": "user", "content": input}]
687
435
  if media is not None:
688
436
  input[0]["media"] = [media]
689
- results = self.chat_with_workflow(input)
690
- results.pop("working_memory")
691
- return results["code"] # type: ignore
437
+ code_and_context = self.generate_code(input)
438
+ return code_and_context["code"] # type: ignore
692
439
 
693
- def chat_with_workflow(
440
+ def generate_code_from_plan(
694
441
  self,
695
442
  chat: List[Message],
696
- test_multi_plan: bool = True,
697
- display_visualization: bool = False,
698
- custom_tool_names: Optional[List[str]] = None,
443
+ plan_context: PlanContext,
444
+ code_interpreter: Optional[CodeInterpreter] = None,
699
445
  ) -> Dict[str, Any]:
700
- """Chat with VisionAgentCoder and return intermediate information regarding the
701
- task.
446
+ """Generates code and other intermediate outputs from a chat input and a plan.
447
+ The plan includes:
448
+ - plans: The plans generated by the planner.
449
+ - best_plan: The best plan selected by the planner.
450
+ - plan_thoughts: The thoughts of the planner, including any modifications
451
+ to the plan.
452
+ - tool_doc: The tool documentation for the best plan.
453
+ - tool_output: The tool output from the tools used by the best plan.
702
454
 
703
455
  Parameters:
704
- chat (List[Message]): A conversation
705
- in the format of:
706
- [{"role": "user", "content": "describe your task here..."}]
707
- or if it contains media files, it should be in the format of:
708
- [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
709
- test_multi_plan (bool): If True, it will test tools for multiple plans and
710
- pick the best one based off of the tool results. If False, it will go
711
- with the first plan.
712
- display_visualization (bool): If True, it opens a new window locally to
713
- show the image(s) created by visualization code (if there is any).
714
- custom_tool_names (List[str]): A list of custom tools for the agent to pick
715
- and use. If not provided, default to full tool set from vision_agent.tools.
456
+ chat (List[Message]): A conversation in the format of
457
+ [{"role": "user", "content": "describe your task here..."}].
458
+ plan_context (PlanContext): The context of the plan, including the plans,
459
+ best_plan, plan_thoughts, tool_doc, and tool_output.
460
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
461
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
462
+ for the planner.
716
463
 
717
464
  Returns:
718
- Dict[str, Any]: A dictionary containing the code, test, test result, plan,
719
- and working memory of the agent.
465
+ Dict[str, Any]: A dictionary containing the code output by the
466
+ VisionAgentCoder and other intermediate outputs. include:
467
+ - status (str): Whether or not the agent completed or failed generating
468
+ the code.
469
+ - code (str): The code output by the VisionAgentCoder.
470
+ - test (str): The test output by the VisionAgentCoder.
471
+ - test_result (Execution): The result of the test execution.
472
+ - plans (Dict[str, Any]): The plans generated by the planner.
473
+ - plan_thoughts (str): The thoughts of the planner.
474
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
720
475
  """
721
-
722
476
  if not chat:
723
477
  raise ValueError("Chat cannot be empty.")
724
478
 
725
479
  # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
726
- with CodeInterpreterFactory.new_instance(
727
- code_sandbox_runtime=self.code_sandbox_runtime
728
- ) as code_interpreter:
480
+ code_interpreter = (
481
+ self.code_interpreter
482
+ if self.code_interpreter is not None
483
+ and not isinstance(self.code_interpreter, str)
484
+ else CodeInterpreterFactory.new_instance(
485
+ code_sandbox_runtime=self.code_interpreter,
486
+ )
487
+ )
488
+ with code_interpreter:
729
489
  chat = copy.deepcopy(chat)
730
490
  media_list = []
731
491
  for chat_i in chat:
@@ -759,74 +519,22 @@ class VisionAgentCoder(Agent):
759
519
  code = ""
760
520
  test = ""
761
521
  working_memory: List[Dict[str, str]] = []
762
- results = {"code": "", "test": "", "plan": []}
763
- plan = []
764
- success = False
765
-
766
- plans = self._create_plans(
767
- int_chat, custom_tool_names, working_memory, self.planner
768
- )
769
-
770
- if test_multi_plan:
771
- self._log_plans(plans, self.verbosity)
772
-
773
- tool_infos = retrieve_tools(
774
- plans,
775
- self.tool_recommender,
776
- self.log_progress,
777
- self.verbosity,
778
- )
779
-
780
- if test_multi_plan:
781
- plan_thoughts, tool_output_str = pick_plan(
782
- int_chat,
783
- plans,
784
- tool_infos["all"],
785
- self.coder,
786
- code_interpreter,
787
- media_list,
788
- self.log_progress,
789
- verbosity=self.verbosity,
790
- )
791
- best_plan = plan_thoughts["best_plan"]
792
- plan_thoughts_str = plan_thoughts["thoughts"]
793
- else:
794
- best_plan = list(plans.keys())[0]
795
- tool_output_str = ""
796
- plan_thoughts_str = ""
797
-
798
- if best_plan in plans and best_plan in tool_infos:
799
- plan_i = plans[best_plan]
800
- tool_info = tool_infos[best_plan]
801
- else:
802
- if self.verbosity >= 1:
803
- _LOGGER.warning(
804
- f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
805
- )
806
- k = list(plans.keys())[0]
807
- plan_i = plans[k]
808
- tool_info = tool_infos[k]
809
-
810
- self.log_progress(
811
- {
812
- "type": "log",
813
- "log_content": "Creating plans",
814
- "status": "completed",
815
- "payload": tool_info,
816
- }
817
- )
522
+ plan = plan_context.plans[plan_context.best_plan]
523
+ tool_doc = plan_context.tool_doc
524
+ tool_output_str = plan_context.tool_output
525
+ plan_thoughts_str = str(plan_context.plan_thoughts)
818
526
 
819
527
  if self.verbosity >= 1:
820
- plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]]
528
+ plan_fixed = [{"instructions": e} for e in plan["instructions"]]
821
529
  _LOGGER.info(
822
- f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
530
+ f"Picked best plan:\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
823
531
  )
824
532
 
825
533
  results = write_and_test_code(
826
534
  chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
827
- plan=f"\n{plan_i['thoughts']}\n-"
828
- + "\n-".join([e for e in plan_i["instructions"]]),
829
- tool_info=tool_info,
535
+ plan=f"\n{plan['thoughts']}\n-"
536
+ + "\n-".join([e for e in plan["instructions"]]),
537
+ tool_info=tool_doc,
830
538
  tool_output=tool_output_str,
831
539
  plan_thoughts=plan_thoughts_str,
832
540
  tool_utils=T.UTILITIES_DOCSTRING,
@@ -842,64 +550,82 @@ class VisionAgentCoder(Agent):
842
550
  success = cast(bool, results["success"])
843
551
  code = remove_installs_from_code(cast(str, results["code"]))
844
552
  test = remove_installs_from_code(cast(str, results["test"]))
845
- working_memory.extend(results["working_memory"]) # type: ignore
846
- plan.append({"code": code, "test": test, "plan": plan_i})
847
-
553
+ working_memory.extend(results["working_memory"])
848
554
  execution_result = cast(Execution, results["test_result"])
849
555
 
850
- if display_visualization:
851
- for res in execution_result.results:
852
- if res.png:
853
- b64_to_pil(res.png).show()
854
- if res.mp4:
855
- play_video(res.mp4)
856
-
857
556
  return {
858
557
  "status": "completed" if success else "failed",
859
558
  "code": DefaultImports.prepend_imports(code),
860
559
  "test": test,
861
560
  "test_result": execution_result,
862
- "plans": plans,
561
+ "plans": plan_context.plans,
863
562
  "plan_thoughts": plan_thoughts_str,
864
563
  "working_memory": working_memory,
865
564
  }
866
565
 
867
- def log_progress(self, data: Dict[str, Any]) -> None:
868
- if self.report_progress_callback is not None:
869
- self.report_progress_callback(data)
870
-
871
- def _create_plans(
566
+ def generate_code(
872
567
  self,
873
- int_chat: List[Message],
874
- customized_tool_names: Optional[List[str]],
875
- working_memory: List[Dict[str, str]],
876
- planner: LMM,
568
+ chat: List[Message],
569
+ test_multi_plan: bool = True,
570
+ custom_tool_names: Optional[List[str]] = None,
877
571
  ) -> Dict[str, Any]:
878
- self.log_progress(
879
- {
880
- "type": "log",
881
- "log_content": "Creating plans",
882
- "status": "started",
883
- }
884
- )
885
- plans = write_plans(
886
- int_chat,
887
- T.get_tool_descriptions_by_names(
888
- customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
889
- ),
890
- format_memory(working_memory),
891
- planner,
572
+ """Generates code and other intermediate outputs from a chat input.
573
+
574
+ Parameters:
575
+ chat (List[Message]): A conversation in the format of
576
+ [{"role": "user", "content": "describe your task here..."}].
577
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
578
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
579
+ for the planner.
580
+
581
+ Returns:
582
+ Dict[str, Any]: A dictionary containing the code output by the
583
+ VisionAgentCoder and other intermediate outputs. include:
584
+ - status (str): Whether or not the agent completed or failed generating
585
+ the code.
586
+ - code (str): The code output by the VisionAgentCoder.
587
+ - test (str): The test output by the VisionAgentCoder.
588
+ - test_result (Execution): The result of the test execution.
589
+ - plans (Dict[str, Any]): The plans generated by the planner.
590
+ - plan_thoughts (str): The thoughts of the planner.
591
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
592
+ """
593
+ if not chat:
594
+ raise ValueError("Chat cannot be empty.")
595
+
596
+ # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
597
+ code_interpreter = (
598
+ self.code_interpreter
599
+ if self.code_interpreter is not None
600
+ and not isinstance(self.code_interpreter, str)
601
+ else CodeInterpreterFactory.new_instance(
602
+ code_sandbox_runtime=self.code_interpreter,
603
+ )
892
604
  )
893
- return plans
605
+ with code_interpreter:
606
+ plan_context = self.planner.generate_plan( # type: ignore
607
+ chat,
608
+ test_multi_plan=test_multi_plan,
609
+ custom_tool_names=custom_tool_names,
610
+ code_interpreter=code_interpreter,
611
+ )
894
612
 
895
- def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
896
- if verbosity >= 1:
897
- for p in plans:
898
- # tabulate will fail if the keys are not the same for all elements
899
- p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
900
- _LOGGER.info(
901
- f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
902
- )
613
+ code_and_context = self.generate_code_from_plan(
614
+ chat,
615
+ plan_context,
616
+ code_interpreter=code_interpreter,
617
+ )
618
+ return code_and_context
619
+
620
+ def chat(self, chat: List[Message]) -> List[Message]:
621
+ chat = copy.deepcopy(chat)
622
+ code = self.generate_code(chat)
623
+ chat.append({"role": "agent", "content": code["code"]})
624
+ return chat
625
+
626
+ def log_progress(self, data: Dict[str, Any]) -> None:
627
+ if self.report_progress_callback is not None:
628
+ self.report_progress_callback(data)
903
629
 
904
630
 
905
631
  class OpenAIVisionAgentCoder(VisionAgentCoder):
@@ -907,17 +633,18 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
907
633
 
908
634
  def __init__(
909
635
  self,
910
- planner: Optional[LMM] = None,
636
+ planner: Optional[Agent] = None,
911
637
  coder: Optional[LMM] = None,
912
638
  tester: Optional[LMM] = None,
913
639
  debugger: Optional[LMM] = None,
914
- tool_recommender: Optional[Sim] = None,
915
640
  verbosity: int = 0,
916
641
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
917
- code_sandbox_runtime: Optional[str] = None,
642
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
918
643
  ) -> None:
919
644
  self.planner = (
920
- OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
645
+ OpenAIVisionAgentPlanner(verbosity=verbosity)
646
+ if planner is None
647
+ else planner
921
648
  )
922
649
  self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
923
650
  self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
@@ -926,13 +653,8 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
926
653
  if self.verbosity > 0:
927
654
  _LOGGER.setLevel(logging.INFO)
928
655
 
929
- self.tool_recommender = (
930
- Sim(T.TOOLS_DF, sim_key="desc")
931
- if tool_recommender is None
932
- else tool_recommender
933
- )
934
656
  self.report_progress_callback = report_progress_callback
935
- self.code_sandbox_runtime = code_sandbox_runtime
657
+ self.code_interpreter = code_interpreter
936
658
 
937
659
 
938
660
  class AnthropicVisionAgentCoder(VisionAgentCoder):
@@ -940,17 +662,20 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
940
662
 
941
663
  def __init__(
942
664
  self,
943
- planner: Optional[LMM] = None,
665
+ planner: Optional[Agent] = None,
944
666
  coder: Optional[LMM] = None,
945
667
  tester: Optional[LMM] = None,
946
668
  debugger: Optional[LMM] = None,
947
- tool_recommender: Optional[Sim] = None,
948
669
  verbosity: int = 0,
949
670
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
950
- code_sandbox_runtime: Optional[str] = None,
671
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
951
672
  ) -> None:
952
673
  # NOTE: Claude doesn't have an official JSON mode
953
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
674
+ self.planner = (
675
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
676
+ if planner is None
677
+ else planner
678
+ )
954
679
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
955
680
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
956
681
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -958,15 +683,8 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
958
683
  if self.verbosity > 0:
959
684
  _LOGGER.setLevel(logging.INFO)
960
685
 
961
- # Anthropic does not offer any embedding models and instead recomends Voyage,
962
- # we're using OpenAI's embedder for now.
963
- self.tool_recommender = (
964
- Sim(T.TOOLS_DF, sim_key="desc")
965
- if tool_recommender is None
966
- else tool_recommender
967
- )
968
686
  self.report_progress_callback = report_progress_callback
969
- self.code_sandbox_runtime = code_sandbox_runtime
687
+ self.code_interpreter = code_interpreter
970
688
 
971
689
 
972
690
  class OllamaVisionAgentCoder(VisionAgentCoder):
@@ -988,17 +706,17 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
988
706
 
989
707
  def __init__(
990
708
  self,
991
- planner: Optional[LMM] = None,
709
+ planner: Optional[Agent] = None,
992
710
  coder: Optional[LMM] = None,
993
711
  tester: Optional[LMM] = None,
994
712
  debugger: Optional[LMM] = None,
995
- tool_recommender: Optional[Sim] = None,
996
713
  verbosity: int = 0,
997
714
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
715
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
998
716
  ) -> None:
999
717
  super().__init__(
1000
718
  planner=(
1001
- OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
719
+ OllamaVisionAgentPlanner(verbosity=verbosity)
1002
720
  if planner is None
1003
721
  else planner
1004
722
  ),
@@ -1017,13 +735,9 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
1017
735
  if debugger is None
1018
736
  else debugger
1019
737
  ),
1020
- tool_recommender=(
1021
- OllamaSim(T.TOOLS_DF, sim_key="desc")
1022
- if tool_recommender is None
1023
- else tool_recommender
1024
- ),
1025
738
  verbosity=verbosity,
1026
739
  report_progress_callback=report_progress_callback,
740
+ code_interpreter=code_interpreter,
1027
741
  )
1028
742
 
1029
743
 
@@ -1043,22 +757,22 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1043
757
 
1044
758
  def __init__(
1045
759
  self,
1046
- planner: Optional[LMM] = None,
760
+ planner: Optional[Agent] = None,
1047
761
  coder: Optional[LMM] = None,
1048
762
  tester: Optional[LMM] = None,
1049
763
  debugger: Optional[LMM] = None,
1050
- tool_recommender: Optional[Sim] = None,
1051
764
  verbosity: int = 0,
1052
765
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
766
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
1053
767
  ) -> None:
1054
768
  """Initialize the Vision Agent Coder.
1055
769
 
1056
770
  Parameters:
1057
- planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
771
+ planner (Optional[Agent]): The planner model to use. Defaults to
772
+ AzureVisionAgentPlanner.
1058
773
  coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
1059
774
  tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
1060
775
  debugger (Optional[LMM]): The debugger model to
1061
- tool_recommender (Optional[Sim]): The tool recommender model to use.
1062
776
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
1063
777
  highest verbosity level which will output all intermediate debugging
1064
778
  code.
@@ -1069,7 +783,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1069
783
  """
1070
784
  super().__init__(
1071
785
  planner=(
1072
- AzureOpenAILMM(temperature=0.0, json_mode=True)
786
+ AzureVisionAgentPlanner(verbosity=verbosity)
1073
787
  if planner is None
1074
788
  else planner
1075
789
  ),
@@ -1078,11 +792,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1078
792
  debugger=(
1079
793
  AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
1080
794
  ),
1081
- tool_recommender=(
1082
- AzureSim(T.TOOLS_DF, sim_key="desc")
1083
- if tool_recommender is None
1084
- else tool_recommender
1085
- ),
1086
795
  verbosity=verbosity,
1087
796
  report_progress_callback=report_progress_callback,
797
+ code_interpreter=code_interpreter,
1088
798
  )