vision-agent 0.2.236__py3-none-any.whl → 0.2.238__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +57 -80
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/agent.py +2 -2
- vision_agent/agent/vision_agent.py +3 -2
- vision_agent/agent/vision_agent_coder.py +13 -19
- vision_agent/agent/vision_agent_coder_v2.py +17 -17
- vision_agent/agent/vision_agent_planner.py +16 -21
- vision_agent/agent/vision_agent_planner_prompts_v2.py +19 -20
- vision_agent/agent/vision_agent_planner_v2.py +29 -15
- vision_agent/agent/vision_agent_v2.py +12 -12
- vision_agent/clients/landing_public_api.py +1 -1
- vision_agent/configs/anthropic_openai_config.py +17 -3
- vision_agent/configs/config.py +17 -3
- vision_agent/lmm/__init__.py +0 -1
- vision_agent/lmm/lmm.py +4 -3
- vision_agent/models/__init__.py +11 -0
- vision_agent/{lmm/types.py → models/lmm_types.py} +4 -1
- vision_agent/sim/__init__.py +9 -0
- vision_agent/{utils → sim}/sim.py +3 -3
- vision_agent/tools/__init__.py +10 -23
- vision_agent/tools/meta_tools.py +4 -5
- vision_agent/tools/planner_tools.py +148 -37
- vision_agent/tools/tools.py +388 -302
- vision_agent/utils/__init__.py +0 -1
- vision_agent/{agent/agent_utils.py → utils/agent.py} +11 -2
- vision_agent/utils/image_utils.py +18 -7
- vision_agent/{tools/tool_utils.py → utils/tools.py} +1 -93
- vision_agent/utils/tools_doc.py +87 -0
- vision_agent/utils/video.py +15 -0
- vision_agent/utils/video_tracking.py +38 -5
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/METADATA +2 -3
- vision_agent-0.2.238.dist-info/RECORD +55 -0
- vision_agent-0.2.236.dist-info/RECORD +0 -52
- /vision_agent/{agent/types.py → models/agent_types.py} +0 -0
- /vision_agent/{tools → models}/tools_types.py +0 -0
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/WHEEL +0 -0
@@ -4,23 +4,23 @@ from pathlib import Path
|
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
5
5
|
|
6
6
|
from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
|
7
|
-
from vision_agent.agent.
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
)
|
13
|
-
from vision_agent.agent.types import (
|
7
|
+
from vision_agent.agent.vision_agent_coder_v2 import format_code_context
|
8
|
+
from vision_agent.agent.vision_agent_prompts_v2 import CONVERSATION
|
9
|
+
from vision_agent.configs import Config
|
10
|
+
from vision_agent.lmm import LMM
|
11
|
+
from vision_agent.models import (
|
14
12
|
AgentMessage,
|
15
13
|
CodeContext,
|
16
14
|
InteractionContext,
|
15
|
+
Message,
|
17
16
|
PlanContext,
|
18
17
|
)
|
19
|
-
from vision_agent.agent
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
18
|
+
from vision_agent.utils.agent import (
|
19
|
+
add_media_to_chat,
|
20
|
+
convert_message_to_agentmessage,
|
21
|
+
extract_tag,
|
22
|
+
format_conversation,
|
23
|
+
)
|
24
24
|
from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
|
25
25
|
|
26
26
|
CONFIG = Config()
|
@@ -5,7 +5,7 @@ from uuid import UUID
|
|
5
5
|
from requests.exceptions import HTTPError
|
6
6
|
|
7
7
|
from vision_agent.clients.http import BaseHTTP
|
8
|
-
from vision_agent.
|
8
|
+
from vision_agent.models import BboxInputBase64, JobStatus, PromptTask
|
9
9
|
from vision_agent.utils.exceptions import FineTuneModelNotFound
|
10
10
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
11
11
|
|
@@ -96,13 +96,24 @@ class Config(BaseModel):
|
|
96
96
|
}
|
97
97
|
)
|
98
98
|
|
99
|
+
# for get_tool_for_task
|
100
|
+
od_judge: Type[LMM] = Field(default=AnthropicLMM)
|
101
|
+
od_judge_kwargs: dict = Field(
|
102
|
+
default_factory=lambda: {
|
103
|
+
"model_name": "claude-3-5-sonnet-20241022",
|
104
|
+
"temperature": 0.0,
|
105
|
+
"image_size": 512,
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
99
109
|
# for suggestions module
|
100
|
-
suggester: Type[LMM] = Field(default=
|
110
|
+
suggester: Type[LMM] = Field(default=OpenAILMM)
|
101
111
|
suggester_kwargs: dict = Field(
|
102
112
|
default_factory=lambda: {
|
103
|
-
"model_name": "
|
113
|
+
"model_name": "o1",
|
104
114
|
"temperature": 1.0,
|
105
|
-
"
|
115
|
+
"image_detail": "high",
|
116
|
+
"image_size": 1024,
|
106
117
|
}
|
107
118
|
)
|
108
119
|
|
@@ -143,6 +154,9 @@ class Config(BaseModel):
|
|
143
154
|
def create_tool_chooser(self) -> LMM:
|
144
155
|
return self.tool_chooser(**self.tool_chooser_kwargs)
|
145
156
|
|
157
|
+
def create_od_judge(self) -> LMM:
|
158
|
+
return self.od_judge(**self.od_judge_kwargs)
|
159
|
+
|
146
160
|
def create_suggester(self) -> LMM:
|
147
161
|
return self.suggester(**self.suggester_kwargs)
|
148
162
|
|
vision_agent/configs/config.py
CHANGED
@@ -96,13 +96,24 @@ class Config(BaseModel):
|
|
96
96
|
}
|
97
97
|
)
|
98
98
|
|
99
|
+
# for get_tool_for_task
|
100
|
+
od_judge: Type[LMM] = Field(default=AnthropicLMM)
|
101
|
+
od_judge_kwargs: dict = Field(
|
102
|
+
default_factory=lambda: {
|
103
|
+
"model_name": "claude-3-5-sonnet-20241022",
|
104
|
+
"temperature": 0.0,
|
105
|
+
"image_size": 512,
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
99
109
|
# for suggestions module
|
100
|
-
suggester: Type[LMM] = Field(default=
|
110
|
+
suggester: Type[LMM] = Field(default=OpenAILMM)
|
101
111
|
suggester_kwargs: dict = Field(
|
102
112
|
default_factory=lambda: {
|
103
|
-
"model_name": "
|
113
|
+
"model_name": "o1",
|
104
114
|
"temperature": 1.0,
|
105
|
-
"
|
115
|
+
"image_detail": "high",
|
116
|
+
"image_size": 1024,
|
106
117
|
}
|
107
118
|
)
|
108
119
|
|
@@ -143,6 +154,9 @@ class Config(BaseModel):
|
|
143
154
|
def create_tool_chooser(self) -> LMM:
|
144
155
|
return self.tool_chooser(**self.tool_chooser_kwargs)
|
145
156
|
|
157
|
+
def create_od_judge(self) -> LMM:
|
158
|
+
return self.od_judge(**self.od_judge_kwargs)
|
159
|
+
|
146
160
|
def create_suggester(self) -> LMM:
|
147
161
|
return self.suggester(**self.suggester_kwargs)
|
148
162
|
|
vision_agent/lmm/__init__.py
CHANGED
vision_agent/lmm/lmm.py
CHANGED
@@ -9,10 +9,9 @@ import requests
|
|
9
9
|
from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam
|
10
10
|
from openai import AzureOpenAI, OpenAI
|
11
11
|
|
12
|
+
from vision_agent.models import Message
|
12
13
|
from vision_agent.utils.image_utils import encode_media
|
13
14
|
|
14
|
-
from .types import Message
|
15
|
-
|
16
15
|
|
17
16
|
class LMM(ABC):
|
18
17
|
@abstractmethod
|
@@ -64,7 +63,9 @@ class OpenAILMM(LMM):
|
|
64
63
|
self.image_size = image_size
|
65
64
|
self.image_detail = image_detail
|
66
65
|
# o1 does not use max_tokens
|
67
|
-
if "max_tokens" not in kwargs and not
|
66
|
+
if "max_tokens" not in kwargs and not (
|
67
|
+
model_name.startswith("o1") or model_name.startswith("o3")
|
68
|
+
):
|
68
69
|
kwargs["max_tokens"] = max_tokens
|
69
70
|
if json_mode:
|
70
71
|
kwargs["response_format"] = {"type": "json_object"}
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from .agent_types import AgentMessage, CodeContext, InteractionContext, PlanContext
|
2
|
+
from .lmm_types import Message, TextOrImage
|
3
|
+
from .tools_types import (
|
4
|
+
BboxInput,
|
5
|
+
BboxInputBase64,
|
6
|
+
BoundingBoxes,
|
7
|
+
Florence2FtRequest,
|
8
|
+
JobStatus,
|
9
|
+
ODResponseData,
|
10
|
+
PromptTask,
|
11
|
+
)
|
@@ -1,7 +1,10 @@
|
|
1
1
|
from pathlib import Path
|
2
2
|
from typing import Dict, Sequence, Union
|
3
3
|
|
4
|
+
import numpy as np
|
5
|
+
from PIL.Image import Image as ImageType
|
6
|
+
|
4
7
|
from vision_agent.utils.execute import Execution
|
5
8
|
|
6
|
-
TextOrImage = Union[str, Sequence[Union[str, Path]]]
|
9
|
+
TextOrImage = Union[str, Sequence[Union[str, Path, ImageType, np.ndarray]]]
|
7
10
|
Message = Dict[str, Union[TextOrImage, Execution]]
|
@@ -12,17 +12,17 @@ import requests
|
|
12
12
|
from openai import AzureOpenAI, OpenAI
|
13
13
|
from scipy.spatial.distance import cosine # type: ignore
|
14
14
|
|
15
|
-
from vision_agent.tools.
|
15
|
+
from vision_agent.tools.tools import get_tools_df
|
16
|
+
from vision_agent.utils.tools import (
|
16
17
|
_LND_API_KEY,
|
17
18
|
_create_requests_session,
|
18
19
|
_LND_API_URL_v2,
|
19
20
|
)
|
20
|
-
from vision_agent.tools.tools import TOOLS_DF
|
21
21
|
|
22
22
|
|
23
23
|
@lru_cache(maxsize=1)
|
24
24
|
def get_tool_recommender() -> "Sim":
|
25
|
-
return load_cached_sim(
|
25
|
+
return load_cached_sim(get_tools_df())
|
26
26
|
|
27
27
|
|
28
28
|
@lru_cache(maxsize=512)
|
vision_agent/tools/__init__.py
CHANGED
@@ -12,17 +12,10 @@ from .meta_tools import (
|
|
12
12
|
use_object_detection_fine_tuning,
|
13
13
|
view_media_artifact,
|
14
14
|
)
|
15
|
+
from .planner_tools import judge_od_results
|
15
16
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
16
|
-
from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
|
17
17
|
from .tools import (
|
18
|
-
|
19
|
-
TOOL_DESCRIPTIONS,
|
20
|
-
TOOL_DOCSTRING,
|
21
|
-
TOOLS,
|
22
|
-
TOOLS_DF,
|
23
|
-
TOOLS_INFO,
|
24
|
-
UTIL_TOOLS,
|
25
|
-
UTILITIES_DOCSTRING,
|
18
|
+
activity_recognition,
|
26
19
|
agentic_object_detection,
|
27
20
|
agentic_sam2_instance_segmentation,
|
28
21
|
agentic_sam2_video_tracking,
|
@@ -45,7 +38,11 @@ from .tools import (
|
|
45
38
|
florence2_sam2_video_tracking,
|
46
39
|
flux_image_inpainting,
|
47
40
|
generate_pose_image,
|
48
|
-
|
41
|
+
get_tools,
|
42
|
+
get_tools_descriptions,
|
43
|
+
get_tools_df,
|
44
|
+
get_tools_docstring,
|
45
|
+
get_utilties_docstring,
|
49
46
|
load_image,
|
50
47
|
minimum_distance,
|
51
48
|
ocr,
|
@@ -64,7 +61,6 @@ from .tools import (
|
|
64
61
|
save_video,
|
65
62
|
siglip_classification,
|
66
63
|
template_match,
|
67
|
-
video_temporal_localization,
|
68
64
|
vit_image_classification,
|
69
65
|
vit_nsfw_classification,
|
70
66
|
)
|
@@ -79,20 +75,11 @@ def register_tool(imports: Optional[List] = None) -> Callable:
|
|
79
75
|
def decorator(tool: Callable) -> Callable:
|
80
76
|
import inspect
|
81
77
|
|
82
|
-
from .tools import ( # noqa: F811
|
83
|
-
get_tool_descriptions,
|
84
|
-
get_tools_df,
|
85
|
-
get_tools_info,
|
86
|
-
)
|
87
|
-
|
88
78
|
global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS_INFO
|
79
|
+
from vision_agent.tools.tools import TOOLS
|
89
80
|
|
90
|
-
if tool not in TOOLS:
|
91
|
-
TOOLS.append(tool)
|
92
|
-
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
|
93
|
-
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
|
94
|
-
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
|
95
|
-
TOOLS_INFO = get_tools_info(TOOLS) # type: ignore
|
81
|
+
if tool not in TOOLS: # type: ignore
|
82
|
+
TOOLS.append(tool) # type: ignore
|
96
83
|
|
97
84
|
globals()[tool.__name__] = tool
|
98
85
|
if imports is not None:
|
vision_agent/tools/meta_tools.py
CHANGED
@@ -12,12 +12,11 @@ from IPython.display import display
|
|
12
12
|
|
13
13
|
import vision_agent as va
|
14
14
|
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
15
|
-
from vision_agent.
|
16
|
-
from vision_agent.tools.
|
17
|
-
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
|
18
|
-
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
|
15
|
+
from vision_agent.models import BboxInput, BboxInputBase64, Message, PromptTask
|
16
|
+
from vision_agent.tools.tools import get_tools_descriptions as _get_tool_descriptions
|
19
17
|
from vision_agent.utils.execute import Execution, MimeType
|
20
18
|
from vision_agent.utils.image_utils import convert_to_b64
|
19
|
+
from vision_agent.utils.tools_doc import get_tool_documentation
|
21
20
|
|
22
21
|
CURRENT_FILE = None
|
23
22
|
CURRENT_LINE = 0
|
@@ -571,7 +570,7 @@ def get_tool_descriptions() -> str:
|
|
571
570
|
"""Returns a description of all the tools that `generate_vision_code` has access to.
|
572
571
|
Helpful for answering questions about what types of vision tasks you can do with
|
573
572
|
`generate_vision_code`."""
|
574
|
-
return
|
573
|
+
return _get_tool_descriptions()
|
575
574
|
|
576
575
|
|
577
576
|
def object_detection_fine_tuning(bboxes: List[Dict[str, Any]]) -> str:
|
@@ -1,5 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
import math
|
4
|
+
import random
|
3
5
|
import tempfile
|
4
6
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5
7
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
@@ -10,7 +12,6 @@ from IPython.display import display
|
|
10
12
|
from PIL import Image
|
11
13
|
|
12
14
|
import vision_agent.tools as T
|
13
|
-
from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
|
14
15
|
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
15
16
|
CATEGORIZE_TOOL_REQUEST,
|
16
17
|
FINALIZE_PLAN,
|
@@ -21,6 +22,9 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
|
21
22
|
)
|
22
23
|
from vision_agent.configs import Config
|
23
24
|
from vision_agent.lmm import LMM, AnthropicLMM
|
25
|
+
from vision_agent.sim import get_tool_recommender
|
26
|
+
from vision_agent.tools.tools import get_tools, get_tools_info
|
27
|
+
from vision_agent.utils.agent import DefaultImports, extract_json, extract_tag
|
24
28
|
from vision_agent.utils.execute import (
|
25
29
|
CodeInterpreter,
|
26
30
|
CodeInterpreterFactory,
|
@@ -28,12 +32,16 @@ from vision_agent.utils.execute import (
|
|
28
32
|
MimeType,
|
29
33
|
)
|
30
34
|
from vision_agent.utils.image_utils import convert_to_b64
|
31
|
-
from vision_agent.utils.
|
35
|
+
from vision_agent.utils.tools_doc import get_tool_documentation
|
36
|
+
|
37
|
+
|
38
|
+
def get_tool_functions() -> Dict[str, Callable]:
|
39
|
+
return {tool.__name__: tool for tool in get_tools()}
|
40
|
+
|
41
|
+
|
42
|
+
def get_load_tools_docstring() -> str:
|
43
|
+
return get_tool_documentation([T.load_image, T.extract_frames_and_timestamps])
|
32
44
|
|
33
|
-
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
34
|
-
LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
|
35
|
-
[T.load_image, T.extract_frames_and_timestamps]
|
36
|
-
)
|
37
45
|
|
38
46
|
CONFIG = Config()
|
39
47
|
_LOGGER = logging.getLogger(__name__)
|
@@ -50,6 +58,59 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
50
58
|
return return_str
|
51
59
|
|
52
60
|
|
61
|
+
def judge_od_results(
|
62
|
+
prompt: str,
|
63
|
+
image: np.ndarray,
|
64
|
+
detections: List[Dict[str, Any]],
|
65
|
+
) -> str:
|
66
|
+
"""Given an image and the detections, this function will judge the results and
|
67
|
+
return the thoughts on the results.
|
68
|
+
|
69
|
+
Parameters:
|
70
|
+
prompt (str): The prompt that was used to generate the detections.
|
71
|
+
image (np.ndarray): The image that the detections were made on.
|
72
|
+
detections (List[Dict[str, Any]]): The detections made on the image.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
str: The thoughts on the results.
|
76
|
+
"""
|
77
|
+
|
78
|
+
if not detections:
|
79
|
+
return "No detections found in the image."
|
80
|
+
|
81
|
+
od_judge = CONFIG.create_od_judge()
|
82
|
+
max_crop_size = (512, 512)
|
83
|
+
|
84
|
+
# Randomly sample up to 10 detections
|
85
|
+
num_samples = min(10, len(detections))
|
86
|
+
sampled_detections = random.sample(detections, num_samples)
|
87
|
+
crops = []
|
88
|
+
h, w = image.shape[:2]
|
89
|
+
|
90
|
+
for detection in sampled_detections:
|
91
|
+
if "bbox" not in detection:
|
92
|
+
continue
|
93
|
+
x1, y1, x2, y2 = detection["bbox"]
|
94
|
+
crop = image[int(y1 * h) : int(y2 * h), int(x1 * w) : int(x2 * w)]
|
95
|
+
if crop.shape[0] > max_crop_size[0] or crop.shape[1] > max_crop_size[1]:
|
96
|
+
crop = Image.fromarray(crop) # type: ignore
|
97
|
+
crop.thumbnail(max_crop_size) # type: ignore
|
98
|
+
crop = np.array(crop)
|
99
|
+
crops.append("data:image/png;base64," + convert_to_b64(crop))
|
100
|
+
|
101
|
+
sampled_detection_info = [
|
102
|
+
{"score": d["score"], "label": d["label"]} for d in sampled_detections
|
103
|
+
]
|
104
|
+
|
105
|
+
prompt = f"""The user is trying to detect '{prompt}' in an image. You are shown 10 images which represent crops of the detected objets. Below is the detection labels and scores:
|
106
|
+
{sampled_detection_info}
|
107
|
+
|
108
|
+
Look over each of the cropped images and corresponding labels and scores. Provide a judgement on whether or not the results are correct. If the results are incorrect you can only suggest a different prompt or a threshold."""
|
109
|
+
|
110
|
+
response = cast(str, od_judge.generate(prompt, media=crops))
|
111
|
+
return response
|
112
|
+
|
113
|
+
|
53
114
|
def run_multi_judge(
|
54
115
|
tool_chooser: LMM,
|
55
116
|
tool_docs_str: str,
|
@@ -57,6 +118,7 @@ def run_multi_judge(
|
|
57
118
|
code: str,
|
58
119
|
tool_output_str: str,
|
59
120
|
image_paths: List[str],
|
121
|
+
n_judges: int = 3,
|
60
122
|
) -> Tuple[Optional[Callable], str, str]:
|
61
123
|
error_message = ""
|
62
124
|
prompt = PICK_TOOL.format(
|
@@ -77,7 +139,7 @@ def run_multi_judge(
|
|
77
139
|
|
78
140
|
responses = []
|
79
141
|
with ThreadPoolExecutor() as executor:
|
80
|
-
futures = [executor.submit(run_judge) for _ in range(
|
142
|
+
futures = [executor.submit(run_judge) for _ in range(n_judges)]
|
81
143
|
for future in as_completed(futures):
|
82
144
|
responses.append(future.result())
|
83
145
|
|
@@ -86,7 +148,7 @@ def run_multi_judge(
|
|
86
148
|
for tool, tool_thoughts, tool_docstring in responses:
|
87
149
|
if tool is not None:
|
88
150
|
counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
|
89
|
-
if counts[tool.__name__] >= 2:
|
151
|
+
if counts[tool.__name__] >= math.ceil(n_judges / 2):
|
90
152
|
return tool, tool_thoughts, tool_docstring
|
91
153
|
|
92
154
|
if len(responses) == 0:
|
@@ -104,9 +166,12 @@ def extract_tool_info(
|
|
104
166
|
tool_thoughts = tool_choice_context.get("thoughts", "")
|
105
167
|
tool_docstring = ""
|
106
168
|
tool = tool_choice_context.get("best_tool", None)
|
107
|
-
|
108
|
-
|
109
|
-
|
169
|
+
tools_info = get_tools_info()
|
170
|
+
|
171
|
+
tool_functions = get_tool_functions()
|
172
|
+
if tool in tool_functions:
|
173
|
+
tool = tool_functions[tool]
|
174
|
+
tool_docstring = tools_info[tool.__name__]
|
110
175
|
|
111
176
|
return tool, tool_thoughts, tool_docstring, ""
|
112
177
|
|
@@ -153,6 +218,42 @@ def replace_box_threshold(code: str, functions: List[str], box_threshold: float)
|
|
153
218
|
return new_tree.code
|
154
219
|
|
155
220
|
|
221
|
+
def retrieve_tool_docs(lmm: LMM, task: str, exclude_tools: Optional[List[str]]) -> str:
|
222
|
+
query = cast(str, lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task)))
|
223
|
+
categories_str = extract_tag(query, "category")
|
224
|
+
if categories_str is None:
|
225
|
+
categories = []
|
226
|
+
else:
|
227
|
+
categories = [e.strip() for e in categories_str.split(",")]
|
228
|
+
|
229
|
+
explanation = query.split("<category>")[0].strip()
|
230
|
+
if "</category>" in query:
|
231
|
+
explanation += " " + query.split("</category>")[1].strip()
|
232
|
+
explanation = explanation.strip()
|
233
|
+
|
234
|
+
sim = get_tool_recommender()
|
235
|
+
|
236
|
+
all_tool_docs = []
|
237
|
+
all_tool_doc_names = set()
|
238
|
+
exclude_tools = [] if exclude_tools is None else exclude_tools
|
239
|
+
for category in categories:
|
240
|
+
tool_docs = sim.top_k(category, k=3, thresh=0.3)
|
241
|
+
|
242
|
+
for tool_doc in tool_docs:
|
243
|
+
if (
|
244
|
+
tool_doc["name"] not in all_tool_doc_names
|
245
|
+
and tool_doc["name"] not in exclude_tools
|
246
|
+
):
|
247
|
+
all_tool_docs.append(tool_doc)
|
248
|
+
all_tool_doc_names.add(tool_doc["name"])
|
249
|
+
|
250
|
+
tool_docs_str = explanation + "\n\n" + "\n".join([e["doc"] for e in all_tool_docs])
|
251
|
+
tool_docs_str += (
|
252
|
+
"\n" + get_load_tools_docstring() + get_tool_documentation([judge_od_results])
|
253
|
+
)
|
254
|
+
return tool_docs_str
|
255
|
+
|
256
|
+
|
156
257
|
def run_tool_testing(
|
157
258
|
task: str,
|
158
259
|
image_paths: List[str],
|
@@ -162,22 +263,8 @@ def run_tool_testing(
|
|
162
263
|
process_code: Callable[[str], str] = lambda x: x,
|
163
264
|
) -> tuple[str, str, Execution]:
|
164
265
|
"""Helper function to generate and run tool testing code."""
|
165
|
-
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
166
|
-
category = extract_tag(query, "category") # type: ignore
|
167
|
-
if category is None:
|
168
|
-
query = task
|
169
|
-
else:
|
170
|
-
query = f"{category.strip()}. {task}"
|
171
266
|
|
172
|
-
|
173
|
-
if exclude_tools is not None and len(exclude_tools) > 0:
|
174
|
-
cleaned_tool_docs = []
|
175
|
-
for tool_doc in tool_docs:
|
176
|
-
if not tool_doc["name"] in exclude_tools:
|
177
|
-
cleaned_tool_docs.append(tool_doc)
|
178
|
-
tool_docs = cleaned_tool_docs
|
179
|
-
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
|
180
|
-
tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING
|
267
|
+
tool_docs_str = retrieve_tool_docs(lmm, task, exclude_tools)
|
181
268
|
|
182
269
|
prompt = TEST_TOOLS.format(
|
183
270
|
tool_docs=tool_docs_str,
|
@@ -281,6 +368,15 @@ def get_tool_for_task(
|
|
281
368
|
tool_tester = CONFIG.create_tool_tester()
|
282
369
|
tool_chooser = CONFIG.create_tool_chooser()
|
283
370
|
|
371
|
+
if isinstance(images, list):
|
372
|
+
if len(images) > 0 and isinstance(images[0], dict):
|
373
|
+
if all(["frame" in image for image in images]):
|
374
|
+
images = [image["frame"] for image in images]
|
375
|
+
else:
|
376
|
+
raise ValueError(
|
377
|
+
f"Expected a list of numpy arrays or a dictionary of strings to lists of numpy arrays, got a list of dictionaries instead: {images}"
|
378
|
+
)
|
379
|
+
|
284
380
|
if isinstance(images, list):
|
285
381
|
images = {"image": images}
|
286
382
|
|
@@ -295,24 +391,26 @@ def get_tool_for_task(
|
|
295
391
|
Image.fromarray(image).save(image_path)
|
296
392
|
image_paths.append(image_path)
|
297
393
|
|
394
|
+
# run no more than 3 images or else it overloads the LLM
|
395
|
+
image_paths = image_paths[:3]
|
298
396
|
code, tool_docs_str, tool_output = run_tool_testing(
|
299
397
|
task, image_paths, tool_tester, exclude_tools, code_interpreter
|
300
398
|
)
|
301
399
|
tool_output_str = tool_output.text(include_results=False).strip()
|
302
400
|
|
303
401
|
_, tool_thoughts, tool_docstring = run_multi_judge(
|
304
|
-
tool_chooser,
|
402
|
+
tool_chooser,
|
403
|
+
tool_docs_str,
|
404
|
+
task,
|
405
|
+
code,
|
406
|
+
tool_output_str,
|
407
|
+
image_paths,
|
408
|
+
n_judges=3,
|
305
409
|
)
|
306
410
|
|
307
411
|
print(format_tool_output(tool_thoughts, tool_docstring))
|
308
412
|
|
309
413
|
|
310
|
-
def get_tool_documentation(tool_name: str) -> str:
|
311
|
-
# use same format as get_tool_for_task
|
312
|
-
tool_doc = T.TOOLS_DF[T.TOOLS_DF["name"] == tool_name]["doc"].values[0]
|
313
|
-
return format_tool_output("", tool_doc)
|
314
|
-
|
315
|
-
|
316
414
|
def get_tool_for_task_human_reviewer(
|
317
415
|
task: str,
|
318
416
|
images: Union[Dict[str, List[np.ndarray]], List[np.ndarray]],
|
@@ -321,6 +419,15 @@ def get_tool_for_task_human_reviewer(
|
|
321
419
|
# NOTE: this will have the same documentation as get_tool_for_task
|
322
420
|
tool_tester = CONFIG.create_tool_tester()
|
323
421
|
|
422
|
+
if isinstance(images, list):
|
423
|
+
if len(images) > 0 and isinstance(images[0], dict):
|
424
|
+
if all(["frame" in image for image in images]):
|
425
|
+
images = [image["frame"] for image in images]
|
426
|
+
else:
|
427
|
+
raise ValueError(
|
428
|
+
f"Expected a list of numpy arrays or a dictionary of strings to lists of numpy arrays, got a list of dictionaries instead: {images}"
|
429
|
+
)
|
430
|
+
|
324
431
|
if isinstance(images, list):
|
325
432
|
images = {"image": images}
|
326
433
|
|
@@ -335,10 +442,13 @@ def get_tool_for_task_human_reviewer(
|
|
335
442
|
Image.fromarray(image).save(image_path)
|
336
443
|
image_paths.append(image_path)
|
337
444
|
|
445
|
+
# run no more than 3 images or else it overloads the LLM
|
446
|
+
image_paths = image_paths[:3]
|
447
|
+
|
338
448
|
tools = [
|
339
449
|
t.__name__
|
340
|
-
for t in
|
341
|
-
if inspect.signature(t).parameters.get("box_threshold")
|
450
|
+
for t in get_tools()
|
451
|
+
if inspect.signature(t).parameters.get("box_threshold")
|
342
452
|
]
|
343
453
|
|
344
454
|
_, _, tool_output = run_tool_testing(
|
@@ -414,7 +524,8 @@ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
|
|
414
524
|
a problem.
|
415
525
|
|
416
526
|
Parameters:
|
417
|
-
prompt: str: The problem statement
|
527
|
+
prompt: str: The problem statement, provide a detailed description of the
|
528
|
+
problem you are trying to solve.
|
418
529
|
medias: List[np.ndarray]: The images to use for the problem
|
419
530
|
"""
|
420
531
|
try:
|
@@ -431,4 +542,4 @@ PLANNER_TOOLS = [
|
|
431
542
|
suggestion,
|
432
543
|
get_tool_for_task,
|
433
544
|
]
|
434
|
-
PLANNER_DOCSTRING =
|
545
|
+
PLANNER_DOCSTRING = get_tool_documentation(PLANNER_TOOLS) # type: ignore
|