vision-agent 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +23 -4
- vision_agent/image_utils.py +3 -1
- vision_agent/tools/__init__.py +1 -0
- vision_agent/tools/tools.py +157 -9
- {vision_agent-0.1.3.dist-info → vision_agent-0.1.5.dist-info}/METADATA +1 -1
- {vision_agent-0.1.3.dist-info → vision_agent-0.1.5.dist-info}/RECORD +8 -8
- {vision_agent-0.1.3.dist-info → vision_agent-0.1.5.dist-info}/LICENSE +0 -0
- {vision_agent-0.1.3.dist-info → vision_agent-0.1.5.dist-info}/WHEEL +0 -0
@@ -33,6 +33,7 @@ from .vision_agent_prompts import (
|
|
33
33
|
|
34
34
|
logging.basicConfig(stream=sys.stdout)
|
35
35
|
_LOGGER = logging.getLogger(__name__)
|
36
|
+
_MAX_TABULATE_COL_WIDTH = 80
|
36
37
|
|
37
38
|
|
38
39
|
def parse_json(s: str) -> Any:
|
@@ -365,6 +366,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
|
|
365
366
|
"grounding_sam_",
|
366
367
|
"grounding_dino_",
|
367
368
|
"extract_frames_",
|
369
|
+
"dinov_",
|
368
370
|
]:
|
369
371
|
continue
|
370
372
|
|
@@ -444,6 +446,7 @@ class VisionAgent(Agent):
|
|
444
446
|
self,
|
445
447
|
input: Union[List[Dict[str, str]], str],
|
446
448
|
image: Optional[Union[str, Path]] = None,
|
449
|
+
reference_data: Optional[Dict[str, str]] = None,
|
447
450
|
visualize_output: Optional[bool] = False,
|
448
451
|
) -> str:
|
449
452
|
"""Invoke the vision agent.
|
@@ -458,7 +461,12 @@ class VisionAgent(Agent):
|
|
458
461
|
"""
|
459
462
|
if isinstance(input, str):
|
460
463
|
input = [{"role": "user", "content": input}]
|
461
|
-
return self.chat(
|
464
|
+
return self.chat(
|
465
|
+
input,
|
466
|
+
image=image,
|
467
|
+
visualize_output=visualize_output,
|
468
|
+
reference_data=reference_data,
|
469
|
+
)
|
462
470
|
|
463
471
|
def log_progress(self, description: str) -> None:
|
464
472
|
_LOGGER.info(description)
|
@@ -469,11 +477,18 @@ class VisionAgent(Agent):
|
|
469
477
|
self,
|
470
478
|
chat: List[Dict[str, str]],
|
471
479
|
image: Optional[Union[str, Path]] = None,
|
480
|
+
reference_data: Optional[Dict[str, str]] = None,
|
472
481
|
visualize_output: Optional[bool] = False,
|
473
482
|
) -> Tuple[str, List[Dict]]:
|
474
483
|
question = chat[0]["content"]
|
475
484
|
if image:
|
476
485
|
question += f" Image name: {image}"
|
486
|
+
if reference_data:
|
487
|
+
if not ("image" in reference_data and "mask" in reference_data):
|
488
|
+
raise ValueError(
|
489
|
+
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
|
490
|
+
)
|
491
|
+
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
|
477
492
|
|
478
493
|
reflections = ""
|
479
494
|
final_answer = ""
|
@@ -555,10 +570,14 @@ class VisionAgent(Agent):
|
|
555
570
|
self,
|
556
571
|
chat: List[Dict[str, str]],
|
557
572
|
image: Optional[Union[str, Path]] = None,
|
573
|
+
reference_data: Optional[Dict[str, str]] = None,
|
558
574
|
visualize_output: Optional[bool] = False,
|
559
575
|
) -> str:
|
560
576
|
answer, _ = self.chat_with_workflow(
|
561
|
-
chat,
|
577
|
+
chat,
|
578
|
+
image=image,
|
579
|
+
visualize_output=visualize_output,
|
580
|
+
reference_data=reference_data,
|
562
581
|
)
|
563
582
|
return answer
|
564
583
|
|
@@ -596,7 +615,7 @@ class VisionAgent(Agent):
|
|
596
615
|
|
597
616
|
self.log_progress(
|
598
617
|
f"""Going to run the following tool(s) in sequence:
|
599
|
-
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
|
618
|
+
{tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
|
600
619
|
)
|
601
620
|
|
602
621
|
def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
|
@@ -642,6 +661,6 @@ class VisionAgent(Agent):
|
|
642
661
|
task_list = []
|
643
662
|
self.log_progress(
|
644
663
|
f"""Planned tasks:
|
645
|
-
{tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
|
664
|
+
{tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
|
646
665
|
)
|
647
666
|
return task_list
|
vision_agent/image_utils.py
CHANGED
@@ -103,7 +103,9 @@ def overlay_bboxes(
|
|
103
103
|
elif isinstance(image, np.ndarray):
|
104
104
|
image = Image.fromarray(image)
|
105
105
|
|
106
|
-
color = {
|
106
|
+
color = {
|
107
|
+
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
|
108
|
+
}
|
107
109
|
|
108
110
|
width, height = image.size
|
109
111
|
fontsize = max(12, int(min(width, height) / 40))
|
vision_agent/tools/__init__.py
CHANGED
vision_agent/tools/tools.py
CHANGED
@@ -250,7 +250,7 @@ class GroundingDINO(Tool):
|
|
250
250
|
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
|
251
251
|
|
252
252
|
Returns:
|
253
|
-
A
|
253
|
+
A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image.
|
254
254
|
"""
|
255
255
|
image_size = get_image_size(image)
|
256
256
|
image_b64 = convert_to_b64(image)
|
@@ -346,7 +346,7 @@ class GroundingSAM(Tool):
|
|
346
346
|
iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
|
347
347
|
|
348
348
|
Returns:
|
349
|
-
A
|
349
|
+
A dictionary containing the labels, scores, bboxes and masks for the input image.
|
350
350
|
"""
|
351
351
|
image_size = get_image_size(image)
|
352
352
|
image_b64 = convert_to_b64(image)
|
@@ -357,19 +357,113 @@ class GroundingSAM(Tool):
|
|
357
357
|
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
|
358
358
|
}
|
359
359
|
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
|
360
|
-
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
|
361
360
|
if "bboxes" in data:
|
362
|
-
|
363
|
-
|
361
|
+
data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
|
362
|
+
if "masks" in data:
|
363
|
+
data["masks"] = [
|
364
|
+
rle_decode(mask_rle=mask, shape=data["mask_shape"])
|
365
|
+
for mask in data["masks"]
|
366
|
+
]
|
367
|
+
data.pop("mask_shape", None)
|
368
|
+
return data
|
369
|
+
|
370
|
+
|
371
|
+
class DINOv(Tool):
|
372
|
+
r"""DINOv is a tool that can detect and segment similar objects with the given input masks.
|
373
|
+
|
374
|
+
Example
|
375
|
+
-------
|
376
|
+
>>> import vision_agent as va
|
377
|
+
>>> t = va.tools.DINOv()
|
378
|
+
>>> t(prompt=[{"mask":"balloon_mask.jpg", "image": "balloon.jpg"}], image="balloon.jpg"])
|
379
|
+
[{'scores': [0.512, 0.212],
|
380
|
+
'masks': [array([[0, 0, 0, ..., 0, 0, 0],
|
381
|
+
...,
|
382
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)},
|
383
|
+
array([[0, 0, 0, ..., 0, 0, 0],
|
384
|
+
...,
|
385
|
+
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
|
386
|
+
"""
|
387
|
+
|
388
|
+
name = "dinov_"
|
389
|
+
description = "'dinov_' is a tool that can detect and segment similar objects given a reference segmentation mask."
|
390
|
+
usage = {
|
391
|
+
"required_parameters": [
|
392
|
+
{"name": "prompt", "type": "List[Dict[str, str]]"},
|
393
|
+
{"name": "image", "type": "str"},
|
394
|
+
],
|
395
|
+
"examples": [
|
396
|
+
{
|
397
|
+
"scenario": "Can you find all the balloons in this image that is similar to the provided masked area? Image name: input.jpg Reference image: balloon.jpg Reference mask: balloon_mask.jpg",
|
398
|
+
"parameters": {
|
399
|
+
"prompt": [
|
400
|
+
{"mask": "balloon_mask.jpg", "image": "balloon.jpg"},
|
401
|
+
],
|
402
|
+
"image": "input.jpg",
|
403
|
+
},
|
404
|
+
},
|
405
|
+
{
|
406
|
+
"scenario": "Detect all the objects in this image that are similar to the provided mask. Image name: original.jpg Reference image: mask.png Reference mask: background.png",
|
407
|
+
"parameters": {
|
408
|
+
"prompt": [
|
409
|
+
{"mask": "mask.png", "image": "background.png"},
|
410
|
+
],
|
411
|
+
"image": "original.jpg",
|
412
|
+
},
|
413
|
+
},
|
414
|
+
],
|
415
|
+
}
|
416
|
+
|
417
|
+
def __call__(
|
418
|
+
self, prompt: List[Dict[str, str]], image: Union[str, ImageType]
|
419
|
+
) -> Dict:
|
420
|
+
"""Invoke the DINOv model.
|
421
|
+
|
422
|
+
Parameters:
|
423
|
+
prompt: a list of visual prompts in the form of {'mask': 'MASK_FILE_PATH', 'image': 'IMAGE_FILE_PATH'}.
|
424
|
+
image: the input image to segment.
|
425
|
+
|
426
|
+
Returns:
|
427
|
+
A dictionary of the below keys: 'scores', 'masks' and 'mask_shape', which stores a list of detected segmentation masks and its scores.
|
428
|
+
"""
|
429
|
+
image_b64 = convert_to_b64(image)
|
430
|
+
for p in prompt:
|
431
|
+
p["mask"] = convert_to_b64(p["mask"])
|
432
|
+
p["image"] = convert_to_b64(p["image"])
|
433
|
+
request_data = {
|
434
|
+
"prompt": prompt,
|
435
|
+
"image": image_b64,
|
436
|
+
"tool": "dinov",
|
437
|
+
}
|
438
|
+
data: Dict[str, Any] = _send_inference_request(request_data, "dinov")
|
439
|
+
if "bboxes" in data:
|
440
|
+
data["bboxes"] = [
|
441
|
+
normalize_bbox(box, data["mask_shape"]) for box in data["bboxes"]
|
364
442
|
]
|
365
443
|
if "masks" in data:
|
366
|
-
|
444
|
+
data["masks"] = [
|
367
445
|
rle_decode(mask_rle=mask, shape=data["mask_shape"])
|
368
446
|
for mask in data["masks"]
|
369
447
|
]
|
370
|
-
|
371
|
-
|
372
|
-
|
448
|
+
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
|
449
|
+
return data
|
450
|
+
|
451
|
+
|
452
|
+
class AgentDINOv(DINOv):
|
453
|
+
def __call__(
|
454
|
+
self,
|
455
|
+
prompt: List[Dict[str, str]],
|
456
|
+
image: Union[str, ImageType],
|
457
|
+
) -> Dict:
|
458
|
+
rets = super().__call__(prompt, image)
|
459
|
+
mask_files = []
|
460
|
+
for mask in rets["masks"]:
|
461
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
462
|
+
file_name = Path(tmp.name).with_suffix(".mask.png")
|
463
|
+
Image.fromarray(mask * 255).save(file_name)
|
464
|
+
mask_files.append(str(file_name))
|
465
|
+
rets["masks"] = mask_files
|
466
|
+
return rets
|
373
467
|
|
374
468
|
|
375
469
|
class AgentGroundingSAM(GroundingSAM):
|
@@ -545,6 +639,58 @@ class SegIoU(Tool):
|
|
545
639
|
return cast(float, round(iou, 2))
|
546
640
|
|
547
641
|
|
642
|
+
class BboxContains(Tool):
|
643
|
+
name = "bbox_contains_"
|
644
|
+
description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object."
|
645
|
+
usage = {
|
646
|
+
"required_parameters": [
|
647
|
+
{"name": "target", "type": "List[int]"},
|
648
|
+
{"name": "target_class", "type": "str"},
|
649
|
+
{"name": "region", "type": "List[int]"},
|
650
|
+
{"name": "region_class", "type": "str"},
|
651
|
+
],
|
652
|
+
"examples": [
|
653
|
+
{
|
654
|
+
"scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]",
|
655
|
+
"parameters": {
|
656
|
+
"target": [0.2, 0.21, 0.34, 0.42],
|
657
|
+
"target_class": "dog",
|
658
|
+
"region": [0.3, 0.31, 0.44, 0.52],
|
659
|
+
"region_class": "couch",
|
660
|
+
},
|
661
|
+
},
|
662
|
+
{
|
663
|
+
"scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]",
|
664
|
+
"parameters": {
|
665
|
+
"target": [0.2, 0.21, 0.34, 0.42],
|
666
|
+
"target_class": "kid",
|
667
|
+
"region": [0.3, 0.31, 0.44, 0.52],
|
668
|
+
"region_class": "pool",
|
669
|
+
},
|
670
|
+
},
|
671
|
+
],
|
672
|
+
}
|
673
|
+
|
674
|
+
def __call__(
|
675
|
+
self, target: List[int], target_class: str, region: List[int], region_class: str
|
676
|
+
) -> Dict[str, Union[str, float]]:
|
677
|
+
x1, y1, x2, y2 = target
|
678
|
+
x3, y3, x4, y4 = region
|
679
|
+
xA = max(x1, x3)
|
680
|
+
yA = max(y1, y3)
|
681
|
+
xB = min(x2, x4)
|
682
|
+
yB = min(y2, y4)
|
683
|
+
inter_area = max(0, xB - xA) * max(0, yB - yA)
|
684
|
+
boxa_area = (x2 - x1) * (y2 - y1)
|
685
|
+
iou = inter_area / float(boxa_area)
|
686
|
+
area = round(iou, 2)
|
687
|
+
return {
|
688
|
+
"target_class": target_class,
|
689
|
+
"region_class": region_class,
|
690
|
+
"intersection": area,
|
691
|
+
}
|
692
|
+
|
693
|
+
|
548
694
|
class BoxDistance(Tool):
|
549
695
|
name = "box_distance_"
|
550
696
|
description = (
|
@@ -652,12 +798,14 @@ TOOLS = {
|
|
652
798
|
ImageCaption,
|
653
799
|
GroundingDINO,
|
654
800
|
AgentGroundingSAM,
|
801
|
+
AgentDINOv,
|
655
802
|
ExtractFrames,
|
656
803
|
Crop,
|
657
804
|
BboxArea,
|
658
805
|
SegArea,
|
659
806
|
BboxIoU,
|
660
807
|
SegIoU,
|
808
|
+
BboxContains,
|
661
809
|
BoxDistance,
|
662
810
|
Calculator,
|
663
811
|
]
|
@@ -5,7 +5,7 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
|
|
5
5
|
vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
|
6
6
|
vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
|
7
7
|
vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
|
8
|
-
vision_agent/agent/vision_agent.py,sha256=
|
8
|
+
vision_agent/agent/vision_agent.py,sha256=Deuj28hqRq4wHnD08pU_7fok_EicvlGnDoINYh5hw1k,22853
|
9
9
|
vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
|
10
10
|
vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
|
11
11
|
vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
|
@@ -13,17 +13,17 @@ vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,
|
|
13
13
|
vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
|
14
14
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
16
|
-
vision_agent/image_utils.py,sha256=
|
16
|
+
vision_agent/image_utils.py,sha256=qRN_Y1XXBm9EL6V53OZUq21h0spIa1J6X9YDbe6B87o,4805
|
17
17
|
vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
|
18
18
|
vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
|
19
19
|
vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
|
20
20
|
vision_agent/lmm/lmm.py,sha256=1E7e_S_0fOKnf6mSsEdkXvsIjGmhBGl5XW4By2jvhbY,10045
|
21
|
-
vision_agent/tools/__init__.py,sha256=
|
21
|
+
vision_agent/tools/__init__.py,sha256=dkzk9amNzTEKULMB1xRJspqEGpzNPGuccWeXrv1xI0U,280
|
22
22
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
23
|
-
vision_agent/tools/tools.py,sha256=
|
23
|
+
vision_agent/tools/tools.py,sha256=WIodfggPkz_2LSWn_Kqm9uvQUtCgKy3jmMoPVTwf1bA,31181
|
24
24
|
vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
|
25
25
|
vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
|
26
|
-
vision_agent-0.1.
|
27
|
-
vision_agent-0.1.
|
28
|
-
vision_agent-0.1.
|
29
|
-
vision_agent-0.1.
|
26
|
+
vision_agent-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
27
|
+
vision_agent-0.1.5.dist-info/METADATA,sha256=ubzhbZW7oT9sIaIkuM6QObXINZGz5Zcvgjdp7sUcsJE,6233
|
28
|
+
vision_agent-0.1.5.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
29
|
+
vision_agent-0.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|