vision-agent 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +33 -15
- vision_agent/image_utils.py +96 -5
- vision_agent/llm/llm.py +4 -0
- vision_agent/lmm/lmm.py +4 -0
- vision_agent/tools/__init__.py +2 -0
- vision_agent/tools/tools.py +192 -40
- {vision_agent-0.1.4.dist-info → vision_agent-0.1.6.dist-info}/METADATA +5 -2
- {vision_agent-0.1.4.dist-info → vision_agent-0.1.6.dist-info}/RECORD +10 -10
- {vision_agent-0.1.4.dist-info → vision_agent-0.1.6.dist-info}/LICENSE +0 -0
- {vision_agent-0.1.4.dist-info → vision_agent-0.1.6.dist-info}/WHEEL +0 -0
@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
8
|
from PIL import Image
|
9
9
|
from tabulate import tabulate
|
10
10
|
|
11
|
-
from vision_agent.image_utils import overlay_bboxes, overlay_masks
|
11
|
+
from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
|
12
12
|
from vision_agent.llm import LLM, OpenAILLM
|
13
13
|
from vision_agent.lmm import LMM, OpenAILMM
|
14
14
|
from vision_agent.tools import TOOLS
|
@@ -33,6 +33,7 @@ from .vision_agent_prompts import (
|
|
33
33
|
|
34
34
|
logging.basicConfig(stream=sys.stdout)
|
35
35
|
_LOGGER = logging.getLogger(__name__)
|
36
|
+
_MAX_TABULATE_COL_WIDTH = 80
|
36
37
|
|
37
38
|
|
38
39
|
def parse_json(s: str) -> Any:
|
@@ -335,7 +336,9 @@ def _handle_viz_tools(
|
|
335
336
|
|
336
337
|
for param, call_result in zip(parameters, tool_result["call_results"]):
|
337
338
|
# calls can fail, so we need to check if the call was successful
|
338
|
-
if not isinstance(call_result, dict) or
|
339
|
+
if not isinstance(call_result, dict) or (
|
340
|
+
"bboxes" not in call_result and "masks" not in call_result
|
341
|
+
):
|
339
342
|
return image_to_data
|
340
343
|
|
341
344
|
# if the call was successful, then we can add the image data
|
@@ -348,11 +351,12 @@ def _handle_viz_tools(
|
|
348
351
|
"scores": [],
|
349
352
|
}
|
350
353
|
|
351
|
-
image_to_data[image]["bboxes"].extend(call_result
|
352
|
-
image_to_data[image]["labels"].extend(call_result
|
353
|
-
image_to_data[image]["scores"].extend(call_result
|
354
|
-
|
355
|
-
|
354
|
+
image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
|
355
|
+
image_to_data[image]["labels"].extend(call_result.get("labels", []))
|
356
|
+
image_to_data[image]["scores"].extend(call_result.get("scores", []))
|
357
|
+
image_to_data[image]["masks"].extend(call_result.get("masks", []))
|
358
|
+
if "mask_shape" in call_result:
|
359
|
+
image_to_data[image]["mask_shape"] = call_result["mask_shape"]
|
356
360
|
|
357
361
|
return image_to_data
|
358
362
|
|
@@ -366,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
|
|
366
370
|
"grounding_dino_",
|
367
371
|
"extract_frames_",
|
368
372
|
"dinov_",
|
373
|
+
"zero_shot_counting_",
|
374
|
+
"visual_prompt_counting_",
|
369
375
|
]:
|
370
376
|
continue
|
371
377
|
|
@@ -378,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
|
|
378
384
|
for image_str in image_to_data:
|
379
385
|
image_path = Path(image_str)
|
380
386
|
image_data = image_to_data[image_str]
|
381
|
-
|
382
|
-
|
387
|
+
if "_counting_" in tool_result["tool_name"]:
|
388
|
+
image = overlay_heat_map(image_path, image_data)
|
389
|
+
else:
|
390
|
+
image = overlay_masks(image_path, image_data)
|
391
|
+
image = overlay_bboxes(image, image_data)
|
383
392
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
384
393
|
image.save(f.name)
|
385
394
|
visualized_images.append(f.name)
|
@@ -483,11 +492,21 @@ class VisionAgent(Agent):
|
|
483
492
|
if image:
|
484
493
|
question += f" Image name: {image}"
|
485
494
|
if reference_data:
|
486
|
-
if not (
|
495
|
+
if not (
|
496
|
+
"image" in reference_data
|
497
|
+
and ("mask" in reference_data or "bbox" in reference_data)
|
498
|
+
):
|
487
499
|
raise ValueError(
|
488
|
-
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
|
500
|
+
f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
|
489
501
|
)
|
490
|
-
|
502
|
+
visual_prompt_data = (
|
503
|
+
f"Reference mask: {reference_data['mask']}"
|
504
|
+
if "mask" in reference_data
|
505
|
+
else f"Reference bbox: {reference_data['bbox']}"
|
506
|
+
)
|
507
|
+
question += (
|
508
|
+
f" Reference image: {reference_data['image']}, {visual_prompt_data}"
|
509
|
+
)
|
491
510
|
|
492
511
|
reflections = ""
|
493
512
|
final_answer = ""
|
@@ -530,7 +549,6 @@ class VisionAgent(Agent):
|
|
530
549
|
final_answer = answer_summarize(
|
531
550
|
self.answer_model, question, answers, reflections
|
532
551
|
)
|
533
|
-
|
534
552
|
visualized_output = visualize_result(all_tool_results)
|
535
553
|
all_tool_results.append({"visualized_output": visualized_output})
|
536
554
|
if len(visualized_output) > 0:
|
@@ -614,7 +632,7 @@ class VisionAgent(Agent):
|
|
614
632
|
|
615
633
|
self.log_progress(
|
616
634
|
f"""Going to run the following tool(s) in sequence:
|
617
|
-
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
|
635
|
+
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
|
618
636
|
)
|
619
637
|
|
620
638
|
def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
|
@@ -660,6 +678,6 @@ class VisionAgent(Agent):
|
|
660
678
|
task_list = []
|
661
679
|
self.log_progress(
|
662
680
|
f"""Planned tasks:
|
663
|
-
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
|
681
|
+
{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
|
664
682
|
)
|
665
683
|
return task_list
|
vision_agent/image_utils.py
CHANGED
@@ -4,7 +4,7 @@ import base64
|
|
4
4
|
from importlib import resources
|
5
5
|
from io import BytesIO
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, Tuple, Union
|
7
|
+
from typing import Dict, Tuple, Union, List
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from PIL import Image, ImageDraw, ImageFont
|
@@ -34,6 +34,35 @@ COLORS = [
|
|
34
34
|
]
|
35
35
|
|
36
36
|
|
37
|
+
def normalize_bbox(
|
38
|
+
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
|
39
|
+
) -> List[float]:
|
40
|
+
r"""Normalize the bounding box coordinates to be between 0 and 1."""
|
41
|
+
x1, y1, x2, y2 = bbox
|
42
|
+
x1 = round(x1 / image_size[1], 2)
|
43
|
+
y1 = round(y1 / image_size[0], 2)
|
44
|
+
x2 = round(x2 / image_size[1], 2)
|
45
|
+
y2 = round(y2 / image_size[0], 2)
|
46
|
+
return [x1, y1, x2, y2]
|
47
|
+
|
48
|
+
|
49
|
+
def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
|
50
|
+
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
|
51
|
+
|
52
|
+
Parameters:
|
53
|
+
mask_rle: Run-length as string formated (start length)
|
54
|
+
shape: The (height, width) of array to return
|
55
|
+
"""
|
56
|
+
s = mask_rle.split()
|
57
|
+
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
|
58
|
+
starts -= 1
|
59
|
+
ends = starts + lengths
|
60
|
+
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
|
61
|
+
for lo, hi in zip(starts, ends):
|
62
|
+
img[lo:hi] = 1
|
63
|
+
return img.reshape(shape)
|
64
|
+
|
65
|
+
|
37
66
|
def b64_to_pil(b64_str: str) -> ImageType:
|
38
67
|
r"""Convert a base64 string to a PIL Image.
|
39
68
|
|
@@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
|
|
86
115
|
return base64.b64encode(arr_bytes).decode("utf-8")
|
87
116
|
|
88
117
|
|
118
|
+
def denormalize_bbox(
|
119
|
+
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
|
120
|
+
) -> List[float]:
|
121
|
+
r"""DeNormalize the bounding box coordinates so that they are in absolute values."""
|
122
|
+
|
123
|
+
if len(bbox) != 4:
|
124
|
+
raise ValueError("Bounding box must be of length 4.")
|
125
|
+
|
126
|
+
arr = np.array(bbox)
|
127
|
+
if np.all((arr >= 0) & (arr <= 1)):
|
128
|
+
x1, y1, x2, y2 = bbox
|
129
|
+
x1 = round(x1 * image_size[1])
|
130
|
+
y1 = round(y1 * image_size[0])
|
131
|
+
x2 = round(x2 * image_size[1])
|
132
|
+
y2 = round(y2 * image_size[0])
|
133
|
+
return [x1, y1, x2, y2]
|
134
|
+
else:
|
135
|
+
return bbox
|
136
|
+
|
137
|
+
|
89
138
|
def overlay_bboxes(
|
90
139
|
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
|
91
140
|
) -> ImageType:
|
@@ -103,6 +152,9 @@ def overlay_bboxes(
|
|
103
152
|
elif isinstance(image, np.ndarray):
|
104
153
|
image = Image.fromarray(image)
|
105
154
|
|
155
|
+
if "bboxes" not in bboxes:
|
156
|
+
return image.convert("RGB")
|
157
|
+
|
106
158
|
color = {
|
107
159
|
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
|
108
160
|
}
|
@@ -114,8 +166,6 @@ def overlay_bboxes(
|
|
114
166
|
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
|
115
167
|
fontsize,
|
116
168
|
)
|
117
|
-
if "bboxes" not in bboxes:
|
118
|
-
return image.convert("RGB")
|
119
169
|
|
120
170
|
for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
|
121
171
|
box = [
|
@@ -150,11 +200,15 @@ def overlay_masks(
|
|
150
200
|
elif isinstance(image, np.ndarray):
|
151
201
|
image = Image.fromarray(image)
|
152
202
|
|
203
|
+
if "masks" not in masks:
|
204
|
+
return image.convert("RGB")
|
205
|
+
|
206
|
+
if "labels" not in masks:
|
207
|
+
masks["labels"] = [""] * len(masks["masks"])
|
208
|
+
|
153
209
|
color = {
|
154
210
|
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
|
155
211
|
}
|
156
|
-
if "masks" not in masks:
|
157
|
-
return image.convert("RGB")
|
158
212
|
|
159
213
|
for label, mask in zip(masks["labels"], masks["masks"]):
|
160
214
|
if isinstance(mask, str):
|
@@ -164,3 +218,40 @@ def overlay_masks(
|
|
164
218
|
mask_img = Image.fromarray(np_mask.astype(np.uint8))
|
165
219
|
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
|
166
220
|
return image.convert("RGB")
|
221
|
+
|
222
|
+
|
223
|
+
def overlay_heat_map(
|
224
|
+
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
|
225
|
+
) -> ImageType:
|
226
|
+
r"""Plots heat map on to an image.
|
227
|
+
|
228
|
+
Parameters:
|
229
|
+
image: the input image
|
230
|
+
masks: the heatmap to overlay
|
231
|
+
alpha: the transparency of the overlay
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
The image with the heatmap overlayed
|
235
|
+
"""
|
236
|
+
if isinstance(image, (str, Path)):
|
237
|
+
image = Image.open(image)
|
238
|
+
elif isinstance(image, np.ndarray):
|
239
|
+
image = Image.fromarray(image)
|
240
|
+
|
241
|
+
if "masks" not in masks:
|
242
|
+
return image.convert("RGB")
|
243
|
+
|
244
|
+
# Only one heat map per image, so no need to loop through masks
|
245
|
+
image = image.convert("L")
|
246
|
+
|
247
|
+
if isinstance(masks["masks"][0], str):
|
248
|
+
mask = b64_to_pil(masks["masks"][0])
|
249
|
+
|
250
|
+
overlay = Image.new("RGBA", mask.size)
|
251
|
+
odraw = ImageDraw.Draw(overlay)
|
252
|
+
odraw.bitmap(
|
253
|
+
(0, 0), mask, fill=(255, 0, 0, round(alpha * 255))
|
254
|
+
) # fill=(R, G, B, Alpha)
|
255
|
+
combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size))
|
256
|
+
|
257
|
+
return combined.convert("RGB")
|
vision_agent/llm/llm.py
CHANGED
@@ -11,6 +11,7 @@ from vision_agent.tools import (
|
|
11
11
|
SYSTEM_PROMPT,
|
12
12
|
GroundingDINO,
|
13
13
|
GroundingSAM,
|
14
|
+
ZeroShotCounting,
|
14
15
|
)
|
15
16
|
|
16
17
|
|
@@ -127,6 +128,9 @@ class OpenAILLM(LLM):
|
|
127
128
|
|
128
129
|
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
|
129
130
|
|
131
|
+
def generate_zero_shot_counter(self, question: str) -> Callable:
|
132
|
+
return lambda x: ZeroShotCounting()(**{"image": x})
|
133
|
+
|
130
134
|
|
131
135
|
class AzureOpenAILLM(OpenAILLM):
|
132
136
|
def __init__(
|
vision_agent/lmm/lmm.py
CHANGED
@@ -15,6 +15,7 @@ from vision_agent.tools import (
|
|
15
15
|
SYSTEM_PROMPT,
|
16
16
|
GroundingDINO,
|
17
17
|
GroundingSAM,
|
18
|
+
ZeroShotCounting,
|
18
19
|
)
|
19
20
|
|
20
21
|
_LOGGER = logging.getLogger(__name__)
|
@@ -272,6 +273,9 @@ class OpenAILMM(LMM):
|
|
272
273
|
|
273
274
|
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
|
274
275
|
|
276
|
+
def generate_zero_shot_counter(self, question: str) -> Callable:
|
277
|
+
return lambda x: ZeroShotCounting()(**{"image": x})
|
278
|
+
|
275
279
|
|
276
280
|
class AzureOpenAILMM(OpenAILMM):
|
277
281
|
def __init__(
|
vision_agent/tools/__init__.py
CHANGED
vision_agent/tools/tools.py
CHANGED
@@ -9,7 +9,13 @@ import requests
|
|
9
9
|
from PIL import Image
|
10
10
|
from PIL.Image import Image as ImageType
|
11
11
|
|
12
|
-
from vision_agent.image_utils import
|
12
|
+
from vision_agent.image_utils import (
|
13
|
+
convert_to_b64,
|
14
|
+
get_image_size,
|
15
|
+
rle_decode,
|
16
|
+
normalize_bbox,
|
17
|
+
denormalize_bbox,
|
18
|
+
)
|
13
19
|
from vision_agent.tools.video import extract_frames_from_video
|
14
20
|
from vision_agent.type_defs import LandingaiAPIKey
|
15
21
|
|
@@ -18,35 +24,6 @@ _LND_API_KEY = LandingaiAPIKey().api_key
|
|
18
24
|
_LND_API_URL = "https://api.dev.landing.ai/v1/agent"
|
19
25
|
|
20
26
|
|
21
|
-
def normalize_bbox(
|
22
|
-
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
|
23
|
-
) -> List[float]:
|
24
|
-
r"""Normalize the bounding box coordinates to be between 0 and 1."""
|
25
|
-
x1, y1, x2, y2 = bbox
|
26
|
-
x1 = round(x1 / image_size[1], 2)
|
27
|
-
y1 = round(y1 / image_size[0], 2)
|
28
|
-
x2 = round(x2 / image_size[1], 2)
|
29
|
-
y2 = round(y2 / image_size[0], 2)
|
30
|
-
return [x1, y1, x2, y2]
|
31
|
-
|
32
|
-
|
33
|
-
def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
|
34
|
-
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
|
35
|
-
|
36
|
-
Parameters:
|
37
|
-
mask_rle: Run-length as string formated (start length)
|
38
|
-
shape: The (height, width) of array to return
|
39
|
-
"""
|
40
|
-
s = mask_rle.split()
|
41
|
-
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
|
42
|
-
starts -= 1
|
43
|
-
ends = starts + lengths
|
44
|
-
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
|
45
|
-
for lo, hi in zip(starts, ends):
|
46
|
-
img[lo:hi] = 1
|
47
|
-
return img.reshape(shape)
|
48
|
-
|
49
|
-
|
50
27
|
class Tool(ABC):
|
51
28
|
name: str
|
52
29
|
description: str
|
@@ -250,7 +227,7 @@ class GroundingDINO(Tool):
|
|
250
227
|
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
|
251
228
|
|
252
229
|
Returns:
|
253
|
-
A
|
230
|
+
A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image.
|
254
231
|
"""
|
255
232
|
image_size = get_image_size(image)
|
256
233
|
image_b64 = convert_to_b64(image)
|
@@ -346,7 +323,7 @@ class GroundingSAM(Tool):
|
|
346
323
|
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
|
347
324
|
|
348
325
|
Returns:
|
349
|
-
A
|
326
|
+
A dictionary containing the labels, scores, bboxes and masks for the input image.
|
350
327
|
"""
|
351
328
|
image_size = get_image_size(image)
|
352
329
|
image_b64 = convert_to_b64(image)
|
@@ -357,19 +334,15 @@ class GroundingSAM(Tool):
|
|
357
334
|
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
|
358
335
|
}
|
359
336
|
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
|
360
|
-
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
|
361
337
|
if "bboxes" in data:
|
362
|
-
|
363
|
-
normalize_bbox(box, image_size) for box in data["bboxes"]
|
364
|
-
]
|
338
|
+
data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
|
365
339
|
if "masks" in data:
|
366
|
-
|
340
|
+
data["masks"] = [
|
367
341
|
rle_decode(mask_rle=mask, shape=data["mask_shape"])
|
368
342
|
for mask in data["masks"]
|
369
343
|
]
|
370
|
-
|
371
|
-
|
372
|
-
return ret_pred
|
344
|
+
data.pop("mask_shape", None)
|
345
|
+
return data
|
373
346
|
|
374
347
|
|
375
348
|
class DINOv(Tool):
|
@@ -493,6 +466,130 @@ class AgentGroundingSAM(GroundingSAM):
|
|
493
466
|
return rets
|
494
467
|
|
495
468
|
|
469
|
+
class ZeroShotCounting(Tool):
|
470
|
+
r"""ZeroShotCounting is a tool that can count total number of instances of an object
|
471
|
+
present in an image belonging to same class without a text or visual prompt.
|
472
|
+
|
473
|
+
Example
|
474
|
+
-------
|
475
|
+
>>> import vision_agent as va
|
476
|
+
>>> zshot_count = va.tools.ZeroShotCounting()
|
477
|
+
>>> zshot_count("image1.jpg")
|
478
|
+
{'count': 45}
|
479
|
+
"""
|
480
|
+
|
481
|
+
name = "zero_shot_counting_"
|
482
|
+
description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt."
|
483
|
+
|
484
|
+
usage = {
|
485
|
+
"required_parameters": [
|
486
|
+
{"name": "image", "type": "str"},
|
487
|
+
],
|
488
|
+
"examples": [
|
489
|
+
{
|
490
|
+
"scenario": "Can you count the lids in the image ? Image name: lids.jpg",
|
491
|
+
"parameters": {"image": "lids.jpg"},
|
492
|
+
},
|
493
|
+
{
|
494
|
+
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
|
495
|
+
"parameters": {"image": "tray.jpg"},
|
496
|
+
},
|
497
|
+
{
|
498
|
+
"scenario": "Can you build me an object counting tool ? Image name: shirts.jpg",
|
499
|
+
"parameters": {
|
500
|
+
"image": "shirts.jpg",
|
501
|
+
},
|
502
|
+
},
|
503
|
+
],
|
504
|
+
}
|
505
|
+
|
506
|
+
# TODO: Add support for input multiple images, which aligns with the output type.
|
507
|
+
def __call__(self, image: Union[str, ImageType]) -> Dict:
|
508
|
+
"""Invoke the Image captioning model.
|
509
|
+
|
510
|
+
Parameters:
|
511
|
+
image: the input image.
|
512
|
+
|
513
|
+
Returns:
|
514
|
+
A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
|
515
|
+
"""
|
516
|
+
image_b64 = convert_to_b64(image)
|
517
|
+
data = {
|
518
|
+
"image": image_b64,
|
519
|
+
"tool": "zero_shot_counting",
|
520
|
+
}
|
521
|
+
return _send_inference_request(data, "tools")
|
522
|
+
|
523
|
+
|
524
|
+
class VisualPromptCounting(Tool):
|
525
|
+
r"""VisualPromptCounting is a tool that can count total number of instances of an object
|
526
|
+
present in an image belonging to same class with help of an visual prompt which is a bounding box.
|
527
|
+
|
528
|
+
Example
|
529
|
+
-------
|
530
|
+
>>> import vision_agent as va
|
531
|
+
>>> prompt_count = va.tools.VisualPromptCounting()
|
532
|
+
>>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
|
533
|
+
{'count': 23}
|
534
|
+
"""
|
535
|
+
|
536
|
+
name = "visual_prompt_counting_"
|
537
|
+
description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box."
|
538
|
+
|
539
|
+
usage = {
|
540
|
+
"required_parameters": [
|
541
|
+
{"name": "image", "type": "str"},
|
542
|
+
{"name": "prompt", "type": "str"},
|
543
|
+
],
|
544
|
+
"examples": [
|
545
|
+
{
|
546
|
+
"scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg",
|
547
|
+
"parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
|
548
|
+
},
|
549
|
+
{
|
550
|
+
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
|
551
|
+
"parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
|
552
|
+
},
|
553
|
+
{
|
554
|
+
"scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg",
|
555
|
+
"parameters": {
|
556
|
+
"image": "shirts.jpg",
|
557
|
+
"prompt": "0.1, 0.15, 0.2, 0.2",
|
558
|
+
},
|
559
|
+
},
|
560
|
+
{
|
561
|
+
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
|
562
|
+
"parameters": {
|
563
|
+
"image": "shoes.jpg",
|
564
|
+
"prompt": "0.1, 0.1, 0.6, 0.65",
|
565
|
+
},
|
566
|
+
},
|
567
|
+
],
|
568
|
+
}
|
569
|
+
|
570
|
+
# TODO: Add support for input multiple images, which aligns with the output type.
|
571
|
+
def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
|
572
|
+
"""Invoke the Image captioning model.
|
573
|
+
|
574
|
+
Parameters:
|
575
|
+
image: the input image.
|
576
|
+
|
577
|
+
Returns:
|
578
|
+
A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
|
579
|
+
"""
|
580
|
+
image_size = get_image_size(image)
|
581
|
+
bbox = [float(x) for x in prompt.split(",")]
|
582
|
+
prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
|
583
|
+
image_b64 = convert_to_b64(image)
|
584
|
+
|
585
|
+
data = {
|
586
|
+
"image": image_b64,
|
587
|
+
"prompt": prompt,
|
588
|
+
"tool": "few_shot_counting",
|
589
|
+
}
|
590
|
+
return _send_inference_request(data, "tools")
|
591
|
+
|
592
|
+
|
496
593
|
class Crop(Tool):
|
497
594
|
r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
|
498
595
|
|
@@ -643,6 +740,58 @@ class SegIoU(Tool):
|
|
643
740
|
return cast(float, round(iou, 2))
|
644
741
|
|
645
742
|
|
743
|
+
class BboxContains(Tool):
|
744
|
+
name = "bbox_contains_"
|
745
|
+
description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object."
|
746
|
+
usage = {
|
747
|
+
"required_parameters": [
|
748
|
+
{"name": "target", "type": "List[int]"},
|
749
|
+
{"name": "target_class", "type": "str"},
|
750
|
+
{"name": "region", "type": "List[int]"},
|
751
|
+
{"name": "region_class", "type": "str"},
|
752
|
+
],
|
753
|
+
"examples": [
|
754
|
+
{
|
755
|
+
"scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]",
|
756
|
+
"parameters": {
|
757
|
+
"target": [0.2, 0.21, 0.34, 0.42],
|
758
|
+
"target_class": "dog",
|
759
|
+
"region": [0.3, 0.31, 0.44, 0.52],
|
760
|
+
"region_class": "couch",
|
761
|
+
},
|
762
|
+
},
|
763
|
+
{
|
764
|
+
"scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]",
|
765
|
+
"parameters": {
|
766
|
+
"target": [0.2, 0.21, 0.34, 0.42],
|
767
|
+
"target_class": "kid",
|
768
|
+
"region": [0.3, 0.31, 0.44, 0.52],
|
769
|
+
"region_class": "pool",
|
770
|
+
},
|
771
|
+
},
|
772
|
+
],
|
773
|
+
}
|
774
|
+
|
775
|
+
def __call__(
|
776
|
+
self, target: List[int], target_class: str, region: List[int], region_class: str
|
777
|
+
) -> Dict[str, Union[str, float]]:
|
778
|
+
x1, y1, x2, y2 = target
|
779
|
+
x3, y3, x4, y4 = region
|
780
|
+
xA = max(x1, x3)
|
781
|
+
yA = max(y1, y3)
|
782
|
+
xB = min(x2, x4)
|
783
|
+
yB = min(y2, y4)
|
784
|
+
inter_area = max(0, xB - xA) * max(0, yB - yA)
|
785
|
+
boxa_area = (x2 - x1) * (y2 - y1)
|
786
|
+
iou = inter_area / float(boxa_area)
|
787
|
+
area = round(iou, 2)
|
788
|
+
return {
|
789
|
+
"target_class": target_class,
|
790
|
+
"region_class": region_class,
|
791
|
+
"intersection": area,
|
792
|
+
}
|
793
|
+
|
794
|
+
|
646
795
|
class BoxDistance(Tool):
|
647
796
|
name = "box_distance_"
|
648
797
|
description = (
|
@@ -750,6 +899,8 @@ TOOLS = {
|
|
750
899
|
ImageCaption,
|
751
900
|
GroundingDINO,
|
752
901
|
AgentGroundingSAM,
|
902
|
+
ZeroShotCounting,
|
903
|
+
VisualPromptCounting,
|
753
904
|
AgentDINOv,
|
754
905
|
ExtractFrames,
|
755
906
|
Crop,
|
@@ -757,6 +908,7 @@ TOOLS = {
|
|
757
908
|
SegArea,
|
758
909
|
BboxIoU,
|
759
910
|
SegIoU,
|
911
|
+
BboxContains,
|
760
912
|
BoxDistance,
|
761
913
|
Calculator,
|
762
914
|
]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: vision-agent
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: Toolset for Vision Agent
|
5
5
|
Author: Landing AI
|
6
6
|
Author-email: dev@landing.ai
|
@@ -41,7 +41,7 @@ Description-Content-Type: text/markdown
|
|
41
41
|
|
42
42
|
Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
|
43
43
|
Many current vision problems can easily take hours or days to solve, you need to find the
|
44
|
-
right model, figure out how to use it, possibly write programming logic around it to
|
44
|
+
right model, figure out how to use it, possibly write programming logic around it to
|
45
45
|
accomplish the task you want or even more expensive, train your own model. Vision Agent
|
46
46
|
aims to provide an in-seconds experience by allowing users to describe their problem in
|
47
47
|
text and utilizing agent frameworks to solve the task for them. Check out our discord
|
@@ -138,6 +138,9 @@ you. For example:
|
|
138
138
|
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
|
139
139
|
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
|
140
140
|
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
|
141
|
+
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
|
142
|
+
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
|
143
|
+
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |
|
141
144
|
|
142
145
|
|
143
146
|
It also has a basic set of calculate tools such as add, subtract, multiply and divide.
|
@@ -5,7 +5,7 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
|
|
5
5
|
vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
|
6
6
|
vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=MTxeV5_Sghqoe2aOW9EbNgiq61sVCcF3ZndJ7BZl6x0,23588
|
9
9
|
vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
11
|
vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
|
@@ -13,17 +13,17 @@ vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,
|
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
14
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
16
|
-
vision_agent/image_utils.py,sha256=
|
16
|
+
vision_agent/image_utils.py,sha256=Cg4aKO1tQiETT1gdsZ50XzORBtJnBFfMG2cKJyjaY6Q,7555
|
17
17
|
vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
|
18
|
-
vision_agent/llm/llm.py,sha256=
|
18
|
+
vision_agent/llm/llm.py,sha256=gwDQ9-p9wEn24xi1019e5jzTGQg4xWDSqBCsqIqGcU4,5168
|
19
19
|
vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
|
20
|
-
vision_agent/lmm/lmm.py,sha256=
|
21
|
-
vision_agent/tools/__init__.py,sha256=
|
20
|
+
vision_agent/lmm/lmm.py,sha256=FjxCuIk0KXuWnfY4orVmdyhJW2I4C6i5QNNEXk7gybk,10197
|
21
|
+
vision_agent/tools/__init__.py,sha256=BlfxqbYkB0oODhnSmQg1UyzQm73AvvjCjrIiOWBIYDs,328
|
22
22
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
23
|
-
vision_agent/tools/tools.py,sha256=
|
23
|
+
vision_agent/tools/tools.py,sha256=gCjHs5vJuGNBFsnJWFT7PX3wTyfHgtrgX1Eq9vqknN0,34979
|
24
24
|
vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
|
25
25
|
vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
|
26
|
-
vision_agent-0.1.
|
27
|
-
vision_agent-0.1.
|
28
|
-
vision_agent-0.1.
|
29
|
-
vision_agent-0.1.
|
26
|
+
vision_agent-0.1.6.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
27
|
+
vision_agent-0.1.6.dist-info/METADATA,sha256=Ig2tSKyeH8a2A8xZRq72M9XnKyi4_03UM4hDiFpT-eU,6574
|
28
|
+
vision_agent-0.1.6.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
29
|
+
vision_agent-0.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|