vision-agent 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,10 @@ from PIL import Image
11
11
  from PIL.Image import Image as ImageType
12
12
  from scipy.spatial import distance # type: ignore
13
13
 
14
- from vision_agent.image_utils import (
14
+ from vision_agent.lmm import OpenAILMM
15
+ from vision_agent.tools.tool_utils import _send_inference_request
16
+ from vision_agent.utils import extract_frames_from_video
17
+ from vision_agent.utils.image_utils import (
15
18
  b64_to_pil,
16
19
  convert_to_b64,
17
20
  denormalize_bbox,
@@ -19,9 +22,6 @@ from vision_agent.image_utils import (
19
22
  normalize_bbox,
20
23
  rle_decode,
21
24
  )
22
- from vision_agent.lmm import OpenAILMM
23
- from vision_agent.tools.tool_utils import _send_inference_request
24
- from vision_agent.tools.video import extract_frames_from_video
25
25
 
26
26
  _LOGGER = logging.getLogger(__name__)
27
27
 
@@ -174,15 +174,15 @@ class GroundingDINO(Tool):
174
174
  """
175
175
 
176
176
  name = "grounding_dino_"
177
- description = "'grounding_dino_' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
177
+ description = "'grounding_dino_' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
178
178
  usage = {
179
179
  "required_parameters": [
180
180
  {"name": "prompt", "type": "str"},
181
181
  {"name": "image", "type": "str"},
182
182
  ],
183
183
  "optional_parameters": [
184
- {"name": "box_threshold", "type": "float"},
185
- {"name": "iou_threshold", "type": "float"},
184
+ {"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
185
+ {"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
186
186
  ],
187
187
  "examples": [
188
188
  {
@@ -209,7 +209,7 @@ class GroundingDINO(Tool):
209
209
  "prompt": "red shirt. green shirt",
210
210
  "image": "shirts.jpg",
211
211
  "box_threshold": 0.20,
212
- "iou_threshold": 0.75,
212
+ "iou_threshold": 0.20,
213
213
  },
214
214
  },
215
215
  ],
@@ -221,7 +221,7 @@ class GroundingDINO(Tool):
221
221
  prompt: str,
222
222
  image: Union[str, Path, ImageType],
223
223
  box_threshold: float = 0.20,
224
- iou_threshold: float = 0.75,
224
+ iou_threshold: float = 0.20,
225
225
  ) -> Dict:
226
226
  """Invoke the Grounding DINO model.
227
227
 
@@ -249,7 +249,7 @@ class GroundingDINO(Tool):
249
249
  data["scores"] = [round(score, 2) for score in data["scores"]]
250
250
  if "labels" in data:
251
251
  data["labels"] = list(data["labels"])
252
- data["size"] = (image_size[1], image_size[0])
252
+ data["image_size"] = image_size
253
253
  return data
254
254
 
255
255
 
@@ -277,15 +277,15 @@ class GroundingSAM(Tool):
277
277
  """
278
278
 
279
279
  name = "grounding_sam_"
280
- description = "'grounding_sam_' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
280
+ description = "'grounding_sam_' is a tool that can detect and segment multiple objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
281
281
  usage = {
282
282
  "required_parameters": [
283
283
  {"name": "prompt", "type": "str"},
284
284
  {"name": "image", "type": "str"},
285
285
  ],
286
286
  "optional_parameters": [
287
- {"name": "box_threshold", "type": "float"},
288
- {"name": "iou_threshold", "type": "float"},
287
+ {"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
288
+ {"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
289
289
  ],
290
290
  "examples": [
291
291
  {
@@ -312,7 +312,7 @@ class GroundingSAM(Tool):
312
312
  "prompt": "red shirt, green shirt",
313
313
  "image": "shirts.jpg",
314
314
  "box_threshold": 0.20,
315
- "iou_threshold": 0.75,
315
+ "iou_threshold": 0.20,
316
316
  },
317
317
  },
318
318
  ],
@@ -324,7 +324,7 @@ class GroundingSAM(Tool):
324
324
  prompt: str,
325
325
  image: Union[str, ImageType],
326
326
  box_threshold: float = 0.2,
327
- iou_threshold: float = 0.75,
327
+ iou_threshold: float = 0.2,
328
328
  ) -> Dict:
329
329
  """Invoke the Grounding SAM model.
330
330
 
@@ -353,6 +353,7 @@ class GroundingSAM(Tool):
353
353
  rle_decode(mask_rle=mask, shape=data["mask_shape"])
354
354
  for mask in data["masks"]
355
355
  ]
356
+ data["image_size"] = image_size
356
357
  data.pop("mask_shape", None)
357
358
  return data
358
359
 
@@ -422,7 +423,6 @@ class DINOv(Tool):
422
423
  request_data = {
423
424
  "prompt": prompt,
424
425
  "image": image_b64,
425
- "tool": "dinov",
426
426
  }
427
427
  data: Dict[str, Any] = _send_inference_request(request_data, "dinov")
428
428
  if "bboxes" in data:
@@ -435,6 +435,8 @@ class DINOv(Tool):
435
435
  for mask in data["masks"]
436
436
  ]
437
437
  data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
438
+ mask_shape = data.pop("mask_shape", None)
439
+ data["image_size"] = (mask_shape[0], mask_shape[1]) if mask_shape else None
438
440
  return data
439
441
 
440
442
 
@@ -790,33 +792,49 @@ class Crop(Tool):
790
792
  return {"image": tmp.name}
791
793
 
792
794
 
793
- class BboxArea(Tool):
794
- r"""BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places."""
795
+ class BboxStats(Tool):
796
+ r"""BboxStats returns the height, width and area of the bounding box in pixels to 2 decimal places."""
795
797
 
796
- name = "bbox_area_"
797
- description = "'bbox_area_' returns the area of the given bounding box in pixels normalized to 2 decimal places."
798
+ name = "bbox_stats_"
799
+ description = "'bbox_stats_' returns the height, width and area of the given bounding box in pixels to 2 decimal places."
798
800
  usage = {
799
- "required_parameters": [{"name": "bboxes", "type": "List[int]"}],
801
+ "required_parameters": [
802
+ {"name": "bboxes", "type": "List[int]"},
803
+ {"name": "image_size", "type": "Tuple[int]"},
804
+ ],
800
805
  "examples": [
801
806
  {
802
- "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
803
- "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
804
- }
807
+ "scenario": "Calculate the width and height of the bounding box [0.2, 0.21, 0.34, 0.42]",
808
+ "parameters": {
809
+ "bboxes": [[0.2, 0.21, 0.34, 0.42]],
810
+ "image_size": (500, 1200),
811
+ },
812
+ },
813
+ {
814
+ "scenario": "Calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
815
+ "parameters": {
816
+ "bboxes": [[0.2, 0.21, 0.34, 0.42]],
817
+ "image_size": (640, 480),
818
+ },
819
+ },
805
820
  ],
806
821
  }
807
822
 
808
- def __call__(self, bboxes: List[Dict]) -> List[Dict]:
823
+ def __call__(
824
+ self, bboxes: List[List[int]], image_size: Tuple[int, int]
825
+ ) -> List[Dict]:
809
826
  areas = []
810
- for elt in bboxes:
811
- height, width = elt["size"]
812
- for label, bbox in zip(elt["labels"], elt["bboxes"]):
813
- x1, y1, x2, y2 = bbox
814
- areas.append(
815
- {
816
- "area": round((x2 - x1) * (y2 - y1) * width * height, 2),
817
- "label": label,
818
- }
819
- )
827
+ height, width = image_size
828
+ for bbox in bboxes:
829
+ x1, y1, x2, y2 = bbox
830
+ areas.append(
831
+ {
832
+ "width": round((x2 - x1) * width, 2),
833
+ "height": round((y2 - y1) * height, 2),
834
+ "area": round((x2 - x1) * (y2 - y1) * width * height, 2),
835
+ }
836
+ )
837
+
820
838
  return areas
821
839
 
822
840
 
@@ -1055,22 +1073,25 @@ class ExtractFrames(Tool):
1055
1073
  r"""Extract frames from a video."""
1056
1074
 
1057
1075
  name = "extract_frames_"
1058
- description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
1076
+ description = "'extract_frames_' extracts frames from a video every 2 seconds, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
1059
1077
  usage = {
1060
1078
  "required_parameters": [{"name": "video_uri", "type": "str"}],
1079
+ "optional_parameters": [{"name": "frames_every", "type": "float"}],
1061
1080
  "examples": [
1062
1081
  {
1063
1082
  "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
1064
1083
  "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
1065
1084
  },
1066
1085
  {
1067
- "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
1068
- "parameters": {"video_uri": "tests/data/test.mp4"},
1086
+ "scenario": "Can you extract the images from this video file at every 2 seconds ? Video path: tests/data/test.mp4",
1087
+ "parameters": {"video_uri": "tests/data/test.mp4", "frames_every": 2},
1069
1088
  },
1070
1089
  ],
1071
1090
  }
1072
1091
 
1073
- def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
1092
+ def __call__(
1093
+ self, video_uri: str, frames_every: float = 2
1094
+ ) -> List[Tuple[str, float]]:
1074
1095
  """Extract frames from a video.
1075
1096
 
1076
1097
 
@@ -1080,7 +1101,7 @@ class ExtractFrames(Tool):
1080
1101
  Returns:
1081
1102
  a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
1082
1103
  """
1083
- frames = extract_frames_from_video(video_uri)
1104
+ frames = extract_frames_from_video(video_uri, fps=round(1 / frames_every, 2))
1084
1105
  result = []
1085
1106
  _LOGGER.info(
1086
1107
  f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
@@ -1183,7 +1204,7 @@ TOOLS = {
1183
1204
  AgentDINOv,
1184
1205
  ExtractFrames,
1185
1206
  Crop,
1186
- BboxArea,
1207
+ BboxStats,
1187
1208
  SegArea,
1188
1209
  ObjectDistance,
1189
1210
  BboxContains,
@@ -1,13 +1,19 @@
1
1
  import inspect
2
+ import io
3
+ import logging
2
4
  import tempfile
3
5
  from importlib import resources
4
- from typing import Any, Callable, Dict, List
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Dict, List, Tuple, Union
5
8
 
6
9
  import numpy as np
10
+ import pandas as pd
11
+ import requests
7
12
  from PIL import Image, ImageDraw, ImageFont
8
13
 
9
- from vision_agent.image_utils import convert_to_b64, normalize_bbox
10
14
  from vision_agent.tools.tool_utils import _send_inference_request
15
+ from vision_agent.utils import extract_frames_from_video
16
+ from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode
11
17
 
12
18
  COLORS = [
13
19
  (158, 218, 229),
@@ -31,6 +37,10 @@ COLORS = [
31
37
  (255, 127, 14),
32
38
  (31, 119, 180),
33
39
  ]
40
+ _API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV"
41
+ _OCR_URL = "https://app.landing.ai/ocr/v1/detect-text"
42
+ logging.basicConfig(level=logging.INFO)
43
+ _LOGGER = logging.getLogger(__name__)
34
44
 
35
45
 
36
46
  def grounding_dino(
@@ -39,23 +49,30 @@ def grounding_dino(
39
49
  box_threshold: float = 0.20,
40
50
  iou_threshold: float = 0.75,
41
51
  ) -> List[Dict[str, Any]]:
42
- """'grounding_dino' is a tool that can detect arbitrary objects with inputs such as
43
- category names or referring expressions.
52
+ """'grounding_dino' is a tool that can detect and count objects given a text prompt
53
+ such as category names or referring expressions. It returns a list and count of
54
+ bounding boxes, label names and associated probability scores.
44
55
 
45
56
  Parameters:
46
57
  prompt (str): The prompt to ground to the image.
47
58
  image (np.ndarray): The image to ground the prompt to.
48
- box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20.
49
- iou_threshold (float, optional): The threshold for the Intersection over Union (IoU). Defaults to 0.75.
59
+ box_threshold (float, optional): The threshold for the box detection. Defaults
60
+ to 0.20.
61
+ iou_threshold (float, optional): The threshold for the Intersection over Union
62
+ (IoU). Defaults to 0.75.
50
63
 
51
64
  Returns:
52
65
  List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
53
- bounding box of the detected objects with normalized coordinates.
66
+ bounding box of the detected objects with normalized coordinates
67
+ (x1, y1, x2, y2).
54
68
 
55
69
  Example
56
70
  -------
57
71
  >>> grounding_dino("car. dinosaur", image)
58
- [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}]
72
+ [
73
+ {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
74
+ {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
75
+ ]
59
76
  """
60
77
  image_size = image.shape[:2]
61
78
  image_b64 = convert_to_b64(Image.fromarray(image))
@@ -78,6 +95,147 @@ def grounding_dino(
78
95
  return return_data
79
96
 
80
97
 
98
+ def grounding_sam(
99
+ prompt: str,
100
+ image: np.ndarray,
101
+ box_threshold: float = 0.20,
102
+ iou_threshold: float = 0.75,
103
+ ) -> List[Dict[str, Any]]:
104
+ """'grounding_sam' is a tool that can detect and segment objects given a text
105
+ prompt such as category names or referring expressions. It returns a list of
106
+ bounding boxes, label names and masks file names and associated probability scores.
107
+
108
+ Parameters:
109
+ prompt (str): The prompt to ground to the image.
110
+ image (np.ndarray): The image to ground the prompt to.
111
+ box_threshold (float, optional): The threshold for the box detection. Defaults
112
+ to 0.20.
113
+ iou_threshold (float, optional): The threshold for the Intersection over Union
114
+ (IoU). Defaults to 0.75.
115
+
116
+ Returns:
117
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label,
118
+ bounding box, and mask of the detected objects with normalized coordinates
119
+ (x1, y1, x2, y2).
120
+
121
+ Example
122
+ -------
123
+ >>> grounding_sam("car. dinosaur", image)
124
+ [
125
+ {
126
+ 'score': 0.99,
127
+ 'label': 'dinosaur',
128
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
129
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
130
+ [0, 0, 0, ..., 0, 0, 0],
131
+ ...,
132
+ [0, 0, 0, ..., 0, 0, 0],
133
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
134
+ },
135
+ ]
136
+ """
137
+ image_size = image.shape[:2]
138
+ image_b64 = convert_to_b64(Image.fromarray(image))
139
+ request_data = {
140
+ "prompt": prompt,
141
+ "image": image_b64,
142
+ "tool": "visual_grounding_segment",
143
+ "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
144
+ }
145
+ data: Dict[str, Any] = _send_inference_request(request_data, "tools")
146
+ return_data = []
147
+ for i in range(len(data["bboxes"])):
148
+ return_data.append(
149
+ {
150
+ "score": round(data["scores"][i], 2),
151
+ "label": data["labels"][i],
152
+ "bbox": normalize_bbox(data["bboxes"][i], image_size),
153
+ "mask": rle_decode(mask_rle=data["masks"][i], shape=data["mask_shape"]),
154
+ }
155
+ )
156
+ return return_data
157
+
158
+
159
+ def extract_frames(
160
+ video_uri: Union[str, Path], fps: float = 0.5
161
+ ) -> List[Tuple[np.ndarray, float]]:
162
+ """'extract_frames' extracts frames from a video, returns a list of tuples (frame,
163
+ timestamp), where timestamp is the relative time in seconds where the frame was
164
+ captured. The frame is a local image file path.
165
+
166
+ Parameters:
167
+ video_uri (Union[str, Path]): The path to the video file.
168
+ fps (float, optional): The frame rate per second to extract the frames. Defaults
169
+ to 0.5.
170
+
171
+ Returns:
172
+ List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
173
+ and the timestamp in seconds.
174
+
175
+ Example
176
+ -------
177
+ >>> extract_frames("path/to/video.mp4")
178
+ [(frame1, 0.0), (frame2, 0.5), ...]
179
+ """
180
+
181
+ return extract_frames_from_video(str(video_uri), fps)
182
+
183
+
184
+ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
185
+ """'ocr' extracts text from an image. It returns a list of detected text, bounding
186
+ boxes, and confidence scores.
187
+
188
+ Parameters:
189
+ image (np.ndarray): The image to extract text from.
190
+
191
+ Returns:
192
+ List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox,
193
+ and confidence score.
194
+
195
+ Example
196
+ -------
197
+ >>> ocr(image)
198
+ [
199
+ {'label': 'some text', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
200
+ ]
201
+ """
202
+
203
+ pil_image = Image.fromarray(image).convert("RGB")
204
+ image_size = pil_image.size[::-1]
205
+ image_buffer = io.BytesIO()
206
+ pil_image.save(image_buffer, format="PNG")
207
+ buffer_bytes = image_buffer.getvalue()
208
+ image_buffer.close()
209
+
210
+ res = requests.post(
211
+ _OCR_URL,
212
+ files={"images": buffer_bytes},
213
+ data={"language": "en"},
214
+ headers={"contentType": "multipart/form-data", "apikey": _API_KEY},
215
+ )
216
+
217
+ if res.status_code != 200:
218
+ raise ValueError(f"OCR request failed with status code {res.status_code}")
219
+
220
+ data = res.json()
221
+ output = []
222
+ for det in data[0]:
223
+ label = det["text"]
224
+ box = [
225
+ det["location"][0]["x"],
226
+ det["location"][0]["y"],
227
+ det["location"][2]["x"],
228
+ det["location"][2]["y"],
229
+ ]
230
+ box = normalize_bbox(box, image_size)
231
+ output.append({"label": label, "bbox": box, "score": round(det["score"], 2)})
232
+
233
+ return output
234
+
235
+
236
+ # Utility and visualization functions
237
+
238
+
81
239
  def load_image(image_path: str) -> np.ndarray:
82
240
  """'load_image' is a utility function that loads an image from the given path.
83
241
 
@@ -117,24 +275,33 @@ def save_image(image: np.ndarray) -> str:
117
275
  return f.name
118
276
 
119
277
 
120
- def display_bounding_boxes(
278
+ def overlay_bounding_boxes(
121
279
  image: np.ndarray, bboxes: List[Dict[str, Any]]
122
280
  ) -> np.ndarray:
123
- """'display_bounding_boxes' is a utility function that displays bounding boxes on an image.
281
+ """'display_bounding_boxes' is a utility function that displays bounding boxes on
282
+ an image.
124
283
 
125
284
  Parameters:
126
285
  image (np.ndarray): The image to display the bounding boxes on.
127
- bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding boxes.
286
+ bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding
287
+ boxes.
128
288
 
129
289
  Returns:
130
- np.ndarray: The image with the bounding boxes displayed.
290
+ np.ndarray: The image with the bounding boxes, labels and scores displayed.
131
291
 
132
292
  Example
133
293
  -------
134
- >>> image_with_bboxes = display_bounding_boxes(image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}])
294
+ >>> image_with_bboxes = display_bounding_boxes(
295
+ image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
296
+ )
135
297
  """
136
298
  pil_image = Image.fromarray(image.astype(np.uint8))
137
299
 
300
+ if len(set([box["label"] for box in bboxes])) > len(COLORS):
301
+ _LOGGER.warning(
302
+ "Number of unique labels exceeds the number of available colors. Some labels may have the same color."
303
+ )
304
+
138
305
  color = {
139
306
  label: COLORS[i % len(COLORS)]
140
307
  for i, label in enumerate(set([box["label"] for box in bboxes]))
@@ -167,15 +334,109 @@ def display_bounding_boxes(
167
334
  return np.array(pil_image.convert("RGB"))
168
335
 
169
336
 
170
- def get_tool_documentation(funcs: List[Callable]) -> str:
337
+ def overlay_segmentation_masks(
338
+ image: np.ndarray, masks: List[Dict[str, Any]]
339
+ ) -> np.ndarray:
340
+ """'display_segmentation_masks' is a utility function that displays segmentation
341
+ masks.
342
+
343
+ Parameters:
344
+ image (np.ndarray): The image to display the masks on.
345
+ masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
346
+
347
+ Returns:
348
+ np.ndarray: The image with the masks displayed.
349
+
350
+ Example
351
+ -------
352
+ >>> image_with_masks = display_segmentation_masks(
353
+ image,
354
+ [{
355
+ 'score': 0.99,
356
+ 'label': 'dinosaur',
357
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
358
+ [0, 0, 0, ..., 0, 0, 0],
359
+ ...,
360
+ [0, 0, 0, ..., 0, 0, 0],
361
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
362
+ }],
363
+ )
364
+ """
365
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")
366
+
367
+ if len(set([mask["label"] for mask in masks])) > len(COLORS):
368
+ _LOGGER.warning(
369
+ "Number of unique labels exceeds the number of available colors. Some labels may have the same color."
370
+ )
371
+
372
+ color = {
373
+ label: COLORS[i % len(COLORS)]
374
+ for i, label in enumerate(set([mask["label"] for mask in masks]))
375
+ }
376
+
377
+ for elt in masks:
378
+ mask = elt["mask"]
379
+ label = elt["label"]
380
+ np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
381
+ np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
382
+ mask_img = Image.fromarray(np_mask.astype(np.uint8))
383
+ pil_image = Image.alpha_composite(pil_image, mask_img)
384
+ return np.array(pil_image.convert("RGB"))
385
+
386
+
387
+ def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
171
388
  docstrings = ""
172
389
  for func in funcs:
173
- docstrings += f"{func.__name__}: {inspect.signature(func)}\n{func.__doc__}\n\n"
390
+ docstrings += f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}\n\n"
174
391
 
175
392
  return docstrings
176
393
 
177
394
 
178
- TOOLS_DOCSTRING = get_tool_documentation([load_image, grounding_dino])
395
+ def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
396
+ descriptions = ""
397
+ for func in funcs:
398
+ description = func.__doc__
399
+ if description is None:
400
+ description = ""
401
+
402
+ description = (
403
+ description[: description.find("Parameters:")].replace("\n", " ").strip()
404
+ )
405
+ description = " ".join(description.split())
406
+ descriptions += f"- {func.__name__}{inspect.signature(func)}: {description}\n"
407
+ return descriptions
408
+
409
+
410
+ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
411
+ data: Dict[str, List[str]] = {"desc": [], "doc": []}
412
+
413
+ for func in funcs:
414
+ desc = func.__doc__
415
+ if desc is None:
416
+ desc = ""
417
+ desc = desc[: desc.find("Parameters:")].replace("\n", " ").strip()
418
+ desc = " ".join(desc.split())
419
+
420
+ doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
421
+ data["desc"].append(desc)
422
+ data["doc"].append(doc)
423
+
424
+ return pd.DataFrame(data) # type: ignore
425
+
426
+
427
+ TOOLS = [
428
+ grounding_dino,
429
+ grounding_sam,
430
+ extract_frames,
431
+ ocr,
432
+ load_image,
433
+ save_image,
434
+ overlay_bounding_boxes,
435
+ overlay_segmentation_masks,
436
+ ]
437
+ TOOLS_DF = get_tools_df(TOOLS) # type: ignore
438
+ TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
439
+ TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
179
440
  UTILITIES_DOCSTRING = get_tool_documentation(
180
- [load_image, save_image, display_bounding_boxes]
441
+ [load_image, save_image, overlay_bounding_boxes]
181
442
  )
@@ -0,0 +1,3 @@
1
+ from .execute import Execute
2
+ from .sim import Sim
3
+ from .video import extract_frames_from_video