vision-agent 0.0.40__tar.gz → 0.0.41__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {vision_agent-0.0.40 → vision_agent-0.0.41}/PKG-INFO +1 -1
  2. {vision_agent-0.0.40 → vision_agent-0.0.41}/pyproject.toml +1 -1
  3. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/vision_agent.py +73 -19
  4. vision_agent-0.0.41/vision_agent/image_utils.py +152 -0
  5. vision_agent-0.0.41/vision_agent/tools/__init__.py +15 -0
  6. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/tools/tools.py +123 -60
  7. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/tools/video.py +8 -4
  8. vision_agent-0.0.40/vision_agent/image_utils.py +0 -62
  9. vision_agent-0.0.40/vision_agent/tools/__init__.py +0 -2
  10. {vision_agent-0.0.40 → vision_agent-0.0.41}/LICENSE +0 -0
  11. {vision_agent-0.0.40 → vision_agent-0.0.41}/README.md +0 -0
  12. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/__init__.py +0 -0
  13. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/__init__.py +0 -0
  14. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/agent.py +0 -0
  15. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/easytool.py +0 -0
  16. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/easytool_prompts.py +0 -0
  17. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/reflexion.py +0 -0
  18. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/reflexion_prompts.py +0 -0
  19. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/agent/vision_agent_prompts.py +0 -0
  20. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/data/__init__.py +0 -0
  21. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/data/data.py +0 -0
  22. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/emb/__init__.py +0 -0
  23. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/emb/emb.py +0 -0
  24. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/llm/__init__.py +0 -0
  25. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/llm/llm.py +0 -0
  26. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/lmm/__init__.py +0 -0
  27. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/lmm/lmm.py +0 -0
  28. {vision_agent-0.0.40 → vision_agent-0.0.41}/vision_agent/tools/prompts.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.0.40
3
+ Version: 0.0.41
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.0.40"
7
+ version = "0.0.41"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  import sys
4
+ import tempfile
4
5
  from pathlib import Path
5
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
7
 
7
8
  from tabulate import tabulate
8
9
 
10
+ from vision_agent.image_utils import overlay_bboxes, overlay_masks
9
11
  from vision_agent.llm import LLM, OpenAILLM
10
12
  from vision_agent.lmm import LMM, OpenAILMM
11
13
  from vision_agent.tools import TOOLS
@@ -248,12 +250,12 @@ def retrieval(
248
250
  tools: Dict[int, Any],
249
251
  previous_log: str,
250
252
  reflections: str,
251
- ) -> Tuple[List[Dict], str]:
253
+ ) -> Tuple[Dict, str]:
252
254
  tool_id = choose_tool(
253
255
  model, question, {k: v["description"] for k, v in tools.items()}, reflections
254
256
  )
255
257
  if tool_id is None:
256
- return [{}], ""
258
+ return {}, ""
257
259
  _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
258
260
 
259
261
  tool_instructions = tools[tool_id]
@@ -265,14 +267,12 @@ def retrieval(
265
267
  )
266
268
  _LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
267
269
  if parameters is None:
268
- return [{}], ""
269
- tool_results = [
270
- {"task": question, "tool_name": tool_name, "parameters": parameters}
271
- ]
270
+ return {}, ""
271
+ tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
272
272
 
273
273
  _LOGGER.info(
274
- f"""Going to run the following {len(tool_results)} tool(s) in sequence:
275
- {tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
274
+ f"""Going to run the following tool(s) in sequence:
275
+ {tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
276
276
  )
277
277
 
278
278
  def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
@@ -286,12 +286,10 @@ def retrieval(
286
286
  call_results.append(function_call(tools[tool_id]["class"], parameters))
287
287
  return call_results
288
288
 
289
- call_results = []
290
- for i, result in enumerate(tool_results):
291
- call_results.extend(parse_tool_results(result))
292
- tool_results[i]["call_results"] = call_results
289
+ call_results = parse_tool_results(tool_results)
290
+ tool_results["call_results"] = call_results
293
291
 
294
- call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
292
+ call_results_str = str(call_results)
295
293
  _LOGGER.info(f"\tCall Results: {call_results_str}")
296
294
  return tool_results, call_results_str
297
295
 
@@ -335,7 +333,11 @@ def self_reflect(
335
333
  tool_results=str(tool_result),
336
334
  final_answer=final_answer,
337
335
  )
338
- if issubclass(type(reflect_model), LMM):
336
+ if (
337
+ issubclass(type(reflect_model), LMM)
338
+ and image is not None
339
+ and Path(image).suffix in [".jpg", ".jpeg", ".png"]
340
+ ):
339
341
  return reflect_model(prompt, image=image) # type: ignore
340
342
  return reflect_model(prompt)
341
343
 
@@ -345,6 +347,56 @@ def parse_reflect(reflect: str) -> bool:
345
347
  return "finish" in reflect.lower() and len(reflect) < 100
346
348
 
347
349
 
350
+ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
351
+ image_to_data: Dict[str, Dict] = {}
352
+ for tool_result in all_tool_results:
353
+ if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
354
+ continue
355
+
356
+ parameters = tool_result["parameters"]
357
+ # parameters can either be a dictionary or list, parameters can also be malformed
358
+ # becaus the LLM builds them
359
+ if isinstance(parameters, dict):
360
+ if "image" not in parameters:
361
+ continue
362
+ parameters = [parameters]
363
+ elif isinstance(tool_result["parameters"], list):
364
+ if (
365
+ len(tool_result["parameters"]) < 1
366
+ and "image" not in tool_result["parameters"][0]
367
+ ):
368
+ continue
369
+
370
+ for param, call_result in zip(parameters, tool_result["call_results"]):
371
+
372
+ # calls can fail, so we need to check if the call was successful
373
+ if not isinstance(call_result, dict):
374
+ continue
375
+ if "bboxes" not in call_result:
376
+ continue
377
+
378
+ # if the call was successful, then we can add the image data
379
+ image = param["image"]
380
+ if image not in image_to_data:
381
+ image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
382
+
383
+ image_to_data[image]["bboxes"].extend(call_result["bboxes"])
384
+ image_to_data[image]["labels"].extend(call_result["labels"])
385
+ if "masks" in call_result:
386
+ image_to_data[image]["masks"].extend(call_result["masks"])
387
+
388
+ visualized_images = []
389
+ for image in image_to_data:
390
+ image_path = Path(image)
391
+ image_data = image_to_data[image]
392
+ image = overlay_masks(image_path, image_data)
393
+ image = overlay_bboxes(image, image_data)
394
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
395
+ image.save(f.name)
396
+ visualized_images.append(f.name)
397
+ return visualized_images
398
+
399
+
348
400
  class VisionAgent(Agent):
349
401
  r"""Vision Agent is an agent framework that utilizes tools as well as self
350
402
  reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
@@ -389,7 +441,8 @@ class VisionAgent(Agent):
389
441
  """Invoke the vision agent.
390
442
 
391
443
  Parameters:
392
- input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
444
+ input: a prompt that describe the task or a conversation in the format of
445
+ [{"role": "user", "content": "describe your task here..."}].
393
446
  image: the input image referenced in the prompt parameter.
394
447
 
395
448
  Returns:
@@ -436,9 +489,8 @@ class VisionAgent(Agent):
436
489
  self.answer_model, task_str, call_results, previous_log, reflections
437
490
  )
438
491
 
439
- for tool_result in tool_results:
440
- tool_result["answer"] = answer
441
- all_tool_results.extend(tool_results)
492
+ tool_results["answer"] = answer
493
+ all_tool_results.append(tool_results)
442
494
 
443
495
  _LOGGER.info(f"\tAnswer: {answer}")
444
496
  answers.append({"task": task_str, "answer": answer})
@@ -448,13 +500,15 @@ class VisionAgent(Agent):
448
500
  self.answer_model, question, answers, reflections
449
501
  )
450
502
 
503
+ visualized_images = visualize_result(all_tool_results)
504
+ all_tool_results.append({"visualized_images": visualized_images})
451
505
  reflection = self_reflect(
452
506
  self.reflect_model,
453
507
  question,
454
508
  self.tools,
455
509
  all_tool_results,
456
510
  final_answer,
457
- image,
511
+ visualized_images[0] if len(visualized_images) > 0 else image,
458
512
  )
459
513
  _LOGGER.info(f"\tReflection: {reflection}")
460
514
  if parse_reflect(reflection):
@@ -0,0 +1,152 @@
1
+ """Utility functions for image processing."""
2
+
3
+ import base64
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from typing import Dict, Tuple, Union
7
+
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from PIL.Image import Image as ImageType
11
+
12
+ COLORS = [
13
+ (158, 218, 229),
14
+ (219, 219, 141),
15
+ (23, 190, 207),
16
+ (188, 189, 34),
17
+ (199, 199, 199),
18
+ (247, 182, 210),
19
+ (127, 127, 127),
20
+ (227, 119, 194),
21
+ (196, 156, 148),
22
+ (197, 176, 213),
23
+ (140, 86, 75),
24
+ (148, 103, 189),
25
+ (255, 152, 150),
26
+ (152, 223, 138),
27
+ (214, 39, 40),
28
+ (44, 160, 44),
29
+ (255, 187, 120),
30
+ (174, 199, 232),
31
+ (255, 127, 14),
32
+ (31, 119, 180),
33
+ ]
34
+
35
+
36
+ def b64_to_pil(b64_str: str) -> ImageType:
37
+ r"""Convert a base64 string to a PIL Image.
38
+
39
+ Parameters:
40
+ b64_str: the base64 encoded image
41
+
42
+ Returns:
43
+ The decoded PIL Image
44
+ """
45
+ # , can't be encoded in b64 data so must be part of prefix
46
+ if "," in b64_str:
47
+ b64_str = b64_str.split(",")[1]
48
+ return Image.open(BytesIO(base64.b64decode(b64_str)))
49
+
50
+
51
+ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
52
+ r"""Get the size of an image.
53
+
54
+ Parameters:
55
+ data: the input image
56
+
57
+ Returns:
58
+ The size of the image in the form (height, width)
59
+ """
60
+ if isinstance(data, (str, Path)):
61
+ data = Image.open(data)
62
+
63
+ return data.size[::-1] if isinstance(data, Image.Image) else data.shape[:2]
64
+
65
+
66
+ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
67
+ r"""Convert an image to a base64 string.
68
+
69
+ Parameters:
70
+ data: the input image
71
+
72
+ Returns:
73
+ The base64 encoded image
74
+ """
75
+ if data is None:
76
+ raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
77
+ if isinstance(data, (str, Path)):
78
+ data = Image.open(data)
79
+ if isinstance(data, Image.Image):
80
+ buffer = BytesIO()
81
+ data.save(buffer, format="PNG")
82
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
83
+ else:
84
+ arr_bytes = data.tobytes()
85
+ return base64.b64encode(arr_bytes).decode("utf-8")
86
+
87
+
88
+ def overlay_bboxes(
89
+ image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
90
+ ) -> ImageType:
91
+ r"""Plots bounding boxes on to an image.
92
+
93
+ Parameters:
94
+ image: the input image
95
+ bboxes: the bounding boxes to overlay
96
+
97
+ Returns:
98
+ The image with the bounding boxes overlayed
99
+ """
100
+ if isinstance(image, (str, Path)):
101
+ image = Image.open(image)
102
+ elif isinstance(image, np.ndarray):
103
+ image = Image.fromarray(image)
104
+
105
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
106
+
107
+ draw = ImageDraw.Draw(image)
108
+ font = ImageFont.load_default()
109
+ width, height = image.size
110
+ if "bboxes" not in bboxes:
111
+ return image.convert("RGB")
112
+
113
+ for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
114
+ box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
115
+ draw.rectangle(box, outline=color[label], width=3)
116
+ label = f"{label}"
117
+ text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
118
+ draw.rectangle(text_box, fill=color[label])
119
+ draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
120
+ return image.convert("RGB")
121
+
122
+
123
+ def overlay_masks(
124
+ image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
125
+ ) -> ImageType:
126
+ r"""Plots masks on to an image.
127
+
128
+ Parameters:
129
+ image: the input image
130
+ masks: the masks to overlay
131
+ alpha: the transparency of the overlay
132
+
133
+ Returns:
134
+ The image with the masks overlayed
135
+ """
136
+ if isinstance(image, (str, Path)):
137
+ image = Image.open(image)
138
+ elif isinstance(image, np.ndarray):
139
+ image = Image.fromarray(image)
140
+
141
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
142
+ if "masks" not in masks:
143
+ return image.convert("RGB")
144
+
145
+ for label, mask in zip(masks["labels"], masks["masks"]):
146
+ if isinstance(mask, str):
147
+ mask = np.array(Image.open(mask))
148
+ np_mask = np.zeros((image.size[1], image.size[0], 4))
149
+ np_mask[mask > 0, :] = color[label] + (255 * alpha,)
150
+ mask_img = Image.fromarray(np_mask.astype(np.uint8))
151
+ image = Image.alpha_composite(image.convert("RGBA"), mask_img)
152
+ return image.convert("RGB")
@@ -0,0 +1,15 @@
1
+ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
2
+ from .tools import (
3
+ CLIP,
4
+ TOOLS,
5
+ BboxArea,
6
+ BboxIoU,
7
+ Counter,
8
+ Crop,
9
+ ExtractFrames,
10
+ GroundingDINO,
11
+ GroundingSAM,
12
+ SegArea,
13
+ SegIoU,
14
+ Tool,
15
+ )
@@ -92,7 +92,7 @@ class CLIP(Tool):
92
92
  }
93
93
 
94
94
  # TODO: Add support for input multiple images, which aligns with the output type.
95
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
95
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
96
96
  """Invoke the CLIP model.
97
97
 
98
98
  Parameters:
@@ -122,7 +122,7 @@ class CLIP(Tool):
122
122
  rets = []
123
123
  for elt in resp_json["data"]:
124
124
  rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
125
- return cast(List[Dict], rets)
125
+ return cast(Dict, rets[0])
126
126
 
127
127
 
128
128
  class GroundingDINO(Tool):
@@ -168,7 +168,7 @@ class GroundingDINO(Tool):
168
168
  }
169
169
 
170
170
  # TODO: Add support for input multiple images, which aligns with the output type.
171
- def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]:
171
+ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
172
172
  """Invoke the Grounding DINO model.
173
173
 
174
174
  Parameters:
@@ -204,7 +204,7 @@ class GroundingDINO(Tool):
204
204
  if "scores" in elt:
205
205
  elt["scores"] = [round(score, 2) for score in elt["scores"]]
206
206
  elt["size"] = (image_size[1], image_size[0])
207
- return cast(List[Dict], resp_data)
207
+ return cast(Dict, resp_data)
208
208
 
209
209
 
210
210
  class GroundingSAM(Tool):
@@ -259,7 +259,7 @@ class GroundingSAM(Tool):
259
259
  }
260
260
 
261
261
  # TODO: Add support for input multiple images, which aligns with the output type.
262
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
262
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
263
263
  """Invoke the Grounding SAM model.
264
264
 
265
265
  Parameters:
@@ -294,7 +294,7 @@ class GroundingSAM(Tool):
294
294
  ret_pred["labels"].append(pred["label_name"])
295
295
  ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
296
296
  ret_pred["masks"].append(mask)
297
- return [ret_pred]
297
+ return ret_pred
298
298
 
299
299
 
300
300
  class AgentGroundingSAM(GroundingSAM):
@@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM):
302
302
  returns the file name. This makes it easier for agents to use.
303
303
  """
304
304
 
305
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
305
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
306
306
  rets = super().__call__(prompt, image)
307
- for ret in rets:
308
- mask_files = []
309
- for mask in ret["masks"]:
310
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
311
- Image.fromarray(mask * 255).save(tmp)
312
- mask_files.append(tmp.name)
313
- ret["masks"] = mask_files
307
+ mask_files = []
308
+ for mask in rets["masks"]:
309
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
310
+ Image.fromarray(mask * 255).save(tmp)
311
+ mask_files.append(tmp.name)
312
+ rets["masks"] = mask_files
314
313
  return rets
315
314
 
316
315
 
@@ -363,7 +362,7 @@ class Crop(Tool):
363
362
  ],
364
363
  }
365
364
 
366
- def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
365
+ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
367
366
  pil_image = Image.open(image)
368
367
  width, height = pil_image.size
369
368
  bbox = [
@@ -373,10 +372,10 @@ class Crop(Tool):
373
372
  int(bbox[3] * height),
374
373
  ]
375
374
  cropped_image = pil_image.crop(bbox) # type: ignore
376
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
375
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
377
376
  cropped_image.save(tmp.name)
378
377
 
379
- return tmp.name
378
+ return {"image": tmp.name}
380
379
 
381
380
 
382
381
  class BboxArea(Tool):
@@ -388,7 +387,7 @@ class BboxArea(Tool):
388
387
  "required_parameters": [{"name": "bbox", "type": "List[int]"}],
389
388
  "examples": [
390
389
  {
391
- "scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]",
390
+ "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
392
391
  "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
393
392
  }
394
393
  ],
@@ -430,6 +429,109 @@ class SegArea(Tool):
430
429
  return cast(float, round(np.sum(np_mask) / 255, 2))
431
430
 
432
431
 
432
+ class BboxIoU(Tool):
433
+ name = "bbox_iou_"
434
+ description = (
435
+ "'bbox_iou_' returns the intersection over union of two bounding boxes."
436
+ )
437
+ usage = {
438
+ "required_parameters": [
439
+ {"name": "bbox1", "type": "List[int]"},
440
+ {"name": "bbox2", "type": "List[int]"},
441
+ ],
442
+ "examples": [
443
+ {
444
+ "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
445
+ "parameters": {
446
+ "bbox1": [0.2, 0.21, 0.34, 0.42],
447
+ "bbox2": [0.3, 0.31, 0.44, 0.52],
448
+ },
449
+ }
450
+ ],
451
+ }
452
+
453
+ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
454
+ x1, y1, x2, y2 = bbox1
455
+ x3, y3, x4, y4 = bbox2
456
+ xA = max(x1, x3)
457
+ yA = max(y1, y3)
458
+ xB = min(x2, x4)
459
+ yB = min(y2, y4)
460
+ inter_area = max(0, xB - xA) * max(0, yB - yA)
461
+ boxa_area = (x2 - x1) * (y2 - y1)
462
+ boxb_area = (x4 - x3) * (y4 - y3)
463
+ iou = inter_area / float(boxa_area + boxb_area - inter_area)
464
+ return round(iou, 2)
465
+
466
+
467
+ class SegIoU(Tool):
468
+ name = "seg_iou_"
469
+ description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
470
+ usage = {
471
+ "required_parameters": [
472
+ {"name": "mask1", "type": "str"},
473
+ {"name": "mask2", "type": "str"},
474
+ ],
475
+ "examples": [
476
+ {
477
+ "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
478
+ "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
479
+ }
480
+ ],
481
+ }
482
+
483
+ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
484
+ pil_mask1 = Image.open(str(mask1))
485
+ pil_mask2 = Image.open(str(mask2))
486
+ np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
487
+ np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
488
+ intersection = np.logical_and(np_mask1, np_mask2)
489
+ union = np.logical_or(np_mask1, np_mask2)
490
+ iou = np.sum(intersection) / np.sum(union)
491
+ return cast(float, round(iou, 2))
492
+
493
+
494
+ class ExtractFrames(Tool):
495
+ r"""Extract frames from a video."""
496
+
497
+ name = "extract_frames_"
498
+ description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
499
+ usage = {
500
+ "required_parameters": [{"name": "video_uri", "type": "str"}],
501
+ "examples": [
502
+ {
503
+ "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
504
+ "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
505
+ },
506
+ {
507
+ "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
508
+ "parameters": {"video_uri": "tests/data/test.mp4"},
509
+ },
510
+ ],
511
+ }
512
+
513
+ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
514
+ """Extract frames from a video.
515
+
516
+
517
+ Parameters:
518
+ video_uri: the path to the video file or a url points to the video data
519
+
520
+ Returns:
521
+ a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
522
+ """
523
+ frames = extract_frames_from_video(video_uri)
524
+ result = []
525
+ _LOGGER.info(
526
+ f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
527
+ )
528
+ for frame, ts in frames:
529
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
530
+ Image.fromarray(frame).save(tmp)
531
+ result.append((tmp.name, ts))
532
+ return result
533
+
534
+
433
535
  class Add(Tool):
434
536
  r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
435
537
 
@@ -506,47 +608,6 @@ class Divide(Tool):
506
608
  return round(input[0] / input[1], 2)
507
609
 
508
610
 
509
- class ExtractFrames(Tool):
510
- r"""Extract frames from a video."""
511
-
512
- name = "extract_frames_"
513
- description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
514
- usage = {
515
- "required_parameters": [{"name": "video_uri", "type": "str"}],
516
- "examples": [
517
- {
518
- "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
519
- "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
520
- },
521
- {
522
- "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
523
- "parameters": {"video_uri": "tests/data/test.mp4"},
524
- },
525
- ],
526
- }
527
-
528
- def __call__(self, video_uri: str) -> list[tuple[str, float]]:
529
- """Extract frames from a video.
530
-
531
-
532
- Parameters:
533
- video_uri: the path to the video file or a url points to the video data
534
-
535
- Returns:
536
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
537
- """
538
- frames = extract_frames_from_video(video_uri)
539
- result = []
540
- _LOGGER.info(
541
- f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
542
- )
543
- for frame, ts in frames:
544
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
545
- Image.fromarray(frame).save(tmp)
546
- result.append((tmp.name, ts))
547
- return result
548
-
549
-
550
611
  TOOLS = {
551
612
  i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
552
613
  for i, c in enumerate(
@@ -554,15 +615,17 @@ TOOLS = {
554
615
  CLIP,
555
616
  GroundingDINO,
556
617
  AgentGroundingSAM,
618
+ ExtractFrames,
557
619
  Counter,
558
620
  Crop,
559
621
  BboxArea,
560
622
  SegArea,
623
+ BboxIoU,
624
+ SegIoU,
561
625
  Add,
562
626
  Subtract,
563
627
  Multiply,
564
628
  Divide,
565
- ExtractFrames,
566
629
  ]
567
630
  )
568
631
  if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
@@ -22,12 +22,16 @@ def extract_frames_from_video(
22
22
  Parameters:
23
23
  video_uri: the path to the video file or a video file url
24
24
  fps: the frame rate per second to extract the frames
25
- motion_detection_threshold: The threshold to detect motion between changes/frames.
26
- A value between 0-1, which represents the percentage change required for the frames to be considered in motion.
27
- For example, a lower value means more frames will be extracted.
25
+ motion_detection_threshold: The threshold to detect motion between
26
+ changes/frames. A value between 0-1, which represents the percentage change
27
+ required for the frames to be considered in motion. For example, a lower
28
+ value means more frames will be extracted.
28
29
 
29
30
  Returns:
30
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
31
+ a list of tuples containing the extracted frame and the timestamp in seconds.
32
+ E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
33
+ from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
34
+ the video. The frames are sorted by the timestamp in ascending order.
31
35
  """
32
36
  with VideoFileClip(video_uri) as video:
33
37
  video_duration: float = video.duration
@@ -1,62 +0,0 @@
1
- """Utility functions for image processing."""
2
-
3
- import base64
4
- from io import BytesIO
5
- from pathlib import Path
6
- from typing import Tuple, Union
7
-
8
- import numpy as np
9
- from PIL import Image
10
- from PIL.Image import Image as ImageType
11
-
12
-
13
- def b64_to_pil(b64_str: str) -> ImageType:
14
- """Convert a base64 string to a PIL Image.
15
-
16
- Parameters:
17
- b64_str: the base64 encoded image
18
-
19
- Returns:
20
- The decoded PIL Image
21
- """
22
- # , can't be encoded in b64 data so must be part of prefix
23
- if "," in b64_str:
24
- b64_str = b64_str.split(",")[1]
25
- return Image.open(BytesIO(base64.b64decode(b64_str)))
26
-
27
-
28
- def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
29
- """Get the size of an image.
30
-
31
- Parameters:
32
- data: the input image
33
-
34
- Returns:
35
- The size of the image in the form (height, width)
36
- """
37
- if isinstance(data, (str, Path)):
38
- data = Image.open(data)
39
-
40
- return data.size[::-1] if isinstance(data, Image.Image) else data.shape[:2]
41
-
42
-
43
- def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
44
- """Convert an image to a base64 string.
45
-
46
- Parameters:
47
- data: the input image
48
-
49
- Returns:
50
- The base64 encoded image
51
- """
52
- if data is None:
53
- raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
54
- if isinstance(data, (str, Path)):
55
- data = Image.open(data)
56
- if isinstance(data, Image.Image):
57
- buffer = BytesIO()
58
- data.save(buffer, format="PNG")
59
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
60
- else:
61
- arr_bytes = data.tobytes()
62
- return base64.b64encode(arr_bytes).decode("utf-8")
@@ -1,2 +0,0 @@
1
- from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
2
- from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
File without changes
File without changes