vision-agent 0.2.161__py3-none-any.whl → 0.2.163__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,32 +2,35 @@ import copy
2
2
  import logging
3
3
  import os
4
4
  import sys
5
- from json import JSONDecodeError
6
5
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
8
7
 
9
- from rich.console import Console
10
- from rich.style import Style
11
- from rich.syntax import Syntax
8
+ from redbaron import RedBaron # type: ignore
12
9
  from tabulate import tabulate
13
10
 
14
11
  import vision_agent.tools as T
15
- from vision_agent.agent import Agent
12
+ from vision_agent.agent.agent import Agent
16
13
  from vision_agent.agent.agent_utils import (
14
+ _MAX_TABULATE_COL_WIDTH,
15
+ DefaultImports,
17
16
  extract_code,
18
- extract_json,
17
+ extract_tag,
18
+ format_memory,
19
+ print_code,
19
20
  remove_installs_from_code,
20
21
  )
21
22
  from vision_agent.agent.vision_agent_coder_prompts import (
22
23
  CODE,
23
24
  FIX_BUG,
24
25
  FULL_TASK,
25
- PICK_PLAN,
26
- PLAN,
27
- PREVIOUS_FAILED,
28
26
  SIMPLE_TEST,
29
- TEST_PLANS,
30
- USER_REQ,
27
+ )
28
+ from vision_agent.agent.vision_agent_planner import (
29
+ AnthropicVisionAgentPlanner,
30
+ AzureVisionAgentPlanner,
31
+ OllamaVisionAgentPlanner,
32
+ OpenAIVisionAgentPlanner,
33
+ PlanContext,
31
34
  )
32
35
  from vision_agent.lmm import (
33
36
  LMM,
@@ -40,241 +43,48 @@ from vision_agent.lmm import (
40
43
  from vision_agent.tools.meta_tools import get_diff
41
44
  from vision_agent.utils import CodeInterpreterFactory, Execution
42
45
  from vision_agent.utils.execute import CodeInterpreter
43
- from vision_agent.utils.image_utils import b64_to_pil
44
- from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
45
- from vision_agent.utils.video import play_video
46
46
 
47
47
  logging.basicConfig(stream=sys.stdout)
48
48
  WORKSPACE = Path(os.getenv("WORKSPACE", ""))
49
49
  _LOGGER = logging.getLogger(__name__)
50
- _MAX_TABULATE_COL_WIDTH = 80
51
- _CONSOLE = Console()
52
-
53
-
54
- class DefaultImports:
55
- """Container for default imports used in the code execution."""
56
-
57
- common_imports = [
58
- "import os",
59
- "import numpy as np",
60
- "from vision_agent.tools import *",
61
- "from typing import *",
62
- "from pillow_heif import register_heif_opener",
63
- "register_heif_opener()",
64
- ]
65
-
66
- @staticmethod
67
- def to_code_string() -> str:
68
- return "\n".join(DefaultImports.common_imports + T.__new_tools__)
69
-
70
- @staticmethod
71
- def prepend_imports(code: str) -> str:
72
- """Run this method to prepend the default imports to the code.
73
- NOTE: be sure to run this method after the custom tools have been registered.
74
- """
75
- return DefaultImports.to_code_string() + "\n\n" + code
76
-
77
-
78
- def format_memory(memory: List[Dict[str, str]]) -> str:
79
- output_str = ""
80
- for i, m in enumerate(memory):
81
- output_str += f"### Feedback {i}:\n"
82
- output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
83
- output_str += f"Feedback {i}: {m['feedback']}\n\n"
84
- if "edits" in m:
85
- output_str += f"Edits {i}:\n{m['edits']}\n"
86
- output_str += "\n"
87
-
88
- return output_str
89
-
90
-
91
- def format_plans(plans: Dict[str, Any]) -> str:
92
- plan_str = ""
93
- for k, v in plans.items():
94
- plan_str += "\n" + f"{k}: {v['thoughts']}\n"
95
- plan_str += " -" + "\n -".join([e for e in v["instructions"]])
96
-
97
- return plan_str
98
-
99
-
100
- def write_plans(
101
- chat: List[Message],
102
- tool_desc: str,
103
- working_memory: str,
104
- model: LMM,
105
- ) -> Dict[str, Any]:
106
- chat = copy.deepcopy(chat)
107
- if chat[-1]["role"] != "user":
108
- raise ValueError("Last chat message must be from the user.")
109
-
110
- user_request = chat[-1]["content"]
111
- context = USER_REQ.format(user_request=user_request)
112
- prompt = PLAN.format(
113
- context=context,
114
- tool_desc=tool_desc,
115
- feedback=working_memory,
116
- )
117
- chat[-1]["content"] = prompt
118
- return extract_json(model(chat, stream=False)) # type: ignore
119
-
120
-
121
- def pick_plan(
122
- chat: List[Message],
123
- plans: Dict[str, Any],
124
- tool_info: str,
125
- model: LMM,
126
- code_interpreter: CodeInterpreter,
127
- media: List[str],
128
- log_progress: Callable[[Dict[str, Any]], None],
129
- verbosity: int = 0,
130
- max_retries: int = 3,
131
- ) -> Tuple[Dict[str, str], str]:
132
- log_progress(
133
- {
134
- "type": "log",
135
- "log_content": "Generating code to pick the best plan",
136
- "status": "started",
137
- }
138
- )
139
50
 
140
- chat = copy.deepcopy(chat)
141
- if chat[-1]["role"] != "user":
142
- raise ValueError("Last chat message must be from the user.")
143
51
 
144
- plan_str = format_plans(plans)
145
- prompt = TEST_PLANS.format(
146
- docstring=tool_info, plans=plan_str, previous_attempts="", media=media
147
- )
148
-
149
- code = extract_code(model(prompt, stream=False)) # type: ignore
150
- log_progress(
151
- {
152
- "type": "log",
153
- "log_content": "Executing code to test plans",
154
- "code": DefaultImports.prepend_imports(code),
155
- "status": "running",
156
- }
157
- )
158
- tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
159
- # Because of the way we trace function calls the trace information ends up in the
160
- # results. We don't want to show this info to the LLM so we don't include it in the
161
- # tool_output_str.
162
- tool_output_str = tool_output.text(include_results=False).strip()
163
-
164
- if verbosity == 2:
165
- _print_code("Initial code and tests:", code)
166
- _LOGGER.info(f"Initial code execution result:\n{tool_output_str}")
167
-
168
- log_progress(
169
- {
170
- "type": "log",
171
- "log_content": (
172
- "Code execution succeeded"
173
- if tool_output.success
174
- else "Code execution failed"
175
- ),
176
- "code": DefaultImports.prepend_imports(code),
177
- # "payload": tool_output.to_json(),
178
- "status": "completed" if tool_output.success else "failed",
179
- }
180
- )
181
-
182
- # retry if the tool output is empty or code fails
183
- count = 0
184
- while (
185
- not tool_output.success
186
- or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
187
- ) and count < max_retries:
188
- prompt = TEST_PLANS.format(
189
- docstring=tool_info,
190
- plans=plan_str,
191
- previous_attempts=PREVIOUS_FAILED.format(
192
- code=code, error="\n".join(tool_output_str.splitlines()[-50:])
193
- ),
194
- media=media,
195
- )
196
- log_progress(
197
- {
198
- "type": "log",
199
- "log_content": "Retrying code to test plans",
200
- "status": "running",
201
- "code": DefaultImports.prepend_imports(code),
202
- }
203
- )
204
- code = extract_code(model(prompt, stream=False)) # type: ignore
205
- tool_output = code_interpreter.exec_isolation(
206
- DefaultImports.prepend_imports(code)
207
- )
208
- log_progress(
209
- {
210
- "type": "log",
211
- "log_content": (
212
- "Code execution succeeded"
213
- if tool_output.success
214
- else "Code execution failed"
215
- ),
216
- "code": DefaultImports.prepend_imports(code),
217
- # "payload": tool_output.to_json(),
218
- "status": "completed" if tool_output.success else "failed",
219
- }
220
- )
221
- tool_output_str = tool_output.text(include_results=False).strip()
222
-
223
- if verbosity == 2:
224
- _print_code("Code and test after attempted fix:", code)
225
- _LOGGER.info(f"Code execution result after attempt {count + 1}")
226
- _LOGGER.info(f"{tool_output_str}")
227
-
228
- count += 1
229
-
230
- if verbosity >= 1:
231
- _print_code("Final code:", code)
232
-
233
- user_req = chat[-1]["content"]
234
- context = USER_REQ.format(user_request=user_req)
235
- # because the tool picker model gets the image as well, we have to be careful with
236
- # how much text we send it, so we truncate the tool output to 20,000 characters
237
- prompt = PICK_PLAN.format(
238
- context=context,
239
- plans=format_plans(plans),
240
- tool_output=tool_output_str[:20_000],
241
- )
242
- chat[-1]["content"] = prompt
243
-
244
- count = 0
245
- plan_thoughts = None
246
- while plan_thoughts is None and count < max_retries:
247
- try:
248
- plan_thoughts = extract_json(model(chat, stream=False)) # type: ignore
249
- except JSONDecodeError as e:
250
- _LOGGER.exception(
251
- f"Error while extracting JSON during picking best plan {str(e)}"
252
- )
253
- pass
254
- count += 1
255
-
256
- if (
257
- plan_thoughts is None
258
- or "best_plan" not in plan_thoughts
259
- or ("best_plan" in plan_thoughts and plan_thoughts["best_plan"] not in plans)
260
- ):
261
- _LOGGER.info(f"Failed to pick best plan. Using the first plan. {plan_thoughts}")
262
- plan_thoughts = {"best_plan": list(plans.keys())[0]}
263
-
264
- if "thoughts" not in plan_thoughts:
265
- plan_thoughts["thoughts"] = ""
266
-
267
- if verbosity >= 1:
268
- _LOGGER.info(f"Best plan:\n{plan_thoughts}")
269
- log_progress(
270
- {
271
- "type": "log",
272
- "log_content": "Picked best plan",
273
- "status": "completed",
274
- "payload": plans[plan_thoughts["best_plan"]],
275
- }
276
- )
277
- return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str
52
+ def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str:
53
+ """This will strip out all code that calls functions except for functions included
54
+ in exclusions.
55
+ """
56
+ if exclusions is None:
57
+ exclusions = []
58
+
59
+ red = RedBaron(code)
60
+ nodes_to_remove = []
61
+ for node in red:
62
+ if node.type == "def":
63
+ continue
64
+ elif node.type == "import" or node.type == "from_import":
65
+ continue
66
+ elif node.type == "call":
67
+ if node.value and node.value[0].value in exclusions:
68
+ continue
69
+ nodes_to_remove.append(node)
70
+ elif node.type == "atomtrailers":
71
+ if node[0].value in exclusions:
72
+ continue
73
+ nodes_to_remove.append(node)
74
+ elif node.type == "assignment":
75
+ if node.value.type == "call" or node.value.type == "atomtrailers":
76
+ func_name = node.value[0].value
77
+ if func_name in exclusions:
78
+ continue
79
+ nodes_to_remove.append(node)
80
+ elif node.type == "endl":
81
+ continue
82
+ else:
83
+ nodes_to_remove.append(node)
84
+ for node in nodes_to_remove:
85
+ node.parent.remove(node)
86
+ cleaned_code = red.dumps().strip()
87
+ return cleaned_code if isinstance(cleaned_code, str) else code
278
88
 
279
89
 
280
90
  def write_code(
@@ -359,6 +169,7 @@ def write_and_test_code(
359
169
  plan_thoughts,
360
170
  format_memory(working_memory),
361
171
  )
172
+ code = strip_function_calls(code)
362
173
  test = write_test(
363
174
  tester, chat, tool_utils, code, format_memory(working_memory), media
364
175
  )
@@ -393,7 +204,7 @@ def write_and_test_code(
393
204
  }
394
205
  )
395
206
  if verbosity == 2:
396
- _print_code("Initial code and tests:", code, test)
207
+ print_code("Initial code and tests:", code, test)
397
208
  _LOGGER.info(
398
209
  f"Initial code execution result:\n{result.text(include_logs=True)}"
399
210
  )
@@ -418,7 +229,7 @@ def write_and_test_code(
418
229
  count += 1
419
230
 
420
231
  if verbosity >= 1:
421
- _print_code("Final code and tests:", code, test)
232
+ print_code("Final code and tests:", code, test)
422
233
 
423
234
  return {
424
235
  "code": code,
@@ -449,7 +260,9 @@ def debug_code(
449
260
  }
450
261
  )
451
262
 
452
- fixed_code_and_test = {"code": "", "test": "", "reflections": ""}
263
+ fixed_code = None
264
+ fixed_test = None
265
+ thoughts = ""
453
266
  success = False
454
267
  count = 0
455
268
  while not success and count < 3:
@@ -472,21 +285,16 @@ def debug_code(
472
285
  stream=False,
473
286
  )
474
287
  fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
475
- fixed_code_and_test = extract_json(fixed_code_and_test_str)
476
- code = extract_code(fixed_code_and_test_str)
477
- if (
478
- "which_code" in fixed_code_and_test
479
- and fixed_code_and_test["which_code"] == "test"
480
- ):
481
- fixed_code_and_test["code"] = ""
482
- fixed_code_and_test["test"] = code
483
- else: # for everything else always assume it's updating code
484
- fixed_code_and_test["code"] = code
485
- fixed_code_and_test["test"] = ""
486
- if "which_code" in fixed_code_and_test:
487
- del fixed_code_and_test["which_code"]
488
-
489
- success = True
288
+ thoughts_tag = extract_tag(fixed_code_and_test_str, "thoughts")
289
+ thoughts = thoughts_tag if thoughts_tag is not None else ""
290
+ fixed_code = extract_tag(fixed_code_and_test_str, "code")
291
+ fixed_test = extract_tag(fixed_code_and_test_str, "test")
292
+
293
+ if fixed_code is None and fixed_test is None:
294
+ success = False
295
+ else:
296
+ success = True
297
+
490
298
  except Exception as e:
491
299
  _LOGGER.exception(f"Error while extracting JSON: {e}")
492
300
 
@@ -495,15 +303,15 @@ def debug_code(
495
303
  old_code = code
496
304
  old_test = test
497
305
 
498
- if fixed_code_and_test["code"].strip() != "":
499
- code = fixed_code_and_test["code"]
500
- if fixed_code_and_test["test"].strip() != "":
501
- test = fixed_code_and_test["test"]
306
+ if fixed_code is not None and fixed_code.strip() != "":
307
+ code = fixed_code
308
+ if fixed_test is not None and fixed_test.strip() != "":
309
+ test = fixed_test
502
310
 
503
311
  new_working_memory.append(
504
312
  {
505
313
  "code": f"{code}\n{test}",
506
- "feedback": fixed_code_and_test["reflections"],
314
+ "feedback": thoughts,
507
315
  "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
508
316
  }
509
317
  )
@@ -537,70 +345,14 @@ def debug_code(
537
345
  }
538
346
  )
539
347
  if verbosity == 2:
540
- _print_code("Code and test after attempted fix:", code, test)
348
+ print_code("Code and test after attempted fix:", code, test)
541
349
  _LOGGER.info(
542
- f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
350
+ f"Reflection: {thoughts}\nCode execution result after attempted fix: {result.text(include_logs=True)}"
543
351
  )
544
352
 
545
353
  return code, test, result
546
354
 
547
355
 
548
- def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
549
- _CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
550
- _CONSOLE.print("=" * 30 + " Code " + "=" * 30)
551
- _CONSOLE.print(
552
- Syntax(
553
- DefaultImports.prepend_imports(code),
554
- "python",
555
- theme="gruvbox-dark",
556
- line_numbers=True,
557
- )
558
- )
559
- if test:
560
- _CONSOLE.print("=" * 30 + " Test " + "=" * 30)
561
- _CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
562
-
563
-
564
- def retrieve_tools(
565
- plans: Dict[str, Dict[str, Any]],
566
- tool_recommender: Sim,
567
- log_progress: Callable[[Dict[str, Any]], None],
568
- verbosity: int = 0,
569
- ) -> Dict[str, str]:
570
- log_progress(
571
- {
572
- "type": "log",
573
- "log_content": ("Retrieving tools for each plan"),
574
- "status": "started",
575
- }
576
- )
577
- tool_info = []
578
- tool_desc = []
579
- tool_lists: Dict[str, List[Dict[str, str]]] = {}
580
- for k, plan in plans.items():
581
- tool_lists[k] = []
582
- for task in plan["instructions"]:
583
- tools = tool_recommender.top_k(task, k=2, thresh=0.3)
584
- tool_info.extend([e["doc"] for e in tools])
585
- tool_desc.extend([e["desc"] for e in tools])
586
- tool_lists[k].extend(
587
- {"description": e["desc"], "documentation": e["doc"]} for e in tools
588
- )
589
-
590
- if verbosity == 2:
591
- tool_desc_str = "\n".join(set(tool_desc))
592
- _LOGGER.info(f"Tools Description:\n{tool_desc_str}")
593
-
594
- tool_lists_unique = {}
595
- for k in tool_lists:
596
- tool_lists_unique[k] = "\n\n".join(
597
- set(e["documentation"] for e in tool_lists[k])
598
- )
599
- all_tools = "\n\n".join(set(tool_info))
600
- tool_lists_unique["all"] = all_tools
601
- return tool_lists_unique
602
-
603
-
604
356
  class VisionAgentCoder(Agent):
605
357
  """Vision Agent Coder is an agentic framework that can output code based on a user
606
358
  request. It can plan tasks, retrieve relevant tools, write code, write tests and
@@ -616,23 +368,22 @@ class VisionAgentCoder(Agent):
616
368
 
617
369
  def __init__(
618
370
  self,
619
- planner: Optional[LMM] = None,
371
+ planner: Optional[Agent] = None,
620
372
  coder: Optional[LMM] = None,
621
373
  tester: Optional[LMM] = None,
622
374
  debugger: Optional[LMM] = None,
623
- tool_recommender: Optional[Sim] = None,
624
375
  verbosity: int = 0,
625
376
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
626
- code_sandbox_runtime: Optional[str] = None,
377
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
627
378
  ) -> None:
628
379
  """Initialize the Vision Agent Coder.
629
380
 
630
381
  Parameters:
631
- planner (Optional[LMM]): The planner model to use. Defaults to AnthropicLMM.
382
+ planner (Optional[Agent]): The planner model to use. Defaults to
383
+ AnthropicVisionAgentPlanner.
632
384
  coder (Optional[LMM]): The coder model to use. Defaults to AnthropicLMM.
633
385
  tester (Optional[LMM]): The tester model to use. Defaults to AnthropicLMM.
634
386
  debugger (Optional[LMM]): The debugger model to use. Defaults to AnthropicLMM.
635
- tool_recommender (Optional[Sim]): The tool recommender model to use.
636
387
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
637
388
  highest verbosity level which will output all intermediate debugging
638
389
  code.
@@ -641,14 +392,17 @@ class VisionAgentCoder(Agent):
641
392
  in a web application where multiple VisionAgentCoder instances are
642
393
  running in parallel. This callback ensures that the progress are not
643
394
  mixed up.
644
- code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
645
- code sandbox is used to run the generated code. It can be one of the
646
- following values: None, "local" or "e2b". If None, VisionAgentCoder
647
- will read the value from the environment variable CODE_SANDBOX_RUNTIME.
648
- If it's also None, the local python runtime environment will be used.
395
+ code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
396
+ it can be one of: None, "local" or "e2b". If None, it will read from
397
+ the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
398
+ object is provided it will use that.
649
399
  """
650
400
 
651
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
401
+ self.planner = (
402
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
403
+ if planner is None
404
+ else planner
405
+ )
652
406
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
653
407
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
654
408
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -656,21 +410,15 @@ class VisionAgentCoder(Agent):
656
410
  if self.verbosity > 0:
657
411
  _LOGGER.setLevel(logging.INFO)
658
412
 
659
- self.tool_recommender = (
660
- Sim(T.TOOLS_DF, sim_key="desc")
661
- if tool_recommender is None
662
- else tool_recommender
663
- )
664
413
  self.report_progress_callback = report_progress_callback
665
- self.code_sandbox_runtime = code_sandbox_runtime
414
+ self.code_interpreter = code_interpreter
666
415
 
667
416
  def __call__(
668
417
  self,
669
418
  input: Union[str, List[Message]],
670
419
  media: Optional[Union[str, Path]] = None,
671
420
  ) -> str:
672
- """Chat with VisionAgentCoder and return intermediate information regarding the
673
- task.
421
+ """Generate code based on a user request.
674
422
 
675
423
  Parameters:
676
424
  input (Union[str, List[Message]]): A conversation in the format of
@@ -686,46 +434,58 @@ class VisionAgentCoder(Agent):
686
434
  input = [{"role": "user", "content": input}]
687
435
  if media is not None:
688
436
  input[0]["media"] = [media]
689
- results = self.chat_with_workflow(input)
690
- results.pop("working_memory")
691
- return results["code"] # type: ignore
437
+ code_and_context = self.generate_code(input)
438
+ return code_and_context["code"] # type: ignore
692
439
 
693
- def chat_with_workflow(
440
+ def generate_code_from_plan(
694
441
  self,
695
442
  chat: List[Message],
696
- test_multi_plan: bool = True,
697
- display_visualization: bool = False,
698
- custom_tool_names: Optional[List[str]] = None,
443
+ plan_context: PlanContext,
444
+ code_interpreter: Optional[CodeInterpreter] = None,
699
445
  ) -> Dict[str, Any]:
700
- """Chat with VisionAgentCoder and return intermediate information regarding the
701
- task.
446
+ """Generates code and other intermediate outputs from a chat input and a plan.
447
+ The plan includes:
448
+ - plans: The plans generated by the planner.
449
+ - best_plan: The best plan selected by the planner.
450
+ - plan_thoughts: The thoughts of the planner, including any modifications
451
+ to the plan.
452
+ - tool_doc: The tool documentation for the best plan.
453
+ - tool_output: The tool output from the tools used by the best plan.
702
454
 
703
455
  Parameters:
704
- chat (List[Message]): A conversation
705
- in the format of:
706
- [{"role": "user", "content": "describe your task here..."}]
707
- or if it contains media files, it should be in the format of:
708
- [{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
709
- test_multi_plan (bool): If True, it will test tools for multiple plans and
710
- pick the best one based off of the tool results. If False, it will go
711
- with the first plan.
712
- display_visualization (bool): If True, it opens a new window locally to
713
- show the image(s) created by visualization code (if there is any).
714
- custom_tool_names (List[str]): A list of custom tools for the agent to pick
715
- and use. If not provided, default to full tool set from vision_agent.tools.
456
+ chat (List[Message]): A conversation in the format of
457
+ [{"role": "user", "content": "describe your task here..."}].
458
+ plan_context (PlanContext): The context of the plan, including the plans,
459
+ best_plan, plan_thoughts, tool_doc, and tool_output.
460
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
461
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
462
+ for the planner.
716
463
 
717
464
  Returns:
718
- Dict[str, Any]: A dictionary containing the code, test, test result, plan,
719
- and working memory of the agent.
465
+ Dict[str, Any]: A dictionary containing the code output by the
466
+ VisionAgentCoder and other intermediate outputs. include:
467
+ - status (str): Whether or not the agent completed or failed generating
468
+ the code.
469
+ - code (str): The code output by the VisionAgentCoder.
470
+ - test (str): The test output by the VisionAgentCoder.
471
+ - test_result (Execution): The result of the test execution.
472
+ - plans (Dict[str, Any]): The plans generated by the planner.
473
+ - plan_thoughts (str): The thoughts of the planner.
474
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
720
475
  """
721
-
722
476
  if not chat:
723
477
  raise ValueError("Chat cannot be empty.")
724
478
 
725
479
  # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
726
- with CodeInterpreterFactory.new_instance(
727
- code_sandbox_runtime=self.code_sandbox_runtime
728
- ) as code_interpreter:
480
+ code_interpreter = (
481
+ self.code_interpreter
482
+ if self.code_interpreter is not None
483
+ and not isinstance(self.code_interpreter, str)
484
+ else CodeInterpreterFactory.new_instance(
485
+ code_sandbox_runtime=self.code_interpreter,
486
+ )
487
+ )
488
+ with code_interpreter:
729
489
  chat = copy.deepcopy(chat)
730
490
  media_list = []
731
491
  for chat_i in chat:
@@ -759,74 +519,22 @@ class VisionAgentCoder(Agent):
759
519
  code = ""
760
520
  test = ""
761
521
  working_memory: List[Dict[str, str]] = []
762
- results = {"code": "", "test": "", "plan": []}
763
- plan = []
764
- success = False
765
-
766
- plans = self._create_plans(
767
- int_chat, custom_tool_names, working_memory, self.planner
768
- )
769
-
770
- if test_multi_plan:
771
- self._log_plans(plans, self.verbosity)
772
-
773
- tool_infos = retrieve_tools(
774
- plans,
775
- self.tool_recommender,
776
- self.log_progress,
777
- self.verbosity,
778
- )
779
-
780
- if test_multi_plan:
781
- plan_thoughts, tool_output_str = pick_plan(
782
- int_chat,
783
- plans,
784
- tool_infos["all"],
785
- self.coder,
786
- code_interpreter,
787
- media_list,
788
- self.log_progress,
789
- verbosity=self.verbosity,
790
- )
791
- best_plan = plan_thoughts["best_plan"]
792
- plan_thoughts_str = plan_thoughts["thoughts"]
793
- else:
794
- best_plan = list(plans.keys())[0]
795
- tool_output_str = ""
796
- plan_thoughts_str = ""
797
-
798
- if best_plan in plans and best_plan in tool_infos:
799
- plan_i = plans[best_plan]
800
- tool_info = tool_infos[best_plan]
801
- else:
802
- if self.verbosity >= 1:
803
- _LOGGER.warning(
804
- f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
805
- )
806
- k = list(plans.keys())[0]
807
- plan_i = plans[k]
808
- tool_info = tool_infos[k]
809
-
810
- self.log_progress(
811
- {
812
- "type": "log",
813
- "log_content": "Creating plans",
814
- "status": "completed",
815
- "payload": tool_info,
816
- }
817
- )
522
+ plan = plan_context.plans[plan_context.best_plan]
523
+ tool_doc = plan_context.tool_doc
524
+ tool_output_str = plan_context.tool_output
525
+ plan_thoughts_str = str(plan_context.plan_thoughts)
818
526
 
819
527
  if self.verbosity >= 1:
820
- plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]]
528
+ plan_fixed = [{"instructions": e} for e in plan["instructions"]]
821
529
  _LOGGER.info(
822
- f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
530
+ f"Picked best plan:\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
823
531
  )
824
532
 
825
533
  results = write_and_test_code(
826
534
  chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
827
- plan=f"\n{plan_i['thoughts']}\n-"
828
- + "\n-".join([e for e in plan_i["instructions"]]),
829
- tool_info=tool_info,
535
+ plan=f"\n{plan['thoughts']}\n-"
536
+ + "\n-".join([e for e in plan["instructions"]]),
537
+ tool_info=tool_doc,
830
538
  tool_output=tool_output_str,
831
539
  plan_thoughts=plan_thoughts_str,
832
540
  tool_utils=T.UTILITIES_DOCSTRING,
@@ -842,64 +550,82 @@ class VisionAgentCoder(Agent):
842
550
  success = cast(bool, results["success"])
843
551
  code = remove_installs_from_code(cast(str, results["code"]))
844
552
  test = remove_installs_from_code(cast(str, results["test"]))
845
- working_memory.extend(results["working_memory"]) # type: ignore
846
- plan.append({"code": code, "test": test, "plan": plan_i})
847
-
553
+ working_memory.extend(results["working_memory"])
848
554
  execution_result = cast(Execution, results["test_result"])
849
555
 
850
- if display_visualization:
851
- for res in execution_result.results:
852
- if res.png:
853
- b64_to_pil(res.png).show()
854
- if res.mp4:
855
- play_video(res.mp4)
856
-
857
556
  return {
858
557
  "status": "completed" if success else "failed",
859
558
  "code": DefaultImports.prepend_imports(code),
860
559
  "test": test,
861
560
  "test_result": execution_result,
862
- "plans": plans,
561
+ "plans": plan_context.plans,
863
562
  "plan_thoughts": plan_thoughts_str,
864
563
  "working_memory": working_memory,
865
564
  }
866
565
 
867
- def log_progress(self, data: Dict[str, Any]) -> None:
868
- if self.report_progress_callback is not None:
869
- self.report_progress_callback(data)
870
-
871
- def _create_plans(
566
+ def generate_code(
872
567
  self,
873
- int_chat: List[Message],
874
- customized_tool_names: Optional[List[str]],
875
- working_memory: List[Dict[str, str]],
876
- planner: LMM,
568
+ chat: List[Message],
569
+ test_multi_plan: bool = True,
570
+ custom_tool_names: Optional[List[str]] = None,
877
571
  ) -> Dict[str, Any]:
878
- self.log_progress(
879
- {
880
- "type": "log",
881
- "log_content": "Creating plans",
882
- "status": "started",
883
- }
884
- )
885
- plans = write_plans(
886
- int_chat,
887
- T.get_tool_descriptions_by_names(
888
- customized_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
889
- ),
890
- format_memory(working_memory),
891
- planner,
572
+ """Generates code and other intermediate outputs from a chat input.
573
+
574
+ Parameters:
575
+ chat (List[Message]): A conversation in the format of
576
+ [{"role": "user", "content": "describe your task here..."}].
577
+ test_multi_plan (bool): Whether to test multiple plans or just the best plan.
578
+ custom_tool_names (Optional[List[str]]): A list of custom tool names to use
579
+ for the planner.
580
+
581
+ Returns:
582
+ Dict[str, Any]: A dictionary containing the code output by the
583
+ VisionAgentCoder and other intermediate outputs. include:
584
+ - status (str): Whether or not the agent completed or failed generating
585
+ the code.
586
+ - code (str): The code output by the VisionAgentCoder.
587
+ - test (str): The test output by the VisionAgentCoder.
588
+ - test_result (Execution): The result of the test execution.
589
+ - plans (Dict[str, Any]): The plans generated by the planner.
590
+ - plan_thoughts (str): The thoughts of the planner.
591
+ - working_memory (List[Dict[str, str]]): The working memory of the agent.
592
+ """
593
+ if not chat:
594
+ raise ValueError("Chat cannot be empty.")
595
+
596
+ # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
597
+ code_interpreter = (
598
+ self.code_interpreter
599
+ if self.code_interpreter is not None
600
+ and not isinstance(self.code_interpreter, str)
601
+ else CodeInterpreterFactory.new_instance(
602
+ code_sandbox_runtime=self.code_interpreter,
603
+ )
892
604
  )
893
- return plans
605
+ with code_interpreter:
606
+ plan_context = self.planner.generate_plan( # type: ignore
607
+ chat,
608
+ test_multi_plan=test_multi_plan,
609
+ custom_tool_names=custom_tool_names,
610
+ code_interpreter=code_interpreter,
611
+ )
894
612
 
895
- def _log_plans(self, plans: Dict[str, Any], verbosity: int) -> None:
896
- if verbosity >= 1:
897
- for p in plans:
898
- # tabulate will fail if the keys are not the same for all elements
899
- p_fixed = [{"instructions": e} for e in plans[p]["instructions"]]
900
- _LOGGER.info(
901
- f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
902
- )
613
+ code_and_context = self.generate_code_from_plan(
614
+ chat,
615
+ plan_context,
616
+ code_interpreter=code_interpreter,
617
+ )
618
+ return code_and_context
619
+
620
+ def chat(self, chat: List[Message]) -> List[Message]:
621
+ chat = copy.deepcopy(chat)
622
+ code = self.generate_code(chat)
623
+ chat.append({"role": "agent", "content": code["code"]})
624
+ return chat
625
+
626
+ def log_progress(self, data: Dict[str, Any]) -> None:
627
+ if self.report_progress_callback is not None:
628
+ self.report_progress_callback(data)
903
629
 
904
630
 
905
631
  class OpenAIVisionAgentCoder(VisionAgentCoder):
@@ -907,17 +633,18 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
907
633
 
908
634
  def __init__(
909
635
  self,
910
- planner: Optional[LMM] = None,
636
+ planner: Optional[Agent] = None,
911
637
  coder: Optional[LMM] = None,
912
638
  tester: Optional[LMM] = None,
913
639
  debugger: Optional[LMM] = None,
914
- tool_recommender: Optional[Sim] = None,
915
640
  verbosity: int = 0,
916
641
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
917
- code_sandbox_runtime: Optional[str] = None,
642
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
918
643
  ) -> None:
919
644
  self.planner = (
920
- OpenAILMM(temperature=0.0, json_mode=True) if planner is None else planner
645
+ OpenAIVisionAgentPlanner(verbosity=verbosity)
646
+ if planner is None
647
+ else planner
921
648
  )
922
649
  self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
923
650
  self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
@@ -926,13 +653,8 @@ class OpenAIVisionAgentCoder(VisionAgentCoder):
926
653
  if self.verbosity > 0:
927
654
  _LOGGER.setLevel(logging.INFO)
928
655
 
929
- self.tool_recommender = (
930
- Sim(T.TOOLS_DF, sim_key="desc")
931
- if tool_recommender is None
932
- else tool_recommender
933
- )
934
656
  self.report_progress_callback = report_progress_callback
935
- self.code_sandbox_runtime = code_sandbox_runtime
657
+ self.code_interpreter = code_interpreter
936
658
 
937
659
 
938
660
  class AnthropicVisionAgentCoder(VisionAgentCoder):
@@ -940,17 +662,20 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
940
662
 
941
663
  def __init__(
942
664
  self,
943
- planner: Optional[LMM] = None,
665
+ planner: Optional[Agent] = None,
944
666
  coder: Optional[LMM] = None,
945
667
  tester: Optional[LMM] = None,
946
668
  debugger: Optional[LMM] = None,
947
- tool_recommender: Optional[Sim] = None,
948
669
  verbosity: int = 0,
949
670
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
950
- code_sandbox_runtime: Optional[str] = None,
671
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
951
672
  ) -> None:
952
673
  # NOTE: Claude doesn't have an official JSON mode
953
- self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
674
+ self.planner = (
675
+ AnthropicVisionAgentPlanner(verbosity=verbosity)
676
+ if planner is None
677
+ else planner
678
+ )
954
679
  self.coder = AnthropicLMM(temperature=0.0) if coder is None else coder
955
680
  self.tester = AnthropicLMM(temperature=0.0) if tester is None else tester
956
681
  self.debugger = AnthropicLMM(temperature=0.0) if debugger is None else debugger
@@ -958,15 +683,8 @@ class AnthropicVisionAgentCoder(VisionAgentCoder):
958
683
  if self.verbosity > 0:
959
684
  _LOGGER.setLevel(logging.INFO)
960
685
 
961
- # Anthropic does not offer any embedding models and instead recomends Voyage,
962
- # we're using OpenAI's embedder for now.
963
- self.tool_recommender = (
964
- Sim(T.TOOLS_DF, sim_key="desc")
965
- if tool_recommender is None
966
- else tool_recommender
967
- )
968
686
  self.report_progress_callback = report_progress_callback
969
- self.code_sandbox_runtime = code_sandbox_runtime
687
+ self.code_interpreter = code_interpreter
970
688
 
971
689
 
972
690
  class OllamaVisionAgentCoder(VisionAgentCoder):
@@ -988,17 +706,17 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
988
706
 
989
707
  def __init__(
990
708
  self,
991
- planner: Optional[LMM] = None,
709
+ planner: Optional[Agent] = None,
992
710
  coder: Optional[LMM] = None,
993
711
  tester: Optional[LMM] = None,
994
712
  debugger: Optional[LMM] = None,
995
- tool_recommender: Optional[Sim] = None,
996
713
  verbosity: int = 0,
997
714
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
715
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
998
716
  ) -> None:
999
717
  super().__init__(
1000
718
  planner=(
1001
- OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
719
+ OllamaVisionAgentPlanner(verbosity=verbosity)
1002
720
  if planner is None
1003
721
  else planner
1004
722
  ),
@@ -1017,13 +735,9 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
1017
735
  if debugger is None
1018
736
  else debugger
1019
737
  ),
1020
- tool_recommender=(
1021
- OllamaSim(T.TOOLS_DF, sim_key="desc")
1022
- if tool_recommender is None
1023
- else tool_recommender
1024
- ),
1025
738
  verbosity=verbosity,
1026
739
  report_progress_callback=report_progress_callback,
740
+ code_interpreter=code_interpreter,
1027
741
  )
1028
742
 
1029
743
 
@@ -1043,22 +757,22 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1043
757
 
1044
758
  def __init__(
1045
759
  self,
1046
- planner: Optional[LMM] = None,
760
+ planner: Optional[Agent] = None,
1047
761
  coder: Optional[LMM] = None,
1048
762
  tester: Optional[LMM] = None,
1049
763
  debugger: Optional[LMM] = None,
1050
- tool_recommender: Optional[Sim] = None,
1051
764
  verbosity: int = 0,
1052
765
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
766
+ code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
1053
767
  ) -> None:
1054
768
  """Initialize the Vision Agent Coder.
1055
769
 
1056
770
  Parameters:
1057
- planner (Optional[LMM]): The planner model to use. Defaults to OpenAILMM.
771
+ planner (Optional[Agent]): The planner model to use. Defaults to
772
+ AzureVisionAgentPlanner.
1058
773
  coder (Optional[LMM]): The coder model to use. Defaults to OpenAILMM.
1059
774
  tester (Optional[LMM]): The tester model to use. Defaults to OpenAILMM.
1060
775
  debugger (Optional[LMM]): The debugger model to
1061
- tool_recommender (Optional[Sim]): The tool recommender model to use.
1062
776
  verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
1063
777
  highest verbosity level which will output all intermediate debugging
1064
778
  code.
@@ -1069,7 +783,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1069
783
  """
1070
784
  super().__init__(
1071
785
  planner=(
1072
- AzureOpenAILMM(temperature=0.0, json_mode=True)
786
+ AzureVisionAgentPlanner(verbosity=verbosity)
1073
787
  if planner is None
1074
788
  else planner
1075
789
  ),
@@ -1078,11 +792,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
1078
792
  debugger=(
1079
793
  AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
1080
794
  ),
1081
- tool_recommender=(
1082
- AzureSim(T.TOOLS_DF, sim_key="desc")
1083
- if tool_recommender is None
1084
- else tool_recommender
1085
- ),
1086
795
  verbosity=verbosity,
1087
796
  report_progress_callback=report_progress_callback,
797
+ code_interpreter=code_interpreter,
1088
798
  )