vision-agent 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/__init__.py +1 -0
- vision_agent/agent/agent_coder.py +33 -7
- vision_agent/agent/vision_agent.py +16 -14
- vision_agent/agent/vision_agent_v2.py +300 -0
- vision_agent/agent/vision_agent_v2_prompt.py +170 -0
- vision_agent/llm/llm.py +11 -3
- vision_agent/tools/__init__.py +3 -3
- vision_agent/tools/tool_utils.py +1 -1
- vision_agent/tools/tools.py +62 -41
- vision_agent/tools/tools_v2.py +278 -17
- vision_agent/utils/__init__.py +3 -0
- vision_agent/utils/execute.py +104 -0
- vision_agent/utils/sim.py +70 -0
- {vision_agent-0.2.14.dist-info → vision_agent-0.2.16.dist-info}/METADATA +4 -1
- vision_agent-0.2.16.dist-info/RECORD +34 -0
- vision_agent/agent/execution.py +0 -287
- vision_agent-0.2.14.dist-info/RECORD +0 -30
- /vision_agent/{image_utils.py → utils/image_utils.py} +0 -0
- /vision_agent/{type_defs.py → utils/type_defs.py} +0 -0
- /vision_agent/{tools → utils}/video.py +0 -0
- {vision_agent-0.2.14.dist-info → vision_agent-0.2.16.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.14.dist-info → vision_agent-0.2.16.dist-info}/WHEEL +0 -0
vision_agent/tools/tools.py
CHANGED
@@ -11,7 +11,10 @@ from PIL import Image
|
|
11
11
|
from PIL.Image import Image as ImageType
|
12
12
|
from scipy.spatial import distance # type: ignore
|
13
13
|
|
14
|
-
from vision_agent.
|
14
|
+
from vision_agent.lmm import OpenAILMM
|
15
|
+
from vision_agent.tools.tool_utils import _send_inference_request
|
16
|
+
from vision_agent.utils import extract_frames_from_video
|
17
|
+
from vision_agent.utils.image_utils import (
|
15
18
|
b64_to_pil,
|
16
19
|
convert_to_b64,
|
17
20
|
denormalize_bbox,
|
@@ -19,9 +22,6 @@ from vision_agent.image_utils import (
|
|
19
22
|
normalize_bbox,
|
20
23
|
rle_decode,
|
21
24
|
)
|
22
|
-
from vision_agent.lmm import OpenAILMM
|
23
|
-
from vision_agent.tools.tool_utils import _send_inference_request
|
24
|
-
from vision_agent.tools.video import extract_frames_from_video
|
25
25
|
|
26
26
|
_LOGGER = logging.getLogger(__name__)
|
27
27
|
|
@@ -174,15 +174,15 @@ class GroundingDINO(Tool):
|
|
174
174
|
"""
|
175
175
|
|
176
176
|
name = "grounding_dino_"
|
177
|
-
description = "'grounding_dino_' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
|
177
|
+
description = "'grounding_dino_' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
|
178
178
|
usage = {
|
179
179
|
"required_parameters": [
|
180
180
|
{"name": "prompt", "type": "str"},
|
181
181
|
{"name": "image", "type": "str"},
|
182
182
|
],
|
183
183
|
"optional_parameters": [
|
184
|
-
{"name": "box_threshold", "type": "float"},
|
185
|
-
{"name": "iou_threshold", "type": "float"},
|
184
|
+
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
|
185
|
+
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
|
186
186
|
],
|
187
187
|
"examples": [
|
188
188
|
{
|
@@ -209,7 +209,7 @@ class GroundingDINO(Tool):
|
|
209
209
|
"prompt": "red shirt. green shirt",
|
210
210
|
"image": "shirts.jpg",
|
211
211
|
"box_threshold": 0.20,
|
212
|
-
"iou_threshold": 0.
|
212
|
+
"iou_threshold": 0.20,
|
213
213
|
},
|
214
214
|
},
|
215
215
|
],
|
@@ -221,7 +221,7 @@ class GroundingDINO(Tool):
|
|
221
221
|
prompt: str,
|
222
222
|
image: Union[str, Path, ImageType],
|
223
223
|
box_threshold: float = 0.20,
|
224
|
-
iou_threshold: float = 0.
|
224
|
+
iou_threshold: float = 0.20,
|
225
225
|
) -> Dict:
|
226
226
|
"""Invoke the Grounding DINO model.
|
227
227
|
|
@@ -249,7 +249,7 @@ class GroundingDINO(Tool):
|
|
249
249
|
data["scores"] = [round(score, 2) for score in data["scores"]]
|
250
250
|
if "labels" in data:
|
251
251
|
data["labels"] = list(data["labels"])
|
252
|
-
data["
|
252
|
+
data["image_size"] = image_size
|
253
253
|
return data
|
254
254
|
|
255
255
|
|
@@ -277,15 +277,15 @@ class GroundingSAM(Tool):
|
|
277
277
|
"""
|
278
278
|
|
279
279
|
name = "grounding_sam_"
|
280
|
-
description = "'grounding_sam_' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
|
280
|
+
description = "'grounding_sam_' is a tool that can detect and segment multiple objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
|
281
281
|
usage = {
|
282
282
|
"required_parameters": [
|
283
283
|
{"name": "prompt", "type": "str"},
|
284
284
|
{"name": "image", "type": "str"},
|
285
285
|
],
|
286
286
|
"optional_parameters": [
|
287
|
-
{"name": "box_threshold", "type": "float"},
|
288
|
-
{"name": "iou_threshold", "type": "float"},
|
287
|
+
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
|
288
|
+
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
|
289
289
|
],
|
290
290
|
"examples": [
|
291
291
|
{
|
@@ -312,7 +312,7 @@ class GroundingSAM(Tool):
|
|
312
312
|
"prompt": "red shirt, green shirt",
|
313
313
|
"image": "shirts.jpg",
|
314
314
|
"box_threshold": 0.20,
|
315
|
-
"iou_threshold": 0.
|
315
|
+
"iou_threshold": 0.20,
|
316
316
|
},
|
317
317
|
},
|
318
318
|
],
|
@@ -324,7 +324,7 @@ class GroundingSAM(Tool):
|
|
324
324
|
prompt: str,
|
325
325
|
image: Union[str, ImageType],
|
326
326
|
box_threshold: float = 0.2,
|
327
|
-
iou_threshold: float = 0.
|
327
|
+
iou_threshold: float = 0.2,
|
328
328
|
) -> Dict:
|
329
329
|
"""Invoke the Grounding SAM model.
|
330
330
|
|
@@ -353,6 +353,7 @@ class GroundingSAM(Tool):
|
|
353
353
|
rle_decode(mask_rle=mask, shape=data["mask_shape"])
|
354
354
|
for mask in data["masks"]
|
355
355
|
]
|
356
|
+
data["image_size"] = image_size
|
356
357
|
data.pop("mask_shape", None)
|
357
358
|
return data
|
358
359
|
|
@@ -422,7 +423,6 @@ class DINOv(Tool):
|
|
422
423
|
request_data = {
|
423
424
|
"prompt": prompt,
|
424
425
|
"image": image_b64,
|
425
|
-
"tool": "dinov",
|
426
426
|
}
|
427
427
|
data: Dict[str, Any] = _send_inference_request(request_data, "dinov")
|
428
428
|
if "bboxes" in data:
|
@@ -435,6 +435,8 @@ class DINOv(Tool):
|
|
435
435
|
for mask in data["masks"]
|
436
436
|
]
|
437
437
|
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
|
438
|
+
mask_shape = data.pop("mask_shape", None)
|
439
|
+
data["image_size"] = (mask_shape[0], mask_shape[1]) if mask_shape else None
|
438
440
|
return data
|
439
441
|
|
440
442
|
|
@@ -790,33 +792,49 @@ class Crop(Tool):
|
|
790
792
|
return {"image": tmp.name}
|
791
793
|
|
792
794
|
|
793
|
-
class
|
794
|
-
r"""
|
795
|
+
class BboxStats(Tool):
|
796
|
+
r"""BboxStats returns the height, width and area of the bounding box in pixels to 2 decimal places."""
|
795
797
|
|
796
|
-
name = "
|
797
|
-
description = "'
|
798
|
+
name = "bbox_stats_"
|
799
|
+
description = "'bbox_stats_' returns the height, width and area of the given bounding box in pixels to 2 decimal places."
|
798
800
|
usage = {
|
799
|
-
"required_parameters": [
|
801
|
+
"required_parameters": [
|
802
|
+
{"name": "bboxes", "type": "List[int]"},
|
803
|
+
{"name": "image_size", "type": "Tuple[int]"},
|
804
|
+
],
|
800
805
|
"examples": [
|
801
806
|
{
|
802
|
-
"scenario": "
|
803
|
-
"parameters": {
|
804
|
-
|
807
|
+
"scenario": "Calculate the width and height of the bounding box [0.2, 0.21, 0.34, 0.42]",
|
808
|
+
"parameters": {
|
809
|
+
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
|
810
|
+
"image_size": (500, 1200),
|
811
|
+
},
|
812
|
+
},
|
813
|
+
{
|
814
|
+
"scenario": "Calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
|
815
|
+
"parameters": {
|
816
|
+
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
|
817
|
+
"image_size": (640, 480),
|
818
|
+
},
|
819
|
+
},
|
805
820
|
],
|
806
821
|
}
|
807
822
|
|
808
|
-
def __call__(
|
823
|
+
def __call__(
|
824
|
+
self, bboxes: List[List[int]], image_size: Tuple[int, int]
|
825
|
+
) -> List[Dict]:
|
809
826
|
areas = []
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
827
|
+
height, width = image_size
|
828
|
+
for bbox in bboxes:
|
829
|
+
x1, y1, x2, y2 = bbox
|
830
|
+
areas.append(
|
831
|
+
{
|
832
|
+
"width": round((x2 - x1) * width, 2),
|
833
|
+
"height": round((y2 - y1) * height, 2),
|
834
|
+
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
|
835
|
+
}
|
836
|
+
)
|
837
|
+
|
820
838
|
return areas
|
821
839
|
|
822
840
|
|
@@ -1055,22 +1073,25 @@ class ExtractFrames(Tool):
|
|
1055
1073
|
r"""Extract frames from a video."""
|
1056
1074
|
|
1057
1075
|
name = "extract_frames_"
|
1058
|
-
description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
|
1076
|
+
description = "'extract_frames_' extracts frames from a video every 2 seconds, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
|
1059
1077
|
usage = {
|
1060
1078
|
"required_parameters": [{"name": "video_uri", "type": "str"}],
|
1079
|
+
"optional_parameters": [{"name": "frames_every", "type": "float"}],
|
1061
1080
|
"examples": [
|
1062
1081
|
{
|
1063
1082
|
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
|
1064
1083
|
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
|
1065
1084
|
},
|
1066
1085
|
{
|
1067
|
-
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
|
1068
|
-
"parameters": {"video_uri": "tests/data/test.mp4"},
|
1086
|
+
"scenario": "Can you extract the images from this video file at every 2 seconds ? Video path: tests/data/test.mp4",
|
1087
|
+
"parameters": {"video_uri": "tests/data/test.mp4", "frames_every": 2},
|
1069
1088
|
},
|
1070
1089
|
],
|
1071
1090
|
}
|
1072
1091
|
|
1073
|
-
def __call__(
|
1092
|
+
def __call__(
|
1093
|
+
self, video_uri: str, frames_every: float = 2
|
1094
|
+
) -> List[Tuple[str, float]]:
|
1074
1095
|
"""Extract frames from a video.
|
1075
1096
|
|
1076
1097
|
|
@@ -1080,7 +1101,7 @@ class ExtractFrames(Tool):
|
|
1080
1101
|
Returns:
|
1081
1102
|
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
|
1082
1103
|
"""
|
1083
|
-
frames = extract_frames_from_video(video_uri)
|
1104
|
+
frames = extract_frames_from_video(video_uri, fps=round(1 / frames_every, 2))
|
1084
1105
|
result = []
|
1085
1106
|
_LOGGER.info(
|
1086
1107
|
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
|
@@ -1183,7 +1204,7 @@ TOOLS = {
|
|
1183
1204
|
AgentDINOv,
|
1184
1205
|
ExtractFrames,
|
1185
1206
|
Crop,
|
1186
|
-
|
1207
|
+
BboxStats,
|
1187
1208
|
SegArea,
|
1188
1209
|
ObjectDistance,
|
1189
1210
|
BboxContains,
|
vision_agent/tools/tools_v2.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1
1
|
import inspect
|
2
|
+
import io
|
3
|
+
import logging
|
2
4
|
import tempfile
|
3
5
|
from importlib import resources
|
4
|
-
from
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Callable, Dict, List, Tuple, Union
|
5
8
|
|
6
9
|
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
import requests
|
7
12
|
from PIL import Image, ImageDraw, ImageFont
|
8
13
|
|
9
|
-
from vision_agent.image_utils import convert_to_b64, normalize_bbox
|
10
14
|
from vision_agent.tools.tool_utils import _send_inference_request
|
15
|
+
from vision_agent.utils import extract_frames_from_video
|
16
|
+
from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode
|
11
17
|
|
12
18
|
COLORS = [
|
13
19
|
(158, 218, 229),
|
@@ -31,6 +37,10 @@ COLORS = [
|
|
31
37
|
(255, 127, 14),
|
32
38
|
(31, 119, 180),
|
33
39
|
]
|
40
|
+
_API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV"
|
41
|
+
_OCR_URL = "https://app.landing.ai/ocr/v1/detect-text"
|
42
|
+
logging.basicConfig(level=logging.INFO)
|
43
|
+
_LOGGER = logging.getLogger(__name__)
|
34
44
|
|
35
45
|
|
36
46
|
def grounding_dino(
|
@@ -39,23 +49,30 @@ def grounding_dino(
|
|
39
49
|
box_threshold: float = 0.20,
|
40
50
|
iou_threshold: float = 0.75,
|
41
51
|
) -> List[Dict[str, Any]]:
|
42
|
-
"""'grounding_dino' is a tool that can detect
|
43
|
-
category names or referring expressions.
|
52
|
+
"""'grounding_dino' is a tool that can detect and count objects given a text prompt
|
53
|
+
such as category names or referring expressions. It returns a list and count of
|
54
|
+
bounding boxes, label names and associated probability scores.
|
44
55
|
|
45
56
|
Parameters:
|
46
57
|
prompt (str): The prompt to ground to the image.
|
47
58
|
image (np.ndarray): The image to ground the prompt to.
|
48
|
-
box_threshold (float, optional): The threshold for the box detection. Defaults
|
49
|
-
|
59
|
+
box_threshold (float, optional): The threshold for the box detection. Defaults
|
60
|
+
to 0.20.
|
61
|
+
iou_threshold (float, optional): The threshold for the Intersection over Union
|
62
|
+
(IoU). Defaults to 0.75.
|
50
63
|
|
51
64
|
Returns:
|
52
65
|
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
53
|
-
bounding box of the detected objects with normalized coordinates
|
66
|
+
bounding box of the detected objects with normalized coordinates
|
67
|
+
(x1, y1, x2, y2).
|
54
68
|
|
55
69
|
Example
|
56
70
|
-------
|
57
71
|
>>> grounding_dino("car. dinosaur", image)
|
58
|
-
[
|
72
|
+
[
|
73
|
+
{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
74
|
+
{'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
|
75
|
+
]
|
59
76
|
"""
|
60
77
|
image_size = image.shape[:2]
|
61
78
|
image_b64 = convert_to_b64(Image.fromarray(image))
|
@@ -78,6 +95,147 @@ def grounding_dino(
|
|
78
95
|
return return_data
|
79
96
|
|
80
97
|
|
98
|
+
def grounding_sam(
|
99
|
+
prompt: str,
|
100
|
+
image: np.ndarray,
|
101
|
+
box_threshold: float = 0.20,
|
102
|
+
iou_threshold: float = 0.75,
|
103
|
+
) -> List[Dict[str, Any]]:
|
104
|
+
"""'grounding_sam' is a tool that can detect and segment objects given a text
|
105
|
+
prompt such as category names or referring expressions. It returns a list of
|
106
|
+
bounding boxes, label names and masks file names and associated probability scores.
|
107
|
+
|
108
|
+
Parameters:
|
109
|
+
prompt (str): The prompt to ground to the image.
|
110
|
+
image (np.ndarray): The image to ground the prompt to.
|
111
|
+
box_threshold (float, optional): The threshold for the box detection. Defaults
|
112
|
+
to 0.20.
|
113
|
+
iou_threshold (float, optional): The threshold for the Intersection over Union
|
114
|
+
(IoU). Defaults to 0.75.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
|
118
|
+
bounding box, and mask of the detected objects with normalized coordinates
|
119
|
+
(x1, y1, x2, y2).
|
120
|
+
|
121
|
+
Example
|
122
|
+
-------
|
123
|
+
>>> grounding_sam("car. dinosaur", image)
|
124
|
+
[
|
125
|
+
{
|
126
|
+
'score': 0.99,
|
127
|
+
'label': 'dinosaur',
|
128
|
+
'bbox': [0.1, 0.11, 0.35, 0.4],
|
129
|
+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
130
|
+
[0, 0, 0, ..., 0, 0, 0],
|
131
|
+
...,
|
132
|
+
[0, 0, 0, ..., 0, 0, 0],
|
133
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
134
|
+
},
|
135
|
+
]
|
136
|
+
"""
|
137
|
+
image_size = image.shape[:2]
|
138
|
+
image_b64 = convert_to_b64(Image.fromarray(image))
|
139
|
+
request_data = {
|
140
|
+
"prompt": prompt,
|
141
|
+
"image": image_b64,
|
142
|
+
"tool": "visual_grounding_segment",
|
143
|
+
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
|
144
|
+
}
|
145
|
+
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
|
146
|
+
return_data = []
|
147
|
+
for i in range(len(data["bboxes"])):
|
148
|
+
return_data.append(
|
149
|
+
{
|
150
|
+
"score": round(data["scores"][i], 2),
|
151
|
+
"label": data["labels"][i],
|
152
|
+
"bbox": normalize_bbox(data["bboxes"][i], image_size),
|
153
|
+
"mask": rle_decode(mask_rle=data["masks"][i], shape=data["mask_shape"]),
|
154
|
+
}
|
155
|
+
)
|
156
|
+
return return_data
|
157
|
+
|
158
|
+
|
159
|
+
def extract_frames(
|
160
|
+
video_uri: Union[str, Path], fps: float = 0.5
|
161
|
+
) -> List[Tuple[np.ndarray, float]]:
|
162
|
+
"""'extract_frames' extracts frames from a video, returns a list of tuples (frame,
|
163
|
+
timestamp), where timestamp is the relative time in seconds where the frame was
|
164
|
+
captured. The frame is a local image file path.
|
165
|
+
|
166
|
+
Parameters:
|
167
|
+
video_uri (Union[str, Path]): The path to the video file.
|
168
|
+
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
169
|
+
to 0.5.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
|
173
|
+
and the timestamp in seconds.
|
174
|
+
|
175
|
+
Example
|
176
|
+
-------
|
177
|
+
>>> extract_frames("path/to/video.mp4")
|
178
|
+
[(frame1, 0.0), (frame2, 0.5), ...]
|
179
|
+
"""
|
180
|
+
|
181
|
+
return extract_frames_from_video(str(video_uri), fps)
|
182
|
+
|
183
|
+
|
184
|
+
def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
185
|
+
"""'ocr' extracts text from an image. It returns a list of detected text, bounding
|
186
|
+
boxes, and confidence scores.
|
187
|
+
|
188
|
+
Parameters:
|
189
|
+
image (np.ndarray): The image to extract text from.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox,
|
193
|
+
and confidence score.
|
194
|
+
|
195
|
+
Example
|
196
|
+
-------
|
197
|
+
>>> ocr(image)
|
198
|
+
[
|
199
|
+
{'label': 'some text', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
|
200
|
+
]
|
201
|
+
"""
|
202
|
+
|
203
|
+
pil_image = Image.fromarray(image).convert("RGB")
|
204
|
+
image_size = pil_image.size[::-1]
|
205
|
+
image_buffer = io.BytesIO()
|
206
|
+
pil_image.save(image_buffer, format="PNG")
|
207
|
+
buffer_bytes = image_buffer.getvalue()
|
208
|
+
image_buffer.close()
|
209
|
+
|
210
|
+
res = requests.post(
|
211
|
+
_OCR_URL,
|
212
|
+
files={"images": buffer_bytes},
|
213
|
+
data={"language": "en"},
|
214
|
+
headers={"contentType": "multipart/form-data", "apikey": _API_KEY},
|
215
|
+
)
|
216
|
+
|
217
|
+
if res.status_code != 200:
|
218
|
+
raise ValueError(f"OCR request failed with status code {res.status_code}")
|
219
|
+
|
220
|
+
data = res.json()
|
221
|
+
output = []
|
222
|
+
for det in data[0]:
|
223
|
+
label = det["text"]
|
224
|
+
box = [
|
225
|
+
det["location"][0]["x"],
|
226
|
+
det["location"][0]["y"],
|
227
|
+
det["location"][2]["x"],
|
228
|
+
det["location"][2]["y"],
|
229
|
+
]
|
230
|
+
box = normalize_bbox(box, image_size)
|
231
|
+
output.append({"label": label, "bbox": box, "score": round(det["score"], 2)})
|
232
|
+
|
233
|
+
return output
|
234
|
+
|
235
|
+
|
236
|
+
# Utility and visualization functions
|
237
|
+
|
238
|
+
|
81
239
|
def load_image(image_path: str) -> np.ndarray:
|
82
240
|
"""'load_image' is a utility function that loads an image from the given path.
|
83
241
|
|
@@ -117,24 +275,33 @@ def save_image(image: np.ndarray) -> str:
|
|
117
275
|
return f.name
|
118
276
|
|
119
277
|
|
120
|
-
def
|
278
|
+
def overlay_bounding_boxes(
|
121
279
|
image: np.ndarray, bboxes: List[Dict[str, Any]]
|
122
280
|
) -> np.ndarray:
|
123
|
-
"""'display_bounding_boxes' is a utility function that displays bounding boxes on
|
281
|
+
"""'display_bounding_boxes' is a utility function that displays bounding boxes on
|
282
|
+
an image.
|
124
283
|
|
125
284
|
Parameters:
|
126
285
|
image (np.ndarray): The image to display the bounding boxes on.
|
127
|
-
bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
286
|
+
bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
287
|
+
boxes.
|
128
288
|
|
129
289
|
Returns:
|
130
|
-
np.ndarray: The image with the bounding boxes displayed.
|
290
|
+
np.ndarray: The image with the bounding boxes, labels and scores displayed.
|
131
291
|
|
132
292
|
Example
|
133
293
|
-------
|
134
|
-
>>> image_with_bboxes = display_bounding_boxes(
|
294
|
+
>>> image_with_bboxes = display_bounding_boxes(
|
295
|
+
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
|
296
|
+
)
|
135
297
|
"""
|
136
298
|
pil_image = Image.fromarray(image.astype(np.uint8))
|
137
299
|
|
300
|
+
if len(set([box["label"] for box in bboxes])) > len(COLORS):
|
301
|
+
_LOGGER.warning(
|
302
|
+
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
|
303
|
+
)
|
304
|
+
|
138
305
|
color = {
|
139
306
|
label: COLORS[i % len(COLORS)]
|
140
307
|
for i, label in enumerate(set([box["label"] for box in bboxes]))
|
@@ -167,15 +334,109 @@ def display_bounding_boxes(
|
|
167
334
|
return np.array(pil_image.convert("RGB"))
|
168
335
|
|
169
336
|
|
170
|
-
def
|
337
|
+
def overlay_segmentation_masks(
|
338
|
+
image: np.ndarray, masks: List[Dict[str, Any]]
|
339
|
+
) -> np.ndarray:
|
340
|
+
"""'display_segmentation_masks' is a utility function that displays segmentation
|
341
|
+
masks.
|
342
|
+
|
343
|
+
Parameters:
|
344
|
+
image (np.ndarray): The image to display the masks on.
|
345
|
+
masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
np.ndarray: The image with the masks displayed.
|
349
|
+
|
350
|
+
Example
|
351
|
+
-------
|
352
|
+
>>> image_with_masks = display_segmentation_masks(
|
353
|
+
image,
|
354
|
+
[{
|
355
|
+
'score': 0.99,
|
356
|
+
'label': 'dinosaur',
|
357
|
+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
358
|
+
[0, 0, 0, ..., 0, 0, 0],
|
359
|
+
...,
|
360
|
+
[0, 0, 0, ..., 0, 0, 0],
|
361
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
362
|
+
}],
|
363
|
+
)
|
364
|
+
"""
|
365
|
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")
|
366
|
+
|
367
|
+
if len(set([mask["label"] for mask in masks])) > len(COLORS):
|
368
|
+
_LOGGER.warning(
|
369
|
+
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
|
370
|
+
)
|
371
|
+
|
372
|
+
color = {
|
373
|
+
label: COLORS[i % len(COLORS)]
|
374
|
+
for i, label in enumerate(set([mask["label"] for mask in masks]))
|
375
|
+
}
|
376
|
+
|
377
|
+
for elt in masks:
|
378
|
+
mask = elt["mask"]
|
379
|
+
label = elt["label"]
|
380
|
+
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
|
381
|
+
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
|
382
|
+
mask_img = Image.fromarray(np_mask.astype(np.uint8))
|
383
|
+
pil_image = Image.alpha_composite(pil_image, mask_img)
|
384
|
+
return np.array(pil_image.convert("RGB"))
|
385
|
+
|
386
|
+
|
387
|
+
def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
|
171
388
|
docstrings = ""
|
172
389
|
for func in funcs:
|
173
|
-
docstrings += f"{func.__name__}
|
390
|
+
docstrings += f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}\n\n"
|
174
391
|
|
175
392
|
return docstrings
|
176
393
|
|
177
394
|
|
178
|
-
|
395
|
+
def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
|
396
|
+
descriptions = ""
|
397
|
+
for func in funcs:
|
398
|
+
description = func.__doc__
|
399
|
+
if description is None:
|
400
|
+
description = ""
|
401
|
+
|
402
|
+
description = (
|
403
|
+
description[: description.find("Parameters:")].replace("\n", " ").strip()
|
404
|
+
)
|
405
|
+
description = " ".join(description.split())
|
406
|
+
descriptions += f"- {func.__name__}{inspect.signature(func)}: {description}\n"
|
407
|
+
return descriptions
|
408
|
+
|
409
|
+
|
410
|
+
def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
|
411
|
+
data: Dict[str, List[str]] = {"desc": [], "doc": []}
|
412
|
+
|
413
|
+
for func in funcs:
|
414
|
+
desc = func.__doc__
|
415
|
+
if desc is None:
|
416
|
+
desc = ""
|
417
|
+
desc = desc[: desc.find("Parameters:")].replace("\n", " ").strip()
|
418
|
+
desc = " ".join(desc.split())
|
419
|
+
|
420
|
+
doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
|
421
|
+
data["desc"].append(desc)
|
422
|
+
data["doc"].append(doc)
|
423
|
+
|
424
|
+
return pd.DataFrame(data) # type: ignore
|
425
|
+
|
426
|
+
|
427
|
+
TOOLS = [
|
428
|
+
grounding_dino,
|
429
|
+
grounding_sam,
|
430
|
+
extract_frames,
|
431
|
+
ocr,
|
432
|
+
load_image,
|
433
|
+
save_image,
|
434
|
+
overlay_bounding_boxes,
|
435
|
+
overlay_segmentation_masks,
|
436
|
+
]
|
437
|
+
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
|
438
|
+
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
|
439
|
+
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
|
179
440
|
UTILITIES_DOCSTRING = get_tool_documentation(
|
180
|
-
[load_image, save_image,
|
441
|
+
[load_image, save_image, overlay_bounding_boxes]
|
181
442
|
)
|