vision-agent 0.2.199__py3-none-any.whl → 0.2.201__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vision_agent/agent/__init__.py +2 -1
- vision_agent/agent/agent.py +33 -0
- vision_agent/agent/agent_utils.py +47 -34
- vision_agent/agent/types.py +51 -0
- vision_agent/agent/vision_agent.py +20 -77
- vision_agent/agent/vision_agent_coder.py +0 -6
- vision_agent/agent/vision_agent_coder_v2.py +131 -43
- vision_agent/agent/vision_agent_planner.py +0 -6
- vision_agent/agent/vision_agent_planner_prompts_v2.py +1 -1
- vision_agent/agent/vision_agent_planner_v2.py +109 -50
- vision_agent/agent/vision_agent_prompts.py +4 -4
- vision_agent/agent/vision_agent_prompts_v2.py +46 -0
- vision_agent/agent/vision_agent_v2.py +215 -0
- vision_agent/tools/meta_tools.py +18 -94
- vision_agent/utils/execute.py +1 -1
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/METADATA +1 -1
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/RECORD +19 -16
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.199.dist-info → vision_agent-0.2.201.dist-info}/WHEEL +0 -0
@@ -6,19 +6,19 @@ from rich.console import Console
|
|
6
6
|
from rich.markup import escape
|
7
7
|
|
8
8
|
import vision_agent.tools as T
|
9
|
-
from vision_agent.agent import
|
9
|
+
from vision_agent.agent import AgentCoder, AgentPlanner
|
10
10
|
from vision_agent.agent.agent_utils import (
|
11
|
-
CodeContext,
|
12
11
|
DefaultImports,
|
13
|
-
PlanContext,
|
14
12
|
add_media_to_chat,
|
15
13
|
capture_media_from_exec,
|
14
|
+
convert_message_to_agentmessage,
|
16
15
|
extract_tag,
|
17
16
|
format_feedback,
|
18
17
|
format_plan_v2,
|
19
18
|
print_code,
|
20
19
|
strip_function_calls,
|
21
20
|
)
|
21
|
+
from vision_agent.agent.types import AgentMessage, CodeContext, PlanContext
|
22
22
|
from vision_agent.agent.vision_agent_coder_prompts_v2 import CODE, FIX_BUG, TEST
|
23
23
|
from vision_agent.agent.vision_agent_planner_v2 import VisionAgentPlannerV2
|
24
24
|
from vision_agent.lmm import LMM, AnthropicLMM
|
@@ -34,6 +34,12 @@ from vision_agent.utils.sim import Sim, load_cached_sim
|
|
34
34
|
_CONSOLE = Console()
|
35
35
|
|
36
36
|
|
37
|
+
def format_code_context(
|
38
|
+
code_context: CodeContext,
|
39
|
+
) -> str:
|
40
|
+
return f"<final_code>{code_context.code}</final_code>\n<final_test>{code_context.test}</final_test>"
|
41
|
+
|
42
|
+
|
37
43
|
def retrieve_tools(
|
38
44
|
plan: List[str],
|
39
45
|
tool_recommender: Sim,
|
@@ -49,46 +55,54 @@ def retrieve_tools(
|
|
49
55
|
|
50
56
|
def write_code(
|
51
57
|
coder: LMM,
|
52
|
-
chat: List[
|
58
|
+
chat: List[AgentMessage],
|
53
59
|
tool_docs: str,
|
54
60
|
plan: str,
|
55
61
|
) -> str:
|
56
62
|
chat = copy.deepcopy(chat)
|
57
|
-
if chat[-1]
|
63
|
+
if chat[-1].role != "user":
|
58
64
|
raise ValueError("Last chat message must be from the user.")
|
59
65
|
|
60
|
-
user_request = chat[-1]
|
66
|
+
user_request = chat[-1].content
|
61
67
|
prompt = CODE.format(
|
62
68
|
docstring=tool_docs,
|
63
69
|
question=user_request,
|
64
70
|
plan=plan,
|
65
71
|
)
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
response = cast(str, coder([{"role": "user", "content": prompt}], stream=False))
|
73
|
+
maybe_code = extract_tag(response, "code")
|
74
|
+
|
75
|
+
# if the response wasn't properly formatted with the code tags just retrun the response
|
76
|
+
if maybe_code is None:
|
77
|
+
return response
|
78
|
+
return maybe_code
|
69
79
|
|
70
80
|
|
71
81
|
def write_test(
|
72
82
|
tester: LMM,
|
73
|
-
chat: List[
|
83
|
+
chat: List[AgentMessage],
|
74
84
|
tool_util_docs: str,
|
75
85
|
code: str,
|
76
86
|
media_list: Optional[Sequence[Union[str, Path]]] = None,
|
77
87
|
) -> str:
|
78
88
|
chat = copy.deepcopy(chat)
|
79
|
-
if chat[-1]
|
89
|
+
if chat[-1].role != "user":
|
80
90
|
raise ValueError("Last chat message must be from the user.")
|
81
91
|
|
82
|
-
user_request = chat[-1]
|
92
|
+
user_request = chat[-1].content
|
83
93
|
prompt = TEST.format(
|
84
94
|
docstring=tool_util_docs,
|
85
95
|
question=user_request,
|
86
96
|
code=code,
|
87
97
|
media=media_list,
|
88
98
|
)
|
89
|
-
|
90
|
-
|
91
|
-
|
99
|
+
response = cast(str, tester([{"role": "user", "content": prompt}], stream=False))
|
100
|
+
maybe_code = extract_tag(response, "code")
|
101
|
+
|
102
|
+
# if the response wasn't properly formatted with the code tags just retrun the response
|
103
|
+
if maybe_code is None:
|
104
|
+
return response
|
105
|
+
return maybe_code
|
92
106
|
|
93
107
|
|
94
108
|
def debug_code(
|
@@ -170,12 +184,11 @@ def write_and_test_code(
|
|
170
184
|
coder: LMM,
|
171
185
|
tester: LMM,
|
172
186
|
debugger: LMM,
|
173
|
-
chat: List[
|
187
|
+
chat: List[AgentMessage],
|
174
188
|
plan: str,
|
175
189
|
tool_docs: str,
|
176
190
|
code_interpreter: CodeInterpreter,
|
177
191
|
media_list: List[Union[str, Path]],
|
178
|
-
update_callback: Callable[[Dict[str, Any]], None],
|
179
192
|
verbose: bool,
|
180
193
|
) -> CodeContext:
|
181
194
|
code = write_code(
|
@@ -226,14 +239,6 @@ def write_and_test_code(
|
|
226
239
|
f"[bold cyan]Code execution result after attempted fix:[/bold cyan] [yellow]{escape(result.text(include_logs=True))}[/yellow]"
|
227
240
|
)
|
228
241
|
|
229
|
-
update_callback(
|
230
|
-
{
|
231
|
-
"role": "assistant",
|
232
|
-
"content": f"<final_code>{DefaultImports.to_code_string()}\n{code}</final_code>\n<final_test>{DefaultImports.to_code_string()}\n{test}</final_test>",
|
233
|
-
"media": capture_media_from_exec(result),
|
234
|
-
}
|
235
|
-
)
|
236
|
-
|
237
242
|
return CodeContext(
|
238
243
|
code=f"{DefaultImports.to_code_string()}\n{code}",
|
239
244
|
test=f"{DefaultImports.to_code_string()}\n{test}",
|
@@ -242,10 +247,12 @@ def write_and_test_code(
|
|
242
247
|
)
|
243
248
|
|
244
249
|
|
245
|
-
class VisionAgentCoderV2(
|
250
|
+
class VisionAgentCoderV2(AgentCoder):
|
251
|
+
"""VisionAgentCoderV2 is an agent that will write vision code for you."""
|
252
|
+
|
246
253
|
def __init__(
|
247
254
|
self,
|
248
|
-
planner: Optional[
|
255
|
+
planner: Optional[AgentPlanner] = None,
|
249
256
|
coder: Optional[LMM] = None,
|
250
257
|
tester: Optional[LMM] = None,
|
251
258
|
debugger: Optional[LMM] = None,
|
@@ -254,6 +261,25 @@ class VisionAgentCoderV2(Agent):
|
|
254
261
|
code_sandbox_runtime: Optional[str] = None,
|
255
262
|
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
256
263
|
) -> None:
|
264
|
+
"""Initialize the VisionAgentCoderV2.
|
265
|
+
|
266
|
+
Parameters:
|
267
|
+
planner (Optional[AgentPlanner]): The planner agent to use for generating
|
268
|
+
vision plans. If None, a default VisionAgentPlannerV2 will be used.
|
269
|
+
coder (Optional[LMM]): The language model to use for the coder agent. If
|
270
|
+
None, a default AnthropicLMM will be used.
|
271
|
+
tester (Optional[LMM]): The language model to use for the tester agent. If
|
272
|
+
None, a default AnthropicLMM will be used.
|
273
|
+
debugger (Optional[LMM]): The language model to use for the debugger agent.
|
274
|
+
tool_recommender (Optional[Union[str, Sim]]): The tool recommender to use.
|
275
|
+
verbose (bool): Whether to print out debug information.
|
276
|
+
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
|
277
|
+
be one of: None, "local" or "e2b". If None, it will read from the
|
278
|
+
environment variable CODE_SANDBOX_RUNTIME.
|
279
|
+
update_callback (Callable[[Dict[str, Any]], None]): The callback function
|
280
|
+
that will send back intermediate conversation messages.
|
281
|
+
"""
|
282
|
+
|
257
283
|
self.planner = (
|
258
284
|
planner
|
259
285
|
if planner is not None
|
@@ -290,20 +316,52 @@ class VisionAgentCoderV2(Agent):
|
|
290
316
|
self,
|
291
317
|
input: Union[str, List[Message]],
|
292
318
|
media: Optional[Union[str, Path]] = None,
|
293
|
-
) ->
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
input[
|
298
|
-
|
299
|
-
|
300
|
-
|
319
|
+
) -> str:
|
320
|
+
"""Generate vision code from a conversation.
|
321
|
+
|
322
|
+
Parameters:
|
323
|
+
input (Union[str, List[Message]]): The input to the agent. This can be a
|
324
|
+
string or a list of messages in the format of [{"role": "user",
|
325
|
+
"content": "describe your task here..."}, ...].
|
326
|
+
media (Optional[Union[str, Path]]): The path to the media file to use with
|
327
|
+
the input. This can be an image or video file.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
str: The generated code as a string.
|
331
|
+
"""
|
332
|
+
|
333
|
+
input_msg = convert_message_to_agentmessage(input, media)
|
334
|
+
return self.generate_code(input_msg).code
|
335
|
+
|
336
|
+
def generate_code(
|
337
|
+
self,
|
338
|
+
chat: List[AgentMessage],
|
339
|
+
max_steps: Optional[int] = None,
|
340
|
+
code_interpreter: Optional[CodeInterpreter] = None,
|
341
|
+
) -> CodeContext:
|
342
|
+
"""Generate vision code from a conversation.
|
343
|
+
|
344
|
+
Parameters:
|
345
|
+
chat (List[AgentMessage]): The input to the agent. This should be a list of
|
346
|
+
AgentMessage objects.
|
347
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
CodeContext: The generated code as a CodeContext object which includes the
|
351
|
+
code, test code, whether or not it was exceuted successfully, and the
|
352
|
+
execution result.
|
353
|
+
"""
|
354
|
+
|
301
355
|
chat = copy.deepcopy(chat)
|
302
|
-
with
|
303
|
-
self.code_sandbox_runtime
|
356
|
+
with (
|
357
|
+
CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
|
358
|
+
if code_interpreter is None
|
359
|
+
else code_interpreter
|
304
360
|
) as code_interpreter:
|
305
361
|
int_chat, orig_chat, _ = add_media_to_chat(chat, code_interpreter)
|
306
|
-
plan_context = self.planner.generate_plan(
|
362
|
+
plan_context = self.planner.generate_plan(
|
363
|
+
int_chat, max_steps=max_steps, code_interpreter=code_interpreter
|
364
|
+
)
|
307
365
|
code_context = self.generate_code_from_plan(
|
308
366
|
orig_chat,
|
309
367
|
plan_context,
|
@@ -313,13 +371,30 @@ class VisionAgentCoderV2(Agent):
|
|
313
371
|
|
314
372
|
def generate_code_from_plan(
|
315
373
|
self,
|
316
|
-
chat: List[
|
374
|
+
chat: List[AgentMessage],
|
317
375
|
plan_context: PlanContext,
|
318
376
|
code_interpreter: Optional[CodeInterpreter] = None,
|
319
377
|
) -> CodeContext:
|
378
|
+
"""Generate vision code from a conversation and a previously made plan. This
|
379
|
+
will skip the planning step and go straight to generating code.
|
380
|
+
|
381
|
+
Parameters:
|
382
|
+
chat (List[AgentMessage]): The input to the agent. This should be a list of
|
383
|
+
AgentMessage objects.
|
384
|
+
plan_context (PlanContext): The plan context that was previously generated.
|
385
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
386
|
+
|
387
|
+
Returns:
|
388
|
+
CodeContext: The generated code as a CodeContext object which includes the
|
389
|
+
code, test code, whether or not it was exceuted successfully, and the
|
390
|
+
execution result.
|
391
|
+
"""
|
392
|
+
|
320
393
|
chat = copy.deepcopy(chat)
|
321
|
-
with
|
322
|
-
self.code_sandbox_runtime
|
394
|
+
with (
|
395
|
+
CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
|
396
|
+
if code_interpreter is None
|
397
|
+
else code_interpreter
|
323
398
|
) as code_interpreter:
|
324
399
|
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
325
400
|
tool_docs = retrieve_tools(plan_context.instructions, self.tool_recommender)
|
@@ -331,10 +406,23 @@ class VisionAgentCoderV2(Agent):
|
|
331
406
|
plan=format_plan_v2(plan_context),
|
332
407
|
tool_docs=tool_docs,
|
333
408
|
code_interpreter=code_interpreter,
|
334
|
-
media_list=media_list,
|
335
|
-
update_callback=self.update_callback,
|
409
|
+
media_list=media_list,
|
336
410
|
verbose=self.verbose,
|
337
411
|
)
|
412
|
+
|
413
|
+
self.update_callback(
|
414
|
+
{
|
415
|
+
"role": "coder",
|
416
|
+
"content": format_code_context(code_context),
|
417
|
+
"media": capture_media_from_exec(code_context.test_result),
|
418
|
+
}
|
419
|
+
)
|
420
|
+
self.update_callback(
|
421
|
+
{
|
422
|
+
"role": "observation",
|
423
|
+
"content": code_context.test_result.text(),
|
424
|
+
}
|
425
|
+
)
|
338
426
|
return code_context
|
339
427
|
|
340
428
|
def log_progress(self, data: Dict[str, Any]) -> None:
|
@@ -391,12 +391,6 @@ class VisionAgentPlanner(Agent):
|
|
391
391
|
for chat_i in chat:
|
392
392
|
if "media" in chat_i:
|
393
393
|
for media in chat_i["media"]:
|
394
|
-
media = (
|
395
|
-
media
|
396
|
-
if type(media) is str
|
397
|
-
and media.startswith(("http", "https"))
|
398
|
-
else code_interpreter.upload_file(cast(str, media))
|
399
|
-
)
|
400
394
|
chat_i["content"] += f" Media name {media}" # type: ignore
|
401
395
|
media_list.append(str(media))
|
402
396
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import logging
|
3
|
+
import time
|
3
4
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
@@ -10,16 +11,17 @@ from rich.markup import escape
|
|
10
11
|
|
11
12
|
import vision_agent.tools as T
|
12
13
|
import vision_agent.tools.planner_tools as pt
|
13
|
-
from vision_agent.agent import
|
14
|
+
from vision_agent.agent import AgentPlanner
|
14
15
|
from vision_agent.agent.agent_utils import (
|
15
|
-
PlanContext,
|
16
16
|
add_media_to_chat,
|
17
17
|
capture_media_from_exec,
|
18
|
+
convert_message_to_agentmessage,
|
18
19
|
extract_json,
|
19
20
|
extract_tag,
|
20
21
|
print_code,
|
21
22
|
print_table,
|
22
23
|
)
|
24
|
+
from vision_agent.agent.types import AgentMessage, PlanContext
|
23
25
|
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
24
26
|
CRITIQUE_PLAN,
|
25
27
|
EXAMPLE_PLAN1,
|
@@ -70,26 +72,24 @@ class DefaultPlanningImports:
|
|
70
72
|
|
71
73
|
|
72
74
|
def get_planning(
|
73
|
-
chat: List[
|
75
|
+
chat: List[AgentMessage],
|
74
76
|
) -> str:
|
75
77
|
chat = copy.deepcopy(chat)
|
76
78
|
planning = ""
|
77
79
|
for chat_i in chat:
|
78
|
-
if chat_i
|
79
|
-
planning += f"USER: {chat_i
|
80
|
-
elif chat_i
|
81
|
-
planning += f"OBSERVATION: {chat_i
|
82
|
-
elif chat_i
|
83
|
-
planning += f"
|
84
|
-
else:
|
85
|
-
raise ValueError(f"Unknown role: {chat_i['role']}")
|
80
|
+
if chat_i.role == "user":
|
81
|
+
planning += f"USER: {chat_i.content}\n\n"
|
82
|
+
elif chat_i.role == "observation":
|
83
|
+
planning += f"OBSERVATION: {chat_i.content}\n\n"
|
84
|
+
elif chat_i.role == "planner":
|
85
|
+
planning += f"AGENT: {chat_i.content}\n\n"
|
86
86
|
|
87
87
|
return planning
|
88
88
|
|
89
89
|
|
90
90
|
def run_planning(
|
91
|
-
chat: List[
|
92
|
-
media_list: List[str],
|
91
|
+
chat: List[AgentMessage],
|
92
|
+
media_list: List[Union[str, Path]],
|
93
93
|
model: LMM,
|
94
94
|
) -> str:
|
95
95
|
# only keep last 10 messages for planning
|
@@ -102,16 +102,16 @@ def run_planning(
|
|
102
102
|
)
|
103
103
|
|
104
104
|
message: Message = {"role": "user", "content": prompt}
|
105
|
-
if chat[-1]
|
106
|
-
message["media"] = chat[-1]
|
105
|
+
if chat[-1].role == "observation" and chat[-1].media is not None:
|
106
|
+
message["media"] = chat[-1].media
|
107
107
|
|
108
108
|
response = model.chat([message])
|
109
109
|
return cast(str, response)
|
110
110
|
|
111
111
|
|
112
112
|
def run_multi_trial_planning(
|
113
|
-
chat: List[
|
114
|
-
media_list: List[str],
|
113
|
+
chat: List[AgentMessage],
|
114
|
+
media_list: List[Union[str, Path]],
|
115
115
|
model: LMM,
|
116
116
|
) -> str:
|
117
117
|
planning = get_planning(chat)
|
@@ -123,8 +123,8 @@ def run_multi_trial_planning(
|
|
123
123
|
)
|
124
124
|
|
125
125
|
message: Message = {"role": "user", "content": prompt}
|
126
|
-
if chat[-1]
|
127
|
-
message["media"] = chat[-1]
|
126
|
+
if chat[-1].role == "observation" and chat[-1].media is not None:
|
127
|
+
message["media"] = chat[-1].media
|
128
128
|
|
129
129
|
responses = []
|
130
130
|
with ThreadPoolExecutor() as executor:
|
@@ -151,7 +151,9 @@ def run_multi_trial_planning(
|
|
151
151
|
return cast(str, responses[0])
|
152
152
|
|
153
153
|
|
154
|
-
def run_critic(
|
154
|
+
def run_critic(
|
155
|
+
chat: List[AgentMessage], media_list: List[Union[str, Path]], model: LMM
|
156
|
+
) -> Optional[str]:
|
155
157
|
planning = get_planning(chat)
|
156
158
|
prompt = CRITIQUE_PLAN.format(
|
157
159
|
planning=planning,
|
@@ -196,17 +198,19 @@ def response_safeguards(response: str) -> str:
|
|
196
198
|
def execute_code_action(
|
197
199
|
code: str,
|
198
200
|
code_interpreter: CodeInterpreter,
|
199
|
-
chat: List[
|
201
|
+
chat: List[AgentMessage],
|
200
202
|
model: LMM,
|
201
203
|
verbose: bool = False,
|
202
204
|
) -> Tuple[Execution, str, str]:
|
203
205
|
if verbose:
|
204
206
|
print_code("Code to Execute:", code)
|
207
|
+
start = time.time()
|
205
208
|
execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
|
209
|
+
end = time.time()
|
206
210
|
obs = execution.text(include_results=False).strip()
|
207
211
|
if verbose:
|
208
212
|
_CONSOLE.print(
|
209
|
-
f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
213
|
+
f"[bold cyan]Code Execution Output ({end - start:.2f} sec):[/bold cyan] [yellow]{escape(obs)}[/yellow]"
|
210
214
|
)
|
211
215
|
|
212
216
|
count = 1
|
@@ -246,13 +250,13 @@ def find_and_replace_code(response: str, code: str) -> str:
|
|
246
250
|
def maybe_run_code(
|
247
251
|
code: Optional[str],
|
248
252
|
response: str,
|
249
|
-
chat: List[
|
250
|
-
media_list: List[str],
|
253
|
+
chat: List[AgentMessage],
|
254
|
+
media_list: List[Union[str, Path]],
|
251
255
|
model: LMM,
|
252
256
|
code_interpreter: CodeInterpreter,
|
253
257
|
verbose: bool = False,
|
254
|
-
) -> List[
|
255
|
-
return_chat: List[
|
258
|
+
) -> List[AgentMessage]:
|
259
|
+
return_chat: List[AgentMessage] = []
|
256
260
|
if code is not None:
|
257
261
|
code = code_safeguards(code)
|
258
262
|
execution, obs, code = execute_code_action(
|
@@ -262,30 +266,32 @@ def maybe_run_code(
|
|
262
266
|
# if we had to debug the code to fix an issue, replace the old code
|
263
267
|
# with the fixed code in the response
|
264
268
|
fixed_response = find_and_replace_code(response, code)
|
265
|
-
return_chat.append(
|
269
|
+
return_chat.append(
|
270
|
+
AgentMessage(role="planner", content=fixed_response, media=None)
|
271
|
+
)
|
266
272
|
|
267
273
|
media_data = capture_media_from_exec(execution)
|
268
|
-
int_chat_elt
|
274
|
+
int_chat_elt = AgentMessage(role="observation", content=obs, media=None)
|
269
275
|
if media_list:
|
270
|
-
int_chat_elt
|
276
|
+
int_chat_elt.media = cast(List[Union[str, Path]], media_data)
|
271
277
|
return_chat.append(int_chat_elt)
|
272
278
|
else:
|
273
|
-
return_chat.append(
|
279
|
+
return_chat.append(AgentMessage(role="planner", content=response, media=None))
|
274
280
|
return return_chat
|
275
281
|
|
276
282
|
|
277
283
|
def create_finalize_plan(
|
278
|
-
chat: List[
|
284
|
+
chat: List[AgentMessage],
|
279
285
|
model: LMM,
|
280
286
|
verbose: bool = False,
|
281
|
-
) -> Tuple[List[
|
287
|
+
) -> Tuple[List[AgentMessage], PlanContext]:
|
282
288
|
prompt = FINALIZE_PLAN.format(
|
283
289
|
planning=get_planning(chat),
|
284
290
|
excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
|
285
291
|
)
|
286
292
|
response = model.chat([{"role": "user", "content": prompt}])
|
287
293
|
plan_str = cast(str, response)
|
288
|
-
return_chat
|
294
|
+
return_chat = [AgentMessage(role="planner", content=plan_str, media=None)]
|
289
295
|
|
290
296
|
plan_json = extract_tag(plan_str, "json")
|
291
297
|
plan = (
|
@@ -305,7 +311,16 @@ def create_finalize_plan(
|
|
305
311
|
return return_chat, PlanContext(**plan)
|
306
312
|
|
307
313
|
|
308
|
-
|
314
|
+
def get_steps(chat: List[AgentMessage], max_steps: int) -> int:
|
315
|
+
for chat_elt in reversed(chat):
|
316
|
+
if "<count>" in chat_elt.content:
|
317
|
+
return int(extract_tag(chat_elt.content, "count")) # type: ignore
|
318
|
+
return max_steps
|
319
|
+
|
320
|
+
|
321
|
+
class VisionAgentPlannerV2(AgentPlanner):
|
322
|
+
"""VisionAgentPlannerV2 is a class that generates a plan to solve a vision task."""
|
323
|
+
|
309
324
|
def __init__(
|
310
325
|
self,
|
311
326
|
planner: Optional[LMM] = None,
|
@@ -317,6 +332,25 @@ class VisionAgentPlannerV2(Agent):
|
|
317
332
|
code_sandbox_runtime: Optional[str] = None,
|
318
333
|
update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
|
319
334
|
) -> None:
|
335
|
+
"""Initialize the VisionAgentPlannerV2.
|
336
|
+
|
337
|
+
Parameters:
|
338
|
+
planner (Optional[LMM]): The language model to use for planning. If None, a
|
339
|
+
default AnthropicLMM will be used.
|
340
|
+
critic (Optional[LMM]): The language model to use for critiquing the plan.
|
341
|
+
If None, a default AnthropicLMM will be used.
|
342
|
+
max_steps (int): The maximum number of steps to plan.
|
343
|
+
use_multi_trial_planning (bool): Whether to use multi-trial planning.
|
344
|
+
critique_steps (int): The number of steps between critiques. If critic steps
|
345
|
+
is larger than max_steps no critiques will be made.
|
346
|
+
verbose (bool): Whether to print out debug information.
|
347
|
+
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
|
348
|
+
be one of: None, "local" or "e2b". If None, it will read from the
|
349
|
+
environment variable CODE_SANDBOX_RUNTIME.
|
350
|
+
update_callback (Callable[[Dict[str, Any]], None]): The callback function
|
351
|
+
that will send back intermediate conversation messages.
|
352
|
+
"""
|
353
|
+
|
320
354
|
self.planner = (
|
321
355
|
planner
|
322
356
|
if planner is not None
|
@@ -339,20 +373,42 @@ class VisionAgentPlannerV2(Agent):
|
|
339
373
|
self,
|
340
374
|
input: Union[str, List[Message]],
|
341
375
|
media: Optional[Union[str, Path]] = None,
|
342
|
-
) ->
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
376
|
+
) -> str:
|
377
|
+
"""Generate a plan to solve a vision task.
|
378
|
+
|
379
|
+
Parameters:
|
380
|
+
input (Union[str, List[Message]]): The input to the agent. This can be a
|
381
|
+
string or a list of messages in the format of [{"role": "user",
|
382
|
+
"content": "describe your task here..."}, ...].
|
383
|
+
media (Optional[Union[str, Path]]): The path to the media file to use with
|
384
|
+
the input. This can be an image or video file.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
str: The generated plan as a string.
|
388
|
+
"""
|
389
|
+
|
390
|
+
input_msg = convert_message_to_agentmessage(input, media)
|
391
|
+
plan = self.generate_plan(input_msg)
|
392
|
+
return plan.plan
|
350
393
|
|
351
394
|
def generate_plan(
|
352
395
|
self,
|
353
|
-
chat: List[
|
396
|
+
chat: List[AgentMessage],
|
397
|
+
max_steps: Optional[int] = None,
|
354
398
|
code_interpreter: Optional[CodeInterpreter] = None,
|
355
399
|
) -> PlanContext:
|
400
|
+
"""Generate a plan to solve a vision task.
|
401
|
+
|
402
|
+
Parameters:
|
403
|
+
chat (List[AgentMessage]): The conversation messages to generate a plan for.
|
404
|
+
max_steps (Optional[int]): The maximum number of steps to plan.
|
405
|
+
code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
PlanContext: The generated plan including the instructions and code snippets
|
409
|
+
needed to solve the task.
|
410
|
+
"""
|
411
|
+
|
356
412
|
if not chat:
|
357
413
|
raise ValueError("Chat cannot be empty")
|
358
414
|
|
@@ -360,13 +416,16 @@ class VisionAgentPlannerV2(Agent):
|
|
360
416
|
code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
|
361
417
|
self.code_sandbox_runtime
|
362
418
|
)
|
419
|
+
max_steps = max_steps or self.max_steps
|
363
420
|
|
364
421
|
with code_interpreter:
|
365
422
|
critque_steps = 1
|
366
|
-
step = self.max_steps
|
367
423
|
finished = False
|
368
424
|
int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
|
369
|
-
|
425
|
+
|
426
|
+
step = get_steps(int_chat, max_steps)
|
427
|
+
if "<count>" not in int_chat[-1].content and step == max_steps:
|
428
|
+
int_chat[-1].content += f"\n<count>{step}</count>\n"
|
370
429
|
while step > 0 and not finished:
|
371
430
|
if self.use_multi_trial_planning:
|
372
431
|
response = run_multi_trial_planning(
|
@@ -402,29 +461,29 @@ class VisionAgentPlannerV2(Agent):
|
|
402
461
|
|
403
462
|
if critque_steps % self.critique_steps == 0:
|
404
463
|
critique = run_critic(int_chat, media_list, self.critic)
|
405
|
-
if critique is not None and int_chat[-1]
|
464
|
+
if critique is not None and int_chat[-1].role == "observation":
|
406
465
|
_CONSOLE.print(
|
407
466
|
f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
|
408
467
|
)
|
409
468
|
critique_str = f"\n[critique]\n{critique}\n[end of critique]"
|
410
|
-
updated_chat[-1]
|
469
|
+
updated_chat[-1].content += critique_str
|
411
470
|
# if plan was critiqued, ensure we don't finish so we can
|
412
471
|
# respond to the critique
|
413
472
|
finished = False
|
414
473
|
|
415
474
|
critque_steps += 1
|
416
475
|
step -= 1
|
417
|
-
updated_chat[-1]
|
476
|
+
updated_chat[-1].content += f"\n<count>{step}</count>\n"
|
418
477
|
int_chat.extend(updated_chat)
|
419
478
|
for chat_elt in updated_chat:
|
420
|
-
self.update_callback(chat_elt)
|
479
|
+
self.update_callback(chat_elt.model_dump())
|
421
480
|
|
422
481
|
updated_chat, plan_context = create_finalize_plan(
|
423
482
|
int_chat, self.planner, self.verbose
|
424
483
|
)
|
425
484
|
int_chat.extend(updated_chat)
|
426
485
|
for chat_elt in updated_chat:
|
427
|
-
self.update_callback(chat_elt)
|
486
|
+
self.update_callback(chat_elt.model_dump())
|
428
487
|
|
429
488
|
return plan_context
|
430
489
|
|
@@ -55,10 +55,10 @@ generate_vision_code(artifacts, 'dog_detector.py', 'Can you write code to detect
|
|
55
55
|
|
56
56
|
OBSERVATION:
|
57
57
|
[Artifact dog_detector.py (5 lines total)]
|
58
|
-
0|from vision_agent.tools import load_image,
|
58
|
+
0|from vision_agent.tools import load_image, owl_v2_image
|
59
59
|
1|def detect_dogs(image_path: str):
|
60
60
|
2| image = load_image(image_path)
|
61
|
-
3| dogs =
|
61
|
+
3| dogs = owl_v2_image("dog", image)
|
62
62
|
4| return dogs
|
63
63
|
[End of artifact]
|
64
64
|
|
@@ -96,10 +96,10 @@ edit_vision_code(artifacts, 'dog_detector.py', ['Can you write code to detect do
|
|
96
96
|
|
97
97
|
OBSERVATION:
|
98
98
|
[Artifact dog_detector.py (5 lines total)]
|
99
|
-
0|from vision_agent.tools import load_image,
|
99
|
+
0|from vision_agent.tools import load_image, owl_v2_image
|
100
100
|
1|def detect_dogs(image_path: str):
|
101
101
|
2| image = load_image(image_path)
|
102
|
-
3| dogs =
|
102
|
+
3| dogs = owl_v2_image("dog", image, threshold=0.24)
|
103
103
|
4| return dogs
|
104
104
|
[End of artifact]
|
105
105
|
|