vision-agent 0.2.193__py3-none-any.whl → 0.2.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,432 @@
1
+ import copy
2
+ import logging
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
6
+
7
+ import numpy as np
8
+ from rich.console import Console
9
+ from rich.markup import escape
10
+
11
+ import vision_agent.tools as T
12
+ import vision_agent.tools.planner_tools as pt
13
+ from vision_agent.agent import Agent
14
+ from vision_agent.agent.agent_utils import (
15
+ PlanContext,
16
+ add_media_to_chat,
17
+ capture_media_from_exec,
18
+ extract_json,
19
+ extract_tag,
20
+ print_code,
21
+ print_table,
22
+ )
23
+ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
24
+ CRITIQUE_PLAN,
25
+ EXAMPLE_PLAN1,
26
+ EXAMPLE_PLAN2,
27
+ FINALIZE_PLAN,
28
+ FIX_BUG,
29
+ PICK_PLAN,
30
+ PLAN,
31
+ )
32
+ from vision_agent.lmm import LMM, AnthropicLMM, Message
33
+ from vision_agent.utils.execute import (
34
+ CodeInterpreter,
35
+ CodeInterpreterFactory,
36
+ Execution,
37
+ )
38
+
39
+ logging.basicConfig(level=logging.INFO)
40
+ UTIL_DOCSTRING = T.get_tool_documentation(
41
+ [
42
+ T.load_image,
43
+ T.extract_frames_and_timestamps,
44
+ T.save_image,
45
+ T.save_video,
46
+ T.overlay_bounding_boxes,
47
+ T.overlay_segmentation_masks,
48
+ ]
49
+ )
50
+ PLANNING_TOOLS_DOCSTRING = UTIL_DOCSTRING + "\n" + pt.PLANNER_DOCSTRING
51
+ _CONSOLE = Console()
52
+
53
+
54
+ class DefaultPlanningImports:
55
+ imports = [
56
+ "import os",
57
+ "import numpy as np",
58
+ "import cv2",
59
+ "from typing import *",
60
+ "from vision_agent.tools import *",
61
+ "from vision_agent.tools.planner_tools import claude35_vqa, suggestion, get_tool_for_task",
62
+ "from pillow_heif import register_heif_opener",
63
+ "register_heif_opener()",
64
+ "import matplotlib.pyplot as plt",
65
+ ]
66
+
67
+ @staticmethod
68
+ def prepend_imports(code: str) -> str:
69
+ return "\n".join(DefaultPlanningImports.imports) + "\n\n" + code
70
+
71
+
72
+ def get_planning(
73
+ chat: List[Message],
74
+ ) -> str:
75
+ chat = copy.deepcopy(chat)
76
+ planning = ""
77
+ for chat_i in chat:
78
+ if chat_i["role"] == "user":
79
+ planning += f"USER: {chat_i['content']}\n\n"
80
+ elif chat_i["role"] == "observation":
81
+ planning += f"OBSERVATION: {chat_i['content']}\n\n"
82
+ elif chat_i["role"] == "assistant":
83
+ planning += f"ASSISTANT: {chat_i['content']}\n\n"
84
+ else:
85
+ raise ValueError(f"Unknown role: {chat_i['role']}")
86
+
87
+ return planning
88
+
89
+
90
+ def run_planning(
91
+ chat: List[Message],
92
+ media_list: List[str],
93
+ model: LMM,
94
+ ) -> str:
95
+ # only keep last 10 messages for planning
96
+ planning = get_planning(chat[-10:])
97
+ prompt = PLAN.format(
98
+ tool_desc=PLANNING_TOOLS_DOCSTRING,
99
+ examples=f"{EXAMPLE_PLAN1}\n{EXAMPLE_PLAN2}",
100
+ planning=planning,
101
+ media_list=str(media_list),
102
+ )
103
+
104
+ message: Message = {"role": "user", "content": prompt}
105
+ if chat[-1]["role"] == "observation" and "media" in chat[-1]:
106
+ message["media"] = chat[-1]["media"]
107
+
108
+ response = model.chat([message])
109
+ return cast(str, response)
110
+
111
+
112
+ def run_multi_trial_planning(
113
+ chat: List[Message],
114
+ media_list: List[str],
115
+ model: LMM,
116
+ ) -> str:
117
+ planning = get_planning(chat)
118
+ prompt = PLAN.format(
119
+ tool_desc=PLANNING_TOOLS_DOCSTRING,
120
+ examples=EXAMPLE_PLAN1,
121
+ planning=planning,
122
+ media_list=str(media_list),
123
+ )
124
+
125
+ message: Message = {"role": "user", "content": prompt}
126
+ if chat[-1]["role"] == "observation" and "media" in chat[-1]:
127
+ message["media"] = chat[-1]["media"]
128
+
129
+ responses = []
130
+ with ThreadPoolExecutor() as executor:
131
+ futures = [
132
+ executor.submit(lambda: model.chat([message], temperature=1.0))
133
+ for _ in range(3)
134
+ ]
135
+ for future in as_completed(futures):
136
+ responses.append(future.result())
137
+
138
+ prompt = PICK_PLAN.format(
139
+ planning=planning,
140
+ response1=responses[0],
141
+ response2=responses[1],
142
+ response3=responses[2],
143
+ )
144
+ response = cast(str, model.chat([{"role": "user", "content": prompt}]))
145
+ json_str = extract_tag(response, "json")
146
+ if json_str:
147
+ json_data = extract_json(json_str)
148
+ best = np.argmax([int(json_data[f"response{k}"]) for k in [1, 2, 3]])
149
+ return cast(str, responses[best])
150
+ else:
151
+ return cast(str, responses[0])
152
+
153
+
154
+ def run_critic(chat: List[Message], media_list: List[str], model: LMM) -> Optional[str]:
155
+ planning = get_planning(chat)
156
+ prompt = CRITIQUE_PLAN.format(
157
+ planning=planning,
158
+ )
159
+ message: Message = {"role": "user", "content": prompt}
160
+ if len(media_list) > 0:
161
+ message["media"] = media_list
162
+
163
+ response = cast(str, model.chat([message]))
164
+ score = extract_tag(response, "score")
165
+ thoughts = extract_tag(response, "thoughts")
166
+ if score is not None and thoughts is not None:
167
+ try:
168
+ fscore = float(score)
169
+ if fscore < 8:
170
+ return thoughts
171
+ except ValueError:
172
+ pass
173
+ return None
174
+
175
+
176
+ def code_safeguards(code: str) -> str:
177
+ if "get_tool_for_task" in code:
178
+ lines = code.split("\n")
179
+ new_lines = []
180
+ for line in lines:
181
+ new_lines.append(line)
182
+ if "get_tool_for_task" in line:
183
+ break
184
+ code = "\n".join(new_lines)
185
+ return code
186
+
187
+
188
+ def response_safeguards(response: str) -> str:
189
+ if "<execute_python>" in response:
190
+ response = response[
191
+ : response.index("</execute_python>") + len("</execute_python>")
192
+ ]
193
+ return response
194
+
195
+
196
+ def execute_code_action(
197
+ code: str,
198
+ code_interpreter: CodeInterpreter,
199
+ chat: List[Message],
200
+ model: LMM,
201
+ verbose: bool = False,
202
+ ) -> Tuple[Execution, str, str]:
203
+ if verbose:
204
+ print_code("Code to Execute:", code)
205
+ execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
206
+ obs = execution.text(include_results=False).strip()
207
+ if verbose:
208
+ _CONSOLE.print(
209
+ f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
210
+ )
211
+
212
+ count = 1
213
+ while not execution.success and count <= 3:
214
+ prompt = FIX_BUG.format(chat_history=get_planning(chat), code=code, error=obs)
215
+ response = cast(str, model.chat([{"role": "user", "content": prompt}]))
216
+ new_code = extract_tag(response, "code")
217
+ if not new_code:
218
+ continue
219
+ else:
220
+ code = new_code
221
+
222
+ execution = code_interpreter.exec_cell(
223
+ DefaultPlanningImports.prepend_imports(code)
224
+ )
225
+ obs = execution.text(include_results=False).strip()
226
+ if verbose:
227
+ print_code(f"Fixing Bug Round {count}:", code)
228
+ _CONSOLE.print(
229
+ f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
230
+ )
231
+ count += 1
232
+
233
+ if obs.startswith("----- stdout -----\n"):
234
+ obs = obs[19:]
235
+ if obs.endswith("\n----- stderr -----"):
236
+ obs = obs[:-19]
237
+ return execution, obs, code
238
+
239
+
240
+ def find_and_replace_code(response: str, code: str) -> str:
241
+ code_start = response.index("<execute_python>") + len("<execute_python>")
242
+ code_end = response.index("</execute_python>")
243
+ return response[:code_start] + code + response[code_end:]
244
+
245
+
246
+ def maybe_run_code(
247
+ code: Optional[str],
248
+ response: str,
249
+ chat: List[Message],
250
+ media_list: List[str],
251
+ model: LMM,
252
+ code_interpreter: CodeInterpreter,
253
+ verbose: bool = False,
254
+ ) -> List[Message]:
255
+ return_chat: List[Message] = []
256
+ if code is not None:
257
+ code = code_safeguards(code)
258
+ execution, obs, code = execute_code_action(
259
+ code, code_interpreter, chat, model, verbose
260
+ )
261
+
262
+ # if we had to debug the code to fix an issue, replace the old code
263
+ # with the fixed code in the response
264
+ fixed_response = find_and_replace_code(response, code)
265
+ return_chat.append({"role": "assistant", "content": fixed_response})
266
+
267
+ media_data = capture_media_from_exec(execution)
268
+ int_chat_elt: Message = {"role": "observation", "content": obs}
269
+ if media_list:
270
+ int_chat_elt["media"] = media_data
271
+ return_chat.append(int_chat_elt)
272
+ else:
273
+ return_chat.append({"role": "assistant", "content": response})
274
+ return return_chat
275
+
276
+
277
+ def create_finalize_plan(
278
+ chat: List[Message],
279
+ model: LMM,
280
+ verbose: bool = False,
281
+ ) -> Tuple[List[Message], PlanContext]:
282
+ prompt = FINALIZE_PLAN.format(
283
+ planning=get_planning(chat),
284
+ excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
285
+ )
286
+ response = model.chat([{"role": "user", "content": prompt}])
287
+ plan_str = cast(str, response)
288
+ return_chat: List[Message] = [{"role": "assistant", "content": plan_str}]
289
+
290
+ plan_json = extract_tag(plan_str, "json")
291
+ plan = (
292
+ extract_json(plan_json)
293
+ if plan_json is not None
294
+ else {"plan": plan_str, "instructions": [], "code": ""}
295
+ )
296
+ code_snippets = extract_tag(plan_str, "code")
297
+ plan["code"] = code_snippets if code_snippets is not None else ""
298
+ if verbose:
299
+ _CONSOLE.print(
300
+ f"[bold cyan]Final Plan:[/bold cyan] [magenta]{plan['plan']}[/magenta]"
301
+ )
302
+ print_table("Plan", ["Instructions"], [[p] for p in plan["instructions"]])
303
+ print_code("Plan Code", plan["code"])
304
+
305
+ return return_chat, PlanContext(**plan)
306
+
307
+
308
+ class VisionAgentPlannerV2(Agent):
309
+ def __init__(
310
+ self,
311
+ planner: Optional[LMM] = None,
312
+ critic: Optional[LMM] = None,
313
+ max_steps: int = 10,
314
+ use_multi_trial_planning: bool = False,
315
+ critique_steps: int = 11,
316
+ verbose: bool = False,
317
+ code_sandbox_runtime: Optional[str] = None,
318
+ update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
319
+ ) -> None:
320
+ self.planner = (
321
+ planner
322
+ if planner is not None
323
+ else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
324
+ )
325
+ self.critic = (
326
+ critic
327
+ if critic is not None
328
+ else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
329
+ )
330
+ self.max_steps = max_steps
331
+ self.use_multi_trial_planning = use_multi_trial_planning
332
+ self.critique_steps = critique_steps
333
+
334
+ self.verbose = verbose
335
+ self.code_sandbox_runtime = code_sandbox_runtime
336
+ self.update_callback = update_callback
337
+
338
+ def __call__(
339
+ self,
340
+ input: Union[str, List[Message]],
341
+ media: Optional[Union[str, Path]] = None,
342
+ ) -> Union[str, List[Message]]:
343
+ if isinstance(input, str):
344
+ if media is not None:
345
+ input = [{"role": "user", "content": input, "media": [media]}]
346
+ else:
347
+ input = [{"role": "user", "content": input}]
348
+ plan = self.generate_plan(input)
349
+ return str(plan)
350
+
351
+ def generate_plan(
352
+ self,
353
+ chat: List[Message],
354
+ code_interpreter: Optional[CodeInterpreter] = None,
355
+ ) -> PlanContext:
356
+ if not chat:
357
+ raise ValueError("Chat cannot be empty")
358
+
359
+ chat = copy.deepcopy(chat)
360
+ code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
361
+ self.code_sandbox_runtime
362
+ )
363
+
364
+ with code_interpreter:
365
+ critque_steps = 1
366
+ step = self.max_steps
367
+ finished = False
368
+ int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
369
+ int_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
370
+ while step > 0 and not finished:
371
+ if self.use_multi_trial_planning:
372
+ response = run_multi_trial_planning(
373
+ int_chat, media_list, self.planner
374
+ )
375
+ else:
376
+ response = run_planning(int_chat, media_list, self.planner)
377
+
378
+ response = response_safeguards(response)
379
+ thinking = extract_tag(response, "thinking")
380
+ code = extract_tag(response, "execute_python")
381
+ finalize_plan = extract_tag(response, "finalize_plan")
382
+ finished = finalize_plan is not None
383
+
384
+ if self.verbose:
385
+ _CONSOLE.print(
386
+ f"[bold cyan]Step {step}:[/bold cyan] [green]{thinking}[/green]"
387
+ )
388
+ if finalize_plan is not None:
389
+ _CONSOLE.print(
390
+ f"[bold cyan]Finalizing Plan:[/bold cyan] [magenta]{finalize_plan}[/magenta]"
391
+ )
392
+
393
+ updated_chat = maybe_run_code(
394
+ code,
395
+ response,
396
+ int_chat,
397
+ media_list,
398
+ self.planner,
399
+ code_interpreter,
400
+ self.verbose,
401
+ )
402
+
403
+ if critque_steps % self.critique_steps == 0:
404
+ critique = run_critic(int_chat, media_list, self.critic)
405
+ if critique is not None and int_chat[-1]["role"] == "observation":
406
+ _CONSOLE.print(
407
+ f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
408
+ )
409
+ critique_str = f"\n[critique]\n{critique}\n[end of critique]"
410
+ updated_chat[-1]["content"] += critique_str # type: ignore
411
+ # if plan was critiqued, ensure we don't finish so we can
412
+ # respond to the critique
413
+ finished = False
414
+
415
+ critque_steps += 1
416
+ step -= 1
417
+ updated_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
418
+ int_chat.extend(updated_chat)
419
+ for chat_elt in updated_chat:
420
+ self.update_callback(chat_elt)
421
+
422
+ updated_chat, plan_context = create_finalize_plan(
423
+ int_chat, self.planner, self.verbose
424
+ )
425
+ int_chat.extend(updated_chat)
426
+ for chat_elt in updated_chat:
427
+ self.update_callback(chat_elt)
428
+
429
+ return plan_context
430
+
431
+ def log_progress(self, data: Dict[str, Any]) -> None:
432
+ pass
vision_agent/lmm/lmm.py CHANGED
@@ -400,6 +400,8 @@ class AnthropicLMM(LMM):
400
400
  if "media" in msg:
401
401
  for media_path in msg["media"]:
402
402
  encoded_media = encode_media(media_path, resize=768)
403
+ if encoded_media.startswith("data:image/png;base64,"):
404
+ encoded_media = encoded_media[len("data:image/png;base64,") :]
403
405
  content.append(
404
406
  ImageBlockParam(
405
407
  type="image",
@@ -447,6 +449,8 @@ class AnthropicLMM(LMM):
447
449
  if media:
448
450
  for m in media:
449
451
  encoded_media = encode_media(m, resize=768)
452
+ if encoded_media.startswith("data:image/png;base64,"):
453
+ encoded_media = encoded_media[len("data:image/png;base64,") :]
450
454
  content.append(
451
455
  ImageBlockParam(
452
456
  type="image",
@@ -13,7 +13,7 @@ from .meta_tools import (
13
13
  view_media_artifact,
14
14
  )
15
15
  from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
16
- from .tool_utils import get_tool_descriptions_by_names
16
+ from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
17
17
  from .tools import (
18
18
  FUNCTION_TOOLS,
19
19
  TOOL_DESCRIPTIONS,
@@ -24,6 +24,7 @@ from .tools import (
24
24
  UTIL_TOOLS,
25
25
  UTILITIES_DOCSTRING,
26
26
  blip_image_caption,
27
+ claude35_text_extraction,
27
28
  clip,
28
29
  closest_box_distance,
29
30
  closest_mask_distance,