vision-agent 0.2.202__tar.gz → 0.2.203__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {vision_agent-0.2.202 → vision_agent-0.2.203}/PKG-INFO +1 -1
  2. {vision_agent-0.2.202 → vision_agent-0.2.203}/pyproject.toml +1 -1
  3. vision_agent-0.2.203/vision_agent/agent/README.md +89 -0
  4. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/agent.py +8 -3
  5. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/types.py +26 -1
  6. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_coder_v2.py +41 -4
  7. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_planner_v2.py +99 -20
  8. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_v2.py +72 -21
  9. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/planner_tools.py +166 -57
  10. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/tool_utils.py +6 -3
  11. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/execute.py +10 -3
  12. {vision_agent-0.2.202 → vision_agent-0.2.203}/LICENSE +0 -0
  13. {vision_agent-0.2.202 → vision_agent-0.2.203}/README.md +0 -0
  14. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/.sim_tools/df.csv +0 -0
  15. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/.sim_tools/embs.npy +0 -0
  16. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/__init__.py +0 -0
  17. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/__init__.py +0 -0
  18. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/agent_utils.py +0 -0
  19. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent.py +0 -0
  20. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_coder.py +0 -0
  21. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
  22. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_coder_prompts_v2.py +0 -0
  23. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_planner.py +0 -0
  24. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
  25. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_planner_prompts_v2.py +0 -0
  26. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_prompts.py +0 -0
  27. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/agent/vision_agent_prompts_v2.py +0 -0
  28. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/clients/__init__.py +0 -0
  29. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/clients/http.py +0 -0
  30. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/clients/landing_public_api.py +0 -0
  31. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/fonts/__init__.py +0 -0
  32. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  33. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/lmm/__init__.py +0 -0
  34. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/lmm/lmm.py +0 -0
  35. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/lmm/types.py +0 -0
  36. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/__init__.py +0 -0
  37. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/meta_tools.py +0 -0
  38. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/prompts.py +0 -0
  39. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/tools.py +0 -0
  40. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/tools/tools_types.py +0 -0
  41. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/__init__.py +0 -0
  42. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/exceptions.py +0 -0
  43. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/image_utils.py +0 -0
  44. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/sim.py +0 -0
  45. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/type_defs.py +0 -0
  46. {vision_agent-0.2.202 → vision_agent-0.2.203}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.202
3
+ Version: 0.2.203
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.202"
7
+ version = "0.2.203"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -0,0 +1,89 @@
1
+ ## V2 Agents
2
+
3
+ This gives an overview of all the V2 agents, how they communicate, and how human-in-the-loop works.
4
+
5
+ - `vision_agent_v2` - This is the conversation agent. It can take exactly one action and one response. The actions are fixed JSON actions (meaning it does not execute code but instead returns a JSON and we execute the code on it's behalf). This is so we can control the action arguments as well as pass around the notebook code interpreter.
6
+ - `vision_agent_planner_v2` - This agent is responsible for planning. It can run for N turns, each turn it can take some code action (it can execute it's own python code and has access to `planner_tools`) to test a new potential step in the plan.
7
+ - `vision_agent_coder_v2` - This agent is responsible for the final code. It can call the planner on it's own or it can take in the final `PlanContext` returned by `vision_agent_planner_v2` and use that to write the final code and test it.
8
+
9
+ ### Communication
10
+ The agents communicate through `AgentMessage`'s and return `PlanContext`'s and `CodeContext`'s for the planner and coder agent respectively.
11
+ ```
12
+ _______________
13
+ |VisionAgentV2|
14
+ ---------------
15
+ | ____________________
16
+ -----(AgentMessage)---> |VisionAgentCoderV2|
17
+ --------------------
18
+ | ______________________
19
+ -----(AgentMessage)----> |VisionAgentPlannerV2|
20
+ ----------------------
21
+ ____________________ |
22
+ |VisionAgentCoderV2| <----(PlanContext)-------
23
+ --------------------
24
+ _______________ |
25
+ |VisionAgentV2|<-----(CodeContext)--------
26
+ ---------------
27
+ ```
28
+
29
+ #### AgentMessage and Contexts
30
+ `AgentMessage` is a basic chat message but with extended roles. The roles can be your typical `user` and `assistant` but can also be `conversation`, `planner` or `coder` where the come from `VisionAgentV2`, `VisionAgentPlannerV2` and `VisionAgentCoderV2` respectively. `conversation`, `planner` and `coder` are all types of `assistant`. `observation`'s come from responses from executing python code internally by the planner.
31
+
32
+ The `VisionAgentPlannerV2` returns `PlanContext` which contains the a finalized version of the plan, including instructions and code snippets used during planning. `VisionAgentCoderV2` will then take in that `PlanContext` and return a `CodeContext` which contains the final code and any additional information.
33
+
34
+
35
+ #### Callbacks
36
+ If you want to recieve intermediate messages you can use the `update_callback` argument in all the `V2` constructors. This will asynchronously send `AgentMessage`'s to the callback function you provide. You can see an example of how to run this in `examples/chat/app.py`
37
+
38
+ ### Human-in-the-loop
39
+ Human-in-the-loop is a feature that allows the user to interact with the agents at certain points in the conversation. This is handled by using the `interaction` and `interaction_response` roles in the `AgentMessage`. You can enable this feature by passing `hil=True` to the `VisionAgentV2`, currently you can only use human-in-the-loop if you are also using the `update_callback` to collect the messages and pass them back to `VisionAgentV2`.
40
+
41
+ When the planner agent wants to interact with a human, it will return `InteractionContext` which will propogate back up to `VisionAgentV2` and then to the user. This exits the planner so it can ask for human input. If you collect the messages from `update_callback`, you will see the last `AgentMessage` has a role of `interaction` and the content will include a JSON string surrounded by `<interaction>` tags:
42
+
43
+ ```
44
+ AgentMessage(
45
+ role="interaction",
46
+ content="<interaction>{\"prompt\": \"Should I use owl_v2_image or countgd_counting?\"}</interaction>",
47
+ media=None,
48
+ )
49
+ ```
50
+
51
+ The user can then add an additional `AgentMessage` with the role `interaction_response` and the response they want to give:
52
+
53
+ ```
54
+ AgentMessage(
55
+ role="interaction_response",
56
+ content="{\"function_name\": \"owl_v2_image\"}",
57
+ media=None,
58
+ )
59
+ ```
60
+
61
+ You can see an example of how this works in `examples/chat/chat-app/src/components/ChatSection.tsx` under the `handleSubmit` function.
62
+
63
+
64
+ #### Human-in-the-loop Caveats
65
+ One issue with this approach is the planner is running code on a notebook and has access to all the previous executions. This means that for the internal notebook that the agents use, usually named `code_interpreter` from `CodeInterpreterFactor.new_instance`, you cannot close it or restart it. This is an issue because the notebook will close once you exit from the `VisionAgentV2` call.
66
+
67
+ To fix this you must construct a notebook outside of the chat and ensure it has `non_exiting=True`:
68
+
69
+ ```python
70
+ agent = VisionAgentV2(
71
+ verbose=True,
72
+ update_callback=update_callback,
73
+ hil=True,
74
+ )
75
+
76
+ code_interpreter = CodeInterpreterFactory.new_instance(non_exiting=True)
77
+ agent.chat(
78
+ [
79
+ AgentMessage(
80
+ role="user",
81
+ content="Hello",
82
+ media=None
83
+ )
84
+ ],
85
+ code_interpreter=code_interpreter
86
+ )
87
+ ```
88
+
89
+ An example of this can be seen in `examples/chat/app.py`. Here the `code_intepreter` is constructed outside of the chat and passed in. This is so the notebook does not close when the chat ends or returns when asking for human feedback.
@@ -2,7 +2,12 @@ from abc import ABC, abstractmethod
2
2
  from pathlib import Path
3
3
  from typing import Any, Dict, List, Optional, Union
4
4
 
5
- from vision_agent.agent.types import AgentMessage, CodeContext, PlanContext
5
+ from vision_agent.agent.types import (
6
+ AgentMessage,
7
+ CodeContext,
8
+ InteractionContext,
9
+ PlanContext,
10
+ )
6
11
  from vision_agent.lmm.types import Message
7
12
  from vision_agent.utils.execute import CodeInterpreter
8
13
 
@@ -31,7 +36,7 @@ class AgentCoder(Agent):
31
36
  chat: List[AgentMessage],
32
37
  max_steps: Optional[int] = None,
33
38
  code_interpreter: Optional[CodeInterpreter] = None,
34
- ) -> CodeContext:
39
+ ) -> Union[CodeContext, InteractionContext]:
35
40
  pass
36
41
 
37
42
  @abstractmethod
@@ -51,5 +56,5 @@ class AgentPlanner(Agent):
51
56
  chat: List[AgentMessage],
52
57
  max_steps: Optional[int] = None,
53
58
  code_interpreter: Optional[CodeInterpreter] = None,
54
- ) -> PlanContext:
59
+ ) -> Union[PlanContext, InteractionContext]:
55
60
  pass
@@ -17,12 +17,12 @@ class AgentMessage(BaseModel):
17
17
  interaction: An interaction between the user and the assistant. For example if the
18
18
  assistant wants to ask the user for help on a task, it could send an
19
19
  interaction message.
20
+ interaction_response: The user's response to an interaction message.
20
21
  conversation: Messages coming from the conversation agent, this is a type of
21
22
  assistant messages.
22
23
  planner: Messages coming from the planner agent, this is a type of assistant
23
24
  messages.
24
25
  coder: Messages coming from the coder agent, this is a type of assistant messages.
25
-
26
26
  """
27
27
 
28
28
  role: Union[
@@ -30,6 +30,7 @@ class AgentMessage(BaseModel):
30
30
  Literal["assistant"], # planner, coder and conversation are of type assistant
31
31
  Literal["observation"],
32
32
  Literal["interaction"],
33
+ Literal["interaction_response"],
33
34
  Literal["conversation"],
34
35
  Literal["planner"],
35
36
  Literal["coder"],
@@ -39,13 +40,37 @@ class AgentMessage(BaseModel):
39
40
 
40
41
 
41
42
  class PlanContext(BaseModel):
43
+ """PlanContext is a data model that represents the context of a plan.
44
+
45
+ plan: A description of the overall plan.
46
+ instructions: A list of step-by-step instructions.
47
+ code: Code snippets that were used during planning.
48
+ """
49
+
42
50
  plan: str
43
51
  instructions: List[str]
44
52
  code: str
45
53
 
46
54
 
47
55
  class CodeContext(BaseModel):
56
+ """CodeContext is a data model that represents final code and test cases.
57
+
58
+ code: The final code that was written.
59
+ test: The test cases that were written.
60
+ success: A boolean value indicating whether the code passed the test cases.
61
+ test_result: The result of running the test cases.
62
+ """
63
+
48
64
  code: str
49
65
  test: str
50
66
  success: bool
51
67
  test_result: Execution
68
+
69
+
70
+ class InteractionContext(BaseModel):
71
+ """InteractionContext is a data model that represents the context of an interaction.
72
+
73
+ chat: A list of messages exchanged between the user and the assistant.
74
+ """
75
+
76
+ chat: List[AgentMessage]
@@ -18,7 +18,12 @@ from vision_agent.agent.agent_utils import (
18
18
  print_code,
19
19
  strip_function_calls,
20
20
  )
21
- from vision_agent.agent.types import AgentMessage, CodeContext, PlanContext
21
+ from vision_agent.agent.types import (
22
+ AgentMessage,
23
+ CodeContext,
24
+ InteractionContext,
25
+ PlanContext,
26
+ )
22
27
  from vision_agent.agent.vision_agent_coder_prompts_v2 import CODE, FIX_BUG, TEST
23
28
  from vision_agent.agent.vision_agent_planner_v2 import VisionAgentPlannerV2
24
29
  from vision_agent.lmm import LMM, AnthropicLMM
@@ -257,6 +262,7 @@ class VisionAgentCoderV2(AgentCoder):
257
262
  tester: Optional[LMM] = None,
258
263
  debugger: Optional[LMM] = None,
259
264
  tool_recommender: Optional[Union[str, Sim]] = None,
265
+ hil: bool = False,
260
266
  verbose: bool = False,
261
267
  code_sandbox_runtime: Optional[str] = None,
262
268
  update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
@@ -272,6 +278,7 @@ class VisionAgentCoderV2(AgentCoder):
272
278
  None, a default AnthropicLMM will be used.
273
279
  debugger (Optional[LMM]): The language model to use for the debugger agent.
274
280
  tool_recommender (Optional[Union[str, Sim]]): The tool recommender to use.
281
+ hil (bool): Whether to use human-in-the-loop mode.
275
282
  verbose (bool): Whether to print out debug information.
276
283
  code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
277
284
  be one of: None, "local" or "e2b". If None, it will read from the
@@ -283,8 +290,11 @@ class VisionAgentCoderV2(AgentCoder):
283
290
  self.planner = (
284
291
  planner
285
292
  if planner is not None
286
- else VisionAgentPlannerV2(verbose=verbose, update_callback=update_callback)
293
+ else VisionAgentPlannerV2(
294
+ verbose=verbose, update_callback=update_callback, hil=hil
295
+ )
287
296
  )
297
+
288
298
  self.coder = (
289
299
  coder
290
300
  if coder is not None
@@ -311,6 +321,8 @@ class VisionAgentCoderV2(AgentCoder):
311
321
  self.verbose = verbose
312
322
  self.code_sandbox_runtime = code_sandbox_runtime
313
323
  self.update_callback = update_callback
324
+ if hasattr(self.planner, "update_callback"):
325
+ self.planner.update_callback = update_callback
314
326
 
315
327
  def __call__(
316
328
  self,
@@ -331,14 +343,17 @@ class VisionAgentCoderV2(AgentCoder):
331
343
  """
332
344
 
333
345
  input_msg = convert_message_to_agentmessage(input, media)
334
- return self.generate_code(input_msg).code
346
+ code_or_interaction = self.generate_code(input_msg)
347
+ if isinstance(code_or_interaction, InteractionContext):
348
+ return code_or_interaction.chat[-1].content
349
+ return code_or_interaction.code
335
350
 
336
351
  def generate_code(
337
352
  self,
338
353
  chat: List[AgentMessage],
339
354
  max_steps: Optional[int] = None,
340
355
  code_interpreter: Optional[CodeInterpreter] = None,
341
- ) -> CodeContext:
356
+ ) -> Union[CodeContext, InteractionContext]:
342
357
  """Generate vision code from a conversation.
343
358
 
344
359
  Parameters:
@@ -353,6 +368,11 @@ class VisionAgentCoderV2(AgentCoder):
353
368
  """
354
369
 
355
370
  chat = copy.deepcopy(chat)
371
+ if not chat or chat[-1].role not in {"user", "interaction_response"}:
372
+ raise ValueError(
373
+ f"Last chat message must be from the user or interaction_response, got {chat[-1].role}."
374
+ )
375
+
356
376
  with (
357
377
  CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
358
378
  if code_interpreter is None
@@ -362,6 +382,10 @@ class VisionAgentCoderV2(AgentCoder):
362
382
  plan_context = self.planner.generate_plan(
363
383
  int_chat, max_steps=max_steps, code_interpreter=code_interpreter
364
384
  )
385
+ # the planner needs an interaction, so return before generating code
386
+ if isinstance(plan_context, InteractionContext):
387
+ return plan_context
388
+
365
389
  code_context = self.generate_code_from_plan(
366
390
  orig_chat,
367
391
  plan_context,
@@ -391,6 +415,19 @@ class VisionAgentCoderV2(AgentCoder):
391
415
  """
392
416
 
393
417
  chat = copy.deepcopy(chat)
418
+ if not chat or chat[-1].role not in {"user", "interaction_response"}:
419
+ raise ValueError(
420
+ f"Last chat message must be from the user or interaction_response, got {chat[-1].role}."
421
+ )
422
+
423
+ # we don't need the user_interaction response for generating code since it's
424
+ # already in the plan context
425
+ while chat[-1].role != "user":
426
+ chat.pop()
427
+
428
+ if not chat:
429
+ raise ValueError("Chat must have at least one user message.")
430
+
394
431
  with (
395
432
  CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
396
433
  if code_interpreter is None
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import json
2
3
  import logging
3
4
  import time
4
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -21,7 +22,7 @@ from vision_agent.agent.agent_utils import (
21
22
  print_code,
22
23
  print_table,
23
24
  )
24
- from vision_agent.agent.types import AgentMessage, PlanContext
25
+ from vision_agent.agent.types import AgentMessage, InteractionContext, PlanContext
25
26
  from vision_agent.agent.vision_agent_planner_prompts_v2 import (
26
27
  CRITIQUE_PLAN,
27
28
  EXAMPLE_PLAN1,
@@ -32,6 +33,7 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
32
33
  PLAN,
33
34
  )
34
35
  from vision_agent.lmm import LMM, AnthropicLMM, Message
36
+ from vision_agent.tools.planner_tools import check_function_call, get_tool_documentation
35
37
  from vision_agent.utils.execute import (
36
38
  CodeInterpreter,
37
39
  CodeInterpreterFactory,
@@ -210,7 +212,7 @@ def execute_code_action(
210
212
  obs = execution.text(include_results=False).strip()
211
213
  if verbose:
212
214
  _CONSOLE.print(
213
- f"[bold cyan]Code Execution Output ({end - start:.2f} sec):[/bold cyan] [yellow]{escape(obs)}[/yellow]"
215
+ f"[bold cyan]Code Execution Output ({end - start:.2f}s):[/bold cyan] [yellow]{escape(obs)}[/yellow]"
214
216
  )
215
217
 
216
218
  count = 1
@@ -247,6 +249,31 @@ def find_and_replace_code(response: str, code: str) -> str:
247
249
  return response[:code_start] + code + response[code_end:]
248
250
 
249
251
 
252
+ def create_hil_response(
253
+ execution: Execution,
254
+ ) -> AgentMessage:
255
+ content = []
256
+ for result in execution.results:
257
+ for format in result.formats():
258
+ if format == "json":
259
+ data = result.json
260
+ try:
261
+ data_dict: Dict[str, Any] = dict(data) # type: ignore
262
+ if (
263
+ "request" in data_dict
264
+ and "function_name" in data_dict["request"]
265
+ ):
266
+ content.append(data_dict)
267
+ except Exception:
268
+ continue
269
+
270
+ return AgentMessage(
271
+ role="interaction",
272
+ content="<interaction>" + json.dumps(content) + "</interaction>",
273
+ media=None,
274
+ )
275
+
276
+
250
277
  def maybe_run_code(
251
278
  code: Optional[str],
252
279
  response: str,
@@ -254,6 +281,7 @@ def maybe_run_code(
254
281
  media_list: List[Union[str, Path]],
255
282
  model: LMM,
256
283
  code_interpreter: CodeInterpreter,
284
+ hil: bool = False,
257
285
  verbose: bool = False,
258
286
  ) -> List[AgentMessage]:
259
287
  return_chat: List[AgentMessage] = []
@@ -270,6 +298,11 @@ def maybe_run_code(
270
298
  AgentMessage(role="planner", content=fixed_response, media=None)
271
299
  )
272
300
 
301
+ # if we are running human-in-the-loop mode, send back an interaction message
302
+ # make sure we return code from planner and the hil response
303
+ if check_function_call(code, "get_tool_for_task") and hil:
304
+ return return_chat + [create_hil_response(execution)]
305
+
273
306
  media_data = capture_media_from_exec(execution)
274
307
  int_chat_elt = AgentMessage(role="observation", content=obs, media=None)
275
308
  if media_list:
@@ -285,6 +318,10 @@ def create_finalize_plan(
285
318
  model: LMM,
286
319
  verbose: bool = False,
287
320
  ) -> Tuple[List[AgentMessage], PlanContext]:
321
+ # if we're in the middle of an interaction, don't finalize the plan
322
+ if chat[-1].role == "interaction":
323
+ return [], PlanContext(plan="", instructions=[], code="")
324
+
288
325
  prompt = FINALIZE_PLAN.format(
289
326
  planning=get_planning(chat),
290
327
  excluded_tools=str([t.__name__ for t in pt.PLANNER_TOOLS]),
@@ -318,6 +355,27 @@ def get_steps(chat: List[AgentMessage], max_steps: int) -> int:
318
355
  return max_steps
319
356
 
320
357
 
358
+ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]:
359
+ chat = copy.deepcopy(chat)
360
+ new_chat = []
361
+
362
+ for i, chat_i in enumerate(chat):
363
+ if chat_i.role == "interaction" and (
364
+ i < len(chat) and chat[i + 1].role == "interaction_response"
365
+ ):
366
+ try:
367
+ response = json.loads(chat[i + 1].content)
368
+ function_name = response["function_name"]
369
+ tool_doc = get_tool_documentation(function_name)
370
+ new_chat.append(AgentMessage(role="observation", content=tool_doc))
371
+ except json.JSONDecodeError:
372
+ raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
373
+ else:
374
+ new_chat.append(chat_i)
375
+
376
+ return new_chat
377
+
378
+
321
379
  class VisionAgentPlannerV2(AgentPlanner):
322
380
  """VisionAgentPlannerV2 is a class that generates a plan to solve a vision task."""
323
381
 
@@ -328,6 +386,7 @@ class VisionAgentPlannerV2(AgentPlanner):
328
386
  max_steps: int = 10,
329
387
  use_multi_trial_planning: bool = False,
330
388
  critique_steps: int = 11,
389
+ hil: bool = False,
331
390
  verbose: bool = False,
332
391
  code_sandbox_runtime: Optional[str] = None,
333
392
  update_callback: Callable[[Dict[str, Any]], None] = lambda _: None,
@@ -343,6 +402,7 @@ class VisionAgentPlannerV2(AgentPlanner):
343
402
  use_multi_trial_planning (bool): Whether to use multi-trial planning.
344
403
  critique_steps (int): The number of steps between critiques. If critic steps
345
404
  is larger than max_steps no critiques will be made.
405
+ hil (bool): Whether to use human-in-the-loop mode.
346
406
  verbose (bool): Whether to print out debug information.
347
407
  code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
348
408
  be one of: None, "local" or "e2b". If None, it will read from the
@@ -365,6 +425,11 @@ class VisionAgentPlannerV2(AgentPlanner):
365
425
  self.use_multi_trial_planning = use_multi_trial_planning
366
426
  self.critique_steps = critique_steps
367
427
 
428
+ self.hil = hil
429
+ if self.hil:
430
+ DefaultPlanningImports.imports.append(
431
+ "from vision_agent.tools.planner_tools import get_tool_for_task_human_reviewer as get_tool_for_task"
432
+ )
368
433
  self.verbose = verbose
369
434
  self.code_sandbox_runtime = code_sandbox_runtime
370
435
  self.update_callback = update_callback
@@ -388,15 +453,17 @@ class VisionAgentPlannerV2(AgentPlanner):
388
453
  """
389
454
 
390
455
  input_msg = convert_message_to_agentmessage(input, media)
391
- plan = self.generate_plan(input_msg)
392
- return plan.plan
456
+ plan_or_interaction = self.generate_plan(input_msg)
457
+ if isinstance(plan_or_interaction, InteractionContext):
458
+ return plan_or_interaction.chat[-1].content
459
+ return plan_or_interaction.plan
393
460
 
394
461
  def generate_plan(
395
462
  self,
396
463
  chat: List[AgentMessage],
397
464
  max_steps: Optional[int] = None,
398
465
  code_interpreter: Optional[CodeInterpreter] = None,
399
- ) -> PlanContext:
466
+ ) -> Union[PlanContext, InteractionContext]:
400
467
  """Generate a plan to solve a vision task.
401
468
 
402
469
  Parameters:
@@ -409,24 +476,30 @@ class VisionAgentPlannerV2(AgentPlanner):
409
476
  needed to solve the task.
410
477
  """
411
478
 
412
- if not chat:
413
- raise ValueError("Chat cannot be empty")
479
+ if not chat or chat[-1].role not in {"user", "interaction_response"}:
480
+ raise ValueError(
481
+ f"Last chat message must be from the user or interaction_response, got {chat[-1].role}."
482
+ )
414
483
 
415
484
  chat = copy.deepcopy(chat)
416
- code_interpreter = code_interpreter or CodeInterpreterFactory.new_instance(
417
- self.code_sandbox_runtime
418
- )
419
485
  max_steps = max_steps or self.max_steps
420
486
 
421
- with code_interpreter:
487
+ with (
488
+ CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
489
+ if code_interpreter is None
490
+ else code_interpreter
491
+ ) as code_interpreter:
422
492
  critque_steps = 1
423
493
  finished = False
494
+ interaction = False
424
495
  int_chat, _, media_list = add_media_to_chat(chat, code_interpreter)
496
+ int_chat = replace_interaction_with_obs(int_chat)
425
497
 
426
498
  step = get_steps(int_chat, max_steps)
427
499
  if "<count>" not in int_chat[-1].content and step == max_steps:
428
500
  int_chat[-1].content += f"\n<count>{step}</count>\n"
429
- while step > 0 and not finished:
501
+
502
+ while step > 0 and not finished and not interaction:
430
503
  if self.use_multi_trial_planning:
431
504
  response = run_multi_trial_planning(
432
505
  int_chat, media_list, self.planner
@@ -456,8 +529,10 @@ class VisionAgentPlannerV2(AgentPlanner):
456
529
  media_list,
457
530
  self.planner,
458
531
  code_interpreter,
459
- self.verbose,
532
+ hil=self.hil,
533
+ verbose=self.verbose,
460
534
  )
535
+ interaction = updated_chat[-1].role == "interaction"
461
536
 
462
537
  if critque_steps % self.critique_steps == 0:
463
538
  critique = run_critic(int_chat, media_list, self.critic)
@@ -478,14 +553,18 @@ class VisionAgentPlannerV2(AgentPlanner):
478
553
  for chat_elt in updated_chat:
479
554
  self.update_callback(chat_elt.model_dump())
480
555
 
481
- updated_chat, plan_context = create_finalize_plan(
482
- int_chat, self.planner, self.verbose
483
- )
484
- int_chat.extend(updated_chat)
485
- for chat_elt in updated_chat:
486
- self.update_callback(chat_elt.model_dump())
556
+ context: Union[PlanContext, InteractionContext]
557
+ if interaction:
558
+ context = InteractionContext(chat=int_chat)
559
+ else:
560
+ updated_chat, context = create_finalize_plan(
561
+ int_chat, self.planner, self.verbose
562
+ )
563
+ int_chat.extend(updated_chat)
564
+ for chat_elt in updated_chat:
565
+ self.update_callback(chat_elt.model_dump())
487
566
 
488
- return plan_context
567
+ return context
489
568
 
490
569
  def log_progress(self, data: Dict[str, Any]) -> None:
491
570
  pass
@@ -1,4 +1,5 @@
1
1
  import copy
2
+ import json
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, Dict, List, Optional, Union, cast
4
5
 
@@ -8,7 +9,12 @@ from vision_agent.agent.agent_utils import (
8
9
  convert_message_to_agentmessage,
9
10
  extract_tag,
10
11
  )
11
- from vision_agent.agent.types import AgentMessage, PlanContext
12
+ from vision_agent.agent.types import (
13
+ AgentMessage,
14
+ CodeContext,
15
+ InteractionContext,
16
+ PlanContext,
17
+ )
12
18
  from vision_agent.agent.vision_agent_coder_v2 import format_code_context
13
19
  from vision_agent.agent.vision_agent_prompts_v2 import CONVERSATION
14
20
  from vision_agent.lmm import LMM, AnthropicLMM
@@ -39,10 +45,24 @@ def run_conversation(agent: LMM, chat: List[AgentMessage]) -> str:
39
45
  return cast(str, response)
40
46
 
41
47
 
48
+ def check_for_interaction(chat: List[AgentMessage]) -> bool:
49
+ return (
50
+ len(chat) > 2
51
+ and chat[-2].role == "interaction"
52
+ and chat[-1].role == "interaction_response"
53
+ )
54
+
55
+
42
56
  def extract_conversation_for_generate_code(
43
57
  chat: List[AgentMessage],
44
58
  ) -> List[AgentMessage]:
45
59
  chat = copy.deepcopy(chat)
60
+
61
+ # if we are in the middle of an interaction, return all the intermediate planning
62
+ # steps
63
+ if check_for_interaction(chat):
64
+ return chat
65
+
46
66
  extracted_chat = []
47
67
  for chat_i in chat:
48
68
  if chat_i.role == "user":
@@ -66,13 +86,20 @@ def maybe_run_action(
66
86
  # to the outside user via it's update_callback, but we don't necessarily have
67
87
  # access to that update_callback here, so we re-create the message using
68
88
  # format_code_context.
69
- code_context = coder.generate_code(
70
- extracted_chat, code_interpreter=code_interpreter
71
- )
72
- return [
73
- AgentMessage(role="coder", content=format_code_context(code_context)),
74
- AgentMessage(role="observation", content=code_context.test_result.text()),
75
- ]
89
+ context = coder.generate_code(extracted_chat, code_interpreter=code_interpreter)
90
+
91
+ if isinstance(context, CodeContext):
92
+ return [
93
+ AgentMessage(role="coder", content=format_code_context(context)),
94
+ AgentMessage(role="observation", content=context.test_result.text()),
95
+ ]
96
+ elif isinstance(context, InteractionContext):
97
+ return [
98
+ AgentMessage(
99
+ role="interaction",
100
+ content=json.dumps([elt.model_dump() for elt in context.chat]),
101
+ )
102
+ ]
76
103
  elif action == "edit_code":
77
104
  extracted_chat = extract_conversation_for_generate_code(chat)
78
105
  plan_context = PlanContext(
@@ -80,12 +107,12 @@ def maybe_run_action(
80
107
  instructions=[],
81
108
  code="",
82
109
  )
83
- code_context = coder.generate_code_from_plan(
110
+ context = coder.generate_code_from_plan(
84
111
  extracted_chat, plan_context, code_interpreter=code_interpreter
85
112
  )
86
113
  return [
87
- AgentMessage(role="coder", content=format_code_context(code_context)),
88
- AgentMessage(role="observation", content=code_context.test_result.text()),
114
+ AgentMessage(role="coder", content=format_code_context(context)),
115
+ AgentMessage(role="observation", content=context.test_result.text()),
89
116
  ]
90
117
  elif action == "view_image":
91
118
  pass
@@ -102,6 +129,7 @@ class VisionAgentV2(Agent):
102
129
  self,
103
130
  agent: Optional[LMM] = None,
104
131
  coder: Optional[AgentCoder] = None,
132
+ hil: bool = False,
105
133
  verbose: bool = False,
106
134
  code_sandbox_runtime: Optional[str] = None,
107
135
  update_callback: Callable[[Dict[str, Any]], None] = lambda x: None,
@@ -113,6 +141,7 @@ class VisionAgentV2(Agent):
113
141
  default AnthropicLMM will be used.
114
142
  coder (Optional[AgentCoder]): The coder agent to use for generating vision
115
143
  code. If None, a default VisionAgentCoderV2 will be used.
144
+ hil (bool): Whether to use human-in-the-loop mode.
116
145
  verbose (bool): Whether to print out debug information.
117
146
  code_sandbox_runtime (Optional[str]): The code sandbox runtime to use, can
118
147
  be one of: None, "local" or "e2b". If None, it will read from the
@@ -132,7 +161,9 @@ class VisionAgentV2(Agent):
132
161
  self.coder = (
133
162
  coder
134
163
  if coder is not None
135
- else VisionAgentCoderV2(verbose=verbose, update_callback=update_callback)
164
+ else VisionAgentCoderV2(
165
+ verbose=verbose, update_callback=update_callback, hil=hil
166
+ )
136
167
  )
137
168
 
138
169
  self.verbose = verbose
@@ -169,6 +200,7 @@ class VisionAgentV2(Agent):
169
200
  def chat(
170
201
  self,
171
202
  chat: List[AgentMessage],
203
+ code_interpreter: Optional[CodeInterpreter] = None,
172
204
  ) -> List[AgentMessage]:
173
205
  """Conversational interface to the agent. This is the main method to use to
174
206
  interact with the agent. It takes in a list of messages and returns the agent's
@@ -177,28 +209,47 @@ class VisionAgentV2(Agent):
177
209
  Parameters:
178
210
  chat (List[AgentMessage]): The input to the agent. This should be a list of
179
211
  AgentMessage objects.
212
+ code_interpreter (Optional[CodeInterpreter]): The code interpreter to use.
180
213
 
181
214
  Returns:
182
215
  List[AgentMessage]: The agent's response as a list of AgentMessage objects.
183
216
  """
184
217
 
218
+ chat = copy.deepcopy(chat)
219
+ if not chat or chat[-1].role not in {"user", "interaction_response"}:
220
+ raise ValueError(
221
+ f"Last chat message must be from the user or interaction_response, got {chat[-1].role}."
222
+ )
223
+
185
224
  return_chat = []
186
- with CodeInterpreterFactory.new_instance(
187
- self.code_sandbox_runtime
225
+ with (
226
+ CodeInterpreterFactory.new_instance(self.code_sandbox_runtime)
227
+ if code_interpreter is None
228
+ else code_interpreter
188
229
  ) as code_interpreter:
189
230
  int_chat, _, _ = add_media_to_chat(chat, code_interpreter)
190
- response_context = run_conversation(self.agent, int_chat)
191
- return_chat.append(
192
- AgentMessage(role="conversation", content=response_context)
193
- )
194
- self.update_callback(return_chat[-1].model_dump())
195
231
 
196
- action = extract_tag(response_context, "action")
232
+ # if we had an interaction and then recieved an observation from the user
233
+ # go back into the same action to finish it.
234
+ action = None
235
+ if check_for_interaction(int_chat):
236
+ action = "generate_or_edit_vision_code"
237
+ else:
238
+ response_context = run_conversation(self.agent, int_chat)
239
+ return_chat.append(
240
+ AgentMessage(role="conversation", content=response_context)
241
+ )
242
+ self.update_callback(return_chat[-1].model_dump())
243
+ action = extract_tag(response_context, "action")
197
244
 
198
245
  updated_chat = maybe_run_action(
199
246
  self.coder, action, int_chat, code_interpreter=code_interpreter
200
247
  )
201
- if updated_chat is not None:
248
+
249
+ # return an interaction early to get users feedback
250
+ if updated_chat is not None and updated_chat[-1].role == "interaction":
251
+ return_chat.extend(updated_chat)
252
+ elif updated_chat is not None and updated_chat[-1].role != "interaction":
202
253
  # do not append updated_chat to return_chat becuase the observation
203
254
  # from running the action will have already been added via the callbacks
204
255
  obs_response_context = run_conversation(
@@ -3,7 +3,9 @@ import shutil
3
3
  import tempfile
4
4
  from typing import Any, Callable, Dict, List, Optional, Tuple, cast
5
5
 
6
+ import libcst as cst
6
7
  import numpy as np
8
+ from IPython.display import display
7
9
  from PIL import Image
8
10
 
9
11
  import vision_agent.tools as T
@@ -21,8 +23,13 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
21
23
  TEST_TOOLS_EXAMPLE1,
22
24
  TEST_TOOLS_EXAMPLE2,
23
25
  )
24
- from vision_agent.lmm import AnthropicLMM
25
- from vision_agent.utils.execute import CodeInterpreterFactory
26
+ from vision_agent.lmm import LMM, AnthropicLMM
27
+ from vision_agent.utils.execute import (
28
+ CodeInterpreter,
29
+ CodeInterpreterFactory,
30
+ Execution,
31
+ MimeType,
32
+ )
26
33
  from vision_agent.utils.image_utils import convert_to_b64
27
34
  from vision_agent.utils.sim import load_cached_sim
28
35
 
@@ -33,6 +40,16 @@ _LOGGER = logging.getLogger(__name__)
33
40
  EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
34
41
 
35
42
 
43
+ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
44
+ return_str = "[get_tool_for_task output]\n"
45
+ if tool_thoughts.strip() != "":
46
+ return_str += f"{tool_thoughts}\n\n"
47
+ return_str += (
48
+ f"Tool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
49
+ )
50
+ return return_str
51
+
52
+
36
53
  def extract_tool_info(
37
54
  tool_choice_context: Dict[str, Any]
38
55
  ) -> Tuple[Optional[Callable], str, str, str]:
@@ -46,6 +63,70 @@ def extract_tool_info(
46
63
  return tool, tool_thoughts, tool_docstring, ""
47
64
 
48
65
 
66
+ def run_tool_testing(
67
+ task: str,
68
+ image_paths: List[str],
69
+ lmm: LMM,
70
+ exclude_tools: Optional[List[str]],
71
+ code_interpreter: CodeInterpreter,
72
+ ) -> tuple[str, str, Execution]:
73
+ """Helper function to generate and run tool testing code."""
74
+ query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
75
+ category = extract_tag(query, "category") # type: ignore
76
+ if category is None:
77
+ category = task
78
+ else:
79
+ category = (
80
+ f"I need models from the {category.strip()} category of tools. {task}"
81
+ )
82
+
83
+ tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
84
+ if exclude_tools is not None and len(exclude_tools) > 0:
85
+ cleaned_tool_docs = []
86
+ for tool_doc in tool_docs:
87
+ if not tool_doc["name"] in exclude_tools:
88
+ cleaned_tool_docs.append(tool_doc)
89
+ tool_docs = cleaned_tool_docs
90
+ tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
91
+
92
+ prompt = TEST_TOOLS.format(
93
+ tool_docs=tool_docs_str,
94
+ previous_attempts="",
95
+ user_request=task,
96
+ examples=EXAMPLES,
97
+ media=str(image_paths),
98
+ )
99
+
100
+ response = lmm.generate(prompt, media=image_paths)
101
+ code = extract_tag(response, "code") # type: ignore
102
+ if code is None:
103
+ raise ValueError(f"Could not extract code from response: {response}")
104
+ tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
105
+ tool_output_str = tool_output.text(include_results=False).strip()
106
+
107
+ count = 1
108
+ while (
109
+ not tool_output.success
110
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
111
+ ) and count <= 3:
112
+ if tool_output_str.strip() == "":
113
+ tool_output_str = "EMPTY"
114
+ prompt = TEST_TOOLS.format(
115
+ tool_docs=tool_docs_str,
116
+ previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
117
+ user_request=task,
118
+ examples=EXAMPLES,
119
+ media=str(image_paths),
120
+ )
121
+ code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
122
+ tool_output = code_interpreter.exec_isolation(
123
+ DefaultImports.prepend_imports(code)
124
+ )
125
+ tool_output_str = tool_output.text(include_results=False).strip()
126
+
127
+ return code, tool_docs_str, tool_output
128
+
129
+
49
130
  def get_tool_for_task(
50
131
  task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
51
132
  ) -> None:
@@ -90,61 +171,11 @@ def get_tool_for_task(
90
171
  Image.fromarray(image).save(image_path)
91
172
  image_paths.append(image_path)
92
173
 
93
- query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
94
- category = extract_tag(query, "category") # type: ignore
95
- if category is None:
96
- category = task
97
- else:
98
- category = (
99
- f"I need models from the {category.strip()} category of tools. {task}"
100
- )
101
-
102
- tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
103
- if exclude_tools is not None and len(exclude_tools) > 0:
104
- cleaned_tool_docs = []
105
- for tool_doc in tool_docs:
106
- if not tool_doc["name"] in exclude_tools:
107
- cleaned_tool_docs.append(tool_doc)
108
- tool_docs = cleaned_tool_docs
109
- tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
110
-
111
- prompt = TEST_TOOLS.format(
112
- tool_docs=tool_docs_str,
113
- previous_attempts="",
114
- user_request=task,
115
- examples=EXAMPLES,
116
- media=str(image_paths),
117
- )
118
-
119
- response = lmm.generate(prompt, media=image_paths)
120
- code = extract_tag(response, "code") # type: ignore
121
- if code is None:
122
- raise ValueError(f"Could not extract code from response: {response}")
123
- tool_output = code_interpreter.exec_isolation(
124
- DefaultImports.prepend_imports(code)
174
+ code, tool_docs_str, tool_output = run_tool_testing(
175
+ task, image_paths, lmm, exclude_tools, code_interpreter
125
176
  )
126
177
  tool_output_str = tool_output.text(include_results=False).strip()
127
178
 
128
- count = 1
129
- while (
130
- not tool_output.success
131
- or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
132
- ) and count <= 3:
133
- if tool_output_str.strip() == "":
134
- tool_output_str = "EMPTY"
135
- prompt = TEST_TOOLS.format(
136
- tool_docs=tool_docs_str,
137
- previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
138
- user_request=task,
139
- examples=EXAMPLES,
140
- media=str(image_paths),
141
- )
142
- code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
143
- tool_output = code_interpreter.exec_isolation(
144
- DefaultImports.prepend_imports(code)
145
- )
146
- tool_output_str = tool_output.text(include_results=False).strip()
147
-
148
179
  error_message = ""
149
180
  prompt = PICK_TOOL.format(
150
181
  tool_docs=tool_docs_str,
@@ -178,9 +209,87 @@ def get_tool_for_task(
178
209
  except Exception as e:
179
210
  _LOGGER.error(f"Error removing temp directory: {e}")
180
211
 
181
- print(
182
- f"[get_tool_for_task output]\n{tool_thoughts}\n\nTool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
183
- )
212
+ print(format_tool_output(tool_thoughts, tool_docstring))
213
+
214
+
215
+ def get_tool_documentation(tool_name: str) -> str:
216
+ # use same format as get_tool_for_task
217
+ tool_doc = T.TOOLS_DF[T.TOOLS_DF["name"] == tool_name]["doc"].values[0]
218
+ return format_tool_output("", tool_doc)
219
+
220
+
221
+ def get_tool_for_task_human_reviewer(
222
+ task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
223
+ ) -> None:
224
+ # NOTE: this should be the same documentation as get_tool_for_task
225
+ """Given a task and one or more images this function will find a tool to accomplish
226
+ the jobs. It prints the tool documentation and thoughts on why it chose the tool.
227
+
228
+ It can produce tools for the following types of tasks:
229
+ - Object detection and counting
230
+ - Classification
231
+ - Segmentation
232
+ - OCR
233
+ - VQA
234
+ - Depth and pose estimation
235
+ - Video object tracking
236
+
237
+ Wait until the documentation is printed to use the function so you know what the
238
+ input and output signatures are.
239
+
240
+ Parameters:
241
+ task: str: The task to accomplish.
242
+ images: List[np.ndarray]: The images to use for the task.
243
+ exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
244
+ recommendations. This is helpful if you are calling get_tool_for_task twice
245
+ and do not want the same tool recommended.
246
+
247
+ Returns:
248
+ The tool to use for the task is printed to stdout
249
+
250
+ Examples
251
+ --------
252
+ >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
253
+ """
254
+ lmm = AnthropicLMM()
255
+
256
+ with (
257
+ tempfile.TemporaryDirectory() as tmpdirname,
258
+ CodeInterpreterFactory.new_instance() as code_interpreter,
259
+ ):
260
+ image_paths = []
261
+ for i, image in enumerate(images[:3]):
262
+ image_path = f"{tmpdirname}/image_{i}.png"
263
+ Image.fromarray(image).save(image_path)
264
+ image_paths.append(image_path)
265
+
266
+ _, _, tool_output = run_tool_testing(
267
+ task, image_paths, lmm, exclude_tools, code_interpreter
268
+ )
269
+
270
+ # need to re-display results for the outer notebook to see them
271
+ for result in tool_output.results:
272
+ if "json" in result.formats():
273
+ display({MimeType.APPLICATION_JSON: result.json}, raw=True)
274
+
275
+
276
+ def check_function_call(code: str, function_name: str) -> bool:
277
+ class FunctionCallVisitor(cst.CSTVisitor):
278
+ def __init__(self) -> None:
279
+ self.function_name = function_name
280
+ self.function_called = False
281
+
282
+ def visit_Call(self, node: cst.Call) -> None:
283
+ if (
284
+ isinstance(node.func, cst.Name)
285
+ and node.func.value == self.function_name
286
+ ):
287
+ self.function_called = True
288
+
289
+ tree = cst.parse_module(code)
290
+ visitor = FunctionCallVisitor()
291
+ tree.visit(visitor)
292
+ return visitor.function_called
184
293
 
185
294
 
186
295
  def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
@@ -213,6 +213,8 @@ def _call_post(
213
213
  files_in_b64 = None
214
214
  if files:
215
215
  files_in_b64 = [(file[0], b64encode(file[1]).decode("utf-8")) for file in files]
216
+
217
+ tool_call_trace = None
216
218
  try:
217
219
  if files is not None:
218
220
  response = session.post(url, data=payload, files=files)
@@ -250,9 +252,10 @@ def _call_post(
250
252
  tool_call_trace.response = result
251
253
  return result
252
254
  finally:
253
- trace = tool_call_trace.model_dump()
254
- trace["type"] = "tool_call"
255
- display({MimeType.APPLICATION_JSON: trace}, raw=True)
255
+ if tool_call_trace is not None:
256
+ trace = tool_call_trace.model_dump()
257
+ trace["type"] = "tool_call"
258
+ display({MimeType.APPLICATION_JSON: trace}, raw=True)
256
259
 
257
260
 
258
261
  def filter_bboxes_by_threshold(
@@ -398,17 +398,20 @@ class CodeInterpreter(abc.ABC):
398
398
  self,
399
399
  timeout: int,
400
400
  remote_path: Optional[Union[str, Path]] = None,
401
+ non_exiting: bool = False,
401
402
  *args: Any,
402
403
  **kwargs: Any,
403
404
  ) -> None:
404
405
  self.timeout = timeout
405
406
  self.remote_path = Path(remote_path if remote_path is not None else WORKSPACE)
407
+ self.non_exiting = non_exiting
406
408
 
407
409
  def __enter__(self) -> Self:
408
410
  return self
409
411
 
410
412
  def __exit__(self, *exc_info: Any) -> None:
411
- self.close()
413
+ if not self.non_exiting:
414
+ self.close()
412
415
 
413
416
  def close(self, *args: Any, **kwargs: Any) -> None:
414
417
  raise NotImplementedError()
@@ -571,8 +574,9 @@ class LocalCodeInterpreter(CodeInterpreter):
571
574
  self,
572
575
  timeout: int = _SESSION_TIMEOUT,
573
576
  remote_path: Optional[Union[str, Path]] = None,
577
+ non_exiting: bool = False,
574
578
  ) -> None:
575
- super().__init__(timeout=timeout)
579
+ super().__init__(timeout=timeout, non_exiting=non_exiting)
576
580
  self.nb = nbformat.v4.new_notebook()
577
581
  # Set the notebook execution path to the remote path
578
582
  self.remote_path = Path(remote_path if remote_path is not None else WORKSPACE)
@@ -692,6 +696,7 @@ class CodeInterpreterFactory:
692
696
  def new_instance(
693
697
  code_sandbox_runtime: Optional[str] = None,
694
698
  remote_path: Optional[Union[str, Path]] = None,
699
+ non_exiting: bool = False,
695
700
  ) -> CodeInterpreter:
696
701
  if not code_sandbox_runtime:
697
702
  code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
@@ -702,7 +707,9 @@ class CodeInterpreterFactory:
702
707
  )
703
708
  elif code_sandbox_runtime == "local":
704
709
  instance = LocalCodeInterpreter(
705
- timeout=_SESSION_TIMEOUT, remote_path=remote_path
710
+ timeout=_SESSION_TIMEOUT,
711
+ remote_path=remote_path,
712
+ non_exiting=non_exiting,
706
713
  )
707
714
  else:
708
715
  raise ValueError(
File without changes
File without changes