vision-agent 0.2.228__py3-none-any.whl → 0.2.230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +10 -8
- vision_agent/agent/agent_utils.py +10 -9
- vision_agent/agent/vision_agent.py +3 -4
- vision_agent/agent/vision_agent_coder_prompts.py +6 -6
- vision_agent/agent/vision_agent_coder_v2.py +41 -26
- vision_agent/agent/vision_agent_planner_prompts.py +6 -6
- vision_agent/agent/vision_agent_planner_prompts_v2.py +16 -50
- vision_agent/agent/vision_agent_planner_v2.py +10 -12
- vision_agent/agent/vision_agent_prompts.py +11 -11
- vision_agent/agent/vision_agent_prompts_v2.py +18 -3
- vision_agent/agent/vision_agent_v2.py +29 -30
- vision_agent/configs/__init__.py +1 -0
- vision_agent/configs/anthropic_config.py +150 -0
- vision_agent/configs/anthropic_openai_config.py +150 -0
- vision_agent/configs/config.py +150 -0
- vision_agent/configs/openai_config.py +160 -0
- vision_agent/lmm/__init__.py +1 -1
- vision_agent/lmm/lmm.py +63 -9
- vision_agent/tools/planner_tools.py +60 -40
- vision_agent/tools/tool_utils.py +1 -2
- vision_agent/tools/tools.py +10 -8
- vision_agent-0.2.230.dist-info/METADATA +156 -0
- {vision_agent-0.2.228.dist-info → vision_agent-0.2.230.dist-info}/RECORD +25 -20
- vision_agent-0.2.228.dist-info/METADATA +0 -562
- {vision_agent-0.2.228.dist-info → vision_agent-0.2.230.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.228.dist-info → vision_agent-0.2.230.dist-info}/WHEEL +0 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
from typing import Type
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from vision_agent.lmm import LMM, OpenAILMM
|
6
|
+
|
7
|
+
|
8
|
+
class Config(BaseModel):
|
9
|
+
# for vision_agent_v2
|
10
|
+
agent: Type[LMM] = Field(default=OpenAILMM)
|
11
|
+
agent_kwargs: dict = Field(
|
12
|
+
default_factory=lambda: {
|
13
|
+
"model_name": "gpt-4o-2024-08-06",
|
14
|
+
"temperature": 0.0,
|
15
|
+
"image_size": 768,
|
16
|
+
"image_detail": "low",
|
17
|
+
}
|
18
|
+
)
|
19
|
+
|
20
|
+
# for vision_agent_planner_v2
|
21
|
+
planner: Type[LMM] = Field(default=OpenAILMM)
|
22
|
+
planner_kwargs: dict = Field(
|
23
|
+
default_factory=lambda: {
|
24
|
+
"model_name": "gpt-4o-2024-08-06",
|
25
|
+
"temperature": 0.0,
|
26
|
+
"image_size": 768,
|
27
|
+
"image_detail": "low",
|
28
|
+
}
|
29
|
+
)
|
30
|
+
|
31
|
+
# for vision_agent_planner_v2
|
32
|
+
summarizer: Type[LMM] = Field(default=OpenAILMM)
|
33
|
+
summarizer_kwargs: dict = Field(
|
34
|
+
default_factory=lambda: {
|
35
|
+
"model_name": "o1",
|
36
|
+
"temperature": 1.0,
|
37
|
+
"image_size": 768,
|
38
|
+
}
|
39
|
+
)
|
40
|
+
|
41
|
+
# for vision_agent_planner_v2
|
42
|
+
critic: Type[LMM] = Field(default=OpenAILMM)
|
43
|
+
critic_kwargs: dict = Field(
|
44
|
+
default_factory=lambda: {
|
45
|
+
"model_name": "gpt-4o-2024-08-06",
|
46
|
+
"temperature": 0.0,
|
47
|
+
"image_size": 768,
|
48
|
+
"image_detail": "low",
|
49
|
+
}
|
50
|
+
)
|
51
|
+
|
52
|
+
# for vision_agent_coder_v2
|
53
|
+
coder: Type[LMM] = Field(default=OpenAILMM)
|
54
|
+
coder_kwargs: dict = Field(
|
55
|
+
default_factory=lambda: {
|
56
|
+
"model_name": "gpt-4o-2024-08-06",
|
57
|
+
"temperature": 0.0,
|
58
|
+
"image_size": 768,
|
59
|
+
"image_detail": "low",
|
60
|
+
}
|
61
|
+
)
|
62
|
+
|
63
|
+
# for vision_agent_coder_v2
|
64
|
+
tester: Type[LMM] = Field(default=OpenAILMM)
|
65
|
+
tester_kwargs: dict = Field(
|
66
|
+
default_factory=lambda: {
|
67
|
+
"model_name": "gpt-4o-2024-08-06",
|
68
|
+
"temperature": 0.0,
|
69
|
+
"image_size": 768,
|
70
|
+
"image_detail": "low",
|
71
|
+
}
|
72
|
+
)
|
73
|
+
|
74
|
+
# for vision_agent_coder_v2
|
75
|
+
debugger: Type[LMM] = Field(default=OpenAILMM)
|
76
|
+
debugger_kwargs: dict = Field(
|
77
|
+
default_factory=lambda: {
|
78
|
+
"model_name": "gpt-4o-2024-08-06",
|
79
|
+
"temperature": 0.0,
|
80
|
+
"image_size": 768,
|
81
|
+
"image_detail": "low",
|
82
|
+
}
|
83
|
+
)
|
84
|
+
|
85
|
+
# for get_tool_for_task
|
86
|
+
tool_tester: Type[LMM] = Field(default=OpenAILMM)
|
87
|
+
tool_tester_kwargs: dict = Field(
|
88
|
+
default_factory=lambda: {
|
89
|
+
"model_name": "gpt-4o-2024-08-06",
|
90
|
+
"temperature": 0.0,
|
91
|
+
"image_size": 768,
|
92
|
+
"image_detail": "low",
|
93
|
+
}
|
94
|
+
)
|
95
|
+
|
96
|
+
# for get_tool_for_task
|
97
|
+
tool_chooser: Type[LMM] = Field(default=OpenAILMM)
|
98
|
+
tool_chooser_kwargs: dict = Field(
|
99
|
+
default_factory=lambda: {
|
100
|
+
"model_name": "gpt-4o-2024-08-06",
|
101
|
+
"temperature": 0.0,
|
102
|
+
"image_size": 768,
|
103
|
+
"image_detail": "low",
|
104
|
+
}
|
105
|
+
)
|
106
|
+
|
107
|
+
# for suggestions module
|
108
|
+
suggester: Type[LMM] = Field(default=OpenAILMM)
|
109
|
+
suggester_kwargs: dict = Field(
|
110
|
+
default_factory=lambda: {
|
111
|
+
"model_name": "gpt-4o-2024-08-06",
|
112
|
+
"temperature": 0.0,
|
113
|
+
"image_size": 768,
|
114
|
+
"image_detail": "low",
|
115
|
+
}
|
116
|
+
)
|
117
|
+
|
118
|
+
# for vqa module
|
119
|
+
vqa: Type[LMM] = Field(default=OpenAILMM)
|
120
|
+
vqa_kwargs: dict = Field(
|
121
|
+
default_factory=lambda: {
|
122
|
+
"model_name": "gpt-4o-2024-08-06",
|
123
|
+
"temperature": 0.0,
|
124
|
+
"image_size": 768,
|
125
|
+
"image_detail": "low",
|
126
|
+
}
|
127
|
+
)
|
128
|
+
|
129
|
+
def create_agent(self) -> LMM:
|
130
|
+
return self.agent(**self.agent_kwargs)
|
131
|
+
|
132
|
+
def create_planner(self) -> LMM:
|
133
|
+
return self.planner(**self.planner_kwargs)
|
134
|
+
|
135
|
+
def create_summarizer(self) -> LMM:
|
136
|
+
return self.summarizer(**self.summarizer_kwargs)
|
137
|
+
|
138
|
+
def create_critic(self) -> LMM:
|
139
|
+
return self.critic(**self.critic_kwargs)
|
140
|
+
|
141
|
+
def create_coder(self) -> LMM:
|
142
|
+
return self.coder(**self.coder_kwargs)
|
143
|
+
|
144
|
+
def create_tester(self) -> LMM:
|
145
|
+
return self.tester(**self.tester_kwargs)
|
146
|
+
|
147
|
+
def create_debugger(self) -> LMM:
|
148
|
+
return self.debugger(**self.debugger_kwargs)
|
149
|
+
|
150
|
+
def create_tool_tester(self) -> LMM:
|
151
|
+
return self.tool_tester(**self.tool_tester_kwargs)
|
152
|
+
|
153
|
+
def create_tool_chooser(self) -> LMM:
|
154
|
+
return self.tool_chooser(**self.tool_chooser_kwargs)
|
155
|
+
|
156
|
+
def create_suggester(self) -> LMM:
|
157
|
+
return self.suggester(**self.suggester_kwargs)
|
158
|
+
|
159
|
+
def create_vqa(self) -> LMM:
|
160
|
+
return self.vqa(**self.vqa_kwargs)
|
vision_agent/lmm/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
|
1
|
+
from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, GoogleLMM, OllamaLMM, OpenAILMM
|
2
2
|
from .types import Message
|
vision_agent/lmm/lmm.py
CHANGED
@@ -50,6 +50,8 @@ class OpenAILMM(LMM):
|
|
50
50
|
api_key: Optional[str] = None,
|
51
51
|
max_tokens: int = 4096,
|
52
52
|
json_mode: bool = False,
|
53
|
+
image_size: int = 768,
|
54
|
+
image_detail: str = "low",
|
53
55
|
**kwargs: Any,
|
54
56
|
):
|
55
57
|
if not api_key:
|
@@ -59,7 +61,10 @@ class OpenAILMM(LMM):
|
|
59
61
|
|
60
62
|
self.client = OpenAI(api_key=api_key)
|
61
63
|
self.model_name = model_name
|
62
|
-
|
64
|
+
self.image_size = image_size
|
65
|
+
self.image_detail = image_detail
|
66
|
+
# o1 does not use max_tokens
|
67
|
+
if "max_tokens" not in kwargs and not model_name.startswith("o1"):
|
63
68
|
kwargs["max_tokens"] = max_tokens
|
64
69
|
if json_mode:
|
65
70
|
kwargs["response_format"] = {"type": "json_object"}
|
@@ -94,7 +99,13 @@ class OpenAILMM(LMM):
|
|
94
99
|
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
|
95
100
|
if "media" in c:
|
96
101
|
for media in c["media"]:
|
97
|
-
|
102
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
103
|
+
image_detail = (
|
104
|
+
kwargs["image_detail"]
|
105
|
+
if "image_detail" in kwargs
|
106
|
+
else self.image_detail
|
107
|
+
)
|
108
|
+
encoded_media = encode_media(cast(str, media), resize=resize)
|
98
109
|
|
99
110
|
fixed_c["content"].append( # type: ignore
|
100
111
|
{
|
@@ -106,7 +117,7 @@ class OpenAILMM(LMM):
|
|
106
117
|
or encoded_media.startswith("data:image/")
|
107
118
|
else f"data:image/png;base64,{encoded_media}"
|
108
119
|
),
|
109
|
-
"detail":
|
120
|
+
"detail": image_detail,
|
110
121
|
},
|
111
122
|
},
|
112
123
|
)
|
@@ -144,7 +155,13 @@ class OpenAILMM(LMM):
|
|
144
155
|
]
|
145
156
|
if media and len(media) > 0:
|
146
157
|
for m in media:
|
147
|
-
|
158
|
+
resize = kwargs["resize"] if "resize" in kwargs else None
|
159
|
+
image_detail = (
|
160
|
+
kwargs["image_detail"]
|
161
|
+
if "image_detail" in kwargs
|
162
|
+
else self.image_detail
|
163
|
+
)
|
164
|
+
encoded_media = encode_media(m, resize=resize)
|
148
165
|
message[0]["content"].append(
|
149
166
|
{
|
150
167
|
"type": "image_url",
|
@@ -155,7 +172,7 @@ class OpenAILMM(LMM):
|
|
155
172
|
or encoded_media.startswith("data:image/")
|
156
173
|
else f"data:image/png;base64,{encoded_media}"
|
157
174
|
),
|
158
|
-
"detail":
|
175
|
+
"detail": image_detail,
|
159
176
|
},
|
160
177
|
},
|
161
178
|
)
|
@@ -186,6 +203,7 @@ class AzureOpenAILMM(OpenAILMM):
|
|
186
203
|
azure_endpoint: Optional[str] = None,
|
187
204
|
max_tokens: int = 4096,
|
188
205
|
json_mode: bool = False,
|
206
|
+
image_detail: str = "low",
|
189
207
|
**kwargs: Any,
|
190
208
|
):
|
191
209
|
if not api_key:
|
@@ -208,6 +226,7 @@ class AzureOpenAILMM(OpenAILMM):
|
|
208
226
|
azure_endpoint=azure_endpoint,
|
209
227
|
)
|
210
228
|
self.model_name = model_name
|
229
|
+
self.image_detail = image_detail
|
211
230
|
|
212
231
|
if "max_tokens" not in kwargs:
|
213
232
|
kwargs["max_tokens"] = max_tokens
|
@@ -225,6 +244,7 @@ class OllamaLMM(LMM):
|
|
225
244
|
base_url: Optional[str] = "http://localhost:11434/api",
|
226
245
|
json_mode: bool = False,
|
227
246
|
num_ctx: int = 128_000,
|
247
|
+
image_size: int = 768,
|
228
248
|
**kwargs: Any,
|
229
249
|
):
|
230
250
|
"""Initializes the Ollama LMM. kwargs are passed as 'options' to the model.
|
@@ -241,6 +261,7 @@ class OllamaLMM(LMM):
|
|
241
261
|
|
242
262
|
self.url = base_url
|
243
263
|
self.model_name = model_name
|
264
|
+
self.image_size = image_size
|
244
265
|
self.kwargs = {"options": kwargs}
|
245
266
|
|
246
267
|
if json_mode:
|
@@ -273,8 +294,9 @@ class OllamaLMM(LMM):
|
|
273
294
|
fixed_chat = []
|
274
295
|
for message in chat:
|
275
296
|
if "media" in message:
|
297
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
276
298
|
message["images"] = [
|
277
|
-
encode_media(cast(str, m)) for m in message["media"]
|
299
|
+
encode_media(cast(str, m), resize=resize) for m in message["media"]
|
278
300
|
]
|
279
301
|
del message["media"]
|
280
302
|
fixed_chat.append(message)
|
@@ -328,7 +350,8 @@ class OllamaLMM(LMM):
|
|
328
350
|
|
329
351
|
if media and len(media) > 0:
|
330
352
|
for m in media:
|
331
|
-
|
353
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
354
|
+
data["images"].append(encode_media(m, resize=resize))
|
332
355
|
|
333
356
|
tmp_kwargs = self.kwargs | kwargs
|
334
357
|
data.update(tmp_kwargs)
|
@@ -370,9 +393,11 @@ class AnthropicLMM(LMM):
|
|
370
393
|
api_key: Optional[str] = None,
|
371
394
|
model_name: str = "claude-3-5-sonnet-20240620",
|
372
395
|
max_tokens: int = 4096,
|
396
|
+
image_size: int = 768,
|
373
397
|
**kwargs: Any,
|
374
398
|
):
|
375
399
|
self.client = anthropic.Anthropic(api_key=api_key)
|
400
|
+
self.image_size = image_size
|
376
401
|
self.model_name = model_name
|
377
402
|
if "max_tokens" not in kwargs:
|
378
403
|
kwargs["max_tokens"] = max_tokens
|
@@ -399,7 +424,8 @@ class AnthropicLMM(LMM):
|
|
399
424
|
]
|
400
425
|
if "media" in msg:
|
401
426
|
for media_path in msg["media"]:
|
402
|
-
|
427
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
428
|
+
encoded_media = encode_media(media_path, resize=resize)
|
403
429
|
if encoded_media.startswith("data:image/png;base64,"):
|
404
430
|
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
405
431
|
content.append(
|
@@ -448,7 +474,8 @@ class AnthropicLMM(LMM):
|
|
448
474
|
]
|
449
475
|
if media:
|
450
476
|
for m in media:
|
451
|
-
|
477
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
478
|
+
encoded_media = encode_media(m, resize=resize)
|
452
479
|
if encoded_media.startswith("data:image/png;base64,"):
|
453
480
|
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
454
481
|
content.append(
|
@@ -486,3 +513,30 @@ class AnthropicLMM(LMM):
|
|
486
513
|
return f()
|
487
514
|
else:
|
488
515
|
return cast(str, response.content[0].text)
|
516
|
+
|
517
|
+
|
518
|
+
class GoogleLMM(OpenAILMM):
|
519
|
+
r"""An LMM class for the Google LMMs."""
|
520
|
+
|
521
|
+
def __init__(
|
522
|
+
self,
|
523
|
+
api_key: Optional[str] = None,
|
524
|
+
model_name: str = "gemini-2.0-flash-exp",
|
525
|
+
max_tokens: int = 4096,
|
526
|
+
image_detail: str = "low",
|
527
|
+
image_size: int = 768,
|
528
|
+
**kwargs: Any,
|
529
|
+
):
|
530
|
+
base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
531
|
+
if not api_key:
|
532
|
+
api_key = os.environ.get("GEMINI_API_KEY")
|
533
|
+
|
534
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
535
|
+
|
536
|
+
self.model_name = model_name
|
537
|
+
self.image_size = image_size
|
538
|
+
self.image_detail = image_detail
|
539
|
+
|
540
|
+
if "max_tokens" not in kwargs:
|
541
|
+
kwargs["max_tokens"] = max_tokens
|
542
|
+
self.kwargs = kwargs
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
-
import shutil
|
4
3
|
import tempfile
|
4
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5
5
|
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
6
6
|
|
7
7
|
import libcst as cst
|
@@ -24,6 +24,7 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
|
24
24
|
TEST_TOOLS_EXAMPLE1,
|
25
25
|
TEST_TOOLS_EXAMPLE2,
|
26
26
|
)
|
27
|
+
from vision_agent.configs import Config
|
27
28
|
from vision_agent.lmm import LMM, AnthropicLMM
|
28
29
|
from vision_agent.utils.execute import (
|
29
30
|
CodeInterpreter,
|
@@ -36,6 +37,7 @@ from vision_agent.utils.sim import get_tool_recommender
|
|
36
37
|
|
37
38
|
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
38
39
|
|
40
|
+
CONFIG = Config()
|
39
41
|
_LOGGER = logging.getLogger(__name__)
|
40
42
|
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
41
43
|
|
@@ -50,6 +52,54 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
50
52
|
return return_str
|
51
53
|
|
52
54
|
|
55
|
+
def run_multi_judge(
|
56
|
+
tool_chooser: LMM,
|
57
|
+
tool_docs_str: str,
|
58
|
+
task: str,
|
59
|
+
code: str,
|
60
|
+
tool_output_str: str,
|
61
|
+
image_paths: List[str],
|
62
|
+
) -> Tuple[Optional[Callable], str, str]:
|
63
|
+
error_message = ""
|
64
|
+
prompt = PICK_TOOL.format(
|
65
|
+
tool_docs=tool_docs_str,
|
66
|
+
user_request=task,
|
67
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
68
|
+
previous_attempts=error_message,
|
69
|
+
)
|
70
|
+
|
71
|
+
def run_judge() -> Tuple[Optional[Callable], str, str]:
|
72
|
+
response = tool_chooser.generate(prompt, media=image_paths, temperature=1.0)
|
73
|
+
tool_choice_context = extract_tag(response, "json") # type: ignore
|
74
|
+
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
75
|
+
tool, tool_thoughts, tool_docstring, _ = extract_tool_info(
|
76
|
+
tool_choice_context_dict
|
77
|
+
)
|
78
|
+
return tool, tool_thoughts, tool_docstring
|
79
|
+
|
80
|
+
responses = []
|
81
|
+
with ThreadPoolExecutor() as executor:
|
82
|
+
futures = [executor.submit(run_judge) for _ in range(3)]
|
83
|
+
for future in as_completed(futures):
|
84
|
+
responses.append(future.result())
|
85
|
+
|
86
|
+
responses = [r for r in responses if r[0] is not None]
|
87
|
+
counts: Dict[str, int] = {}
|
88
|
+
for tool, tool_thoughts, tool_docstring in responses:
|
89
|
+
if tool is not None:
|
90
|
+
counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
|
91
|
+
if counts[tool.__name__] >= 2:
|
92
|
+
return tool, tool_thoughts, tool_docstring
|
93
|
+
|
94
|
+
if len(responses) == 0:
|
95
|
+
return (
|
96
|
+
None,
|
97
|
+
"No tool could be found, please try again with a different prompt or image",
|
98
|
+
"",
|
99
|
+
)
|
100
|
+
return responses[0]
|
101
|
+
|
102
|
+
|
53
103
|
def extract_tool_info(
|
54
104
|
tool_choice_context: Dict[str, Any],
|
55
105
|
) -> Tuple[Optional[Callable], str, str, str]:
|
@@ -212,7 +262,8 @@ def get_tool_for_task(
|
|
212
262
|
--------
|
213
263
|
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
214
264
|
"""
|
215
|
-
|
265
|
+
tool_tester = CONFIG.create_tool_tester()
|
266
|
+
tool_chooser = CONFIG.create_tool_chooser()
|
216
267
|
|
217
268
|
with (
|
218
269
|
tempfile.TemporaryDirectory() as tmpdirname,
|
@@ -225,45 +276,14 @@ def get_tool_for_task(
|
|
225
276
|
image_paths.append(image_path)
|
226
277
|
|
227
278
|
code, tool_docs_str, tool_output = run_tool_testing(
|
228
|
-
task, image_paths,
|
279
|
+
task, image_paths, tool_tester, exclude_tools, code_interpreter
|
229
280
|
)
|
230
281
|
tool_output_str = tool_output.text(include_results=False).strip()
|
231
282
|
|
232
|
-
|
233
|
-
|
234
|
-
tool_docs=tool_docs_str,
|
235
|
-
user_request=task,
|
236
|
-
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
237
|
-
previous_attempts=error_message,
|
238
|
-
)
|
239
|
-
|
240
|
-
response = lmm.generate(prompt, media=image_paths)
|
241
|
-
tool_choice_context = extract_tag(response, "json") # type: ignore
|
242
|
-
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
243
|
-
|
244
|
-
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
245
|
-
tool_choice_context_dict
|
283
|
+
_, tool_thoughts, tool_docstring = run_multi_judge(
|
284
|
+
tool_chooser, tool_docs_str, task, code, tool_output_str, image_paths
|
246
285
|
)
|
247
286
|
|
248
|
-
count = 1
|
249
|
-
while tool is None and count <= 3:
|
250
|
-
prompt = PICK_TOOL.format(
|
251
|
-
tool_docs=tool_docs_str,
|
252
|
-
user_request=task,
|
253
|
-
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
254
|
-
previous_attempts=error_message,
|
255
|
-
)
|
256
|
-
tool_choice_context_dict = extract_json(
|
257
|
-
lmm.generate(prompt, media=image_paths) # type: ignore
|
258
|
-
)
|
259
|
-
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
260
|
-
tool_choice_context_dict
|
261
|
-
)
|
262
|
-
try:
|
263
|
-
shutil.rmtree(tmpdirname)
|
264
|
-
except Exception as e:
|
265
|
-
_LOGGER.error(f"Error removing temp directory: {e}")
|
266
|
-
|
267
287
|
print(format_tool_output(tool_thoughts, tool_docstring))
|
268
288
|
|
269
289
|
|
@@ -277,7 +297,7 @@ def get_tool_for_task_human_reviewer(
|
|
277
297
|
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
278
298
|
) -> None:
|
279
299
|
# NOTE: this will have the same documentation as get_tool_for_task
|
280
|
-
|
300
|
+
tool_tester = CONFIG.create_tool_tester()
|
281
301
|
|
282
302
|
with (
|
283
303
|
tempfile.TemporaryDirectory() as tmpdirname,
|
@@ -298,7 +318,7 @@ def get_tool_for_task_human_reviewer(
|
|
298
318
|
_, _, tool_output = run_tool_testing(
|
299
319
|
task,
|
300
320
|
image_paths,
|
301
|
-
|
321
|
+
tool_tester,
|
302
322
|
exclude_tools,
|
303
323
|
code_interpreter,
|
304
324
|
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
|
@@ -349,7 +369,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
|
349
369
|
medias: List[np.ndarray]: The images to ask the question about, it could also
|
350
370
|
be frames from a video. You can send up to 5 frames from a video.
|
351
371
|
"""
|
352
|
-
|
372
|
+
vqa = CONFIG.create_vqa()
|
353
373
|
if isinstance(medias, np.ndarray):
|
354
374
|
medias = [medias]
|
355
375
|
if isinstance(medias, list) and len(medias) > 5:
|
@@ -358,7 +378,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
|
358
378
|
"data:image/png;base64," + convert_to_b64(media) for media in medias
|
359
379
|
]
|
360
380
|
|
361
|
-
response = cast(str,
|
381
|
+
response = cast(str, vqa.generate(prompt, media=all_media_b64))
|
362
382
|
print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
|
363
383
|
|
364
384
|
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -72,8 +72,7 @@ def send_inference_request(
|
|
72
72
|
|
73
73
|
response = _call_post(url, payload, session, files, function_name, is_form)
|
74
74
|
|
75
|
-
|
76
|
-
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
|
75
|
+
return response["data"]
|
77
76
|
|
78
77
|
|
79
78
|
def send_task_inference_request(
|
vision_agent/tools/tools.py
CHANGED
@@ -595,14 +595,14 @@ def owlv2_sam2_video_tracking(
|
|
595
595
|
def florence2_object_detection(
|
596
596
|
prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
|
597
597
|
) -> List[Dict[str, Any]]:
|
598
|
-
"""'florence2_object_detection' is a tool that can detect multiple
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
confidence scores of 1.0.
|
598
|
+
"""'florence2_object_detection' is a tool that can detect multiple objects given a
|
599
|
+
text prompt which can be object names or caption. You can optionally separate the
|
600
|
+
object names in the text with commas. It returns a list of bounding boxes with
|
601
|
+
normalized coordinates, label names and associated confidence scores of 1.0.
|
603
602
|
|
604
603
|
Parameters:
|
605
|
-
prompt (str): The prompt to ground to the image.
|
604
|
+
prompt (str): The prompt to ground to the image. Use exclusive categories that
|
605
|
+
do not overlap such as 'person, car' and NOT 'person, athlete'.
|
606
606
|
image (np.ndarray): The image to used to detect objects
|
607
607
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
608
608
|
fine-tuned model ID here to use it.
|
@@ -681,7 +681,8 @@ def florence2_sam2_instance_segmentation(
|
|
681
681
|
1.0.
|
682
682
|
|
683
683
|
Parameters:
|
684
|
-
prompt (str): The prompt to ground to the image.
|
684
|
+
prompt (str): The prompt to ground to the image. Use exclusive categories that
|
685
|
+
do not overlap such as 'person, car' and NOT 'person, athlete'.
|
685
686
|
image (np.ndarray): The image to ground the prompt to.
|
686
687
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
687
688
|
fine-tuned model ID here to use it.
|
@@ -769,7 +770,8 @@ def florence2_sam2_video_tracking(
|
|
769
770
|
is useful for tracking and counting without duplicating counts.
|
770
771
|
|
771
772
|
Parameters:
|
772
|
-
prompt (str): The prompt to ground to the
|
773
|
+
prompt (str): The prompt to ground to the image. Use exclusive categories that
|
774
|
+
do not overlap such as 'person, car' and NOT 'person, athlete'.
|
773
775
|
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
774
776
|
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
|
775
777
|
new objects.
|