vision-agent 0.2.229__py3-none-any.whl → 0.2.231__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +10 -8
- vision_agent/agent/agent_utils.py +10 -9
- vision_agent/agent/types.py +1 -0
- vision_agent/agent/vision_agent.py +3 -4
- vision_agent/agent/vision_agent_coder_prompts.py +6 -6
- vision_agent/agent/vision_agent_coder_v2.py +41 -26
- vision_agent/agent/vision_agent_planner_prompts.py +6 -6
- vision_agent/agent/vision_agent_planner_prompts_v2.py +16 -50
- vision_agent/agent/vision_agent_planner_v2.py +11 -12
- vision_agent/agent/vision_agent_prompts.py +11 -11
- vision_agent/agent/vision_agent_prompts_v2.py +18 -3
- vision_agent/agent/vision_agent_v2.py +29 -30
- vision_agent/configs/__init__.py +1 -0
- vision_agent/configs/anthropic_config.py +150 -0
- vision_agent/configs/anthropic_openai_config.py +150 -0
- vision_agent/configs/config.py +150 -0
- vision_agent/configs/openai_config.py +160 -0
- vision_agent/lmm/__init__.py +1 -1
- vision_agent/lmm/lmm.py +63 -9
- vision_agent/tools/__init__.py +4 -4
- vision_agent/tools/planner_tools.py +74 -48
- vision_agent/tools/tool_utils.py +3 -0
- vision_agent/tools/tools.py +49 -31
- vision_agent/utils/sim.py +33 -12
- vision_agent-0.2.231.dist-info/METADATA +148 -0
- vision_agent-0.2.231.dist-info/RECORD +52 -0
- vision_agent-0.2.229.dist-info/METADATA +0 -562
- vision_agent-0.2.229.dist-info/RECORD +0 -47
- {vision_agent-0.2.229.dist-info → vision_agent-0.2.231.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.229.dist-info → vision_agent-0.2.231.dist-info}/WHEEL +0 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
from typing import Type
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from vision_agent.lmm import LMM, OpenAILMM
|
6
|
+
|
7
|
+
|
8
|
+
class Config(BaseModel):
|
9
|
+
# for vision_agent_v2
|
10
|
+
agent: Type[LMM] = Field(default=OpenAILMM)
|
11
|
+
agent_kwargs: dict = Field(
|
12
|
+
default_factory=lambda: {
|
13
|
+
"model_name": "gpt-4o-2024-08-06",
|
14
|
+
"temperature": 0.0,
|
15
|
+
"image_size": 768,
|
16
|
+
"image_detail": "low",
|
17
|
+
}
|
18
|
+
)
|
19
|
+
|
20
|
+
# for vision_agent_planner_v2
|
21
|
+
planner: Type[LMM] = Field(default=OpenAILMM)
|
22
|
+
planner_kwargs: dict = Field(
|
23
|
+
default_factory=lambda: {
|
24
|
+
"model_name": "gpt-4o-2024-08-06",
|
25
|
+
"temperature": 0.0,
|
26
|
+
"image_size": 768,
|
27
|
+
"image_detail": "low",
|
28
|
+
}
|
29
|
+
)
|
30
|
+
|
31
|
+
# for vision_agent_planner_v2
|
32
|
+
summarizer: Type[LMM] = Field(default=OpenAILMM)
|
33
|
+
summarizer_kwargs: dict = Field(
|
34
|
+
default_factory=lambda: {
|
35
|
+
"model_name": "o1",
|
36
|
+
"temperature": 1.0,
|
37
|
+
"image_size": 768,
|
38
|
+
}
|
39
|
+
)
|
40
|
+
|
41
|
+
# for vision_agent_planner_v2
|
42
|
+
critic: Type[LMM] = Field(default=OpenAILMM)
|
43
|
+
critic_kwargs: dict = Field(
|
44
|
+
default_factory=lambda: {
|
45
|
+
"model_name": "gpt-4o-2024-08-06",
|
46
|
+
"temperature": 0.0,
|
47
|
+
"image_size": 768,
|
48
|
+
"image_detail": "low",
|
49
|
+
}
|
50
|
+
)
|
51
|
+
|
52
|
+
# for vision_agent_coder_v2
|
53
|
+
coder: Type[LMM] = Field(default=OpenAILMM)
|
54
|
+
coder_kwargs: dict = Field(
|
55
|
+
default_factory=lambda: {
|
56
|
+
"model_name": "gpt-4o-2024-08-06",
|
57
|
+
"temperature": 0.0,
|
58
|
+
"image_size": 768,
|
59
|
+
"image_detail": "low",
|
60
|
+
}
|
61
|
+
)
|
62
|
+
|
63
|
+
# for vision_agent_coder_v2
|
64
|
+
tester: Type[LMM] = Field(default=OpenAILMM)
|
65
|
+
tester_kwargs: dict = Field(
|
66
|
+
default_factory=lambda: {
|
67
|
+
"model_name": "gpt-4o-2024-08-06",
|
68
|
+
"temperature": 0.0,
|
69
|
+
"image_size": 768,
|
70
|
+
"image_detail": "low",
|
71
|
+
}
|
72
|
+
)
|
73
|
+
|
74
|
+
# for vision_agent_coder_v2
|
75
|
+
debugger: Type[LMM] = Field(default=OpenAILMM)
|
76
|
+
debugger_kwargs: dict = Field(
|
77
|
+
default_factory=lambda: {
|
78
|
+
"model_name": "gpt-4o-2024-08-06",
|
79
|
+
"temperature": 0.0,
|
80
|
+
"image_size": 768,
|
81
|
+
"image_detail": "low",
|
82
|
+
}
|
83
|
+
)
|
84
|
+
|
85
|
+
# for get_tool_for_task
|
86
|
+
tool_tester: Type[LMM] = Field(default=OpenAILMM)
|
87
|
+
tool_tester_kwargs: dict = Field(
|
88
|
+
default_factory=lambda: {
|
89
|
+
"model_name": "gpt-4o-2024-08-06",
|
90
|
+
"temperature": 0.0,
|
91
|
+
"image_size": 768,
|
92
|
+
"image_detail": "low",
|
93
|
+
}
|
94
|
+
)
|
95
|
+
|
96
|
+
# for get_tool_for_task
|
97
|
+
tool_chooser: Type[LMM] = Field(default=OpenAILMM)
|
98
|
+
tool_chooser_kwargs: dict = Field(
|
99
|
+
default_factory=lambda: {
|
100
|
+
"model_name": "gpt-4o-2024-08-06",
|
101
|
+
"temperature": 1.0,
|
102
|
+
"image_size": 768,
|
103
|
+
"image_detail": "low",
|
104
|
+
}
|
105
|
+
)
|
106
|
+
|
107
|
+
# for suggestions module
|
108
|
+
suggester: Type[LMM] = Field(default=OpenAILMM)
|
109
|
+
suggester_kwargs: dict = Field(
|
110
|
+
default_factory=lambda: {
|
111
|
+
"model_name": "gpt-4o-2024-08-06",
|
112
|
+
"temperature": 1.0,
|
113
|
+
"image_size": 768,
|
114
|
+
"image_detail": "low",
|
115
|
+
}
|
116
|
+
)
|
117
|
+
|
118
|
+
# for vqa module
|
119
|
+
vqa: Type[LMM] = Field(default=OpenAILMM)
|
120
|
+
vqa_kwargs: dict = Field(
|
121
|
+
default_factory=lambda: {
|
122
|
+
"model_name": "gpt-4o-2024-08-06",
|
123
|
+
"temperature": 0.0,
|
124
|
+
"image_size": 768,
|
125
|
+
"image_detail": "low",
|
126
|
+
}
|
127
|
+
)
|
128
|
+
|
129
|
+
def create_agent(self) -> LMM:
|
130
|
+
return self.agent(**self.agent_kwargs)
|
131
|
+
|
132
|
+
def create_planner(self) -> LMM:
|
133
|
+
return self.planner(**self.planner_kwargs)
|
134
|
+
|
135
|
+
def create_summarizer(self) -> LMM:
|
136
|
+
return self.summarizer(**self.summarizer_kwargs)
|
137
|
+
|
138
|
+
def create_critic(self) -> LMM:
|
139
|
+
return self.critic(**self.critic_kwargs)
|
140
|
+
|
141
|
+
def create_coder(self) -> LMM:
|
142
|
+
return self.coder(**self.coder_kwargs)
|
143
|
+
|
144
|
+
def create_tester(self) -> LMM:
|
145
|
+
return self.tester(**self.tester_kwargs)
|
146
|
+
|
147
|
+
def create_debugger(self) -> LMM:
|
148
|
+
return self.debugger(**self.debugger_kwargs)
|
149
|
+
|
150
|
+
def create_tool_tester(self) -> LMM:
|
151
|
+
return self.tool_tester(**self.tool_tester_kwargs)
|
152
|
+
|
153
|
+
def create_tool_chooser(self) -> LMM:
|
154
|
+
return self.tool_chooser(**self.tool_chooser_kwargs)
|
155
|
+
|
156
|
+
def create_suggester(self) -> LMM:
|
157
|
+
return self.suggester(**self.suggester_kwargs)
|
158
|
+
|
159
|
+
def create_vqa(self) -> LMM:
|
160
|
+
return self.vqa(**self.vqa_kwargs)
|
vision_agent/lmm/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
|
1
|
+
from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, GoogleLMM, OllamaLMM, OpenAILMM
|
2
2
|
from .types import Message
|
vision_agent/lmm/lmm.py
CHANGED
@@ -50,6 +50,8 @@ class OpenAILMM(LMM):
|
|
50
50
|
api_key: Optional[str] = None,
|
51
51
|
max_tokens: int = 4096,
|
52
52
|
json_mode: bool = False,
|
53
|
+
image_size: int = 768,
|
54
|
+
image_detail: str = "low",
|
53
55
|
**kwargs: Any,
|
54
56
|
):
|
55
57
|
if not api_key:
|
@@ -59,7 +61,10 @@ class OpenAILMM(LMM):
|
|
59
61
|
|
60
62
|
self.client = OpenAI(api_key=api_key)
|
61
63
|
self.model_name = model_name
|
62
|
-
|
64
|
+
self.image_size = image_size
|
65
|
+
self.image_detail = image_detail
|
66
|
+
# o1 does not use max_tokens
|
67
|
+
if "max_tokens" not in kwargs and not model_name.startswith("o1"):
|
63
68
|
kwargs["max_tokens"] = max_tokens
|
64
69
|
if json_mode:
|
65
70
|
kwargs["response_format"] = {"type": "json_object"}
|
@@ -94,7 +99,13 @@ class OpenAILMM(LMM):
|
|
94
99
|
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
|
95
100
|
if "media" in c:
|
96
101
|
for media in c["media"]:
|
97
|
-
|
102
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
103
|
+
image_detail = (
|
104
|
+
kwargs["image_detail"]
|
105
|
+
if "image_detail" in kwargs
|
106
|
+
else self.image_detail
|
107
|
+
)
|
108
|
+
encoded_media = encode_media(cast(str, media), resize=resize)
|
98
109
|
|
99
110
|
fixed_c["content"].append( # type: ignore
|
100
111
|
{
|
@@ -106,7 +117,7 @@ class OpenAILMM(LMM):
|
|
106
117
|
or encoded_media.startswith("data:image/")
|
107
118
|
else f"data:image/png;base64,{encoded_media}"
|
108
119
|
),
|
109
|
-
"detail":
|
120
|
+
"detail": image_detail,
|
110
121
|
},
|
111
122
|
},
|
112
123
|
)
|
@@ -144,7 +155,13 @@ class OpenAILMM(LMM):
|
|
144
155
|
]
|
145
156
|
if media and len(media) > 0:
|
146
157
|
for m in media:
|
147
|
-
|
158
|
+
resize = kwargs["resize"] if "resize" in kwargs else None
|
159
|
+
image_detail = (
|
160
|
+
kwargs["image_detail"]
|
161
|
+
if "image_detail" in kwargs
|
162
|
+
else self.image_detail
|
163
|
+
)
|
164
|
+
encoded_media = encode_media(m, resize=resize)
|
148
165
|
message[0]["content"].append(
|
149
166
|
{
|
150
167
|
"type": "image_url",
|
@@ -155,7 +172,7 @@ class OpenAILMM(LMM):
|
|
155
172
|
or encoded_media.startswith("data:image/")
|
156
173
|
else f"data:image/png;base64,{encoded_media}"
|
157
174
|
),
|
158
|
-
"detail":
|
175
|
+
"detail": image_detail,
|
159
176
|
},
|
160
177
|
},
|
161
178
|
)
|
@@ -186,6 +203,7 @@ class AzureOpenAILMM(OpenAILMM):
|
|
186
203
|
azure_endpoint: Optional[str] = None,
|
187
204
|
max_tokens: int = 4096,
|
188
205
|
json_mode: bool = False,
|
206
|
+
image_detail: str = "low",
|
189
207
|
**kwargs: Any,
|
190
208
|
):
|
191
209
|
if not api_key:
|
@@ -208,6 +226,7 @@ class AzureOpenAILMM(OpenAILMM):
|
|
208
226
|
azure_endpoint=azure_endpoint,
|
209
227
|
)
|
210
228
|
self.model_name = model_name
|
229
|
+
self.image_detail = image_detail
|
211
230
|
|
212
231
|
if "max_tokens" not in kwargs:
|
213
232
|
kwargs["max_tokens"] = max_tokens
|
@@ -225,6 +244,7 @@ class OllamaLMM(LMM):
|
|
225
244
|
base_url: Optional[str] = "http://localhost:11434/api",
|
226
245
|
json_mode: bool = False,
|
227
246
|
num_ctx: int = 128_000,
|
247
|
+
image_size: int = 768,
|
228
248
|
**kwargs: Any,
|
229
249
|
):
|
230
250
|
"""Initializes the Ollama LMM. kwargs are passed as 'options' to the model.
|
@@ -241,6 +261,7 @@ class OllamaLMM(LMM):
|
|
241
261
|
|
242
262
|
self.url = base_url
|
243
263
|
self.model_name = model_name
|
264
|
+
self.image_size = image_size
|
244
265
|
self.kwargs = {"options": kwargs}
|
245
266
|
|
246
267
|
if json_mode:
|
@@ -273,8 +294,9 @@ class OllamaLMM(LMM):
|
|
273
294
|
fixed_chat = []
|
274
295
|
for message in chat:
|
275
296
|
if "media" in message:
|
297
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
276
298
|
message["images"] = [
|
277
|
-
encode_media(cast(str, m)) for m in message["media"]
|
299
|
+
encode_media(cast(str, m), resize=resize) for m in message["media"]
|
278
300
|
]
|
279
301
|
del message["media"]
|
280
302
|
fixed_chat.append(message)
|
@@ -328,7 +350,8 @@ class OllamaLMM(LMM):
|
|
328
350
|
|
329
351
|
if media and len(media) > 0:
|
330
352
|
for m in media:
|
331
|
-
|
353
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
354
|
+
data["images"].append(encode_media(m, resize=resize))
|
332
355
|
|
333
356
|
tmp_kwargs = self.kwargs | kwargs
|
334
357
|
data.update(tmp_kwargs)
|
@@ -370,9 +393,11 @@ class AnthropicLMM(LMM):
|
|
370
393
|
api_key: Optional[str] = None,
|
371
394
|
model_name: str = "claude-3-5-sonnet-20240620",
|
372
395
|
max_tokens: int = 4096,
|
396
|
+
image_size: int = 768,
|
373
397
|
**kwargs: Any,
|
374
398
|
):
|
375
399
|
self.client = anthropic.Anthropic(api_key=api_key)
|
400
|
+
self.image_size = image_size
|
376
401
|
self.model_name = model_name
|
377
402
|
if "max_tokens" not in kwargs:
|
378
403
|
kwargs["max_tokens"] = max_tokens
|
@@ -399,7 +424,8 @@ class AnthropicLMM(LMM):
|
|
399
424
|
]
|
400
425
|
if "media" in msg:
|
401
426
|
for media_path in msg["media"]:
|
402
|
-
|
427
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
428
|
+
encoded_media = encode_media(media_path, resize=resize)
|
403
429
|
if encoded_media.startswith("data:image/png;base64,"):
|
404
430
|
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
405
431
|
content.append(
|
@@ -448,7 +474,8 @@ class AnthropicLMM(LMM):
|
|
448
474
|
]
|
449
475
|
if media:
|
450
476
|
for m in media:
|
451
|
-
|
477
|
+
resize = kwargs["resize"] if "resize" in kwargs else self.image_size
|
478
|
+
encoded_media = encode_media(m, resize=resize)
|
452
479
|
if encoded_media.startswith("data:image/png;base64,"):
|
453
480
|
encoded_media = encoded_media[len("data:image/png;base64,") :]
|
454
481
|
content.append(
|
@@ -486,3 +513,30 @@ class AnthropicLMM(LMM):
|
|
486
513
|
return f()
|
487
514
|
else:
|
488
515
|
return cast(str, response.content[0].text)
|
516
|
+
|
517
|
+
|
518
|
+
class GoogleLMM(OpenAILMM):
|
519
|
+
r"""An LMM class for the Google LMMs."""
|
520
|
+
|
521
|
+
def __init__(
|
522
|
+
self,
|
523
|
+
api_key: Optional[str] = None,
|
524
|
+
model_name: str = "gemini-2.0-flash-exp",
|
525
|
+
max_tokens: int = 4096,
|
526
|
+
image_detail: str = "low",
|
527
|
+
image_size: int = 768,
|
528
|
+
**kwargs: Any,
|
529
|
+
):
|
530
|
+
base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
531
|
+
if not api_key:
|
532
|
+
api_key = os.environ.get("GEMINI_API_KEY")
|
533
|
+
|
534
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
535
|
+
|
536
|
+
self.model_name = model_name
|
537
|
+
self.image_size = image_size
|
538
|
+
self.image_detail = image_detail
|
539
|
+
|
540
|
+
if "max_tokens" not in kwargs:
|
541
|
+
kwargs["max_tokens"] = max_tokens
|
542
|
+
self.kwargs = kwargs
|
vision_agent/tools/__init__.py
CHANGED
@@ -23,6 +23,9 @@ from .tools import (
|
|
23
23
|
TOOLS_INFO,
|
24
24
|
UTIL_TOOLS,
|
25
25
|
UTILITIES_DOCSTRING,
|
26
|
+
agentic_object_detection,
|
27
|
+
agentic_sam2_instance_segmentation,
|
28
|
+
agentic_sam2_video_tracking,
|
26
29
|
claude35_text_extraction,
|
27
30
|
closest_box_distance,
|
28
31
|
closest_mask_distance,
|
@@ -30,6 +33,7 @@ from .tools import (
|
|
30
33
|
countgd_sam2_instance_segmentation,
|
31
34
|
countgd_sam2_video_tracking,
|
32
35
|
countgd_visual_prompt_object_detection,
|
36
|
+
custom_object_detection,
|
33
37
|
depth_anything_v2,
|
34
38
|
detr_segmentation,
|
35
39
|
document_extraction,
|
@@ -63,10 +67,6 @@ from .tools import (
|
|
63
67
|
video_temporal_localization,
|
64
68
|
vit_image_classification,
|
65
69
|
vit_nsfw_classification,
|
66
|
-
custom_object_detection,
|
67
|
-
agentic_object_detection,
|
68
|
-
agentic_sam2_instance_segmentation,
|
69
|
-
agentic_sam2_video_tracking,
|
70
70
|
)
|
71
71
|
|
72
72
|
__new_tools__ = [
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
-
import shutil
|
4
3
|
import tempfile
|
4
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5
5
|
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
6
6
|
|
7
7
|
import libcst as cst
|
@@ -10,12 +10,7 @@ from IPython.display import display
|
|
10
10
|
from PIL import Image
|
11
11
|
|
12
12
|
import vision_agent.tools as T
|
13
|
-
from vision_agent.agent.agent_utils import
|
14
|
-
DefaultImports,
|
15
|
-
extract_code,
|
16
|
-
extract_json,
|
17
|
-
extract_tag,
|
18
|
-
)
|
13
|
+
from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
|
19
14
|
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
20
15
|
CATEGORIZE_TOOL_REQUEST,
|
21
16
|
FINALIZE_PLAN,
|
@@ -24,6 +19,7 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
|
24
19
|
TEST_TOOLS_EXAMPLE1,
|
25
20
|
TEST_TOOLS_EXAMPLE2,
|
26
21
|
)
|
22
|
+
from vision_agent.configs import Config
|
27
23
|
from vision_agent.lmm import LMM, AnthropicLMM
|
28
24
|
from vision_agent.utils.execute import (
|
29
25
|
CodeInterpreter,
|
@@ -35,7 +31,11 @@ from vision_agent.utils.image_utils import convert_to_b64
|
|
35
31
|
from vision_agent.utils.sim import get_tool_recommender
|
36
32
|
|
37
33
|
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
34
|
+
LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
|
35
|
+
[T.load_image, T.extract_frames_and_timestamps]
|
36
|
+
)
|
38
37
|
|
38
|
+
CONFIG = Config()
|
39
39
|
_LOGGER = logging.getLogger(__name__)
|
40
40
|
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
41
41
|
|
@@ -50,6 +50,54 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
|
|
50
50
|
return return_str
|
51
51
|
|
52
52
|
|
53
|
+
def run_multi_judge(
|
54
|
+
tool_chooser: LMM,
|
55
|
+
tool_docs_str: str,
|
56
|
+
task: str,
|
57
|
+
code: str,
|
58
|
+
tool_output_str: str,
|
59
|
+
image_paths: List[str],
|
60
|
+
) -> Tuple[Optional[Callable], str, str]:
|
61
|
+
error_message = ""
|
62
|
+
prompt = PICK_TOOL.format(
|
63
|
+
tool_docs=tool_docs_str,
|
64
|
+
user_request=task,
|
65
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
66
|
+
previous_attempts=error_message,
|
67
|
+
)
|
68
|
+
|
69
|
+
def run_judge() -> Tuple[Optional[Callable], str, str]:
|
70
|
+
response = tool_chooser.generate(prompt, media=image_paths, temperature=1.0)
|
71
|
+
tool_choice_context = extract_tag(response, "json") # type: ignore
|
72
|
+
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
73
|
+
tool, tool_thoughts, tool_docstring, _ = extract_tool_info(
|
74
|
+
tool_choice_context_dict
|
75
|
+
)
|
76
|
+
return tool, tool_thoughts, tool_docstring
|
77
|
+
|
78
|
+
responses = []
|
79
|
+
with ThreadPoolExecutor() as executor:
|
80
|
+
futures = [executor.submit(run_judge) for _ in range(3)]
|
81
|
+
for future in as_completed(futures):
|
82
|
+
responses.append(future.result())
|
83
|
+
|
84
|
+
responses = [r for r in responses if r[0] is not None]
|
85
|
+
counts: Dict[str, int] = {}
|
86
|
+
for tool, tool_thoughts, tool_docstring in responses:
|
87
|
+
if tool is not None:
|
88
|
+
counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
|
89
|
+
if counts[tool.__name__] >= 2:
|
90
|
+
return tool, tool_thoughts, tool_docstring
|
91
|
+
|
92
|
+
if len(responses) == 0:
|
93
|
+
return (
|
94
|
+
None,
|
95
|
+
"No tool could be found, please try again with a different prompt or image",
|
96
|
+
"",
|
97
|
+
)
|
98
|
+
return responses[0]
|
99
|
+
|
100
|
+
|
53
101
|
def extract_tool_info(
|
54
102
|
tool_choice_context: Dict[str, Any],
|
55
103
|
) -> Tuple[Optional[Callable], str, str, str]:
|
@@ -129,6 +177,7 @@ def run_tool_testing(
|
|
129
177
|
cleaned_tool_docs.append(tool_doc)
|
130
178
|
tool_docs = cleaned_tool_docs
|
131
179
|
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
|
180
|
+
tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING
|
132
181
|
|
133
182
|
prompt = TEST_TOOLS.format(
|
134
183
|
tool_docs=tool_docs_str,
|
@@ -167,8 +216,15 @@ def run_tool_testing(
|
|
167
216
|
examples=EXAMPLES,
|
168
217
|
media=str(image_paths),
|
169
218
|
)
|
170
|
-
|
171
|
-
code =
|
219
|
+
response = cast(str, lmm.generate(prompt, media=image_paths))
|
220
|
+
code = extract_tag(response, "code")
|
221
|
+
if code is None:
|
222
|
+
code = response
|
223
|
+
|
224
|
+
try:
|
225
|
+
code = process_code(code)
|
226
|
+
except Exception as e:
|
227
|
+
_LOGGER.error(f"Error processing code: {e}")
|
172
228
|
tool_output = code_interpreter.exec_isolation(
|
173
229
|
DefaultImports.prepend_imports(code)
|
174
230
|
)
|
@@ -212,7 +268,8 @@ def get_tool_for_task(
|
|
212
268
|
--------
|
213
269
|
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
214
270
|
"""
|
215
|
-
|
271
|
+
tool_tester = CONFIG.create_tool_tester()
|
272
|
+
tool_chooser = CONFIG.create_tool_chooser()
|
216
273
|
|
217
274
|
with (
|
218
275
|
tempfile.TemporaryDirectory() as tmpdirname,
|
@@ -225,45 +282,14 @@ def get_tool_for_task(
|
|
225
282
|
image_paths.append(image_path)
|
226
283
|
|
227
284
|
code, tool_docs_str, tool_output = run_tool_testing(
|
228
|
-
task, image_paths,
|
285
|
+
task, image_paths, tool_tester, exclude_tools, code_interpreter
|
229
286
|
)
|
230
287
|
tool_output_str = tool_output.text(include_results=False).strip()
|
231
288
|
|
232
|
-
|
233
|
-
|
234
|
-
tool_docs=tool_docs_str,
|
235
|
-
user_request=task,
|
236
|
-
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
237
|
-
previous_attempts=error_message,
|
238
|
-
)
|
239
|
-
|
240
|
-
response = lmm.generate(prompt, media=image_paths)
|
241
|
-
tool_choice_context = extract_tag(response, "json") # type: ignore
|
242
|
-
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
243
|
-
|
244
|
-
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
245
|
-
tool_choice_context_dict
|
289
|
+
_, tool_thoughts, tool_docstring = run_multi_judge(
|
290
|
+
tool_chooser, tool_docs_str, task, code, tool_output_str, image_paths
|
246
291
|
)
|
247
292
|
|
248
|
-
count = 1
|
249
|
-
while tool is None and count <= 3:
|
250
|
-
prompt = PICK_TOOL.format(
|
251
|
-
tool_docs=tool_docs_str,
|
252
|
-
user_request=task,
|
253
|
-
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
254
|
-
previous_attempts=error_message,
|
255
|
-
)
|
256
|
-
tool_choice_context_dict = extract_json(
|
257
|
-
lmm.generate(prompt, media=image_paths) # type: ignore
|
258
|
-
)
|
259
|
-
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
260
|
-
tool_choice_context_dict
|
261
|
-
)
|
262
|
-
try:
|
263
|
-
shutil.rmtree(tmpdirname)
|
264
|
-
except Exception as e:
|
265
|
-
_LOGGER.error(f"Error removing temp directory: {e}")
|
266
|
-
|
267
293
|
print(format_tool_output(tool_thoughts, tool_docstring))
|
268
294
|
|
269
295
|
|
@@ -277,7 +303,7 @@ def get_tool_for_task_human_reviewer(
|
|
277
303
|
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
278
304
|
) -> None:
|
279
305
|
# NOTE: this will have the same documentation as get_tool_for_task
|
280
|
-
|
306
|
+
tool_tester = CONFIG.create_tool_tester()
|
281
307
|
|
282
308
|
with (
|
283
309
|
tempfile.TemporaryDirectory() as tmpdirname,
|
@@ -298,7 +324,7 @@ def get_tool_for_task_human_reviewer(
|
|
298
324
|
_, _, tool_output = run_tool_testing(
|
299
325
|
task,
|
300
326
|
image_paths,
|
301
|
-
|
327
|
+
tool_tester,
|
302
328
|
exclude_tools,
|
303
329
|
code_interpreter,
|
304
330
|
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
|
@@ -349,7 +375,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
|
349
375
|
medias: List[np.ndarray]: The images to ask the question about, it could also
|
350
376
|
be frames from a video. You can send up to 5 frames from a video.
|
351
377
|
"""
|
352
|
-
|
378
|
+
vqa = CONFIG.create_vqa()
|
353
379
|
if isinstance(medias, np.ndarray):
|
354
380
|
medias = [medias]
|
355
381
|
if isinstance(medias, list) and len(medias) > 5:
|
@@ -358,7 +384,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
|
358
384
|
"data:image/png;base64," + convert_to_b64(media) for media in medias
|
359
385
|
]
|
360
386
|
|
361
|
-
response = cast(str,
|
387
|
+
response = cast(str, vqa.generate(prompt, media=all_media_b64))
|
362
388
|
print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
|
363
389
|
|
364
390
|
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -318,6 +318,9 @@ def single_nms(
|
|
318
318
|
def nms(
|
319
319
|
all_preds: List[List[Dict[str, Any]]], iou_threshold: float
|
320
320
|
) -> List[List[Dict[str, Any]]]:
|
321
|
+
if not isinstance(all_preds[0], List):
|
322
|
+
all_preds = [all_preds]
|
323
|
+
|
321
324
|
return_preds = []
|
322
325
|
for frame_preds in all_preds:
|
323
326
|
frame_preds = single_nms(frame_preds, iou_threshold)
|