vision-agent 0.2.204__py3-none-any.whl → 0.2.206__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vision_agent/agent/vision_agent_planner_v2.py +3 -1
- vision_agent/tools/planner_tools.py +71 -36
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/METADATA +1 -1
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/RECORD +6 -6
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.204.dist-info → vision_agent-0.2.206.dist-info}/WHEEL +0 -0
@@ -367,8 +367,10 @@ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]
|
|
367
367
|
response = json.loads(chat[i + 1].content)
|
368
368
|
function_name = response["function_name"]
|
369
369
|
tool_doc = get_tool_documentation(function_name)
|
370
|
+
if "box_threshold" in response:
|
371
|
+
tool_doc = f"Use the following function with box_threshold={response['box_threshold']}\n\n{tool_doc}"
|
370
372
|
new_chat.append(AgentMessage(role="observation", content=tool_doc))
|
371
|
-
except json.JSONDecodeError:
|
373
|
+
except (json.JSONDecodeError, KeyError):
|
372
374
|
raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
|
373
375
|
else:
|
374
376
|
new_chat.append(chat_i)
|
@@ -1,6 +1,8 @@
|
|
1
|
+
import inspect
|
1
2
|
import logging
|
2
3
|
import shutil
|
3
4
|
import tempfile
|
5
|
+
from functools import lru_cache
|
4
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
5
7
|
|
6
8
|
import libcst as cst
|
@@ -31,15 +33,19 @@ from vision_agent.utils.execute import (
|
|
31
33
|
MimeType,
|
32
34
|
)
|
33
35
|
from vision_agent.utils.image_utils import convert_to_b64
|
34
|
-
from vision_agent.utils.sim import load_cached_sim
|
36
|
+
from vision_agent.utils.sim import Sim, load_cached_sim
|
35
37
|
|
36
38
|
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
37
|
-
TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
|
38
39
|
|
39
40
|
_LOGGER = logging.getLogger(__name__)
|
40
41
|
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
41
42
|
|
42
43
|
|
44
|
+
@lru_cache(maxsize=1)
|
45
|
+
def get_tool_recommender() -> Sim:
|
46
|
+
return load_cached_sim(T.TOOLS_DF)
|
47
|
+
|
48
|
+
|
43
49
|
def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
44
50
|
return_str = "[get_tool_for_task output]\n"
|
45
51
|
if tool_thoughts.strip() != "":
|
@@ -51,7 +57,7 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
51
57
|
|
52
58
|
|
53
59
|
def extract_tool_info(
|
54
|
-
tool_choice_context: Dict[str, Any]
|
60
|
+
tool_choice_context: Dict[str, Any],
|
55
61
|
) -> Tuple[Optional[Callable], str, str, str]:
|
56
62
|
tool_thoughts = tool_choice_context.get("thoughts", "")
|
57
63
|
tool_docstring = ""
|
@@ -63,12 +69,55 @@ def extract_tool_info(
|
|
63
69
|
return tool, tool_thoughts, tool_docstring, ""
|
64
70
|
|
65
71
|
|
72
|
+
def replace_box_threshold(code: str, functions: List[str], box_threshold: float) -> str:
|
73
|
+
class ReplaceBoxThresholdTransformer(cst.CSTTransformer):
|
74
|
+
def leave_Call(
|
75
|
+
self, original_node: cst.Call, updated_node: cst.Call
|
76
|
+
) -> cst.Call:
|
77
|
+
if (
|
78
|
+
isinstance(updated_node.func, cst.Name)
|
79
|
+
and updated_node.func.value in functions
|
80
|
+
) or (
|
81
|
+
isinstance(updated_node.func, cst.Attribute)
|
82
|
+
and updated_node.func.attr.value in functions
|
83
|
+
):
|
84
|
+
new_args = []
|
85
|
+
found = False
|
86
|
+
for arg in updated_node.args:
|
87
|
+
if arg.keyword and arg.keyword.value == "box_threshold":
|
88
|
+
new_arg = arg.with_changes(value=cst.Float(str(box_threshold)))
|
89
|
+
new_args.append(new_arg)
|
90
|
+
found = True
|
91
|
+
else:
|
92
|
+
new_args.append(arg)
|
93
|
+
|
94
|
+
if not found:
|
95
|
+
new_args.append(
|
96
|
+
cst.Arg(
|
97
|
+
keyword=cst.Name("box_threshold"),
|
98
|
+
value=cst.Float(str(box_threshold)),
|
99
|
+
equal=cst.AssignEqual(
|
100
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
101
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
102
|
+
),
|
103
|
+
)
|
104
|
+
)
|
105
|
+
return updated_node.with_changes(args=new_args)
|
106
|
+
return updated_node
|
107
|
+
|
108
|
+
tree = cst.parse_module(code)
|
109
|
+
transformer = ReplaceBoxThresholdTransformer()
|
110
|
+
new_tree = tree.visit(transformer)
|
111
|
+
return new_tree.code
|
112
|
+
|
113
|
+
|
66
114
|
def run_tool_testing(
|
67
115
|
task: str,
|
68
116
|
image_paths: List[str],
|
69
117
|
lmm: LMM,
|
70
118
|
exclude_tools: Optional[List[str]],
|
71
119
|
code_interpreter: CodeInterpreter,
|
120
|
+
process_code: Callable[[str], str] = lambda x: x,
|
72
121
|
) -> tuple[str, str, Execution]:
|
73
122
|
"""Helper function to generate and run tool testing code."""
|
74
123
|
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
@@ -80,7 +129,7 @@ def run_tool_testing(
|
|
80
129
|
f"I need models from the {category.strip()} category of tools. {task}"
|
81
130
|
)
|
82
131
|
|
83
|
-
tool_docs =
|
132
|
+
tool_docs = get_tool_recommender().top_k(category, k=10, thresh=0.2)
|
84
133
|
if exclude_tools is not None and len(exclude_tools) > 0:
|
85
134
|
cleaned_tool_docs = []
|
86
135
|
for tool_doc in tool_docs:
|
@@ -101,6 +150,7 @@ def run_tool_testing(
|
|
101
150
|
code = extract_tag(response, "code") # type: ignore
|
102
151
|
if code is None:
|
103
152
|
raise ValueError(f"Could not extract code from response: {response}")
|
153
|
+
code = process_code(code)
|
104
154
|
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
|
105
155
|
tool_output_str = tool_output.text(include_results=False).strip()
|
106
156
|
|
@@ -119,6 +169,7 @@ def run_tool_testing(
|
|
119
169
|
media=str(image_paths),
|
120
170
|
)
|
121
171
|
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
|
172
|
+
code = process_code(code)
|
122
173
|
tool_output = code_interpreter.exec_isolation(
|
123
174
|
DefaultImports.prepend_imports(code)
|
124
175
|
)
|
@@ -200,7 +251,9 @@ def get_tool_for_task(
|
|
200
251
|
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
201
252
|
previous_attempts=error_message,
|
202
253
|
)
|
203
|
-
tool_choice_context_dict = extract_json(
|
254
|
+
tool_choice_context_dict = extract_json(
|
255
|
+
lmm.generate(prompt, media=image_paths) # type: ignore
|
256
|
+
)
|
204
257
|
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
205
258
|
tool_choice_context_dict
|
206
259
|
)
|
@@ -221,36 +274,7 @@ def get_tool_documentation(tool_name: str) -> str:
|
|
221
274
|
def get_tool_for_task_human_reviewer(
|
222
275
|
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
223
276
|
) -> None:
|
224
|
-
# NOTE: this
|
225
|
-
"""Given a task and one or more images this function will find a tool to accomplish
|
226
|
-
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
|
227
|
-
|
228
|
-
It can produce tools for the following types of tasks:
|
229
|
-
- Object detection and counting
|
230
|
-
- Classification
|
231
|
-
- Segmentation
|
232
|
-
- OCR
|
233
|
-
- VQA
|
234
|
-
- Depth and pose estimation
|
235
|
-
- Video object tracking
|
236
|
-
|
237
|
-
Wait until the documentation is printed to use the function so you know what the
|
238
|
-
input and output signatures are.
|
239
|
-
|
240
|
-
Parameters:
|
241
|
-
task: str: The task to accomplish.
|
242
|
-
images: List[np.ndarray]: The images to use for the task.
|
243
|
-
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
|
244
|
-
recommendations. This is helpful if you are calling get_tool_for_task twice
|
245
|
-
and do not want the same tool recommended.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
The tool to use for the task is printed to stdout
|
249
|
-
|
250
|
-
Examples
|
251
|
-
--------
|
252
|
-
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
253
|
-
"""
|
277
|
+
# NOTE: this will have the same documentation as get_tool_for_task
|
254
278
|
lmm = AnthropicLMM()
|
255
279
|
|
256
280
|
with (
|
@@ -263,8 +287,19 @@ def get_tool_for_task_human_reviewer(
|
|
263
287
|
Image.fromarray(image).save(image_path)
|
264
288
|
image_paths.append(image_path)
|
265
289
|
|
290
|
+
tools = [
|
291
|
+
t.__name__
|
292
|
+
for t in T.TOOLS
|
293
|
+
if inspect.signature(t).parameters.get("box_threshold") # type: ignore
|
294
|
+
]
|
295
|
+
|
266
296
|
_, _, tool_output = run_tool_testing(
|
267
|
-
task,
|
297
|
+
task,
|
298
|
+
image_paths,
|
299
|
+
lmm,
|
300
|
+
exclude_tools,
|
301
|
+
code_interpreter,
|
302
|
+
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
|
268
303
|
)
|
269
304
|
|
270
305
|
# need to re-display results for the outer notebook to see them
|
@@ -14,7 +14,7 @@ vision_agent/agent/vision_agent_coder_v2.py,sha256=nXbMsCLpKxTEi075ZE932227tW-lE
|
|
14
14
|
vision_agent/agent/vision_agent_planner.py,sha256=KWMA7XemcSmc_jn-MwdWz9wnKDtj-sYQ9tINi70_OoU,18583
|
15
15
|
vision_agent/agent/vision_agent_planner_prompts.py,sha256=Y3jz9HRf8fz9NLUseN7cTgZqewP0RazxR7vw1sPhcn0,6691
|
16
16
|
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=Tzon3h5iZdHJglesk8GVS-2myNf5-fhf7HUbkpZWHQk,33143
|
17
|
-
vision_agent/agent/vision_agent_planner_v2.py,sha256=
|
17
|
+
vision_agent/agent/vision_agent_planner_v2.py,sha256=pAtGWkY-9fFgxgO2ioebvMvASwbJ-8bAvzRNp8Z0Odc,20437
|
18
18
|
vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
|
19
19
|
vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
|
20
20
|
vision_agent/agent/vision_agent_v2.py,sha256=6gGVV3FlL4NLzHRpjMqMz-fEP6f_JhwwOjUKczZ3TPA,10231
|
@@ -28,7 +28,7 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
|
|
28
28
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
29
29
|
vision_agent/tools/__init__.py,sha256=xuNt5e4syQH28Vr6EdjLmO9ni9i00yav9yqcPMUx1oo,2878
|
30
30
|
vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
|
31
|
-
vision_agent/tools/planner_tools.py,sha256=
|
31
|
+
vision_agent/tools/planner_tools.py,sha256=zlzyCv7tzSOs9W-MjsptaOeM-i4eoA6HxXQWuMc1KkY,13548
|
32
32
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
33
33
|
vision_agent/tools/tool_utils.py,sha256=AT7rMcpKwZgIErfgfSvHS0gmtvd8KMHJoHnu5aMlgO0,10259
|
34
34
|
vision_agent/tools/tools.py,sha256=vavzmDuIBHI-g13RMDnr9NALfWpiIvJWkXhD0pnhCuk,87576
|
@@ -40,7 +40,7 @@ vision_agent/utils/image_utils.py,sha256=rRWcxKggPXIRXIY_XT9rZt30ECDRq8zq7FDeXRD
|
|
40
40
|
vision_agent/utils/sim.py,sha256=NZc9QGD6BTY5O29NVbHH7oxDePL_QMnylT1lYcDUn1Y,7437
|
41
41
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
42
42
|
vision_agent/utils/video.py,sha256=tRcGp4vEnaDycigL1hBO9k0FBPtDH35fCQciVr9GqYI,6013
|
43
|
-
vision_agent-0.2.
|
44
|
-
vision_agent-0.2.
|
45
|
-
vision_agent-0.2.
|
46
|
-
vision_agent-0.2.
|
43
|
+
vision_agent-0.2.206.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
44
|
+
vision_agent-0.2.206.dist-info/METADATA,sha256=3QLRuQR4YwcTTU1y6phpkl7hLXtCIKqxYlYjF1_oNzM,19026
|
45
|
+
vision_agent-0.2.206.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
46
|
+
vision_agent-0.2.206.dist-info/RECORD,,
|
File without changes
|
File without changes
|