vision-agent 0.2.236__py3-none-any.whl → 0.2.238__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vision_agent/.sim_tools/df.csv +57 -80
  2. vision_agent/.sim_tools/embs.npy +0 -0
  3. vision_agent/agent/agent.py +2 -2
  4. vision_agent/agent/vision_agent.py +3 -2
  5. vision_agent/agent/vision_agent_coder.py +13 -19
  6. vision_agent/agent/vision_agent_coder_v2.py +17 -17
  7. vision_agent/agent/vision_agent_planner.py +16 -21
  8. vision_agent/agent/vision_agent_planner_prompts_v2.py +19 -20
  9. vision_agent/agent/vision_agent_planner_v2.py +29 -15
  10. vision_agent/agent/vision_agent_v2.py +12 -12
  11. vision_agent/clients/landing_public_api.py +1 -1
  12. vision_agent/configs/anthropic_openai_config.py +17 -3
  13. vision_agent/configs/config.py +17 -3
  14. vision_agent/lmm/__init__.py +0 -1
  15. vision_agent/lmm/lmm.py +4 -3
  16. vision_agent/models/__init__.py +11 -0
  17. vision_agent/{lmm/types.py → models/lmm_types.py} +4 -1
  18. vision_agent/sim/__init__.py +9 -0
  19. vision_agent/{utils → sim}/sim.py +3 -3
  20. vision_agent/tools/__init__.py +10 -23
  21. vision_agent/tools/meta_tools.py +4 -5
  22. vision_agent/tools/planner_tools.py +148 -37
  23. vision_agent/tools/tools.py +388 -302
  24. vision_agent/utils/__init__.py +0 -1
  25. vision_agent/{agent/agent_utils.py → utils/agent.py} +11 -2
  26. vision_agent/utils/image_utils.py +18 -7
  27. vision_agent/{tools/tool_utils.py → utils/tools.py} +1 -93
  28. vision_agent/utils/tools_doc.py +87 -0
  29. vision_agent/utils/video.py +15 -0
  30. vision_agent/utils/video_tracking.py +38 -5
  31. {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/METADATA +2 -3
  32. vision_agent-0.2.238.dist-info/RECORD +55 -0
  33. vision_agent-0.2.236.dist-info/RECORD +0 -52
  34. /vision_agent/{agent/types.py → models/agent_types.py} +0 -0
  35. /vision_agent/{tools → models}/tools_types.py +0 -0
  36. {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/LICENSE +0 -0
  37. {vision_agent-0.2.236.dist-info → vision_agent-0.2.238.dist-info}/WHEEL +0 -0
@@ -4,23 +4,23 @@ from pathlib import Path
4
4
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
5
5
 
6
6
  from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
7
- from vision_agent.agent.agent_utils import (
8
- add_media_to_chat,
9
- convert_message_to_agentmessage,
10
- extract_tag,
11
- format_conversation,
12
- )
13
- from vision_agent.agent.types import (
7
+ from vision_agent.agent.vision_agent_coder_v2 import format_code_context
8
+ from vision_agent.agent.vision_agent_prompts_v2 import CONVERSATION
9
+ from vision_agent.configs import Config
10
+ from vision_agent.lmm import LMM
11
+ from vision_agent.models import (
14
12
  AgentMessage,
15
13
  CodeContext,
16
14
  InteractionContext,
15
+ Message,
17
16
  PlanContext,
18
17
  )
19
- from vision_agent.agent.vision_agent_coder_v2 import format_code_context
20
- from vision_agent.agent.vision_agent_prompts_v2 import CONVERSATION
21
- from vision_agent.configs import Config
22
- from vision_agent.lmm import LMM
23
- from vision_agent.lmm.types import Message
18
+ from vision_agent.utils.agent import (
19
+ add_media_to_chat,
20
+ convert_message_to_agentmessage,
21
+ extract_tag,
22
+ format_conversation,
23
+ )
24
24
  from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
25
25
 
26
26
  CONFIG = Config()
@@ -5,7 +5,7 @@ from uuid import UUID
5
5
  from requests.exceptions import HTTPError
6
6
 
7
7
  from vision_agent.clients.http import BaseHTTP
8
- from vision_agent.tools.tools_types import BboxInputBase64, JobStatus, PromptTask
8
+ from vision_agent.models import BboxInputBase64, JobStatus, PromptTask
9
9
  from vision_agent.utils.exceptions import FineTuneModelNotFound
10
10
  from vision_agent.utils.type_defs import LandingaiAPIKey
11
11
 
@@ -96,13 +96,24 @@ class Config(BaseModel):
96
96
  }
97
97
  )
98
98
 
99
+ # for get_tool_for_task
100
+ od_judge: Type[LMM] = Field(default=AnthropicLMM)
101
+ od_judge_kwargs: dict = Field(
102
+ default_factory=lambda: {
103
+ "model_name": "claude-3-5-sonnet-20241022",
104
+ "temperature": 0.0,
105
+ "image_size": 512,
106
+ }
107
+ )
108
+
99
109
  # for suggestions module
100
- suggester: Type[LMM] = Field(default=AnthropicLMM)
110
+ suggester: Type[LMM] = Field(default=OpenAILMM)
101
111
  suggester_kwargs: dict = Field(
102
112
  default_factory=lambda: {
103
- "model_name": "claude-3-5-sonnet-20241022",
113
+ "model_name": "o1",
104
114
  "temperature": 1.0,
105
- "image_size": 768,
115
+ "image_detail": "high",
116
+ "image_size": 1024,
106
117
  }
107
118
  )
108
119
 
@@ -143,6 +154,9 @@ class Config(BaseModel):
143
154
  def create_tool_chooser(self) -> LMM:
144
155
  return self.tool_chooser(**self.tool_chooser_kwargs)
145
156
 
157
+ def create_od_judge(self) -> LMM:
158
+ return self.od_judge(**self.od_judge_kwargs)
159
+
146
160
  def create_suggester(self) -> LMM:
147
161
  return self.suggester(**self.suggester_kwargs)
148
162
 
@@ -96,13 +96,24 @@ class Config(BaseModel):
96
96
  }
97
97
  )
98
98
 
99
+ # for get_tool_for_task
100
+ od_judge: Type[LMM] = Field(default=AnthropicLMM)
101
+ od_judge_kwargs: dict = Field(
102
+ default_factory=lambda: {
103
+ "model_name": "claude-3-5-sonnet-20241022",
104
+ "temperature": 0.0,
105
+ "image_size": 512,
106
+ }
107
+ )
108
+
99
109
  # for suggestions module
100
- suggester: Type[LMM] = Field(default=AnthropicLMM)
110
+ suggester: Type[LMM] = Field(default=OpenAILMM)
101
111
  suggester_kwargs: dict = Field(
102
112
  default_factory=lambda: {
103
- "model_name": "claude-3-5-sonnet-20241022",
113
+ "model_name": "o1",
104
114
  "temperature": 1.0,
105
- "image_size": 768,
115
+ "image_detail": "high",
116
+ "image_size": 1024,
106
117
  }
107
118
  )
108
119
 
@@ -143,6 +154,9 @@ class Config(BaseModel):
143
154
  def create_tool_chooser(self) -> LMM:
144
155
  return self.tool_chooser(**self.tool_chooser_kwargs)
145
156
 
157
+ def create_od_judge(self) -> LMM:
158
+ return self.od_judge(**self.od_judge_kwargs)
159
+
146
160
  def create_suggester(self) -> LMM:
147
161
  return self.suggester(**self.suggester_kwargs)
148
162
 
@@ -1,2 +1 @@
1
1
  from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, GoogleLMM, OllamaLMM, OpenAILMM
2
- from .types import Message
vision_agent/lmm/lmm.py CHANGED
@@ -9,10 +9,9 @@ import requests
9
9
  from anthropic.types import ImageBlockParam, MessageParam, TextBlockParam
10
10
  from openai import AzureOpenAI, OpenAI
11
11
 
12
+ from vision_agent.models import Message
12
13
  from vision_agent.utils.image_utils import encode_media
13
14
 
14
- from .types import Message
15
-
16
15
 
17
16
  class LMM(ABC):
18
17
  @abstractmethod
@@ -64,7 +63,9 @@ class OpenAILMM(LMM):
64
63
  self.image_size = image_size
65
64
  self.image_detail = image_detail
66
65
  # o1 does not use max_tokens
67
- if "max_tokens" not in kwargs and not model_name.startswith("o1"):
66
+ if "max_tokens" not in kwargs and not (
67
+ model_name.startswith("o1") or model_name.startswith("o3")
68
+ ):
68
69
  kwargs["max_tokens"] = max_tokens
69
70
  if json_mode:
70
71
  kwargs["response_format"] = {"type": "json_object"}
@@ -0,0 +1,11 @@
1
+ from .agent_types import AgentMessage, CodeContext, InteractionContext, PlanContext
2
+ from .lmm_types import Message, TextOrImage
3
+ from .tools_types import (
4
+ BboxInput,
5
+ BboxInputBase64,
6
+ BoundingBoxes,
7
+ Florence2FtRequest,
8
+ JobStatus,
9
+ ODResponseData,
10
+ PromptTask,
11
+ )
@@ -1,7 +1,10 @@
1
1
  from pathlib import Path
2
2
  from typing import Dict, Sequence, Union
3
3
 
4
+ import numpy as np
5
+ from PIL.Image import Image as ImageType
6
+
4
7
  from vision_agent.utils.execute import Execution
5
8
 
6
- TextOrImage = Union[str, Sequence[Union[str, Path]]]
9
+ TextOrImage = Union[str, Sequence[Union[str, Path, ImageType, np.ndarray]]]
7
10
  Message = Dict[str, Union[TextOrImage, Execution]]
@@ -0,0 +1,9 @@
1
+ from .sim import (
2
+ AzureSim,
3
+ OllamaSim,
4
+ Sim,
5
+ StellaSim,
6
+ get_tool_recommender,
7
+ load_cached_sim,
8
+ load_sim,
9
+ )
@@ -12,17 +12,17 @@ import requests
12
12
  from openai import AzureOpenAI, OpenAI
13
13
  from scipy.spatial.distance import cosine # type: ignore
14
14
 
15
- from vision_agent.tools.tool_utils import (
15
+ from vision_agent.tools.tools import get_tools_df
16
+ from vision_agent.utils.tools import (
16
17
  _LND_API_KEY,
17
18
  _create_requests_session,
18
19
  _LND_API_URL_v2,
19
20
  )
20
- from vision_agent.tools.tools import TOOLS_DF
21
21
 
22
22
 
23
23
  @lru_cache(maxsize=1)
24
24
  def get_tool_recommender() -> "Sim":
25
- return load_cached_sim(TOOLS_DF)
25
+ return load_cached_sim(get_tools_df())
26
26
 
27
27
 
28
28
  @lru_cache(maxsize=512)
@@ -12,17 +12,10 @@ from .meta_tools import (
12
12
  use_object_detection_fine_tuning,
13
13
  view_media_artifact,
14
14
  )
15
+ from .planner_tools import judge_od_results
15
16
  from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
16
- from .tool_utils import add_bboxes_from_masks, get_tool_descriptions_by_names
17
17
  from .tools import (
18
- FUNCTION_TOOLS,
19
- TOOL_DESCRIPTIONS,
20
- TOOL_DOCSTRING,
21
- TOOLS,
22
- TOOLS_DF,
23
- TOOLS_INFO,
24
- UTIL_TOOLS,
25
- UTILITIES_DOCSTRING,
18
+ activity_recognition,
26
19
  agentic_object_detection,
27
20
  agentic_sam2_instance_segmentation,
28
21
  agentic_sam2_video_tracking,
@@ -45,7 +38,11 @@ from .tools import (
45
38
  florence2_sam2_video_tracking,
46
39
  flux_image_inpainting,
47
40
  generate_pose_image,
48
- get_tool_documentation,
41
+ get_tools,
42
+ get_tools_descriptions,
43
+ get_tools_df,
44
+ get_tools_docstring,
45
+ get_utilties_docstring,
49
46
  load_image,
50
47
  minimum_distance,
51
48
  ocr,
@@ -64,7 +61,6 @@ from .tools import (
64
61
  save_video,
65
62
  siglip_classification,
66
63
  template_match,
67
- video_temporal_localization,
68
64
  vit_image_classification,
69
65
  vit_nsfw_classification,
70
66
  )
@@ -79,20 +75,11 @@ def register_tool(imports: Optional[List] = None) -> Callable:
79
75
  def decorator(tool: Callable) -> Callable:
80
76
  import inspect
81
77
 
82
- from .tools import ( # noqa: F811
83
- get_tool_descriptions,
84
- get_tools_df,
85
- get_tools_info,
86
- )
87
-
88
78
  global TOOLS, TOOLS_DF, TOOL_DESCRIPTIONS, TOOL_DOCSTRING, TOOLS_INFO
79
+ from vision_agent.tools.tools import TOOLS
89
80
 
90
- if tool not in TOOLS:
91
- TOOLS.append(tool)
92
- TOOLS_DF = get_tools_df(TOOLS) # type: ignore
93
- TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
94
- TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
95
- TOOLS_INFO = get_tools_info(TOOLS) # type: ignore
81
+ if tool not in TOOLS: # type: ignore
82
+ TOOLS.append(tool) # type: ignore
96
83
 
97
84
  globals()[tool.__name__] = tool
98
85
  if imports is not None:
@@ -12,12 +12,11 @@ from IPython.display import display
12
12
 
13
13
  import vision_agent as va
14
14
  from vision_agent.clients.landing_public_api import LandingPublicAPI
15
- from vision_agent.lmm.types import Message
16
- from vision_agent.tools.tool_utils import get_tool_documentation
17
- from vision_agent.tools.tools import TOOL_DESCRIPTIONS
18
- from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
15
+ from vision_agent.models import BboxInput, BboxInputBase64, Message, PromptTask
16
+ from vision_agent.tools.tools import get_tools_descriptions as _get_tool_descriptions
19
17
  from vision_agent.utils.execute import Execution, MimeType
20
18
  from vision_agent.utils.image_utils import convert_to_b64
19
+ from vision_agent.utils.tools_doc import get_tool_documentation
21
20
 
22
21
  CURRENT_FILE = None
23
22
  CURRENT_LINE = 0
@@ -571,7 +570,7 @@ def get_tool_descriptions() -> str:
571
570
  """Returns a description of all the tools that `generate_vision_code` has access to.
572
571
  Helpful for answering questions about what types of vision tasks you can do with
573
572
  `generate_vision_code`."""
574
- return TOOL_DESCRIPTIONS
573
+ return _get_tool_descriptions()
575
574
 
576
575
 
577
576
  def object_detection_fine_tuning(bboxes: List[Dict[str, Any]]) -> str:
@@ -1,5 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
+ import math
4
+ import random
3
5
  import tempfile
4
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
7
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
@@ -10,7 +12,6 @@ from IPython.display import display
10
12
  from PIL import Image
11
13
 
12
14
  import vision_agent.tools as T
13
- from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
14
15
  from vision_agent.agent.vision_agent_planner_prompts_v2 import (
15
16
  CATEGORIZE_TOOL_REQUEST,
16
17
  FINALIZE_PLAN,
@@ -21,6 +22,9 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
21
22
  )
22
23
  from vision_agent.configs import Config
23
24
  from vision_agent.lmm import LMM, AnthropicLMM
25
+ from vision_agent.sim import get_tool_recommender
26
+ from vision_agent.tools.tools import get_tools, get_tools_info
27
+ from vision_agent.utils.agent import DefaultImports, extract_json, extract_tag
24
28
  from vision_agent.utils.execute import (
25
29
  CodeInterpreter,
26
30
  CodeInterpreterFactory,
@@ -28,12 +32,16 @@ from vision_agent.utils.execute import (
28
32
  MimeType,
29
33
  )
30
34
  from vision_agent.utils.image_utils import convert_to_b64
31
- from vision_agent.utils.sim import get_tool_recommender
35
+ from vision_agent.utils.tools_doc import get_tool_documentation
36
+
37
+
38
+ def get_tool_functions() -> Dict[str, Callable]:
39
+ return {tool.__name__: tool for tool in get_tools()}
40
+
41
+
42
+ def get_load_tools_docstring() -> str:
43
+ return get_tool_documentation([T.load_image, T.extract_frames_and_timestamps])
32
44
 
33
- TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
34
- LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
35
- [T.load_image, T.extract_frames_and_timestamps]
36
- )
37
45
 
38
46
  CONFIG = Config()
39
47
  _LOGGER = logging.getLogger(__name__)
@@ -50,6 +58,59 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
50
58
  return return_str
51
59
 
52
60
 
61
+ def judge_od_results(
62
+ prompt: str,
63
+ image: np.ndarray,
64
+ detections: List[Dict[str, Any]],
65
+ ) -> str:
66
+ """Given an image and the detections, this function will judge the results and
67
+ return the thoughts on the results.
68
+
69
+ Parameters:
70
+ prompt (str): The prompt that was used to generate the detections.
71
+ image (np.ndarray): The image that the detections were made on.
72
+ detections (List[Dict[str, Any]]): The detections made on the image.
73
+
74
+ Returns:
75
+ str: The thoughts on the results.
76
+ """
77
+
78
+ if not detections:
79
+ return "No detections found in the image."
80
+
81
+ od_judge = CONFIG.create_od_judge()
82
+ max_crop_size = (512, 512)
83
+
84
+ # Randomly sample up to 10 detections
85
+ num_samples = min(10, len(detections))
86
+ sampled_detections = random.sample(detections, num_samples)
87
+ crops = []
88
+ h, w = image.shape[:2]
89
+
90
+ for detection in sampled_detections:
91
+ if "bbox" not in detection:
92
+ continue
93
+ x1, y1, x2, y2 = detection["bbox"]
94
+ crop = image[int(y1 * h) : int(y2 * h), int(x1 * w) : int(x2 * w)]
95
+ if crop.shape[0] > max_crop_size[0] or crop.shape[1] > max_crop_size[1]:
96
+ crop = Image.fromarray(crop) # type: ignore
97
+ crop.thumbnail(max_crop_size) # type: ignore
98
+ crop = np.array(crop)
99
+ crops.append("data:image/png;base64," + convert_to_b64(crop))
100
+
101
+ sampled_detection_info = [
102
+ {"score": d["score"], "label": d["label"]} for d in sampled_detections
103
+ ]
104
+
105
+ prompt = f"""The user is trying to detect '{prompt}' in an image. You are shown 10 images which represent crops of the detected objets. Below is the detection labels and scores:
106
+ {sampled_detection_info}
107
+
108
+ Look over each of the cropped images and corresponding labels and scores. Provide a judgement on whether or not the results are correct. If the results are incorrect you can only suggest a different prompt or a threshold."""
109
+
110
+ response = cast(str, od_judge.generate(prompt, media=crops))
111
+ return response
112
+
113
+
53
114
  def run_multi_judge(
54
115
  tool_chooser: LMM,
55
116
  tool_docs_str: str,
@@ -57,6 +118,7 @@ def run_multi_judge(
57
118
  code: str,
58
119
  tool_output_str: str,
59
120
  image_paths: List[str],
121
+ n_judges: int = 3,
60
122
  ) -> Tuple[Optional[Callable], str, str]:
61
123
  error_message = ""
62
124
  prompt = PICK_TOOL.format(
@@ -77,7 +139,7 @@ def run_multi_judge(
77
139
 
78
140
  responses = []
79
141
  with ThreadPoolExecutor() as executor:
80
- futures = [executor.submit(run_judge) for _ in range(3)]
142
+ futures = [executor.submit(run_judge) for _ in range(n_judges)]
81
143
  for future in as_completed(futures):
82
144
  responses.append(future.result())
83
145
 
@@ -86,7 +148,7 @@ def run_multi_judge(
86
148
  for tool, tool_thoughts, tool_docstring in responses:
87
149
  if tool is not None:
88
150
  counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
89
- if counts[tool.__name__] >= 2:
151
+ if counts[tool.__name__] >= math.ceil(n_judges / 2):
90
152
  return tool, tool_thoughts, tool_docstring
91
153
 
92
154
  if len(responses) == 0:
@@ -104,9 +166,12 @@ def extract_tool_info(
104
166
  tool_thoughts = tool_choice_context.get("thoughts", "")
105
167
  tool_docstring = ""
106
168
  tool = tool_choice_context.get("best_tool", None)
107
- if tool in TOOL_FUNCTIONS:
108
- tool = TOOL_FUNCTIONS[tool]
109
- tool_docstring = T.TOOLS_INFO[tool.__name__]
169
+ tools_info = get_tools_info()
170
+
171
+ tool_functions = get_tool_functions()
172
+ if tool in tool_functions:
173
+ tool = tool_functions[tool]
174
+ tool_docstring = tools_info[tool.__name__]
110
175
 
111
176
  return tool, tool_thoughts, tool_docstring, ""
112
177
 
@@ -153,6 +218,42 @@ def replace_box_threshold(code: str, functions: List[str], box_threshold: float)
153
218
  return new_tree.code
154
219
 
155
220
 
221
+ def retrieve_tool_docs(lmm: LMM, task: str, exclude_tools: Optional[List[str]]) -> str:
222
+ query = cast(str, lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task)))
223
+ categories_str = extract_tag(query, "category")
224
+ if categories_str is None:
225
+ categories = []
226
+ else:
227
+ categories = [e.strip() for e in categories_str.split(",")]
228
+
229
+ explanation = query.split("<category>")[0].strip()
230
+ if "</category>" in query:
231
+ explanation += " " + query.split("</category>")[1].strip()
232
+ explanation = explanation.strip()
233
+
234
+ sim = get_tool_recommender()
235
+
236
+ all_tool_docs = []
237
+ all_tool_doc_names = set()
238
+ exclude_tools = [] if exclude_tools is None else exclude_tools
239
+ for category in categories:
240
+ tool_docs = sim.top_k(category, k=3, thresh=0.3)
241
+
242
+ for tool_doc in tool_docs:
243
+ if (
244
+ tool_doc["name"] not in all_tool_doc_names
245
+ and tool_doc["name"] not in exclude_tools
246
+ ):
247
+ all_tool_docs.append(tool_doc)
248
+ all_tool_doc_names.add(tool_doc["name"])
249
+
250
+ tool_docs_str = explanation + "\n\n" + "\n".join([e["doc"] for e in all_tool_docs])
251
+ tool_docs_str += (
252
+ "\n" + get_load_tools_docstring() + get_tool_documentation([judge_od_results])
253
+ )
254
+ return tool_docs_str
255
+
256
+
156
257
  def run_tool_testing(
157
258
  task: str,
158
259
  image_paths: List[str],
@@ -162,22 +263,8 @@ def run_tool_testing(
162
263
  process_code: Callable[[str], str] = lambda x: x,
163
264
  ) -> tuple[str, str, Execution]:
164
265
  """Helper function to generate and run tool testing code."""
165
- query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
166
- category = extract_tag(query, "category") # type: ignore
167
- if category is None:
168
- query = task
169
- else:
170
- query = f"{category.strip()}. {task}"
171
266
 
172
- tool_docs = get_tool_recommender().top_k(query, k=5, thresh=0.3)
173
- if exclude_tools is not None and len(exclude_tools) > 0:
174
- cleaned_tool_docs = []
175
- for tool_doc in tool_docs:
176
- if not tool_doc["name"] in exclude_tools:
177
- cleaned_tool_docs.append(tool_doc)
178
- tool_docs = cleaned_tool_docs
179
- tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
180
- tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING
267
+ tool_docs_str = retrieve_tool_docs(lmm, task, exclude_tools)
181
268
 
182
269
  prompt = TEST_TOOLS.format(
183
270
  tool_docs=tool_docs_str,
@@ -281,6 +368,15 @@ def get_tool_for_task(
281
368
  tool_tester = CONFIG.create_tool_tester()
282
369
  tool_chooser = CONFIG.create_tool_chooser()
283
370
 
371
+ if isinstance(images, list):
372
+ if len(images) > 0 and isinstance(images[0], dict):
373
+ if all(["frame" in image for image in images]):
374
+ images = [image["frame"] for image in images]
375
+ else:
376
+ raise ValueError(
377
+ f"Expected a list of numpy arrays or a dictionary of strings to lists of numpy arrays, got a list of dictionaries instead: {images}"
378
+ )
379
+
284
380
  if isinstance(images, list):
285
381
  images = {"image": images}
286
382
 
@@ -295,24 +391,26 @@ def get_tool_for_task(
295
391
  Image.fromarray(image).save(image_path)
296
392
  image_paths.append(image_path)
297
393
 
394
+ # run no more than 3 images or else it overloads the LLM
395
+ image_paths = image_paths[:3]
298
396
  code, tool_docs_str, tool_output = run_tool_testing(
299
397
  task, image_paths, tool_tester, exclude_tools, code_interpreter
300
398
  )
301
399
  tool_output_str = tool_output.text(include_results=False).strip()
302
400
 
303
401
  _, tool_thoughts, tool_docstring = run_multi_judge(
304
- tool_chooser, tool_docs_str, task, code, tool_output_str, image_paths
402
+ tool_chooser,
403
+ tool_docs_str,
404
+ task,
405
+ code,
406
+ tool_output_str,
407
+ image_paths,
408
+ n_judges=3,
305
409
  )
306
410
 
307
411
  print(format_tool_output(tool_thoughts, tool_docstring))
308
412
 
309
413
 
310
- def get_tool_documentation(tool_name: str) -> str:
311
- # use same format as get_tool_for_task
312
- tool_doc = T.TOOLS_DF[T.TOOLS_DF["name"] == tool_name]["doc"].values[0]
313
- return format_tool_output("", tool_doc)
314
-
315
-
316
414
  def get_tool_for_task_human_reviewer(
317
415
  task: str,
318
416
  images: Union[Dict[str, List[np.ndarray]], List[np.ndarray]],
@@ -321,6 +419,15 @@ def get_tool_for_task_human_reviewer(
321
419
  # NOTE: this will have the same documentation as get_tool_for_task
322
420
  tool_tester = CONFIG.create_tool_tester()
323
421
 
422
+ if isinstance(images, list):
423
+ if len(images) > 0 and isinstance(images[0], dict):
424
+ if all(["frame" in image for image in images]):
425
+ images = [image["frame"] for image in images]
426
+ else:
427
+ raise ValueError(
428
+ f"Expected a list of numpy arrays or a dictionary of strings to lists of numpy arrays, got a list of dictionaries instead: {images}"
429
+ )
430
+
324
431
  if isinstance(images, list):
325
432
  images = {"image": images}
326
433
 
@@ -335,10 +442,13 @@ def get_tool_for_task_human_reviewer(
335
442
  Image.fromarray(image).save(image_path)
336
443
  image_paths.append(image_path)
337
444
 
445
+ # run no more than 3 images or else it overloads the LLM
446
+ image_paths = image_paths[:3]
447
+
338
448
  tools = [
339
449
  t.__name__
340
- for t in T.TOOLS
341
- if inspect.signature(t).parameters.get("box_threshold") # type: ignore
450
+ for t in get_tools()
451
+ if inspect.signature(t).parameters.get("box_threshold")
342
452
  ]
343
453
 
344
454
  _, _, tool_output = run_tool_testing(
@@ -414,7 +524,8 @@ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
414
524
  a problem.
415
525
 
416
526
  Parameters:
417
- prompt: str: The problem statement.
527
+ prompt: str: The problem statement, provide a detailed description of the
528
+ problem you are trying to solve.
418
529
  medias: List[np.ndarray]: The images to use for the problem
419
530
  """
420
531
  try:
@@ -431,4 +542,4 @@ PLANNER_TOOLS = [
431
542
  suggestion,
432
543
  get_tool_for_task,
433
544
  ]
434
- PLANNER_DOCSTRING = T.get_tool_documentation(PLANNER_TOOLS) # type: ignore
545
+ PLANNER_DOCSTRING = get_tool_documentation(PLANNER_TOOLS) # type: ignore