vision-agent 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,778 +1,456 @@
1
+ import copy
1
2
  import json
2
3
  import logging
3
4
  import sys
4
- import tempfile
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+ from typing import Any, Callable, Dict, List, Optional, Union, cast
7
7
 
8
- from PIL import Image
8
+ from rich.console import Console
9
+ from rich.syntax import Syntax
9
10
  from tabulate import tabulate
10
11
 
11
- from vision_agent.agent.agent import Agent
12
- from vision_agent.agent.easytool_prompts import (
13
- ANSWER_GENERATE,
14
- ANSWER_SUMMARIZE,
15
- CHOOSE_PARAMETER,
16
- CHOOSE_TOOL,
17
- TASK_DECOMPOSE,
18
- TASK_TOPOLOGY,
19
- )
12
+ from vision_agent.agent import Agent
20
13
  from vision_agent.agent.vision_agent_prompts import (
21
- ANSWER_GENERATE_DEPENDS,
22
- ANSWER_SUMMARIZE_DEPENDS,
23
- CHOOSE_PARAMETER_DEPENDS,
24
- CHOOSE_TOOL_DEPENDS,
25
- TASK_DECOMPOSE_DEPENDS,
26
- VISION_AGENT_REFLECTION,
14
+ CODE,
15
+ FEEDBACK,
16
+ FIX_BUG,
17
+ FULL_TASK,
18
+ PLAN,
19
+ REFLECT,
20
+ SIMPLE_TEST,
21
+ USER_REQ,
27
22
  )
28
23
  from vision_agent.llm import LLM, OpenAILLM
29
24
  from vision_agent.lmm import LMM, OpenAILMM
30
- from vision_agent.tools import TOOLS
31
- from vision_agent.utils.image_utils import (
32
- convert_to_b64,
33
- overlay_bboxes,
34
- overlay_heat_map,
35
- overlay_masks,
36
- )
25
+ from vision_agent.tools import TOOL_DESCRIPTIONS, TOOLS_DF, UTILITIES_DOCSTRING
26
+ from vision_agent.utils import Execute
27
+ from vision_agent.utils.sim import Sim
37
28
 
38
29
  logging.basicConfig(stream=sys.stdout)
39
30
  _LOGGER = logging.getLogger(__name__)
40
31
  _MAX_TABULATE_COL_WIDTH = 80
32
+ _EXECUTE = Execute(600)
33
+ _CONSOLE = Console()
41
34
 
42
35
 
43
- def parse_json(s: str) -> Any:
44
- s = (
45
- s.replace(": True", ": true")
46
- .replace(": False", ": false")
47
- .replace(":True", ": true")
48
- .replace(":False", ": false")
49
- .replace("```", "")
50
- .strip()
51
- )
52
- return json.loads(s)
53
-
54
-
55
- def change_name(name: str) -> str:
56
- change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
57
- if name in change_list:
58
- name = "is_" + name.lower()
59
- return name
60
-
61
-
62
- def format_tools(tools: Dict[int, Any]) -> str:
63
- # Format this way so it's clear what the ID's are
64
- tool_str = ""
65
- for key in tools:
66
- tool_str += f"ID: {key} - {tools[key]}\n"
67
- return tool_str
68
-
69
-
70
- def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
71
- usage = []
72
- name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
73
- for tool_res in tool_result:
74
- if "tool_name" in tool_res:
75
- usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))
76
-
77
- usage_str = ""
78
- for tool_name, tool_usage in usage:
79
- usage_str += f"{tool_name} - {tool_usage}\n"
80
- return usage_str
81
-
82
-
83
- def topological_sort(tasks: List[Dict]) -> List[Dict]:
84
- in_degree = {task["id"]: 0 for task in tasks}
85
- for task in tasks:
86
- for dep in task["dep"]:
87
- if dep in in_degree:
88
- in_degree[task["id"]] += 1
89
-
90
- queue = [task for task in tasks if in_degree[task["id"]] == 0]
91
- sorted_order = []
92
-
93
- while queue:
94
- current = queue.pop(0)
95
- sorted_order.append(current)
96
-
97
- for task in tasks:
98
- if current["id"] in task["dep"]:
99
- in_degree[task["id"]] -= 1
100
- if in_degree[task["id"]] == 0:
101
- queue.append(task)
102
-
103
- if len(sorted_order) != len(tasks):
104
- completed_ids = set([task["id"] for task in sorted_order])
105
- remaining_tasks = [task for task in tasks if task["id"] not in completed_ids]
106
- sorted_order.extend(remaining_tasks)
107
- return sorted_order
108
-
109
-
110
- def task_decompose(
111
- model: Union[LLM, LMM, Agent],
112
- question: str,
113
- tools: Dict[int, Any],
114
- reflections: str,
115
- ) -> Optional[Dict]:
116
- if reflections:
117
- prompt = TASK_DECOMPOSE_DEPENDS.format(
118
- question=question, tools=format_tools(tools), reflections=reflections
119
- )
120
- else:
121
- prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools))
122
- tries = 0
123
- str_result = ""
124
- while True:
125
- try:
126
- str_result = model(prompt)
127
- result = parse_json(str_result)
128
- return result["Tasks"] # type: ignore
129
- except Exception:
130
- if tries > 10:
131
- _LOGGER.error(f"Failed task_decompose on: {str_result}")
132
- return None
133
- tries += 1
134
- continue
135
-
136
-
137
- def task_topology(
138
- model: Union[LLM, LMM, Agent], question: str, task_list: List[Dict]
139
- ) -> List[Dict[str, Any]]:
140
- prompt = TASK_TOPOLOGY.format(question=question, task_list=task_list)
141
- tries = 0
142
- str_result = ""
143
- while True:
144
- try:
145
- str_result = model(prompt)
146
- result = parse_json(str_result)
147
- for elt in result["Tasks"]:
148
- if isinstance(elt["dep"], str):
149
- elt["dep"] = [int(dep) for dep in elt["dep"].split(",")]
150
- elif isinstance(elt["dep"], int):
151
- elt["dep"] = [elt["dep"]]
152
- elif isinstance(elt["dep"], list):
153
- elt["dep"] = [int(dep) for dep in elt["dep"]]
154
- return result["Tasks"] # type: ignore
155
- except Exception:
156
- if tries > 10:
157
- _LOGGER.error(f"Failed task_topology on: {str_result}")
158
- return task_list
159
- tries += 1
160
- continue
161
-
162
-
163
- def choose_tool(
164
- model: Union[LLM, LMM, Agent],
165
- question: str,
166
- tools: Dict[int, Any],
167
- reflections: str,
168
- ) -> Optional[int]:
169
- if reflections:
170
- prompt = CHOOSE_TOOL_DEPENDS.format(
171
- question=question, tools=format_tools(tools), reflections=reflections
172
- )
173
- else:
174
- prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools))
175
- tries = 0
176
- str_result = ""
177
- while True:
178
- try:
179
- str_result = model(prompt)
180
- result = parse_json(str_result)
181
- return result["ID"] # type: ignore
182
- except Exception:
183
- if tries > 10:
184
- _LOGGER.error(f"Failed choose_tool on: {str_result}")
185
- return None
186
- tries += 1
187
- continue
188
-
189
-
190
- def choose_parameter(
191
- model: Union[LLM, LMM, Agent],
192
- question: str,
193
- tool_usage: Dict,
194
- previous_log: str,
195
- reflections: str,
196
- ) -> Optional[Any]:
197
- # TODO: should format tool_usage
198
- if reflections:
199
- prompt = CHOOSE_PARAMETER_DEPENDS.format(
200
- question=question,
201
- tool_usage=tool_usage,
202
- previous_log=previous_log,
203
- reflections=reflections,
204
- )
205
- else:
206
- prompt = CHOOSE_PARAMETER.format(
207
- question=question, tool_usage=tool_usage, previous_log=previous_log
208
- )
209
- tries = 0
210
- str_result = ""
211
- while True:
212
- try:
213
- str_result = model(prompt)
214
- result = parse_json(str_result)
215
- return result["Parameters"]
216
- except Exception:
217
- if tries > 10:
218
- _LOGGER.error(f"Failed choose_parameter on: {str_result}")
219
- return None
220
- tries += 1
221
- continue
222
-
223
-
224
- def answer_generate(
225
- model: Union[LLM, LMM, Agent],
226
- question: str,
227
- call_results: str,
228
- previous_log: str,
229
- reflections: str,
230
- ) -> str:
231
- if reflections:
232
- prompt = ANSWER_GENERATE_DEPENDS.format(
233
- question=question,
234
- call_results=call_results,
235
- previous_log=previous_log,
236
- reflections=reflections,
237
- )
238
- else:
239
- prompt = ANSWER_GENERATE.format(
240
- question=question, call_results=call_results, previous_log=previous_log
36
+ def format_memory(memory: List[Dict[str, str]]) -> str:
37
+ return FEEDBACK.format(
38
+ feedback="\n".join(
39
+ [
40
+ f"### Feedback {i}:\nCode: ```python\n{m['code']}\n```\nFeedback: {m['feedback']}\n"
41
+ for i, m in enumerate(memory)
42
+ ]
241
43
  )
242
- return model(prompt)
44
+ )
243
45
 
244
46
 
245
- def answer_summarize(
246
- model: Union[LLM, LMM, Agent], question: str, answers: List[Dict], reflections: str
247
- ) -> str:
248
- if reflections:
249
- prompt = ANSWER_SUMMARIZE_DEPENDS.format(
250
- question=question, answers=answers, reflections=reflections
251
- )
47
+ def extract_code(code: str) -> str:
48
+ if "\n```python" in code:
49
+ start = "\n```python"
50
+ elif "```python" in code:
51
+ start = "```python"
252
52
  else:
253
- prompt = ANSWER_SUMMARIZE.format(question=question, answers=answers)
254
- return model(prompt)
53
+ return code
54
+
55
+ code = code[code.find(start) + len(start) :]
56
+ code = code[: code.find("```")]
57
+ if code.startswith("python\n"):
58
+ code = code[len("python\n") :]
59
+ return code
255
60
 
256
61
 
257
- def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
62
+ def extract_json(json_str: str) -> Dict[str, Any]:
258
63
  try:
259
- return tool()(**parameters)
260
- except Exception as e:
261
- _LOGGER.error(f"Failed function_call on: {e}")
262
- # return error message so it can self-correct
263
- return str(e)
264
-
265
-
266
- def self_reflect(
267
- reflect_model: Union[LLM, LMM],
268
- question: str,
269
- tools: Dict[int, Any],
270
- tool_result: List[Dict],
271
- final_answer: str,
272
- images: Optional[Sequence[Union[str, Path]]] = None,
273
- ) -> str:
274
- prompt = VISION_AGENT_REFLECTION.format(
275
- question=question,
276
- tools=format_tools({k: v["description"] for k, v in tools.items()}),
277
- tool_usage=format_tool_usage(tools, tool_result),
278
- tool_results=str(tool_result),
279
- final_answer=final_answer,
64
+ json_dict = json.loads(json_str)
65
+ except json.JSONDecodeError:
66
+ if "```json" in json_str:
67
+ json_str = json_str[json_str.find("```json") + len("```json") :]
68
+ json_str = json_str[: json_str.find("```")]
69
+ elif "```" in json_str:
70
+ json_str = json_str[json_str.find("```") + len("```") :]
71
+ # get the last ``` not one from an intermediate string
72
+ json_str = json_str[: json_str.find("}```")]
73
+ json_dict = json.loads(json_str)
74
+ return json_dict # type: ignore
75
+
76
+
77
+ def write_plan(
78
+ chat: List[Dict[str, str]],
79
+ tool_desc: str,
80
+ working_memory: str,
81
+ model: Union[LLM, LMM],
82
+ media: Optional[List[Union[str, Path]]] = None,
83
+ ) -> List[Dict[str, str]]:
84
+ chat = copy.deepcopy(chat)
85
+ if chat[-1]["role"] != "user":
86
+ raise ValueError("Last chat message must be from the user.")
87
+
88
+ user_request = chat[-1]["content"]
89
+ context = USER_REQ.format(user_request=user_request)
90
+ prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
91
+ chat[-1]["content"] = prompt
92
+ if isinstance(model, OpenAILMM):
93
+ return extract_json(model.chat(chat, images=media))["plan"] # type: ignore
94
+ else:
95
+ return extract_json(model.chat(chat))["plan"] # type: ignore
96
+
97
+
98
+ def reflect(
99
+ chat: List[Dict[str, str]],
100
+ plan: str,
101
+ code: str,
102
+ model: LLM,
103
+ ) -> Dict[str, Union[str, bool]]:
104
+ chat = copy.deepcopy(chat)
105
+ if chat[-1]["role"] != "user":
106
+ raise ValueError("Last chat message must be from the user.")
107
+
108
+ user_request = chat[-1]["content"]
109
+ context = USER_REQ.format(user_request=user_request)
110
+ prompt = REFLECT.format(context=context, plan=plan, code=code)
111
+ chat[-1]["content"] = prompt
112
+ return extract_json(model.chat(chat))
113
+
114
+
115
+ def write_and_test_code(
116
+ task: str,
117
+ tool_info: str,
118
+ tool_utils: str,
119
+ working_memory: str,
120
+ coder: LLM,
121
+ tester: LLM,
122
+ debugger: LLM,
123
+ log_progress: Callable[[Dict[str, Any]], None],
124
+ verbosity: int = 0,
125
+ max_retries: int = 3,
126
+ input_media: Optional[Union[str, Path]] = None,
127
+ ) -> Dict[str, Any]:
128
+ code = extract_code(
129
+ coder(CODE.format(docstring=tool_info, question=task, feedback=working_memory))
130
+ )
131
+ test = extract_code(
132
+ tester(
133
+ SIMPLE_TEST.format(
134
+ docstring=tool_utils,
135
+ question=task,
136
+ code=code,
137
+ feedback=working_memory,
138
+ media=input_media,
139
+ )
140
+ )
280
141
  )
281
- if (
282
- issubclass(type(reflect_model), LMM)
283
- and images is not None
284
- and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
285
- ):
286
- return reflect_model(prompt, images=images) # type: ignore
287
- return reflect_model(prompt)
288
142
 
143
+ success, result = _EXECUTE.run_isolation(f"{code}\n{test}")
144
+ if verbosity == 2:
145
+ _LOGGER.info("Initial code and tests:")
146
+ log_progress(
147
+ {
148
+ "log": "Code:",
149
+ "code": code,
150
+ }
151
+ )
152
+ log_progress(
153
+ {
154
+ "log": "Test:",
155
+ "code": test,
156
+ }
157
+ )
158
+ _CONSOLE.print(
159
+ Syntax(f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True)
160
+ )
161
+ log_progress(
162
+ {
163
+ "log": "Result:",
164
+ "result": result,
165
+ }
166
+ )
167
+ _LOGGER.info(f"Initial result: {result}")
168
+
169
+ count = 0
170
+ new_working_memory = []
171
+ while not success and count < max_retries:
172
+ fixed_code_and_test = extract_json(
173
+ debugger(
174
+ FIX_BUG.format(
175
+ code=code, tests=test, result=result, feedback=working_memory
176
+ )
177
+ )
178
+ )
179
+ if fixed_code_and_test["code"].strip() != "":
180
+ code = extract_code(fixed_code_and_test["code"])
181
+ if fixed_code_and_test["test"].strip() != "":
182
+ test = extract_code(fixed_code_and_test["test"])
183
+ new_working_memory.append(
184
+ {"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]}
185
+ )
289
186
 
290
- def parse_reflect(reflect: str) -> Any:
291
- reflect = reflect.strip()
292
- try:
293
- return parse_json(reflect)
294
- except Exception:
295
- _LOGGER.error(f"Failed parse json reflection: {reflect}")
296
- # LMMs have a hard time following directions, so make the criteria less strict
297
- finish = (
298
- "finish" in reflect.lower() and len(reflect) < 100
299
- ) or "finish" in reflect.lower()[-10:]
300
- return {"Finish": finish, "Reflection": reflect}
301
-
302
-
303
- def _handle_extract_frames(
304
- image_to_data: Dict[str, Dict], tool_result: Dict
305
- ) -> Dict[str, Dict]:
306
- image_to_data = image_to_data.copy()
307
- # handle extract_frames_ case, useful if it extracts frames but doesn't do
308
- # any following processing
309
- for video_file_output in tool_result["call_results"]:
310
- # When the video tool is run with wrong parameters, exit the loop
311
- if not isinstance(video_file_output, tuple) or len(video_file_output) < 2:
312
- break
313
- for frame, _ in video_file_output:
314
- image = frame
315
- if image not in image_to_data:
316
- image_to_data[image] = {
317
- "bboxes": [],
318
- "masks": [],
319
- "heat_map": [],
320
- "labels": [],
321
- "scores": [],
187
+ success, result = _EXECUTE.run_isolation(f"{code}\n{test}")
188
+ if verbosity == 2:
189
+ log_progress(
190
+ {
191
+ "log": f"Debug attempt {count + 1}, reflection:",
192
+ "result": fixed_code_and_test["reflections"],
322
193
  }
323
- return image_to_data
324
-
325
-
326
- def _handle_viz_tools(
327
- image_to_data: Dict[str, Dict], tool_result: Dict
328
- ) -> Dict[str, Dict]:
329
- image_to_data = image_to_data.copy()
330
-
331
- # handle grounding_sam_ and grounding_dino_
332
- parameters = tool_result["parameters"]
333
- # parameters can either be a dictionary or list, parameters can also be malformed
334
- # becaus the LLM builds them
335
- if isinstance(parameters, dict):
336
- if "image" not in parameters:
337
- return image_to_data
338
- parameters = [parameters]
339
- elif isinstance(tool_result["parameters"], list):
340
- if len(tool_result["parameters"]) < 1 or (
341
- "image" not in tool_result["parameters"][0]
342
- ):
343
- return image_to_data
344
-
345
- for param, call_result in zip(parameters, tool_result["call_results"]):
346
- # Calls can fail, so we need to check if the call was successful. It can either:
347
- # 1. return a str or some error that's not a dictionary
348
- # 2. return a dictionary but not have the necessary keys
349
-
350
- if not isinstance(call_result, dict) or (
351
- "bboxes" not in call_result
352
- and "mask" not in call_result
353
- and "heat_map" not in call_result
354
- ):
355
- return image_to_data
356
-
357
- # if the call was successful, then we can add the image data
358
- image = param["image"]
359
- if image not in image_to_data:
360
- image_to_data[image] = {
361
- "bboxes": [],
362
- "masks": [],
363
- "heat_map": [],
364
- "labels": [],
365
- "scores": [],
366
- }
194
+ )
195
+ _LOGGER.info(
196
+ f"Debug attempt {count + 1}, reflection: {fixed_code_and_test['reflections']}"
197
+ )
198
+ _CONSOLE.print(
199
+ Syntax(
200
+ f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True
201
+ )
202
+ )
203
+ log_progress(
204
+ {
205
+ "log": "Debug result:",
206
+ "result": result,
207
+ }
208
+ )
209
+ _LOGGER.info(f"Debug result: {result}")
210
+ count += 1
367
211
 
368
- image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
369
- image_to_data[image]["labels"].extend(call_result.get("labels", []))
370
- image_to_data[image]["scores"].extend(call_result.get("scores", []))
371
- image_to_data[image]["masks"].extend(call_result.get("masks", []))
372
- # only single heatmap is returned
373
- if "heat_map" in call_result:
374
- image_to_data[image]["heat_map"].append(call_result["heat_map"])
375
- if "mask_shape" in call_result:
376
- image_to_data[image]["mask_shape"] = call_result["mask_shape"]
377
-
378
- return image_to_data
379
-
380
-
381
- def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence:
382
- if n <= 0:
383
- return []
384
- elif len(lst) == 0:
385
- return []
386
- elif n == 1:
387
- return [lst[0]]
388
- elif n >= len(lst):
389
- return lst
390
-
391
- spacing = (len(lst) - 1) / (n - 1)
392
- return [lst[round(spacing * i)] for i in range(n)]
393
-
394
-
395
- def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
396
- image_to_data: Dict[str, Dict] = {}
397
- for tool_result in all_tool_results:
398
- # only handle bbox/mask tools or frame extraction
399
- if tool_result["tool_name"] not in [
400
- "grounding_sam_",
401
- "grounding_dino_",
402
- "extract_frames_",
403
- "dinov_",
404
- "zero_shot_counting_",
405
- "visual_prompt_counting_",
406
- "ocr_",
407
- ]:
408
- continue
409
-
410
- if tool_result["tool_name"] == "extract_frames_":
411
- image_to_data = _handle_extract_frames(image_to_data, tool_result)
412
- else:
413
- image_to_data = _handle_viz_tools(image_to_data, tool_result)
414
-
415
- visualized_images = []
416
- for image_str in image_to_data:
417
- image_path = Path(image_str)
418
- image_data = image_to_data[image_str]
419
- if "_counting_" in tool_result["tool_name"]:
420
- image = overlay_heat_map(image_path, image_data)
421
- else:
422
- image = overlay_masks(image_path, image_data)
423
- image = overlay_bboxes(image, image_data)
424
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
425
- image.save(f.name)
426
- visualized_images.append(f.name)
427
- return visualized_images
212
+ if verbosity >= 1:
213
+ _LOGGER.info("Final code and tests:")
214
+ _CONSOLE.print(
215
+ Syntax(f"{code}\n{test}", "python", theme="gruvbox-dark", line_numbers=True)
216
+ )
217
+ _LOGGER.info(f"Final Result: {result}")
218
+
219
+ return {
220
+ "code": code,
221
+ "test": test,
222
+ "success": success,
223
+ "test_result": result,
224
+ "working_memory": new_working_memory,
225
+ }
226
+
227
+
228
+ def retrieve_tools(
229
+ plan: List[Dict[str, str]],
230
+ tool_recommender: Sim,
231
+ log_progress: Callable[[Dict[str, Any]], None],
232
+ verbosity: int = 0,
233
+ ) -> str:
234
+ tool_info = []
235
+ tool_desc = []
236
+ for task in plan:
237
+ tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
238
+ tool_info.extend([e["doc"] for e in tools])
239
+ tool_desc.extend([e["desc"] for e in tools])
240
+ if verbosity == 2:
241
+ log_progress(
242
+ {
243
+ "log": "Retrieved tools:",
244
+ "tools": tool_desc,
245
+ }
246
+ )
247
+ _LOGGER.info(f"Tools: {tool_desc}")
248
+ tool_info_set = set(tool_info)
249
+ return "\n\n".join(tool_info_set)
428
250
 
429
251
 
430
252
  class VisionAgent(Agent):
431
- r"""Vision Agent is an agent framework that utilizes tools as well as self
432
- reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
433
- off of EasyTool https://arxiv.org/abs/2401.06201 and Reflexion
434
- https://arxiv.org/abs/2303.11366 where it will attempt to complete a task and then
435
- reflect on whether or not it was able to accomplish the task based off of the plan
436
- and final results, if not it will redo the task with this newly added reflection.
253
+ """Vision Agent is an agentic framework that can output code based on a user
254
+ request. It can plan tasks, retrieve relevant tools, write code, write tests and
255
+ reflect on failed test cases to debug code. It is inspired by AgentCoder
256
+ https://arxiv.org/abs/2312.13010 and Data Interpeter
257
+ https://arxiv.org/abs/2402.18679
437
258
 
438
259
  Example
439
260
  -------
440
- >>> from vision_agent.agent import VisionAgent
261
+ >>> from vision_agent import VisionAgent
441
262
  >>> agent = VisionAgent()
442
- >>> resp = agent("If red tomatoes cost $5 each and yellow tomatoes cost $2.50 each, what is the total cost of all the tomatoes in the image?", image="tomatoes.jpg")
443
- >>> print(resp)
444
- "The total cost is $57.50."
263
+ >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
445
264
  """
446
265
 
447
266
  def __init__(
448
267
  self,
449
- task_model: Optional[Union[LLM, LMM]] = None,
450
- answer_model: Optional[Union[LLM, LMM]] = None,
451
- reflect_model: Optional[Union[LLM, LMM]] = None,
452
- max_retries: int = 2,
453
- verbose: bool = False,
268
+ planner: Optional[LLM] = None,
269
+ coder: Optional[LLM] = None,
270
+ tester: Optional[LLM] = None,
271
+ debugger: Optional[LLM] = None,
272
+ tool_recommender: Optional[Sim] = None,
273
+ verbosity: int = 0,
454
274
  report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
455
- ):
456
- """VisionAgent constructor.
275
+ ) -> None:
276
+ """Initialize the Vision Agent.
457
277
 
458
278
  Parameters:
459
- task_model: the model to use for task decomposition.
460
- answer_model: the model to use for reasoning and concluding the answer.
461
- reflect_model: the model to use for self reflection.
462
- max_retries: maximum number of retries to attempt to complete the task.
463
- verbose: whether to print more logs.
464
- report_progress_callback: a callback to report the progress of the agent. This is useful for streaming logs in a web application where multiple VisionAgent instances are running in parallel. This callback ensures that the progress are not mixed up.
279
+ planner (Optional[LLM]): The planner model to use. Defaults to OpenAILLM.
280
+ coder (Optional[LLM]): The coder model to use. Defaults to OpenAILLM.
281
+ tester (Optional[LLM]): The tester model to use. Defaults to OpenAILLM.
282
+ debugger (Optional[LLM]): The debugger model to
283
+ tool_recommender (Optional[Sim]): The tool recommender model to use.
284
+ verbosity (int): The verbosity level of the agent. Defaults to 0. 2 is the
285
+ highest verbosity level which will output all intermediate debugging
286
+ code.
287
+ report_progress_callback: a callback to report the progress of the agent.
288
+ This is useful for streaming logs in a web application where multiple
289
+ VisionAgent instances are running in parallel. This callback ensures
290
+ that the progress are not mixed up.
465
291
  """
466
- self.task_model = (
467
- OpenAILLM(model_name="gpt-4-turbo", json_mode=True, temperature=0.0)
468
- if task_model is None
469
- else task_model
292
+
293
+ self.planner = (
294
+ OpenAILLM(temperature=0.0, json_mode=True) if planner is None else planner
470
295
  )
471
- self.answer_model = (
472
- OpenAILLM(model_name="gpt-4-turbo", temperature=0.0)
473
- if answer_model is None
474
- else answer_model
296
+ self.coder = OpenAILLM(temperature=0.0) if coder is None else coder
297
+ self.tester = OpenAILLM(temperature=0.0) if tester is None else tester
298
+ self.debugger = (
299
+ OpenAILLM(temperature=0.0, json_mode=True) if debugger is None else debugger
475
300
  )
476
- self.reflect_model = (
477
- OpenAILMM(model_name="gpt-4-turbo", json_mode=True, temperature=0.0)
478
- if reflect_model is None
479
- else reflect_model
301
+
302
+ self.tool_recommender = (
303
+ Sim(TOOLS_DF, sim_key="desc")
304
+ if tool_recommender is None
305
+ else tool_recommender
480
306
  )
481
- self.max_retries = max_retries
482
- self.tools = TOOLS
307
+ self.verbosity = verbosity
308
+ self.max_retries = 2
483
309
  self.report_progress_callback = report_progress_callback
484
- if verbose:
485
- _LOGGER.setLevel(logging.INFO)
486
310
 
487
311
  def __call__(
488
312
  self,
489
313
  input: Union[List[Dict[str, str]], str],
490
- image: Optional[Union[str, Path]] = None,
491
- reference_data: Optional[Dict[str, str]] = None,
492
- visualize_output: Optional[bool] = False,
493
- self_reflection: Optional[bool] = True,
314
+ media: Optional[Union[str, Path]] = None,
494
315
  ) -> str:
495
- """Invoke the vision agent.
316
+ """Chat with Vision Agent and return intermediate information regarding the task.
496
317
 
497
318
  Parameters:
498
- chat: A conversation in the format of
319
+ chat (List[Dict[str, str]]): A conversation in the format of
499
320
  [{"role": "user", "content": "describe your task here..."}].
500
- image: The input image referenced in the chat parameter.
501
- reference_data: A dictionary containing the reference image, mask or bounding
502
- box in the format of:
503
- {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
504
- where the bounding box coordinates are normalized.
505
- visualize_output: Whether to visualize the output.
506
- self_reflection: boolean to enable and disable self reflection.
321
+ media (Optional[Union[str, Path]]): The media file to be used in the task.
322
+ self_reflection (bool): Whether to reflect on the task and debug the code.
507
323
 
508
324
  Returns:
509
- The result of the vision agent in text.
325
+ str: The code output by the Vision Agent.
510
326
  """
327
+
511
328
  if isinstance(input, str):
512
329
  input = [{"role": "user", "content": input}]
513
- return self.chat(
514
- input,
515
- image=image,
516
- visualize_output=visualize_output,
517
- reference_data=reference_data,
518
- self_reflection=self_reflection,
519
- )
520
-
521
- def log_progress(self, data: Dict[str, Any]) -> None:
522
- _LOGGER.info(data)
523
- if self.report_progress_callback:
524
- self.report_progress_callback(data)
525
-
526
- def _report_visualization_via_callback(
527
- self, images: Sequence[Union[str, Path]]
528
- ) -> None:
529
- """This is intended for streaming the visualization images via the callback to the client side."""
530
- if self.report_progress_callback:
531
- self.report_progress_callback({"log": "<VIZ>"})
532
- if images:
533
- for img in images:
534
- self.report_progress_callback(
535
- {"log": f"<IMG>base:64{convert_to_b64(img)}</IMG>"}
536
- )
537
- self.report_progress_callback({"log": "</VIZ>"})
330
+ results = self.chat_with_workflow(input, media)
331
+ results.pop("working_memory")
332
+ return results # type: ignore
538
333
 
539
334
  def chat_with_workflow(
540
335
  self,
541
336
  chat: List[Dict[str, str]],
542
- image: Optional[Union[str, Path]] = None,
543
- reference_data: Optional[Dict[str, str]] = None,
544
- visualize_output: Optional[bool] = False,
545
- self_reflection: Optional[bool] = True,
546
- ) -> Tuple[str, List[Dict]]:
547
- """Chat with the vision agent and return the final answer and all tool results.
337
+ media: Optional[Union[str, Path]] = None,
338
+ self_reflection: bool = False,
339
+ ) -> Dict[str, Any]:
340
+ """Chat with Vision Agent and return intermediate information regarding the task.
548
341
 
549
342
  Parameters:
550
- chat: A conversation in the format of
343
+ chat (List[Dict[str, str]]): A conversation in the format of
551
344
  [{"role": "user", "content": "describe your task here..."}].
552
- image: The input image referenced in the chat parameter.
553
- reference_data: A dictionary containing the reference image, mask or bounding
554
- box in the format of:
555
- {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
556
- where the bounding box coordinates are normalized.
557
- visualize_output: Whether to visualize the output.
558
- self_reflection: boolean to enable and disable self reflection.
345
+ media (Optional[Union[str, Path]]): The media file to be used in the task.
346
+ self_reflection (bool): Whether to reflect on the task and debug the code.
559
347
 
560
348
  Returns:
561
- A tuple where the first item is the final answer and the second item is a
562
- list of all the tool results. The last item in the tool results also
563
- contains the visualized output.
349
+ Dict[str, Any]: A dictionary containing the code, test, test result, plan,
350
+ and working memory of the agent.
564
351
  """
565
- if len(chat) == 0:
566
- raise ValueError("Input cannot be empty.")
567
-
568
- question = chat[0]["content"]
569
- if image:
570
- question += f" Image name: {image}"
571
- if reference_data:
572
- question += (
573
- f" Reference image: {reference_data['image']}"
574
- if "image" in reference_data
575
- else ""
576
- )
577
- question += (
578
- f" Reference mask: {reference_data['mask']}"
579
- if "mask" in reference_data
580
- else ""
581
- )
582
- question += (
583
- f" Reference bbox: {reference_data['bbox']}"
584
- if "bbox" in reference_data
585
- else ""
586
- )
587
-
588
- reflections = ""
589
- final_answer = ""
590
- all_tool_results: List[Dict] = []
591
352
 
592
- for _ in range(self.max_retries):
593
- task_list = self.create_tasks(
594
- self.task_model, question, self.tools, reflections
353
+ if len(chat) == 0:
354
+ raise ValueError("Chat cannot be empty.")
355
+
356
+ if media is not None:
357
+ for chat_i in chat:
358
+ if chat_i["role"] == "user":
359
+ chat_i["content"] += f" Image name {media}"
360
+
361
+ code = ""
362
+ test = ""
363
+ working_memory: List[Dict[str, str]] = []
364
+ results = {"code": "", "test": "", "plan": []}
365
+ plan = []
366
+ success = False
367
+ retries = 0
368
+
369
+ while not success and retries < self.max_retries:
370
+ plan_i = write_plan(
371
+ chat,
372
+ TOOL_DESCRIPTIONS,
373
+ format_memory(working_memory),
374
+ self.planner,
375
+ media=[media] if media else None,
595
376
  )
596
-
597
- task_depend = {"Original Question": question}
598
- previous_log = ""
599
- answers = []
600
- for task in task_list:
601
- task_depend[task["id"]] = {"task": task["task"], "answer": "", "call_result": ""} # type: ignore
602
- all_tool_results = []
603
-
604
- for task in task_list:
605
- task_str = task["task"]
606
- previous_log = str(task_depend)
607
- tool_results, call_results = self.retrieval(
608
- self.task_model,
609
- task_str,
610
- self.tools,
611
- previous_log,
612
- reflections,
613
- )
614
- answer = answer_generate(
615
- self.answer_model, task_str, call_results, previous_log, reflections
377
+ plan_i_str = "\n-".join([e["instructions"] for e in plan_i])
378
+ if self.verbosity >= 1:
379
+ self.log_progress(
380
+ {
381
+ "log": "Going to run the following plan(s) in sequence:\n",
382
+ "plan": plan_i,
383
+ }
616
384
  )
617
385
 
618
- tool_results["answer"] = answer
619
- all_tool_results.append(tool_results)
386
+ _LOGGER.info(
387
+ f"""
388
+ {tabulate(tabular_data=plan_i, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
389
+ )
620
390
 
621
- self.log_progress({"log": f"\tCall Result: {call_results}"})
622
- self.log_progress({"log": f"\tAnswer: {answer}"})
623
- answers.append({"task": task_str, "answer": answer})
624
- task_depend[task["id"]]["answer"] = answer # type: ignore
625
- task_depend[task["id"]]["call_result"] = call_results # type: ignore
626
- final_answer = answer_summarize(
627
- self.answer_model, question, answers, reflections
391
+ tool_info = retrieve_tools(
392
+ plan_i,
393
+ self.tool_recommender,
394
+ self.log_progress,
395
+ self.verbosity,
628
396
  )
629
- visualized_output = visualize_result(all_tool_results)
630
- all_tool_results.append({"visualized_output": visualized_output})
631
- if len(visualized_output) > 0:
632
- reflection_images = sample_n_evenly_spaced(visualized_output, 3)
633
- elif image is not None:
634
- reflection_images = [image]
635
- else:
636
- reflection_images = None
397
+ results = write_and_test_code(
398
+ FULL_TASK.format(user_request=chat[0]["content"], subtasks=plan_i_str),
399
+ tool_info,
400
+ UTILITIES_DOCSTRING,
401
+ format_memory(working_memory),
402
+ self.coder,
403
+ self.tester,
404
+ self.debugger,
405
+ self.log_progress,
406
+ verbosity=self.verbosity,
407
+ input_media=media,
408
+ )
409
+ success = cast(bool, results["success"])
410
+ code = cast(str, results["code"])
411
+ test = cast(str, results["test"])
412
+ working_memory.extend(results["working_memory"]) # type: ignore
413
+ plan.append({"code": code, "test": test, "plan": plan_i})
637
414
 
638
415
  if self_reflection:
639
- reflection = self_reflect(
640
- self.reflect_model,
641
- question,
642
- self.tools,
643
- all_tool_results,
644
- final_answer,
645
- reflection_images,
646
- )
647
- self.log_progress({"log": f"Reflection: {reflection}"})
648
- parsed_reflection = parse_reflect(reflection)
649
- if parsed_reflection["Finish"]:
650
- break
651
- else:
652
- reflections += "\n" + parsed_reflection["Reflection"]
653
- else:
654
- self.log_progress(
655
- {"log": "Self Reflection skipped based on user request."}
416
+ reflection = reflect(
417
+ chat,
418
+ FULL_TASK.format(
419
+ user_request=chat[0]["content"], subtasks=plan_i_str
420
+ ),
421
+ code,
422
+ self.planner,
656
423
  )
657
- break
658
- # '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
659
- self.log_progress(
660
- {
661
- "log": f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
662
- }
663
- )
664
-
665
- if visualize_output:
666
- viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][
667
- "visualized_output"
668
- ]
669
- self._report_visualization_via_callback(viz_images)
670
- for img in viz_images:
671
- Image.open(img).show()
672
-
673
- return final_answer, all_tool_results
674
-
675
- def chat(
676
- self,
677
- chat: List[Dict[str, str]],
678
- image: Optional[Union[str, Path]] = None,
679
- reference_data: Optional[Dict[str, str]] = None,
680
- visualize_output: Optional[bool] = False,
681
- self_reflection: Optional[bool] = True,
682
- ) -> str:
683
- answer, _ = self.chat_with_workflow(
684
- chat,
685
- image=image,
686
- visualize_output=visualize_output,
687
- reference_data=reference_data,
688
- self_reflection=self_reflection,
689
- )
690
- return answer
691
-
692
- def retrieval(
693
- self,
694
- model: Union[LLM, LMM, Agent],
695
- question: str,
696
- tools: Dict[int, Any],
697
- previous_log: str,
698
- reflections: str,
699
- ) -> Tuple[Dict, str]:
700
- tool_id = choose_tool(
701
- model,
702
- question,
703
- {k: v["description"] for k, v in tools.items()},
704
- reflections,
705
- )
706
- if tool_id is None:
707
- return {}, ""
708
-
709
- tool_instructions = tools[tool_id]
710
- tool_usage = tool_instructions["usage"]
711
- tool_name = tool_instructions["name"]
424
+ if self.verbosity > 0:
425
+ self.log_progress(
426
+ {
427
+ "log": "Reflection:",
428
+ "reflection": reflection,
429
+ }
430
+ )
431
+ _LOGGER.info(f"Reflection: {reflection}")
432
+ feedback = cast(str, reflection["feedback"])
433
+ success = cast(bool, reflection["success"])
434
+ working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})
712
435
 
713
- parameters = choose_parameter(
714
- model, question, tool_usage, previous_log, reflections
715
- )
716
- if parameters is None:
717
- return {}, ""
718
- tool_results = {
719
- "task": question,
720
- "tool_name": tool_name,
721
- "parameters": parameters,
722
- }
436
+ retries += 1
723
437
 
724
438
  self.log_progress(
725
439
  {
726
- "log": f"""Going to run the following tool(s) in sequence:
727
- {tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
440
+ "log": f"Vision Agent has concluded this chat.\nSuccess: {success}",
441
+ "finished": True,
728
442
  }
729
443
  )
730
444
 
731
- def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
732
- call_results: List[Any] = []
733
- if isinstance(result["parameters"], Dict):
734
- call_results.append(
735
- function_call(tools[tool_id]["class"], result["parameters"])
736
- )
737
- elif isinstance(result["parameters"], List):
738
- for parameters in result["parameters"]:
739
- call_results.append(
740
- function_call(tools[tool_id]["class"], parameters)
741
- )
742
- return call_results
743
-
744
- call_results = parse_tool_results(tool_results)
745
- tool_results["call_results"] = call_results
746
-
747
- call_results_str = str(call_results)
748
- return tool_results, call_results_str
445
+ return {
446
+ "code": code,
447
+ "test": test,
448
+ "test_result": results["test_result"],
449
+ "plan": plan,
450
+ "working_memory": working_memory,
451
+ }
749
452
 
750
- def create_tasks(
751
- self,
752
- task_model: Union[LLM, LMM],
753
- question: str,
754
- tools: Dict[int, Any],
755
- reflections: str,
756
- ) -> List[Dict]:
757
- tasks = task_decompose(
758
- task_model,
759
- question,
760
- {k: v["description"] for k, v in tools.items()},
761
- reflections,
762
- )
763
- if tasks is not None:
764
- task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
765
- task_list = task_topology(task_model, question, task_list)
766
- try:
767
- task_list = topological_sort(task_list)
768
- except Exception:
769
- _LOGGER.error(f"Failed topological_sort on: {task_list}")
770
- else:
771
- task_list = []
772
- self.log_progress(
773
- {
774
- "log": "Planned tasks:",
775
- "plan": task_list,
776
- }
777
- )
778
- return task_list
453
+ def log_progress(self, data: Dict[str, Any]) -> None:
454
+ if self.report_progress_callback is not None:
455
+ self.report_progress_callback(data)
456
+ pass