vision-agent 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/easytool_prompts.py +9 -1
- vision_agent/agent/reflexion.py +14 -6
- vision_agent/agent/vision_agent.py +34 -10
- vision_agent/agent/vision_agent_prompts.py +5 -4
- vision_agent/data/data.py +2 -2
- vision_agent/image_utils.py +3 -1
- vision_agent/lmm/lmm.py +58 -44
- vision_agent/tools/__init__.py +3 -2
- vision_agent/tools/tools.py +132 -30
- vision_agent/tools/video.py +14 -12
- {vision_agent-0.1.2.dist-info → vision_agent-0.1.4.dist-info}/METADATA +1 -1
- {vision_agent-0.1.2.dist-info → vision_agent-0.1.4.dist-info}/RECORD +14 -14
- {vision_agent-0.1.2.dist-info → vision_agent-0.1.4.dist-info}/LICENSE +0 -0
- {vision_agent-0.1.2.dist-info → vision_agent-0.1.4.dist-info}/WHEEL +0 -0
@@ -56,6 +56,7 @@ Example 2: {{"Parameters":[{{"input": [1,2,3]}}, {{"input": [2,3,4]}}]}}
|
|
56
56
|
|
57
57
|
These are logs of previous questions and answers:
|
58
58
|
{previous_log}
|
59
|
+
|
59
60
|
This is the current user's question: {question}
|
60
61
|
This is the API tool documentation: {tool_usage}
|
61
62
|
Output: """
|
@@ -67,15 +68,22 @@ Please note that:
|
|
67
68
|
2. We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
|
68
69
|
3. If the API tool does not provide useful information in the response, please answer with your knowledge.
|
69
70
|
4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers.
|
71
|
+
|
70
72
|
These are logs of previous questions and answers:
|
71
73
|
{previous_log}
|
74
|
+
|
72
75
|
This is the user's question: {question}
|
76
|
+
|
73
77
|
This is the response output by the API tool:
|
74
78
|
{call_results}
|
79
|
+
|
75
80
|
We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
|
76
81
|
Output: """
|
77
82
|
|
78
83
|
ANSWER_SUMMARIZE = """We break down a complex user's problems into simple subtasks and provide answers to each simple subtask. You need to organize these answers to each subtask and form a self-consistent final answer to the user's question.
|
79
84
|
This is the user's question: {question}
|
80
|
-
|
85
|
+
|
86
|
+
These are subtasks and their answers:
|
87
|
+
{answers}
|
88
|
+
|
81
89
|
Final answer: """
|
vision_agent/agent/reflexion.py
CHANGED
@@ -238,12 +238,20 @@ class Reflexion(Agent):
|
|
238
238
|
self._build_agent_prompt(question, reflections, scratchpad)
|
239
239
|
)
|
240
240
|
)
|
241
|
-
|
242
|
-
|
243
|
-
self.
|
244
|
-
|
241
|
+
elif isinstance(self.action_agent, LMM):
|
242
|
+
return format_step(
|
243
|
+
self.action_agent(
|
244
|
+
self._build_agent_prompt(question, reflections, scratchpad),
|
245
|
+
images=[image] if image is not None else None,
|
246
|
+
)
|
247
|
+
)
|
248
|
+
elif isinstance(self.action_agent, Agent):
|
249
|
+
return format_step(
|
250
|
+
self.action_agent(
|
251
|
+
self._build_agent_prompt(question, reflections, scratchpad),
|
252
|
+
image=image,
|
253
|
+
)
|
245
254
|
)
|
246
|
-
)
|
247
255
|
|
248
256
|
def prompt_reflection(
|
249
257
|
self,
|
@@ -261,7 +269,7 @@ class Reflexion(Agent):
|
|
261
269
|
return format_step(
|
262
270
|
self.self_reflect_model(
|
263
271
|
self._build_reflect_prompt(question, context, scratchpad),
|
264
|
-
|
272
|
+
images=[image] if image is not None else None,
|
265
273
|
)
|
266
274
|
)
|
267
275
|
|
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import sys
|
4
4
|
import tempfile
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
7
7
|
|
8
8
|
from PIL import Image
|
9
9
|
from tabulate import tabulate
|
@@ -264,7 +264,7 @@ def self_reflect(
|
|
264
264
|
tools: Dict[int, Any],
|
265
265
|
tool_result: List[Dict],
|
266
266
|
final_answer: str,
|
267
|
-
|
267
|
+
images: Optional[Sequence[Union[str, Path]]] = None,
|
268
268
|
) -> str:
|
269
269
|
prompt = VISION_AGENT_REFLECTION.format(
|
270
270
|
question=question,
|
@@ -275,10 +275,10 @@ def self_reflect(
|
|
275
275
|
)
|
276
276
|
if (
|
277
277
|
issubclass(type(reflect_model), LMM)
|
278
|
-
and
|
279
|
-
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
|
278
|
+
and images is not None
|
279
|
+
and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images])
|
280
280
|
):
|
281
|
-
return reflect_model(prompt,
|
281
|
+
return reflect_model(prompt, images=images) # type: ignore
|
282
282
|
return reflect_model(prompt)
|
283
283
|
|
284
284
|
|
@@ -357,7 +357,7 @@ def _handle_viz_tools(
|
|
357
357
|
return image_to_data
|
358
358
|
|
359
359
|
|
360
|
-
def visualize_result(all_tool_results: List[Dict]) ->
|
360
|
+
def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
|
361
361
|
image_to_data: Dict[str, Dict] = {}
|
362
362
|
for tool_result in all_tool_results:
|
363
363
|
# only handle bbox/mask tools or frame extraction
|
@@ -365,6 +365,7 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
|
|
365
365
|
"grounding_sam_",
|
366
366
|
"grounding_dino_",
|
367
367
|
"extract_frames_",
|
368
|
+
"dinov_",
|
368
369
|
]:
|
369
370
|
continue
|
370
371
|
|
@@ -407,7 +408,7 @@ class VisionAgent(Agent):
|
|
407
408
|
task_model: Optional[Union[LLM, LMM]] = None,
|
408
409
|
answer_model: Optional[Union[LLM, LMM]] = None,
|
409
410
|
reflect_model: Optional[Union[LLM, LMM]] = None,
|
410
|
-
max_retries: int =
|
411
|
+
max_retries: int = 3,
|
411
412
|
verbose: bool = False,
|
412
413
|
report_progress_callback: Optional[Callable[[str], None]] = None,
|
413
414
|
):
|
@@ -444,6 +445,7 @@ class VisionAgent(Agent):
|
|
444
445
|
self,
|
445
446
|
input: Union[List[Dict[str, str]], str],
|
446
447
|
image: Optional[Union[str, Path]] = None,
|
448
|
+
reference_data: Optional[Dict[str, str]] = None,
|
447
449
|
visualize_output: Optional[bool] = False,
|
448
450
|
) -> str:
|
449
451
|
"""Invoke the vision agent.
|
@@ -458,7 +460,12 @@ class VisionAgent(Agent):
|
|
458
460
|
"""
|
459
461
|
if isinstance(input, str):
|
460
462
|
input = [{"role": "user", "content": input}]
|
461
|
-
return self.chat(
|
463
|
+
return self.chat(
|
464
|
+
input,
|
465
|
+
image=image,
|
466
|
+
visualize_output=visualize_output,
|
467
|
+
reference_data=reference_data,
|
468
|
+
)
|
462
469
|
|
463
470
|
def log_progress(self, description: str) -> None:
|
464
471
|
_LOGGER.info(description)
|
@@ -469,11 +476,18 @@ class VisionAgent(Agent):
|
|
469
476
|
self,
|
470
477
|
chat: List[Dict[str, str]],
|
471
478
|
image: Optional[Union[str, Path]] = None,
|
479
|
+
reference_data: Optional[Dict[str, str]] = None,
|
472
480
|
visualize_output: Optional[bool] = False,
|
473
481
|
) -> Tuple[str, List[Dict]]:
|
474
482
|
question = chat[0]["content"]
|
475
483
|
if image:
|
476
484
|
question += f" Image name: {image}"
|
485
|
+
if reference_data:
|
486
|
+
if not ("image" in reference_data and "mask" in reference_data):
|
487
|
+
raise ValueError(
|
488
|
+
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
|
489
|
+
)
|
490
|
+
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
|
477
491
|
|
478
492
|
reflections = ""
|
479
493
|
final_answer = ""
|
@@ -519,13 +533,19 @@ class VisionAgent(Agent):
|
|
519
533
|
|
520
534
|
visualized_output = visualize_result(all_tool_results)
|
521
535
|
all_tool_results.append({"visualized_output": visualized_output})
|
536
|
+
if len(visualized_output) > 0:
|
537
|
+
reflection_images = visualized_output
|
538
|
+
elif image is not None:
|
539
|
+
reflection_images = [image]
|
540
|
+
else:
|
541
|
+
reflection_images = None
|
522
542
|
reflection = self_reflect(
|
523
543
|
self.reflect_model,
|
524
544
|
question,
|
525
545
|
self.tools,
|
526
546
|
all_tool_results,
|
527
547
|
final_answer,
|
528
|
-
|
548
|
+
reflection_images,
|
529
549
|
)
|
530
550
|
self.log_progress(f"Reflection: {reflection}")
|
531
551
|
parsed_reflection = parse_reflect(reflection)
|
@@ -549,10 +569,14 @@ class VisionAgent(Agent):
|
|
549
569
|
self,
|
550
570
|
chat: List[Dict[str, str]],
|
551
571
|
image: Optional[Union[str, Path]] = None,
|
572
|
+
reference_data: Optional[Dict[str, str]] = None,
|
552
573
|
visualize_output: Optional[bool] = False,
|
553
574
|
) -> str:
|
554
575
|
answer, _ = self.chat_with_workflow(
|
555
|
-
chat,
|
576
|
+
chat,
|
577
|
+
image=image,
|
578
|
+
visualize_output=visualize_output,
|
579
|
+
reference_data=reference_data,
|
556
580
|
)
|
557
581
|
return answer
|
558
582
|
|
@@ -1,11 +1,11 @@
|
|
1
|
-
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
|
1
|
+
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question, the tool usage for each of the tools used and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
|
2
2
|
|
3
3
|
Please note that:
|
4
4
|
1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like:
|
5
5
|
{{"Finish": true, "Reflection": "The agent's answer was correct."}}
|
6
|
-
2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or
|
7
|
-
3. If the agent's answer was incorrect, you must diagnose
|
8
|
-
{{"Finish": false, "Reflection": "I can see from
|
6
|
+
2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or if the tools were used incorrectly or the wrong tools were used.
|
7
|
+
3. If the agent's answer was incorrect, you must diagnose the reason for failure and devise a new concise and concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
|
8
|
+
{{"Finish": false, "Reflection": "I can see from the visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
|
9
9
|
Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives.
|
10
10
|
Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}}
|
11
11
|
4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true.
|
@@ -140,4 +140,5 @@ These are subtasks and their answers:
|
|
140
140
|
|
141
141
|
This is a reflection from a previous failed attempt:
|
142
142
|
{reflections}
|
143
|
+
|
143
144
|
Final answer: """
|
vision_agent/data/data.py
CHANGED
@@ -63,9 +63,9 @@ class DataStore:
|
|
63
63
|
|
64
64
|
self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
|
65
65
|
lambda x: (
|
66
|
-
func(self.lmm.generate(prompt,
|
66
|
+
func(self.lmm.generate(prompt, images=[x]))
|
67
67
|
if func
|
68
|
-
else self.lmm.generate(prompt,
|
68
|
+
else self.lmm.generate(prompt, images=[x])
|
69
69
|
)
|
70
70
|
)
|
71
71
|
return self
|
vision_agent/image_utils.py
CHANGED
@@ -103,7 +103,9 @@ def overlay_bboxes(
|
|
103
103
|
elif isinstance(image, np.ndarray):
|
104
104
|
image = Image.fromarray(image)
|
105
105
|
|
106
|
-
color = {
|
106
|
+
color = {
|
107
|
+
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
|
108
|
+
}
|
107
109
|
|
108
110
|
width, height = image.size
|
109
111
|
fontsize = max(12, int(min(width, height) / 40))
|
vision_agent/lmm/lmm.py
CHANGED
@@ -30,12 +30,16 @@ def encode_image(image: Union[str, Path]) -> str:
|
|
30
30
|
|
31
31
|
class LMM(ABC):
|
32
32
|
@abstractmethod
|
33
|
-
def generate(
|
33
|
+
def generate(
|
34
|
+
self, prompt: str, images: Optional[List[Union[str, Path]]] = None
|
35
|
+
) -> str:
|
34
36
|
pass
|
35
37
|
|
36
38
|
@abstractmethod
|
37
39
|
def chat(
|
38
|
-
self,
|
40
|
+
self,
|
41
|
+
chat: List[Dict[str, str]],
|
42
|
+
images: Optional[List[Union[str, Path]]] = None,
|
39
43
|
) -> str:
|
40
44
|
pass
|
41
45
|
|
@@ -43,7 +47,7 @@ class LMM(ABC):
|
|
43
47
|
def __call__(
|
44
48
|
self,
|
45
49
|
input: Union[str, List[Dict[str, str]]],
|
46
|
-
|
50
|
+
images: Optional[List[Union[str, Path]]] = None,
|
47
51
|
) -> str:
|
48
52
|
pass
|
49
53
|
|
@@ -57,27 +61,29 @@ class LLaVALMM(LMM):
|
|
57
61
|
def __call__(
|
58
62
|
self,
|
59
63
|
input: Union[str, List[Dict[str, str]]],
|
60
|
-
|
64
|
+
images: Optional[List[Union[str, Path]]] = None,
|
61
65
|
) -> str:
|
62
66
|
if isinstance(input, str):
|
63
|
-
return self.generate(input,
|
64
|
-
return self.chat(input,
|
67
|
+
return self.generate(input, images)
|
68
|
+
return self.chat(input, images)
|
65
69
|
|
66
70
|
def chat(
|
67
|
-
self,
|
71
|
+
self,
|
72
|
+
chat: List[Dict[str, str]],
|
73
|
+
images: Optional[List[Union[str, Path]]] = None,
|
68
74
|
) -> str:
|
69
75
|
raise NotImplementedError("Chat not supported for LLaVA")
|
70
76
|
|
71
77
|
def generate(
|
72
78
|
self,
|
73
79
|
prompt: str,
|
74
|
-
|
80
|
+
images: Optional[List[Union[str, Path]]] = None,
|
75
81
|
temperature: float = 0.1,
|
76
82
|
max_new_tokens: int = 1500,
|
77
83
|
) -> str:
|
78
84
|
data = {"prompt": prompt}
|
79
|
-
if
|
80
|
-
data["image"] = encode_image(
|
85
|
+
if images and len(images) > 0:
|
86
|
+
data["image"] = encode_image(images[0])
|
81
87
|
data["temperature"] = temperature # type: ignore
|
82
88
|
data["max_new_tokens"] = max_new_tokens # type: ignore
|
83
89
|
res = requests.post(
|
@@ -121,14 +127,16 @@ class OpenAILMM(LMM):
|
|
121
127
|
def __call__(
|
122
128
|
self,
|
123
129
|
input: Union[str, List[Dict[str, str]]],
|
124
|
-
|
130
|
+
images: Optional[List[Union[str, Path]]] = None,
|
125
131
|
) -> str:
|
126
132
|
if isinstance(input, str):
|
127
|
-
return self.generate(input,
|
128
|
-
return self.chat(input,
|
133
|
+
return self.generate(input, images)
|
134
|
+
return self.chat(input, images)
|
129
135
|
|
130
136
|
def chat(
|
131
|
-
self,
|
137
|
+
self,
|
138
|
+
chat: List[Dict[str, str]],
|
139
|
+
images: Optional[List[Union[str, Path]]] = None,
|
132
140
|
) -> str:
|
133
141
|
fixed_chat = []
|
134
142
|
for c in chat:
|
@@ -136,25 +144,26 @@ class OpenAILMM(LMM):
|
|
136
144
|
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
|
137
145
|
fixed_chat.append(fixed_c)
|
138
146
|
|
139
|
-
if
|
140
|
-
|
141
|
-
|
142
|
-
extension
|
143
|
-
|
144
|
-
extension
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
"
|
154
|
-
|
147
|
+
if images and len(images) > 0:
|
148
|
+
for image in images:
|
149
|
+
extension = Path(image).suffix
|
150
|
+
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
|
151
|
+
extension = "jpg"
|
152
|
+
elif extension.lower() == ".png":
|
153
|
+
extension = "png"
|
154
|
+
else:
|
155
|
+
raise ValueError(f"Unsupported image extension: {extension}")
|
156
|
+
|
157
|
+
encoded_image = encode_image(image)
|
158
|
+
fixed_chat[0]["content"].append( # type: ignore
|
159
|
+
{
|
160
|
+
"type": "image_url",
|
161
|
+
"image_url": {
|
162
|
+
"url": f"data:image/{extension};base64,{encoded_image}",
|
163
|
+
"detail": "low",
|
164
|
+
},
|
155
165
|
},
|
156
|
-
|
157
|
-
)
|
166
|
+
)
|
158
167
|
|
159
168
|
response = self.client.chat.completions.create(
|
160
169
|
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
|
@@ -162,7 +171,11 @@ class OpenAILMM(LMM):
|
|
162
171
|
|
163
172
|
return cast(str, response.choices[0].message.content)
|
164
173
|
|
165
|
-
def generate(
|
174
|
+
def generate(
|
175
|
+
self,
|
176
|
+
prompt: str,
|
177
|
+
images: Optional[List[Union[str, Path]]] = None,
|
178
|
+
) -> str:
|
166
179
|
message: List[Dict[str, Any]] = [
|
167
180
|
{
|
168
181
|
"role": "user",
|
@@ -171,18 +184,19 @@ class OpenAILMM(LMM):
|
|
171
184
|
],
|
172
185
|
}
|
173
186
|
]
|
174
|
-
if
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
"
|
182
|
-
|
187
|
+
if images and len(images) > 0:
|
188
|
+
for image in images:
|
189
|
+
extension = Path(image).suffix
|
190
|
+
encoded_image = encode_image(image)
|
191
|
+
message[0]["content"].append(
|
192
|
+
{
|
193
|
+
"type": "image_url",
|
194
|
+
"image_url": {
|
195
|
+
"url": f"data:image/{extension};base64,{encoded_image}",
|
196
|
+
"detail": "low",
|
197
|
+
},
|
183
198
|
},
|
184
|
-
|
185
|
-
)
|
199
|
+
)
|
186
200
|
|
187
201
|
response = self.client.chat.completions.create(
|
188
202
|
model=self.model_name, messages=message, **self.kwargs # type: ignore
|
vision_agent/tools/__init__.py
CHANGED
vision_agent/tools/tools.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
import tempfile
|
3
3
|
from abc import ABC
|
4
|
-
from collections import Counter as CounterClass
|
5
4
|
from pathlib import Path
|
6
5
|
from typing import Any, Dict, List, Tuple, Union, cast
|
7
6
|
|
@@ -373,6 +372,104 @@ class GroundingSAM(Tool):
|
|
373
372
|
return ret_pred
|
374
373
|
|
375
374
|
|
375
|
+
class DINOv(Tool):
|
376
|
+
r"""DINOv is a tool that can detect and segment similar objects with the given input masks.
|
377
|
+
|
378
|
+
Example
|
379
|
+
-------
|
380
|
+
>>> import vision_agent as va
|
381
|
+
>>> t = va.tools.DINOv()
|
382
|
+
>>> t(prompt=[{"mask":"balloon_mask.jpg", "image": "balloon.jpg"}], image="balloon.jpg"])
|
383
|
+
[{'scores': [0.512, 0.212],
|
384
|
+
'masks': [array([[0, 0, 0, ..., 0, 0, 0],
|
385
|
+
...,
|
386
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)},
|
387
|
+
array([[0, 0, 0, ..., 0, 0, 0],
|
388
|
+
...,
|
389
|
+
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
|
390
|
+
"""
|
391
|
+
|
392
|
+
name = "dinov_"
|
393
|
+
description = "'dinov_' is a tool that can detect and segment similar objects given a reference segmentation mask."
|
394
|
+
usage = {
|
395
|
+
"required_parameters": [
|
396
|
+
{"name": "prompt", "type": "List[Dict[str, str]]"},
|
397
|
+
{"name": "image", "type": "str"},
|
398
|
+
],
|
399
|
+
"examples": [
|
400
|
+
{
|
401
|
+
"scenario": "Can you find all the balloons in this image that is similar to the provided masked area? Image name: input.jpg Reference image: balloon.jpg Reference mask: balloon_mask.jpg",
|
402
|
+
"parameters": {
|
403
|
+
"prompt": [
|
404
|
+
{"mask": "balloon_mask.jpg", "image": "balloon.jpg"},
|
405
|
+
],
|
406
|
+
"image": "input.jpg",
|
407
|
+
},
|
408
|
+
},
|
409
|
+
{
|
410
|
+
"scenario": "Detect all the objects in this image that are similar to the provided mask. Image name: original.jpg Reference image: mask.png Reference mask: background.png",
|
411
|
+
"parameters": {
|
412
|
+
"prompt": [
|
413
|
+
{"mask": "mask.png", "image": "background.png"},
|
414
|
+
],
|
415
|
+
"image": "original.jpg",
|
416
|
+
},
|
417
|
+
},
|
418
|
+
],
|
419
|
+
}
|
420
|
+
|
421
|
+
def __call__(
|
422
|
+
self, prompt: List[Dict[str, str]], image: Union[str, ImageType]
|
423
|
+
) -> Dict:
|
424
|
+
"""Invoke the DINOv model.
|
425
|
+
|
426
|
+
Parameters:
|
427
|
+
prompt: a list of visual prompts in the form of {'mask': 'MASK_FILE_PATH', 'image': 'IMAGE_FILE_PATH'}.
|
428
|
+
image: the input image to segment.
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
A dictionary of the below keys: 'scores', 'masks' and 'mask_shape', which stores a list of detected segmentation masks and its scores.
|
432
|
+
"""
|
433
|
+
image_b64 = convert_to_b64(image)
|
434
|
+
for p in prompt:
|
435
|
+
p["mask"] = convert_to_b64(p["mask"])
|
436
|
+
p["image"] = convert_to_b64(p["image"])
|
437
|
+
request_data = {
|
438
|
+
"prompt": prompt,
|
439
|
+
"image": image_b64,
|
440
|
+
"tool": "dinov",
|
441
|
+
}
|
442
|
+
data: Dict[str, Any] = _send_inference_request(request_data, "dinov")
|
443
|
+
if "bboxes" in data:
|
444
|
+
data["bboxes"] = [
|
445
|
+
normalize_bbox(box, data["mask_shape"]) for box in data["bboxes"]
|
446
|
+
]
|
447
|
+
if "masks" in data:
|
448
|
+
data["masks"] = [
|
449
|
+
rle_decode(mask_rle=mask, shape=data["mask_shape"])
|
450
|
+
for mask in data["masks"]
|
451
|
+
]
|
452
|
+
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
|
453
|
+
return data
|
454
|
+
|
455
|
+
|
456
|
+
class AgentDINOv(DINOv):
|
457
|
+
def __call__(
|
458
|
+
self,
|
459
|
+
prompt: List[Dict[str, str]],
|
460
|
+
image: Union[str, ImageType],
|
461
|
+
) -> Dict:
|
462
|
+
rets = super().__call__(prompt, image)
|
463
|
+
mask_files = []
|
464
|
+
for mask in rets["masks"]:
|
465
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
466
|
+
file_name = Path(tmp.name).with_suffix(".mask.png")
|
467
|
+
Image.fromarray(mask * 255).save(file_name)
|
468
|
+
mask_files.append(str(file_name))
|
469
|
+
rets["masks"] = mask_files
|
470
|
+
return rets
|
471
|
+
|
472
|
+
|
376
473
|
class AgentGroundingSAM(GroundingSAM):
|
377
474
|
r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files
|
378
475
|
returns the file name. This makes it easier for agents to use.
|
@@ -396,33 +493,6 @@ class AgentGroundingSAM(GroundingSAM):
|
|
396
493
|
return rets
|
397
494
|
|
398
495
|
|
399
|
-
class Counter(Tool):
|
400
|
-
r"""Counter detects and counts the number of objects in an image given an input such as a category name or referring expression."""
|
401
|
-
|
402
|
-
name = "counter_"
|
403
|
-
description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression. It returns a dictionary containing the labels and their counts."
|
404
|
-
usage = {
|
405
|
-
"required_parameters": [
|
406
|
-
{"name": "prompt", "type": "str"},
|
407
|
-
{"name": "image", "type": "str"},
|
408
|
-
],
|
409
|
-
"examples": [
|
410
|
-
{
|
411
|
-
"scenario": "Can you count the number of cars in this image? Image name image.jpg",
|
412
|
-
"parameters": {"prompt": "car", "image": "image.jpg"},
|
413
|
-
},
|
414
|
-
{
|
415
|
-
"scenario": "Can you count the number of people? Image name: people.png",
|
416
|
-
"parameters": {"prompt": "person", "image": "people.png"},
|
417
|
-
},
|
418
|
-
],
|
419
|
-
}
|
420
|
-
|
421
|
-
def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
|
422
|
-
resp = GroundingDINO()(prompt, image)
|
423
|
-
return dict(CounterClass(resp["labels"]))
|
424
|
-
|
425
|
-
|
426
496
|
class Crop(Tool):
|
427
497
|
r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
|
428
498
|
|
@@ -573,11 +643,42 @@ class SegIoU(Tool):
|
|
573
643
|
return cast(float, round(iou, 2))
|
574
644
|
|
575
645
|
|
646
|
+
class BoxDistance(Tool):
|
647
|
+
name = "box_distance_"
|
648
|
+
description = (
|
649
|
+
"'box_distance_' returns the minimum distance between two bounding boxes."
|
650
|
+
)
|
651
|
+
usage = {
|
652
|
+
"required_parameters": [
|
653
|
+
{"name": "bbox1", "type": "List[int]"},
|
654
|
+
{"name": "bbox2", "type": "List[int]"},
|
655
|
+
],
|
656
|
+
"examples": [
|
657
|
+
{
|
658
|
+
"scenario": "If you want to calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
|
659
|
+
"parameters": {
|
660
|
+
"bbox1": [0.2, 0.21, 0.34, 0.42],
|
661
|
+
"bbox2": [0.3, 0.31, 0.44, 0.52],
|
662
|
+
},
|
663
|
+
}
|
664
|
+
],
|
665
|
+
}
|
666
|
+
|
667
|
+
def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
|
668
|
+
x11, y11, x12, y12 = bbox1
|
669
|
+
x21, y21, x22, y22 = bbox2
|
670
|
+
|
671
|
+
horizontal_dist = np.max([0, x21 - x12, x11 - x22])
|
672
|
+
vertical_dist = np.max([0, y21 - y12, y11 - y22])
|
673
|
+
|
674
|
+
return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2))
|
675
|
+
|
676
|
+
|
576
677
|
class ExtractFrames(Tool):
|
577
678
|
r"""Extract frames from a video."""
|
578
679
|
|
579
680
|
name = "extract_frames_"
|
580
|
-
description = "'extract_frames_' extracts frames
|
681
|
+
description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
|
581
682
|
usage = {
|
582
683
|
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
583
684
|
"examples": [
|
@@ -649,13 +750,14 @@ TOOLS = {
|
|
649
750
|
ImageCaption,
|
650
751
|
GroundingDINO,
|
651
752
|
AgentGroundingSAM,
|
753
|
+
AgentDINOv,
|
652
754
|
ExtractFrames,
|
653
|
-
Counter,
|
654
755
|
Crop,
|
655
756
|
BboxArea,
|
656
757
|
SegArea,
|
657
758
|
BboxIoU,
|
658
759
|
SegIoU,
|
760
|
+
BoxDistance,
|
659
761
|
Calculator,
|
660
762
|
]
|
661
763
|
)
|
vision_agent/tools/video.py
CHANGED
@@ -15,7 +15,7 @@ _CLIP_LENGTH = 30.0
|
|
15
15
|
|
16
16
|
|
17
17
|
def extract_frames_from_video(
|
18
|
-
video_uri: str, fps:
|
18
|
+
video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
|
19
19
|
) -> List[Tuple[np.ndarray, float]]:
|
20
20
|
"""Extract frames from a video
|
21
21
|
|
@@ -25,7 +25,8 @@ def extract_frames_from_video(
|
|
25
25
|
motion_detection_threshold: The threshold to detect motion between
|
26
26
|
changes/frames. A value between 0-1, which represents the percentage change
|
27
27
|
required for the frames to be considered in motion. For example, a lower
|
28
|
-
value means more frames will be extracted.
|
28
|
+
value means more frames will be extracted. A non-positive value will disable
|
29
|
+
motion detection and extract all frames.
|
29
30
|
|
30
31
|
Returns:
|
31
32
|
a list of tuples containing the extracted frame and the timestamp in seconds.
|
@@ -119,18 +120,19 @@ def _extract_frames_by_clip(
|
|
119
120
|
total=processable_frames, desc=f"Extracting frames from clip {start}-{end}"
|
120
121
|
)
|
121
122
|
for i, frame in enumerate(clip.iter_frames(fps=fps, dtype="uint8")):
|
122
|
-
curr_processed_frame = _preprocess_frame(frame)
|
123
123
|
total_count += 1
|
124
124
|
pbar.update(1)
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
125
|
+
if motion_detection_threshold > 0:
|
126
|
+
curr_processed_frame = _preprocess_frame(frame)
|
127
|
+
# Skip the frame if it is similar to the previous one
|
128
|
+
if prev_processed_frame is not None and _similar_frame(
|
129
|
+
prev_processed_frame,
|
130
|
+
curr_processed_frame,
|
131
|
+
threshold=motion_detection_threshold,
|
132
|
+
):
|
133
|
+
skipped_count += 1
|
134
|
+
continue
|
135
|
+
prev_processed_frame = curr_processed_frame
|
134
136
|
ts = round(clip.reader.pos / source_fps, 3)
|
135
137
|
frames.append((frame, ts))
|
136
138
|
|
@@ -2,28 +2,28 @@ vision_agent/__init__.py,sha256=wD1cssVTAJ55uTViNfBGooqJUV0p9fmVAuTMHHrmUBU,229
|
|
2
2
|
vision_agent/agent/__init__.py,sha256=B4JVrbY4IRVCJfjmrgvcp7h1mTUEk8MZvL0Zmej4Ka0,127
|
3
3
|
vision_agent/agent/agent.py,sha256=X7kON-g9ePUKumCDaYfQNBX_MEFE-ax5PnRp7-Cc5Wo,529
|
4
4
|
vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMVg,11511
|
5
|
-
vision_agent/agent/easytool_prompts.py,sha256=
|
6
|
-
vision_agent/agent/reflexion.py,sha256=
|
5
|
+
vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
|
6
|
+
vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
9
|
-
vision_agent/agent/vision_agent_prompts.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=QWIirRBB3ZPg3figWcf8-g9ltFydM1BDn75LbXWbep0,22735
|
9
|
+
vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
|
-
vision_agent/data/data.py,sha256=
|
11
|
+
vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
|
12
12
|
vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
14
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
16
|
-
vision_agent/image_utils.py,sha256=
|
16
|
+
vision_agent/image_utils.py,sha256=qRN_Y1XXBm9EL6V53OZUq21h0spIa1J6X9YDbe6B87o,4805
|
17
17
|
vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
|
18
18
|
vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
|
19
19
|
vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
|
20
|
-
vision_agent/lmm/lmm.py,sha256=
|
21
|
-
vision_agent/tools/__init__.py,sha256=
|
20
|
+
vision_agent/lmm/lmm.py,sha256=1E7e_S_0fOKnf6mSsEdkXvsIjGmhBGl5XW4By2jvhbY,10045
|
21
|
+
vision_agent/tools/__init__.py,sha256=dkzk9amNzTEKULMB1xRJspqEGpzNPGuccWeXrv1xI0U,280
|
22
22
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
23
|
-
vision_agent/tools/tools.py,sha256=
|
24
|
-
vision_agent/tools/video.py,sha256=
|
23
|
+
vision_agent/tools/tools.py,sha256=ybhCyutEGzHPKuR0Cu--Nb--KubjYvyzLEzVQYzIMTw,29148
|
24
|
+
vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
|
25
25
|
vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
|
26
|
-
vision_agent-0.1.
|
27
|
-
vision_agent-0.1.
|
28
|
-
vision_agent-0.1.
|
29
|
-
vision_agent-0.1.
|
26
|
+
vision_agent-0.1.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
27
|
+
vision_agent-0.1.4.dist-info/METADATA,sha256=FyBYGPHgC0uV7uy7wph8yvdQpEWSACnGR96y6Jt-E6A,6233
|
28
|
+
vision_agent-0.1.4.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
29
|
+
vision_agent-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|