vision-agent 0.2.192__py3-none-any.whl → 0.2.195__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vision_agent/.sim_tools/df.csv +640 -0
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/__init__.py +2 -0
- vision_agent/agent/agent_utils.py +211 -3
- vision_agent/agent/vision_agent_coder.py +5 -113
- vision_agent/agent/vision_agent_coder_prompts_v2.py +119 -0
- vision_agent/agent/vision_agent_coder_v2.py +341 -0
- vision_agent/agent/vision_agent_planner.py +2 -2
- vision_agent/agent/vision_agent_planner_prompts.py +1 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +748 -0
- vision_agent/agent/vision_agent_planner_v2.py +432 -0
- vision_agent/lmm/lmm.py +4 -0
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/planner_tools.py +246 -0
- vision_agent/tools/tool_utils.py +65 -1
- vision_agent/tools/tools.py +98 -35
- vision_agent/utils/image_utils.py +12 -6
- vision_agent/utils/sim.py +65 -14
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/METADATA +1 -1
- vision_agent-0.2.195.dist-info/RECORD +42 -0
- vision_agent-0.2.192.dist-info/RECORD +0 -35
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/WHEEL +0 -0
@@ -0,0 +1,432 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from rich.console import Console
|
9
|
+
from rich.markup import escape
|
10
|
+
|
11
|
+
import vision_agent.tools as T
|
12
|
+
import vision_agent.tools.planner_tools as pt
|
13
|
+
from vision_agent.agent import Agent
|
14
|
+
from vision_agent.agent.agent_utils import (
|
15
|
+
PlanContext,
|
16
|
+
add_media_to_chat,
|
17
|
+
capture_media_from_exec,
|
18
|
+
extract_json,
|
19
|
+
extract_tag,
|
20
|
+
print_code,
|
21
|
+
print_table,
|
22
|
+
)
|
23
|
+
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
24
|
+
CRITIQUE_PLAN,
|
25
|
+
EXAMPLE_PLAN1,
|
26
|
+
EXAMPLE_PLAN2,
|
27
|
+
FINALIZE_PLAN,
|
28
|
+
FIX_BUG,
|
29
|
+
PICK_PLAN,
|
30
|
+
PLAN,
|
31
|
+
)
|
32
|
+
from vision_agent.lmm import LMM, AnthropicLMM, Message
|
33
|
+
from vision_agent.utils.execute import (
|
34
|
+
CodeInterpreter,
|
35
|
+
CodeInterpreterFactory,
|
36
|
+
Execution,
|
37
|
+
)
|
38
|
+
|
39
|
+
logging.basicConfig(level=logging.INFO)
|
40
|
+
UTIL_DOCSTRING = T.get_tool_documentation(
|
41
|
+
[
|
42
|
+
T.load_image,
|
43
|
+
T.extract_frames_and_timestamps,
|
44
|
+
T.save_image,
|
45
|
+
T.save_video,
|
46
|
+
T.overlay_bounding_boxes,
|
47
|
+
T.overlay_segmentation_masks,
|
48
|
+
]
|
49
|
+
)
|
50
|
+
PLANNING_TOOLS_DOCSTRING = UTIL_DOCSTRING + "\n" + pt.PLANNER_DOCSTRING
|
51
|
+
_CONSOLE = Console()
|
52
|
+
|
53
|
+
|
54
|
+
class DefaultPlanningImports:
|
55
|
+
imports = [
|
56
|
+
"import os",
|
57
|
+
"import numpy as np",
|
58
|
+
"import cv2",
|
59
|
+
"from typing import *",
|
60
|
+
"from vision_agent.tools import *",
|
61
|
+
"from vision_agent.tools.planner_tools import claude35_vqa, suggestion, get_tool_for_task",
|
62
|
+
"from pillow_heif import register_heif_opener",
|
63
|
+
"register_heif_opener()",
|
64
|
+
"import matplotlib.pyplot as plt",
|
65
|
+
]
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def prepend_imports(code: str) -> str:
|
69
|
+
return "\n".join(DefaultPlanningImports.imports) + "\n\n" + code
|
70
|
+
|
71
|
+
|
72
|
+
def get_planning(
|
73
|
+
chat: List[Message],
|
74
|
+
) -> str:
|
75
|
+
chat = copy.deepcopy(chat)
|
76
|
+
planning = ""
|
77
|
+
for chat_i in chat:
|
78
|
+
if chat_i["role"] == "user":
|
79
|
+
planning += f"USER: {chat_i['content']}\n\n"
|
80
|
+
elif chat_i["role"] == "observation":
|
81
|
+
planning += f"OBSERVATION: {chat_i['content']}\n\n"
|
82
|
+
elif chat_i["role"] == "assistant":
|
83
|
+
planning += f"ASSISTANT: {chat_i['content']}\n\n"
|
84
|
+
else:
|
85
|
+
raise ValueError(f"Unknown role: {chat_i['role']}")
|
86
|
+
|
87
|
+
return planning
|
88
|
+
|
89
|
+
|
90
|
+
def run_planning(
|
91
|
+
chat: List[Message],
|
92
|
+
media_list: List[str],
|
93
|
+
model: LMM,
|
94
|
+
) -> str:
|
95
|
+
# only keep last 10 messages for planning
|
96
|
+
planning = get_planning(chat[-10:])
|
97
|
+
prompt = PLAN.format(
|
98
|
+
tool_desc=PLANNING_TOOLS_DOCSTRING,
|
99
|
+
examples=f"{EXAMPLE_PLAN1}\n{EXAMPLE_PLAN2}",
|
100
|
+
planning=planning,
|
101
|
+
media_list=str(media_list),
|
102
|
+
)
|
103
|
+
|
104
|
+
message: Message = {"role": "user", "content": prompt}
|
105
|
+
if chat[-1]["role"] == "observation" and "media" in chat[-1]:
|
106
|
+
message["media"] = chat[-1]["media"]
|
107
|
+
|
108
|
+
response = model.chat([message])
|
109
|
+
return cast(str, response)
|
110
|
+
|
111
|
+
|
112
|
+
def run_multi_trial_planning(
|
113
|
+
chat: List[Message],
|
114
|
+
media_list: List[str],
|
115
|
+
model: LMM,
|
116
|
+
) -> str:
|
117
|
+
planning = get_planning(chat)
|
118
|
+
prompt = PLAN.format(
|
119
|
+
tool_desc=PLANNING_TOOLS_DOCSTRING,
|
120
|
+
examples=EXAMPLE_PLAN1,
|
121
|
+
planning=planning,
|
122
|
+
media_list=str(media_list),
|
123
|
+
)
|
124
|
+
|
125
|
+
message: Message = {"role": "user", "content": prompt}
|
126
|
+
if chat[-1]["role"] == "observation" and "media" in chat[-1]:
|
127
|
+
message["media"] = chat[-1]["media"]
|
128
|
+
|
129
|
+
responses = []
|
130
|
+
with ThreadPoolExecutor() as executor:
|
131
|
+
futures = [
|
132
|
+
executor.submit(lambda: model.chat([message], temperature=1.0))
|
133
|
+
for _ in range(3)
|
134
|
+
]
|
135
|
+
for future in as_completed(futures):
|
136
|
+
responses.append(future.result())
|
137
|
+
|
138
|
+
prompt = PICK_PLAN.format(
|
139
|
+
planning=planning,
|
140
|
+
response1=responses[0],
|
141
|
+
response2=responses[1],
|
142
|
+
response3=responses[2],
|
143
|
+
)
|
144
|
+
response = cast(str, model.chat([{"role": "user", "content": prompt}]))
|
145
|
+
json_str = extract_tag(response, "json")
|
146
|
+
if json_str:
|
147
|
+
json_data = extract_json(json_str)
|
148
|
+
best = np.argmax([int(json_data[f"response{k}"]) for k in [1, 2, 3]])
|
149
|
+
return cast(str, responses[best])
|
150
|
+
else:
|
151
|
+
return cast(str, responses[0])
|
152
|
+
|
153
|
+
|
154
|
+
def run_critic(chat: List[Message], media_list: List[str], model: LMM) -> Optional[str]:
|
155
|
+
planning = get_planning(chat)
|
156
|
+
prompt = CRITIQUE_PLAN.format(
|
157
|
+
planning=planning,
|
158
|
+
)
|
159
|
+
message: Message = {"role": "user", "content": prompt}
|
160
|
+
if len(media_list) > 0:
|
161
|
+
message["media"] = media_list
|
162
|
+
|
163
|
+
response = cast(str, model.chat([message]))
|
164
|
+
score = extract_tag(response, "score")
|
165
|
+
thoughts = extract_tag(response, "thoughts")
|
166
|
+
if score is not None and thoughts is not None:
|
167
|
+
try:
|
168
|
+
fscore = float(score)
|
169
|
+
if fscore < 8:
|
170
|
+
return thoughts
|
171
|
+
except ValueError:
|
172
|
+
pass
|
173
|
+
return None
|
174
|
+
|
175
|
+
|
176
|
+
def code_safeguards(code: str) -> str:
|
177
|
+
if "get_tool_for_task" in code:
|
178
|
+
lines = code.split("\n")
|
179
|
+
new_lines = []
|
180
|
+
for line in lines:
|
181
|
+
new_lines.append(line)
|
182
|
+
if "get_tool_for_task" in line:
|
183
|
+
break
|
184
|
+
code = "\n".join(new_lines)
|
185
|
+
return code
|
186
|
+
|
187
|
+
|
188
|
+
def response_safeguards(response: str) -> str:
|
189
|
+
if "<execute_python>" in response:
|
190
|
+
response = response[
|
191
|
+
: response.index("</execute_python>") + len("</execute_python>")
|
192
|
+
]
|
193
|
+
return response
|
194
|
+
|
195
|
+
|
196
|
+
def execute_code_action(
|
197
|
+
code: str,
|
198
|
+
code_interpreter: CodeInterpreter,
|
199
|
+
chat: List[Message],
|
200
|
+
model: LMM,
|
201
|
+
verbose: bool = False,
|
202
|
+
) -> Tuple[Execution, str, str]:
|
203
|
+
if verbose:
|
204
|
+
print_code("Code to Execute:", code)
|
205
|
+
execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
|
206
|
+
obs = execution.text(include_results=False).strip()
|
207
|
+
if verbose:
|
208
|
+
_CONSOLE.print(
|
209
|
+
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
210
|
+
)
|
211
|
+
|
212
|
+
count = 1
|
213
|
+
while not execution.success and count <= 3:
|
214
|
+
prompt = FIX_BUG.format(chat_history=get_planning(chat), code=code, error=obs)
|
215
|
+
response = cast(str, model.chat([{"role": "user", "content": prompt}]))
|
216
|
+
new_code = extract_tag(response, "code")
|
217
|
+
if not new_code:
|
218
|
+
continue
|
219
|
+
else:
|
220
|
+
code = new_code
|
221
|
+
|
222
|
+
execution = code_interpreter.exec_cell(
|
223
|
+
DefaultPlanningImports.prepend_imports(code)
|
224
|
+
)
|
225
|
+
obs = execution.text(include_results=False).strip()
|
226
|
+
if verbose:
|
227
|
+
print_code(f"Fixing Bug Round {count}:", code)
|
228
|
+
_CONSOLE.print(
|
229
|
+
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
230
|
+
)
|
231
|
+
count += 1
|
232
|
+
|
233
|
+
if obs.startswith("----- stdout -----\n"):
|
234
|
+
obs = obs[19:]
|
235
|
+
if obs.endswith("\n----- stderr -----"):
|
236
|
+
obs = obs[:-19]
|
237
|
+
return execution, obs, code
|
238
|
+
|
239
|
+
|
240
|
+
def find_and_replace_code(response: str, code: str) -> str:
|
241
|
+
code_start = response.index("<execute_python>") + len("<execute_python>")
|
242
|
+
code_end = response.index("</execute_python>")
|
243
|
+
return response[:code_start] + code + response[code_end:]
|
244
|
+
|
245
|
+
|
246
|
+
def maybe_run_code(
|
247
|
+
code: Optional[str],
|
248
|
+
response: str,
|
249
|
+
chat: List[Message],
|
250
|
+
media_list: List[str],
|
251
|
+
model: LMM,
|
252
|
+
code_interpreter: CodeInterpreter,
|
253
|
+
verbose: bool = False,
|
254
|
+
) -> List[Message]:
|
255
|
+
return_chat: List[Message] = []
|
256
|
+
if code is not None:
|
257
|
+
code = code_safeguards(code)
|
258
|
+
execution, obs, code = execute_code_action(
|
259
|
+
code, code_interpreter, chat, model, verbose
|
260
|
+
)
|
261
|
+
|
262
|
+
# if we had to debug the code to fix an issue, replace the old code
|
263
|
+
# with the fixed code in the response
|
264
|
+
fixed_response = find_and_replace_code(response, code)
|
265
|
+
return_chat.append({"role": "assistant", "content": fixed_response})
|
266
|
+
|
267
|
+
media_data = capture_media_from_exec(execution)
|
268
|
+
int_chat_elt: Message = {"role": "observation", "content": obs}
|
269
|
+
if media_list:
|
270
|
+
int_chat_elt["media"] = media_data
|
271
|
+
return_chat.append(int_chat_elt)
|
272
|
+
else:
|
273
|
+
return_chat.append({"role": "assistant", "content": response})
|
274
|
+
return return_chat
|
275
|
+
|
276
|
+
|
277
|
+
def create_finalize_plan(
|
278
|
+
chat: List[Message],
|
279
|
+
model: LMM,
|
280
|
+
verbose: bool = False,
|
281
|
+
) -> Tuple[List[Message], PlanContext]:
|
282
|
+
prompt = FINALIZE_PLAN.format(
|
283
|
+
planning=get_planning(chat),
|
284
|
+
excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
|
285
|
+
)
|
286
|
+
response = model.chat([{"role": "user", "content": prompt}])
|
287
|
+
plan_str = cast(str, response)
|
288
|
+
return_chat: List[Message] = [{"role": "assistant", "content": plan_str}]
|
289
|
+
|
290
|
+
plan_json = extract_tag(plan_str, "json")
|
291
|
+
plan = (
|
292
|
+
extract_json(plan_json)
|
293
|
+
if plan_json is not None
|
294
|
+
else {"plan": plan_str, "instructions": [], "code": ""}
|
295
|
+
)
|
296
|
+
code_snippets = extract_tag(plan_str, "code")
|
297
|
+
plan["code"] = code_snippets if code_snippets is not None else ""
|
298
|
+
if verbose:
|
299
|
+
_CONSOLE.print(
|
300
|
+
f"[bold cyan]Final Plan:[/bold cyan] [magenta]{plan['plan']}[/magenta]"
|
301
|
+
)
|
302
|
+
print_table("Plan", ["Instructions"], [[p] for p in plan["instructions"]])
|
303
|
+
print_code("Plan Code", plan["code"])
|
304
|
+
|
305
|
+
return return_chat, PlanContext(**plan)
|
306
|
+
|
307
|
+
|
308
|
+
class VisionAgentPlannerV2(Agent):
|
309
|
+
def __init__(
|
310
|
+
self,
|
311
|
+
planner: Optional[LMM] = None,
|
312
|
+
critic: Optional[LMM] = None,
|
313
|
+
max_steps: int = 10,
|
314
|
+
use_multi_trial_planning: bool = False,
|
315
|
+
critique_steps: int = 11,
|
316
|
+
verbose: bool = False,
|
317
|
+
code_sandbox_runtime: Optional[str] = None,
|
318
|
+
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
319
|
+
) -> None:
|
320
|
+
self.planner = (
|
321
|
+
planner
|
322
|
+
if planner is not None
|
323
|
+
else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
|
324
|
+
)
|
325
|
+
self.critic = (
|
326
|
+
critic
|
327
|
+
if critic is not None
|
328
|
+
else AnthropicLMM(model_name="claude-3-5-sonnet-20241022", temperature=0.0)
|
329
|
+
)
|
330
|
+
self.max_steps = max_steps
|
331
|
+
self.use_multi_trial_planning = use_multi_trial_planning
|
332
|
+
self.critique_steps = critique_steps
|
333
|
+
|
334
|
+
self.verbose = verbose
|
335
|
+
self.code_sandbox_runtime = code_sandbox_runtime
|
336
|
+
self.update_callback = update_callback
|
337
|
+
|
338
|
+
def __call__(
|
339
|
+
self,
|
340
|
+
input: Union[str, List[Message]],
|
341
|
+
media: Optional[Union[str, Path]] = None,
|
342
|
+
) -> Union[str, List[Message]]:
|
343
|
+
if isinstance(input, str):
|
344
|
+
if media is not None:
|
345
|
+
input = [{"role": "user", "content": input, "media": [media]}]
|
346
|
+
else:
|
347
|
+
input = [{"role": "user", "content": input}]
|
348
|
+
plan = self.generate_plan(input)
|
349
|
+
return str(plan)
|
350
|
+
|
351
|
+
def generate_plan(
|
352
|
+
self,
|
353
|
+
chat: List[Message],
|
354
|
+
code_interpreter: Optional[CodeInterpreter] = None,
|
355
|
+
) -> PlanContext:
|
356
|
+
if not chat:
|
357
|
+
raise ValueError("Chat cannot be empty")
|
358
|
+
|
359
|
+
chat = copy.deepcopy(chat)
|
360
|
+
code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
|
361
|
+
self.code_sandbox_runtime
|
362
|
+
)
|
363
|
+
|
364
|
+
with code_interpreter:
|
365
|
+
critque_steps = 1
|
366
|
+
step = self.max_steps
|
367
|
+
finished = False
|
368
|
+
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
369
|
+
int_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
|
370
|
+
while step > 0 and not finished:
|
371
|
+
if self.use_multi_trial_planning:
|
372
|
+
response = run_multi_trial_planning(
|
373
|
+
int_chat, media_list, self.planner
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
response = run_planning(int_chat, media_list, self.planner)
|
377
|
+
|
378
|
+
response = response_safeguards(response)
|
379
|
+
thinking = extract_tag(response, "thinking")
|
380
|
+
code = extract_tag(response, "execute_python")
|
381
|
+
finalize_plan = extract_tag(response, "finalize_plan")
|
382
|
+
finished = finalize_plan is not None
|
383
|
+
|
384
|
+
if self.verbose:
|
385
|
+
_CONSOLE.print(
|
386
|
+
f"[bold cyan]Step {step}:[/bold cyan] [green]{thinking}[/green]"
|
387
|
+
)
|
388
|
+
if finalize_plan is not None:
|
389
|
+
_CONSOLE.print(
|
390
|
+
f"[bold cyan]Finalizing Plan:[/bold cyan] [magenta]{finalize_plan}[/magenta]"
|
391
|
+
)
|
392
|
+
|
393
|
+
updated_chat = maybe_run_code(
|
394
|
+
code,
|
395
|
+
response,
|
396
|
+
int_chat,
|
397
|
+
media_list,
|
398
|
+
self.planner,
|
399
|
+
code_interpreter,
|
400
|
+
self.verbose,
|
401
|
+
)
|
402
|
+
|
403
|
+
if critque_steps % self.critique_steps == 0:
|
404
|
+
critique = run_critic(int_chat, media_list, self.critic)
|
405
|
+
if critique is not None and int_chat[-1]["role"] == "observation":
|
406
|
+
_CONSOLE.print(
|
407
|
+
f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
|
408
|
+
)
|
409
|
+
critique_str = f"\n[critique]\n{critique}\n[end of critique]"
|
410
|
+
updated_chat[-1]["content"] += critique_str # type: ignore
|
411
|
+
# if plan was critiqued, ensure we don't finish so we can
|
412
|
+
# respond to the critique
|
413
|
+
finished = False
|
414
|
+
|
415
|
+
critque_steps += 1
|
416
|
+
step -= 1
|
417
|
+
updated_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
|
418
|
+
int_chat.extend(updated_chat)
|
419
|
+
for chat_elt in updated_chat:
|
420
|
+
self.update_callback(chat_elt)
|
421
|
+
|
422
|
+
updated_chat, plan_context = create_finalize_plan(
|
423
|
+
int_chat, self.planner, self.verbose
|
424
|
+
)
|
425
|
+
int_chat.extend(updated_chat)
|
426
|
+
for chat_elt in updated_chat:
|
427
|
+
self.update_callback(chat_elt)
|
428
|
+
|
429
|
+
return plan_context
|
430
|
+
|
431
|
+
def log_progress(self, data: Dict[str, Any]) -> None:
|
432
|
+
pass
|
vision_agent/lmm/lmm.py
CHANGED
@@ -400,6 +400,8 @@ class AnthropicLMM(LMM):
|
|
400
400
|
if "media" in msg:
|
401
401
|
for media_path in msg["media"]:
|
402
402
|
encoded_media = encode_media(media_path, resize=768)
|
403
|
+
if encoded_media.startswith("data:image/png;base64,"):
|
404
|
+
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
403
405
|
content.append(
|
404
406
|
ImageBlockParam(
|
405
407
|
type="image",
|
@@ -447,6 +449,8 @@ class AnthropicLMM(LMM):
|
|
447
449
|
if media:
|
448
450
|
for m in media:
|
449
451
|
encoded_media = encode_media(m, resize=768)
|
452
|
+
if encoded_media.startswith("data:image/png;base64,"):
|
453
|
+
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
450
454
|
content.append(
|
451
455
|
ImageBlockParam(
|
452
456
|
type="image",
|
vision_agent/tools/__init__.py
CHANGED
@@ -13,7 +13,7 @@ from .meta_tools import (
|
|
13
13
|
view_media_artifact,
|
14
14
|
)
|
15
15
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
16
|
-
from .tool_utils import get_tool_descriptions_by_names
|
16
|
+
from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
|
17
17
|
from .tools import (
|
18
18
|
FUNCTION_TOOLS,
|
19
19
|
TOOL_DESCRIPTIONS,
|
@@ -24,6 +24,7 @@ from .tools import (
|
|
24
24
|
UTIL_TOOLS,
|
25
25
|
UTILITIES_DOCSTRING,
|
26
26
|
blip_image_caption,
|
27
|
+
claude35_text_extraction,
|
27
28
|
clip,
|
28
29
|
closest_box_distance,
|
29
30
|
closest_mask_distance,
|