vision-agent 0.2.192__py3-none-any.whl → 0.2.195__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,341 @@
1
+ import copy
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
4
+
5
+ from rich.console import Console
6
+ from rich.markup import escape
7
+
8
+ import vision_agent.tools as T
9
+ from vision_agent.agent import Agent
10
+ from vision_agent.agent.agent_utils import (
11
+ CodeContext,
12
+ DefaultImports,
13
+ PlanContext,
14
+ add_media_to_chat,
15
+ capture_media_from_exec,
16
+ extract_tag,
17
+ format_feedback,
18
+ format_plan_v2,
19
+ print_code,
20
+ strip_function_calls,
21
+ )
22
+ from vision_agent.agent.vision_agent_coder_prompts_v2 import CODE, FIX_BUG, TEST
23
+ from vision_agent.agent.vision_agent_planner_v2 import VisionAgentPlannerV2
24
+ from vision_agent.lmm import LMM, AnthropicLMM
25
+ from vision_agent.lmm.types import Message
26
+ from vision_agent.tools.meta_tools import get_diff
27
+ from vision_agent.utils.execute import (
28
+ CodeInterpreter,
29
+ CodeInterpreterFactory,
30
+ Execution,
31
+ )
32
+ from vision_agent.utils.sim import Sim, load_cached_sim
33
+
34
+ _CONSOLE = Console()
35
+
36
+
37
+ def retrieve_tools(
38
+ plan: List[str],
39
+ tool_recommender: Sim,
40
+ ) -> str:
41
+ tool_docs = []
42
+ for inst in plan:
43
+ tools = tool_recommender.top_k(inst, k=1, thresh=0.3)
44
+ tool_docs.extend([e["doc"] for e in tools])
45
+
46
+ tool_docs_str = "\n\n".join(set(tool_docs))
47
+ return tool_docs_str
48
+
49
+
50
+ def write_code(
51
+ coder: LMM,
52
+ chat: List[Message],
53
+ tool_docs: str,
54
+ plan: str,
55
+ ) -> str:
56
+ chat = copy.deepcopy(chat)
57
+ if chat[-1]["role"] != "user":
58
+ raise ValueError("Last chat message must be from the user.")
59
+
60
+ user_request = chat[-1]["content"]
61
+ prompt = CODE.format(
62
+ docstring=tool_docs,
63
+ question=user_request,
64
+ plan=plan,
65
+ )
66
+ chat[-1]["content"] = prompt
67
+ response = coder(chat, stream=False)
68
+ return extract_tag(response, "code") # type: ignore
69
+
70
+
71
+ def write_test(
72
+ tester: LMM,
73
+ chat: List[Message],
74
+ tool_util_docs: str,
75
+ code: str,
76
+ media_list: Optional[Sequence[Union[str, Path]]] = None,
77
+ ) -> str:
78
+ chat = copy.deepcopy(chat)
79
+ if chat[-1]["role"] != "user":
80
+ raise ValueError("Last chat message must be from the user.")
81
+
82
+ user_request = chat[-1]["content"]
83
+ prompt = TEST.format(
84
+ docstring=tool_util_docs,
85
+ question=user_request,
86
+ code=code,
87
+ media=media_list,
88
+ )
89
+ chat[-1]["content"] = prompt
90
+ response = tester(chat, stream=False)
91
+ return extract_tag(response, "code") # type: ignore
92
+
93
+
94
+ def debug_code(
95
+ debugger: LMM,
96
+ tool_docs: str,
97
+ plan: str,
98
+ code: str,
99
+ test: str,
100
+ result: Execution,
101
+ debug_info: str,
102
+ verbose: bool,
103
+ ) -> tuple[str, str, str]:
104
+ fixed_code = None
105
+ fixed_test = None
106
+ thoughts = ""
107
+ success = False
108
+ count = 0
109
+ while not success and count < 3:
110
+ try:
111
+ # LLMs write worse code when it's in JSON, so we have it write JSON
112
+ # followed by code each wrapped in markdown blocks.
113
+ fixed_code_and_test_str = debugger(
114
+ FIX_BUG.format(
115
+ docstring=tool_docs,
116
+ plan=plan,
117
+ code=code,
118
+ tests=test,
119
+ # Because of the way we trace function calls the trace information
120
+ # ends up in the results. We don't want to show this info to the
121
+ # LLM so we don't include it in the tool_output_str.
122
+ result="\n".join(
123
+ result.text(include_results=False).splitlines()[-50:]
124
+ ),
125
+ debug=debug_info,
126
+ ),
127
+ stream=False,
128
+ )
129
+ fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
130
+ thoughts_tag = extract_tag(fixed_code_and_test_str, "thoughts")
131
+ thoughts = thoughts_tag if thoughts_tag is not None else ""
132
+ fixed_code = extract_tag(fixed_code_and_test_str, "code")
133
+ fixed_test = extract_tag(fixed_code_and_test_str, "test")
134
+
135
+ success = not (fixed_code is None and fixed_test is None)
136
+
137
+ except Exception as e:
138
+ _CONSOLE.print(f"[bold red]Error while extracting JSON:[/bold red] {e}")
139
+
140
+ count += 1
141
+
142
+ old_code = code
143
+ old_test = test
144
+
145
+ if fixed_code is not None and fixed_code.strip() != "":
146
+ code = fixed_code
147
+ if fixed_test is not None and fixed_test.strip() != "":
148
+ test = fixed_test
149
+
150
+ debug_info_i = format_feedback(
151
+ [
152
+ {
153
+ "code": f"{code}\n{test}",
154
+ "feedback": thoughts,
155
+ "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"),
156
+ }
157
+ ]
158
+ )
159
+ debug_info += f"\n{debug_info_i}"
160
+
161
+ if verbose:
162
+ _CONSOLE.print(
163
+ f"[bold cyan]Thoughts on attempted fix:[/bold cyan] [green]{thoughts}[/green]"
164
+ )
165
+
166
+ return code, test, debug_info
167
+
168
+
169
+ def write_and_test_code(
170
+ coder: LMM,
171
+ tester: LMM,
172
+ debugger: LMM,
173
+ chat: List[Message],
174
+ plan: str,
175
+ tool_docs: str,
176
+ code_interpreter: CodeInterpreter,
177
+ media_list: List[Union[str, Path]],
178
+ update_callback: Callable[[Dict[str, Any]], None],
179
+ verbose: bool,
180
+ ) -> CodeContext:
181
+ code = write_code(
182
+ coder=coder,
183
+ chat=chat,
184
+ tool_docs=tool_docs,
185
+ plan=plan,
186
+ )
187
+ code = strip_function_calls(code)
188
+ test = write_test(
189
+ tester=tester,
190
+ chat=chat,
191
+ tool_util_docs=T.UTILITIES_DOCSTRING,
192
+ code=code,
193
+ media_list=media_list,
194
+ )
195
+ if verbose:
196
+ print_code("Code:", code)
197
+ print_code("Test:", test)
198
+ result = code_interpreter.exec_isolation(
199
+ f"{DefaultImports.to_code_string()}\n{code}\n{test}"
200
+ )
201
+ if verbose:
202
+ _CONSOLE.print(
203
+ f"[bold cyan]Code execution result:[/bold cyan] [yellow]{escape(result.text(include_logs=True))}[/yellow]"
204
+ )
205
+
206
+ count = 0
207
+ debug_info = ""
208
+ while (not result.success or len(result.logs.stdout) == 0) and count < 3:
209
+ code, test, debug_info = debug_code(
210
+ debugger,
211
+ T.UTILITIES_DOCSTRING + "\n" + tool_docs,
212
+ plan,
213
+ code,
214
+ test,
215
+ result,
216
+ debug_info,
217
+ verbose,
218
+ )
219
+ result = code_interpreter.exec_isolation(
220
+ f"{DefaultImports.to_code_string()}\n{code}\n{test}"
221
+ )
222
+ count += 1
223
+ if verbose:
224
+ print_code("Code and test after attempted fix:", code, test)
225
+ _CONSOLE.print(
226
+ f"[bold cyan]Code execution result after attempted fix:[/bold cyan] [yellow]{escape(result.text(include_logs=True))}[/yellow]"
227
+ )
228
+
229
+ update_callback(
230
+ {
231
+ "role": "assistant",
232
+ "content": f"<final_code>{DefaultImports.to_code_string()}\n{code}</final_code>\n<final_test>{DefaultImports.to_code_string()}\n{test}</final_test>",
233
+ "media": capture_media_from_exec(result),
234
+ }
235
+ )
236
+
237
+ return CodeContext(
238
+ code=f"{DefaultImports.to_code_string()}\n{code}",
239
+ test=f"{DefaultImports.to_code_string()}\n{test}",
240
+ success=result.success,
241
+ test_result=result,
242
+ )
243
+
244
+
245
+ class VisionAgentCoderV2(Agent):
246
+ def __init__(
247
+ self,
248
+ planner: Optional[Agent] = None,
249
+ coder: Optional[LMM] = None,
250
+ tester: Optional[LMM] = None,
251
+ debugger: Optional[LMM] = None,
252
+ tool_recommender: Optional[Union[str, Sim]] = None,
253
+ verbose: bool = False,
254
+ code_sandbox_runtime: Optional[str] = None,
255
+ update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
256
+ ) -> None:
257
+ self.planner = (
258
+ planner
259
+ if planner is not None
260
+ else VisionAgentPlannerV2(verbose=verbose, update_callback=update_callback)
261
+ )
262
+ self.coder = (
263
+ coder
264
+ if coder is not None
265
+ else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
266
+ )
267
+ self.tester = (
268
+ tester
269
+ if tester is not None
270
+ else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
271
+ )
272
+ self.debugger = (
273
+ debugger
274
+ if debugger is not None
275
+ else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
276
+ )
277
+ if tool_recommender is not None:
278
+ if isinstance(tool_recommender, str):
279
+ self.tool_recommender = Sim.load(tool_recommender)
280
+ elif isinstance(tool_recommender, Sim):
281
+ self.tool_recommender = tool_recommender
282
+ else:
283
+ self.tool_recommender = load_cached_sim(T.TOOLS_DF)
284
+
285
+ self.verbose = verbose
286
+ self.code_sandbox_runtime = code_sandbox_runtime
287
+ self.update_callback = update_callback
288
+
289
+ def __call__(
290
+ self,
291
+ input: Union[str, List[Message]],
292
+ media: Optional[Union[str, Path]] = None,
293
+ ) -> Union[str, List[Message]]:
294
+ if isinstance(input, str):
295
+ input = [{"role": "user", "content": input}]
296
+ if media is not None:
297
+ input[0]["media"] = [media]
298
+ return self.generate_code(input).code
299
+
300
+ def generate_code(self, chat: List[Message]) -> CodeContext:
301
+ chat = copy.deepcopy(chat)
302
+ with CodeInterpreterFactory.new_instance(
303
+ self.code_sandbox_runtime
304
+ ) as code_interpreter:
305
+ int_chat, orig_chat, _ = add_media_to_chat(chat, code_interpreter)
306
+ plan_context = self.planner.generate_plan(int_chat, code_interpreter) # type: ignore
307
+ code_context = self.generate_code_from_plan(
308
+ orig_chat,
309
+ plan_context,
310
+ code_interpreter,
311
+ )
312
+ return code_context
313
+
314
+ def generate_code_from_plan(
315
+ self,
316
+ chat: List[Message],
317
+ plan_context: PlanContext,
318
+ code_interpreter: Optional[CodeInterpreter] = None,
319
+ ) -> CodeContext:
320
+ chat = copy.deepcopy(chat)
321
+ with CodeInterpreterFactory.new_instance(
322
+ self.code_sandbox_runtime
323
+ ) as code_interpreter:
324
+ int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
325
+ tool_docs = retrieve_tools(plan_context.instructions, self.tool_recommender)
326
+ code_context = write_and_test_code(
327
+ coder=self.coder,
328
+ tester=self.tester,
329
+ debugger=self.debugger,
330
+ chat=int_chat,
331
+ plan=format_plan_v2(plan_context),
332
+ tool_docs=tool_docs,
333
+ code_interpreter=code_interpreter,
334
+ media_list=media_list, # type: ignore
335
+ update_callback=self.update_callback,
336
+ verbose=self.verbose,
337
+ )
338
+ return code_context
339
+
340
+ def log_progress(self, data: Dict[str, Any]) -> None:
341
+ pass
@@ -14,7 +14,7 @@ from vision_agent.agent.agent_utils import (
14
14
  DefaultImports,
15
15
  extract_code,
16
16
  extract_json,
17
- format_memory,
17
+ format_feedback,
18
18
  format_plans,
19
19
  print_code,
20
20
  )
@@ -423,7 +423,7 @@ class VisionAgentPlanner(Agent):
423
423
  T.get_tool_descriptions_by_names(
424
424
  custom_tool_names, T.FUNCTION_TOOLS, T.UTIL_TOOLS # type: ignore
425
425
  ),
426
- format_memory(working_memory),
426
+ format_feedback(working_memory),
427
427
  self.planner,
428
428
  )
429
429
  if self.verbosity >= 1:
@@ -190,7 +190,7 @@ PICK_PLAN = """
190
190
  1. Re-read the user request, plans, tool outputs and examine the image.
191
191
  2. Solve the problem yourself given the image and pick the most accurate plan that matches your solution the best.
192
192
  3. Add modifications to improve the plan including: changing a tool, adding thresholds, string matching.
193
- 3. Output a JSON object with the following format:
193
+ 4. Output a JSON object with the following format:
194
194
  {{
195
195
  "predicted_answer": str # the answer you would expect from the best plan
196
196
  "thoughts": str # your thought process for choosing the best plan over other plans and any modifications you made