vision-agent 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  import sys
4
+ import tempfile
4
5
  from pathlib import Path
5
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
7
 
8
+ from tabulate import tabulate
9
+
10
+ from vision_agent.image_utils import overlay_bboxes, overlay_masks
7
11
  from vision_agent.llm import LLM, OpenAILLM
8
12
  from vision_agent.lmm import LMM, OpenAILMM
9
13
  from vision_agent.tools import TOOLS
@@ -246,12 +250,12 @@ def retrieval(
246
250
  tools: Dict[int, Any],
247
251
  previous_log: str,
248
252
  reflections: str,
249
- ) -> Tuple[List[Dict], str]:
253
+ ) -> Tuple[Dict, str]:
250
254
  tool_id = choose_tool(
251
255
  model, question, {k: v["description"] for k, v in tools.items()}, reflections
252
256
  )
253
257
  if tool_id is None:
254
- return [{}], ""
258
+ return {}, ""
255
259
  _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
256
260
 
257
261
  tool_instructions = tools[tool_id]
@@ -263,10 +267,13 @@ def retrieval(
263
267
  )
264
268
  _LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
265
269
  if parameters is None:
266
- return [{}], ""
267
- tool_results = [
268
- {"task": question, "tool_name": tool_name, "parameters": parameters}
269
- ]
270
+ return {}, ""
271
+ tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
272
+
273
+ _LOGGER.info(
274
+ f"""Going to run the following tool(s) in sequence:
275
+ {tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
276
+ )
270
277
 
271
278
  def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
272
279
  call_results: List[Any] = []
@@ -279,12 +286,10 @@ def retrieval(
279
286
  call_results.append(function_call(tools[tool_id]["class"], parameters))
280
287
  return call_results
281
288
 
282
- call_results = []
283
- for i, result in enumerate(tool_results):
284
- call_results.extend(parse_tool_results(result))
285
- tool_results[i]["call_results"] = call_results
289
+ call_results = parse_tool_results(tool_results)
290
+ tool_results["call_results"] = call_results
286
291
 
287
- call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
292
+ call_results_str = str(call_results)
288
293
  _LOGGER.info(f"\tCall Results: {call_results_str}")
289
294
  return tool_results, call_results_str
290
295
 
@@ -298,8 +303,6 @@ def create_tasks(
298
303
  {k: v["description"] for k, v in tools.items()},
299
304
  reflections,
300
305
  )
301
-
302
- _LOGGER.info(f"Tasks: {tasks}")
303
306
  if tasks is not None:
304
307
  task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
305
308
  task_list = task_topology(task_model, question, task_list)
@@ -309,6 +312,10 @@ def create_tasks(
309
312
  _LOGGER.error(f"Failed topological_sort on: {task_list}")
310
313
  else:
311
314
  task_list = []
315
+ _LOGGER.info(
316
+ f"""Planned tasks:
317
+ {tabulate(task_list, headers="keys", tablefmt="mixed_grid")}"""
318
+ )
312
319
  return task_list
313
320
 
314
321
 
@@ -326,7 +333,11 @@ def self_reflect(
326
333
  tool_results=str(tool_result),
327
334
  final_answer=final_answer,
328
335
  )
329
- if issubclass(type(reflect_model), LMM):
336
+ if (
337
+ issubclass(type(reflect_model), LMM)
338
+ and image is not None
339
+ and Path(image).suffix in [".jpg", ".jpeg", ".png"]
340
+ ):
330
341
  return reflect_model(prompt, image=image) # type: ignore
331
342
  return reflect_model(prompt)
332
343
 
@@ -336,6 +347,56 @@ def parse_reflect(reflect: str) -> bool:
336
347
  return "finish" in reflect.lower() and len(reflect) < 100
337
348
 
338
349
 
350
+ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
351
+ image_to_data: Dict[str, Dict] = {}
352
+ for tool_result in all_tool_results:
353
+ if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
354
+ continue
355
+
356
+ parameters = tool_result["parameters"]
357
+ # parameters can either be a dictionary or list, parameters can also be malformed
358
+ # becaus the LLM builds them
359
+ if isinstance(parameters, dict):
360
+ if "image" not in parameters:
361
+ continue
362
+ parameters = [parameters]
363
+ elif isinstance(tool_result["parameters"], list):
364
+ if (
365
+ len(tool_result["parameters"]) < 1
366
+ and "image" not in tool_result["parameters"][0]
367
+ ):
368
+ continue
369
+
370
+ for param, call_result in zip(parameters, tool_result["call_results"]):
371
+
372
+ # calls can fail, so we need to check if the call was successful
373
+ if not isinstance(call_result, dict):
374
+ continue
375
+ if "bboxes" not in call_result:
376
+ continue
377
+
378
+ # if the call was successful, then we can add the image data
379
+ image = param["image"]
380
+ if image not in image_to_data:
381
+ image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
382
+
383
+ image_to_data[image]["bboxes"].extend(call_result["bboxes"])
384
+ image_to_data[image]["labels"].extend(call_result["labels"])
385
+ if "masks" in call_result:
386
+ image_to_data[image]["masks"].extend(call_result["masks"])
387
+
388
+ visualized_images = []
389
+ for image in image_to_data:
390
+ image_path = Path(image)
391
+ image_data = image_to_data[image]
392
+ image = overlay_masks(image_path, image_data)
393
+ image = overlay_bboxes(image, image_data)
394
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
395
+ image.save(f.name)
396
+ visualized_images.append(f.name)
397
+ return visualized_images
398
+
399
+
339
400
  class VisionAgent(Agent):
340
401
  r"""Vision Agent is an agent framework that utilizes tools as well as self
341
402
  reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
@@ -380,7 +441,8 @@ class VisionAgent(Agent):
380
441
  """Invoke the vision agent.
381
442
 
382
443
  Parameters:
383
- input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
444
+ input: a prompt that describe the task or a conversation in the format of
445
+ [{"role": "user", "content": "describe your task here..."}].
384
446
  image: the input image referenced in the prompt parameter.
385
447
 
386
448
  Returns:
@@ -427,9 +489,8 @@ class VisionAgent(Agent):
427
489
  self.answer_model, task_str, call_results, previous_log, reflections
428
490
  )
429
491
 
430
- for tool_result in tool_results:
431
- tool_result["answer"] = answer
432
- all_tool_results.extend(tool_results)
492
+ tool_results["answer"] = answer
493
+ all_tool_results.append(tool_results)
433
494
 
434
495
  _LOGGER.info(f"\tAnswer: {answer}")
435
496
  answers.append({"task": task_str, "answer": answer})
@@ -439,13 +500,15 @@ class VisionAgent(Agent):
439
500
  self.answer_model, question, answers, reflections
440
501
  )
441
502
 
503
+ visualized_images = visualize_result(all_tool_results)
504
+ all_tool_results.append({"visualized_images": visualized_images})
442
505
  reflection = self_reflect(
443
506
  self.reflect_model,
444
507
  question,
445
508
  self.tools,
446
509
  all_tool_results,
447
510
  final_answer,
448
- image,
511
+ visualized_images[0] if len(visualized_images) > 0 else image,
449
512
  )
450
513
  _LOGGER.info(f"\tReflection: {reflection}")
451
514
  if parse_reflect(reflection):
@@ -3,15 +3,38 @@
3
3
  import base64
4
4
  from io import BytesIO
5
5
  from pathlib import Path
6
- from typing import Tuple, Union
6
+ from typing import Dict, Tuple, Union
7
7
 
8
8
  import numpy as np
9
- from PIL import Image
9
+ from PIL import Image, ImageDraw, ImageFont
10
10
  from PIL.Image import Image as ImageType
11
11
 
12
+ COLORS = [
13
+ (158, 218, 229),
14
+ (219, 219, 141),
15
+ (23, 190, 207),
16
+ (188, 189, 34),
17
+ (199, 199, 199),
18
+ (247, 182, 210),
19
+ (127, 127, 127),
20
+ (227, 119, 194),
21
+ (196, 156, 148),
22
+ (197, 176, 213),
23
+ (140, 86, 75),
24
+ (148, 103, 189),
25
+ (255, 152, 150),
26
+ (152, 223, 138),
27
+ (214, 39, 40),
28
+ (44, 160, 44),
29
+ (255, 187, 120),
30
+ (174, 199, 232),
31
+ (255, 127, 14),
32
+ (31, 119, 180),
33
+ ]
34
+
12
35
 
13
36
  def b64_to_pil(b64_str: str) -> ImageType:
14
- """Convert a base64 string to a PIL Image.
37
+ r"""Convert a base64 string to a PIL Image.
15
38
 
16
39
  Parameters:
17
40
  b64_str: the base64 encoded image
@@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:
26
49
 
27
50
 
28
51
  def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
29
- """Get the size of an image.
52
+ r"""Get the size of an image.
30
53
 
31
54
  Parameters:
32
55
  data: the input image
@@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,
41
64
 
42
65
 
43
66
  def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
44
- """Convert an image to a base64 string.
67
+ r"""Convert an image to a base64 string.
45
68
 
46
69
  Parameters:
47
70
  data: the input image
@@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
60
83
  else:
61
84
  arr_bytes = data.tobytes()
62
85
  return base64.b64encode(arr_bytes).decode("utf-8")
86
+
87
+
88
+ def overlay_bboxes(
89
+ image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
90
+ ) -> ImageType:
91
+ r"""Plots bounding boxes on to an image.
92
+
93
+ Parameters:
94
+ image: the input image
95
+ bboxes: the bounding boxes to overlay
96
+
97
+ Returns:
98
+ The image with the bounding boxes overlayed
99
+ """
100
+ if isinstance(image, (str, Path)):
101
+ image = Image.open(image)
102
+ elif isinstance(image, np.ndarray):
103
+ image = Image.fromarray(image)
104
+
105
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
106
+
107
+ draw = ImageDraw.Draw(image)
108
+ font = ImageFont.load_default()
109
+ width, height = image.size
110
+ if "bboxes" not in bboxes:
111
+ return image.convert("RGB")
112
+
113
+ for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
114
+ box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
115
+ draw.rectangle(box, outline=color[label], width=3)
116
+ label = f"{label}"
117
+ text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
118
+ draw.rectangle(text_box, fill=color[label])
119
+ draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
120
+ return image.convert("RGB")
121
+
122
+
123
+ def overlay_masks(
124
+ image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
125
+ ) -> ImageType:
126
+ r"""Plots masks on to an image.
127
+
128
+ Parameters:
129
+ image: the input image
130
+ masks: the masks to overlay
131
+ alpha: the transparency of the overlay
132
+
133
+ Returns:
134
+ The image with the masks overlayed
135
+ """
136
+ if isinstance(image, (str, Path)):
137
+ image = Image.open(image)
138
+ elif isinstance(image, np.ndarray):
139
+ image = Image.fromarray(image)
140
+
141
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
142
+ if "masks" not in masks:
143
+ return image.convert("RGB")
144
+
145
+ for label, mask in zip(masks["labels"], masks["masks"]):
146
+ if isinstance(mask, str):
147
+ mask = np.array(Image.open(mask))
148
+ np_mask = np.zeros((image.size[1], image.size[0], 4))
149
+ np_mask[mask > 0, :] = color[label] + (255 * alpha,)
150
+ mask_img = Image.fromarray(np_mask.astype(np.uint8))
151
+ image = Image.alpha_composite(image.convert("RGBA"), mask_img)
152
+ return image.convert("RGB")
@@ -1,2 +1,15 @@
1
1
  from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
2
- from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
2
+ from .tools import (
3
+ CLIP,
4
+ TOOLS,
5
+ BboxArea,
6
+ BboxIoU,
7
+ Counter,
8
+ Crop,
9
+ ExtractFrames,
10
+ GroundingDINO,
11
+ GroundingSAM,
12
+ SegArea,
13
+ SegIoU,
14
+ Tool,
15
+ )
@@ -92,7 +92,7 @@ class CLIP(Tool):
92
92
  }
93
93
 
94
94
  # TODO: Add support for input multiple images, which aligns with the output type.
95
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
95
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
96
96
  """Invoke the CLIP model.
97
97
 
98
98
  Parameters:
@@ -122,7 +122,7 @@ class CLIP(Tool):
122
122
  rets = []
123
123
  for elt in resp_json["data"]:
124
124
  rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
125
- return cast(List[Dict], rets)
125
+ return cast(Dict, rets[0])
126
126
 
127
127
 
128
128
  class GroundingDINO(Tool):
@@ -168,7 +168,7 @@ class GroundingDINO(Tool):
168
168
  }
169
169
 
170
170
  # TODO: Add support for input multiple images, which aligns with the output type.
171
- def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]:
171
+ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
172
172
  """Invoke the Grounding DINO model.
173
173
 
174
174
  Parameters:
@@ -204,7 +204,7 @@ class GroundingDINO(Tool):
204
204
  if "scores" in elt:
205
205
  elt["scores"] = [round(score, 2) for score in elt["scores"]]
206
206
  elt["size"] = (image_size[1], image_size[0])
207
- return cast(List[Dict], resp_data)
207
+ return cast(Dict, resp_data)
208
208
 
209
209
 
210
210
  class GroundingSAM(Tool):
@@ -259,7 +259,7 @@ class GroundingSAM(Tool):
259
259
  }
260
260
 
261
261
  # TODO: Add support for input multiple images, which aligns with the output type.
262
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
262
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
263
263
  """Invoke the Grounding SAM model.
264
264
 
265
265
  Parameters:
@@ -294,7 +294,7 @@ class GroundingSAM(Tool):
294
294
  ret_pred["labels"].append(pred["label_name"])
295
295
  ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
296
296
  ret_pred["masks"].append(mask)
297
- return [ret_pred]
297
+ return ret_pred
298
298
 
299
299
 
300
300
  class AgentGroundingSAM(GroundingSAM):
@@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM):
302
302
  returns the file name. This makes it easier for agents to use.
303
303
  """
304
304
 
305
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
305
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
306
306
  rets = super().__call__(prompt, image)
307
- for ret in rets:
308
- mask_files = []
309
- for mask in ret["masks"]:
310
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
311
- Image.fromarray(mask * 255).save(tmp)
312
- mask_files.append(tmp.name)
313
- ret["masks"] = mask_files
307
+ mask_files = []
308
+ for mask in rets["masks"]:
309
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
310
+ Image.fromarray(mask * 255).save(tmp)
311
+ mask_files.append(tmp.name)
312
+ rets["masks"] = mask_files
314
313
  return rets
315
314
 
316
315
 
@@ -363,7 +362,7 @@ class Crop(Tool):
363
362
  ],
364
363
  }
365
364
 
366
- def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
365
+ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
367
366
  pil_image = Image.open(image)
368
367
  width, height = pil_image.size
369
368
  bbox = [
@@ -373,10 +372,10 @@ class Crop(Tool):
373
372
  int(bbox[3] * height),
374
373
  ]
375
374
  cropped_image = pil_image.crop(bbox) # type: ignore
376
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
375
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
377
376
  cropped_image.save(tmp.name)
378
377
 
379
- return tmp.name
378
+ return {"image": tmp.name}
380
379
 
381
380
 
382
381
  class BboxArea(Tool):
@@ -388,7 +387,7 @@ class BboxArea(Tool):
388
387
  "required_parameters": [{"name": "bbox", "type": "List[int]"}],
389
388
  "examples": [
390
389
  {
391
- "scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]",
390
+ "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
392
391
  "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
393
392
  }
394
393
  ],
@@ -430,6 +429,109 @@ class SegArea(Tool):
430
429
  return cast(float, round(np.sum(np_mask) / 255, 2))
431
430
 
432
431
 
432
+ class BboxIoU(Tool):
433
+ name = "bbox_iou_"
434
+ description = (
435
+ "'bbox_iou_' returns the intersection over union of two bounding boxes."
436
+ )
437
+ usage = {
438
+ "required_parameters": [
439
+ {"name": "bbox1", "type": "List[int]"},
440
+ {"name": "bbox2", "type": "List[int]"},
441
+ ],
442
+ "examples": [
443
+ {
444
+ "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
445
+ "parameters": {
446
+ "bbox1": [0.2, 0.21, 0.34, 0.42],
447
+ "bbox2": [0.3, 0.31, 0.44, 0.52],
448
+ },
449
+ }
450
+ ],
451
+ }
452
+
453
+ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
454
+ x1, y1, x2, y2 = bbox1
455
+ x3, y3, x4, y4 = bbox2
456
+ xA = max(x1, x3)
457
+ yA = max(y1, y3)
458
+ xB = min(x2, x4)
459
+ yB = min(y2, y4)
460
+ inter_area = max(0, xB - xA) * max(0, yB - yA)
461
+ boxa_area = (x2 - x1) * (y2 - y1)
462
+ boxb_area = (x4 - x3) * (y4 - y3)
463
+ iou = inter_area / float(boxa_area + boxb_area - inter_area)
464
+ return round(iou, 2)
465
+
466
+
467
+ class SegIoU(Tool):
468
+ name = "seg_iou_"
469
+ description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
470
+ usage = {
471
+ "required_parameters": [
472
+ {"name": "mask1", "type": "str"},
473
+ {"name": "mask2", "type": "str"},
474
+ ],
475
+ "examples": [
476
+ {
477
+ "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
478
+ "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
479
+ }
480
+ ],
481
+ }
482
+
483
+ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
484
+ pil_mask1 = Image.open(str(mask1))
485
+ pil_mask2 = Image.open(str(mask2))
486
+ np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
487
+ np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
488
+ intersection = np.logical_and(np_mask1, np_mask2)
489
+ union = np.logical_or(np_mask1, np_mask2)
490
+ iou = np.sum(intersection) / np.sum(union)
491
+ return cast(float, round(iou, 2))
492
+
493
+
494
+ class ExtractFrames(Tool):
495
+ r"""Extract frames from a video."""
496
+
497
+ name = "extract_frames_"
498
+ description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
499
+ usage = {
500
+ "required_parameters": [{"name": "video_uri", "type": "str"}],
501
+ "examples": [
502
+ {
503
+ "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
504
+ "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
505
+ },
506
+ {
507
+ "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
508
+ "parameters": {"video_uri": "tests/data/test.mp4"},
509
+ },
510
+ ],
511
+ }
512
+
513
+ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
514
+ """Extract frames from a video.
515
+
516
+
517
+ Parameters:
518
+ video_uri: the path to the video file or a url points to the video data
519
+
520
+ Returns:
521
+ a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
522
+ """
523
+ frames = extract_frames_from_video(video_uri)
524
+ result = []
525
+ _LOGGER.info(
526
+ f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
527
+ )
528
+ for frame, ts in frames:
529
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
530
+ Image.fromarray(frame).save(tmp)
531
+ result.append((tmp.name, ts))
532
+ return result
533
+
534
+
433
535
  class Add(Tool):
434
536
  r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
435
537
 
@@ -506,47 +608,6 @@ class Divide(Tool):
506
608
  return round(input[0] / input[1], 2)
507
609
 
508
610
 
509
- class ExtractFrames(Tool):
510
- r"""Extract frames from a video."""
511
-
512
- name = "extract_frames_"
513
- description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
514
- usage = {
515
- "required_parameters": [{"name": "video_uri", "type": "str"}],
516
- "examples": [
517
- {
518
- "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
519
- "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
520
- },
521
- {
522
- "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
523
- "parameters": {"video_uri": "tests/data/test.mp4"},
524
- },
525
- ],
526
- }
527
-
528
- def __call__(self, video_uri: str) -> list[tuple[str, float]]:
529
- """Extract frames from a video.
530
-
531
-
532
- Parameters:
533
- video_uri: the path to the video file or a url points to the video data
534
-
535
- Returns:
536
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
537
- """
538
- frames = extract_frames_from_video(video_uri)
539
- result = []
540
- _LOGGER.info(
541
- f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
542
- )
543
- for frame, ts in frames:
544
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
545
- Image.fromarray(frame).save(tmp)
546
- result.append((tmp.name, ts))
547
- return result
548
-
549
-
550
611
  TOOLS = {
551
612
  i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
552
613
  for i, c in enumerate(
@@ -554,15 +615,17 @@ TOOLS = {
554
615
  CLIP,
555
616
  GroundingDINO,
556
617
  AgentGroundingSAM,
618
+ ExtractFrames,
557
619
  Counter,
558
620
  Crop,
559
621
  BboxArea,
560
622
  SegArea,
623
+ BboxIoU,
624
+ SegIoU,
561
625
  Add,
562
626
  Subtract,
563
627
  Multiply,
564
628
  Divide,
565
- ExtractFrames,
566
629
  ]
567
630
  )
568
631
  if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
@@ -22,12 +22,16 @@ def extract_frames_from_video(
22
22
  Parameters:
23
23
  video_uri: the path to the video file or a video file url
24
24
  fps: the frame rate per second to extract the frames
25
- motion_detection_threshold: The threshold to detect motion between changes/frames.
26
- A value between 0-1, which represents the percentage change required for the frames to be considered in motion.
27
- For example, a lower value means more frames will be extracted.
25
+ motion_detection_threshold: The threshold to detect motion between
26
+ changes/frames. A value between 0-1, which represents the percentage change
27
+ required for the frames to be considered in motion. For example, a lower
28
+ value means more frames will be extracted.
28
29
 
29
30
  Returns:
30
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
31
+ a list of tuples containing the extracted frame and the timestamp in seconds.
32
+ E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
33
+ from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
34
+ the video. The frames are sorted by the timestamp in ascending order.
31
35
  """
32
36
  with VideoFileClip(video_uri) as video:
33
37
  video_duration: float = video.duration
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.0.39
3
+ Version: 0.0.41
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -18,6 +18,7 @@ Requires-Dist: pandas (>=2.0.0,<3.0.0)
18
18
  Requires-Dist: pillow (>=10.0.0,<11.0.0)
19
19
  Requires-Dist: requests (>=2.0.0,<3.0.0)
20
20
  Requires-Dist: sentence-transformers (>=2.0.0,<3.0.0)
21
+ Requires-Dist: tabulate (>=0.9.0,<0.10.0)
21
22
  Requires-Dist: torch (>=2.1.0,<2.2.0)
22
23
  Requires-Dist: tqdm (>=4.64.0,<5.0.0)
23
24
  Requires-Dist: typing_extensions (>=4.0.0,<5.0.0)
@@ -5,22 +5,22 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
5
5
  vision_agent/agent/easytool_prompts.py,sha256=uNp12LOFRLr3i2zLhNuLuyFms2-s8es2t6P6h76QDow,4493
6
6
  vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=UPGkt_qgHBMUY0VPVoF-BqhR0d_6WPjjrhbYLBYOtnQ,9342
8
- vision_agent/agent/vision_agent.py,sha256=JPoY92M5xNaViLdNf4d1oqAX00QUuQxk-gcc9jIlfqA,14981
8
+ vision_agent/agent/vision_agent.py,sha256=_K6yWJiU1j0EGe8cabB40K0HxUkdzF-_c8G2k5eQL8s,17469
9
9
  vision_agent/agent/vision_agent_prompts.py,sha256=otaDRsaHc7bqw_tgWTnu-eUcFeOzBFrn9sPU7_xr2VQ,6151
10
10
  vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
11
  vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
12
12
  vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
13
13
  vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
14
- vision_agent/image_utils.py,sha256=D5H-GN35Bz3u1Fq_JfYQVjNzAmZjJl138wma5fRtVjA,1684
14
+ vision_agent/image_utils.py,sha256=XiOLpHAvlk55URw6iG7hl1OY71FVRA9_25b650amZXA,4420
15
15
  vision_agent/llm/__init__.py,sha256=fBKsIjL4z08eA0QYx6wvhRe4Nkp2pJ4VrZK0-uUL5Ec,32
16
16
  vision_agent/llm/llm.py,sha256=d8A7jmLVGx5HzoiYJ75mTMU7dbD5-bOYeXYlHaay6WA,3957
17
17
  vision_agent/lmm/__init__.py,sha256=I8mbeNUajTfWVNqLsuFQVOaNBDlkIhYp9DFU8H4kB7g,51
18
18
  vision_agent/lmm/lmm.py,sha256=ARcbgkcyP83TbVVoXI9B-gtG0gJuTaG_MjcUGbams4U,8052
19
- vision_agent/tools/__init__.py,sha256=aX0pU3pXU1V0Cj9FzYCvdsX76TAglFMHx59kNhXHbPs,131
19
+ vision_agent/tools/__init__.py,sha256=AKN-T659HpwVearRnkCd6wWNoJ6K5kW9gAZwb8IQSLE,235
20
20
  vision_agent/tools/prompts.py,sha256=9RBbyqlNlExsGKlJ89Jkph83DAEJ8PCVGaHoNbyN7TM,1416
21
- vision_agent/tools/tools.py,sha256=2mmomPDbldXRpw3q5zAcazKJMjAGd0Jl9ak9JykHQYI,21211
22
- vision_agent/tools/video.py,sha256=KV_Wcat7DDGxpHSaGBu7s4lj4crlYaUu4YKpCO_86k4,7440
23
- vision_agent-0.0.39.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
- vision_agent-0.0.39.dist-info/METADATA,sha256=_jugEQnOeNbLa3kSSo0zTn2bII3Rh5dfop9qyMWXPfw,5282
25
- vision_agent-0.0.39.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
26
- vision_agent-0.0.39.dist-info/RECORD,,
21
+ vision_agent/tools/tools.py,sha256=aMTBxxaXQp33HwplOS8xrgfbsTJ8e1pwO6byR7HcTJI,23447
22
+ vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
23
+ vision_agent-0.0.41.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
+ vision_agent-0.0.41.dist-info/METADATA,sha256=45hGAgKvEd7WjzrmbFVluki2t0O64UomaHtIrwLCknw,5324
25
+ vision_agent-0.0.41.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
26
+ vision_agent-0.0.41.dist-info/RECORD,,