vision-agent 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +88 -29
- vision_agent/image_utils.py +95 -5
- vision_agent/llm/llm.py +10 -7
- vision_agent/lmm/lmm.py +14 -3
- vision_agent/tools/__init__.py +14 -1
- vision_agent/tools/tools.py +123 -60
- vision_agent/tools/video.py +8 -4
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.42.dist-info}/METADATA +1 -1
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.42.dist-info}/RECORD +11 -11
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.42.dist-info}/LICENSE +0 -0
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.42.dist-info}/WHEEL +0 -0
@@ -1,11 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import sys
|
4
|
+
import tempfile
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6
7
|
|
7
8
|
from tabulate import tabulate
|
8
9
|
|
10
|
+
from vision_agent.image_utils import overlay_bboxes, overlay_masks
|
9
11
|
from vision_agent.llm import LLM, OpenAILLM
|
10
12
|
from vision_agent.lmm import LMM, OpenAILMM
|
11
13
|
from vision_agent.tools import TOOLS
|
@@ -248,13 +250,12 @@ def retrieval(
|
|
248
250
|
tools: Dict[int, Any],
|
249
251
|
previous_log: str,
|
250
252
|
reflections: str,
|
251
|
-
) -> Tuple[
|
253
|
+
) -> Tuple[Dict, str]:
|
252
254
|
tool_id = choose_tool(
|
253
255
|
model, question, {k: v["description"] for k, v in tools.items()}, reflections
|
254
256
|
)
|
255
257
|
if tool_id is None:
|
256
|
-
return
|
257
|
-
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
|
258
|
+
return {}, ""
|
258
259
|
|
259
260
|
tool_instructions = tools[tool_id]
|
260
261
|
tool_usage = tool_instructions["usage"]
|
@@ -263,16 +264,13 @@ def retrieval(
|
|
263
264
|
parameters = choose_parameter(
|
264
265
|
model, question, tool_usage, previous_log, reflections
|
265
266
|
)
|
266
|
-
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
|
267
267
|
if parameters is None:
|
268
|
-
return
|
269
|
-
tool_results =
|
270
|
-
{"task": question, "tool_name": tool_name, "parameters": parameters}
|
271
|
-
]
|
268
|
+
return {}, ""
|
269
|
+
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
|
272
270
|
|
273
271
|
_LOGGER.info(
|
274
|
-
f"""Going to run the following
|
275
|
-
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
|
272
|
+
f"""Going to run the following tool(s) in sequence:
|
273
|
+
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
|
276
274
|
)
|
277
275
|
|
278
276
|
def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
|
@@ -286,13 +284,11 @@ def retrieval(
|
|
286
284
|
call_results.append(function_call(tools[tool_id]["class"], parameters))
|
287
285
|
return call_results
|
288
286
|
|
289
|
-
call_results =
|
290
|
-
|
291
|
-
call_results.extend(parse_tool_results(result))
|
292
|
-
tool_results[i]["call_results"] = call_results
|
287
|
+
call_results = parse_tool_results(tool_results)
|
288
|
+
tool_results["call_results"] = call_results
|
293
289
|
|
294
|
-
call_results_str =
|
295
|
-
_LOGGER.info(f"\tCall Results: {call_results_str}")
|
290
|
+
call_results_str = str(call_results)
|
291
|
+
# _LOGGER.info(f"\tCall Results: {call_results_str}")
|
296
292
|
return tool_results, call_results_str
|
297
293
|
|
298
294
|
|
@@ -335,14 +331,70 @@ def self_reflect(
|
|
335
331
|
tool_results=str(tool_result),
|
336
332
|
final_answer=final_answer,
|
337
333
|
)
|
338
|
-
if
|
334
|
+
if (
|
335
|
+
issubclass(type(reflect_model), LMM)
|
336
|
+
and image is not None
|
337
|
+
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
|
338
|
+
):
|
339
339
|
return reflect_model(prompt, image=image) # type: ignore
|
340
340
|
return reflect_model(prompt)
|
341
341
|
|
342
342
|
|
343
343
|
def parse_reflect(reflect: str) -> bool:
|
344
344
|
# GPT-4V has a hard time following directions, so make the criteria less strict
|
345
|
-
return
|
345
|
+
return (
|
346
|
+
"finish" in reflect.lower() and len(reflect) < 100
|
347
|
+
) or "finish" in reflect.lower()[-10:]
|
348
|
+
|
349
|
+
|
350
|
+
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
|
351
|
+
image_to_data: Dict[str, Dict] = {}
|
352
|
+
for tool_result in all_tool_results:
|
353
|
+
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
|
354
|
+
continue
|
355
|
+
|
356
|
+
parameters = tool_result["parameters"]
|
357
|
+
# parameters can either be a dictionary or list, parameters can also be malformed
|
358
|
+
# becaus the LLM builds them
|
359
|
+
if isinstance(parameters, dict):
|
360
|
+
if "image" not in parameters:
|
361
|
+
continue
|
362
|
+
parameters = [parameters]
|
363
|
+
elif isinstance(tool_result["parameters"], list):
|
364
|
+
if (
|
365
|
+
len(tool_result["parameters"]) < 1
|
366
|
+
and "image" not in tool_result["parameters"][0]
|
367
|
+
):
|
368
|
+
continue
|
369
|
+
|
370
|
+
for param, call_result in zip(parameters, tool_result["call_results"]):
|
371
|
+
|
372
|
+
# calls can fail, so we need to check if the call was successful
|
373
|
+
if not isinstance(call_result, dict):
|
374
|
+
continue
|
375
|
+
if "bboxes" not in call_result:
|
376
|
+
continue
|
377
|
+
|
378
|
+
# if the call was successful, then we can add the image data
|
379
|
+
image = param["image"]
|
380
|
+
if image not in image_to_data:
|
381
|
+
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
|
382
|
+
|
383
|
+
image_to_data[image]["bboxes"].extend(call_result["bboxes"])
|
384
|
+
image_to_data[image]["labels"].extend(call_result["labels"])
|
385
|
+
if "masks" in call_result:
|
386
|
+
image_to_data[image]["masks"].extend(call_result["masks"])
|
387
|
+
|
388
|
+
visualized_images = []
|
389
|
+
for image in image_to_data:
|
390
|
+
image_path = Path(image)
|
391
|
+
image_data = image_to_data[image]
|
392
|
+
image = overlay_masks(image_path, image_data)
|
393
|
+
image = overlay_bboxes(image, image_data)
|
394
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
395
|
+
image.save(f.name)
|
396
|
+
visualized_images.append(f.name)
|
397
|
+
return visualized_images
|
346
398
|
|
347
399
|
|
348
400
|
class VisionAgent(Agent):
|
@@ -371,10 +423,16 @@ class VisionAgent(Agent):
|
|
371
423
|
verbose: bool = False,
|
372
424
|
):
|
373
425
|
self.task_model = (
|
374
|
-
OpenAILLM(json_mode=True)
|
426
|
+
OpenAILLM(json_mode=True, temperature=0.1)
|
427
|
+
if task_model is None
|
428
|
+
else task_model
|
429
|
+
)
|
430
|
+
self.answer_model = (
|
431
|
+
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
|
432
|
+
)
|
433
|
+
self.reflect_model = (
|
434
|
+
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
|
375
435
|
)
|
376
|
-
self.answer_model = OpenAILLM() if answer_model is None else answer_model
|
377
|
-
self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model
|
378
436
|
self.max_retries = max_retries
|
379
437
|
|
380
438
|
self.tools = TOOLS
|
@@ -389,7 +447,8 @@ class VisionAgent(Agent):
|
|
389
447
|
"""Invoke the vision agent.
|
390
448
|
|
391
449
|
Parameters:
|
392
|
-
input: a prompt that describe the task or a conversation in the format of
|
450
|
+
input: a prompt that describe the task or a conversation in the format of
|
451
|
+
[{"role": "user", "content": "describe your task here..."}].
|
393
452
|
image: the input image referenced in the prompt parameter.
|
394
453
|
|
395
454
|
Returns:
|
@@ -413,7 +472,6 @@ class VisionAgent(Agent):
|
|
413
472
|
for _ in range(self.max_retries):
|
414
473
|
task_list = create_tasks(self.task_model, question, self.tools, reflections)
|
415
474
|
|
416
|
-
_LOGGER.info(f"Task Dependency: {task_list}")
|
417
475
|
task_depend = {"Original Quesiton": question}
|
418
476
|
previous_log = ""
|
419
477
|
answers = []
|
@@ -424,7 +482,6 @@ class VisionAgent(Agent):
|
|
424
482
|
for task in task_list:
|
425
483
|
task_str = task["task"]
|
426
484
|
previous_log = str(task_depend)
|
427
|
-
_LOGGER.info(f"\tSubtask: {task_str}")
|
428
485
|
tool_results, call_results = retrieval(
|
429
486
|
self.task_model,
|
430
487
|
task_str,
|
@@ -436,10 +493,10 @@ class VisionAgent(Agent):
|
|
436
493
|
self.answer_model, task_str, call_results, previous_log, reflections
|
437
494
|
)
|
438
495
|
|
439
|
-
|
440
|
-
|
441
|
-
all_tool_results.extend(tool_results)
|
496
|
+
tool_results["answer"] = answer
|
497
|
+
all_tool_results.append(tool_results)
|
442
498
|
|
499
|
+
_LOGGER.info(f"\tCall Result: {call_results}")
|
443
500
|
_LOGGER.info(f"\tAnswer: {answer}")
|
444
501
|
answers.append({"task": task_str, "answer": answer})
|
445
502
|
task_depend[task["id"]]["answer"] = answer # type: ignore
|
@@ -448,15 +505,17 @@ class VisionAgent(Agent):
|
|
448
505
|
self.answer_model, question, answers, reflections
|
449
506
|
)
|
450
507
|
|
508
|
+
visualized_images = visualize_result(all_tool_results)
|
509
|
+
all_tool_results.append({"visualized_images": visualized_images})
|
451
510
|
reflection = self_reflect(
|
452
511
|
self.reflect_model,
|
453
512
|
question,
|
454
513
|
self.tools,
|
455
514
|
all_tool_results,
|
456
515
|
final_answer,
|
457
|
-
image,
|
516
|
+
visualized_images[0] if len(visualized_images) > 0 else image,
|
458
517
|
)
|
459
|
-
_LOGGER.info(f"
|
518
|
+
_LOGGER.info(f"Reflection: {reflection}")
|
460
519
|
if parse_reflect(reflection):
|
461
520
|
break
|
462
521
|
else:
|
vision_agent/image_utils.py
CHANGED
@@ -3,15 +3,38 @@
|
|
3
3
|
import base64
|
4
4
|
from io import BytesIO
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Tuple, Union
|
6
|
+
from typing import Dict, Tuple, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from PIL import Image
|
9
|
+
from PIL import Image, ImageDraw, ImageFont
|
10
10
|
from PIL.Image import Image as ImageType
|
11
11
|
|
12
|
+
COLORS = [
|
13
|
+
(158, 218, 229),
|
14
|
+
(219, 219, 141),
|
15
|
+
(23, 190, 207),
|
16
|
+
(188, 189, 34),
|
17
|
+
(199, 199, 199),
|
18
|
+
(247, 182, 210),
|
19
|
+
(127, 127, 127),
|
20
|
+
(227, 119, 194),
|
21
|
+
(196, 156, 148),
|
22
|
+
(197, 176, 213),
|
23
|
+
(140, 86, 75),
|
24
|
+
(148, 103, 189),
|
25
|
+
(255, 152, 150),
|
26
|
+
(152, 223, 138),
|
27
|
+
(214, 39, 40),
|
28
|
+
(44, 160, 44),
|
29
|
+
(255, 187, 120),
|
30
|
+
(174, 199, 232),
|
31
|
+
(255, 127, 14),
|
32
|
+
(31, 119, 180),
|
33
|
+
]
|
34
|
+
|
12
35
|
|
13
36
|
def b64_to_pil(b64_str: str) -> ImageType:
|
14
|
-
"""Convert a base64 string to a PIL Image.
|
37
|
+
r"""Convert a base64 string to a PIL Image.
|
15
38
|
|
16
39
|
Parameters:
|
17
40
|
b64_str: the base64 encoded image
|
@@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:
|
|
26
49
|
|
27
50
|
|
28
51
|
def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
|
29
|
-
"""Get the size of an image.
|
52
|
+
r"""Get the size of an image.
|
30
53
|
|
31
54
|
Parameters:
|
32
55
|
data: the input image
|
@@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,
|
|
41
64
|
|
42
65
|
|
43
66
|
def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
|
44
|
-
"""Convert an image to a base64 string.
|
67
|
+
r"""Convert an image to a base64 string.
|
45
68
|
|
46
69
|
Parameters:
|
47
70
|
data: the input image
|
@@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
|
|
60
83
|
else:
|
61
84
|
arr_bytes = data.tobytes()
|
62
85
|
return base64.b64encode(arr_bytes).decode("utf-8")
|
86
|
+
|
87
|
+
|
88
|
+
def overlay_bboxes(
|
89
|
+
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
|
90
|
+
) -> ImageType:
|
91
|
+
r"""Plots bounding boxes on to an image.
|
92
|
+
|
93
|
+
Parameters:
|
94
|
+
image: the input image
|
95
|
+
bboxes: the bounding boxes to overlay
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The image with the bounding boxes overlayed
|
99
|
+
"""
|
100
|
+
if isinstance(image, (str, Path)):
|
101
|
+
image = Image.open(image)
|
102
|
+
elif isinstance(image, np.ndarray):
|
103
|
+
image = Image.fromarray(image)
|
104
|
+
|
105
|
+
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
|
106
|
+
|
107
|
+
draw = ImageDraw.Draw(image)
|
108
|
+
font = ImageFont.load_default()
|
109
|
+
width, height = image.size
|
110
|
+
if "bboxes" not in bboxes:
|
111
|
+
return image.convert("RGB")
|
112
|
+
|
113
|
+
for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
|
114
|
+
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
|
115
|
+
draw.rectangle(box, outline=color[label], width=3)
|
116
|
+
label = f"{label}"
|
117
|
+
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
|
118
|
+
draw.rectangle(text_box, fill=color[label])
|
119
|
+
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
|
120
|
+
return image.convert("RGB")
|
121
|
+
|
122
|
+
|
123
|
+
def overlay_masks(
|
124
|
+
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
|
125
|
+
) -> ImageType:
|
126
|
+
r"""Plots masks on to an image.
|
127
|
+
|
128
|
+
Parameters:
|
129
|
+
image: the input image
|
130
|
+
masks: the masks to overlay
|
131
|
+
alpha: the transparency of the overlay
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The image with the masks overlayed
|
135
|
+
"""
|
136
|
+
if isinstance(image, (str, Path)):
|
137
|
+
image = Image.open(image)
|
138
|
+
elif isinstance(image, np.ndarray):
|
139
|
+
image = Image.fromarray(image)
|
140
|
+
|
141
|
+
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
|
142
|
+
if "masks" not in masks:
|
143
|
+
return image.convert("RGB")
|
144
|
+
|
145
|
+
for label, mask in zip(masks["labels"], masks["masks"]):
|
146
|
+
if isinstance(mask, str):
|
147
|
+
mask = np.array(Image.open(mask))
|
148
|
+
np_mask = np.zeros((image.size[1], image.size[0], 4))
|
149
|
+
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
|
150
|
+
mask_img = Image.fromarray(np_mask.astype(np.uint8))
|
151
|
+
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
|
152
|
+
return image.convert("RGB")
|
vision_agent/llm/llm.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from typing import Callable, Dict, List, Mapping, Union, cast
|
3
|
+
from typing import Any, Callable, Dict, List, Mapping, Union, cast
|
4
4
|
|
5
5
|
from openai import OpenAI
|
6
6
|
|
@@ -31,30 +31,33 @@ class OpenAILLM(LLM):
|
|
31
31
|
r"""An LLM class for any OpenAI LLM model."""
|
32
32
|
|
33
33
|
def __init__(
|
34
|
-
self,
|
34
|
+
self,
|
35
|
+
model_name: str = "gpt-4-turbo-preview",
|
36
|
+
json_mode: bool = False,
|
37
|
+
**kwargs: Any
|
35
38
|
):
|
36
39
|
self.model_name = model_name
|
37
40
|
self.client = OpenAI()
|
38
|
-
self.
|
41
|
+
self.kwargs = kwargs
|
42
|
+
if json_mode:
|
43
|
+
self.kwargs["response_format"] = {"type": "json_object"}
|
39
44
|
|
40
45
|
def generate(self, prompt: str) -> str:
|
41
|
-
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
|
42
46
|
response = self.client.chat.completions.create(
|
43
47
|
model=self.model_name,
|
44
48
|
messages=[
|
45
49
|
{"role": "user", "content": prompt},
|
46
50
|
],
|
47
|
-
**kwargs,
|
51
|
+
**self.kwargs,
|
48
52
|
)
|
49
53
|
|
50
54
|
return cast(str, response.choices[0].message.content)
|
51
55
|
|
52
56
|
def chat(self, chat: List[Dict[str, str]]) -> str:
|
53
|
-
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
|
54
57
|
response = self.client.chat.completions.create(
|
55
58
|
model=self.model_name,
|
56
59
|
messages=chat, # type: ignore
|
57
|
-
**kwargs,
|
60
|
+
**self.kwargs,
|
58
61
|
)
|
59
62
|
|
60
63
|
return cast(str, response.choices[0].message.content)
|
vision_agent/lmm/lmm.py
CHANGED
@@ -97,11 +97,15 @@ class OpenAILMM(LMM):
|
|
97
97
|
r"""An LMM class for the OpenAI GPT-4 Vision model."""
|
98
98
|
|
99
99
|
def __init__(
|
100
|
-
self,
|
100
|
+
self,
|
101
|
+
model_name: str = "gpt-4-vision-preview",
|
102
|
+
max_tokens: int = 1024,
|
103
|
+
**kwargs: Any,
|
101
104
|
):
|
102
105
|
self.model_name = model_name
|
103
106
|
self.max_tokens = max_tokens
|
104
107
|
self.client = OpenAI()
|
108
|
+
self.kwargs = kwargs
|
105
109
|
|
106
110
|
def __call__(
|
107
111
|
self,
|
@@ -123,6 +127,13 @@ class OpenAILMM(LMM):
|
|
123
127
|
|
124
128
|
if image:
|
125
129
|
extension = Path(image).suffix
|
130
|
+
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
|
131
|
+
extension = "jpg"
|
132
|
+
elif extension.lower() == ".png":
|
133
|
+
extension = "png"
|
134
|
+
else:
|
135
|
+
raise ValueError(f"Unsupported image extension: {extension}")
|
136
|
+
|
126
137
|
encoded_image = encode_image(image)
|
127
138
|
fixed_chat[0]["content"].append( # type: ignore
|
128
139
|
{
|
@@ -135,7 +146,7 @@ class OpenAILMM(LMM):
|
|
135
146
|
)
|
136
147
|
|
137
148
|
response = self.client.chat.completions.create(
|
138
|
-
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore
|
149
|
+
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
|
139
150
|
)
|
140
151
|
|
141
152
|
return cast(str, response.choices[0].message.content)
|
@@ -163,7 +174,7 @@ class OpenAILMM(LMM):
|
|
163
174
|
)
|
164
175
|
|
165
176
|
response = self.client.chat.completions.create(
|
166
|
-
model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore
|
177
|
+
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
|
167
178
|
)
|
168
179
|
return cast(str, response.choices[0].message.content)
|
169
180
|
|
vision_agent/tools/__init__.py
CHANGED
@@ -1,2 +1,15 @@
|
|
1
1
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
2
|
-
from .tools import
|
2
|
+
from .tools import (
|
3
|
+
CLIP,
|
4
|
+
TOOLS,
|
5
|
+
BboxArea,
|
6
|
+
BboxIoU,
|
7
|
+
Counter,
|
8
|
+
Crop,
|
9
|
+
ExtractFrames,
|
10
|
+
GroundingDINO,
|
11
|
+
GroundingSAM,
|
12
|
+
SegArea,
|
13
|
+
SegIoU,
|
14
|
+
Tool,
|
15
|
+
)
|
vision_agent/tools/tools.py
CHANGED
@@ -92,7 +92,7 @@ class CLIP(Tool):
|
|
92
92
|
}
|
93
93
|
|
94
94
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
95
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
95
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
96
96
|
"""Invoke the CLIP model.
|
97
97
|
|
98
98
|
Parameters:
|
@@ -122,7 +122,7 @@ class CLIP(Tool):
|
|
122
122
|
rets = []
|
123
123
|
for elt in resp_json["data"]:
|
124
124
|
rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
|
125
|
-
return cast(
|
125
|
+
return cast(Dict, rets[0])
|
126
126
|
|
127
127
|
|
128
128
|
class GroundingDINO(Tool):
|
@@ -168,7 +168,7 @@ class GroundingDINO(Tool):
|
|
168
168
|
}
|
169
169
|
|
170
170
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
171
|
-
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) ->
|
171
|
+
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
|
172
172
|
"""Invoke the Grounding DINO model.
|
173
173
|
|
174
174
|
Parameters:
|
@@ -204,7 +204,7 @@ class GroundingDINO(Tool):
|
|
204
204
|
if "scores" in elt:
|
205
205
|
elt["scores"] = [round(score, 2) for score in elt["scores"]]
|
206
206
|
elt["size"] = (image_size[1], image_size[0])
|
207
|
-
return cast(
|
207
|
+
return cast(Dict, resp_data)
|
208
208
|
|
209
209
|
|
210
210
|
class GroundingSAM(Tool):
|
@@ -259,7 +259,7 @@ class GroundingSAM(Tool):
|
|
259
259
|
}
|
260
260
|
|
261
261
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
262
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
262
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
263
263
|
"""Invoke the Grounding SAM model.
|
264
264
|
|
265
265
|
Parameters:
|
@@ -294,7 +294,7 @@ class GroundingSAM(Tool):
|
|
294
294
|
ret_pred["labels"].append(pred["label_name"])
|
295
295
|
ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
|
296
296
|
ret_pred["masks"].append(mask)
|
297
|
-
return
|
297
|
+
return ret_pred
|
298
298
|
|
299
299
|
|
300
300
|
class AgentGroundingSAM(GroundingSAM):
|
@@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM):
|
|
302
302
|
returns the file name. This makes it easier for agents to use.
|
303
303
|
"""
|
304
304
|
|
305
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
305
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
306
306
|
rets = super().__call__(prompt, image)
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
ret["masks"] = mask_files
|
307
|
+
mask_files = []
|
308
|
+
for mask in rets["masks"]:
|
309
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
310
|
+
Image.fromarray(mask * 255).save(tmp)
|
311
|
+
mask_files.append(tmp.name)
|
312
|
+
rets["masks"] = mask_files
|
314
313
|
return rets
|
315
314
|
|
316
315
|
|
@@ -363,7 +362,7 @@ class Crop(Tool):
|
|
363
362
|
],
|
364
363
|
}
|
365
364
|
|
366
|
-
def __call__(self, bbox: List[float], image: Union[str, Path]) ->
|
365
|
+
def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
|
367
366
|
pil_image = Image.open(image)
|
368
367
|
width, height = pil_image.size
|
369
368
|
bbox = [
|
@@ -373,10 +372,10 @@ class Crop(Tool):
|
|
373
372
|
int(bbox[3] * height),
|
374
373
|
]
|
375
374
|
cropped_image = pil_image.crop(bbox) # type: ignore
|
376
|
-
with tempfile.NamedTemporaryFile(suffix=".
|
375
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
377
376
|
cropped_image.save(tmp.name)
|
378
377
|
|
379
|
-
return tmp.name
|
378
|
+
return {"image": tmp.name}
|
380
379
|
|
381
380
|
|
382
381
|
class BboxArea(Tool):
|
@@ -388,7 +387,7 @@ class BboxArea(Tool):
|
|
388
387
|
"required_parameters": [{"name": "bbox", "type": "List[int]"}],
|
389
388
|
"examples": [
|
390
389
|
{
|
391
|
-
"scenario": "If you want to calculate the area of the bounding box [0, 0,
|
390
|
+
"scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
|
392
391
|
"parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
|
393
392
|
}
|
394
393
|
],
|
@@ -430,6 +429,109 @@ class SegArea(Tool):
|
|
430
429
|
return cast(float, round(np.sum(np_mask) / 255, 2))
|
431
430
|
|
432
431
|
|
432
|
+
class BboxIoU(Tool):
|
433
|
+
name = "bbox_iou_"
|
434
|
+
description = (
|
435
|
+
"'bbox_iou_' returns the intersection over union of two bounding boxes."
|
436
|
+
)
|
437
|
+
usage = {
|
438
|
+
"required_parameters": [
|
439
|
+
{"name": "bbox1", "type": "List[int]"},
|
440
|
+
{"name": "bbox2", "type": "List[int]"},
|
441
|
+
],
|
442
|
+
"examples": [
|
443
|
+
{
|
444
|
+
"scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
|
445
|
+
"parameters": {
|
446
|
+
"bbox1": [0.2, 0.21, 0.34, 0.42],
|
447
|
+
"bbox2": [0.3, 0.31, 0.44, 0.52],
|
448
|
+
},
|
449
|
+
}
|
450
|
+
],
|
451
|
+
}
|
452
|
+
|
453
|
+
def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
|
454
|
+
x1, y1, x2, y2 = bbox1
|
455
|
+
x3, y3, x4, y4 = bbox2
|
456
|
+
xA = max(x1, x3)
|
457
|
+
yA = max(y1, y3)
|
458
|
+
xB = min(x2, x4)
|
459
|
+
yB = min(y2, y4)
|
460
|
+
inter_area = max(0, xB - xA) * max(0, yB - yA)
|
461
|
+
boxa_area = (x2 - x1) * (y2 - y1)
|
462
|
+
boxb_area = (x4 - x3) * (y4 - y3)
|
463
|
+
iou = inter_area / float(boxa_area + boxb_area - inter_area)
|
464
|
+
return round(iou, 2)
|
465
|
+
|
466
|
+
|
467
|
+
class SegIoU(Tool):
|
468
|
+
name = "seg_iou_"
|
469
|
+
description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
|
470
|
+
usage = {
|
471
|
+
"required_parameters": [
|
472
|
+
{"name": "mask1", "type": "str"},
|
473
|
+
{"name": "mask2", "type": "str"},
|
474
|
+
],
|
475
|
+
"examples": [
|
476
|
+
{
|
477
|
+
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
|
478
|
+
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
|
479
|
+
}
|
480
|
+
],
|
481
|
+
}
|
482
|
+
|
483
|
+
def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
|
484
|
+
pil_mask1 = Image.open(str(mask1))
|
485
|
+
pil_mask2 = Image.open(str(mask2))
|
486
|
+
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
|
487
|
+
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
|
488
|
+
intersection = np.logical_and(np_mask1, np_mask2)
|
489
|
+
union = np.logical_or(np_mask1, np_mask2)
|
490
|
+
iou = np.sum(intersection) / np.sum(union)
|
491
|
+
return cast(float, round(iou, 2))
|
492
|
+
|
493
|
+
|
494
|
+
class ExtractFrames(Tool):
|
495
|
+
r"""Extract frames from a video."""
|
496
|
+
|
497
|
+
name = "extract_frames_"
|
498
|
+
description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
|
499
|
+
usage = {
|
500
|
+
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
501
|
+
"examples": [
|
502
|
+
{
|
503
|
+
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
|
504
|
+
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
|
505
|
+
},
|
506
|
+
{
|
507
|
+
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
|
508
|
+
"parameters": {"video_uri": "tests/data/test.mp4"},
|
509
|
+
},
|
510
|
+
],
|
511
|
+
}
|
512
|
+
|
513
|
+
def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
|
514
|
+
"""Extract frames from a video.
|
515
|
+
|
516
|
+
|
517
|
+
Parameters:
|
518
|
+
video_uri: the path to the video file or a url points to the video data
|
519
|
+
|
520
|
+
Returns:
|
521
|
+
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
|
522
|
+
"""
|
523
|
+
frames = extract_frames_from_video(video_uri)
|
524
|
+
result = []
|
525
|
+
_LOGGER.info(
|
526
|
+
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
|
527
|
+
)
|
528
|
+
for frame, ts in frames:
|
529
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
530
|
+
Image.fromarray(frame).save(tmp)
|
531
|
+
result.append((tmp.name, ts))
|
532
|
+
return result
|
533
|
+
|
534
|
+
|
433
535
|
class Add(Tool):
|
434
536
|
r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
|
435
537
|
|
@@ -506,47 +608,6 @@ class Divide(Tool):
|
|
506
608
|
return round(input[0] / input[1], 2)
|
507
609
|
|
508
610
|
|
509
|
-
class ExtractFrames(Tool):
|
510
|
-
r"""Extract frames from a video."""
|
511
|
-
|
512
|
-
name = "extract_frames_"
|
513
|
-
description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
|
514
|
-
usage = {
|
515
|
-
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
516
|
-
"examples": [
|
517
|
-
{
|
518
|
-
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
|
519
|
-
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
|
520
|
-
},
|
521
|
-
{
|
522
|
-
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
|
523
|
-
"parameters": {"video_uri": "tests/data/test.mp4"},
|
524
|
-
},
|
525
|
-
],
|
526
|
-
}
|
527
|
-
|
528
|
-
def __call__(self, video_uri: str) -> list[tuple[str, float]]:
|
529
|
-
"""Extract frames from a video.
|
530
|
-
|
531
|
-
|
532
|
-
Parameters:
|
533
|
-
video_uri: the path to the video file or a url points to the video data
|
534
|
-
|
535
|
-
Returns:
|
536
|
-
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
|
537
|
-
"""
|
538
|
-
frames = extract_frames_from_video(video_uri)
|
539
|
-
result = []
|
540
|
-
_LOGGER.info(
|
541
|
-
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
|
542
|
-
)
|
543
|
-
for frame, ts in frames:
|
544
|
-
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
545
|
-
Image.fromarray(frame).save(tmp)
|
546
|
-
result.append((tmp.name, ts))
|
547
|
-
return result
|
548
|
-
|
549
|
-
|
550
611
|
TOOLS = {
|
551
612
|
i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
|
552
613
|
for i, c in enumerate(
|
@@ -554,15 +615,17 @@ TOOLS = {
|
|
554
615
|
CLIP,
|
555
616
|
GroundingDINO,
|
556
617
|
AgentGroundingSAM,
|
618
|
+
ExtractFrames,
|
557
619
|
Counter,
|
558
620
|
Crop,
|
559
621
|
BboxArea,
|
560
622
|
SegArea,
|
623
|
+
BboxIoU,
|
624
|
+
SegIoU,
|
561
625
|
Add,
|
562
626
|
Subtract,
|
563
627
|
Multiply,
|
564
628
|
Divide,
|
565
|
-
ExtractFrames,
|
566
629
|
]
|
567
630
|
)
|
568
631
|
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
|
vision_agent/tools/video.py
CHANGED
@@ -22,12 +22,16 @@ def extract_frames_from_video(
|
|
22
22
|
Parameters:
|
23
23
|
video_uri: the path to the video file or a video file url
|
24
24
|
fps: the frame rate per second to extract the frames
|
25
|
-
motion_detection_threshold: The threshold to detect motion between
|
26
|
-
A value between 0-1, which represents the percentage change
|
27
|
-
|
25
|
+
motion_detection_threshold: The threshold to detect motion between
|
26
|
+
changes/frames. A value between 0-1, which represents the percentage change
|
27
|
+
required for the frames to be considered in motion. For example, a lower
|
28
|
+
value means more frames will be extracted.
|
28
29
|
|
29
30
|
Returns:
|
30
|
-
a list of tuples containing the extracted frame and the timestamp in seconds.
|
31
|
+
a list of tuples containing the extracted frame and the timestamp in seconds.
|
32
|
+
E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
|
33
|
+
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
|
34
|
+
the video. The frames are sorted by the timestamp in ascending order.
|
31
35
|
"""
|
32
36
|
with VideoFileClip(video_uri) as video:
|
33
37
|
video_duration: float = video.duration
|
@@ -5,22 +5,22 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
|
|
5
5
|
vision_agent/agent/easytool_prompts.py,sha256=uNp12LOFRLr3i2zLhNuLuyFms2-s8es2t6P6h76QDow,4493
|
6
6
|
vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=UPGkt_qgHBMUY0VPVoF-BqhR0d_6WPjjrhbYLBYOtnQ,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=P2melU6XQCCiiL1C_4QsxGUaWbwahuJA90eIcQJTR4U,17449
|
9
9
|
vision_agent/agent/vision_agent_prompts.py,sha256=otaDRsaHc7bqw_tgWTnu-eUcFeOzBFrn9sPU7_xr2VQ,6151
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
11
|
vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
|
12
12
|
vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
|
-
vision_agent/image_utils.py,sha256=
|
14
|
+
vision_agent/image_utils.py,sha256=XiOLpHAvlk55URw6iG7hl1OY71FVRA9_25b650amZXA,4420
|
15
15
|
vision_agent/llm/__init__.py,sha256=fBKsIjL4z08eA0QYx6wvhRe4Nkp2pJ4VrZK0-uUL5Ec,32
|
16
|
-
vision_agent/llm/llm.py,sha256=
|
16
|
+
vision_agent/llm/llm.py,sha256=l8ZVh6vCZOJBHfenfOoHwPySXEUQoNt_gbL14gkvu2g,3904
|
17
17
|
vision_agent/lmm/__init__.py,sha256=I8mbeNUajTfWVNqLsuFQVOaNBDlkIhYp9DFU8H4kB7g,51
|
18
|
-
vision_agent/lmm/lmm.py,sha256=
|
19
|
-
vision_agent/tools/__init__.py,sha256=
|
18
|
+
vision_agent/lmm/lmm.py,sha256=s_A3SKCoWm2biOt-gS9PXOsa9l-zrmR6mInLjAqam-A,8438
|
19
|
+
vision_agent/tools/__init__.py,sha256=AKN-T659HpwVearRnkCd6wWNoJ6K5kW9gAZwb8IQSLE,235
|
20
20
|
vision_agent/tools/prompts.py,sha256=9RBbyqlNlExsGKlJ89Jkph83DAEJ8PCVGaHoNbyN7TM,1416
|
21
|
-
vision_agent/tools/tools.py,sha256=
|
22
|
-
vision_agent/tools/video.py,sha256=
|
23
|
-
vision_agent-0.0.
|
24
|
-
vision_agent-0.0.
|
25
|
-
vision_agent-0.0.
|
26
|
-
vision_agent-0.0.
|
21
|
+
vision_agent/tools/tools.py,sha256=aMTBxxaXQp33HwplOS8xrgfbsTJ8e1pwO6byR7HcTJI,23447
|
22
|
+
vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
|
23
|
+
vision_agent-0.0.42.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
24
|
+
vision_agent-0.0.42.dist-info/METADATA,sha256=r523uVvu-DsNoA-H-18O2JXF4J9G2nZ2cDSmjXUFq_M,5324
|
25
|
+
vision_agent-0.0.42.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
26
|
+
vision_agent-0.0.42.dist-info/RECORD,,
|
File without changes
|
File without changes
|