vision-agent 0.2.203__py3-none-any.whl → 0.2.205__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent_planner_v2.py +3 -1
- vision_agent/agent/vision_agent_prompts.py +1 -1
- vision_agent/tools/planner_tools.py +59 -31
- {vision_agent-0.2.203.dist-info → vision_agent-0.2.205.dist-info}/METADATA +1 -1
- {vision_agent-0.2.203.dist-info → vision_agent-0.2.205.dist-info}/RECORD +7 -7
- {vision_agent-0.2.203.dist-info → vision_agent-0.2.205.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.203.dist-info → vision_agent-0.2.205.dist-info}/WHEEL +0 -0
@@ -367,8 +367,10 @@ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]
|
|
367
367
|
response = json.loads(chat[i + 1].content)
|
368
368
|
function_name = response["function_name"]
|
369
369
|
tool_doc = get_tool_documentation(function_name)
|
370
|
+
if "box_threshold" in response:
|
371
|
+
tool_doc = f"Use the following function with box_threshold={response['box_threshold']}\n\n{tool_doc}"
|
370
372
|
new_chat.append(AgentMessage(role="observation", content=tool_doc))
|
371
|
-
except json.JSONDecodeError:
|
373
|
+
except (json.JSONDecodeError, KeyError):
|
372
374
|
raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
|
373
375
|
else:
|
374
376
|
new_chat.append(chat_i)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import inspect
|
1
2
|
import logging
|
2
3
|
import shutil
|
3
4
|
import tempfile
|
@@ -63,12 +64,55 @@ def extract_tool_info(
|
|
63
64
|
return tool, tool_thoughts, tool_docstring, ""
|
64
65
|
|
65
66
|
|
67
|
+
def replace_box_threshold(code: str, functions: List[str], box_threshold: float) -> str:
|
68
|
+
class ReplaceBoxThresholdTransformer(cst.CSTTransformer):
|
69
|
+
def leave_Call(
|
70
|
+
self, original_node: cst.Call, updated_node: cst.Call
|
71
|
+
) -> cst.Call:
|
72
|
+
if (
|
73
|
+
isinstance(updated_node.func, cst.Name)
|
74
|
+
and updated_node.func.value in functions
|
75
|
+
) or (
|
76
|
+
isinstance(updated_node.func, cst.Attribute)
|
77
|
+
and updated_node.func.attr.value in functions
|
78
|
+
):
|
79
|
+
new_args = []
|
80
|
+
found = False
|
81
|
+
for arg in updated_node.args:
|
82
|
+
if arg.keyword and arg.keyword.value == "box_threshold":
|
83
|
+
new_arg = arg.with_changes(value=cst.Float(str(box_threshold)))
|
84
|
+
new_args.append(new_arg)
|
85
|
+
found = True
|
86
|
+
else:
|
87
|
+
new_args.append(arg)
|
88
|
+
|
89
|
+
if not found:
|
90
|
+
new_args.append(
|
91
|
+
cst.Arg(
|
92
|
+
keyword=cst.Name("box_threshold"),
|
93
|
+
value=cst.Float(str(box_threshold)),
|
94
|
+
equal=cst.AssignEqual(
|
95
|
+
whitespace_before=cst.SimpleWhitespace(""),
|
96
|
+
whitespace_after=cst.SimpleWhitespace(""),
|
97
|
+
),
|
98
|
+
)
|
99
|
+
)
|
100
|
+
return updated_node.with_changes(args=new_args)
|
101
|
+
return updated_node
|
102
|
+
|
103
|
+
tree = cst.parse_module(code)
|
104
|
+
transformer = ReplaceBoxThresholdTransformer()
|
105
|
+
new_tree = tree.visit(transformer)
|
106
|
+
return new_tree.code
|
107
|
+
|
108
|
+
|
66
109
|
def run_tool_testing(
|
67
110
|
task: str,
|
68
111
|
image_paths: List[str],
|
69
112
|
lmm: LMM,
|
70
113
|
exclude_tools: Optional[List[str]],
|
71
114
|
code_interpreter: CodeInterpreter,
|
115
|
+
process_code: Callable[[str], str] = lambda x: x,
|
72
116
|
) -> tuple[str, str, Execution]:
|
73
117
|
"""Helper function to generate and run tool testing code."""
|
74
118
|
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
@@ -101,6 +145,7 @@ def run_tool_testing(
|
|
101
145
|
code = extract_tag(response, "code") # type: ignore
|
102
146
|
if code is None:
|
103
147
|
raise ValueError(f"Could not extract code from response: {response}")
|
148
|
+
code = process_code(code)
|
104
149
|
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
|
105
150
|
tool_output_str = tool_output.text(include_results=False).strip()
|
106
151
|
|
@@ -119,6 +164,7 @@ def run_tool_testing(
|
|
119
164
|
media=str(image_paths),
|
120
165
|
)
|
121
166
|
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
|
167
|
+
code = process_code(code)
|
122
168
|
tool_output = code_interpreter.exec_isolation(
|
123
169
|
DefaultImports.prepend_imports(code)
|
124
170
|
)
|
@@ -221,36 +267,7 @@ def get_tool_documentation(tool_name: str) -> str:
|
|
221
267
|
def get_tool_for_task_human_reviewer(
|
222
268
|
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
223
269
|
) -> None:
|
224
|
-
# NOTE: this
|
225
|
-
"""Given a task and one or more images this function will find a tool to accomplish
|
226
|
-
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
|
227
|
-
|
228
|
-
It can produce tools for the following types of tasks:
|
229
|
-
- Object detection and counting
|
230
|
-
- Classification
|
231
|
-
- Segmentation
|
232
|
-
- OCR
|
233
|
-
- VQA
|
234
|
-
- Depth and pose estimation
|
235
|
-
- Video object tracking
|
236
|
-
|
237
|
-
Wait until the documentation is printed to use the function so you know what the
|
238
|
-
input and output signatures are.
|
239
|
-
|
240
|
-
Parameters:
|
241
|
-
task: str: The task to accomplish.
|
242
|
-
images: List[np.ndarray]: The images to use for the task.
|
243
|
-
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
|
244
|
-
recommendations. This is helpful if you are calling get_tool_for_task twice
|
245
|
-
and do not want the same tool recommended.
|
246
|
-
|
247
|
-
Returns:
|
248
|
-
The tool to use for the task is printed to stdout
|
249
|
-
|
250
|
-
Examples
|
251
|
-
--------
|
252
|
-
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
253
|
-
"""
|
270
|
+
# NOTE: this will have the same documentation as get_tool_for_task
|
254
271
|
lmm = AnthropicLMM()
|
255
272
|
|
256
273
|
with (
|
@@ -263,8 +280,19 @@ def get_tool_for_task_human_reviewer(
|
|
263
280
|
Image.fromarray(image).save(image_path)
|
264
281
|
image_paths.append(image_path)
|
265
282
|
|
283
|
+
tools = [
|
284
|
+
t.__name__
|
285
|
+
for t in T.TOOLS
|
286
|
+
if inspect.signature(t).parameters.get("box_threshold") # type: ignore
|
287
|
+
]
|
288
|
+
|
266
289
|
_, _, tool_output = run_tool_testing(
|
267
|
-
task,
|
290
|
+
task,
|
291
|
+
image_paths,
|
292
|
+
lmm,
|
293
|
+
exclude_tools,
|
294
|
+
code_interpreter,
|
295
|
+
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
|
268
296
|
)
|
269
297
|
|
270
298
|
# need to re-display results for the outer notebook to see them
|
@@ -14,8 +14,8 @@ vision_agent/agent/vision_agent_coder_v2.py,sha256=nXbMsCLpKxTEi075ZE932227tW-lE
|
|
14
14
|
vision_agent/agent/vision_agent_planner.py,sha256=KWMA7XemcSmc_jn-MwdWz9wnKDtj-sYQ9tINi70_OoU,18583
|
15
15
|
vision_agent/agent/vision_agent_planner_prompts.py,sha256=Y3jz9HRf8fz9NLUseN7cTgZqewP0RazxR7vw1sPhcn0,6691
|
16
16
|
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=Tzon3h5iZdHJglesk8GVS-2myNf5-fhf7HUbkpZWHQk,33143
|
17
|
-
vision_agent/agent/vision_agent_planner_v2.py,sha256=
|
18
|
-
vision_agent/agent/vision_agent_prompts.py,sha256=
|
17
|
+
vision_agent/agent/vision_agent_planner_v2.py,sha256=pAtGWkY-9fFgxgO2ioebvMvASwbJ-8bAvzRNp8Z0Odc,20437
|
18
|
+
vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
|
19
19
|
vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
|
20
20
|
vision_agent/agent/vision_agent_v2.py,sha256=6gGVV3FlL4NLzHRpjMqMz-fEP6f_JhwwOjUKczZ3TPA,10231
|
21
21
|
vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -28,7 +28,7 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
|
|
28
28
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
29
29
|
vision_agent/tools/__init__.py,sha256=xuNt5e4syQH28Vr6EdjLmO9ni9i00yav9yqcPMUx1oo,2878
|
30
30
|
vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
|
31
|
-
vision_agent/tools/planner_tools.py,sha256=
|
31
|
+
vision_agent/tools/planner_tools.py,sha256=MYYUN9WwEHkjFq_TF2rDVfOHOM0Ko460pxg970loojc,13423
|
32
32
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
33
33
|
vision_agent/tools/tool_utils.py,sha256=AT7rMcpKwZgIErfgfSvHS0gmtvd8KMHJoHnu5aMlgO0,10259
|
34
34
|
vision_agent/tools/tools.py,sha256=vavzmDuIBHI-g13RMDnr9NALfWpiIvJWkXhD0pnhCuk,87576
|
@@ -40,7 +40,7 @@ vision_agent/utils/image_utils.py,sha256=rRWcxKggPXIRXIY_XT9rZt30ECDRq8zq7FDeXRD
|
|
40
40
|
vision_agent/utils/sim.py,sha256=NZc9QGD6BTY5O29NVbHH7oxDePL_QMnylT1lYcDUn1Y,7437
|
41
41
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
42
42
|
vision_agent/utils/video.py,sha256=tRcGp4vEnaDycigL1hBO9k0FBPtDH35fCQciVr9GqYI,6013
|
43
|
-
vision_agent-0.2.
|
44
|
-
vision_agent-0.2.
|
45
|
-
vision_agent-0.2.
|
46
|
-
vision_agent-0.2.
|
43
|
+
vision_agent-0.2.205.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
44
|
+
vision_agent-0.2.205.dist-info/METADATA,sha256=BCcmFsPZJi6CHOTsNfAgqkHfz1oLowbZjdpQKAWvj94,19026
|
45
|
+
vision_agent-0.2.205.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
46
|
+
vision_agent-0.2.205.dist-info/RECORD,,
|
File without changes
|
File without changes
|