vision-agent 0.2.204__py3-none-any.whl → 0.2.206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -367,8 +367,10 @@ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]
367
367
  response = json.loads(chat[i + 1].content)
368
368
  function_name = response["function_name"]
369
369
  tool_doc = get_tool_documentation(function_name)
370
+ if "box_threshold" in response:
371
+ tool_doc = f"Use the following function with box_threshold={response['box_threshold']}\n\n{tool_doc}"
370
372
  new_chat.append(AgentMessage(role="observation", content=tool_doc))
371
- except json.JSONDecodeError:
373
+ except (json.JSONDecodeError, KeyError):
372
374
  raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
373
375
  else:
374
376
  new_chat.append(chat_i)
@@ -1,6 +1,8 @@
1
+ import inspect
1
2
  import logging
2
3
  import shutil
3
4
  import tempfile
5
+ from functools import lru_cache
4
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, cast
5
7
 
6
8
  import libcst as cst
@@ -31,15 +33,19 @@ from vision_agent.utils.execute import (
31
33
  MimeType,
32
34
  )
33
35
  from vision_agent.utils.image_utils import convert_to_b64
34
- from vision_agent.utils.sim import load_cached_sim
36
+ from vision_agent.utils.sim import Sim, load_cached_sim
35
37
 
36
38
  TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
37
- TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
38
39
 
39
40
  _LOGGER = logging.getLogger(__name__)
40
41
  EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
41
42
 
42
43
 
44
+ @lru_cache(maxsize=1)
45
+ def get_tool_recommender() -> Sim:
46
+ return load_cached_sim(T.TOOLS_DF)
47
+
48
+
43
49
  def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
44
50
  return_str = "[get_tool_for_task output]\n"
45
51
  if tool_thoughts.strip() != "":
@@ -51,7 +57,7 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
51
57
 
52
58
 
53
59
  def extract_tool_info(
54
- tool_choice_context: Dict[str, Any]
60
+ tool_choice_context: Dict[str, Any],
55
61
  ) -> Tuple[Optional[Callable], str, str, str]:
56
62
  tool_thoughts = tool_choice_context.get("thoughts", "")
57
63
  tool_docstring = ""
@@ -63,12 +69,55 @@ def extract_tool_info(
63
69
  return tool, tool_thoughts, tool_docstring, ""
64
70
 
65
71
 
72
+ def replace_box_threshold(code: str, functions: List[str], box_threshold: float) -> str:
73
+ class ReplaceBoxThresholdTransformer(cst.CSTTransformer):
74
+ def leave_Call(
75
+ self, original_node: cst.Call, updated_node: cst.Call
76
+ ) -> cst.Call:
77
+ if (
78
+ isinstance(updated_node.func, cst.Name)
79
+ and updated_node.func.value in functions
80
+ ) or (
81
+ isinstance(updated_node.func, cst.Attribute)
82
+ and updated_node.func.attr.value in functions
83
+ ):
84
+ new_args = []
85
+ found = False
86
+ for arg in updated_node.args:
87
+ if arg.keyword and arg.keyword.value == "box_threshold":
88
+ new_arg = arg.with_changes(value=cst.Float(str(box_threshold)))
89
+ new_args.append(new_arg)
90
+ found = True
91
+ else:
92
+ new_args.append(arg)
93
+
94
+ if not found:
95
+ new_args.append(
96
+ cst.Arg(
97
+ keyword=cst.Name("box_threshold"),
98
+ value=cst.Float(str(box_threshold)),
99
+ equal=cst.AssignEqual(
100
+ whitespace_before=cst.SimpleWhitespace(""),
101
+ whitespace_after=cst.SimpleWhitespace(""),
102
+ ),
103
+ )
104
+ )
105
+ return updated_node.with_changes(args=new_args)
106
+ return updated_node
107
+
108
+ tree = cst.parse_module(code)
109
+ transformer = ReplaceBoxThresholdTransformer()
110
+ new_tree = tree.visit(transformer)
111
+ return new_tree.code
112
+
113
+
66
114
  def run_tool_testing(
67
115
  task: str,
68
116
  image_paths: List[str],
69
117
  lmm: LMM,
70
118
  exclude_tools: Optional[List[str]],
71
119
  code_interpreter: CodeInterpreter,
120
+ process_code: Callable[[str], str] = lambda x: x,
72
121
  ) -> tuple[str, str, Execution]:
73
122
  """Helper function to generate and run tool testing code."""
74
123
  query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
@@ -80,7 +129,7 @@ def run_tool_testing(
80
129
  f"I need models from the {category.strip()} category of tools. {task}"
81
130
  )
82
131
 
83
- tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
132
+ tool_docs = get_tool_recommender().top_k(category, k=10, thresh=0.2)
84
133
  if exclude_tools is not None and len(exclude_tools) > 0:
85
134
  cleaned_tool_docs = []
86
135
  for tool_doc in tool_docs:
@@ -101,6 +150,7 @@ def run_tool_testing(
101
150
  code = extract_tag(response, "code") # type: ignore
102
151
  if code is None:
103
152
  raise ValueError(f"Could not extract code from response: {response}")
153
+ code = process_code(code)
104
154
  tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
105
155
  tool_output_str = tool_output.text(include_results=False).strip()
106
156
 
@@ -119,6 +169,7 @@ def run_tool_testing(
119
169
  media=str(image_paths),
120
170
  )
121
171
  code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
172
+ code = process_code(code)
122
173
  tool_output = code_interpreter.exec_isolation(
123
174
  DefaultImports.prepend_imports(code)
124
175
  )
@@ -200,7 +251,9 @@ def get_tool_for_task(
200
251
  context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
201
252
  previous_attempts=error_message,
202
253
  )
203
- tool_choice_context_dict = extract_json(lmm.generate(prompt, media=image_paths)) # type: ignore
254
+ tool_choice_context_dict = extract_json(
255
+ lmm.generate(prompt, media=image_paths) # type: ignore
256
+ )
204
257
  tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
205
258
  tool_choice_context_dict
206
259
  )
@@ -221,36 +274,7 @@ def get_tool_documentation(tool_name: str) -> str:
221
274
  def get_tool_for_task_human_reviewer(
222
275
  task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
223
276
  ) -> None:
224
- # NOTE: this should be the same documentation as get_tool_for_task
225
- """Given a task and one or more images this function will find a tool to accomplish
226
- the jobs. It prints the tool documentation and thoughts on why it chose the tool.
227
-
228
- It can produce tools for the following types of tasks:
229
- - Object detection and counting
230
- - Classification
231
- - Segmentation
232
- - OCR
233
- - VQA
234
- - Depth and pose estimation
235
- - Video object tracking
236
-
237
- Wait until the documentation is printed to use the function so you know what the
238
- input and output signatures are.
239
-
240
- Parameters:
241
- task: str: The task to accomplish.
242
- images: List[np.ndarray]: The images to use for the task.
243
- exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
244
- recommendations. This is helpful if you are calling get_tool_for_task twice
245
- and do not want the same tool recommended.
246
-
247
- Returns:
248
- The tool to use for the task is printed to stdout
249
-
250
- Examples
251
- --------
252
- >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
253
- """
277
+ # NOTE: this will have the same documentation as get_tool_for_task
254
278
  lmm = AnthropicLMM()
255
279
 
256
280
  with (
@@ -263,8 +287,19 @@ def get_tool_for_task_human_reviewer(
263
287
  Image.fromarray(image).save(image_path)
264
288
  image_paths.append(image_path)
265
289
 
290
+ tools = [
291
+ t.__name__
292
+ for t in T.TOOLS
293
+ if inspect.signature(t).parameters.get("box_threshold") # type: ignore
294
+ ]
295
+
266
296
  _, _, tool_output = run_tool_testing(
267
- task, image_paths, lmm, exclude_tools, code_interpreter
297
+ task,
298
+ image_paths,
299
+ lmm,
300
+ exclude_tools,
301
+ code_interpreter,
302
+ process_code=lambda x: replace_box_threshold(x, tools, 0.05),
268
303
  )
269
304
 
270
305
  # need to re-display results for the outer notebook to see them
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.204
3
+ Version: 0.2.206
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -14,7 +14,7 @@ vision_agent/agent/vision_agent_coder_v2.py,sha256=nXbMsCLpKxTEi075ZE932227tW-lE
14
14
  vision_agent/agent/vision_agent_planner.py,sha256=KWMA7XemcSmc_jn-MwdWz9wnKDtj-sYQ9tINi70_OoU,18583
15
15
  vision_agent/agent/vision_agent_planner_prompts.py,sha256=Y3jz9HRf8fz9NLUseN7cTgZqewP0RazxR7vw1sPhcn0,6691
16
16
  vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=Tzon3h5iZdHJglesk8GVS-2myNf5-fhf7HUbkpZWHQk,33143
17
- vision_agent/agent/vision_agent_planner_v2.py,sha256=3sXW4A-GZ5Bg2rGheuIYspAu_N2e00Sii1f_1HJS934,20255
17
+ vision_agent/agent/vision_agent_planner_v2.py,sha256=pAtGWkY-9fFgxgO2ioebvMvASwbJ-8bAvzRNp8Z0Odc,20437
18
18
  vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
19
19
  vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
20
20
  vision_agent/agent/vision_agent_v2.py,sha256=6gGVV3FlL4NLzHRpjMqMz-fEP6f_JhwwOjUKczZ3TPA,10231
@@ -28,7 +28,7 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
28
28
  vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
29
29
  vision_agent/tools/__init__.py,sha256=xuNt5e4syQH28Vr6EdjLmO9ni9i00yav9yqcPMUx1oo,2878
30
30
  vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
31
- vision_agent/tools/planner_tools.py,sha256=amaycM_REQ4cwZCaKSyIWr-6ExqlHGEVs3PuIXjf-9M,12373
31
+ vision_agent/tools/planner_tools.py,sha256=zlzyCv7tzSOs9W-MjsptaOeM-i4eoA6HxXQWuMc1KkY,13548
32
32
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
33
33
  vision_agent/tools/tool_utils.py,sha256=AT7rMcpKwZgIErfgfSvHS0gmtvd8KMHJoHnu5aMlgO0,10259
34
34
  vision_agent/tools/tools.py,sha256=vavzmDuIBHI-g13RMDnr9NALfWpiIvJWkXhD0pnhCuk,87576
@@ -40,7 +40,7 @@ vision_agent/utils/image_utils.py,sha256=rRWcxKggPXIRXIY_XT9rZt30ECDRq8zq7FDeXRD
40
40
  vision_agent/utils/sim.py,sha256=NZc9QGD6BTY5O29NVbHH7oxDePL_QMnylT1lYcDUn1Y,7437
41
41
  vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
42
42
  vision_agent/utils/video.py,sha256=tRcGp4vEnaDycigL1hBO9k0FBPtDH35fCQciVr9GqYI,6013
43
- vision_agent-0.2.204.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
44
- vision_agent-0.2.204.dist-info/METADATA,sha256=cuwR0b_QsTgq_dle_aATNpcNC-XGl78sLY11dS9OGbg,19026
45
- vision_agent-0.2.204.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
46
- vision_agent-0.2.204.dist-info/RECORD,,
43
+ vision_agent-0.2.206.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
44
+ vision_agent-0.2.206.dist-info/METADATA,sha256=3QLRuQR4YwcTTU1y6phpkl7hLXtCIKqxYlYjF1_oNzM,19026
45
+ vision_agent-0.2.206.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
46
+ vision_agent-0.2.206.dist-info/RECORD,,