vision-agent 0.2.199__py3-none-any.whl → 0.2.201__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -6,19 +6,19 @@ from rich.console import Console
6
6
  from rich.markup import escape
7
7
 
8
8
  import vision_agent.tools as T
9
- from vision_agent.agent import Agent
9
+ from vision_agent.agent import AgentCoder, AgentPlanner
10
10
  from vision_agent.agent.agent_utils import (
11
- CodeContext,
12
11
  DefaultImports,
13
- PlanContext,
14
12
  add_media_to_chat,
15
13
  capture_media_from_exec,
14
+ convert_message_to_agentmessage,
16
15
  extract_tag,
17
16
  format_feedback,
18
17
  format_plan_v2,
19
18
  print_code,
20
19
  strip_function_calls,
21
20
  )
21
+ from vision_agent.agent.types import AgentMessage, CodeContext, PlanContext
22
22
  from vision_agent.agent.vision_agent_coder_prompts_v2 import CODE, FIX_BUG, TEST
23
23
  from vision_agent.agent.vision_agent_planner_v2 import VisionAgentPlannerV2
24
24
  from vision_agent.lmm import LMM, AnthropicLMM
@@ -34,6 +34,12 @@ from vision_agent.utils.sim import Sim, load_cached_sim
34
34
  _CONSOLE = Console()
35
35
 
36
36
 
37
+ def format_code_context(
38
+ code_context: CodeContext,
39
+ ) -> str:
40
+ return f"<final_code>{code_context.code}</final_code>\n<final_test>{code_context.test}</final_test>"
41
+
42
+
37
43
  def retrieve_tools(
38
44
  plan: List[str],
39
45
  tool_recommender: Sim,
@@ -49,46 +55,54 @@ def retrieve_tools(
49
55
 
50
56
  def write_code(
51
57
  coder: LMM,
52
- chat: List[Message],
58
+ chat: List[AgentMessage],
53
59
  tool_docs: str,
54
60
  plan: str,
55
61
  ) -> str:
56
62
  chat = copy.deepcopy(chat)
57
- if chat[-1]["role"] != "user":
63
+ if chat[-1].role != "user":
58
64
  raise ValueError("Last chat message must be from the user.")
59
65
 
60
- user_request = chat[-1]["content"]
66
+ user_request = chat[-1].content
61
67
  prompt = CODE.format(
62
68
  docstring=tool_docs,
63
69
  question=user_request,
64
70
  plan=plan,
65
71
  )
66
- chat[-1]["content"] = prompt
67
- response = coder(chat, stream=False)
68
- return extract_tag(response, "code") # type: ignore
72
+ response = cast(str, coder([{"role": "user", "content": prompt}], stream=False))
73
+ maybe_code = extract_tag(response, "code")
74
+
75
+ # if the response wasn't properly formatted with the code tags just retrun the response
76
+ if maybe_code is None:
77
+ return response
78
+ return maybe_code
69
79
 
70
80
 
71
81
  def write_test(
72
82
  tester: LMM,
73
- chat: List[Message],
83
+ chat: List[AgentMessage],
74
84
  tool_util_docs: str,
75
85
  code: str,
76
86
  media_list: Optional[Sequence[Union[str, Path]]] = None,
77
87
  ) -> str:
78
88
  chat = copy.deepcopy(chat)
79
- if chat[-1]["role"] != "user":
89
+ if chat[-1].role != "user":
80
90
  raise ValueError("Last chat message must be from the user.")
81
91
 
82
- user_request = chat[-1]["content"]
92
+ user_request = chat[-1].content
83
93
  prompt = TEST.format(
84
94
  docstring=tool_util_docs,
85
95
  question=user_request,
86
96
  code=code,
87
97
  media=media_list,
88
98
  )
89
- chat[-1]["content"] = prompt
90
- response = tester(chat, stream=False)
91
- return extract_tag(response, "code") # type: ignore
99
+ response = cast(str, tester([{"role": "user", "content": prompt}], stream=False))
100
+ maybe_code = extract_tag(response, "code")
101
+
102
+ # if the response wasn't properly formatted with the code tags just retrun the response
103
+ if maybe_code is None:
104
+ return response
105
+ return maybe_code
92
106
 
93
107
 
94
108
  def debug_code(
@@ -170,12 +184,11 @@ def write_and_test_code(
170
184
  coder: LMM,
171
185
  tester: LMM,
172
186
  debugger: LMM,
173
- chat: List[Message],
187
+ chat: List[AgentMessage],
174
188
  plan: str,
175
189
  tool_docs: str,
176
190
  code_interpreter: CodeInterpreter,
177
191
  media_list: List[Union[str, Path]],
178
- update_callback: Callable[[Dict[str, Any]], None],
179
192
  verbose: bool,
180
193
  ) -> CodeContext:
181
194
  code = write_code(
@@ -226,14 +239,6 @@ def write_and_test_code(
226
239
  f"[bold cyan]Code execution result after attempted fix:[/bold cyan] [yellow]{escape(result.text(include_logs=True))}[/yellow]"
227
240
  )
228
241
 
229
- update_callback(
230
- {
231
- "role": "assistant",
232
- "content": f"<final_code>{DefaultImports.to_code_string()}\n{code}</final_code>\n<final_test>{DefaultImports.to_code_string()}\n{test}</final_test>",
233
- "media": capture_media_from_exec(result),
234
- }
235
- )
236
-
237
242
  return CodeContext(
238
243
  code=f"{DefaultImports.to_code_string()}\n{code}",
239
244
  test=f"{DefaultImports.to_code_string()}\n{test}",
@@ -242,10 +247,12 @@ def write_and_test_code(
242
247
  )
243
248
 
244
249
 
245
- class VisionAgentCoderV2(Agent):
250
+ class VisionAgentCoderV2(AgentCoder):
251
+ """VisionAgentCoderV2 is an agent that will write vision code for you."""
252
+
246
253
  def __init__(
247
254
  self,
248
- planner: Optional[Agent] = None,
255
+ planner: Optional[AgentPlanner] = None,
249
256
  coder: Optional[LMM] = None,
250
257
  tester: Optional[LMM] = None,
251
258
  debugger: Optional[LMM] = None,
@@ -254,6 +261,25 @@ class VisionAgentCoderV2(Agent):
254
261
  code_sandbox_runtime: Optional[str] = None,
255
262
  update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
256
263
  ) -> None:
264
+ """Initialize the VisionAgentCoderV2.
265
+
266
+ Parameters:
267
+ planner (Optional[AgentPlanner]): The planner agent to use for generating
268
+ vision plans. If None, a default VisionAgentPlannerV2 will be used.
269
+ coder (Optional[LMM]): The language model to use for the coder agent. If
270
+ None, a default AnthropicLMM will be used.
271
+ tester (Optional[LMM]): The language model to use for the tester agent. If
272
+ None, a default AnthropicLMM will be used.
273
+ debugger (Optional[LMM]): The language model to use for the debugger agent.
274
+ tool_recommender (Optional[Union[str, Sim]]): The tool recommender to use.
275
+ verbose (bool): Whether to print out debug information.
276
+ code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
277
+ be one of: None, "local" or "e2b". If None, it will read from the
278
+ environment variable CODE_SANDBOX_RUNTIME.
279
+ update_callback (Callable[[Dict[str, Any]], None]): The callback function
280
+ that will send back intermediate conversation messages.
281
+ """
282
+
257
283
  self.planner = (
258
284
  planner
259
285
  if planner is not None
@@ -290,20 +316,52 @@ class VisionAgentCoderV2(Agent):
290
316
  self,
291
317
  input: Union[str, List[Message]],
292
318
  media: Optional[Union[str, Path]] = None,
293
- ) -> Union[str, List[Message]]:
294
- if isinstance(input, str):
295
- input = [{"role": "user", "content": input}]
296
- if media is not None:
297
- input[0]["media"] = [media]
298
- return self.generate_code(input).code
299
-
300
- def generate_code(self, chat: List[Message]) -> CodeContext:
319
+ ) -> str:
320
+ """Generate vision code from a conversation.
321
+
322
+ Parameters:
323
+ input (Union[str, List[Message]]): The input to the agent. This can be a
324
+ string or a list of messages in the format of [{"role": "user",
325
+ "content": "describe your task here..."}, ...].
326
+ media (Optional[Union[str, Path]]): The path to the media file to use with
327
+ the input. This can be an image or video file.
328
+
329
+ Returns:
330
+ str: The generated code as a string.
331
+ """
332
+
333
+ input_msg = convert_message_to_agentmessage(input, media)
334
+ return self.generate_code(input_msg).code
335
+
336
+ def generate_code(
337
+ self,
338
+ chat: List[AgentMessage],
339
+ max_steps: Optional[int] = None,
340
+ code_interpreter: Optional[CodeInterpreter] = None,
341
+ ) -> CodeContext:
342
+ """Generate vision code from a conversation.
343
+
344
+ Parameters:
345
+ chat (List[AgentMessage]): The input to the agent. This should be a list of
346
+ AgentMessage objects.
347
+ code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
348
+
349
+ Returns:
350
+ CodeContext: The generated code as a CodeContext object which includes the
351
+ code, test code, whether or not it was exceuted successfully, and the
352
+ execution result.
353
+ """
354
+
301
355
  chat = copy.deepcopy(chat)
302
- with CodeInterpreterFactory.new_instance(
303
- self.code_sandbox_runtime
356
+ with (
357
+ CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
358
+ if code_interpreter is None
359
+ else code_interpreter
304
360
  ) as code_interpreter:
305
361
  int_chat, orig_chat, _ = add_media_to_chat(chat, code_interpreter)
306
- plan_context = self.planner.generate_plan(int_chat, code_interpreter) # type: ignore
362
+ plan_context = self.planner.generate_plan(
363
+ int_chat, max_steps=max_steps, code_interpreter=code_interpreter
364
+ )
307
365
  code_context = self.generate_code_from_plan(
308
366
  orig_chat,
309
367
  plan_context,
@@ -313,13 +371,30 @@ class VisionAgentCoderV2(Agent):
313
371
 
314
372
  def generate_code_from_plan(
315
373
  self,
316
- chat: List[Message],
374
+ chat: List[AgentMessage],
317
375
  plan_context: PlanContext,
318
376
  code_interpreter: Optional[CodeInterpreter] = None,
319
377
  ) -> CodeContext:
378
+ """Generate vision code from a conversation and a previously made plan. This
379
+ will skip the planning step and go straight to generating code.
380
+
381
+ Parameters:
382
+ chat (List[AgentMessage]): The input to the agent. This should be a list of
383
+ AgentMessage objects.
384
+ plan_context (PlanContext): The plan context that was previously generated.
385
+ code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
386
+
387
+ Returns:
388
+ CodeContext: The generated code as a CodeContext object which includes the
389
+ code, test code, whether or not it was exceuted successfully, and the
390
+ execution result.
391
+ """
392
+
320
393
  chat = copy.deepcopy(chat)
321
- with CodeInterpreterFactory.new_instance(
322
- self.code_sandbox_runtime
394
+ with (
395
+ CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
396
+ if code_interpreter is None
397
+ else code_interpreter
323
398
  ) as code_interpreter:
324
399
  int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
325
400
  tool_docs = retrieve_tools(plan_context.instructions, self.tool_recommender)
@@ -331,10 +406,23 @@ class VisionAgentCoderV2(Agent):
331
406
  plan=format_plan_v2(plan_context),
332
407
  tool_docs=tool_docs,
333
408
  code_interpreter=code_interpreter,
334
- media_list=media_list, # type: ignore
335
- update_callback=self.update_callback,
409
+ media_list=media_list,
336
410
  verbose=self.verbose,
337
411
  )
412
+
413
+ self.update_callback(
414
+ {
415
+ "role": "coder",
416
+ "content": format_code_context(code_context),
417
+ "media": capture_media_from_exec(code_context.test_result),
418
+ }
419
+ )
420
+ self.update_callback(
421
+ {
422
+ "role": "observation",
423
+ "content": code_context.test_result.text(),
424
+ }
425
+ )
338
426
  return code_context
339
427
 
340
428
  def log_progress(self, data: Dict[str, Any]) -> None:
@@ -391,12 +391,6 @@ class VisionAgentPlanner(Agent):
391
391
  for chat_i in chat:
392
392
  if "media" in chat_i:
393
393
  for media in chat_i["media"]:
394
- media = (
395
- media
396
- if type(media) is str
397
- and media.startswith(("http", "https"))
398
- else code_interpreter.upload_file(cast(str, media))
399
- )
400
394
  chat_i["content"] += f" Media name {media}" # type: ignore
401
395
  media_list.append(str(media))
402
396
 
@@ -389,7 +389,7 @@ for infos in obj_to_info:
389
389
  print(f"{len(objects_with_tape)} boxes with tape found")
390
390
  </execute_python>
391
391
 
392
- OBJERVATION:
392
+ OBSERVATION:
393
393
  3 boxes were tracked
394
394
  2 boxes with tape found
395
395
  <count>6</count>
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import logging
3
+ import time
3
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
5
  from pathlib import Path
5
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
@@ -10,16 +11,17 @@ from rich.markup import escape
10
11
 
11
12
  import vision_agent.tools as T
12
13
  import vision_agent.tools.planner_tools as pt
13
- from vision_agent.agent import Agent
14
+ from vision_agent.agent import AgentPlanner
14
15
  from vision_agent.agent.agent_utils import (
15
- PlanContext,
16
16
  add_media_to_chat,
17
17
  capture_media_from_exec,
18
+ convert_message_to_agentmessage,
18
19
  extract_json,
19
20
  extract_tag,
20
21
  print_code,
21
22
  print_table,
22
23
  )
24
+ from vision_agent.agent.types import AgentMessage, PlanContext
23
25
  from vision_agent.agent.vision_agent_planner_prompts_v2 import (
24
26
  CRITIQUE_PLAN,
25
27
  EXAMPLE_PLAN1,
@@ -70,26 +72,24 @@ class DefaultPlanningImports:
70
72
 
71
73
 
72
74
  def get_planning(
73
- chat: List[Message],
75
+ chat: List[AgentMessage],
74
76
  ) -> str:
75
77
  chat = copy.deepcopy(chat)
76
78
  planning = ""
77
79
  for chat_i in chat:
78
- if chat_i["role"] == "user":
79
- planning += f"USER: {chat_i['content']}\n\n"
80
- elif chat_i["role"] == "observation":
81
- planning += f"OBSERVATION: {chat_i['content']}\n\n"
82
- elif chat_i["role"] == "assistant":
83
- planning += f"ASSISTANT: {chat_i['content']}\n\n"
84
- else:
85
- raise ValueError(f"Unknown role: {chat_i['role']}")
80
+ if chat_i.role == "user":
81
+ planning += f"USER: {chat_i.content}\n\n"
82
+ elif chat_i.role == "observation":
83
+ planning += f"OBSERVATION: {chat_i.content}\n\n"
84
+ elif chat_i.role == "planner":
85
+ planning += f"AGENT: {chat_i.content}\n\n"
86
86
 
87
87
  return planning
88
88
 
89
89
 
90
90
  def run_planning(
91
- chat: List[Message],
92
- media_list: List[str],
91
+ chat: List[AgentMessage],
92
+ media_list: List[Union[str, Path]],
93
93
  model: LMM,
94
94
  ) -> str:
95
95
  # only keep last 10 messages for planning
@@ -102,16 +102,16 @@ def run_planning(
102
102
  )
103
103
 
104
104
  message: Message = {"role": "user", "content": prompt}
105
- if chat[-1]["role"] == "observation" and "media" in chat[-1]:
106
- message["media"] = chat[-1]["media"]
105
+ if chat[-1].role == "observation" and chat[-1].media is not None:
106
+ message["media"] = chat[-1].media
107
107
 
108
108
  response = model.chat([message])
109
109
  return cast(str, response)
110
110
 
111
111
 
112
112
  def run_multi_trial_planning(
113
- chat: List[Message],
114
- media_list: List[str],
113
+ chat: List[AgentMessage],
114
+ media_list: List[Union[str, Path]],
115
115
  model: LMM,
116
116
  ) -> str:
117
117
  planning = get_planning(chat)
@@ -123,8 +123,8 @@ def run_multi_trial_planning(
123
123
  )
124
124
 
125
125
  message: Message = {"role": "user", "content": prompt}
126
- if chat[-1]["role"] == "observation" and "media" in chat[-1]:
127
- message["media"] = chat[-1]["media"]
126
+ if chat[-1].role == "observation" and chat[-1].media is not None:
127
+ message["media"] = chat[-1].media
128
128
 
129
129
  responses = []
130
130
  with ThreadPoolExecutor() as executor:
@@ -151,7 +151,9 @@ def run_multi_trial_planning(
151
151
  return cast(str, responses[0])
152
152
 
153
153
 
154
- def run_critic(chat: List[Message], media_list: List[str], model: LMM) -> Optional[str]:
154
+ def run_critic(
155
+ chat: List[AgentMessage], media_list: List[Union[str, Path]], model: LMM
156
+ ) -> Optional[str]:
155
157
  planning = get_planning(chat)
156
158
  prompt = CRITIQUE_PLAN.format(
157
159
  planning=planning,
@@ -196,17 +198,19 @@ def response_safeguards(response: str) -> str:
196
198
  def execute_code_action(
197
199
  code: str,
198
200
  code_interpreter: CodeInterpreter,
199
- chat: List[Message],
201
+ chat: List[AgentMessage],
200
202
  model: LMM,
201
203
  verbose: bool = False,
202
204
  ) -> Tuple[Execution, str, str]:
203
205
  if verbose:
204
206
  print_code("Code to Execute:", code)
207
+ start = time.time()
205
208
  execution = code_interpreter.exec_cell(DefaultPlanningImports.prepend_imports(code))
209
+ end = time.time()
206
210
  obs = execution.text(include_results=False).strip()
207
211
  if verbose:
208
212
  _CONSOLE.print(
209
- f"[bold cyan]Code Execution Output:[/bold cyan] [yellow]{escape(obs)}[/yellow]"
213
+ f"[bold cyan]Code Execution Output ({end - start:.2f} sec):[/bold cyan] [yellow]{escape(obs)}[/yellow]"
210
214
  )
211
215
 
212
216
  count = 1
@@ -246,13 +250,13 @@ def find_and_replace_code(response: str, code: str) -> str:
246
250
  def maybe_run_code(
247
251
  code: Optional[str],
248
252
  response: str,
249
- chat: List[Message],
250
- media_list: List[str],
253
+ chat: List[AgentMessage],
254
+ media_list: List[Union[str, Path]],
251
255
  model: LMM,
252
256
  code_interpreter: CodeInterpreter,
253
257
  verbose: bool = False,
254
- ) -> List[Message]:
255
- return_chat: List[Message] = []
258
+ ) -> List[AgentMessage]:
259
+ return_chat: List[AgentMessage] = []
256
260
  if code is not None:
257
261
  code = code_safeguards(code)
258
262
  execution, obs, code = execute_code_action(
@@ -262,30 +266,32 @@ def maybe_run_code(
262
266
  # if we had to debug the code to fix an issue, replace the old code
263
267
  # with the fixed code in the response
264
268
  fixed_response = find_and_replace_code(response, code)
265
- return_chat.append({"role": "assistant", "content": fixed_response})
269
+ return_chat.append(
270
+ AgentMessage(role="planner", content=fixed_response, media=None)
271
+ )
266
272
 
267
273
  media_data = capture_media_from_exec(execution)
268
- int_chat_elt: Message = {"role": "observation", "content": obs}
274
+ int_chat_elt = AgentMessage(role="observation", content=obs, media=None)
269
275
  if media_list:
270
- int_chat_elt["media"] = media_data
276
+ int_chat_elt.media = cast(List[Union[str, Path]], media_data)
271
277
  return_chat.append(int_chat_elt)
272
278
  else:
273
- return_chat.append({"role": "assistant", "content": response})
279
+ return_chat.append(AgentMessage(role="planner", content=response, media=None))
274
280
  return return_chat
275
281
 
276
282
 
277
283
  def create_finalize_plan(
278
- chat: List[Message],
284
+ chat: List[AgentMessage],
279
285
  model: LMM,
280
286
  verbose: bool = False,
281
- ) -> Tuple[List[Message], PlanContext]:
287
+ ) -> Tuple[List[AgentMessage], PlanContext]:
282
288
  prompt = FINALIZE_PLAN.format(
283
289
  planning=get_planning(chat),
284
290
  excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
285
291
  )
286
292
  response = model.chat([{"role": "user", "content": prompt}])
287
293
  plan_str = cast(str, response)
288
- return_chat: List[Message] = [{"role": "assistant", "content": plan_str}]
294
+ return_chat = [AgentMessage(role="planner", content=plan_str, media=None)]
289
295
 
290
296
  plan_json = extract_tag(plan_str, "json")
291
297
  plan = (
@@ -305,7 +311,16 @@ def create_finalize_plan(
305
311
  return return_chat, PlanContext(**plan)
306
312
 
307
313
 
308
- class VisionAgentPlannerV2(Agent):
314
+ def get_steps(chat: List[AgentMessage], max_steps: int) -> int:
315
+ for chat_elt in reversed(chat):
316
+ if "<count>" in chat_elt.content:
317
+ return int(extract_tag(chat_elt.content, "count")) # type: ignore
318
+ return max_steps
319
+
320
+
321
+ class VisionAgentPlannerV2(AgentPlanner):
322
+ """VisionAgentPlannerV2 is a class that generates a plan to solve a vision task."""
323
+
309
324
  def __init__(
310
325
  self,
311
326
  planner: Optional[LMM] = None,
@@ -317,6 +332,25 @@ class VisionAgentPlannerV2(Agent):
317
332
  code_sandbox_runtime: Optional[str] = None,
318
333
  update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
319
334
  ) -> None:
335
+ """Initialize the VisionAgentPlannerV2.
336
+
337
+ Parameters:
338
+ planner (Optional[LMM]): The language model to use for planning. If None, a
339
+ default AnthropicLMM will be used.
340
+ critic (Optional[LMM]): The language model to use for critiquing the plan.
341
+ If None, a default AnthropicLMM will be used.
342
+ max_steps (int): The maximum number of steps to plan.
343
+ use_multi_trial_planning (bool): Whether to use multi-trial planning.
344
+ critique_steps (int): The number of steps between critiques. If critic steps
345
+ is larger than max_steps no critiques will be made.
346
+ verbose (bool): Whether to print out debug information.
347
+ code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
348
+ be one of: None, "local" or "e2b". If None, it will read from the
349
+ environment variable CODE_SANDBOX_RUNTIME.
350
+ update_callback (Callable[[Dict[str, Any]], None]): The callback function
351
+ that will send back intermediate conversation messages.
352
+ """
353
+
320
354
  self.planner = (
321
355
  planner
322
356
  if planner is not None
@@ -339,20 +373,42 @@ class VisionAgentPlannerV2(Agent):
339
373
  self,
340
374
  input: Union[str, List[Message]],
341
375
  media: Optional[Union[str, Path]] = None,
342
- ) -> Union[str, List[Message]]:
343
- if isinstance(input, str):
344
- if media is not None:
345
- input = [{"role": "user", "content": input, "media": [media]}]
346
- else:
347
- input = [{"role": "user", "content": input}]
348
- plan = self.generate_plan(input)
349
- return str(plan)
376
+ ) -> str:
377
+ """Generate a plan to solve a vision task.
378
+
379
+ Parameters:
380
+ input (Union[str, List[Message]]): The input to the agent. This can be a
381
+ string or a list of messages in the format of [{"role": "user",
382
+ "content": "describe your task here..."}, ...].
383
+ media (Optional[Union[str, Path]]): The path to the media file to use with
384
+ the input. This can be an image or video file.
385
+
386
+ Returns:
387
+ str: The generated plan as a string.
388
+ """
389
+
390
+ input_msg = convert_message_to_agentmessage(input, media)
391
+ plan = self.generate_plan(input_msg)
392
+ return plan.plan
350
393
 
351
394
  def generate_plan(
352
395
  self,
353
- chat: List[Message],
396
+ chat: List[AgentMessage],
397
+ max_steps: Optional[int] = None,
354
398
  code_interpreter: Optional[CodeInterpreter] = None,
355
399
  ) -> PlanContext:
400
+ """Generate a plan to solve a vision task.
401
+
402
+ Parameters:
403
+ chat (List[AgentMessage]): The conversation messages to generate a plan for.
404
+ max_steps (Optional[int]): The maximum number of steps to plan.
405
+ code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
406
+
407
+ Returns:
408
+ PlanContext: The generated plan including the instructions and code snippets
409
+ needed to solve the task.
410
+ """
411
+
356
412
  if not chat:
357
413
  raise ValueError("Chat cannot be empty")
358
414
 
@@ -360,13 +416,16 @@ class VisionAgentPlannerV2(Agent):
360
416
  code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
361
417
  self.code_sandbox_runtime
362
418
  )
419
+ max_steps = max_steps or self.max_steps
363
420
 
364
421
  with code_interpreter:
365
422
  critque_steps = 1
366
- step = self.max_steps
367
423
  finished = False
368
424
  int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
369
- int_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
425
+
426
+ step = get_steps(int_chat, max_steps)
427
+ if "<count>" not in int_chat[-1].content and step == max_steps:
428
+ int_chat[-1].content += f"\n<count>{step}</count>\n"
370
429
  while step > 0 and not finished:
371
430
  if self.use_multi_trial_planning:
372
431
  response = run_multi_trial_planning(
@@ -402,29 +461,29 @@ class VisionAgentPlannerV2(Agent):
402
461
 
403
462
  if critque_steps % self.critique_steps == 0:
404
463
  critique = run_critic(int_chat, media_list, self.critic)
405
- if critique is not None and int_chat[-1]["role"] == "observation":
464
+ if critique is not None and int_chat[-1].role == "observation":
406
465
  _CONSOLE.print(
407
466
  f"[bold cyan]Critique:[/bold cyan] [red]{critique}[/red]"
408
467
  )
409
468
  critique_str = f"\n[critique]\n{critique}\n[end of critique]"
410
- updated_chat[-1]["content"] += critique_str # type: ignore
469
+ updated_chat[-1].content += critique_str
411
470
  # if plan was critiqued, ensure we don't finish so we can
412
471
  # respond to the critique
413
472
  finished = False
414
473
 
415
474
  critque_steps += 1
416
475
  step -= 1
417
- updated_chat[-1]["content"] += f"\n<count>{step}</count>\n" # type: ignore
476
+ updated_chat[-1].content += f"\n<count>{step}</count>\n"
418
477
  int_chat.extend(updated_chat)
419
478
  for chat_elt in updated_chat:
420
- self.update_callback(chat_elt)
479
+ self.update_callback(chat_elt.model_dump())
421
480
 
422
481
  updated_chat, plan_context = create_finalize_plan(
423
482
  int_chat, self.planner, self.verbose
424
483
  )
425
484
  int_chat.extend(updated_chat)
426
485
  for chat_elt in updated_chat:
427
- self.update_callback(chat_elt)
486
+ self.update_callback(chat_elt.model_dump())
428
487
 
429
488
  return plan_context
430
489
 
@@ -55,10 +55,10 @@ generate_vision_code(artifacts, 'dog_detector.py', 'Can you write code to detect
55
55
 
56
56
  OBSERVATION:
57
57
  [Artifact dog_detector.py (5 lines total)]
58
- 0|from vision_agent.tools import load_image, owl_v2
58
+ 0|from vision_agent.tools import load_image, owl_v2_image
59
59
  1|def detect_dogs(image_path: str):
60
60
  2| image = load_image(image_path)
61
- 3| dogs = owl_v2("dog", image)
61
+ 3| dogs = owl_v2_image("dog", image)
62
62
  4| return dogs
63
63
  [End of artifact]
64
64
 
@@ -96,10 +96,10 @@ edit_vision_code(artifacts, 'dog_detector.py', ['Can you write code to detect do
96
96
 
97
97
  OBSERVATION:
98
98
  [Artifact dog_detector.py (5 lines total)]
99
- 0|from vision_agent.tools import load_image, owl_v2
99
+ 0|from vision_agent.tools import load_image, owl_v2_image
100
100
  1|def detect_dogs(image_path: str):
101
101
  2| image = load_image(image_path)
102
- 3| dogs = owl_v2("dog", image, threshold=0.24)
102
+ 3| dogs = owl_v2_image("dog", image, threshold=0.24)
103
103
  4| return dogs
104
104
  [End of artifact]
105
105