vision-agent 0.2.204__py3-none-any.whl → 0.2.206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent_planner_v2.py +3 -1
- vision_agent/tools/planner_tools.py +71 -36
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/METADATA +1 -1
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/RECORD +6 -6
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/WHEEL +0 -0
@@ -367,8 +367,10 @@ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]
|
|
367
367
|
response = json.loads(chat[i + 1].content)
|
368
368
|
function_name = response["function_name"]
|
369
369
|
tool_doc = get_tool_documentation(function_name)
|
370
|
+
if "box_threshold" in response:
|
371
|
+
tool_doc = f"Use the following function with box_threshold={response['box_threshold']}\n\n{tool_doc}"
|
370
372
|
new_chat.append(AgentMessage(role="observation", content=tool_doc))
|
371
|
-
except json.JSONDecodeError:
|
373
|
+
except (json.JSONDecodeError, KeyError):
|
372
374
|
raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
|
373
375
|
else:
|
374
376
|
new_chat.append(chat_i)
|
@@ -1,6 +1,8 @@
|
|
1
|
+
import inspect
|
1
2
|
import logging
|
2
3
|
import shutil
|
3
4
|
import tempfile
|
5
|
+
from functools import lru_cache
|
4
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
5
7
|
|
6
8
|
import libcst as cst
|
@@ -31,15 +33,19 @@ from vision_agent.utils.execute import (
|
|
31
33
|
MimeType,
|
32
34
|
)
|
33
35
|
from vision_agent.utils.image_utils import convert_to_b64
|
34
|
-
from vision_agent.utils.sim import load_cached_sim
|
36
|
+
from vision_agent.utils.sim import Sim, load_cached_sim
|
35
37
|
|
36
38
|
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
37
|
-
TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
|
38
39
|
|
39
40
|
_LOGGER = logging.getLogger(__name__)
|
40
41
|
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
41
42
|
|
42
43
|
|
44
|
+
@lru_cache(maxsize=1)
|
45
|
+
def get_tool_recommender() -> Sim:
|
46
|
+
return load_cached_sim(T.TOOLS_DF)
|
47
|
+
|
48
|
+
|
43
49
|
def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
44
50
|
return_str = "[get_tool_for_task output]\n"
|
45
51
|
if tool_thoughts.strip() != "":
|
@@ -51,7 +57,7 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
51
57
|
|
52
58
|
|
53
59
|
def extract_tool_info(
|
54
|
-
tool_choice_context: Dict[str, Any]
|
60
|
+
tool_choice_context: Dict[str, Any],
|
55
61
|
) -> Tuple[Optional[Callable], str, str, str]:
|
56
62
|
tool_thoughts = tool_choice_context.get("thoughts", "")
|
57
63
|
tool_docstring = ""
|
@@ -63,12 +69,55 @@ def extract_tool_info(
|
|
63
69
|
return tool, tool_thoughts, tool_docstring, ""
|
64
70
|
|
65
71
|
|
72
|
+
def replace_box_threshold(code: str, functions: List[str], box_threshold: float) -> str:
|
73
|
+
class ReplaceBoxThresholdTransformer(cst.CSTTransformer):
|
74
|
+
def leave_Call(
|
75
|
+
self, original_node: cst.Call, updated_node: cst.Call
|
76
|
+
) -> cst.Call:
|
77
|
+
if (
|
78
|
+
isinstance(updated_node.func, cst.Name)
|
79
|
+
and updated_node.func.value in functions
|
80
|
+
) or (
|
81
|
+
isinstance(updated_node.func, cst.Attribute)
|
82
|
+
and updated_node.func.attr.value in functions
|
83
|
+
):
|
84
|
+
new_args = []
|
85
|
+
found = False
|
86
|
+
for arg in updated_node.args:
|
87
|
+
if arg.keyword and arg.keyword.value == "box_threshold":
|
88
|
+
new_arg = arg.with_changes(value=cst.Float(str(box_threshold)))
|
89
|
+
new_args.append(new_arg)
|
90
|
+
found = True
|
91
|
+
else:
|
92
|
+
new_args.append(arg)
|
93
|
+
|
94
|
+
if not found:
|
95
|
+
new_args.append(
|
96
|
+
cst.Arg(
|
97
|
+
keyword=cst.Name("box_threshold"),
|
98
|
+
value=cst.Float(str(box_threshold)),
|
99
|
+
equal=cst.AssignEqual(
|
100
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
101
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
102
|
+
),
|
103
|
+
)
|
104
|
+
)
|
105
|
+
return updated_node.with_changes(args=new_args)
|
106
|
+
return updated_node
|
107
|
+
|
108
|
+
tree = cst.parse_module(code)
|
109
|
+
transformer = ReplaceBoxThresholdTransformer()
|
110
|
+
new_tree = tree.visit(transformer)
|
111
|
+
return new_tree.code
|
112
|
+
|
113
|
+
|
66
114
|
def run_tool_testing(
|
67
115
|
task: str,
|
68
116
|
image_paths: List[str],
|
69
117
|
lmm: LMM,
|
70
118
|
exclude_tools: Optional[List[str]],
|
71
119
|
code_interpreter: CodeInterpreter,
|
120
|
+
process_code: Callable[[str], str] = lambda x: x,
|
72
121
|
) -> tuple[str, str, Execution]:
|
73
122
|
"""Helper function to generate and run tool testing code."""
|
74
123
|
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
@@ -80,7 +129,7 @@ def run_tool_testing(
|
|
80
129
|
f"I need models from the {category.strip()} category of tools. {task}"
|
81
130
|
)
|
82
131
|
|
83
|
-
tool_docs =
|
132
|
+
tool_docs = get_tool_recommender().top_k(category, k=10, thresh=0.2)
|
84
133
|
if exclude_tools is not None and len(exclude_tools) > 0:
|
85
134
|
cleaned_tool_docs = []
|
86
135
|
for tool_doc in tool_docs:
|
@@ -101,6 +150,7 @@ def run_tool_testing(
|
|
101
150
|
code = extract_tag(response, "code") # type: ignore
|
102
151
|
if code is None:
|
103
152
|
raise ValueError(f"Could not extract code from response: {response}")
|
153
|
+
code = process_code(code)
|
104
154
|
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
|
105
155
|
tool_output_str = tool_output.text(include_results=False).strip()
|
106
156
|
|
@@ -119,6 +169,7 @@ def run_tool_testing(
|
|
119
169
|
media=str(image_paths),
|
120
170
|
)
|
121
171
|
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
|
172
|
+
code = process_code(code)
|
122
173
|
tool_output = code_interpreter.exec_isolation(
|
123
174
|
DefaultImports.prepend_imports(code)
|
124
175
|
)
|
@@ -200,7 +251,9 @@ def get_tool_for_task(
|
|
200
251
|
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
201
252
|
previous_attempts=error_message,
|
202
253
|
)
|
203
|
-
tool_choice_context_dict = extract_json(
|
254
|
+
tool_choice_context_dict = extract_json(
|
255
|
+
lmm.generate(prompt, media=image_paths) # type: ignore
|
256
|
+
)
|
204
257
|
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
205
258
|
tool_choice_context_dict
|
206
259
|
)
|
@@ -221,36 +274,7 @@ def get_tool_documentation(tool_name: str) -> str:
|
|
221
274
|
def get_tool_for_task_human_reviewer(
|
222
275
|
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
223
276
|
) -> None:
|
224
|
-
# NOTE: this
|
225
|
-
"""Given a task and one or more images this function will find a tool to accomplish
|
226
|
-
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
|
227
|
-
|
228
|
-
It can produce tools for the following types of tasks:
|
229
|
-
- Object detection and counting
|
230
|
-
- Classification
|
231
|
-
- Segmentation
|
232
|
-
- OCR
|
233
|
-
- VQA
|
234
|
-
- Depth and pose estimation
|
235
|
-
- Video object tracking
|
236
|
-
|
237
|
-
Wait until the documentation is printed to use the function so you know what the
|
238
|
-
input and output signatures are.
|
239
|
-
|
240
|
-
Parameters:
|
241
|
-
task: str: The task to accomplish.
|
242
|
-
images: List[np.ndarray]: The images to use for the task.
|
243
|
-
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
|
244
|
-
recommendations. This is helpful if you are calling get_tool_for_task twice
|
245
|
-
and do not want the same tool recommended.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
The tool to use for the task is printed to stdout
|
249
|
-
|
250
|
-
Examples
|
251
|
-
--------
|
252
|
-
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
253
|
-
"""
|
277
|
+
# NOTE: this will have the same documentation as get_tool_for_task
|
254
278
|
lmm = AnthropicLMM()
|
255
279
|
|
256
280
|
with (
|
@@ -263,8 +287,19 @@ def get_tool_for_task_human_reviewer(
|
|
263
287
|
Image.fromarray(image).save(image_path)
|
264
288
|
image_paths.append(image_path)
|
265
289
|
|
290
|
+
tools = [
|
291
|
+
t.__name__
|
292
|
+
for t in T.TOOLS
|
293
|
+
if inspect.signature(t).parameters.get("box_threshold") # type: ignore
|
294
|
+
]
|
295
|
+
|
266
296
|
_, _, tool_output = run_tool_testing(
|
267
|
-
task,
|
297
|
+
task,
|
298
|
+
image_paths,
|
299
|
+
lmm,
|
300
|
+
exclude_tools,
|
301
|
+
code_interpreter,
|
302
|
+
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
|
268
303
|
)
|
269
304
|
|
270
305
|
# need to re-display results for the outer notebook to see them
|
@@ -14,7 +14,7 @@ vision_agent/agent/vision_agent_coder_v2.py,sha256=nXbMsCLpKxTEi075ZE932227tW-lE
|
|
14
14
|
vision_agent/agent/vision_agent_planner.py,sha256=KWMA7XemcSmc_jn-MwdWz9wnKDtj-sYQ9tINi70_OoU,18583
|
15
15
|
vision_agent/agent/vision_agent_planner_prompts.py,sha256=Y3jz9HRf8fz9NLUseN7cTgZqewP0RazxR7vw1sPhcn0,6691
|
16
16
|
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=Tzon3h5iZdHJglesk8GVS-2myNf5-fhf7HUbkpZWHQk,33143
|
17
|
-
vision_agent/agent/vision_agent_planner_v2.py,sha256=
|
17
|
+
vision_agent/agent/vision_agent_planner_v2.py,sha256=pAtGWkY-9fFgxgO2ioebvMvASwbJ-8bAvzRNp8Z0Odc,20437
|
18
18
|
vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
|
19
19
|
vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
|
20
20
|
vision_agent/agent/vision_agent_v2.py,sha256=6gGVV3FlL4NLzHRpjMqMz-fEP6f_JhwwOjUKczZ3TPA,10231
|
@@ -28,7 +28,7 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
|
|
28
28
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
29
29
|
vision_agent/tools/__init__.py,sha256=xuNt5e4syQH28Vr6EdjLmO9ni9i00yav9yqcPMUx1oo,2878
|
30
30
|
vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
|
31
|
-
vision_agent/tools/planner_tools.py,sha256=
|
31
|
+
vision_agent/tools/planner_tools.py,sha256=zlzyCv7tzSOs9W-MjsptaOeM-i4eoA6HxXQWuMc1KkY,13548
|
32
32
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
33
33
|
vision_agent/tools/tool_utils.py,sha256=AT7rMcpKwZgIErfgfSvHS0gmtvd8KMHJoHnu5aMlgO0,10259
|
34
34
|
vision_agent/tools/tools.py,sha256=vavzmDuIBHI-g13RMDnr9NALfWpiIvJWkXhD0pnhCuk,87576
|
@@ -40,7 +40,7 @@ vision_agent/utils/image_utils.py,sha256=rRWcxKggPXIRXIY_XT9rZt30ECDRq8zq7FDeXRD
|
|
40
40
|
vision_agent/utils/sim.py,sha256=NZc9QGD6BTY5O29NVbHH7oxDePL_QMnylT1lYcDUn1Y,7437
|
41
41
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
42
42
|
vision_agent/utils/video.py,sha256=tRcGp4vEnaDycigL1hBO9k0FBPtDH35fCQciVr9GqYI,6013
|
43
|
-
vision_agent-0.2.
|
44
|
-
vision_agent-0.2.
|
45
|
-
vision_agent-0.2.
|
46
|
-
vision_agent-0.2.
|
43
|
+
vision_agent-0.2.206.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
44
|
+
vision_agent-0.2.206.dist-info/METADATA,sha256=3QLRuQR4YwcTTU1y6phpkl7hLXtCIKqxYlYjF1_oNzM,19026
|
45
|
+
vision_agent-0.2.206.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
46
|
+
vision_agent-0.2.206.dist-info/RECORD,,
|
File without changes
|
File without changes
|