vision-agent 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  import sys
4
+ import tempfile
4
5
  from pathlib import Path
5
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
7
 
7
8
  from tabulate import tabulate
8
9
 
10
+ from vision_agent.image_utils import overlay_bboxes, overlay_masks
9
11
  from vision_agent.llm import LLM, OpenAILLM
10
12
  from vision_agent.lmm import LMM, OpenAILMM
11
13
  from vision_agent.tools import TOOLS
@@ -248,13 +250,12 @@ def retrieval(
248
250
  tools: Dict[int, Any],
249
251
  previous_log: str,
250
252
  reflections: str,
251
- ) -> Tuple[List[Dict], str]:
253
+ ) -> Tuple[Dict, str]:
252
254
  tool_id = choose_tool(
253
255
  model, question, {k: v["description"] for k, v in tools.items()}, reflections
254
256
  )
255
257
  if tool_id is None:
256
- return [{}], ""
257
- _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")
258
+ return {}, ""
258
259
 
259
260
  tool_instructions = tools[tool_id]
260
261
  tool_usage = tool_instructions["usage"]
@@ -263,16 +264,13 @@ def retrieval(
263
264
  parameters = choose_parameter(
264
265
  model, question, tool_usage, previous_log, reflections
265
266
  )
266
- _LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
267
267
  if parameters is None:
268
- return [{}], ""
269
- tool_results = [
270
- {"task": question, "tool_name": tool_name, "parameters": parameters}
271
- ]
268
+ return {}, ""
269
+ tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
272
270
 
273
271
  _LOGGER.info(
274
- f"""Going to run the following {len(tool_results)} tool(s) in sequence:
275
- {tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
272
+ f"""Going to run the following tool(s) in sequence:
273
+ {tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
276
274
  )
277
275
 
278
276
  def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
@@ -286,13 +284,11 @@ def retrieval(
286
284
  call_results.append(function_call(tools[tool_id]["class"], parameters))
287
285
  return call_results
288
286
 
289
- call_results = []
290
- for i, result in enumerate(tool_results):
291
- call_results.extend(parse_tool_results(result))
292
- tool_results[i]["call_results"] = call_results
287
+ call_results = parse_tool_results(tool_results)
288
+ tool_results["call_results"] = call_results
293
289
 
294
- call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
295
- _LOGGER.info(f"\tCall Results: {call_results_str}")
290
+ call_results_str = str(call_results)
291
+ # _LOGGER.info(f"\tCall Results: {call_results_str}")
296
292
  return tool_results, call_results_str
297
293
 
298
294
 
@@ -335,14 +331,70 @@ def self_reflect(
335
331
  tool_results=str(tool_result),
336
332
  final_answer=final_answer,
337
333
  )
338
- if issubclass(type(reflect_model), LMM):
334
+ if (
335
+ issubclass(type(reflect_model), LMM)
336
+ and image is not None
337
+ and Path(image).suffix in [".jpg", ".jpeg", ".png"]
338
+ ):
339
339
  return reflect_model(prompt, image=image) # type: ignore
340
340
  return reflect_model(prompt)
341
341
 
342
342
 
343
343
  def parse_reflect(reflect: str) -> bool:
344
344
  # GPT-4V has a hard time following directions, so make the criteria less strict
345
- return "finish" in reflect.lower() and len(reflect) < 100
345
+ return (
346
+ "finish" in reflect.lower() and len(reflect) < 100
347
+ ) or "finish" in reflect.lower()[-10:]
348
+
349
+
350
+ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
351
+ image_to_data: Dict[str, Dict] = {}
352
+ for tool_result in all_tool_results:
353
+ if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
354
+ continue
355
+
356
+ parameters = tool_result["parameters"]
357
+ # parameters can either be a dictionary or list, parameters can also be malformed
358
+ # becaus the LLM builds them
359
+ if isinstance(parameters, dict):
360
+ if "image" not in parameters:
361
+ continue
362
+ parameters = [parameters]
363
+ elif isinstance(tool_result["parameters"], list):
364
+ if (
365
+ len(tool_result["parameters"]) < 1
366
+ and "image" not in tool_result["parameters"][0]
367
+ ):
368
+ continue
369
+
370
+ for param, call_result in zip(parameters, tool_result["call_results"]):
371
+
372
+ # calls can fail, so we need to check if the call was successful
373
+ if not isinstance(call_result, dict):
374
+ continue
375
+ if "bboxes" not in call_result:
376
+ continue
377
+
378
+ # if the call was successful, then we can add the image data
379
+ image = param["image"]
380
+ if image not in image_to_data:
381
+ image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
382
+
383
+ image_to_data[image]["bboxes"].extend(call_result["bboxes"])
384
+ image_to_data[image]["labels"].extend(call_result["labels"])
385
+ if "masks" in call_result:
386
+ image_to_data[image]["masks"].extend(call_result["masks"])
387
+
388
+ visualized_images = []
389
+ for image in image_to_data:
390
+ image_path = Path(image)
391
+ image_data = image_to_data[image]
392
+ image = overlay_masks(image_path, image_data)
393
+ image = overlay_bboxes(image, image_data)
394
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
395
+ image.save(f.name)
396
+ visualized_images.append(f.name)
397
+ return visualized_images
346
398
 
347
399
 
348
400
  class VisionAgent(Agent):
@@ -371,10 +423,16 @@ class VisionAgent(Agent):
371
423
  verbose: bool = False,
372
424
  ):
373
425
  self.task_model = (
374
- OpenAILLM(json_mode=True) if task_model is None else task_model
426
+ OpenAILLM(json_mode=True, temperature=0.1)
427
+ if task_model is None
428
+ else task_model
429
+ )
430
+ self.answer_model = (
431
+ OpenAILLM(temperature=0.1) if answer_model is None else answer_model
432
+ )
433
+ self.reflect_model = (
434
+ OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
375
435
  )
376
- self.answer_model = OpenAILLM() if answer_model is None else answer_model
377
- self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model
378
436
  self.max_retries = max_retries
379
437
 
380
438
  self.tools = TOOLS
@@ -389,7 +447,8 @@ class VisionAgent(Agent):
389
447
  """Invoke the vision agent.
390
448
 
391
449
  Parameters:
392
- input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
450
+ input: a prompt that describe the task or a conversation in the format of
451
+ [{"role": "user", "content": "describe your task here..."}].
393
452
  image: the input image referenced in the prompt parameter.
394
453
 
395
454
  Returns:
@@ -413,7 +472,6 @@ class VisionAgent(Agent):
413
472
  for _ in range(self.max_retries):
414
473
  task_list = create_tasks(self.task_model, question, self.tools, reflections)
415
474
 
416
- _LOGGER.info(f"Task Dependency: {task_list}")
417
475
  task_depend = {"Original Quesiton": question}
418
476
  previous_log = ""
419
477
  answers = []
@@ -424,7 +482,6 @@ class VisionAgent(Agent):
424
482
  for task in task_list:
425
483
  task_str = task["task"]
426
484
  previous_log = str(task_depend)
427
- _LOGGER.info(f"\tSubtask: {task_str}")
428
485
  tool_results, call_results = retrieval(
429
486
  self.task_model,
430
487
  task_str,
@@ -436,10 +493,10 @@ class VisionAgent(Agent):
436
493
  self.answer_model, task_str, call_results, previous_log, reflections
437
494
  )
438
495
 
439
- for tool_result in tool_results:
440
- tool_result["answer"] = answer
441
- all_tool_results.extend(tool_results)
496
+ tool_results["answer"] = answer
497
+ all_tool_results.append(tool_results)
442
498
 
499
+ _LOGGER.info(f"\tCall Result: {call_results}")
443
500
  _LOGGER.info(f"\tAnswer: {answer}")
444
501
  answers.append({"task": task_str, "answer": answer})
445
502
  task_depend[task["id"]]["answer"] = answer # type: ignore
@@ -448,15 +505,17 @@ class VisionAgent(Agent):
448
505
  self.answer_model, question, answers, reflections
449
506
  )
450
507
 
508
+ visualized_images = visualize_result(all_tool_results)
509
+ all_tool_results.append({"visualized_images": visualized_images})
451
510
  reflection = self_reflect(
452
511
  self.reflect_model,
453
512
  question,
454
513
  self.tools,
455
514
  all_tool_results,
456
515
  final_answer,
457
- image,
516
+ visualized_images[0] if len(visualized_images) > 0 else image,
458
517
  )
459
- _LOGGER.info(f"\tReflection: {reflection}")
518
+ _LOGGER.info(f"Reflection: {reflection}")
460
519
  if parse_reflect(reflection):
461
520
  break
462
521
  else:
@@ -3,15 +3,38 @@
3
3
  import base64
4
4
  from io import BytesIO
5
5
  from pathlib import Path
6
- from typing import Tuple, Union
6
+ from typing import Dict, Tuple, Union
7
7
 
8
8
  import numpy as np
9
- from PIL import Image
9
+ from PIL import Image, ImageDraw, ImageFont
10
10
  from PIL.Image import Image as ImageType
11
11
 
12
+ COLORS = [
13
+ (158, 218, 229),
14
+ (219, 219, 141),
15
+ (23, 190, 207),
16
+ (188, 189, 34),
17
+ (199, 199, 199),
18
+ (247, 182, 210),
19
+ (127, 127, 127),
20
+ (227, 119, 194),
21
+ (196, 156, 148),
22
+ (197, 176, 213),
23
+ (140, 86, 75),
24
+ (148, 103, 189),
25
+ (255, 152, 150),
26
+ (152, 223, 138),
27
+ (214, 39, 40),
28
+ (44, 160, 44),
29
+ (255, 187, 120),
30
+ (174, 199, 232),
31
+ (255, 127, 14),
32
+ (31, 119, 180),
33
+ ]
34
+
12
35
 
13
36
  def b64_to_pil(b64_str: str) -> ImageType:
14
- """Convert a base64 string to a PIL Image.
37
+ r"""Convert a base64 string to a PIL Image.
15
38
 
16
39
  Parameters:
17
40
  b64_str: the base64 encoded image
@@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:
26
49
 
27
50
 
28
51
  def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
29
- """Get the size of an image.
52
+ r"""Get the size of an image.
30
53
 
31
54
  Parameters:
32
55
  data: the input image
@@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,
41
64
 
42
65
 
43
66
  def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
44
- """Convert an image to a base64 string.
67
+ r"""Convert an image to a base64 string.
45
68
 
46
69
  Parameters:
47
70
  data: the input image
@@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
60
83
  else:
61
84
  arr_bytes = data.tobytes()
62
85
  return base64.b64encode(arr_bytes).decode("utf-8")
86
+
87
+
88
+ def overlay_bboxes(
89
+ image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
90
+ ) -> ImageType:
91
+ r"""Plots bounding boxes on to an image.
92
+
93
+ Parameters:
94
+ image: the input image
95
+ bboxes: the bounding boxes to overlay
96
+
97
+ Returns:
98
+ The image with the bounding boxes overlayed
99
+ """
100
+ if isinstance(image, (str, Path)):
101
+ image = Image.open(image)
102
+ elif isinstance(image, np.ndarray):
103
+ image = Image.fromarray(image)
104
+
105
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
106
+
107
+ draw = ImageDraw.Draw(image)
108
+ font = ImageFont.load_default()
109
+ width, height = image.size
110
+ if "bboxes" not in bboxes:
111
+ return image.convert("RGB")
112
+
113
+ for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
114
+ box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
115
+ draw.rectangle(box, outline=color[label], width=3)
116
+ label = f"{label}"
117
+ text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
118
+ draw.rectangle(text_box, fill=color[label])
119
+ draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
120
+ return image.convert("RGB")
121
+
122
+
123
+ def overlay_masks(
124
+ image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
125
+ ) -> ImageType:
126
+ r"""Plots masks on to an image.
127
+
128
+ Parameters:
129
+ image: the input image
130
+ masks: the masks to overlay
131
+ alpha: the transparency of the overlay
132
+
133
+ Returns:
134
+ The image with the masks overlayed
135
+ """
136
+ if isinstance(image, (str, Path)):
137
+ image = Image.open(image)
138
+ elif isinstance(image, np.ndarray):
139
+ image = Image.fromarray(image)
140
+
141
+ color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
142
+ if "masks" not in masks:
143
+ return image.convert("RGB")
144
+
145
+ for label, mask in zip(masks["labels"], masks["masks"]):
146
+ if isinstance(mask, str):
147
+ mask = np.array(Image.open(mask))
148
+ np_mask = np.zeros((image.size[1], image.size[0], 4))
149
+ np_mask[mask > 0, :] = color[label] + (255 * alpha,)
150
+ mask_img = Image.fromarray(np_mask.astype(np.uint8))
151
+ image = Image.alpha_composite(image.convert("RGBA"), mask_img)
152
+ return image.convert("RGB")
vision_agent/llm/llm.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
- from typing import Callable, Dict, List, Mapping, Union, cast
3
+ from typing import Any, Callable, Dict, List, Mapping, Union, cast
4
4
 
5
5
  from openai import OpenAI
6
6
 
@@ -31,30 +31,33 @@ class OpenAILLM(LLM):
31
31
  r"""An LLM class for any OpenAI LLM model."""
32
32
 
33
33
  def __init__(
34
- self, model_name: str = "gpt-4-turbo-preview", json_mode: bool = False
34
+ self,
35
+ model_name: str = "gpt-4-turbo-preview",
36
+ json_mode: bool = False,
37
+ **kwargs: Any
35
38
  ):
36
39
  self.model_name = model_name
37
40
  self.client = OpenAI()
38
- self.json_mode = json_mode
41
+ self.kwargs = kwargs
42
+ if json_mode:
43
+ self.kwargs["response_format"] = {"type": "json_object"}
39
44
 
40
45
  def generate(self, prompt: str) -> str:
41
- kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
42
46
  response = self.client.chat.completions.create(
43
47
  model=self.model_name,
44
48
  messages=[
45
49
  {"role": "user", "content": prompt},
46
50
  ],
47
- **kwargs, # type: ignore
51
+ **self.kwargs,
48
52
  )
49
53
 
50
54
  return cast(str, response.choices[0].message.content)
51
55
 
52
56
  def chat(self, chat: List[Dict[str, str]]) -> str:
53
- kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
54
57
  response = self.client.chat.completions.create(
55
58
  model=self.model_name,
56
59
  messages=chat, # type: ignore
57
- **kwargs,
60
+ **self.kwargs,
58
61
  )
59
62
 
60
63
  return cast(str, response.choices[0].message.content)
vision_agent/lmm/lmm.py CHANGED
@@ -97,11 +97,15 @@ class OpenAILMM(LMM):
97
97
  r"""An LMM class for the OpenAI GPT-4 Vision model."""
98
98
 
99
99
  def __init__(
100
- self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024
100
+ self,
101
+ model_name: str = "gpt-4-vision-preview",
102
+ max_tokens: int = 1024,
103
+ **kwargs: Any,
101
104
  ):
102
105
  self.model_name = model_name
103
106
  self.max_tokens = max_tokens
104
107
  self.client = OpenAI()
108
+ self.kwargs = kwargs
105
109
 
106
110
  def __call__(
107
111
  self,
@@ -123,6 +127,13 @@ class OpenAILMM(LMM):
123
127
 
124
128
  if image:
125
129
  extension = Path(image).suffix
130
+ if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
131
+ extension = "jpg"
132
+ elif extension.lower() == ".png":
133
+ extension = "png"
134
+ else:
135
+ raise ValueError(f"Unsupported image extension: {extension}")
136
+
126
137
  encoded_image = encode_image(image)
127
138
  fixed_chat[0]["content"].append( # type: ignore
128
139
  {
@@ -135,7 +146,7 @@ class OpenAILMM(LMM):
135
146
  )
136
147
 
137
148
  response = self.client.chat.completions.create(
138
- model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore
149
+ model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
139
150
  )
140
151
 
141
152
  return cast(str, response.choices[0].message.content)
@@ -163,7 +174,7 @@ class OpenAILMM(LMM):
163
174
  )
164
175
 
165
176
  response = self.client.chat.completions.create(
166
- model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore
177
+ model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
167
178
  )
168
179
  return cast(str, response.choices[0].message.content)
169
180
 
@@ -1,2 +1,15 @@
1
1
  from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
2
- from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
2
+ from .tools import (
3
+ CLIP,
4
+ TOOLS,
5
+ BboxArea,
6
+ BboxIoU,
7
+ Counter,
8
+ Crop,
9
+ ExtractFrames,
10
+ GroundingDINO,
11
+ GroundingSAM,
12
+ SegArea,
13
+ SegIoU,
14
+ Tool,
15
+ )
@@ -92,7 +92,7 @@ class CLIP(Tool):
92
92
  }
93
93
 
94
94
  # TODO: Add support for input multiple images, which aligns with the output type.
95
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
95
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
96
96
  """Invoke the CLIP model.
97
97
 
98
98
  Parameters:
@@ -122,7 +122,7 @@ class CLIP(Tool):
122
122
  rets = []
123
123
  for elt in resp_json["data"]:
124
124
  rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
125
- return cast(List[Dict], rets)
125
+ return cast(Dict, rets[0])
126
126
 
127
127
 
128
128
  class GroundingDINO(Tool):
@@ -168,7 +168,7 @@ class GroundingDINO(Tool):
168
168
  }
169
169
 
170
170
  # TODO: Add support for input multiple images, which aligns with the output type.
171
- def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]:
171
+ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
172
172
  """Invoke the Grounding DINO model.
173
173
 
174
174
  Parameters:
@@ -204,7 +204,7 @@ class GroundingDINO(Tool):
204
204
  if "scores" in elt:
205
205
  elt["scores"] = [round(score, 2) for score in elt["scores"]]
206
206
  elt["size"] = (image_size[1], image_size[0])
207
- return cast(List[Dict], resp_data)
207
+ return cast(Dict, resp_data)
208
208
 
209
209
 
210
210
  class GroundingSAM(Tool):
@@ -259,7 +259,7 @@ class GroundingSAM(Tool):
259
259
  }
260
260
 
261
261
  # TODO: Add support for input multiple images, which aligns with the output type.
262
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
262
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
263
263
  """Invoke the Grounding SAM model.
264
264
 
265
265
  Parameters:
@@ -294,7 +294,7 @@ class GroundingSAM(Tool):
294
294
  ret_pred["labels"].append(pred["label_name"])
295
295
  ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
296
296
  ret_pred["masks"].append(mask)
297
- return [ret_pred]
297
+ return ret_pred
298
298
 
299
299
 
300
300
  class AgentGroundingSAM(GroundingSAM):
@@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM):
302
302
  returns the file name. This makes it easier for agents to use.
303
303
  """
304
304
 
305
- def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
305
+ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
306
306
  rets = super().__call__(prompt, image)
307
- for ret in rets:
308
- mask_files = []
309
- for mask in ret["masks"]:
310
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
311
- Image.fromarray(mask * 255).save(tmp)
312
- mask_files.append(tmp.name)
313
- ret["masks"] = mask_files
307
+ mask_files = []
308
+ for mask in rets["masks"]:
309
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
310
+ Image.fromarray(mask * 255).save(tmp)
311
+ mask_files.append(tmp.name)
312
+ rets["masks"] = mask_files
314
313
  return rets
315
314
 
316
315
 
@@ -363,7 +362,7 @@ class Crop(Tool):
363
362
  ],
364
363
  }
365
364
 
366
- def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
365
+ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
367
366
  pil_image = Image.open(image)
368
367
  width, height = pil_image.size
369
368
  bbox = [
@@ -373,10 +372,10 @@ class Crop(Tool):
373
372
  int(bbox[3] * height),
374
373
  ]
375
374
  cropped_image = pil_image.crop(bbox) # type: ignore
376
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
375
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
377
376
  cropped_image.save(tmp.name)
378
377
 
379
- return tmp.name
378
+ return {"image": tmp.name}
380
379
 
381
380
 
382
381
  class BboxArea(Tool):
@@ -388,7 +387,7 @@ class BboxArea(Tool):
388
387
  "required_parameters": [{"name": "bbox", "type": "List[int]"}],
389
388
  "examples": [
390
389
  {
391
- "scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]",
390
+ "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
392
391
  "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
393
392
  }
394
393
  ],
@@ -430,6 +429,109 @@ class SegArea(Tool):
430
429
  return cast(float, round(np.sum(np_mask) / 255, 2))
431
430
 
432
431
 
432
+ class BboxIoU(Tool):
433
+ name = "bbox_iou_"
434
+ description = (
435
+ "'bbox_iou_' returns the intersection over union of two bounding boxes."
436
+ )
437
+ usage = {
438
+ "required_parameters": [
439
+ {"name": "bbox1", "type": "List[int]"},
440
+ {"name": "bbox2", "type": "List[int]"},
441
+ ],
442
+ "examples": [
443
+ {
444
+ "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
445
+ "parameters": {
446
+ "bbox1": [0.2, 0.21, 0.34, 0.42],
447
+ "bbox2": [0.3, 0.31, 0.44, 0.52],
448
+ },
449
+ }
450
+ ],
451
+ }
452
+
453
+ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
454
+ x1, y1, x2, y2 = bbox1
455
+ x3, y3, x4, y4 = bbox2
456
+ xA = max(x1, x3)
457
+ yA = max(y1, y3)
458
+ xB = min(x2, x4)
459
+ yB = min(y2, y4)
460
+ inter_area = max(0, xB - xA) * max(0, yB - yA)
461
+ boxa_area = (x2 - x1) * (y2 - y1)
462
+ boxb_area = (x4 - x3) * (y4 - y3)
463
+ iou = inter_area / float(boxa_area + boxb_area - inter_area)
464
+ return round(iou, 2)
465
+
466
+
467
+ class SegIoU(Tool):
468
+ name = "seg_iou_"
469
+ description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
470
+ usage = {
471
+ "required_parameters": [
472
+ {"name": "mask1", "type": "str"},
473
+ {"name": "mask2", "type": "str"},
474
+ ],
475
+ "examples": [
476
+ {
477
+ "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
478
+ "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
479
+ }
480
+ ],
481
+ }
482
+
483
+ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
484
+ pil_mask1 = Image.open(str(mask1))
485
+ pil_mask2 = Image.open(str(mask2))
486
+ np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
487
+ np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
488
+ intersection = np.logical_and(np_mask1, np_mask2)
489
+ union = np.logical_or(np_mask1, np_mask2)
490
+ iou = np.sum(intersection) / np.sum(union)
491
+ return cast(float, round(iou, 2))
492
+
493
+
494
+ class ExtractFrames(Tool):
495
+ r"""Extract frames from a video."""
496
+
497
+ name = "extract_frames_"
498
+ description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
499
+ usage = {
500
+ "required_parameters": [{"name": "video_uri", "type": "str"}],
501
+ "examples": [
502
+ {
503
+ "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
504
+ "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
505
+ },
506
+ {
507
+ "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
508
+ "parameters": {"video_uri": "tests/data/test.mp4"},
509
+ },
510
+ ],
511
+ }
512
+
513
+ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
514
+ """Extract frames from a video.
515
+
516
+
517
+ Parameters:
518
+ video_uri: the path to the video file or a url points to the video data
519
+
520
+ Returns:
521
+ a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
522
+ """
523
+ frames = extract_frames_from_video(video_uri)
524
+ result = []
525
+ _LOGGER.info(
526
+ f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
527
+ )
528
+ for frame, ts in frames:
529
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
530
+ Image.fromarray(frame).save(tmp)
531
+ result.append((tmp.name, ts))
532
+ return result
533
+
534
+
433
535
  class Add(Tool):
434
536
  r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
435
537
 
@@ -506,47 +608,6 @@ class Divide(Tool):
506
608
  return round(input[0] / input[1], 2)
507
609
 
508
610
 
509
- class ExtractFrames(Tool):
510
- r"""Extract frames from a video."""
511
-
512
- name = "extract_frames_"
513
- description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
514
- usage = {
515
- "required_parameters": [{"name": "video_uri", "type": "str"}],
516
- "examples": [
517
- {
518
- "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
519
- "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
520
- },
521
- {
522
- "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
523
- "parameters": {"video_uri": "tests/data/test.mp4"},
524
- },
525
- ],
526
- }
527
-
528
- def __call__(self, video_uri: str) -> list[tuple[str, float]]:
529
- """Extract frames from a video.
530
-
531
-
532
- Parameters:
533
- video_uri: the path to the video file or a url points to the video data
534
-
535
- Returns:
536
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
537
- """
538
- frames = extract_frames_from_video(video_uri)
539
- result = []
540
- _LOGGER.info(
541
- f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
542
- )
543
- for frame, ts in frames:
544
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
545
- Image.fromarray(frame).save(tmp)
546
- result.append((tmp.name, ts))
547
- return result
548
-
549
-
550
611
  TOOLS = {
551
612
  i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
552
613
  for i, c in enumerate(
@@ -554,15 +615,17 @@ TOOLS = {
554
615
  CLIP,
555
616
  GroundingDINO,
556
617
  AgentGroundingSAM,
618
+ ExtractFrames,
557
619
  Counter,
558
620
  Crop,
559
621
  BboxArea,
560
622
  SegArea,
623
+ BboxIoU,
624
+ SegIoU,
561
625
  Add,
562
626
  Subtract,
563
627
  Multiply,
564
628
  Divide,
565
- ExtractFrames,
566
629
  ]
567
630
  )
568
631
  if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
@@ -22,12 +22,16 @@ def extract_frames_from_video(
22
22
  Parameters:
23
23
  video_uri: the path to the video file or a video file url
24
24
  fps: the frame rate per second to extract the frames
25
- motion_detection_threshold: The threshold to detect motion between changes/frames.
26
- A value between 0-1, which represents the percentage change required for the frames to be considered in motion.
27
- For example, a lower value means more frames will be extracted.
25
+ motion_detection_threshold: The threshold to detect motion between
26
+ changes/frames. A value between 0-1, which represents the percentage change
27
+ required for the frames to be considered in motion. For example, a lower
28
+ value means more frames will be extracted.
28
29
 
29
30
  Returns:
30
- a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
31
+ a list of tuples containing the extracted frame and the timestamp in seconds.
32
+ E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
33
+ from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
34
+ the video. The frames are sorted by the timestamp in ascending order.
31
35
  """
32
36
  with VideoFileClip(video_uri) as video:
33
37
  video_duration: float = video.duration
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.0.40
3
+ Version: 0.0.42
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -5,22 +5,22 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
5
5
  vision_agent/agent/easytool_prompts.py,sha256=uNp12LOFRLr3i2zLhNuLuyFms2-s8es2t6P6h76QDow,4493
6
6
  vision_agent/agent/reflexion.py,sha256=wzpptfALNZIh9Q5jgkK3imGL5LWjTW_n_Ypsvxdh07Q,10101
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=UPGkt_qgHBMUY0VPVoF-BqhR0d_6WPjjrhbYLBYOtnQ,9342
8
- vision_agent/agent/vision_agent.py,sha256=AS8-2mKg476X6ydopcT_Ike3GCmSlzbwYaw-yuHCPl0,15262
8
+ vision_agent/agent/vision_agent.py,sha256=P2melU6XQCCiiL1C_4QsxGUaWbwahuJA90eIcQJTR4U,17449
9
9
  vision_agent/agent/vision_agent_prompts.py,sha256=otaDRsaHc7bqw_tgWTnu-eUcFeOzBFrn9sPU7_xr2VQ,6151
10
10
  vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
11
  vision_agent/data/data.py,sha256=pgtSGZdAnbQ8oGsuapLtFTMPajnCGDGekEXTnFuBwsY,5122
12
12
  vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
13
13
  vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
14
- vision_agent/image_utils.py,sha256=D5H-GN35Bz3u1Fq_JfYQVjNzAmZjJl138wma5fRtVjA,1684
14
+ vision_agent/image_utils.py,sha256=XiOLpHAvlk55URw6iG7hl1OY71FVRA9_25b650amZXA,4420
15
15
  vision_agent/llm/__init__.py,sha256=fBKsIjL4z08eA0QYx6wvhRe4Nkp2pJ4VrZK0-uUL5Ec,32
16
- vision_agent/llm/llm.py,sha256=d8A7jmLVGx5HzoiYJ75mTMU7dbD5-bOYeXYlHaay6WA,3957
16
+ vision_agent/llm/llm.py,sha256=l8ZVh6vCZOJBHfenfOoHwPySXEUQoNt_gbL14gkvu2g,3904
17
17
  vision_agent/lmm/__init__.py,sha256=I8mbeNUajTfWVNqLsuFQVOaNBDlkIhYp9DFU8H4kB7g,51
18
- vision_agent/lmm/lmm.py,sha256=ARcbgkcyP83TbVVoXI9B-gtG0gJuTaG_MjcUGbams4U,8052
19
- vision_agent/tools/__init__.py,sha256=aX0pU3pXU1V0Cj9FzYCvdsX76TAglFMHx59kNhXHbPs,131
18
+ vision_agent/lmm/lmm.py,sha256=s_A3SKCoWm2biOt-gS9PXOsa9l-zrmR6mInLjAqam-A,8438
19
+ vision_agent/tools/__init__.py,sha256=AKN-T659HpwVearRnkCd6wWNoJ6K5kW9gAZwb8IQSLE,235
20
20
  vision_agent/tools/prompts.py,sha256=9RBbyqlNlExsGKlJ89Jkph83DAEJ8PCVGaHoNbyN7TM,1416
21
- vision_agent/tools/tools.py,sha256=2mmomPDbldXRpw3q5zAcazKJMjAGd0Jl9ak9JykHQYI,21211
22
- vision_agent/tools/video.py,sha256=KV_Wcat7DDGxpHSaGBu7s4lj4crlYaUu4YKpCO_86k4,7440
23
- vision_agent-0.0.40.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
- vision_agent-0.0.40.dist-info/METADATA,sha256=uPMyB4VrvlIs6R7yGCoRkC2Bf9Zemc6wQF03BjgBFgs,5324
25
- vision_agent-0.0.40.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
26
- vision_agent-0.0.40.dist-info/RECORD,,
21
+ vision_agent/tools/tools.py,sha256=aMTBxxaXQp33HwplOS8xrgfbsTJ8e1pwO6byR7HcTJI,23447
22
+ vision_agent/tools/video.py,sha256=40rscP8YvKN3lhZ4PDcOK4XbdFX2duCRpHY_krmBYKU,7476
23
+ vision_agent-0.0.42.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
+ vision_agent-0.0.42.dist-info/METADATA,sha256=r523uVvu-DsNoA-H-18O2JXF4J9G2nZ2cDSmjXUFq_M,5324
25
+ vision_agent-0.0.42.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
26
+ vision_agent-0.0.42.dist-info/RECORD,,