vision-agent 0.2.236__py3-none-any.whl → 0.2.237__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +57 -80
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/agent.py +2 -2
- vision_agent/agent/vision_agent.py +3 -2
- vision_agent/agent/vision_agent_coder.py +13 -19
- vision_agent/agent/vision_agent_coder_v2.py +17 -17
- vision_agent/agent/vision_agent_planner.py +16 -21
- vision_agent/agent/vision_agent_planner_prompts_v2.py +19 -20
- vision_agent/agent/vision_agent_planner_v2.py +29 -15
- vision_agent/agent/vision_agent_v2.py +12 -12
- vision_agent/clients/landing_public_api.py +1 -1
- vision_agent/configs/config.py +17 -3
- vision_agent/lmm/__init__.py +0 -1
- vision_agent/lmm/lmm.py +4 -3
- vision_agent/models/__init__.py +11 -0
- vision_agent/{lmm/types.py → models/lmm_types.py} +4 -1
- vision_agent/sim/__init__.py +8 -0
- vision_agent/{utils → sim}/sim.py +3 -3
- vision_agent/tools/__init__.py +10 -23
- vision_agent/tools/meta_tools.py +4 -5
- vision_agent/tools/planner_tools.py +127 -37
- vision_agent/tools/tools.py +388 -302
- vision_agent/utils/__init__.py +0 -1
- vision_agent/{agent/agent_utils.py → utils/agent.py} +11 -2
- vision_agent/utils/image_utils.py +18 -7
- vision_agent/{tools/tool_utils.py → utils/tools.py} +1 -93
- vision_agent/utils/tools_doc.py +87 -0
- vision_agent/utils/video.py +15 -0
- vision_agent/utils/video_tracking.py +38 -5
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.237.dist-info}/METADATA +2 -2
- vision_agent-0.2.237.dist-info/RECORD +55 -0
- vision_agent-0.2.236.dist-info/RECORD +0 -52
- /vision_agent/{agent/types.py → models/agent_types.py} +0 -0
- /vision_agent/{tools → models}/tools_types.py +0 -0
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.237.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.237.dist-info}/WHEEL +0 -0
@@ -4,23 +4,23 @@ from pathlib import Path
|
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
5
5
|
|
6
6
|
from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
|
7
|
-
from vision_agent.agent.
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
)
|
13
|
-
from vision_agent.agent.types import (
|
7
|
+
from vision_agent.agent.vision_agent_coder_v2 import format_code_context
|
8
|
+
from vision_agent.agent.vision_agent_prompts_v2 import CONVERSATION
|
9
|
+
from vision_agent.configs import Config
|
10
|
+
from vision_agent.lmm import LMM
|
11
|
+
from vision_agent.models import (
|
14
12
|
AgentMessage,
|
15
13
|
CodeContext,
|
16
14
|
InteractionContext,
|
15
|
+
Message,
|
17
16
|
PlanContext,
|
18
17
|
)
|
19
|
-
from vision_agent.agent
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
18
|
+
from vision_agent.utils.agent import (
|
19
|
+
add_media_to_chat,
|
20
|
+
convert_message_to_agentmessage,
|
21
|
+
extract_tag,
|
22
|
+
format_conversation,
|
23
|
+
)
|
24
24
|
from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
|
25
25
|
|
26
26
|
CONFIG = Config()
|
@@ -5,7 +5,7 @@ from uuid import UUID
|
|
5
5
|
from requests.exceptions import HTTPError
|
6
6
|
|
7
7
|
from vision_agent.clients.http import BaseHTTP
|
8
|
-
from vision_agent.
|
8
|
+
from vision_agent.models import BboxInputBase64, JobStatus, PromptTask
|
9
9
|
from vision_agent.utils.exceptions import FineTuneModelNotFound
|
10
10
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
11
11
|
|
vision_agent/configs/config.py
CHANGED
@@ -96,13 +96,24 @@ class Config(BaseModel):
|
|
96
96
|
}
|
97
97
|
)
|
98
98
|
|
99
|
+
# for get_tool_for_task
|
100
|
+
od_judge: Type[LMM] = Field(default=AnthropicLMM)
|
101
|
+
od_judge_kwargs: dict = Field(
|
102
|
+
default_factory=lambda: {
|
103
|
+
"model_name": "claude-3-5-sonnet-20241022",
|
104
|
+
"temperature": 0.0,
|
105
|
+
"image_size": 512,
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
99
109
|
# for suggestions module
|
100
|
-
suggester: Type[LMM] = Field(default=
|
110
|
+
suggester: Type[LMM] = Field(default=OpenAILMM)
|
101
111
|
suggester_kwargs: dict = Field(
|
102
112
|
default_factory=lambda: {
|
103
|
-
"model_name": "
|
113
|
+
"model_name": "o1",
|
104
114
|
"temperature": 1.0,
|
105
|
-
"
|
115
|
+
"image_detail": "high",
|
116
|
+
"image_size": 1024,
|
106
117
|
}
|
107
118
|
)
|
108
119
|
|
@@ -143,6 +154,9 @@ class Config(BaseModel):
|
|
143
154
|
def create_tool_chooser(self) -> LMM:
|
144
155
|
return self.tool_chooser(**self.tool_chooser_kwargs)
|
145
156
|
|
157
|
+
def create_od_judge(self) -> LMM:
|
158
|
+
return self.od_judge(**self.od_judge_kwargs)
|
159
|
+
|
146
160
|
def create_suggester(self) -> LMM:
|
147
161
|
return self.suggester(**self.suggester_kwargs)
|
148
162
|
|
vision_agent/lmm/__init__.py
CHANGED
vision_agent/lmm/lmm.py
CHANGED
@@ -9,10 +9,9 @@ import requests
|
|
9
9
|
from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam
|
10
10
|
from openai import AzureOpenAI, OpenAI
|
11
11
|
|
12
|
+
from vision_agent.models import Message
|
12
13
|
from vision_agent.utils.image_utils import encode_media
|
13
14
|
|
14
|
-
from .types import Message
|
15
|
-
|
16
15
|
|
17
16
|
class LMM(ABC):
|
18
17
|
@abstractmethod
|
@@ -64,7 +63,9 @@ class OpenAILMM(LMM):
|
|
64
63
|
self.image_size = image_size
|
65
64
|
self.image_detail = image_detail
|
66
65
|
# o1 does not use max_tokens
|
67
|
-
if "max_tokens" not in kwargs and not
|
66
|
+
if "max_tokens" not in kwargs and not (
|
67
|
+
model_name.startswith("o1") or model_name.startswith("o3")
|
68
|
+
):
|
68
69
|
kwargs["max_tokens"] = max_tokens
|
69
70
|
if json_mode:
|
70
71
|
kwargs["response_format"] = {"type": "json_object"}
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .agent_types import AgentMessage, CodeContext, InteractionContext, PlanContext
|
2
|
+
from .lmm_types import Message, TextOrImage
|
3
|
+
from .tools_types import (
|
4
|
+
BboxInput,
|
5
|
+
BboxInputBase64,
|
6
|
+
BoundingBoxes,
|
7
|
+
Florence2FtRequest,
|
8
|
+
JobStatus,
|
9
|
+
ODResponseData,
|
10
|
+
PromptTask,
|
11
|
+
)
|
@@ -1,7 +1,10 @@
|
|
1
1
|
from pathlib import Path
|
2
2
|
from typing import Dict, Sequence, Union
|
3
3
|
|
4
|
+
import numpy as np
|
5
|
+
from PIL.Image import Image as ImageType
|
6
|
+
|
4
7
|
from vision_agent.utils.execute import Execution
|
5
8
|
|
6
|
-
TextOrImage = Union[str, Sequence[Union[str, Path]]]
|
9
|
+
TextOrImage = Union[str, Sequence[Union[str, Path, ImageType, np.ndarray]]]
|
7
10
|
Message = Dict[str, Union[TextOrImage, Execution]]
|
@@ -12,17 +12,17 @@ import requests
|
|
12
12
|
from openai import AzureOpenAI, OpenAI
|
13
13
|
from scipy.spatial.distance import cosine # type: ignore
|
14
14
|
|
15
|
-
from vision_agent.tools.
|
15
|
+
from vision_agent.tools.tools import get_tools_df
|
16
|
+
from vision_agent.utils.tools import (
|
16
17
|
_LND_API_KEY,
|
17
18
|
_create_requests_session,
|
18
19
|
_LND_API_URL_v2,
|
19
20
|
)
|
20
|
-
from vision_agent.tools.tools import TOOLS_DF
|
21
21
|
|
22
22
|
|
23
23
|
@lru_cache(maxsize=1)
|
24
24
|
def get_tool_recommender() -> "Sim":
|
25
|
-
return load_cached_sim(
|
25
|
+
return load_cached_sim(get_tools_df())
|
26
26
|
|
27
27
|
|
28
28
|
@lru_cache(maxsize=512)
|
vision_agent/tools/__init__.py
CHANGED
@@ -12,17 +12,10 @@ from .meta_tools import (
|
|
12
12
|
use_object_detection_fine_tuning,
|
13
13
|
view_media_artifact,
|
14
14
|
)
|
15
|
+
from .planner_tools import judge_od_results
|
15
16
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
16
|
-
from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
|
17
17
|
from .tools import (
|
18
|
-
|
19
|
-
TOOL_DESCRIPTIONS,
|
20
|
-
TOOL_DOCSTRING,
|
21
|
-
TOOLS,
|
22
|
-
TOOLS_DF,
|
23
|
-
TOOLS_INFO,
|
24
|
-
UTIL_TOOLS,
|
25
|
-
UTILITIES_DOCSTRING,
|
18
|
+
activity_recognition,
|
26
19
|
agentic_object_detection,
|
27
20
|
agentic_sam2_instance_segmentation,
|
28
21
|
agentic_sam2_video_tracking,
|
@@ -45,7 +38,11 @@ from .tools import (
|
|
45
38
|
florence2_sam2_video_tracking,
|
46
39
|
flux_image_inpainting,
|
47
40
|
generate_pose_image,
|
48
|
-
|
41
|
+
get_tools,
|
42
|
+
get_tools_descriptions,
|
43
|
+
get_tools_df,
|
44
|
+
get_tools_docstring,
|
45
|
+
get_utilties_docstring,
|
49
46
|
load_image,
|
50
47
|
minimum_distance,
|
51
48
|
ocr,
|
@@ -64,7 +61,6 @@ from .tools import (
|
|
64
61
|
save_video,
|
65
62
|
siglip_classification,
|
66
63
|
template_match,
|
67
|
-
video_temporal_localization,
|
68
64
|
vit_image_classification,
|
69
65
|
vit_nsfw_classification,
|
70
66
|
)
|
@@ -79,20 +75,11 @@ def register_tool(imports: Optional[List] = None) -> Callable:
|
|
79
75
|
def decorator(tool: Callable) -> Callable:
|
80
76
|
import inspect
|
81
77
|
|
82
|
-
from .tools import ( # noqa: F811
|
83
|
-
get_tool_descriptions,
|
84
|
-
get_tools_df,
|
85
|
-
get_tools_info,
|
86
|
-
)
|
87
|
-
|
88
78
|
global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS_INFO
|
79
|
+
from vision_agent.tools.tools import TOOLS
|
89
80
|
|
90
|
-
if tool not in TOOLS:
|
91
|
-
TOOLS.append(tool)
|
92
|
-
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
|
93
|
-
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
|
94
|
-
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
|
95
|
-
TOOLS_INFO = get_tools_info(TOOLS) # type: ignore
|
81
|
+
if tool not in TOOLS: # type: ignore
|
82
|
+
TOOLS.append(tool) # type: ignore
|
96
83
|
|
97
84
|
globals()[tool.__name__] = tool
|
98
85
|
if imports is not None:
|
vision_agent/tools/meta_tools.py
CHANGED
@@ -12,12 +12,11 @@ from IPython.display import display
|
|
12
12
|
|
13
13
|
import vision_agent as va
|
14
14
|
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
15
|
-
from vision_agent.
|
16
|
-
from vision_agent.tools.
|
17
|
-
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
|
18
|
-
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
|
15
|
+
from vision_agent.models import BboxInput, BboxInputBase64, Message, PromptTask
|
16
|
+
from vision_agent.tools.tools import get_tools_descriptions as _get_tool_descriptions
|
19
17
|
from vision_agent.utils.execute import Execution, MimeType
|
20
18
|
from vision_agent.utils.image_utils import convert_to_b64
|
19
|
+
from vision_agent.utils.tools_doc import get_tool_documentation
|
21
20
|
|
22
21
|
CURRENT_FILE = None
|
23
22
|
CURRENT_LINE = 0
|
@@ -571,7 +570,7 @@ def get_tool_descriptions() -> str:
|
|
571
570
|
"""Returns a description of all the tools that `generate_vision_code` has access to.
|
572
571
|
Helpful for answering questions about what types of vision tasks you can do with
|
573
572
|
`generate_vision_code`."""
|
574
|
-
return
|
573
|
+
return _get_tool_descriptions()
|
575
574
|
|
576
575
|
|
577
576
|
def object_detection_fine_tuning(bboxes: List[Dict[str, Any]]) -> str:
|
@@ -1,5 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
import math
|
4
|
+
import random
|
3
5
|
import tempfile
|
4
6
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5
7
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
@@ -10,7 +12,6 @@ from IPython.display import display
|
|
10
12
|
from PIL import Image
|
11
13
|
|
12
14
|
import vision_agent.tools as T
|
13
|
-
from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
|
14
15
|
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
15
16
|
CATEGORIZE_TOOL_REQUEST,
|
16
17
|
FINALIZE_PLAN,
|
@@ -21,6 +22,9 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
|
21
22
|
)
|
22
23
|
from vision_agent.configs import Config
|
23
24
|
from vision_agent.lmm import LMM, AnthropicLMM
|
25
|
+
from vision_agent.sim import get_tool_recommender
|
26
|
+
from vision_agent.tools.tools import get_tools, get_tools_info
|
27
|
+
from vision_agent.utils.agent import DefaultImports, extract_json, extract_tag
|
24
28
|
from vision_agent.utils.execute import (
|
25
29
|
CodeInterpreter,
|
26
30
|
CodeInterpreterFactory,
|
@@ -28,12 +32,16 @@ from vision_agent.utils.execute import (
|
|
28
32
|
MimeType,
|
29
33
|
)
|
30
34
|
from vision_agent.utils.image_utils import convert_to_b64
|
31
|
-
from vision_agent.utils.
|
35
|
+
from vision_agent.utils.tools_doc import get_tool_documentation
|
36
|
+
|
37
|
+
|
38
|
+
def get_tool_functions() -> Dict[str, Callable]:
|
39
|
+
return {tool.__name__: tool for tool in get_tools()}
|
40
|
+
|
41
|
+
|
42
|
+
def get_load_tools_docstring() -> str:
|
43
|
+
return get_tool_documentation([T.load_image, T.extract_frames_and_timestamps])
|
32
44
|
|
33
|
-
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
34
|
-
LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
|
35
|
-
[T.load_image, T.extract_frames_and_timestamps]
|
36
|
-
)
|
37
45
|
|
38
46
|
CONFIG = Config()
|
39
47
|
_LOGGER = logging.getLogger(__name__)
|
@@ -50,6 +58,59 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
50
58
|
return return_str
|
51
59
|
|
52
60
|
|
61
|
+
def judge_od_results(
|
62
|
+
prompt: str,
|
63
|
+
image: np.ndarray,
|
64
|
+
detections: List[Dict[str, Any]],
|
65
|
+
) -> str:
|
66
|
+
"""Given an image and the detections, this function will judge the results and
|
67
|
+
return the thoughts on the results.
|
68
|
+
|
69
|
+
Parameters:
|
70
|
+
prompt (str): The prompt that was used to generate the detections.
|
71
|
+
image (np.ndarray): The image that the detections were made on.
|
72
|
+
detections (List[Dict[str, Any]]): The detections made on the image.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
str: The thoughts on the results.
|
76
|
+
"""
|
77
|
+
|
78
|
+
if not detections:
|
79
|
+
return "No detections found in the image."
|
80
|
+
|
81
|
+
od_judge = CONFIG.create_od_judge()
|
82
|
+
max_crop_size = (512, 512)
|
83
|
+
|
84
|
+
# Randomly sample up to 10 detections
|
85
|
+
num_samples = min(10, len(detections))
|
86
|
+
sampled_detections = random.sample(detections, num_samples)
|
87
|
+
crops = []
|
88
|
+
h, w = image.shape[:2]
|
89
|
+
|
90
|
+
for detection in sampled_detections:
|
91
|
+
if "bbox" not in detection:
|
92
|
+
continue
|
93
|
+
x1, y1, x2, y2 = detection["bbox"]
|
94
|
+
crop = image[int(y1 * h) : int(y2 * h), int(x1 * w) : int(x2 * w)]
|
95
|
+
if crop.shape[0] > max_crop_size[0] or crop.shape[1] > max_crop_size[1]:
|
96
|
+
crop = Image.fromarray(crop) # type: ignore
|
97
|
+
crop.thumbnail(max_crop_size) # type: ignore
|
98
|
+
crop = np.array(crop)
|
99
|
+
crops.append("data:image/png;base64," + convert_to_b64(crop))
|
100
|
+
|
101
|
+
sampled_detection_info = [
|
102
|
+
{"score": d["score"], "label": d["label"]} for d in sampled_detections
|
103
|
+
]
|
104
|
+
|
105
|
+
prompt = f"""The user is trying to detect '{prompt}' in an image. You are shown 10 images which represent crops of the detected objets. Below is the detection labels and scores:
|
106
|
+
{sampled_detection_info}
|
107
|
+
|
108
|
+
Look over each of the cropped images and corresponding labels and scores. Provide a judgement on whether or not the results are correct. If the results are incorrect you can only suggest a different prompt or a threshold."""
|
109
|
+
|
110
|
+
response = cast(str, od_judge.generate(prompt, media=crops))
|
111
|
+
return response
|
112
|
+
|
113
|
+
|
53
114
|
def run_multi_judge(
|
54
115
|
tool_chooser: LMM,
|
55
116
|
tool_docs_str: str,
|
@@ -57,6 +118,7 @@ def run_multi_judge(
|
|
57
118
|
code: str,
|
58
119
|
tool_output_str: str,
|
59
120
|
image_paths: List[str],
|
121
|
+
n_judges: int = 3,
|
60
122
|
) -> Tuple[Optional[Callable], str, str]:
|
61
123
|
error_message = ""
|
62
124
|
prompt = PICK_TOOL.format(
|
@@ -77,7 +139,7 @@ def run_multi_judge(
|
|
77
139
|
|
78
140
|
responses = []
|
79
141
|
with ThreadPoolExecutor() as executor:
|
80
|
-
futures = [executor.submit(run_judge) for _ in range(
|
142
|
+
futures = [executor.submit(run_judge) for _ in range(n_judges)]
|
81
143
|
for future in as_completed(futures):
|
82
144
|
responses.append(future.result())
|
83
145
|
|
@@ -86,7 +148,7 @@ def run_multi_judge(
|
|
86
148
|
for tool, tool_thoughts, tool_docstring in responses:
|
87
149
|
if tool is not None:
|
88
150
|
counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
|
89
|
-
if counts[tool.__name__] >= 2:
|
151
|
+
if counts[tool.__name__] >= math.ceil(n_judges / 2):
|
90
152
|
return tool, tool_thoughts, tool_docstring
|
91
153
|
|
92
154
|
if len(responses) == 0:
|
@@ -104,9 +166,12 @@ def extract_tool_info(
|
|
104
166
|
tool_thoughts = tool_choice_context.get("thoughts", "")
|
105
167
|
tool_docstring = ""
|
106
168
|
tool = tool_choice_context.get("best_tool", None)
|
107
|
-
|
108
|
-
|
109
|
-
|
169
|
+
tools_info = get_tools_info()
|
170
|
+
|
171
|
+
tool_functions = get_tool_functions()
|
172
|
+
if tool in tool_functions:
|
173
|
+
tool = tool_functions[tool]
|
174
|
+
tool_docstring = tools_info[tool.__name__]
|
110
175
|
|
111
176
|
return tool, tool_thoughts, tool_docstring, ""
|
112
177
|
|
@@ -153,6 +218,42 @@ def replace_box_threshold(code: str, functions: List[str], box_threshold: float)
|
|
153
218
|
return new_tree.code
|
154
219
|
|
155
220
|
|
221
|
+
def retrieve_tool_docs(lmm: LMM, task: str, exclude_tools: Optional[List[str]]) -> str:
|
222
|
+
query = cast(str, lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task)))
|
223
|
+
categories_str = extract_tag(query, "category")
|
224
|
+
if categories_str is None:
|
225
|
+
categories = []
|
226
|
+
else:
|
227
|
+
categories = [e.strip() for e in categories_str.split(",")]
|
228
|
+
|
229
|
+
explanation = query.split("<category>")[0].strip()
|
230
|
+
if "</category>" in query:
|
231
|
+
explanation += " " + query.split("</category>")[1].strip()
|
232
|
+
explanation = explanation.strip()
|
233
|
+
|
234
|
+
sim = get_tool_recommender()
|
235
|
+
|
236
|
+
all_tool_docs = []
|
237
|
+
all_tool_doc_names = set()
|
238
|
+
exclude_tools = [] if exclude_tools is None else exclude_tools
|
239
|
+
for category in categories:
|
240
|
+
tool_docs = sim.top_k(category, k=3, thresh=0.3)
|
241
|
+
|
242
|
+
for tool_doc in tool_docs:
|
243
|
+
if (
|
244
|
+
tool_doc["name"] not in all_tool_doc_names
|
245
|
+
and tool_doc["name"] not in exclude_tools
|
246
|
+
):
|
247
|
+
all_tool_docs.append(tool_doc)
|
248
|
+
all_tool_doc_names.add(tool_doc["name"])
|
249
|
+
|
250
|
+
tool_docs_str = explanation + "\n\n" + "\n".join([e["doc"] for e in all_tool_docs])
|
251
|
+
tool_docs_str += (
|
252
|
+
"\n" + get_load_tools_docstring() + get_tool_documentation([judge_od_results])
|
253
|
+
)
|
254
|
+
return tool_docs_str
|
255
|
+
|
256
|
+
|
156
257
|
def run_tool_testing(
|
157
258
|
task: str,
|
158
259
|
image_paths: List[str],
|
@@ -162,22 +263,8 @@ def run_tool_testing(
|
|
162
263
|
process_code: Callable[[str], str] = lambda x: x,
|
163
264
|
) -> tuple[str, str, Execution]:
|
164
265
|
"""Helper function to generate and run tool testing code."""
|
165
|
-
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
166
|
-
category = extract_tag(query, "category") # type: ignore
|
167
|
-
if category is None:
|
168
|
-
query = task
|
169
|
-
else:
|
170
|
-
query = f"{category.strip()}. {task}"
|
171
266
|
|
172
|
-
|
173
|
-
if exclude_tools is not None and len(exclude_tools) > 0:
|
174
|
-
cleaned_tool_docs = []
|
175
|
-
for tool_doc in tool_docs:
|
176
|
-
if not tool_doc["name"] in exclude_tools:
|
177
|
-
cleaned_tool_docs.append(tool_doc)
|
178
|
-
tool_docs = cleaned_tool_docs
|
179
|
-
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
|
180
|
-
tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING
|
267
|
+
tool_docs_str = retrieve_tool_docs(lmm, task, exclude_tools)
|
181
268
|
|
182
269
|
prompt = TEST_TOOLS.format(
|
183
270
|
tool_docs=tool_docs_str,
|
@@ -295,24 +382,26 @@ def get_tool_for_task(
|
|
295
382
|
Image.fromarray(image).save(image_path)
|
296
383
|
image_paths.append(image_path)
|
297
384
|
|
385
|
+
# run no more than 3 images or else it overloads the LLM
|
386
|
+
image_paths = image_paths[:3]
|
298
387
|
code, tool_docs_str, tool_output = run_tool_testing(
|
299
388
|
task, image_paths, tool_tester, exclude_tools, code_interpreter
|
300
389
|
)
|
301
390
|
tool_output_str = tool_output.text(include_results=False).strip()
|
302
391
|
|
303
392
|
_, tool_thoughts, tool_docstring = run_multi_judge(
|
304
|
-
tool_chooser,
|
393
|
+
tool_chooser,
|
394
|
+
tool_docs_str,
|
395
|
+
task,
|
396
|
+
code,
|
397
|
+
tool_output_str,
|
398
|
+
image_paths,
|
399
|
+
n_judges=3,
|
305
400
|
)
|
306
401
|
|
307
402
|
print(format_tool_output(tool_thoughts, tool_docstring))
|
308
403
|
|
309
404
|
|
310
|
-
def get_tool_documentation(tool_name: str) -> str:
|
311
|
-
# use same format as get_tool_for_task
|
312
|
-
tool_doc = T.TOOLS_DF[T.TOOLS_DF["name"] == tool_name]["doc"].values[0]
|
313
|
-
return format_tool_output("", tool_doc)
|
314
|
-
|
315
|
-
|
316
405
|
def get_tool_for_task_human_reviewer(
|
317
406
|
task: str,
|
318
407
|
images: Union[Dict[str, List[np.ndarray]], List[np.ndarray]],
|
@@ -337,8 +426,8 @@ def get_tool_for_task_human_reviewer(
|
|
337
426
|
|
338
427
|
tools = [
|
339
428
|
t.__name__
|
340
|
-
for t in
|
341
|
-
if inspect.signature(t).parameters.get("box_threshold")
|
429
|
+
for t in get_tools()
|
430
|
+
if inspect.signature(t).parameters.get("box_threshold")
|
342
431
|
]
|
343
432
|
|
344
433
|
_, _, tool_output = run_tool_testing(
|
@@ -414,7 +503,8 @@ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
|
|
414
503
|
a problem.
|
415
504
|
|
416
505
|
Parameters:
|
417
|
-
prompt: str: The problem statement
|
506
|
+
prompt: str: The problem statement, provide a detailed description of the
|
507
|
+
problem you are trying to solve.
|
418
508
|
medias: List[np.ndarray]: The images to use for the problem
|
419
509
|
"""
|
420
510
|
try:
|
@@ -431,4 +521,4 @@ PLANNER_TOOLS = [
|
|
431
521
|
suggestion,
|
432
522
|
get_tool_for_task,
|
433
523
|
]
|
434
|
-
PLANNER_DOCSTRING =
|
524
|
+
PLANNER_DOCSTRING = get_tool_documentation(PLANNER_TOOLS) # type: ignore
|