vision-agent 0.2.199__py3-none-any.whl → 0.2.201__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/__init__.py +2 -1
- vision_agent/agent/agent.py +33 -0
- vision_agent/agent/agent_utils.py +47 -34
- vision_agent/agent/types.py +51 -0
- vision_agent/agent/vision_agent.py +20 -77
- vision_agent/agent/vision_agent_coder.py +0 -6
- vision_agent/agent/vision_agent_coder_v2.py +131 -43
- vision_agent/agent/vision_agent_planner.py +0 -6
- vision_agent/agent/vision_agent_planner_prompts_v2.py +1 -1
- vision_agent/agent/vision_agent_planner_v2.py +109 -50
- vision_agent/agent/vision_agent_prompts.py +4 -4
- vision_agent/agent/vision_agent_prompts_v2.py +46 -0
- vision_agent/agent/vision_agent_v2.py +215 -0
- vision_agent/tools/meta_tools.py +18 -94
- vision_agent/utils/execute.py +1 -1
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/METADATA +1 -1
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/RECORD +19 -16
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/WHEEL +0 -0
@@ -6,19 +6,19 @@ from rich.console import Console
|
|
6
6
|
from rich.markup import escape
|
7
7
|
|
8
8
|
import vision_agent.tools as T
|
9
|
-
from vision_agent.agent import
|
9
|
+
from vision_agent.agent import AgentCoder, AgentPlanner
|
10
10
|
from vision_agent.agent.agent_utils import (
|
11
|
-
CodeContext,
|
12
11
|
DefaultImports,
|
13
|
-
PlanContext,
|
14
12
|
add_media_to_chat,
|
15
13
|
capture_media_from_exec,
|
14
|
+
convert_message_to_agentmessage,
|
16
15
|
extract_tag,
|
17
16
|
format_feedback,
|
18
17
|
format_plan_v2,
|
19
18
|
print_code,
|
20
19
|
strip_function_calls,
|
21
20
|
)
|
21
|
+
from vision_agent.agent.types import AgentMessage, CodeContext, PlanContext
|
22
22
|
from vision_agent.agent.vision_agent_coder_prompts_v2 import CODE, FIX_BUG, TEST
|
23
23
|
from vision_agent.agent.vision_agent_planner_v2 import VisionAgentPlannerV2
|
24
24
|
from vision_agent.lmm import LMM, AnthropicLMM
|
@@ -34,6 +34,12 @@ from vision_agent.utils.sim import Sim, load_cached_sim
|
|
34
34
|
_CONSOLE = Console()
|
35
35
|
|
36
36
|
|
37
|
+
def format_code_context(
|
38
|
+
code_context: CodeContext,
|
39
|
+
) -> str:
|
40
|
+
return f"<final_code>{code_context.code}</final_code>\n<final_test>{code_context.test}</final_test>"
|
41
|
+
|
42
|
+
|
37
43
|
def retrieve_tools(
|
38
44
|
plan: List[str],
|
39
45
|
tool_recommender: Sim,
|
@@ -49,46 +55,54 @@ def retrieve_tools(
|
|
49
55
|
|
50
56
|
def write_code(
|
51
57
|
coder: LMM,
|
52
|
-
chat: List[
|
58
|
+
chat: List[AgentMessage],
|
53
59
|
tool_docs: str,
|
54
60
|
plan: str,
|
55
61
|
) -> str:
|
56
62
|
chat = copy.deepcopy(chat)
|
57
|
-
if chat[-1]
|
63
|
+
if chat[-1].role != "user":
|
58
64
|
raise ValueError("Last chat message must be from the user.")
|
59
65
|
|
60
|
-
user_request = chat[-1]
|
66
|
+
user_request = chat[-1].content
|
61
67
|
prompt = CODE.format(
|
62
68
|
docstring=tool_docs,
|
63
69
|
question=user_request,
|
64
70
|
plan=plan,
|
65
71
|
)
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
response = cast(str, coder([{"role": "user", "content": prompt}], stream=False))
|
73
|
+
maybe_code = extract_tag(response, "code")
|
74
|
+
|
75
|
+
# if the response wasn't properly formatted with the code tags just retrun the response
|
76
|
+
if maybe_code is None:
|
77
|
+
return response
|
78
|
+
return maybe_code
|
69
79
|
|
70
80
|
|
71
81
|
def write_test(
|
72
82
|
tester: LMM,
|
73
|
-
chat: List[
|
83
|
+
chat: List[AgentMessage],
|
74
84
|
tool_util_docs: str,
|
75
85
|
code: str,
|
76
86
|
media_list: Optional[Sequence[Union[str, Path]]] = None,
|
77
87
|
) -> str:
|
78
88
|
chat = copy.deepcopy(chat)
|
79
|
-
if chat[-1]
|
89
|
+
if chat[-1].role != "user":
|
80
90
|
raise ValueError("Last chat message must be from the user.")
|
81
91
|
|
82
|
-
user_request = chat[-1]
|
92
|
+
user_request = chat[-1].content
|
83
93
|
prompt = TEST.format(
|
84
94
|
docstring=tool_util_docs,
|
85
95
|
question=user_request,
|
86
96
|
code=code,
|
87
97
|
media=media_list,
|
88
98
|
)
|
89
|
-
|
90
|
-
|
91
|
-
|
99
|
+
response = cast(str, tester([{"role": "user", "content": prompt}], stream=False))
|
100
|
+
maybe_code = extract_tag(response, "code")
|
101
|
+
|
102
|
+
# if the response wasn't properly formatted with the code tags just retrun the response
|
103
|
+
if maybe_code is None:
|
104
|
+
return response
|
105
|
+
return maybe_code
|
92
106
|
|
93
107
|
|
94
108
|
def debug_code(
|
@@ -170,12 +184,11 @@ def write_and_test_code(
|
|
170
184
|
coder: LMM,
|
171
185
|
tester: LMM,
|
172
186
|
debugger: LMM,
|
173
|
-
chat: List[
|
187
|
+
chat: List[AgentMessage],
|
174
188
|
plan: str,
|
175
189
|
tool_docs: str,
|
176
190
|
code_interpreter: CodeInterpreter,
|
177
191
|
media_list: List[Union[str, Path]],
|
178
|
-
update_callback: Callable[[Dict[str, Any]], None],
|
179
192
|
verbose: bool,
|
180
193
|
) -> CodeContext:
|
181
194
|
code = write_code(
|
@@ -226,14 +239,6 @@ def write_and_test_code(
|
|
226
239
|
f"[bold cyan]Code execution result after attempted fix:[/bold cyan] [yellow]{escape(result.text(include_logs=True))}[/yellow]"
|
227
240
|
)
|
228
241
|
|
229
|
-
update_callback(
|
230
|
-
{
|
231
|
-
"role": "assistant",
|
232
|
-
"content": f"<final_code>{DefaultImports.to_code_string()}\n{code}</final_code>\n<final_test>{DefaultImports.to_code_string()}\n{test}</final_test>",
|
233
|
-
"media": capture_media_from_exec(result),
|
234
|
-
}
|
235
|
-
)
|
236
|
-
|
237
242
|
return CodeContext(
|
238
243
|
code=f"{DefaultImports.to_code_string()}\n{code}",
|
239
244
|
test=f"{DefaultImports.to_code_string()}\n{test}",
|
@@ -242,10 +247,12 @@ def write_and_test_code(
|
|
242
247
|
)
|
243
248
|
|
244
249
|
|
245
|
-
class VisionAgentCoderV2(
|
250
|
+
class VisionAgentCoderV2(AgentCoder):
|
251
|
+
"""VisionAgentCoderV2 is an agent that will write vision code for you."""
|
252
|
+
|
246
253
|
def __init__(
|
247
254
|
self,
|
248
|
-
planner: Optional[
|
255
|
+
planner: Optional[AgentPlanner] = None,
|
249
256
|
coder: Optional[LMM] = None,
|
250
257
|
tester: Optional[LMM] = None,
|
251
258
|
debugger: Optional[LMM] = None,
|
@@ -254,6 +261,25 @@ class VisionAgentCoderV2(Agent):
|
|
254
261
|
code_sandbox_runtime: Optional[str] = None,
|
255
262
|
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
256
263
|
) -> None:
|
264
|
+
"""Initialize the VisionAgentCoderV2.
|
265
|
+
|
266
|
+
Parameters:
|
267
|
+
planner (Optional[AgentPlanner]): The planner agent to use for generating
|
268
|
+
vision plans. If None, a default VisionAgentPlannerV2 will be used.
|
269
|
+
coder (Optional[LMM]): The language model to use for the coder agent. If
|
270
|
+
None, a default AnthropicLMM will be used.
|
271
|
+
tester (Optional[LMM]): The language model to use for the tester agent. If
|
272
|
+
None, a default AnthropicLMM will be used.
|
273
|
+
debugger (Optional[LMM]): The language model to use for the debugger agent.
|
274
|
+
tool_recommender (Optional[Union[str, Sim]]): The tool recommender to use.
|
275
|
+
verbose (bool): Whether to print out debug information.
|
276
|
+
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
|
277
|
+
be one of: None, "local" or "e2b". If None, it will read from the
|
278
|
+
environment variable CODE_SANDBOX_RUNTIME.
|
279
|
+
update_callback (Callable[[Dict[str, Any]], None]): The callback function
|
280
|
+
that will send back intermediate conversation messages.
|
281
|
+
"""
|
282
|
+
|
257
283
|
self.planner = (
|
258
284
|
planner
|
259
285
|
if planner is not None
|
@@ -290,20 +316,52 @@ class VisionAgentCoderV2(Agent):
|
|
290
316
|
self,
|
291
317
|
input: Union[str, List[Message]],
|
292
318
|
media: Optional[Union[str, Path]] = None,
|
293
|
-
) ->
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
input[
|
298
|
-
|
299
|
-
|
300
|
-
|
319
|
+
) -> str:
|
320
|
+
"""Generate vision code from a conversation.
|
321
|
+
|
322
|
+
Parameters:
|
323
|
+
input (Union[str, List[Message]]): The input to the agent. This can be a
|
324
|
+
string or a list of messages in the format of [{"role": "user",
|
325
|
+
"content": "describe your task here..."}, ...].
|
326
|
+
media (Optional[Union[str, Path]]): The path to the media file to use with
|
327
|
+
the input. This can be an image or video file.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
str: The generated code as a string.
|
331
|
+
"""
|
332
|
+
|
333
|
+
input_msg = convert_message_to_agentmessage(input, media)
|
334
|
+
return self.generate_code(input_msg).code
|
335
|
+
|
336
|
+
def generate_code(
|
337
|
+
self,
|
338
|
+
chat: List[AgentMessage],
|
339
|
+
max_steps: Optional[int] = None,
|
340
|
+
code_interpreter: Optional[CodeInterpreter] = None,
|
341
|
+
) -> CodeContext:
|
342
|
+
"""Generate vision code from a conversation.
|
343
|
+
|
344
|
+
Parameters:
|
345
|
+
chat (List[AgentMessage]): The input to the agent. This should be a list of
|
346
|
+
AgentMessage objects.
|
347
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
CodeContext: The generated code as a CodeContext object which includes the
|
351
|
+
code, test code, whether or not it was exceuted successfully, and the
|
352
|
+
execution result.
|
353
|
+
"""
|
354
|
+
|
301
355
|
chat = copy.deepcopy(chat)
|
302
|
-
with
|
303
|
-
self.code_sandbox_runtime
|
356
|
+
with (
|
357
|
+
CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
|
358
|
+
if code_interpreter is None
|
359
|
+
else code_interpreter
|
304
360
|
) as code_interpreter:
|
305
361
|
int_chat, orig_chat, _ = add_media_to_chat(chat, code_interpreter)
|
306
|
-
plan_context = self.planner.generate_plan(
|
362
|
+
plan_context = self.planner.generate_plan(
|
363
|
+
int_chat, max_steps=max_steps, code_interpreter=code_interpreter
|
364
|
+
)
|
307
365
|
code_context = self.generate_code_from_plan(
|
308
366
|
orig_chat,
|
309
367
|
plan_context,
|
@@ -313,13 +371,30 @@ class VisionAgentCoderV2(Agent):
|
|
313
371
|
|
314
372
|
def generate_code_from_plan(
|
315
373
|
self,
|
316
|
-
chat: List[
|
374
|
+
chat: List[AgentMessage],
|
317
375
|
plan_context: PlanContext,
|
318
376
|
code_interpreter: Optional[CodeInterpreter] = None,
|
319
377
|
) -> CodeContext:
|
378
|
+
"""Generate vision code from a conversation and a previously made plan. This
|
379
|
+
will skip the planning step and go straight to generating code.
|
380
|
+
|
381
|
+
Parameters:
|
382
|
+
chat (List[AgentMessage]): The input to the agent. This should be a list of
|
383
|
+
AgentMessage objects.
|
384
|
+
plan_context (PlanContext): The plan context that was previously generated.
|
385
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
386
|
+
|
387
|
+
Returns:
|
388
|
+
CodeContext: The generated code as a CodeContext object which includes the
|
389
|
+
code, test code, whether or not it was exceuted successfully, and the
|
390
|
+
execution result.
|
391
|
+
"""
|
392
|
+
|
320
393
|
chat = copy.deepcopy(chat)
|
321
|
-
with
|
322
|
-
self.code_sandbox_runtime
|
394
|
+
with (
|
395
|
+
CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
|
396
|
+
if code_interpreter is None
|
397
|
+
else code_interpreter
|
323
398
|
) as code_interpreter:
|
324
399
|
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
325
400
|
tool_docs = retrieve_tools(plan_context.instructions, self.tool_recommender)
|
@@ -331,10 +406,23 @@ class VisionAgentCoderV2(Agent):
|
|
331
406
|
plan=format_plan_v2(plan_context),
|
332
407
|
tool_docs=tool_docs,
|
333
408
|
code_interpreter=code_interpreter,
|
334
|
-
media_list=media_list,
|
335
|
-
update_callback=self.update_callback,
|
409
|
+
media_list=media_list,
|
336
410
|
verbose=self.verbose,
|
337
411
|
)
|
412
|
+
|
413
|
+
self.update_callback(
|
414
|
+
{
|
415
|
+
"role": "coder",
|
416
|
+
"content": format_code_context(code_context),
|
417
|
+
"media": capture_media_from_exec(code_context.test_result),
|
418
|
+
}
|
419
|
+
)
|
420
|
+
self.update_callback(
|
421
|
+
{
|
422
|
+
"role": "observation",
|
423
|
+
"content": code_context.test_result.text(),
|
424
|
+
}
|
425
|
+
)
|
338
426
|
return code_context
|
339
427
|
|
340
428
|
def log_progress(self, data: Dict[str, Any]) -> None:
|
@@ -391,12 +391,6 @@ class VisionAgentPlanner(Agent):
|
|
391
391
|
for chat_i in chat:
|
392
392
|
if "media" in chat_i:
|
393
393
|
for media in chat_i["media"]:
|
394
|
-
media = (
|
395
|
-
media
|
396
|
-
if type(media) is str
|
397
|
-
and media.startswith(("http", "https"))
|
398
|
-
else code_interpreter.upload_file(cast(str, media))
|
399
|
-
)
|
400
394
|
chat_i["content"] += f" Media name {media}" # type: ignore
|
401
395
|
media_list.append(str(media))
|
402
396
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import logging
|
3
|
+
import time
|
3
4
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
@@ -10,16 +11,17 @@ from rich.markup import escape
|
|
10
11
|
|
11
12
|
import vision_agent.tools as T
|
12
13
|
import vision_agent.tools.planner_tools as pt
|
13
|
-
from vision_agent.agent import
|
14
|
+
from vision_agent.agent import AgentPlanner
|
14
15
|
from vision_agent.agent.agent_utils import (
|
15
|
-
PlanContext,
|
16
16
|
add_media_to_chat,
|
17
17
|
capture_media_from_exec,
|
18
|
+
convert_message_to_agentmessage,
|
18
19
|
extract_json,
|
19
20
|
extract_tag,
|
20
21
|
print_code,
|
21
22
|
print_table,
|
22
23
|
)
|
24
|
+
from vision_agent.agent.types import AgentMessage, PlanContext
|
23
25
|
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
24
26
|
CRITIQUE_PLAN,
|
25
27
|
EXAMPLE_PLAN1,
|
@@ -70,26 +72,24 @@ class DefaultPlanningImports:
|
|
70
72
|
|
71
73
|
|
72
74
|
def get_planning(
|
73
|
-
chat: List[
|
75
|
+
chat: List[AgentMessage],
|
74
76
|
) -> str:
|
75
77
|
chat = copy.deepcopy(chat)
|
76
78
|
planning = ""
|
77
79
|
for chat_i in chat:
|
78
|
-
if chat_i
|
79
|
-
planning += f"USER: {chat_i
|
80
|
-
elif chat_i
|
81
|
-
planning += f"OBSERVATION: {chat_i
|
82
|
-
elif chat_i
|
83
|
-
planning += f"
|
84
|
-
else:
|
85
|
-
raise ValueError(f"Unknown role: {chat_i['role']}")
|
80
|
+
if chat_i.role == "user":
|
81
|
+
planning += f"USER: {chat_i.content}\n\n"
|
82
|
+
elif chat_i.role == "observation":
|
83
|
+
planning += f"OBSERVATION: {chat_i.content}\n\n"
|
84
|
+
elif chat_i.role == "planner":
|
85
|
+
planning += f"AGENT: {chat_i.content}\n\n"
|
86
86
|
|
87
87
|
return planning
|
88
88
|
|
89
89
|
|
90
90
|
def run_planning(
|
91
|
-
chat: List[
|
92
|
-
media_list: List[str],
|
91
|
+
chat: List[AgentMessage],
|
92
|
+
media_list: List[Union[str, Path]],
|
93
93
|
model: LMM,
|
94
94
|
) -> str:
|
95
95
|
# only keep last 10 messages for planning
|
@@ -102,16 +102,16 @@ def run_planning(
|
|
102
102
|
)
|
103
103
|
|
104
104
|
message: Message = {"role": "user", "content": prompt}
|
105
|
-
if chat[-1]
|
106
|
-
message["media"] = chat[-1]
|
105
|
+
if chat[-1].role == "observation" and chat[-1].media is not None:
|
106
|
+
message["media"] = chat[-1].media
|
107
107
|
|
108
108
|
response = model.chat([message])
|
109
109
|
return cast(str, response)
|
110
110
|
|
111
111
|
|
112
112
|
def run_multi_trial_planning(
|
113
|
-
chat: List[
|
114
|
-
media_list: List[str],
|
113
|
+
chat: List[AgentMessage],
|
114
|
+
media_list: List[Union[str, Path]],
|
115
115
|
model: LMM,
|
116
116
|
) -> str:
|
117
117
|
planning = get_planning(chat)
|
@@ -123,8 +123,8 @@ def run_multi_trial_planning(
|
|
123
123
|
)
|
124
124
|
|
125
125
|
message: Message = {"role": "user", "content": prompt}
|
126
|
-
if chat[-1]
|
127
|
-
message["media"] = chat[-1]
|
126
|
+
if chat[-1].role == "observation" and chat[-1].media is not None:
|
127
|
+
message["media"] = chat[-1].media
|
128
128
|
|
129
129
|
responses = []
|
130
130
|
with ThreadPoolExecutor() as executor:
|
@@ -151,7 +151,9 @@ def run_multi_trial_planning(
|
|
151
151
|
return cast(str, responses[0])
|
152
152
|
|
153
153
|
|
154
|
-
def run_critic(
|
154
|
+
def run_critic(
|
155
|
+
chat: List[AgentMessage], media_list: List[Union[str, Path]], model: LMM
|
156
|
+
) -> Optional[str]:
|
155
157
|
planning = get_planning(chat)
|
156
158
|
prompt = CRITIQUE_PLAN.format(
|
157
159
|
planning=planning,
|
@@ -196,17 +198,19 @@ def response_safeguards(response: str) -> str:
|
|
196
198
|
def execute_code_action(
|
197
199
|
code: str,
|
198
200
|
code_interpreter: CodeInterpreter,
|
199
|
-
chat: List[
|
201
|
+
chat: List[AgentMessage],
|
200
202
|
model: LMM,
|
201
203
|
verbose: bool = False,
|
202
204
|
) -> Tuple[Execution, str, str]:
|
203
205
|
if verbose:
|
204
206
|
print_code("Code to Execute:", code)
|
207
|
+
start = time.time()
|
205
208
|
execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
|
209
|
+
end = time.time()
|
206
210
|
obs = execution.text(include_results=False).strip()
|
207
211
|
if verbose:
|
208
212
|
_CONSOLE.print(
|
209
|
-
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
213
|
+
f"[bold cyan]Code Execution Output ({end - start:.2f} sec):[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
210
214
|
)
|
211
215
|
|
212
216
|
count = 1
|
@@ -246,13 +250,13 @@ def find_and_replace_code(response: str, code: str) -> str:
|
|
246
250
|
def maybe_run_code(
|
247
251
|
code: Optional[str],
|
248
252
|
response: str,
|
249
|
-
chat: List[
|
250
|
-
media_list: List[str],
|
253
|
+
chat: List[AgentMessage],
|
254
|
+
media_list: List[Union[str, Path]],
|
251
255
|
model: LMM,
|
252
256
|
code_interpreter: CodeInterpreter,
|
253
257
|
verbose: bool = False,
|
254
|
-
) -> List[
|
255
|
-
return_chat: List[
|
258
|
+
) -> List[AgentMessage]:
|
259
|
+
return_chat: List[AgentMessage] = []
|
256
260
|
if code is not None:
|
257
261
|
code = code_safeguards(code)
|
258
262
|
execution, obs, code = execute_code_action(
|
@@ -262,30 +266,32 @@ def maybe_run_code(
|
|
262
266
|
# if we had to debug the code to fix an issue, replace the old code
|
263
267
|
# with the fixed code in the response
|
264
268
|
fixed_response = find_and_replace_code(response, code)
|
265
|
-
return_chat.append(
|
269
|
+
return_chat.append(
|
270
|
+
AgentMessage(role="planner", content=fixed_response, media=None)
|
271
|
+
)
|
266
272
|
|
267
273
|
media_data = capture_media_from_exec(execution)
|
268
|
-
int_chat_elt
|
274
|
+
int_chat_elt = AgentMessage(role="observation", content=obs, media=None)
|
269
275
|
if media_list:
|
270
|
-
int_chat_elt
|
276
|
+
int_chat_elt.media = cast(List[Union[str, Path]], media_data)
|
271
277
|
return_chat.append(int_chat_elt)
|
272
278
|
else:
|
273
|
-
return_chat.append(
|
279
|
+
return_chat.append(AgentMessage(role="planner", content=response, media=None))
|
274
280
|
return return_chat
|
275
281
|
|
276
282
|
|
277
283
|
def create_finalize_plan(
|
278
|
-
chat: List[
|
284
|
+
chat: List[AgentMessage],
|
279
285
|
model: LMM,
|
280
286
|
verbose: bool = False,
|
281
|
-
) -> Tuple[List[
|
287
|
+
) -> Tuple[List[AgentMessage], PlanContext]:
|
282
288
|
prompt = FINALIZE_PLAN.format(
|
283
289
|
planning=get_planning(chat),
|
284
290
|
excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
|
285
291
|
)
|
286
292
|
response = model.chat([{"role": "user", "content": prompt}])
|
287
293
|
plan_str = cast(str, response)
|
288
|
-
return_chat
|
294
|
+
return_chat = [AgentMessage(role="planner", content=plan_str, media=None)]
|
289
295
|
|
290
296
|
plan_json = extract_tag(plan_str, "json")
|
291
297
|
plan = (
|
@@ -305,7 +311,16 @@ def create_finalize_plan(
|
|
305
311
|
return return_chat, PlanContext(**plan)
|
306
312
|
|
307
313
|
|
308
|
-
|
314
|
+
def get_steps(chat: List[AgentMessage], max_steps: int) -> int:
|
315
|
+
for chat_elt in reversed(chat):
|
316
|
+
if "<count>" in chat_elt.content:
|
317
|
+
return int(extract_tag(chat_elt.content, "count")) # type: ignore
|
318
|
+
return max_steps
|
319
|
+
|
320
|
+
|
321
|
+
class VisionAgentPlannerV2(AgentPlanner):
|
322
|
+
"""VisionAgentPlannerV2 is a class that generates a plan to solve a vision task."""
|
323
|
+
|
309
324
|
def __init__(
|
310
325
|
self,
|
311
326
|
planner: Optional[LMM] = None,
|
@@ -317,6 +332,25 @@ class VisionAgentPlannerV2(Agent):
|
|
317
332
|
code_sandbox_runtime: Optional[str] = None,
|
318
333
|
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
319
334
|
) -> None:
|
335
|
+
"""Initialize the VisionAgentPlannerV2.
|
336
|
+
|
337
|
+
Parameters:
|
338
|
+
planner (Optional[LMM]): The language model to use for planning. If None, a
|
339
|
+
default AnthropicLMM will be used.
|
340
|
+
critic (Optional[LMM]): The language model to use for critiquing the plan.
|
341
|
+
If None, a default AnthropicLMM will be used.
|
342
|
+
max_steps (int): The maximum number of steps to plan.
|
343
|
+
use_multi_trial_planning (bool): Whether to use multi-trial planning.
|
344
|
+
critique_steps (int): The number of steps between critiques. If critic steps
|
345
|
+
is larger than max_steps no critiques will be made.
|
346
|
+
verbose (bool): Whether to print out debug information.
|
347
|
+
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
|
348
|
+
be one of: None, "local" or "e2b". If None, it will read from the
|
349
|
+
environment variable CODE_SANDBOX_RUNTIME.
|
350
|
+
update_callback (Callable[[Dict[str, Any]], None]): The callback function
|
351
|
+
that will send back intermediate conversation messages.
|
352
|
+
"""
|
353
|
+
|
320
354
|
self.planner = (
|
321
355
|
planner
|
322
356
|
if planner is not None
|
@@ -339,20 +373,42 @@ class VisionAgentPlannerV2(Agent):
|
|
339
373
|
self,
|
340
374
|
input: Union[str, List[Message]],
|
341
375
|
media: Optional[Union[str, Path]] = None,
|
342
|
-
) ->
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
376
|
+
) -> str:
|
377
|
+
"""Generate a plan to solve a vision task.
|
378
|
+
|
379
|
+
Parameters:
|
380
|
+
input (Union[str, List[Message]]): The input to the agent. This can be a
|
381
|
+
string or a list of messages in the format of [{"role": "user",
|
382
|
+
"content": "describe your task here..."}, ...].
|
383
|
+
media (Optional[Union[str, Path]]): The path to the media file to use with
|
384
|
+
the input. This can be an image or video file.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
str: The generated plan as a string.
|
388
|
+
"""
|
389
|
+
|
390
|
+
input_msg = convert_message_to_agentmessage(input, media)
|
391
|
+
plan = self.generate_plan(input_msg)
|
392
|
+
return plan.plan
|
350
393
|
|
351
394
|
def generate_plan(
|
352
395
|
self,
|
353
|
-
chat: List[
|
396
|
+
chat: List[AgentMessage],
|
397
|
+
max_steps: Optional[int] = None,
|
354
398
|
code_interpreter: Optional[CodeInterpreter] = None,
|
355
399
|
) -> PlanContext:
|
400
|
+
"""Generate a plan to solve a vision task.
|
401
|
+
|
402
|
+
Parameters:
|
403
|
+
chat (List[AgentMessage]): The conversation messages to generate a plan for.
|
404
|
+
max_steps (Optional[int]): The maximum number of steps to plan.
|
405
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
PlanContext: The generated plan including the instructions and code snippets
|
409
|
+
needed to solve the task.
|
410
|
+
"""
|
411
|
+
|
356
412
|
if not chat:
|
357
413
|
raise ValueError("Chat cannot be empty")
|
358
414
|
|
@@ -360,13 +416,16 @@ class VisionAgentPlannerV2(Agent):
|
|
360
416
|
code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
|
361
417
|
self.code_sandbox_runtime
|
362
418
|
)
|
419
|
+
max_steps = max_steps or self.max_steps
|
363
420
|
|
364
421
|
with code_interpreter:
|
365
422
|
critque_steps = 1
|
366
|
-
step = self.max_steps
|
367
423
|
finished = False
|
368
424
|
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
369
|
-
|
425
|
+
|
426
|
+
step = get_steps(int_chat, max_steps)
|
427
|
+
if "<count>" not in int_chat[-1].content and step == max_steps:
|
428
|
+
int_chat[-1].content += f"\n<count>{step}</count>\n"
|
370
429
|
while step > 0 and not finished:
|
371
430
|
if self.use_multi_trial_planning:
|
372
431
|
response = run_multi_trial_planning(
|
@@ -402,29 +461,29 @@ class VisionAgentPlannerV2(Agent):
|
|
402
461
|
|
403
462
|
if critque_steps % self.critique_steps == 0:
|
404
463
|
critique = run_critic(int_chat, media_list, self.critic)
|
405
|
-
if critique is not None and int_chat[-1]
|
464
|
+
if critique is not None and int_chat[-1].role == "observation":
|
406
465
|
_CONSOLE.print(
|
407
466
|
f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
|
408
467
|
)
|
409
468
|
critique_str = f"\n[critique]\n{critique}\n[end of critique]"
|
410
|
-
updated_chat[-1]
|
469
|
+
updated_chat[-1].content += critique_str
|
411
470
|
# if plan was critiqued, ensure we don't finish so we can
|
412
471
|
# respond to the critique
|
413
472
|
finished = False
|
414
473
|
|
415
474
|
critque_steps += 1
|
416
475
|
step -= 1
|
417
|
-
updated_chat[-1]
|
476
|
+
updated_chat[-1].content += f"\n<count>{step}</count>\n"
|
418
477
|
int_chat.extend(updated_chat)
|
419
478
|
for chat_elt in updated_chat:
|
420
|
-
self.update_callback(chat_elt)
|
479
|
+
self.update_callback(chat_elt.model_dump())
|
421
480
|
|
422
481
|
updated_chat, plan_context = create_finalize_plan(
|
423
482
|
int_chat, self.planner, self.verbose
|
424
483
|
)
|
425
484
|
int_chat.extend(updated_chat)
|
426
485
|
for chat_elt in updated_chat:
|
427
|
-
self.update_callback(chat_elt)
|
486
|
+
self.update_callback(chat_elt.model_dump())
|
428
487
|
|
429
488
|
return plan_context
|
430
489
|
|
@@ -55,10 +55,10 @@ generate_vision_code(artifacts, 'dog_detector.py', 'Can you write code to detect
|
|
55
55
|
|
56
56
|
OBSERVATION:
|
57
57
|
[Artifact dog_detector.py (5 lines total)]
|
58
|
-
0|from vision_agent.tools import load_image,
|
58
|
+
0|from vision_agent.tools import load_image, owl_v2_image
|
59
59
|
1|def detect_dogs(image_path: str):
|
60
60
|
2| image = load_image(image_path)
|
61
|
-
3| dogs =
|
61
|
+
3| dogs = owl_v2_image("dog", image)
|
62
62
|
4| return dogs
|
63
63
|
[End of artifact]
|
64
64
|
|
@@ -96,10 +96,10 @@ edit_vision_code(artifacts, 'dog_detector.py', ['Can you write code to detect do
|
|
96
96
|
|
97
97
|
OBSERVATION:
|
98
98
|
[Artifact dog_detector.py (5 lines total)]
|
99
|
-
0|from vision_agent.tools import load_image,
|
99
|
+
0|from vision_agent.tools import load_image, owl_v2_image
|
100
100
|
1|def detect_dogs(image_path: str):
|
101
101
|
2| image = load_image(image_path)
|
102
|
-
3| dogs =
|
102
|
+
3| dogs = owl_v2_image("dog", image, threshold=0.24)
|
103
103
|
4| return dogs
|
104
104
|
[End of artifact]
|
105
105
|
|