vision-agent 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
8
8
  from PIL import Image
9
9
  from tabulate import tabulate
10
10
 
11
- from vision_agent.image_utils import overlay_bboxes, overlay_masks
11
+ from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
12
12
  from vision_agent.llm import LLM, OpenAILLM
13
13
  from vision_agent.lmm import LMM, OpenAILMM
14
14
  from vision_agent.tools import TOOLS
@@ -33,6 +33,7 @@ from .vision_agent_prompts import (
33
33
 
34
34
  logging.basicConfig(stream=sys.stdout)
35
35
  _LOGGER = logging.getLogger(__name__)
36
+ _MAX_TABULATE_COL_WIDTH = 80
36
37
 
37
38
 
38
39
  def parse_json(s: str) -> Any:
@@ -335,7 +336,9 @@ def _handle_viz_tools(
335
336
 
336
337
  for param, call_result in zip(parameters, tool_result["call_results"]):
337
338
  # calls can fail, so we need to check if the call was successful
338
- if not isinstance(call_result, dict) or "bboxes" not in call_result:
339
+ if not isinstance(call_result, dict) or (
340
+ "bboxes" not in call_result and "masks" not in call_result
341
+ ):
339
342
  return image_to_data
340
343
 
341
344
  # if the call was successful, then we can add the image data
@@ -348,11 +351,12 @@ def _handle_viz_tools(
348
351
  "scores": [],
349
352
  }
350
353
 
351
- image_to_data[image]["bboxes"].extend(call_result["bboxes"])
352
- image_to_data[image]["labels"].extend(call_result["labels"])
353
- image_to_data[image]["scores"].extend(call_result["scores"])
354
- if "masks" in call_result:
355
- image_to_data[image]["masks"].extend(call_result["masks"])
354
+ image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
355
+ image_to_data[image]["labels"].extend(call_result.get("labels", []))
356
+ image_to_data[image]["scores"].extend(call_result.get("scores", []))
357
+ image_to_data[image]["masks"].extend(call_result.get("masks", []))
358
+ if "mask_shape" in call_result:
359
+ image_to_data[image]["mask_shape"] = call_result["mask_shape"]
356
360
 
357
361
  return image_to_data
358
362
 
@@ -366,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
366
370
  "grounding_dino_",
367
371
  "extract_frames_",
368
372
  "dinov_",
373
+ "zero_shot_counting_",
374
+ "visual_prompt_counting_",
369
375
  ]:
370
376
  continue
371
377
 
@@ -378,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
378
384
  for image_str in image_to_data:
379
385
  image_path = Path(image_str)
380
386
  image_data = image_to_data[image_str]
381
- image = overlay_masks(image_path, image_data)
382
- image = overlay_bboxes(image, image_data)
387
+ if "_counting_" in tool_result["tool_name"]:
388
+ image = overlay_heat_map(image_path, image_data)
389
+ else:
390
+ image = overlay_masks(image_path, image_data)
391
+ image = overlay_bboxes(image, image_data)
383
392
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
384
393
  image.save(f.name)
385
394
  visualized_images.append(f.name)
@@ -483,11 +492,21 @@ class VisionAgent(Agent):
483
492
  if image:
484
493
  question += f" Image name: {image}"
485
494
  if reference_data:
486
- if not ("image" in reference_data and "mask" in reference_data):
495
+ if not (
496
+ "image" in reference_data
497
+ and ("mask" in reference_data or "bbox" in reference_data)
498
+ ):
487
499
  raise ValueError(
488
- f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
500
+ f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
489
501
  )
490
- question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
502
+ visual_prompt_data = (
503
+ f"Reference mask: {reference_data['mask']}"
504
+ if "mask" in reference_data
505
+ else f"Reference bbox: {reference_data['bbox']}"
506
+ )
507
+ question += (
508
+ f" Reference image: {reference_data['image']}, {visual_prompt_data}"
509
+ )
491
510
 
492
511
  reflections = ""
493
512
  final_answer = ""
@@ -530,7 +549,6 @@ class VisionAgent(Agent):
530
549
  final_answer = answer_summarize(
531
550
  self.answer_model, question, answers, reflections
532
551
  )
533
-
534
552
  visualized_output = visualize_result(all_tool_results)
535
553
  all_tool_results.append({"visualized_output": visualized_output})
536
554
  if len(visualized_output) > 0:
@@ -614,7 +632,7 @@ class VisionAgent(Agent):
614
632
 
615
633
  self.log_progress(
616
634
  f"""Going to run the following tool(s) in sequence:
617
- {tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
635
+ {tabulate(tabular_data=[tool_results], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
618
636
  )
619
637
 
620
638
  def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
@@ -660,6 +678,6 @@ class VisionAgent(Agent):
660
678
  task_list = []
661
679
  self.log_progress(
662
680
  f"""Planned tasks:
663
- {tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
681
+ {tabulate(task_list, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
664
682
  )
665
683
  return task_list
@@ -4,7 +4,7 @@ import base64
4
4
  from importlib import resources
5
5
  from io import BytesIO
6
6
  from pathlib import Path
7
- from typing import Dict, Tuple, Union
7
+ from typing import Dict, Tuple, Union, List
8
8
 
9
9
  import numpy as np
10
10
  from PIL import Image, ImageDraw, ImageFont
@@ -34,6 +34,35 @@ COLORS = [
34
34
  ]
35
35
 
36
36
 
37
+ def normalize_bbox(
38
+ bbox: List[Union[int, float]], image_size: Tuple[int, ...]
39
+ ) -> List[float]:
40
+ r"""Normalize the bounding box coordinates to be between 0 and 1."""
41
+ x1, y1, x2, y2 = bbox
42
+ x1 = round(x1 / image_size[1], 2)
43
+ y1 = round(y1 / image_size[0], 2)
44
+ x2 = round(x2 / image_size[1], 2)
45
+ y2 = round(y2 / image_size[0], 2)
46
+ return [x1, y1, x2, y2]
47
+
48
+
49
+ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
50
+ r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
51
+
52
+ Parameters:
53
+ mask_rle: Run-length as string formated (start length)
54
+ shape: The (height, width) of array to return
55
+ """
56
+ s = mask_rle.split()
57
+ starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
58
+ starts -= 1
59
+ ends = starts + lengths
60
+ img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
61
+ for lo, hi in zip(starts, ends):
62
+ img[lo:hi] = 1
63
+ return img.reshape(shape)
64
+
65
+
37
66
  def b64_to_pil(b64_str: str) -> ImageType:
38
67
  r"""Convert a base64 string to a PIL Image.
39
68
 
@@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
86
115
  return base64.b64encode(arr_bytes).decode("utf-8")
87
116
 
88
117
 
118
+ def denormalize_bbox(
119
+ bbox: List[Union[int, float]], image_size: Tuple[int, ...]
120
+ ) -> List[float]:
121
+ r"""DeNormalize the bounding box coordinates so that they are in absolute values."""
122
+
123
+ if len(bbox) != 4:
124
+ raise ValueError("Bounding box must be of length 4.")
125
+
126
+ arr = np.array(bbox)
127
+ if np.all((arr >= 0) & (arr <= 1)):
128
+ x1, y1, x2, y2 = bbox
129
+ x1 = round(x1 * image_size[1])
130
+ y1 = round(y1 * image_size[0])
131
+ x2 = round(x2 * image_size[1])
132
+ y2 = round(y2 * image_size[0])
133
+ return [x1, y1, x2, y2]
134
+ else:
135
+ return bbox
136
+
137
+
89
138
  def overlay_bboxes(
90
139
  image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
91
140
  ) -> ImageType:
@@ -103,6 +152,9 @@ def overlay_bboxes(
103
152
  elif isinstance(image, np.ndarray):
104
153
  image = Image.fromarray(image)
105
154
 
155
+ if "bboxes" not in bboxes:
156
+ return image.convert("RGB")
157
+
106
158
  color = {
107
159
  label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
108
160
  }
@@ -114,8 +166,6 @@ def overlay_bboxes(
114
166
  str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
115
167
  fontsize,
116
168
  )
117
- if "bboxes" not in bboxes:
118
- return image.convert("RGB")
119
169
 
120
170
  for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
121
171
  box = [
@@ -150,11 +200,15 @@ def overlay_masks(
150
200
  elif isinstance(image, np.ndarray):
151
201
  image = Image.fromarray(image)
152
202
 
203
+ if "masks" not in masks:
204
+ return image.convert("RGB")
205
+
206
+ if "labels" not in masks:
207
+ masks["labels"] = [""] * len(masks["masks"])
208
+
153
209
  color = {
154
210
  label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
155
211
  }
156
- if "masks" not in masks:
157
- return image.convert("RGB")
158
212
 
159
213
  for label, mask in zip(masks["labels"], masks["masks"]):
160
214
  if isinstance(mask, str):
@@ -164,3 +218,40 @@ def overlay_masks(
164
218
  mask_img = Image.fromarray(np_mask.astype(np.uint8))
165
219
  image = Image.alpha_composite(image.convert("RGBA"), mask_img)
166
220
  return image.convert("RGB")
221
+
222
+
223
+ def overlay_heat_map(
224
+ image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
225
+ ) -> ImageType:
226
+ r"""Plots heat map on to an image.
227
+
228
+ Parameters:
229
+ image: the input image
230
+ masks: the heatmap to overlay
231
+ alpha: the transparency of the overlay
232
+
233
+ Returns:
234
+ The image with the heatmap overlayed
235
+ """
236
+ if isinstance(image, (str, Path)):
237
+ image = Image.open(image)
238
+ elif isinstance(image, np.ndarray):
239
+ image = Image.fromarray(image)
240
+
241
+ if "masks" not in masks:
242
+ return image.convert("RGB")
243
+
244
+ # Only one heat map per image, so no need to loop through masks
245
+ image = image.convert("L")
246
+
247
+ if isinstance(masks["masks"][0], str):
248
+ mask = b64_to_pil(masks["masks"][0])
249
+
250
+ overlay = Image.new("RGBA", mask.size)
251
+ odraw = ImageDraw.Draw(overlay)
252
+ odraw.bitmap(
253
+ (0, 0), mask, fill=(255, 0, 0, round(alpha * 255))
254
+ ) # fill=(R, G, B, Alpha)
255
+ combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size))
256
+
257
+ return combined.convert("RGB")
vision_agent/llm/llm.py CHANGED
@@ -11,6 +11,7 @@ from vision_agent.tools import (
11
11
  SYSTEM_PROMPT,
12
12
  GroundingDINO,
13
13
  GroundingSAM,
14
+ ZeroShotCounting,
14
15
  )
15
16
 
16
17
 
@@ -127,6 +128,9 @@ class OpenAILLM(LLM):
127
128
 
128
129
  return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
129
130
 
131
+ def generate_zero_shot_counter(self, question: str) -> Callable:
132
+ return lambda x: ZeroShotCounting()(**{"image": x})
133
+
130
134
 
131
135
  class AzureOpenAILLM(OpenAILLM):
132
136
  def __init__(
vision_agent/lmm/lmm.py CHANGED
@@ -15,6 +15,7 @@ from vision_agent.tools import (
15
15
  SYSTEM_PROMPT,
16
16
  GroundingDINO,
17
17
  GroundingSAM,
18
+ ZeroShotCounting,
18
19
  )
19
20
 
20
21
  _LOGGER = logging.getLogger(__name__)
@@ -272,6 +273,9 @@ class OpenAILMM(LMM):
272
273
 
273
274
  return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
274
275
 
276
+ def generate_zero_shot_counter(self, question: str) -> Callable:
277
+ return lambda x: ZeroShotCounting()(**{"image": x})
278
+
275
279
 
276
280
  class AzureOpenAILMM(OpenAILMM):
277
281
  def __init__(
@@ -11,6 +11,8 @@ from .tools import ( # Counter,
11
11
  GroundingDINO,
12
12
  GroundingSAM,
13
13
  ImageCaption,
14
+ ZeroShotCounting,
15
+ VisualPromptCounting,
14
16
  SegArea,
15
17
  SegIoU,
16
18
  Tool,
@@ -9,7 +9,13 @@ import requests
9
9
  from PIL import Image
10
10
  from PIL.Image import Image as ImageType
11
11
 
12
- from vision_agent.image_utils import convert_to_b64, get_image_size
12
+ from vision_agent.image_utils import (
13
+ convert_to_b64,
14
+ get_image_size,
15
+ rle_decode,
16
+ normalize_bbox,
17
+ denormalize_bbox,
18
+ )
13
19
  from vision_agent.tools.video import extract_frames_from_video
14
20
  from vision_agent.type_defs import LandingaiAPIKey
15
21
 
@@ -18,35 +24,6 @@ _LND_API_KEY = LandingaiAPIKey().api_key
18
24
  _LND_API_URL = "https://api.dev.landing.ai/v1/agent"
19
25
 
20
26
 
21
- def normalize_bbox(
22
- bbox: List[Union[int, float]], image_size: Tuple[int, ...]
23
- ) -> List[float]:
24
- r"""Normalize the bounding box coordinates to be between 0 and 1."""
25
- x1, y1, x2, y2 = bbox
26
- x1 = round(x1 / image_size[1], 2)
27
- y1 = round(y1 / image_size[0], 2)
28
- x2 = round(x2 / image_size[1], 2)
29
- y2 = round(y2 / image_size[0], 2)
30
- return [x1, y1, x2, y2]
31
-
32
-
33
- def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
34
- r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
35
-
36
- Parameters:
37
- mask_rle: Run-length as string formated (start length)
38
- shape: The (height, width) of array to return
39
- """
40
- s = mask_rle.split()
41
- starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
42
- starts -= 1
43
- ends = starts + lengths
44
- img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
45
- for lo, hi in zip(starts, ends):
46
- img[lo:hi] = 1
47
- return img.reshape(shape)
48
-
49
-
50
27
  class Tool(ABC):
51
28
  name: str
52
29
  description: str
@@ -250,7 +227,7 @@ class GroundingDINO(Tool):
250
227
  iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
251
228
 
252
229
  Returns:
253
- A list of dictionaries containing the labels, scores, and bboxes. Each dictionary contains the detection result for an image.
230
+ A dictionary containing the labels, scores, and bboxes, which is the detection result for the input image.
254
231
  """
255
232
  image_size = get_image_size(image)
256
233
  image_b64 = convert_to_b64(image)
@@ -346,7 +323,7 @@ class GroundingSAM(Tool):
346
323
  iou_threshold: the threshold for intersection over union used in nms algorithm. It will suppress the boxes which have iou greater than this threshold.
347
324
 
348
325
  Returns:
349
- A list of dictionaries containing the labels, scores, bboxes and masks. Each dictionary contains the segmentation result for an image.
326
+ A dictionary containing the labels, scores, bboxes and masks for the input image.
350
327
  """
351
328
  image_size = get_image_size(image)
352
329
  image_b64 = convert_to_b64(image)
@@ -357,19 +334,15 @@ class GroundingSAM(Tool):
357
334
  "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
358
335
  }
359
336
  data: Dict[str, Any] = _send_inference_request(request_data, "tools")
360
- ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
361
337
  if "bboxes" in data:
362
- ret_pred["bboxes"] = [
363
- normalize_bbox(box, image_size) for box in data["bboxes"]
364
- ]
338
+ data["bboxes"] = [normalize_bbox(box, image_size) for box in data["bboxes"]]
365
339
  if "masks" in data:
366
- ret_pred["masks"] = [
340
+ data["masks"] = [
367
341
  rle_decode(mask_rle=mask, shape=data["mask_shape"])
368
342
  for mask in data["masks"]
369
343
  ]
370
- ret_pred["labels"] = data["labels"]
371
- ret_pred["scores"] = data["scores"]
372
- return ret_pred
344
+ data.pop("mask_shape", None)
345
+ return data
373
346
 
374
347
 
375
348
  class DINOv(Tool):
@@ -493,6 +466,130 @@ class AgentGroundingSAM(GroundingSAM):
493
466
  return rets
494
467
 
495
468
 
469
+ class ZeroShotCounting(Tool):
470
+ r"""ZeroShotCounting is a tool that can count total number of instances of an object
471
+ present in an image belonging to same class without a text or visual prompt.
472
+
473
+ Example
474
+ -------
475
+ >>> import vision_agent as va
476
+ >>> zshot_count = va.tools.ZeroShotCounting()
477
+ >>> zshot_count("image1.jpg")
478
+ {'count': 45}
479
+ """
480
+
481
+ name = "zero_shot_counting_"
482
+ description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt."
483
+
484
+ usage = {
485
+ "required_parameters": [
486
+ {"name": "image", "type": "str"},
487
+ ],
488
+ "examples": [
489
+ {
490
+ "scenario": "Can you count the lids in the image ? Image name: lids.jpg",
491
+ "parameters": {"image": "lids.jpg"},
492
+ },
493
+ {
494
+ "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
495
+ "parameters": {"image": "tray.jpg"},
496
+ },
497
+ {
498
+ "scenario": "Can you build me an object counting tool ? Image name: shirts.jpg",
499
+ "parameters": {
500
+ "image": "shirts.jpg",
501
+ },
502
+ },
503
+ ],
504
+ }
505
+
506
+ # TODO: Add support for input multiple images, which aligns with the output type.
507
+ def __call__(self, image: Union[str, ImageType]) -> Dict:
508
+ """Invoke the Image captioning model.
509
+
510
+ Parameters:
511
+ image: the input image.
512
+
513
+ Returns:
514
+ A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
515
+ """
516
+ image_b64 = convert_to_b64(image)
517
+ data = {
518
+ "image": image_b64,
519
+ "tool": "zero_shot_counting",
520
+ }
521
+ return _send_inference_request(data, "tools")
522
+
523
+
524
+ class VisualPromptCounting(Tool):
525
+ r"""VisualPromptCounting is a tool that can count total number of instances of an object
526
+ present in an image belonging to same class with help of an visual prompt which is a bounding box.
527
+
528
+ Example
529
+ -------
530
+ >>> import vision_agent as va
531
+ >>> prompt_count = va.tools.VisualPromptCounting()
532
+ >>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
533
+ {'count': 23}
534
+ """
535
+
536
+ name = "visual_prompt_counting_"
537
+ description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box."
538
+
539
+ usage = {
540
+ "required_parameters": [
541
+ {"name": "image", "type": "str"},
542
+ {"name": "prompt", "type": "str"},
543
+ ],
544
+ "examples": [
545
+ {
546
+ "scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg",
547
+ "parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
548
+ },
549
+ {
550
+ "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
551
+ "parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
552
+ },
553
+ {
554
+ "scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg",
555
+ "parameters": {
556
+ "image": "shirts.jpg",
557
+ "prompt": "0.1, 0.15, 0.2, 0.2",
558
+ },
559
+ },
560
+ {
561
+ "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
562
+ "parameters": {
563
+ "image": "shoes.jpg",
564
+ "prompt": "0.1, 0.1, 0.6, 0.65",
565
+ },
566
+ },
567
+ ],
568
+ }
569
+
570
+ # TODO: Add support for input multiple images, which aligns with the output type.
571
+ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
572
+ """Invoke the Image captioning model.
573
+
574
+ Parameters:
575
+ image: the input image.
576
+
577
+ Returns:
578
+ A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
579
+ """
580
+ image_size = get_image_size(image)
581
+ bbox = [float(x) for x in prompt.split(",")]
582
+ prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
583
+ image_b64 = convert_to_b64(image)
584
+
585
+ data = {
586
+ "image": image_b64,
587
+ "prompt": prompt,
588
+ "tool": "few_shot_counting",
589
+ }
590
+ return _send_inference_request(data, "tools")
591
+
592
+
496
593
  class Crop(Tool):
497
594
  r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
498
595
 
@@ -643,6 +740,58 @@ class SegIoU(Tool):
643
740
  return cast(float, round(iou, 2))
644
741
 
645
742
 
743
+ class BboxContains(Tool):
744
+ name = "bbox_contains_"
745
+ description = "Given two bounding boxes, a target bounding box and a region bounding box, 'bbox_contains_' returns the intersection of the two bounding boxes over the target bounding box, reflects the percentage area of the target bounding box overlaps with the region bounding box. This is a good tool for determining if the region object contains the target object."
746
+ usage = {
747
+ "required_parameters": [
748
+ {"name": "target", "type": "List[int]"},
749
+ {"name": "target_class", "type": "str"},
750
+ {"name": "region", "type": "List[int]"},
751
+ {"name": "region_class", "type": "str"},
752
+ ],
753
+ "examples": [
754
+ {
755
+ "scenario": "Determine if the dog on the couch, bounding box of the dog: [0.2, 0.21, 0.34, 0.42], bounding box of the couch: [0.3, 0.31, 0.44, 0.52]",
756
+ "parameters": {
757
+ "target": [0.2, 0.21, 0.34, 0.42],
758
+ "target_class": "dog",
759
+ "region": [0.3, 0.31, 0.44, 0.52],
760
+ "region_class": "couch",
761
+ },
762
+ },
763
+ {
764
+ "scenario": "Check if the kid is in the pool? bounding box of the kid: [0.2, 0.21, 0.34, 0.42], bounding box of the pool: [0.3, 0.31, 0.44, 0.52]",
765
+ "parameters": {
766
+ "target": [0.2, 0.21, 0.34, 0.42],
767
+ "target_class": "kid",
768
+ "region": [0.3, 0.31, 0.44, 0.52],
769
+ "region_class": "pool",
770
+ },
771
+ },
772
+ ],
773
+ }
774
+
775
+ def __call__(
776
+ self, target: List[int], target_class: str, region: List[int], region_class: str
777
+ ) -> Dict[str, Union[str, float]]:
778
+ x1, y1, x2, y2 = target
779
+ x3, y3, x4, y4 = region
780
+ xA = max(x1, x3)
781
+ yA = max(y1, y3)
782
+ xB = min(x2, x4)
783
+ yB = min(y2, y4)
784
+ inter_area = max(0, xB - xA) * max(0, yB - yA)
785
+ boxa_area = (x2 - x1) * (y2 - y1)
786
+ iou = inter_area / float(boxa_area)
787
+ area = round(iou, 2)
788
+ return {
789
+ "target_class": target_class,
790
+ "region_class": region_class,
791
+ "intersection": area,
792
+ }
793
+
794
+
646
795
  class BoxDistance(Tool):
647
796
  name = "box_distance_"
648
797
  description = (
@@ -750,6 +899,8 @@ TOOLS = {
750
899
  ImageCaption,
751
900
  GroundingDINO,
752
901
  AgentGroundingSAM,
902
+ ZeroShotCounting,
903
+ VisualPromptCounting,
753
904
  AgentDINOv,
754
905
  ExtractFrames,
755
906
  Crop,
@@ -757,6 +908,7 @@ TOOLS = {
757
908
  SegArea,
758
909
  BboxIoU,
759
910
  SegIoU,
911
+ BboxContains,
760
912
  BoxDistance,
761
913
  Calculator,
762
914
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -41,7 +41,7 @@ Description-Content-Type: text/markdown
41
41
 
42
42
  Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
43
43
  Many current vision problems can easily take hours or days to solve, you need to find the
44
- right model, figure out how to use it, possibly write programming logic around it to
44
+ right model, figure out how to use it, possibly write programming logic around it to
45
45
  accomplish the task you want or even more expensive, train your own model. Vision Agent
46
46
  aims to provide an in-seconds experience by allowing users to describe their problem in
47
47
  text and utilizing agent frameworks to solve the task for them. Check out our discord
@@ -138,6 +138,9 @@ you. For example:
138
138
  | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
139
139
  | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
140
140
  | ExtractFrames | ExtractFrames extracts frames with motion from a video. |
141
+ | ExtractFrames | ExtractFrames extracts frames with motion from a video. |
142
+ | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
143
+ | VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |
141
144
 
142
145
 
143
146
  It also has a basic set of calculate tools such as add, subtract, multiply and divide.
@@ -5,7 +5,7 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
5
5
  vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
6
6
  vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=QWIirRBB3ZPg3figWcf8-g9ltFydM1BDn75LbXWbep0,22735
8
+ vision_agent/agent/vision_agent.py,sha256=MTxeV5_Sghqoe2aOW9EbNgiq61sVCcF3ZndJ7BZl6x0,23588
9
9
  vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
10
10
  vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
11
  vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
@@ -13,17 +13,17 @@ vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,
13
13
  vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
14
14
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
16
- vision_agent/image_utils.py,sha256=qRN_Y1XXBm9EL6V53OZUq21h0spIa1J6X9YDbe6B87o,4805
16
+ vision_agent/image_utils.py,sha256=Cg4aKO1tQiETT1gdsZ50XzORBtJnBFfMG2cKJyjaY6Q,7555
17
17
  vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
18
- vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
18
+ vision_agent/llm/llm.py,sha256=gwDQ9-p9wEn24xi1019e5jzTGQg4xWDSqBCsqIqGcU4,5168
19
19
  vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
20
- vision_agent/lmm/lmm.py,sha256=1E7e_S_0fOKnf6mSsEdkXvsIjGmhBGl5XW4By2jvhbY,10045
21
- vision_agent/tools/__init__.py,sha256=dkzk9amNzTEKULMB1xRJspqEGpzNPGuccWeXrv1xI0U,280
20
+ vision_agent/lmm/lmm.py,sha256=FjxCuIk0KXuWnfY4orVmdyhJW2I4C6i5QNNEXk7gybk,10197
21
+ vision_agent/tools/__init__.py,sha256=BlfxqbYkB0oODhnSmQg1UyzQm73AvvjCjrIiOWBIYDs,328
22
22
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
23
- vision_agent/tools/tools.py,sha256=ybhCyutEGzHPKuR0Cu--Nb--KubjYvyzLEzVQYzIMTw,29148
23
+ vision_agent/tools/tools.py,sha256=gCjHs5vJuGNBFsnJWFT7PX3wTyfHgtrgX1Eq9vqknN0,34979
24
24
  vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
25
25
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
26
- vision_agent-0.1.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
- vision_agent-0.1.4.dist-info/METADATA,sha256=FyBYGPHgC0uV7uy7wph8yvdQpEWSACnGR96y6Jt-E6A,6233
28
- vision_agent-0.1.4.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
- vision_agent-0.1.4.dist-info/RECORD,,
26
+ vision_agent-0.1.6.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
+ vision_agent-0.1.6.dist-info/METADATA,sha256=Ig2tSKyeH8a2A8xZRq72M9XnKyi4_03UM4hDiFpT-eU,6574
28
+ vision_agent-0.1.6.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
+ vision_agent-0.1.6.dist-info/RECORD,,