vision-agent 0.2.193__py3-none-any.whl → 0.2.196__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +640 -0
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/__init__.py +2 -0
- vision_agent/agent/agent_utils.py +211 -3
- vision_agent/agent/vision_agent_coder.py +5 -113
- vision_agent/agent/vision_agent_coder_prompts_v2.py +119 -0
- vision_agent/agent/vision_agent_coder_v2.py +341 -0
- vision_agent/agent/vision_agent_planner.py +2 -2
- vision_agent/agent/vision_agent_planner_prompts.py +1 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +748 -0
- vision_agent/agent/vision_agent_planner_v2.py +432 -0
- vision_agent/lmm/lmm.py +4 -0
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/planner_tools.py +246 -0
- vision_agent/tools/tool_utils.py +65 -1
- vision_agent/tools/tools.py +76 -22
- vision_agent/utils/image_utils.py +12 -6
- vision_agent/utils/sim.py +65 -14
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/METADATA +2 -1
- vision_agent-0.2.196.dist-info/RECORD +42 -0
- vision_agent-0.2.193.dist-info/RECORD +0 -35
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/WHEEL +0 -0
@@ -0,0 +1,432 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from rich.console import Console
|
9
|
+
from rich.markup import escape
|
10
|
+
|
11
|
+
import vision_agent.tools as T
|
12
|
+
import vision_agent.tools.planner_tools as pt
|
13
|
+
from vision_agent.agent import Agent
|
14
|
+
from vision_agent.agent.agent_utils import (
|
15
|
+
PlanContext,
|
16
|
+
add_media_to_chat,
|
17
|
+
capture_media_from_exec,
|
18
|
+
extract_json,
|
19
|
+
extract_tag,
|
20
|
+
print_code,
|
21
|
+
print_table,
|
22
|
+
)
|
23
|
+
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
24
|
+
CRITIQUE_PLAN,
|
25
|
+
EXAMPLE_PLAN1,
|
26
|
+
EXAMPLE_PLAN2,
|
27
|
+
FINALIZE_PLAN,
|
28
|
+
FIX_BUG,
|
29
|
+
PICK_PLAN,
|
30
|
+
PLAN,
|
31
|
+
)
|
32
|
+
from vision_agent.lmm import LMM, AnthropicLMM, Message
|
33
|
+
from vision_agent.utils.execute import (
|
34
|
+
CodeInterpreter,
|
35
|
+
CodeInterpreterFactory,
|
36
|
+
Execution,
|
37
|
+
)
|
38
|
+
|
39
|
+
logging.basicConfig(level=logging.INFO)
|
40
|
+
UTIL_DOCSTRING = T.get_tool_documentation(
|
41
|
+
[
|
42
|
+
T.load_image,
|
43
|
+
T.extract_frames_and_timestamps,
|
44
|
+
T.save_image,
|
45
|
+
T.save_video,
|
46
|
+
T.overlay_bounding_boxes,
|
47
|
+
T.overlay_segmentation_masks,
|
48
|
+
]
|
49
|
+
)
|
50
|
+
PLANNING_TOOLS_DOCSTRING = UTIL_DOCSTRING + "\n" + pt.PLANNER_DOCSTRING
|
51
|
+
_CONSOLE = Console()
|
52
|
+
|
53
|
+
|
54
|
+
class DefaultPlanningImports:
|
55
|
+
imports = [
|
56
|
+
"import os",
|
57
|
+
"import numpy as np",
|
58
|
+
"import cv2",
|
59
|
+
"from typing import *",
|
60
|
+
"from vision_agent.tools import *",
|
61
|
+
"from vision_agent.tools.planner_tools import claude35_vqa, suggestion, get_tool_for_task",
|
62
|
+
"from pillow_heif import register_heif_opener",
|
63
|
+
"register_heif_opener()",
|
64
|
+
"import matplotlib.pyplot as plt",
|
65
|
+
]
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def prepend_imports(code: str) -> str:
|
69
|
+
return "\n".join(DefaultPlanningImports.imports) + "\n\n" + code
|
70
|
+
|
71
|
+
|
72
|
+
def get_planning(
|
73
|
+
chat: List[Message],
|
74
|
+
) -> str:
|
75
|
+
chat = copy.deepcopy(chat)
|
76
|
+
planning = ""
|
77
|
+
for chat_i in chat:
|
78
|
+
if chat_i["role"] == "user":
|
79
|
+
planning += f"USER: {chat_i['content']}\n\n"
|
80
|
+
elif chat_i["role"] == "observation":
|
81
|
+
planning += f"OBSERVATION: {chat_i['content']}\n\n"
|
82
|
+
elif chat_i["role"] == "assistant":
|
83
|
+
planning += f"ASSISTANT: {chat_i['content']}\n\n"
|
84
|
+
else:
|
85
|
+
raise ValueError(f"Unknown role: {chat_i['role']}")
|
86
|
+
|
87
|
+
return planning
|
88
|
+
|
89
|
+
|
90
|
+
def run_planning(
|
91
|
+
chat: List[Message],
|
92
|
+
media_list: List[str],
|
93
|
+
model: LMM,
|
94
|
+
) -> str:
|
95
|
+
# only keep last 10 messages for planning
|
96
|
+
planning = get_planning(chat[-10:])
|
97
|
+
prompt = PLAN.format(
|
98
|
+
tool_desc=PLANNING_TOOLS_DOCSTRING,
|
99
|
+
examples=f"{EXAMPLE_PLAN1}\n{EXAMPLE_PLAN2}",
|
100
|
+
planning=planning,
|
101
|
+
media_list=str(media_list),
|
102
|
+
)
|
103
|
+
|
104
|
+
message: Message = {"role": "user", "content": prompt}
|
105
|
+
if chat[-1]["role"] == "observation" and "media" in chat[-1]:
|
106
|
+
message["media"] = chat[-1]["media"]
|
107
|
+
|
108
|
+
response = model.chat([message])
|
109
|
+
return cast(str, response)
|
110
|
+
|
111
|
+
|
112
|
+
def run_multi_trial_planning(
|
113
|
+
chat: List[Message],
|
114
|
+
media_list: List[str],
|
115
|
+
model: LMM,
|
116
|
+
) -> str:
|
117
|
+
planning = get_planning(chat)
|
118
|
+
prompt = PLAN.format(
|
119
|
+
tool_desc=PLANNING_TOOLS_DOCSTRING,
|
120
|
+
examples=EXAMPLE_PLAN1,
|
121
|
+
planning=planning,
|
122
|
+
media_list=str(media_list),
|
123
|
+
)
|
124
|
+
|
125
|
+
message: Message = {"role": "user", "content": prompt}
|
126
|
+
if chat[-1]["role"] == "observation" and "media" in chat[-1]:
|
127
|
+
message["media"] = chat[-1]["media"]
|
128
|
+
|
129
|
+
responses = []
|
130
|
+
with ThreadPoolExecutor() as executor:
|
131
|
+
futures = [
|
132
|
+
executor.submit(lambda: model.chat([message], temperature=1.0))
|
133
|
+
for _ in range(3)
|
134
|
+
]
|
135
|
+
for future in as_completed(futures):
|
136
|
+
responses.append(future.result())
|
137
|
+
|
138
|
+
prompt = PICK_PLAN.format(
|
139
|
+
planning=planning,
|
140
|
+
response1=responses[0],
|
141
|
+
response2=responses[1],
|
142
|
+
response3=responses[2],
|
143
|
+
)
|
144
|
+
response = cast(str, model.chat([{"role": "user", "content": prompt}]))
|
145
|
+
json_str = extract_tag(response, "json")
|
146
|
+
if json_str:
|
147
|
+
json_data = extract_json(json_str)
|
148
|
+
best = np.argmax([int(json_data[f"response{k}"]) for k in [1, 2, 3]])
|
149
|
+
return cast(str, responses[best])
|
150
|
+
else:
|
151
|
+
return cast(str, responses[0])
|
152
|
+
|
153
|
+
|
154
|
+
def run_critic(chat: List[Message], media_list: List[str], model: LMM) -> Optional[str]:
|
155
|
+
planning = get_planning(chat)
|
156
|
+
prompt = CRITIQUE_PLAN.format(
|
157
|
+
planning=planning,
|
158
|
+
)
|
159
|
+
message: Message = {"role": "user", "content": prompt}
|
160
|
+
if len(media_list) > 0:
|
161
|
+
message["media"] = media_list
|
162
|
+
|
163
|
+
response = cast(str, model.chat([message]))
|
164
|
+
score = extract_tag(response, "score")
|
165
|
+
thoughts = extract_tag(response, "thoughts")
|
166
|
+
if score is not None and thoughts is not None:
|
167
|
+
try:
|
168
|
+
fscore = float(score)
|
169
|
+
if fscore < 8:
|
170
|
+
return thoughts
|
171
|
+
except ValueError:
|
172
|
+
pass
|
173
|
+
return None
|
174
|
+
|
175
|
+
|
176
|
+
def code_safeguards(code: str) -> str:
|
177
|
+
if "get_tool_for_task" in code:
|
178
|
+
lines = code.split("\n")
|
179
|
+
new_lines = []
|
180
|
+
for line in lines:
|
181
|
+
new_lines.append(line)
|
182
|
+
if "get_tool_for_task" in line:
|
183
|
+
break
|
184
|
+
code = "\n".join(new_lines)
|
185
|
+
return code
|
186
|
+
|
187
|
+
|
188
|
+
def response_safeguards(response: str) -> str:
|
189
|
+
if "<execute_python>" in response:
|
190
|
+
response = response[
|
191
|
+
: response.index("</execute_python>") + len("</execute_python>")
|
192
|
+
]
|
193
|
+
return response
|
194
|
+
|
195
|
+
|
196
|
+
def execute_code_action(
|
197
|
+
code: str,
|
198
|
+
code_interpreter: CodeInterpreter,
|
199
|
+
chat: List[Message],
|
200
|
+
model: LMM,
|
201
|
+
verbose: bool = False,
|
202
|
+
) -> Tuple[Execution, str, str]:
|
203
|
+
if verbose:
|
204
|
+
print_code("Code to Execute:", code)
|
205
|
+
execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
|
206
|
+
obs = execution.text(include_results=False).strip()
|
207
|
+
if verbose:
|
208
|
+
_CONSOLE.print(
|
209
|
+
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
210
|
+
)
|
211
|
+
|
212
|
+
count = 1
|
213
|
+
while not execution.success and count <= 3:
|
214
|
+
prompt = FIX_BUG.format(chat_history=get_planning(chat), code=code, error=obs)
|
215
|
+
response = cast(str, model.chat([{"role": "user", "content": prompt}]))
|
216
|
+
new_code = extract_tag(response, "code")
|
217
|
+
if not new_code:
|
218
|
+
continue
|
219
|
+
else:
|
220
|
+
code = new_code
|
221
|
+
|
222
|
+
execution = code_interpreter.exec_cell(
|
223
|
+
DefaultPlanningImports.prepend_imports(code)
|
224
|
+
)
|
225
|
+
obs = execution.text(include_results=False).strip()
|
226
|
+
if verbose:
|
227
|
+
print_code(f"Fixing Bug Round {count}:", code)
|
228
|
+
_CONSOLE.print(
|
229
|
+
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
230
|
+
)
|
231
|
+
count += 1
|
232
|
+
|
233
|
+
if obs.startswith("----- stdout -----\n"):
|
234
|
+
obs = obs[19:]
|
235
|
+
if obs.endswith("\n----- stderr -----"):
|
236
|
+
obs = obs[:-19]
|
237
|
+
return execution, obs, code
|
238
|
+
|
239
|
+
|
240
|
+
def find_and_replace_code(response: str, code: str) -> str:
|
241
|
+
code_start = response.index("<execute_python>") + len("<execute_python>")
|
242
|
+
code_end = response.index("</execute_python>")
|
243
|
+
return response[:code_start] + code + response[code_end:]
|
244
|
+
|
245
|
+
|
246
|
+
def maybe_run_code(
|
247
|
+
code: Optional[str],
|
248
|
+
response: str,
|
249
|
+
chat: List[Message],
|
250
|
+
media_list: List[str],
|
251
|
+
model: LMM,
|
252
|
+
code_interpreter: CodeInterpreter,
|
253
|
+
verbose: bool = False,
|
254
|
+
) -> List[Message]:
|
255
|
+
return_chat: List[Message] = []
|
256
|
+
if code is not None:
|
257
|
+
code = code_safeguards(code)
|
258
|
+
execution, obs, code = execute_code_action(
|
259
|
+
code, code_interpreter, chat, model, verbose
|
260
|
+
)
|
261
|
+
|
262
|
+
# if we had to debug the code to fix an issue, replace the old code
|
263
|
+
# with the fixed code in the response
|
264
|
+
fixed_response = find_and_replace_code(response, code)
|
265
|
+
return_chat.append({"role": "assistant", "content": fixed_response})
|
266
|
+
|
267
|
+
media_data = capture_media_from_exec(execution)
|
268
|
+
int_chat_elt: Message = {"role": "observation", "content": obs}
|
269
|
+
if media_list:
|
270
|
+
int_chat_elt["media"] = media_data
|
271
|
+
return_chat.append(int_chat_elt)
|
272
|
+
else:
|
273
|
+
return_chat.append({"role": "assistant", "content": response})
|
274
|
+
return return_chat
|
275
|
+
|
276
|
+
|
277
|
+
def create_finalize_plan(
|
278
|
+
chat: List[Message],
|
279
|
+
model: LMM,
|
280
|
+
verbose: bool = False,
|
281
|
+
) -> Tuple[List[Message], PlanContext]:
|
282
|
+
prompt = FINALIZE_PLAN.format(
|
283
|
+
planning=get_planning(chat),
|
284
|
+
excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
|
285
|
+
)
|
286
|
+
response = model.chat([{"role": "user", "content": prompt}])
|
287
|
+
plan_str = cast(str, response)
|
288
|
+
return_chat: List[Message] = [{"role": "assistant", "content": plan_str}]
|
289
|
+
|
290
|
+
plan_json = extract_tag(plan_str, "json")
|
291
|
+
plan = (
|
292
|
+
extract_json(plan_json)
|
293
|
+
if plan_json is not None
|
294
|
+
else {"plan": plan_str, "instructions": [], "code": ""}
|
295
|
+
)
|
296
|
+
code_snippets = extract_tag(plan_str, "code")
|
297
|
+
plan["code"] = code_snippets if code_snippets is not None else ""
|
298
|
+
if verbose:
|
299
|
+
_CONSOLE.print(
|
300
|
+
f"[bold cyan]Final Plan:[/bold cyan] [magenta]{plan['plan']}[/magenta]"
|
301
|
+
)
|
302
|
+
print_table("Plan", ["Instructions"], [[p] for p in plan["instructions"]])
|
303
|
+
print_code("Plan Code", plan["code"])
|
304
|
+
|
305
|
+
return return_chat, PlanContext(**plan)
|
306
|
+
|
307
|
+
|
308
|
+
class VisionAgentPlannerV2(Agent):
|
309
|
+
def __init__(
|
310
|
+
self,
|
311
|
+
planner: Optional[LMM] = None,
|
312
|
+
critic: Optional[LMM] = None,
|
313
|
+
max_steps: int = 10,
|
314
|
+
use_multi_trial_planning: bool = False,
|
315
|
+
critique_steps: int = 11,
|
316
|
+
verbose: bool = False,
|
317
|
+
code_sandbox_runtime: Optional[str] = None,
|
318
|
+
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
319
|
+
) -> None:
|
320
|
+
self.planner = (
|
321
|
+
planner
|
322
|
+
if planner is not None
|
323
|
+
else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
|
324
|
+
)
|
325
|
+
self.critic = (
|
326
|
+
critic
|
327
|
+
if critic is not None
|
328
|
+
else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
|
329
|
+
)
|
330
|
+
self.max_steps = max_steps
|
331
|
+
self.use_multi_trial_planning = use_multi_trial_planning
|
332
|
+
self.critique_steps = critique_steps
|
333
|
+
|
334
|
+
self.verbose = verbose
|
335
|
+
self.code_sandbox_runtime = code_sandbox_runtime
|
336
|
+
self.update_callback = update_callback
|
337
|
+
|
338
|
+
def __call__(
|
339
|
+
self,
|
340
|
+
input: Union[str, List[Message]],
|
341
|
+
media: Optional[Union[str, Path]] = None,
|
342
|
+
) -> Union[str, List[Message]]:
|
343
|
+
if isinstance(input, str):
|
344
|
+
if media is not None:
|
345
|
+
input = [{"role": "user", "content": input, "media": [media]}]
|
346
|
+
else:
|
347
|
+
input = [{"role": "user", "content": input}]
|
348
|
+
plan = self.generate_plan(input)
|
349
|
+
return str(plan)
|
350
|
+
|
351
|
+
def generate_plan(
|
352
|
+
self,
|
353
|
+
chat: List[Message],
|
354
|
+
code_interpreter: Optional[CodeInterpreter] = None,
|
355
|
+
) -> PlanContext:
|
356
|
+
if not chat:
|
357
|
+
raise ValueError("Chat cannot be empty")
|
358
|
+
|
359
|
+
chat = copy.deepcopy(chat)
|
360
|
+
code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
|
361
|
+
self.code_sandbox_runtime
|
362
|
+
)
|
363
|
+
|
364
|
+
with code_interpreter:
|
365
|
+
critque_steps = 1
|
366
|
+
step = self.max_steps
|
367
|
+
finished = False
|
368
|
+
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
369
|
+
int_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
|
370
|
+
while step > 0 and not finished:
|
371
|
+
if self.use_multi_trial_planning:
|
372
|
+
response = run_multi_trial_planning(
|
373
|
+
int_chat, media_list, self.planner
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
response = run_planning(int_chat, media_list, self.planner)
|
377
|
+
|
378
|
+
response = response_safeguards(response)
|
379
|
+
thinking = extract_tag(response, "thinking")
|
380
|
+
code = extract_tag(response, "execute_python")
|
381
|
+
finalize_plan = extract_tag(response, "finalize_plan")
|
382
|
+
finished = finalize_plan is not None
|
383
|
+
|
384
|
+
if self.verbose:
|
385
|
+
_CONSOLE.print(
|
386
|
+
f"[bold cyan]Step {step}:[/bold cyan] [green]{thinking}[/green]"
|
387
|
+
)
|
388
|
+
if finalize_plan is not None:
|
389
|
+
_CONSOLE.print(
|
390
|
+
f"[bold cyan]Finalizing Plan:[/bold cyan] [magenta]{finalize_plan}[/magenta]"
|
391
|
+
)
|
392
|
+
|
393
|
+
updated_chat = maybe_run_code(
|
394
|
+
code,
|
395
|
+
response,
|
396
|
+
int_chat,
|
397
|
+
media_list,
|
398
|
+
self.planner,
|
399
|
+
code_interpreter,
|
400
|
+
self.verbose,
|
401
|
+
)
|
402
|
+
|
403
|
+
if critque_steps % self.critique_steps == 0:
|
404
|
+
critique = run_critic(int_chat, media_list, self.critic)
|
405
|
+
if critique is not None and int_chat[-1]["role"] == "observation":
|
406
|
+
_CONSOLE.print(
|
407
|
+
f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
|
408
|
+
)
|
409
|
+
critique_str = f"\n[critique]\n{critique}\n[end of critique]"
|
410
|
+
updated_chat[-1]["content"] += critique_str # type: ignore
|
411
|
+
# if plan was critiqued, ensure we don't finish so we can
|
412
|
+
# respond to the critique
|
413
|
+
finished = False
|
414
|
+
|
415
|
+
critque_steps += 1
|
416
|
+
step -= 1
|
417
|
+
updated_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
|
418
|
+
int_chat.extend(updated_chat)
|
419
|
+
for chat_elt in updated_chat:
|
420
|
+
self.update_callback(chat_elt)
|
421
|
+
|
422
|
+
updated_chat, plan_context = create_finalize_plan(
|
423
|
+
int_chat, self.planner, self.verbose
|
424
|
+
)
|
425
|
+
int_chat.extend(updated_chat)
|
426
|
+
for chat_elt in updated_chat:
|
427
|
+
self.update_callback(chat_elt)
|
428
|
+
|
429
|
+
return plan_context
|
430
|
+
|
431
|
+
def log_progress(self, data: Dict[str, Any]) -> None:
|
432
|
+
pass
|
vision_agent/lmm/lmm.py
CHANGED
@@ -400,6 +400,8 @@ class AnthropicLMM(LMM):
|
|
400
400
|
if "media" in msg:
|
401
401
|
for media_path in msg["media"]:
|
402
402
|
encoded_media = encode_media(media_path, resize=768)
|
403
|
+
if encoded_media.startswith("data:image/png;base64,"):
|
404
|
+
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
403
405
|
content.append(
|
404
406
|
ImageBlockParam(
|
405
407
|
type="image",
|
@@ -447,6 +449,8 @@ class AnthropicLMM(LMM):
|
|
447
449
|
if media:
|
448
450
|
for m in media:
|
449
451
|
encoded_media = encode_media(m, resize=768)
|
452
|
+
if encoded_media.startswith("data:image/png;base64,"):
|
453
|
+
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
450
454
|
content.append(
|
451
455
|
ImageBlockParam(
|
452
456
|
type="image",
|
vision_agent/tools/__init__.py
CHANGED
@@ -13,7 +13,7 @@ from .meta_tools import (
|
|
13
13
|
view_media_artifact,
|
14
14
|
)
|
15
15
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
16
|
-
from .tool_utils import get_tool_descriptions_by_names
|
16
|
+
from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
|
17
17
|
from .tools import (
|
18
18
|
FUNCTION_TOOLS,
|
19
19
|
TOOL_DESCRIPTIONS,
|
@@ -24,6 +24,7 @@ from .tools import (
|
|
24
24
|
UTIL_TOOLS,
|
25
25
|
UTILITIES_DOCSTRING,
|
26
26
|
blip_image_caption,
|
27
|
+
claude35_text_extraction,
|
27
28
|
clip,
|
28
29
|
closest_box_distance,
|
29
30
|
closest_mask_distance,
|