vision-agent 0.0.40__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +73 -19
- vision_agent/image_utils.py +95 -5
- vision_agent/tools/__init__.py +14 -1
- vision_agent/tools/tools.py +123 -60
- vision_agent/tools/video.py +8 -4
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.41.dist-info}/METADATA +1 -1
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.41.dist-info}/RECORD +9 -9
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.41.dist-info}/LICENSE +0 -0
- {vision_agent-0.0.40.dist-info → vision_agent-0.0.41.dist-info}/WHEEL +0 -0
@@ -1,11 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import sys
|
4
|
+
import tempfile
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6
7
|
|
7
8
|
from tabulate import tabulate
|
8
9
|
|
10
|
+
from vision_agent.image_utils import overlay_bboxes, overlay_masks
|
9
11
|
from vision_agent.llm import LLM, OpenAILLM
|
10
12
|
from vision_agent.lmm import LMM, OpenAILMM
|
11
13
|
from vision_agent.tools import TOOLS
|
@@ -248,12 +250,12 @@ def retrieval(
|
|
248
250
|
tools: Dict[int, Any],
|
249
251
|
previous_log: str,
|
250
252
|
reflections: str,
|
251
|
-
) -> Tuple[
|
253
|
+
) -> Tuple[Dict, str]:
|
252
254
|
tool_id = choose_tool(
|
253
255
|
model, question, {k: v["description"] for k, v in tools.items()}, reflections
|
254
256
|
)
|
255
257
|
if tool_id is None:
|
256
|
-
return
|
258
|
+
return {}, ""
|
257
259
|
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
|
258
260
|
|
259
261
|
tool_instructions = tools[tool_id]
|
@@ -265,14 +267,12 @@ def retrieval(
|
|
265
267
|
)
|
266
268
|
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
|
267
269
|
if parameters is None:
|
268
|
-
return
|
269
|
-
tool_results =
|
270
|
-
{"task": question, "tool_name": tool_name, "parameters": parameters}
|
271
|
-
]
|
270
|
+
return {}, ""
|
271
|
+
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
|
272
272
|
|
273
273
|
_LOGGER.info(
|
274
|
-
f"""Going to run the following
|
275
|
-
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
|
274
|
+
f"""Going to run the following tool(s) in sequence:
|
275
|
+
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
|
276
276
|
)
|
277
277
|
|
278
278
|
def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
|
@@ -286,12 +286,10 @@ def retrieval(
|
|
286
286
|
call_results.append(function_call(tools[tool_id]["class"], parameters))
|
287
287
|
return call_results
|
288
288
|
|
289
|
-
call_results =
|
290
|
-
|
291
|
-
call_results.extend(parse_tool_results(result))
|
292
|
-
tool_results[i]["call_results"] = call_results
|
289
|
+
call_results = parse_tool_results(tool_results)
|
290
|
+
tool_results["call_results"] = call_results
|
293
291
|
|
294
|
-
call_results_str =
|
292
|
+
call_results_str = str(call_results)
|
295
293
|
_LOGGER.info(f"\tCall Results: {call_results_str}")
|
296
294
|
return tool_results, call_results_str
|
297
295
|
|
@@ -335,7 +333,11 @@ def self_reflect(
|
|
335
333
|
tool_results=str(tool_result),
|
336
334
|
final_answer=final_answer,
|
337
335
|
)
|
338
|
-
if
|
336
|
+
if (
|
337
|
+
issubclass(type(reflect_model), LMM)
|
338
|
+
and image is not None
|
339
|
+
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
|
340
|
+
):
|
339
341
|
return reflect_model(prompt, image=image) # type: ignore
|
340
342
|
return reflect_model(prompt)
|
341
343
|
|
@@ -345,6 +347,56 @@ def parse_reflect(reflect: str) -> bool:
|
|
345
347
|
return "finish" in reflect.lower() and len(reflect) < 100
|
346
348
|
|
347
349
|
|
350
|
+
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
|
351
|
+
image_to_data: Dict[str, Dict] = {}
|
352
|
+
for tool_result in all_tool_results:
|
353
|
+
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
|
354
|
+
continue
|
355
|
+
|
356
|
+
parameters = tool_result["parameters"]
|
357
|
+
# parameters can either be a dictionary or list, parameters can also be malformed
|
358
|
+
# becaus the LLM builds them
|
359
|
+
if isinstance(parameters, dict):
|
360
|
+
if "image" not in parameters:
|
361
|
+
continue
|
362
|
+
parameters = [parameters]
|
363
|
+
elif isinstance(tool_result["parameters"], list):
|
364
|
+
if (
|
365
|
+
len(tool_result["parameters"]) < 1
|
366
|
+
and "image" not in tool_result["parameters"][0]
|
367
|
+
):
|
368
|
+
continue
|
369
|
+
|
370
|
+
for param, call_result in zip(parameters, tool_result["call_results"]):
|
371
|
+
|
372
|
+
# calls can fail, so we need to check if the call was successful
|
373
|
+
if not isinstance(call_result, dict):
|
374
|
+
continue
|
375
|
+
if "bboxes" not in call_result:
|
376
|
+
continue
|
377
|
+
|
378
|
+
# if the call was successful, then we can add the image data
|
379
|
+
image = param["image"]
|
380
|
+
if image not in image_to_data:
|
381
|
+
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
|
382
|
+
|
383
|
+
image_to_data[image]["bboxes"].extend(call_result["bboxes"])
|
384
|
+
image_to_data[image]["labels"].extend(call_result["labels"])
|
385
|
+
if "masks" in call_result:
|
386
|
+
image_to_data[image]["masks"].extend(call_result["masks"])
|
387
|
+
|
388
|
+
visualized_images = []
|
389
|
+
for image in image_to_data:
|
390
|
+
image_path = Path(image)
|
391
|
+
image_data = image_to_data[image]
|
392
|
+
image = overlay_masks(image_path, image_data)
|
393
|
+
image = overlay_bboxes(image, image_data)
|
394
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
395
|
+
image.save(f.name)
|
396
|
+
visualized_images.append(f.name)
|
397
|
+
return visualized_images
|
398
|
+
|
399
|
+
|
348
400
|
class VisionAgent(Agent):
|
349
401
|
r"""Vision Agent is an agent framework that utilizes tools as well as self
|
350
402
|
reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
|
@@ -389,7 +441,8 @@ class VisionAgent(Agent):
|
|
389
441
|
"""Invoke the vision agent.
|
390
442
|
|
391
443
|
Parameters:
|
392
|
-
input: a prompt that describe the task or a conversation in the format of
|
444
|
+
input: a prompt that describe the task or a conversation in the format of
|
445
|
+
[{"role": "user", "content": "describe your task here..."}].
|
393
446
|
image: the input image referenced in the prompt parameter.
|
394
447
|
|
395
448
|
Returns:
|
@@ -436,9 +489,8 @@ class VisionAgent(Agent):
|
|
436
489
|
self.answer_model, task_str, call_results, previous_log, reflections
|
437
490
|
)
|
438
491
|
|
439
|
-
|
440
|
-
|
441
|
-
all_tool_results.extend(tool_results)
|
492
|
+
tool_results["answer"] = answer
|
493
|
+
all_tool_results.append(tool_results)
|
442
494
|
|
443
495
|
_LOGGER.info(f"\tAnswer: {answer}")
|
444
496
|
answers.append({"task": task_str, "answer": answer})
|
@@ -448,13 +500,15 @@ class VisionAgent(Agent):
|
|
448
500
|
self.answer_model, question, answers, reflections
|
449
501
|
)
|
450
502
|
|
503
|
+
visualized_images = visualize_result(all_tool_results)
|
504
|
+
all_tool_results.append({"visualized_images": visualized_images})
|
451
505
|
reflection = self_reflect(
|
452
506
|
self.reflect_model,
|
453
507
|
question,
|
454
508
|
self.tools,
|
455
509
|
all_tool_results,
|
456
510
|
final_answer,
|
457
|
-
image,
|
511
|
+
visualized_images[0] if len(visualized_images) > 0 else image,
|
458
512
|
)
|
459
513
|
_LOGGER.info(f"\tReflection: {reflection}")
|
460
514
|
if parse_reflect(reflection):
|
vision_agent/image_utils.py
CHANGED
@@ -3,15 +3,38 @@
|
|
3
3
|
import base64
|
4
4
|
from io import BytesIO
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Tuple, Union
|
6
|
+
from typing import Dict, Tuple, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from PIL import Image
|
9
|
+
from PIL import Image, ImageDraw, ImageFont
|
10
10
|
from PIL.Image import Image as ImageType
|
11
11
|
|
12
|
+
COLORS = [
|
13
|
+
(158, 218, 229),
|
14
|
+
(219, 219, 141),
|
15
|
+
(23, 190, 207),
|
16
|
+
(188, 189, 34),
|
17
|
+
(199, 199, 199),
|
18
|
+
(247, 182, 210),
|
19
|
+
(127, 127, 127),
|
20
|
+
(227, 119, 194),
|
21
|
+
(196, 156, 148),
|
22
|
+
(197, 176, 213),
|
23
|
+
(140, 86, 75),
|
24
|
+
(148, 103, 189),
|
25
|
+
(255, 152, 150),
|
26
|
+
(152, 223, 138),
|
27
|
+
(214, 39, 40),
|
28
|
+
(44, 160, 44),
|
29
|
+
(255, 187, 120),
|
30
|
+
(174, 199, 232),
|
31
|
+
(255, 127, 14),
|
32
|
+
(31, 119, 180),
|
33
|
+
]
|
34
|
+
|
12
35
|
|
13
36
|
def b64_to_pil(b64_str: str) -> ImageType:
|
14
|
-
"""Convert a base64 string to a PIL Image.
|
37
|
+
r"""Convert a base64 string to a PIL Image.
|
15
38
|
|
16
39
|
Parameters:
|
17
40
|
b64_str: the base64 encoded image
|
@@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:
|
|
26
49
|
|
27
50
|
|
28
51
|
def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
|
29
|
-
"""Get the size of an image.
|
52
|
+
r"""Get the size of an image.
|
30
53
|
|
31
54
|
Parameters:
|
32
55
|
data: the input image
|
@@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,
|
|
41
64
|
|
42
65
|
|
43
66
|
def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
|
44
|
-
"""Convert an image to a base64 string.
|
67
|
+
r"""Convert an image to a base64 string.
|
45
68
|
|
46
69
|
Parameters:
|
47
70
|
data: the input image
|
@@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
|
|
60
83
|
else:
|
61
84
|
arr_bytes = data.tobytes()
|
62
85
|
return base64.b64encode(arr_bytes).decode("utf-8")
|
86
|
+
|
87
|
+
|
88
|
+
def overlay_bboxes(
|
89
|
+
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
|
90
|
+
) -> ImageType:
|
91
|
+
r"""Plots bounding boxes on to an image.
|
92
|
+
|
93
|
+
Parameters:
|
94
|
+
image: the input image
|
95
|
+
bboxes: the bounding boxes to overlay
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The image with the bounding boxes overlayed
|
99
|
+
"""
|
100
|
+
if isinstance(image, (str, Path)):
|
101
|
+
image = Image.open(image)
|
102
|
+
elif isinstance(image, np.ndarray):
|
103
|
+
image = Image.fromarray(image)
|
104
|
+
|
105
|
+
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
|
106
|
+
|
107
|
+
draw = ImageDraw.Draw(image)
|
108
|
+
font = ImageFont.load_default()
|
109
|
+
width, height = image.size
|
110
|
+
if "bboxes" not in bboxes:
|
111
|
+
return image.convert("RGB")
|
112
|
+
|
113
|
+
for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
|
114
|
+
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
|
115
|
+
draw.rectangle(box, outline=color[label], width=3)
|
116
|
+
label = f"{label}"
|
117
|
+
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
|
118
|
+
draw.rectangle(text_box, fill=color[label])
|
119
|
+
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
|
120
|
+
return image.convert("RGB")
|
121
|
+
|
122
|
+
|
123
|
+
def overlay_masks(
|
124
|
+
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
|
125
|
+
) -> ImageType:
|
126
|
+
r"""Plots masks on to an image.
|
127
|
+
|
128
|
+
Parameters:
|
129
|
+
image: the input image
|
130
|
+
masks: the masks to overlay
|
131
|
+
alpha: the transparency of the overlay
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
The image with the masks overlayed
|
135
|
+
"""
|
136
|
+
if isinstance(image, (str, Path)):
|
137
|
+
image = Image.open(image)
|
138
|
+
elif isinstance(image, np.ndarray):
|
139
|
+
image = Image.fromarray(image)
|
140
|
+
|
141
|
+
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
|
142
|
+
if "masks" not in masks:
|
143
|
+
return image.convert("RGB")
|
144
|
+
|
145
|
+
for label, mask in zip(masks["labels"], masks["masks"]):
|
146
|
+
if isinstance(mask, str):
|
147
|
+
mask = np.array(Image.open(mask))
|
148
|
+
np_mask = np.zeros((image.size[1], image.size[0], 4))
|
149
|
+
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
|
150
|
+
mask_img = Image.fromarray(np_mask.astype(np.uint8))
|
151
|
+
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
|
152
|
+
return image.convert("RGB")
|
vision_agent/tools/__init__.py
CHANGED
@@ -1,2 +1,15 @@
|
|
1
1
|
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
|
2
|
-
from .tools import
|
2
|
+
from .tools import (
|
3
|
+
CLIP,
|
4
|
+
TOOLS,
|
5
|
+
BboxArea,
|
6
|
+
BboxIoU,
|
7
|
+
Counter,
|
8
|
+
Crop,
|
9
|
+
ExtractFrames,
|
10
|
+
GroundingDINO,
|
11
|
+
GroundingSAM,
|
12
|
+
SegArea,
|
13
|
+
SegIoU,
|
14
|
+
Tool,
|
15
|
+
)
|
vision_agent/tools/tools.py
CHANGED
@@ -92,7 +92,7 @@ class CLIP(Tool):
|
|
92
92
|
}
|
93
93
|
|
94
94
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
95
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
95
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
96
96
|
"""Invoke the CLIP model.
|
97
97
|
|
98
98
|
Parameters:
|
@@ -122,7 +122,7 @@ class CLIP(Tool):
|
|
122
122
|
rets = []
|
123
123
|
for elt in resp_json["data"]:
|
124
124
|
rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
|
125
|
-
return cast(
|
125
|
+
return cast(Dict, rets[0])
|
126
126
|
|
127
127
|
|
128
128
|
class GroundingDINO(Tool):
|
@@ -168,7 +168,7 @@ class GroundingDINO(Tool):
|
|
168
168
|
}
|
169
169
|
|
170
170
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
171
|
-
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) ->
|
171
|
+
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
|
172
172
|
"""Invoke the Grounding DINO model.
|
173
173
|
|
174
174
|
Parameters:
|
@@ -204,7 +204,7 @@ class GroundingDINO(Tool):
|
|
204
204
|
if "scores" in elt:
|
205
205
|
elt["scores"] = [round(score, 2) for score in elt["scores"]]
|
206
206
|
elt["size"] = (image_size[1], image_size[0])
|
207
|
-
return cast(
|
207
|
+
return cast(Dict, resp_data)
|
208
208
|
|
209
209
|
|
210
210
|
class GroundingSAM(Tool):
|
@@ -259,7 +259,7 @@ class GroundingSAM(Tool):
|
|
259
259
|
}
|
260
260
|
|
261
261
|
# TODO: Add support for input multiple images, which aligns with the output type.
|
262
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
262
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
263
263
|
"""Invoke the Grounding SAM model.
|
264
264
|
|
265
265
|
Parameters:
|
@@ -294,7 +294,7 @@ class GroundingSAM(Tool):
|
|
294
294
|
ret_pred["labels"].append(pred["label_name"])
|
295
295
|
ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
|
296
296
|
ret_pred["masks"].append(mask)
|
297
|
-
return
|
297
|
+
return ret_pred
|
298
298
|
|
299
299
|
|
300
300
|
class AgentGroundingSAM(GroundingSAM):
|
@@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM):
|
|
302
302
|
returns the file name. This makes it easier for agents to use.
|
303
303
|
"""
|
304
304
|
|
305
|
-
def __call__(self, prompt: List[str], image: Union[str, ImageType]) ->
|
305
|
+
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
|
306
306
|
rets = super().__call__(prompt, image)
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
ret["masks"] = mask_files
|
307
|
+
mask_files = []
|
308
|
+
for mask in rets["masks"]:
|
309
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
310
|
+
Image.fromarray(mask * 255).save(tmp)
|
311
|
+
mask_files.append(tmp.name)
|
312
|
+
rets["masks"] = mask_files
|
314
313
|
return rets
|
315
314
|
|
316
315
|
|
@@ -363,7 +362,7 @@ class Crop(Tool):
|
|
363
362
|
],
|
364
363
|
}
|
365
364
|
|
366
|
-
def __call__(self, bbox: List[float], image: Union[str, Path]) ->
|
365
|
+
def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
|
367
366
|
pil_image = Image.open(image)
|
368
367
|
width, height = pil_image.size
|
369
368
|
bbox = [
|
@@ -373,10 +372,10 @@ class Crop(Tool):
|
|
373
372
|
int(bbox[3] * height),
|
374
373
|
]
|
375
374
|
cropped_image = pil_image.crop(bbox) # type: ignore
|
376
|
-
with tempfile.NamedTemporaryFile(suffix=".
|
375
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
377
376
|
cropped_image.save(tmp.name)
|
378
377
|
|
379
|
-
return tmp.name
|
378
|
+
return {"image": tmp.name}
|
380
379
|
|
381
380
|
|
382
381
|
class BboxArea(Tool):
|
@@ -388,7 +387,7 @@ class BboxArea(Tool):
|
|
388
387
|
"required_parameters": [{"name": "bbox", "type": "List[int]"}],
|
389
388
|
"examples": [
|
390
389
|
{
|
391
|
-
"scenario": "If you want to calculate the area of the bounding box [0, 0,
|
390
|
+
"scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
|
392
391
|
"parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
|
393
392
|
}
|
394
393
|
],
|
@@ -430,6 +429,109 @@ class SegArea(Tool):
|
|
430
429
|
return cast(float, round(np.sum(np_mask) / 255, 2))
|
431
430
|
|
432
431
|
|
432
|
+
class BboxIoU(Tool):
|
433
|
+
name = "bbox_iou_"
|
434
|
+
description = (
|
435
|
+
"'bbox_iou_' returns the intersection over union of two bounding boxes."
|
436
|
+
)
|
437
|
+
usage = {
|
438
|
+
"required_parameters": [
|
439
|
+
{"name": "bbox1", "type": "List[int]"},
|
440
|
+
{"name": "bbox2", "type": "List[int]"},
|
441
|
+
],
|
442
|
+
"examples": [
|
443
|
+
{
|
444
|
+
"scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
|
445
|
+
"parameters": {
|
446
|
+
"bbox1": [0.2, 0.21, 0.34, 0.42],
|
447
|
+
"bbox2": [0.3, 0.31, 0.44, 0.52],
|
448
|
+
},
|
449
|
+
}
|
450
|
+
],
|
451
|
+
}
|
452
|
+
|
453
|
+
def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
|
454
|
+
x1, y1, x2, y2 = bbox1
|
455
|
+
x3, y3, x4, y4 = bbox2
|
456
|
+
xA = max(x1, x3)
|
457
|
+
yA = max(y1, y3)
|
458
|
+
xB = min(x2, x4)
|
459
|
+
yB = min(y2, y4)
|
460
|
+
inter_area = max(0, xB - xA) * max(0, yB - yA)
|
461
|
+
boxa_area = (x2 - x1) * (y2 - y1)
|
462
|
+
boxb_area = (x4 - x3) * (y4 - y3)
|
463
|
+
iou = inter_area / float(boxa_area + boxb_area - inter_area)
|
464
|
+
return round(iou, 2)
|
465
|
+
|
466
|
+
|
467
|
+
class SegIoU(Tool):
|
468
|
+
name = "seg_iou_"
|
469
|
+
description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
|
470
|
+
usage = {
|
471
|
+
"required_parameters": [
|
472
|
+
{"name": "mask1", "type": "str"},
|
473
|
+
{"name": "mask2", "type": "str"},
|
474
|
+
],
|
475
|
+
"examples": [
|
476
|
+
{
|
477
|
+
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
|
478
|
+
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
|
479
|
+
}
|
480
|
+
],
|
481
|
+
}
|
482
|
+
|
483
|
+
def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
|
484
|
+
pil_mask1 = Image.open(str(mask1))
|
485
|
+
pil_mask2 = Image.open(str(mask2))
|
486
|
+
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
|
487
|
+
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
|
488
|
+
intersection = np.logical_and(np_mask1, np_mask2)
|
489
|
+
union = np.logical_or(np_mask1, np_mask2)
|
490
|
+
iou = np.sum(intersection) / np.sum(union)
|
491
|
+
return cast(float, round(iou, 2))
|
492
|
+
|
493
|
+
|
494
|
+
class ExtractFrames(Tool):
|
495
|
+
r"""Extract frames from a video."""
|
496
|
+
|
497
|
+
name = "extract_frames_"
|
498
|
+
description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
|
499
|
+
usage = {
|
500
|
+
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
501
|
+
"examples": [
|
502
|
+
{
|
503
|
+
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
|
504
|
+
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
|
505
|
+
},
|
506
|
+
{
|
507
|
+
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
|
508
|
+
"parameters": {"video_uri": "tests/data/test.mp4"},
|
509
|
+
},
|
510
|
+
],
|
511
|
+
}
|
512
|
+
|
513
|
+
def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
|
514
|
+
"""Extract frames from a video.
|
515
|
+
|
516
|
+
|
517
|
+
Parameters:
|
518
|
+
video_uri: the path to the video file or a url points to the video data
|
519
|
+
|
520
|
+
Returns:
|
521
|
+
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
|
522
|
+
"""
|
523
|
+
frames = extract_frames_from_video(video_uri)
|
524
|
+
result = []
|
525
|
+
_LOGGER.info(
|
526
|
+
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
|
527
|
+
)
|
528
|
+
for frame, ts in frames:
|
529
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
530
|
+
Image.fromarray(frame).save(tmp)
|
531
|
+
result.append((tmp.name, ts))
|
532
|
+
return result
|
533
|
+
|
534
|
+
|
433
535
|
class Add(Tool):
|
434
536
|
r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
|
435
537
|
|
@@ -506,47 +608,6 @@ class Divide(Tool):
|
|
506
608
|
return round(input[0] / input[1], 2)
|
507
609
|
|
508
610
|
|
509
|
-
class ExtractFrames(Tool):
|
510
|
-
r"""Extract frames from a video."""
|
511
|
-
|
512
|
-
name = "extract_frames_"
|
513
|
-
description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
|
514
|
-
usage = {
|
515
|
-
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
516
|
-
"examples": [
|
517
|
-
{
|
518
|
-
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
|
519
|
-
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
|
520
|
-
},
|
521
|
-
{
|
522
|
-
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
|
523
|
-
"parameters": {"video_uri": "tests/data/test.mp4"},
|
524
|
-
},
|
525
|
-
],
|
526
|
-
}
|
527
|
-
|
528
|
-
def __call__(self, video_uri: str) -> list[tuple[str, float]]:
|
529
|
-
"""Extract frames from a video.
|
530
|
-
|
531
|
-
|
532
|
-
Parameters:
|
533
|
-
video_uri: the path to the video file or a url points to the video data
|
534
|
-
|
535
|
-
Returns:
|
536
|
-
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
|
537
|
-
"""
|
538
|
-
frames = extract_frames_from_video(video_uri)
|
539
|
-
result = []
|
540
|
-
_LOGGER.info(
|
541
|
-
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
|
542
|
-
)
|
543
|
-
for frame, ts in frames:
|
544
|
-
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
545
|
-
Image.fromarray(frame).save(tmp)
|
546
|
-
result.append((tmp.name, ts))
|
547
|
-
return result
|
548
|
-
|
549
|
-
|
550
611
|
TOOLS = {
|
551
612
|
i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
|
552
613
|
for i, c in enumerate(
|
@@ -554,15 +615,17 @@ TOOLS = {
|
|
554
615
|
CLIP,
|
555
616
|
GroundingDINO,
|
556
617
|
AgentGroundingSAM,
|
618
|
+
ExtractFrames,
|
557
619
|
Counter,
|
558
620
|
Crop,
|
559
621
|
BboxArea,
|
560
622
|
SegArea,
|
623
|
+
BboxIoU,
|
624
|
+
SegIoU,
|
561
625
|
Add,
|
562
626
|
Subtract,
|
563
627
|
Multiply,
|
564
628
|
Divide,
|
565
|
-
ExtractFrames,
|
566
629
|
]
|
567
630
|
)
|
568
631
|
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
|
vision_agent/tools/video.py
CHANGED
@@ -22,12 +22,16 @@ def extract_frames_from_video(
|
|
22
22
|
Parameters:
|
23
23
|
video_uri: the path to the video file or a video file url
|
24
24
|
fps: the frame rate per second to extract the frames
|
25
|
-
motion_detection_threshold: The threshold to detect motion between
|
26
|
-
A value between 0-1, which represents the percentage change
|
27
|
-
|
25
|
+
motion_detection_threshold: The threshold to detect motion between
|
26
|
+
changes/frames. A value between 0-1, which represents the percentage change
|
27
|
+
required for the frames to be considered in motion. For example, a lower
|
28
|
+
value means more frames will be extracted.
|
28
29
|
|
29
30
|
Returns:
|
30
|
-
a list of tuples containing the extracted frame and the timestamp in seconds.
|
31
|
+
a list of tuples containing the extracted frame and the timestamp in seconds.
|
32
|
+
E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
|
33
|
+
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
|
34
|
+
the video. The frames are sorted by the timestamp in ascending order.
|
31
35
|
"""
|
32
36
|
with VideoFileClip(video_uri) as video:
|
33
37
|
video_duration: float = video.duration
|
@@ -5,22 +5,22 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
|
|
5
5
|
vision_agent/agent/easytool_prompts.py,sha256=uNp12LOFRLr3i2zLhNuLuyFms2-s8es2t6P6h76QDow,4493
|
6
6
|
vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=UPGkt_qgHBMUY0VPVoF-BqhR0d_6WPjjrhbYLBYOtnQ,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=_K6yWJiU1j0EGe8cabB40K0HxUkdzF-_c8G2k5eQL8s,17469
|
9
9
|
vision_agent/agent/vision_agent_prompts.py,sha256=otaDRsaHc7bqw_tgWTnu-eUcFeOzBFrn9sPU7_xr2VQ,6151
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
11
|
vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
|
12
12
|
vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
|
-
vision_agent/image_utils.py,sha256=
|
14
|
+
vision_agent/image_utils.py,sha256=XiOLpHAvlk55URw6iG7hl1OY71FVRA9_25b650amZXA,4420
|
15
15
|
vision_agent/llm/__init__.py,sha256=fBKsIjL4z08eA0QYx6wvhRe4Nkp2pJ4VrZK0-uUL5Ec,32
|
16
16
|
vision_agent/llm/llm.py,sha256=d8A7jmLVGx5HzoiYJ75mTMU7dbD5-bOYeXYlHaay6WA,3957
|
17
17
|
vision_agent/lmm/__init__.py,sha256=I8mbeNUajTfWVNqLsuFQVOaNBDlkIhYp9DFU8H4kB7g,51
|
18
18
|
vision_agent/lmm/lmm.py,sha256=ARcbgkcyP83TbVVoXI9B-gtG0gJuTaG_MjcUGbams4U,8052
|
19
|
-
vision_agent/tools/__init__.py,sha256=
|
19
|
+
vision_agent/tools/__init__.py,sha256=AKN-T659HpwVearRnkCd6wWNoJ6K5kW9gAZwb8IQSLE,235
|
20
20
|
vision_agent/tools/prompts.py,sha256=9RBbyqlNlExsGKlJ89Jkph83DAEJ8PCVGaHoNbyN7TM,1416
|
21
|
-
vision_agent/tools/tools.py,sha256=
|
22
|
-
vision_agent/tools/video.py,sha256=
|
23
|
-
vision_agent-0.0.
|
24
|
-
vision_agent-0.0.
|
25
|
-
vision_agent-0.0.
|
26
|
-
vision_agent-0.0.
|
21
|
+
vision_agent/tools/tools.py,sha256=aMTBxxaXQp33HwplOS8xrgfbsTJ8e1pwO6byR7HcTJI,23447
|
22
|
+
vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
|
23
|
+
vision_agent-0.0.41.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
24
|
+
vision_agent-0.0.41.dist-info/METADATA,sha256=45hGAgKvEd7WjzrmbFVluki2t0O64UomaHtIrwLCknw,5324
|
25
|
+
vision_agent-0.0.41.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
26
|
+
vision_agent-0.0.41.dist-info/RECORD,,
|
File without changes
|
File without changes
|